In [237]:
import numpy as np 

In [238]:
a = np.array([[np.nan, 1, -1],
             [3, 2, 1],
             [4, 1, 2],
             [np.nan, 2, 3]])

In [239]:
a

array([[nan,  1., -1.],
       [ 3.,  2.,  1.],
       [ 4.,  1.,  2.],
       [nan,  2.,  3.]])

In [240]:
class KNN:
    def __init__(self, k=3):
        self.k = k
        self.train_X = None
        self.train_y = None
    
    def fit(self, X, y):
        self.train_X = X
        self.train_y = y
    
    def predict(self, X):
        if X.shape[0] == 0:
            return X.squeeze()
        d = self._euclid_dist(self.train_X, X)
        knl = self._k_nearest_labels(d, self.train_y)
        print(knl)
        if self.k == 1:
            return knl.squeeze()[:, np.newaxis]
        else:
            return np.array([np.argmax(np.bincount(y.squeeze().astype(np.int64))) for y in knl])
    
    def _euclid_dist(self, X_known, X_unknown):
        sqrt = np.sqrt
        sm = np.sum
        return np.array([sqrt(sm((x - X_unknown) ** 2, axis=1)) for x in X_known]).T
    
    def _k_nearest_labels(self, dists, y_known):
        num_pred = dists.shape[0]
        n_nearest = []
        closest_y = None
        for j in range(num_pred):
            dst = dists[j]
            closest_y = y_known[np.argsort(dst)][:self.k]        
            n_nearest.append(closest_y)
        return np.asarray(n_nearest)

In [241]:
def fill_knn(predictors, target, k=3):
    not_nan_idx = ~np.isnan(target)
    if len(target[~not_nan_idx]) == 0:
        return None
    if len(predictors.shape) == 1:
        predictors = predictors[:, np.newaxis]
    if len(target.shape) == 1:
        target = target[:, np.newaxis]
    X_train, y_train = predictors[not_nan_idx, :], target[not_nan_idx, :]
    knn = KNN(k)
    knn.fit(X_train, y_train)
    X_test = predictors[~not_nan_idx, :]
    print("X_test.shape:", X_test.shape)
    target[~not_nan_idx] = knn.predict(X_test)

In [242]:
fill_knn(a[:, 2], a[:, 0], k=1)

X_test.shape: (2, 1)
X.shape: (2, 1)
d.shape: (2, 2)
KNL shape: (2, 1, 1)
[[[3.]]

 [[4.]]]
KNL SQUEEZE SHAPE: (2,)


In [243]:
a

array([[ 3.,  1., -1.],
       [ 3.,  2.,  1.],
       [ 4.,  1.,  2.],
       [ 4.,  2.,  3.]])