In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
import torch
from tqdm import tqdm

In [None]:
def moving_average(a, n=100):
    moving_average = []
    for i in range(len(a)):
        start = max(0, i - n)
        values = a[start:i+1]
        moving_average.append(np.sum(values) / float(len(values)))
    return np.array(moving_average)

In [None]:
class KNNClassifier:
    def __init__(self, k=1):
        self.k = k
        self.data = []
    def add_sample(self, sample, label):
        self.data.append((sample, label))
    def classify(self, new_sample):
        distances = [(None, 10000)]*self.k
        for sample, label in self.data:
            dist = torch.sum(new_sample != sample)
            for i, (_, best_dist) in enumerate(distances):
                if dist < best_dist:
                    distances[i] = (label, dist)
                    break
        counts = {}
        for p_label, _ in distances:
            if p_label not in counts:
                counts[p_label] = 0
            counts[p_label] += 1
        counts = list(counts.items())
        counts.sort(key=lambda x: x[1])
        return counts[0][0]

In [None]:
stream = pickle.load(open('dataset.pkl', 'rb'))['noisy']

In [None]:
classifier = KNNClassifier()
x = []
correct = []
for cur in tqdm(range(10, len(stream))):
    readout = torch.tensor(stream[cur-10:cur])    
    target = stream[cur]
    prediction = classifier.classify(readout)
    
    if stream[cur+1] >= 10:
        x.append(cur)
        correct.append(int(prediction == target))
        
    classifier.add_sample(readout, target)

In [None]:
plt.figure(figsize=(15, 8))
plt.plot(x, moving_average(correct))
plt.legend()
plt.xticks(np.linspace(0, 20000, 11))
plt.xlim(0, 20000)
plt.ylim(0, 1)
plt.grid()
plt.savefig('knn_hamming.pdf', dpi=150)
plt.show()

In [None]:
class Encoder():
    def __init__(self, e_size=25):
        self.encodings = {}
        self.e_size = e_size

    def encode(self, x):
        if x in self.encodings:
            return self.encodings[x]
        self.encodings[x] = torch.rand(e_size)
        return self.encodings[x]

    def decode(self, v):
        nearest = None
        best = float('inf')
        for x, e in self.encodings.items():
            dist = (torch.sum((v - e).pow(2))).pow(0.5)
            if dist < best:
                best = dist
                nearest = x
        return nearest
    
    def precode(self, stream):
        for i in stream:
            self.encode(i)

In [None]:
class KNNClassifier:
    def __init__(self, k=1):
        self.k = k
        self.data = []
    def add_sample(self, sample, label):
        self.data.append((sample, label))
    def classify(self, new_sample):
        distances = [(None, 10000)]*self.k
        for sample, label in self.data:
            dist = torch.sum(torch.abs(new_sample - sample))
            for i, (_, best_dist) in enumerate(distances):
                if dist < best_dist:
                    distances[i] = (label, dist)
                    break
        counts = {}
        for p_label, _ in distances:
            if p_label not in counts:
                counts[p_label] = 0
            counts[p_label] += 1
        counts = list(counts.items())
        counts.sort(key=lambda x: x[1])
        return counts[0][0]

In [None]:
e_size = 25
stream = pickle.load(open('dataset.pkl', 'rb'))['noisy']
encoder = Encoder(e_size)
encoder.precode([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [None]:
classifier = KNNClassifier()
x = []
correct = []
for cur in tqdm(range(10, len(stream))):
    inpt = stream[cur-10:cur]
    readout = torch.zeros(10*25)
    for j, s in enumerate(inpt):
        readout[j*25:(j+1)*25] = encoder.encode(s)
    
    target = stream[cur]
    prediction = classifier.classify(readout)
    
    if stream[cur+1] >= 10:
        x.append(cur)
        correct.append(int(prediction == target))
        
    classifier.add_sample(readout, target)

In [None]:
plt.figure(figsize=(15, 8))
plt.plot(x, moving_average(correct))
plt.legend()
plt.xticks(np.linspace(0, 20000, 11))
plt.xlim(0, 20000)
plt.ylim(0, 1)
plt.grid()
plt.savefig('knn_mae.pdf', dpi=150)
plt.show()