In [15]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from collections import defaultdict

In [16]:
class TrainDataSet():
    def __init__(self, data):
        data = np.array(data)

        self.labels = data[:,0]
        self.data_set = data[:,1:]

    def __repr__(self):
        ret  = repr(self.labels) + "\n"
        ret += repr(self.data_set)
        return ret

    def get_data_num(self):
        return self.labels.size

    def get_labels(self, *args):
        if args is None:
            return self.labels
        else:
            return self.labels[args[0]]
    def get_data_set(self):
        return self.data_set

    def get_data_set_partial(self, *args):
        if args is None:
            return self.data_set
        else:
            return self.data_set[args[0]]
    def get_label(self, i):
        return self.labels[i]
    def get_data(self, i):
        return self.data_set[i,:]
    def get_data(self,i, j):
        return self.data_set[i][j]

In [18]:
size = 28
master_data= np.loadtxt('train.csv',delimiter=',',skiprows=1)
test_data= np.loadtxt('test_small.csv',delimiter=',',skiprows=1)

train_data_set = TrainDataSet(master_data)

In [19]:
def get_list_sorted_by_val(k_result, k_dist):
    result_dict = defaultdict(int)
    distance_dict = defaultdict(float)

    # 数字ラベルごとに集計
    for i in k_result:
        result_dict[i] += 1

    # 数字ラベルごとに距離の合計を集計
    for i in range(len(k_dist)):
        distance_dict[k_result[i]] += k_dist[i]

    # 辞書型からリストに変換（ソートするため）
    result_list = []
    order = 0
    for key, val in result_dict.items():
        order += 1
        result_list.append([key, val, distance_dict[key]])

    # ndarray型に変換
    result_list = np.array(result_list) 

    return result_list

In [20]:
k = 5
predicted_list = []    # 数字ラベルの予測値
k_result_list  = []    # k個の近傍リスト
k_distances_list = []  # k個の数字と識別対象データとの距離リスト

# execute k-nearest neighbor method
for i in range(len(test_data)):

    # 識別対象データと教師データの差分をとる
    diff_data = np.tile(test_data[i], (train_data_set.get_data_num(),1)) - train_data_set.get_data_set()

    sq_data   = diff_data ** 2       # 各要素を2乗して符号を消す
    sum_data  = sq_data.sum(axis=1)  # それぞれのベクトル要素を足し合わせる
    distances = sum_data ** 0.5      # ルートをとって距離とする
    ind = distances.argsort()        # 距離の短い順にソートしてその添え字を取り出す
    k_result = train_data_set.get_labels(ind[0:k]) # 近いものからk個取り出す
    k_dist   = distances[ind[0:k]]   # 距離情報もk個取り出す

    k_distances_list.append(k_dist)
    k_result_list.append(k_result)

    # k個のデータから数字ラベルで集約した、(数字ラベル, 個数, 距離)のリストを生成
    result_list = get_list_sorted_by_val(k_result, k_dist)
    candidate = result_list[result_list[:,1].argsort()[::-1]]

    counter = 0
    min = 0
    label_top = 0

    # もっとも数の多い数字ラベルが複数あったらその中で合計距離の小さい方を選択
    result_dict = {}
    for d in candidate:
        if d[1] in result_dict:
            result_dict[d[1]] += [(d[0], d[2])]
        else:
            result_dict[d[1]] =  [(d[0], d[2])]

    for d in result_dict[np.max(result_dict.keys())]:
        if counter == 0:
            label_top = d[0]
            min = d[1]
        else:
            if d[1] < min:
                label_top = d[0]
                min = d[1]
        counter += 1

    # 結果をリストに詰める
    predicted_list.append(label_top)

In [21]:
# disp calc result
print "[Predicted Data List]"
for i in range(len(predicted_list)):
    print ("%d" % i) + "\t" + str(predicted_list[i])

print "[Detail Predicted Data List]"
print "index k units of neighbors, distances for every k units"
for i in range(len(k_result_list)):
    print ("%d" % i) + "\t" + str(k_result_list[i]) + "\t" + str(k_distances_list[i])

[Predicted Data List]
0	2.0
1	0.0
2	9.0
3	9.0
4	3.0
5	7.0
6	0.0
7	3.0
8	0.0
9	3.0
10	5.0
11	7.0
12	4.0
13	0.0
14	4.0
15	3.0
16	3.0
17	1.0
18	9.0
19	0.0
20	9.0
21	1.0
22	1.0
23	5.0
24	7.0
25	4.0
26	2.0
27	7.0
28	4.0
29	7.0
30	7.0
31	5.0
32	4.0
33	2.0
34	6.0
35	2.0
36	5.0
37	5.0
38	1.0
39	6.0
40	7.0
41	7.0
42	4.0
43	9.0
44	8.0
45	7.0
46	8.0
47	2.0
48	6.0
49	7.0
50	6.0
51	8.0
52	8.0
53	3.0
54	8.0
55	2.0
56	1.0
57	2.0
58	2.0
59	0.0
60	4.0
61	1.0
62	7.0
63	0.0
64	0.0
65	0.0
66	1.0
67	9.0
68	0.0
69	1.0
70	6.0
71	5.0
72	8.0
73	8.0
74	2.0
75	8.0
76	8.0
77	9.0
78	2.0
79	3.0
80	5.0
81	4.0
82	1.0
83	0.0
84	9.0
85	2.0
86	4.0
87	3.0
88	6.0
89	7.0
90	2.0
91	0.0
92	6.0
93	6.0
94	1.0
95	4.0
96	3.0
97	9.0
98	7.0
99	4.0
100	0.0
101	9.0
102	2.0
103	0.0
104	7.0
105	3.0
106	0.0
107	5.0
108	0.0
109	8.0
110	0.0
111	0.0
112	4.0
113	7.0
114	1.0
115	7.0
116	1.0
117	1.0
118	3.0
119	3.0
120	3.0
121	7.0
122	2.0
123	8.0
124	6.0
125	3.0
126	8.0
127	7.0
128	8.0
129	4.0
130	3.0
131	5.0
132	6.0
133	0.0
134	0.0
135	0.0
