In [24]:
import numpy as np
from pathlib import Path
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.model_selection import GridSearchCV, train_test_split, RepeatedStratifiedKFold
from sklearn.metrics import accuracy_score, log_loss
from sklearn.gaussian_process.kernels import RBF, DotProduct, Matern, RationalQuadratic, WhiteKernel, ExpSineSquared

grid = dict()
# grid['kernel'] = [1*RBF(), 1*DotProduct(), 1*Matern(),  1*RationalQuadratic(),1*Matern()+1*WhiteKernel(noise_level=0.5)]
grid['kernel'] = [1*RBF(length_scale=1.0, length_scale_bounds=(1e-2, 1e2)), 
                  1*DotProduct(), 1*Matern(),  1*RationalQuadratic(),
                  1*Matern()+1*WhiteKernel(noise_level=0.5),
                  1*ExpSineSquared(),
                  1.0 * ExpSineSquared(length_scale=1.0, periodicity=3.0, length_scale_bounds=(0.1, 10.0), periodicity_bounds=(1.0, 10.0)),
                  1*Matern(length_scale=1.0, length_scale_bounds=(1e-1, 10.0), nu=1.5)+1*WhiteKernel(noise_level=0.5),
                ]


dataDir = Path.cwd().parent.parent.parent/'Data/processed'
ts_dataset = pd.read_csv(dataDir/"memory_24_meanCost_minPara_65%.csv", index_col="id")

In [67]:
X = ts_dataset.iloc[:, 0:len(ts_dataset.columns)-1].copy()
y = pd.DataFrame(ts_dataset.iloc[:, -1])
# split dataset into training and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=56)
model = GaussianProcessClassifier()
cv = RepeatedStratifiedKFold(n_splits=5, n_repeats=5, random_state=5)
# exhausive search over different kernels
search = GridSearchCV(estimator=model, param_grid=grid, scoring='accuracy', cv=cv, n_jobs=-1)
# fit model into the dataset
result = search.fit(X, y)

print("Mean cross-validated score of the best_estimator: ", result.best_score_)
print("Best estimator parameters: ", result.best_params_)

means = result.cv_results_['mean_test_score']
params = result.cv_results_['params']
for mean, param in zip(means, params):
    print("Accuracy %.3f with: %r" % (mean, param))

Mean cross-validated score of the best_estimator:  0.9840000000000001
Best estimator parameters:  {'kernel': 1**2 * DotProduct(sigma_0=1)}
Accuracy 0.932 with: {'kernel': 1**2 * RBF(length_scale=1)}
Accuracy 0.984 with: {'kernel': 1**2 * DotProduct(sigma_0=1)}
Accuracy 0.956 with: {'kernel': 1**2 * Matern(length_scale=1, nu=1.5)}
Accuracy 0.956 with: {'kernel': 1**2 * RationalQuadratic(alpha=1, length_scale=1)}
Accuracy 0.940 with: {'kernel': 1**2 * Matern(length_scale=1, nu=1.5) + 1**2 * WhiteKernel(noise_level=0.5)}
Accuracy nan with: {'kernel': 1**2 * ExpSineSquared(length_scale=1, periodicity=1)}
Accuracy nan with: {'kernel': 1**2 * ExpSineSquared(length_scale=1, periodicity=3)}
Accuracy 0.900 with: {'kernel': 1**2 * Matern(length_scale=1, nu=1.5) + 1**2 * WhiteKernel(noise_level=0.5)}


28 fits failed out of a total of 200.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
1 fits failed with the following error:
Traceback (most recent call last):
  File "d:\Toolbox\python\lib\site-packages\sklearn\model_selection\_validation.py", line 680, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "d:\Toolbox\python\lib\site-packages\sklearn\gaussian_process\_gpc.py", line 715, in fit
    self.base_estimator_.fit(X, y)
  File "d:\Toolbox\python\lib\site-packages\sklearn\gaussian_process\_gpc.py", line 224, in fit
    self._constrained_optimization(
  File "d:\Toolbox\python\lib\site-packages\sklearn\gaussian_process\_gpc.py", line 469, in _constrained_optimization
    opt_res = scipy.optimize.minimize(
  F

In [68]:
y.value_counts()

label
0        15
1         9
dtype: int64

In [73]:
y_test.value_counts()

label
0        3
1        2
dtype: int64

In [70]:
search.best_params_

{'kernel': 1**2 * DotProduct(sigma_0=1)}

In [72]:
best_kernel = 1**2 * DotProduct(sigma_0=1)
# best_kernel = 1**2 * RationalQuadratic(alpha=1, length_scale=1)
# best_kernel = 1**2 * Matern(length_scale=1, nu=1.5)
# best_kernel = 1**2 * Matern(length_scale=1, nu=1.5) + 1**2 * WhiteKernel(noise_level=0.5)

model = GaussianProcessClassifier(kernel=best_kernel, random_state=5, n_jobs=-1)
model.fit(X_train, y_train.values.ravel())
print("Mean accuracy on training data: ", model.score(X_train, y_train))
print("Prediction on test data: ", model.predict(X_test))
pred_test = model.predict(X_test)
print("Prediction accuracy on test data: ", accuracy_score(y_test, pred_test))

Mean accuracy on training data:  1.0
Prediction on test data:  [0 0 0 1 1]
Prediction accuracy on test data:  1.0


In [23]:
y_test

Unnamed: 0_level_0,label
id,Unnamed: 1_level_1
14,0
9,0
27,0
20,1
29,0
31,1
19,1
23,0


In [20]:
X_test

Unnamed: 0_level_0,"Total_MilkProduction__cwt_coefficients__coeff_1__w_5__widths_(2, 5, 10, 20)","Total_MilkProduction__cwt_coefficients__coeff_0__w_2__widths_(2, 5, 10, 20)","Total_MilkProduction__cwt_coefficients__coeff_12__w_20__widths_(2, 5, 10, 20)","Total_MilkProduction__cwt_coefficients__coeff_12__w_5__widths_(2, 5, 10, 20)","Total_MilkProduction__cwt_coefficients__coeff_10__w_10__widths_(2, 5, 10, 20)","Total_MilkProduction__cwt_coefficients__coeff_13__w_20__widths_(2, 5, 10, 20)","Total_MilkProduction__cwt_coefficients__coeff_7__w_5__widths_(2, 5, 10, 20)","Total_MilkProduction__cwt_coefficients__coeff_8__w_5__widths_(2, 5, 10, 20)","Total_MilkProduction__fft_coefficient__attr_""imag""__coeff_12","Total_MilkProduction__fft_coefficient__attr_""angle""__coeff_3",...,milking_times__energy_ratio_by_chunks__num_segments_10__segment_focus_0.1,milking_times__absolute_sum_of_changes.1,milking_times__number_cwt_peaks__n_1.1,"milking_times__fft_aggregated__aggtype_""skew"".1","milking_times__fft_aggregated__aggtype_""variance"".1",milking_times__range_count__max_1000000000000.0__min_0.1,BreedName_1,BreedName_2,BreedName_4,BreedName_99
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
14,-0.941324,-1.984503,-0.103774,1.174812,-0.44922,-0.063459,1.26798,1.591201,-0.972685,0.95797,...,-0.390558,1.298683,0.031513,-0.24011,0.507893,0.346793,0.0,1.0,0.0,0.0
9,-0.911597,-0.149092,-0.416218,0.250451,-0.704385,-0.379315,-0.53184,-0.37644,-0.429435,0.403844,...,-0.315649,-0.586027,0.535715,1.792902,-0.209928,0.346793,1.0,0.0,0.0,0.0
27,-0.943162,-0.136274,-0.890937,0.689897,-0.885752,-0.861434,0.244769,0.631335,0.642924,0.747909,...,2.638429,0.827506,0.031513,-0.386302,0.048515,0.346793,1.0,0.0,0.0,0.0
20,0.698688,1.526522,-1.399463,-1.64547,-1.080614,-1.424495,-0.999164,-1.216966,-0.156442,0.680938,...,0.774766,0.73327,0.031513,-1.363089,0.129965,0.122603,0.0,0.0,0.0,1.0
29,0.287826,-0.508069,-0.234384,-2.18491,-0.901808,-0.199416,-0.597324,-0.946084,2.156876,-0.052619,...,-0.001069,0.73327,1.544121,-0.403901,0.772745,0.458887,0.0,1.0,0.0,0.0
31,0.574778,0.071772,-0.615627,1.5778,0.613386,-0.665367,0.137427,0.408723,-0.265203,-0.85259,...,-0.129655,1.675625,0.535715,-0.867384,0.674173,0.346793,0.0,1.0,0.0,0.0
19,0.505665,-1.909379,-0.069375,-0.454321,0.305638,-0.103548,0.92895,0.802927,0.308731,0.260966,...,-0.791358,0.356328,-0.47269,-0.76706,0.522058,0.234698,0.0,1.0,0.0,0.0
23,-0.150251,0.542351,-1.840801,0.492908,-0.582237,-1.910677,-0.640838,-0.485917,-0.404893,-0.558293,...,-0.523991,-1.057205,-2.489501,-2.271925,-1.992803,-2.007193,0.0,0.0,1.0,0.0


In [22]:
X

Unnamed: 0_level_0,"Total_MilkProduction__cwt_coefficients__coeff_1__w_5__widths_(2, 5, 10, 20)","Total_MilkProduction__cwt_coefficients__coeff_0__w_2__widths_(2, 5, 10, 20)","Total_MilkProduction__cwt_coefficients__coeff_12__w_20__widths_(2, 5, 10, 20)","Total_MilkProduction__cwt_coefficients__coeff_12__w_5__widths_(2, 5, 10, 20)","Total_MilkProduction__cwt_coefficients__coeff_10__w_10__widths_(2, 5, 10, 20)","Total_MilkProduction__cwt_coefficients__coeff_13__w_20__widths_(2, 5, 10, 20)","Total_MilkProduction__cwt_coefficients__coeff_7__w_5__widths_(2, 5, 10, 20)","Total_MilkProduction__cwt_coefficients__coeff_8__w_5__widths_(2, 5, 10, 20)","Total_MilkProduction__fft_coefficient__attr_""imag""__coeff_12","Total_MilkProduction__fft_coefficient__attr_""angle""__coeff_3",...,milking_times__energy_ratio_by_chunks__num_segments_10__segment_focus_0.1,milking_times__absolute_sum_of_changes.1,milking_times__number_cwt_peaks__n_1.1,"milking_times__fft_aggregated__aggtype_""skew"".1","milking_times__fft_aggregated__aggtype_""variance"".1",milking_times__range_count__max_1000000000000.0__min_0.1,BreedName_1,BreedName_2,BreedName_4,BreedName_99
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1,-1.015706,0.202499,-0.538062,0.600443,-0.69881,-0.528962,-1.066386,-0.848579,-1.242208,0.331626,...,-0.448776,-0.491792,0.031513,0.125394,0.74867,0.458887,0.0,1.0,0.0,0.0
2,-0.370088,-0.74293,1.263955,1.009445,0.560211,1.278024,0.138514,0.190373,2.100237,0.248487,...,1.701128,-0.586027,0.031513,0.508881,0.25506,0.346793,1.0,0.0,0.0,0.0
3,0.947747,-0.098435,-0.601347,-0.665421,-0.382099,-0.596389,-1.459548,-1.654593,2.178404,-0.900529,...,-0.223737,-0.774498,-0.47269,-0.285542,-0.273607,0.122603,1.0,0.0,0.0,0.0
4,-0.696069,-0.04521,-0.139337,-0.227413,-0.476032,-0.144297,-0.875224,-0.861702,-0.792088,-0.016373,...,-0.331777,-0.491792,0.535715,1.372743,-0.19928,0.458887,1.0,0.0,0.0,0.0
5,-0.460647,-1.344521,0.424572,0.196696,-0.154607,0.443588,0.428958,0.437358,-0.657242,0.440132,...,-0.719809,-1.99956,-1.985298,0.492485,-2.539226,-2.567666,1.0,0.0,0.0,0.0
6,-0.41277,0.925172,0.95214,-0.345148,-0.211234,0.986574,0.268532,0.258788,0.862191,0.660661,...,-0.893962,0.450564,1.039918,0.089975,0.621241,0.458887,0.0,1.0,0.0,0.0
7,1.770969,-0.632369,2.027828,-0.932173,2.427939,1.980716,1.446899,0.978331,-0.403626,0.911964,...,1.19223,0.167857,1.039918,-0.702109,1.041626,0.346793,0.0,1.0,0.0,0.0
8,-1.126232,0.417633,0.031177,0.720838,-0.388279,0.051566,-0.210772,0.017072,1.674054,0.443476,...,-0.854286,1.675625,1.039918,-0.936993,1.189265,0.458887,1.0,0.0,0.0,0.0
9,-0.911597,-0.149092,-0.416218,0.250451,-0.704385,-0.379315,-0.53184,-0.37644,-0.429435,0.403844,...,-0.315649,-0.586027,0.535715,1.792902,-0.209928,0.346793,1.0,0.0,0.0,0.0
10,1.227416,0.863289,0.840074,-0.885345,0.801816,0.781734,0.148674,-0.090367,-2.47718,-2.64342,...,0.353272,-0.11485,0.031513,-0.349275,0.001149,-0.101586,0.0,1.0,0.0,0.0
