In [1]:
import pickle

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

from matplotlib.ticker import NullFormatter

import numpy as np
import torch
from sklearn import metrics

In [None]:
import faiss

In [None]:
class FaissKNeighbors:
    def __init__(self, k):
        self.index = None
        self.y = None
        self.k = k
        
        # print(faiss.StandardGpuResources())

    def fit(self, X, y):
        self.index = faiss.IndexFlatL2(X.shape[1])
        self.index.add(X.astype(np.float32))
        self.y = y

    def predict(self, X):
        distances, indices = self.index.search(X.astype(np.float32), k=self.k)
        # このvotesの中に獲得票数が入っている。
        votes = self.y[indices]
        self.votes = votes
        
        predictions = np.array([np.argmax(np.bincount(x)) for x in votes])
        return predictions
    
    def get_each_prediction_and_votes(self, X):
        distances, indices = self.index.search(X.astype(np.float32), k=self.k)
        votes = self.y[indices]

        pred_array = np.empty((X.shape[0], 0))
        for one_k in range(self.k):
            pred = np.array([np.argmax(np.bincount(x[:one_k+1])) for x in votes])
            pred_array = np.append(pred_array, pred.reshape(-1,1), axis=1)
        
        return pred_array, votes
    
    def get_votes(self, X):
        distances, indices = self.index.search(X.astype(np.float32), k=self.k)
        votes = self.y[indices]
        
        return votes

In [None]:
### Load SoftMax

model_name = "resnet"
train_name = "IN"
test_name = "IN"


classes = ['airplane', 'bear', 'bicycle', 'bird', 
          'boat', 'bottle', 'car', 'cat', 
          'chair', 'clock', 'dog', 'elephant',
          'keyboard', 'knife', 'oven', 'truck']


load_path = f'./softmax_dict/{test_name}_to_{model_name}-{train_name}_softmax_dict.pkl'
with open(load_path, mode='rb') as f:
    dicts = pickle.load(f)

X = np.empty((0, 1000))
y = np.empty((0, ))

for i, c in enumerate(classes):
  X = np.append(X, dicts[c], axis=0).astype(float)
  y = np.append(y, [i]*dicts[c].shape[0], axis=0).astype(int)

train_X = np.empty((0, 1000))
train_y = np.empty((0, ))

moreplot = 10
for train_num in range(5):
  part_X = X[train_num*80:train_num*80+moreplot]
  train_X = np.append(train_X, part_X, axis=0)
  train_y = np.append(train_y, [train_num]*part_X.shape[0], axis=0)

test_index = np.array([1,81,161,241,321])
test_X = X[test_index]
test_y = np.array([0,1,2,3,4])

In [None]:
fig = plt.figure(figsize=(20,4))

# from fast_kNN import FaissKNeighbors
print(" ===> Get kNN predictions and votes")

k = train_X.shape[0]

knn_model = FaissKNeighbors(k=k)
knn_model.fit(train_X, train_y)


for i, one_test_X in enumerate(test_X):
    v = knn_model.get_votes(one_test_X[np.newaxis, :])
    print(v.shape)
            
            
    print("make heatmap")
    cm = "tab10"
    cm_num = 5
    cmap = plt.get_cmap(cm)

    ax = fig.add_subplot(1,5,i+1)
    # ax.set_title(f"class: {classes[test_y[i]]}")
    a = ax.pcolor(v, cmap=cm, vmin=0, vmax=cm_num-1)
    ax.invert_yaxis()
    ax.axes.yaxis.set_visible(False)
    ax.set_xlabel("k")

# fig.colorbar(a)
fig.tight_layout()
plt.subplots_adjust(
    top=0.65,
    bottom=0.5)
fig.patch.set_alpha(0)
fig.savefig('./kSpace_single_smaple.png')
plt.show()

# plt.clf()
# plt.close()

In [None]:
fig = plt.figure(figsize=(12,4))

# from fast_kNN import FaissKNeighbors
print(" ===> Get kNN predictions and votes")

k = train_X.shape[0]

knn_model = FaissKNeighbors(k=k)
knn_model.fit(train_X, train_y)

v_all = np.empty((0,train_X.shape[0]))
for i, one_test_X in enumerate(test_X):
    v = knn_model.get_votes(one_test_X[np.newaxis, :])
    v_all = np.append(v_all, v, axis=0)
            
            
print("make heatmap")
cm = "tab10"
cm_num = 5
cmap = plt.get_cmap(cm)

ax = fig.add_subplot(1,1,1)
# ax.set_title(f"class: {classes[test_y[i]]}")
a = ax.pcolor(v_all, cmap=cm, vmin=0, vmax=cm_num-1)
ax.invert_yaxis()
ax.axes.yaxis.set_visible(False)
ax.set_xlabel("k")

# fig.colorbar(a)
fig.tight_layout()
fig.patch.set_alpha(0)
fig.savefig('./kSpace_aggregated_smaple.png')
plt.show()

# plt.clf()
# plt.close()