- A drawback of basic "majority voting" classification: when the class distribution is skewed, examples of a more frequent class dominate the prediction for the new example. This is because they tend to be common among the k nearest neighbors due to their large number.
- One way to overcome this problem is to weight the classification, taking into account the distance from the test point to each of its k nearest neighbors.
- That is, the class (or value, in regression problems) of each of the k nearest points is multiplied by a weight proportional to the inverse of the distance from that point to the test point.

# Naive implementation
- Time: $O(Nd)$ where $N$ is the number of points in the set, and $d$ is the dimensionality of each point.
- Space: O(1).

In [2]:
import torch
import utils.plots as plot_utis

def compute_distance(point1, point2):
    squared_distance = ((point1 - point2)**2).sum()
    return torch.sqrt(squared_distance)

def nn_classify(inputs, labels, new_point, k):
    idxs = range(len(inputs))
    key = lambda idx: compute_distance(inputs[idx], new_point)
    idxs = sorted(idxs, key=key)[:k]
    inputs[idxs], labels[idxs]

def nn_weighted_classify(inputs, labels, new_point, k):
    idxs = range(len(inputs))
    distances = torch.tensor([compute_distance(input, new_point) for input in inputs])
    idxs = sorted(idxs, key=lambda idx: distances[idx])[:k]
    inputs[idxs], labels[idxs], distances[idxs]

inputs = torch.randn(50, 2)
labels = torch.randint(0, 3, [50])
new_point = torch.tensor([0, 0])
k = 10

nn_points, nn_labels = nn_classify(inputs, labels, new_point, 10)
# label = max(set(nn_labels), key=nn_labels.count) # works for python lists
label = nn_labels.mode().values.item()

fig = plot_utis.figure(showlegend=False)
fig.add_scatter(x=inputs[:, 0], y=inputs[:, 1],
                mode='markers', marker=dict(size=10, color=labels))
fig.add_annotation(x=0, y=0, text='point', arrowhead=7)
fig.add_scatter(x=nn_points[:, 0], y=nn_points[:, 1], 
                mode='markers', marker=dict(symbol='x-open', size=15, color=nn_labels))
fig.show(renderer='notebook')

TypeError: cannot unpack non-iterable NoneType object