In [None]:
# ライブラリのインポート

from itertools import product
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from sklearn import datasets
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.model_selection import cross_val_score

print("libraries imported")

In [None]:
# Load data
digits = datasets.load_digits()

# データを見てみる
print("Data")
print(digits.data.__class__)
print(digits.data.shape)
print(digits.data)
print("Classes ", digits.target_names)
print(digits.target.__class__)
print(digits.target)

print("Sample Data")
print(digits.target[100])
plt.imshow(digits.data[100].reshape(8, 8), cmap=cm.gray_r, interpolation='nearest')
plt.show()

print(digits.target[101])
plt.imshow(digits.data[101].reshape(8, 8), cmap=cm.gray_r, interpolation='nearest')
plt.show()

In [None]:
## 使うデータはこれ
X = digits.data # 64次元特徴ベクトル
y = digits.target # 0から9のラベル

In [None]:
# scikit-learnのモデルに準じたクラスを実装してみる
from sklearn.base import BaseEstimator
from collections import Counter

# k-NN法の実装
# see https://ja.wikipedia.org/wiki/K%E8%BF%91%E5%82%8D%E6%B3%95
class MyNeaestNeiborClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, neighbor_size=4):
        self.neighbor_size = neighbor_size
    
    def fit(self, X, Y):
        self.X = X
        self.Y = Y
        return self

    def predict(self, newX):
        distance_and_label = self._distance_and_label(newX)
    
        ## TODO
        ## calculate most frequent label in neighbors of newX.
        return 0
    
    def _distance_and_label(self, newX):
        return [(self._distance(x, newX), y) for x, y in zip(self.X, self.Y)]
    
    def _distance(self, x1, x2):
        return np.linalg.norm(x1 - x2)
    
    def score(self, X, y):
        hit = 0
        for index, x in enumerate(X):
            pred = self.predict(x)
            hit += 1 if pred == y[index] else 0
                
        return float(hit) / y.shape[0]

In [None]:
# 上でつくった識別器の評価

clf = MyNeaestNeiborClassifier()
clf.fit(digits.data, digits.target)

print("prediction", clf.predict(X[100]))
plt.imshow(digits.data[100].reshape(8, 8), cmap=cm.gray_r, interpolation='nearest')
plt.show()

print("prediction", clf.predict(X[101]))
plt.imshow(digits.data[101].reshape(8, 8), cmap=cm.gray_r, interpolation='nearest')
plt.show()

print("prediction", clf.predict(X[102]))
plt.imshow(digits.data[102].reshape(8, 8), cmap=cm.gray_r, interpolation='nearest')
plt.show()

scores = cross_val_score(clf, X, y, cv=5)
print("Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std() * 2))