In [1]:
import numpy as np

from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', parser='auto')

from sklearn.model_selection import train_test_split
x_train, x_test, y_train, y_test = train_test_split(mnist.data, mnist.target, test_size=10000, random_state=42, train_size=60000)
x_train = x_train.values.astype(int)
y_train = y_train.values.astype(int)
x_test = x_test.values.astype(int)
y_test = y_test.values.astype(int)

In [2]:
from sklearn.cluster import KMeans

shard_count = 50

kmeans = KMeans(shard_count, init="k-means++", n_init=1)
kmeans.fit(x_train)

shards_train = kmeans.predict(x_train)
shards_test = kmeans.predict(x_test)

In [3]:
input_count = 28*28
output_count = 10

from numba import njit, prange

@njit(fastmath=True)
def fit(shard, S, T):
    for epoch in range(10000):
        total_misses = 0
        for i in range(len(shards_train)):
            if shards_train[i] == shard:
                # diff classifier
                X, y = x_train[i], y_train[i]
                M = S / T
                p = ((M - X)**2).sum(axis=1).argmin()
                if p != y:
                    S[y] += X
                    T[y] += 1
                    S[p] = (S[p] - X / 2).clip(0)
                    total_misses += 1
        if total_misses == 0:
                break

S = np.zeros((shard_count, output_count, input_count))
T = np.ones((shard_count, output_count, 1))

@njit(fastmath=True, parallel=True)
def fit_all(S, T):
    for shard in prange(shard_count):
        fit(shard, S[shard], T[shard])

fit_all(S, T)

In [4]:
M = S / T

total_test = 0
for i in range(len(x_test)):
    X, y = x_test[i], y_test[i]
    if y == ((M[shards_test[i]] - X)**2).sum(axis=1).argmin():
        total_test += 1

total_train = 0
for i in range(len(x_train)):
    X, y = x_train[i], y_train[i]
    if y == ((M[shards_train[i]] - X)**2).sum(axis=1).argmin():
        total_train += 1

print("accuracy train: %f; accuracy test: %f" % (total_train/len(x_train), total_test/len(x_test)))

accuracy train: 1.000000; accuracy test: 0.947900


In [5]:
from sklearn.neighbors import KDTree

tree = KDTree(kmeans.cluster_centers_, leaf_size=2)

total_train = 0
for i in range(len(x_train)):
    X = x_train[i].reshape(1, -1)
    y = y_train[i]
    shard = tree.query(X, return_distance=False)[0][0]
    p = ((M[shard] - X)**2).sum(axis=1).argmin()
    if p == y:
        total_train += 1

total_test = 0
for i in range(len(x_test)):
    X = x_test[i].reshape(1, -1)
    y = y_test[i]
    shard = tree.query(X, return_distance=False)[0][0]
    p = ((M[shard] - X)**2).sum(axis=1).argmin()
    if p == y:
        total_test += 1

print("accuracy train: %f; accuracy test: %f" % (total_train/len(x_train), total_test/len(x_test)))


accuracy train: 1.000000; accuracy test: 0.947900


In [6]:
from sklearn.neighbors import BallTree

tree = BallTree(kmeans.cluster_centers_, leaf_size=2)

total_train = 0
for i in range(len(x_train)):
    X = x_train[i].reshape(1, -1)
    y = y_train[i]
    shard = tree.query(X, return_distance=False)[0][0]
    p = ((M[shard] - X)**2).sum(axis=1).argmin()
    if p == y:
        total_train += 1

total_test = 0
for i in range(len(x_test)):
    X = x_test[i].reshape(1, -1)
    y = y_test[i]
    shard = tree.query(X, return_distance=False)[0][0]
    p = ((M[shard] - X)**2).sum(axis=1).argmin()
    if p == y:
        total_test += 1

print("accuracy train: %f; accuracy test: %f" % (total_train/len(x_train), total_test/len(x_test)))

accuracy train: 1.000000; accuracy test: 0.947900


In [7]:
noise = np.random.choice([-1, 1], size=x_train.shape, p=[1./2, 1./2])
x_train += noise
noise = np.random.choice([-1, 1], size=x_test.shape, p=[1./2, 1./2])
x_test += noise

total_train = 0
for i in range(len(x_train)):
    X = x_train[i].reshape(1, -1)
    y = y_train[i]
    shard = tree.query(X, return_distance=False)[0][0]
    p = ((M[shard] - X)**2).sum(axis=1).argmin()
    if p == y:
        total_train += 1

total_test = 0
for i in range(len(x_test)):
    X = x_test[i].reshape(1, -1)
    y = y_test[i]
    shard = tree.query(X, return_distance=False)[0][0]
    p = ((M[shard] - X)**2).sum(axis=1).argmin()
    if p == y:
        total_test += 1

print("accuracy train: %f; accuracy test: %f" % (total_train/len(x_train), total_test/len(x_test)))

accuracy train: 0.999867; accuracy test: 0.948100


In [8]:
subtree = []
for shard in range(shard_count):
    subtree.append(BallTree(M[shard], leaf_size=2))

total_train = 0
for i in range(len(x_train)):
    X = x_train[i].reshape(1, -1)
    y = y_train[i]
    shard = tree.query(X, return_distance=False)[0][0]
    p = subtree[shard].query(X, return_distance=False)[0][0]
    if p == y:
        total_train += 1

total_test = 0
for i in range(len(x_test)):
    X = x_test[i].reshape(1, -1)
    y = y_test[i]
    shard = tree.query(X, return_distance=False)[0][0]
    p = subtree[shard].query(X, return_distance=False)[0][0]
    if p == y:
        total_test += 1

print("accuracy train: %f; accuracy test: %f" % (total_train/len(x_train), total_test/len(x_test)))

accuracy train: 0.999867; accuracy test: 0.948100
