In [None]:
import numpy as np
import copy

from lut import *

## Toy example from paper

In [None]:
X = np.array(
    [
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 1],
        [1, 0, 0],
        [1, 1, 0],
        [1, 1, 0],
    ],
    dtype=bool,
)

y = np.array([0, 1, 1, 1, 0, 0, 1], dtype=bool)

In [None]:
lut_0 = Lut(2)
lut_0.train(X, cols=[0, 1])
lut_0

In [None]:
lut_1 = Lut(2)
lut_1.train(X, cols=[0, 2])
lut_1

In [None]:
new_X = training_set_from_luts([lut_0, lut_1], X)
new_X

In [None]:
lut_3 = Lut(2)
lut_3.train(new_X)
lut_3

In [None]:
from sklearn import datasets
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

color_list = [x["color"] for x in plt.rcParams["axes.prop_cycle"]]

from matplotlib.ticker import MaxNLocator

## MNIST

In [None]:
# from sklearn.datasets import fetch_openml

# X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
# y = np.array([int(x) for x in y])
# np.savez("MNIST.npz", X=X, y=y)

data = np.load("MNIST.npz", allow_pickle=True)
X_ = data["X"]
y_ = data["y"]

scaler = MinMaxScaler(feature_range=(0, 1))
X_tf = scaler.fit_transform(X)

X = (X_tf > 0.5).astype(bool)
y = ((y_ == 0) | (y_ == 1) | (y_ == 2) | (y_ == 3) | (y_ == 4)).astype(bool)

In [None]:
N = X.shape[0]
bits = X.shape[1]

def get_bit_pattern(bits):
    bit_pattern = np.empty((2 ** bits, bits), dtype=bool)
    for i in range(2 ** bits):
        bit_string = "0" * (bits - len(f"{i:b}")) + f"{i:b}"
        for j, bit in enumerate(bit_string):
            bit_pattern[i, j] = int(bit)
    return bit_pattern
        
bit_pattern = get_bit_pattern(bits)
bit_pattern_tiled = np.tile(bit_pattern, (N, 1))

In [None]:
X

In [None]:
y = np.array([0, 1, 1, 1, 0, 0, 1], dtype=bool).astype(int)

y[y == 0] = -2
y[y == 1] = 2
y

In [None]:
def get_lut(X, y, bit_pattern_tiled, N, bits):
    assert X.shape[1] == bit_pattern_tiled.shape[1]
    pat = np.where(
        np.all(bit_pattern_tiled == np.repeat(X, 2 ** bits, axis=0), axis=1,).reshape(
            (N, 2 ** bits)
        )
        == True
    )[1]
    f = np.bincount(pat, weights=y, minlength=2 ** bits)

    np.put(f, np.where(f == 0)[0], np.random.choice([0, 1], size=(f == 0).sum()))
    np.put(f, np.where(f == -2)[0], 0)
    np.put(f, np.where(f == 2)[0], 1)
    return f.astype(bool)

In [None]:
get_lut(X, y, bit_pattern_tiled, N, bits)

#### Inference

In [None]:
np.where(np.all(np.repeat(X[3][None, :], 2 ** bits, axis=0) == bit_pattern, axis=1))[0].item()

In [None]:
np.where(
        np.all(bit_pattern_tiled == np.repeat(X, 2 ** bits, axis=0), axis=1,).reshape(
            (N, 2 ** bits)
        )
        == True
    )[1]