In [1]:
import logging
logging.basicConfig(level=logging.DEBUG)
from sklearn.metrics import accuracy_score
from sklearn.svm import SVC
import utilities
import os

In [2]:
X_train, y_train, X_test, y_test = utilities.get_ucr_dataset('../UCRArchive_2018/','ItalyPowerDemand')

In [3]:
print(X_train.shape,y_train.shape,X_test.shape)

(67, 24) (67,) (1029, 24)


### Implement 1-NN-DTW

In [4]:
from scipy.spatial import distance
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import classification_report
import numpy as np

In [5]:
def DTW(a, b):   
    an = a.size
    bn = b.size
    pointwise_distance = distance.cdist(a.reshape(-1,1),b.reshape(-1,1))
    cumdist = np.matrix(np.ones((an+1,bn+1)) * np.inf)
    cumdist[0,0] = 0

    for ai in range(an):
        for bi in range(bn):
            minimum_cost = np.min([cumdist[ai, bi+1],
                                   cumdist[ai+1, bi],
                                   cumdist[ai, bi]])
            cumdist[ai+1, bi+1] = pointwise_distance[ai,bi] + minimum_cost

    return cumdist[an, bn]

In [6]:
#train
parameters = {'n_neighbors':[1]}
clf = GridSearchCV(KNeighborsClassifier(metric=DTW), parameters, cv=3, verbose=1)
clf.fit(X_train, y_train)


Fitting 3 folds for each of 1 candidates, totalling 3 fits


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   3 out of   3 | elapsed:   18.5s finished


GridSearchCV(cv=3,
             estimator=KNeighborsClassifier(metric=<function DTW at 0x000001FB789BB5E0>),
             param_grid={'n_neighbors': [1]}, verbose=1)

In [7]:
#evaluate
y_pred = clf.predict(X_test[:100])
print(classification_report(y_test[:100], y_pred))

              precision    recall  f1-score   support

         1.0       0.86      0.95      0.90        39
         2.0       0.96      0.90      0.93        61

    accuracy                           0.92       100
   macro avg       0.91      0.93      0.92       100
weighted avg       0.92      0.92      0.92       100

