In [None]:
import numpy as np
import copy

We use the last column for $y$.

In [None]:
X = np.array(
    [
        [0, 0, 0, 0],
        [0, 0, 0, 1],
        [0, 0, 0, 1],
        [0, 0, 1, 1],
        [1, 0, 0, 0],
        [1, 1, 0, 0],
        [1, 1, 0, 1],
    ],
    dtype=bool,
)

`cnt` stores the counts for $y^0$ and $y^1$ for the corresponding bit pattern. The first 3 dimensions of `cnt` are for referencing the bit pattern, e.g. `010`. `cnt[..., 0]` is for the count of $y^0$ and `cnt[..., 1]` is for the count of $y^1$

In [None]:
def count_recursive(cnt, x):
    """
    Given an array that stores counts of shape ([2] * (bits + 1)) (the last two dimensions
    for y0 and y1, respectively), increase the count by 1 for given pattern (all but last
    entries of x) and whether it's y0 or y1 (the last entry of x). This function is recursive.
    
    Parameters
    ==========
    cnt: np.ndarray
        Array that stores counts for a given dataset.
    x: list
        Single training example as list. The last entry is for y.
    """
    if len(x) > 1:
        idx = x.pop(0)
        count_recursive(cnt[int(idx)], x)
    else:
        cnt[int(x[0])] += 1

In [None]:
class Lut:
    def __init__(self, bits):
        self.bits = bits
        self.lut = None
        self.rnd = None
        self.cnt = None
        self.cols = None
        
    def train(self, training_set, cols=None):
        """
        Train the lut given a training set.
        
        Parameters
        ==========
        training_set: numpy.ndarray
            Numpy array of shape (N, self.bits + 1) and dtype bool. The last column is for y.
        cols: list, optional
            List of indices to select columns. The lut will then be trained on a training set
            with less columns.
        """
        assert self.lut is None, "Lut is already trained!"
        if cols is not None:
            assert len(cols) == self.bits, f"Number of selected columns has to match bit size"
            self.cols = cols
            cols_ = copy.deepcopy(cols)
            cols_.append(-1)
            training_set_ = training_set[:, cols_]
        else:
            self.cols = list(range(training_set.shape[1] - 1))
            training_set_ = training_set

        cnt = np.zeros([2] * (self.bits + 1), dtype=np.uint32)
        for x in training_set_:
            count_recursive(cnt, list(x))

        cnt_ = cnt.reshape((-1, 2))

        lut = np.zeros(([2] * (training_set_.shape[1] - 1)), dtype=bool)
        rnd = np.zeros_like(lut)

        for i in range(2 ** self.bits):
            if cnt_[i, 0] > cnt_[i, 1]:
                lut.ravel()[i] = 0
                rnd.ravel()[i] = 0
            elif cnt_[i, 0] < cnt_[i, 1]:
                lut.ravel()[i] = 1
                rnd.ravel()[i] = 0
            else:
                lut.ravel()[i] = np.random.choice([0, 1])
                rnd.ravel()[i] = 1

        self.lut = lut
        self.cnt = cnt
        self.rnd = rnd
            
    def predict(self, data_set):
        assert self.lut is not None, f"Lut not trained yet!"
        preds = np.zeros((data_set.shape[0],), dtype=bool)
        for idx, x in enumerate(data_set):
            cols = copy.deepcopy(self.cols)
            preds[idx] = classify_single_training_example(self, cols, x)
        return preds
        
    def __repr__(self):
        if self.lut is not None:
            string = ""
            for i in range(2 ** self.bits):
                bit_string = "0" * (self.bits - len(f"{i:b}")) + f"{i:b}"
                is_rnd = "*" if self.rnd.ravel()[i] else ""
                string += bit_string + "  " + f"{self.lut.ravel()[i]:d}" + is_rnd + "\n"
            return string
        else:
            return f"Empty lut, {self.bits} bits"
        
    def __getitem__(self, idx):
        return self.lut[idx]
    
    def get_cnt_table(self):
        """
        Returns a string of the count table for the lut.
        """
        assert self.cnt is not None, "Lut not trained yet, no counts to return"
        cnt_ = self.cnt.reshape((-1, 2))
        string = ""
        for i in range(2 ** self.bits):
            bit_string = "0" * (self.bits - len(f"{i:b}")) + f"{i:b}"
            string += bit_string + "  " + f"{cnt_[i, 0]} " + f"{cnt_[i, 1]}" + "\n"
        return string

In [None]:
lut_0 = Lut(2)
lut_0.train(X, cols=[0, 1])
lut_0

In [None]:
lut_1 = Lut(2)
lut_1.train(X, cols=[0, 2])
lut_1

In [None]:
def classify_single_training_example(lut, cols, training_example):
    """
    Given a single training example and list of column indices that the lut operates on,
    classify that bit pattern.
    
    Parameters
    lut: Lut
        Lookup table object.
    cols: list
        List of columns that the lut operates on.
    training_example: list
        Single row from a training set.
    """
    if len(cols) > 1:
        idx = cols.pop(0)
        return classify_single_training_example(lut[int(training_example[idx])], cols, training_example)
    else:
        return lut[int(training_example[cols[0]])]

def training_set_from_luts(luts, orig_training_set):
    """
    Given a list of luts trained on a subset of the original dataset, obtain a new training set
    where the y labels are the same and the features come according to what the luts classify.
    
    Parameters
    ==========
    luts: array-like
        List of luts.
    orig_training_set: np.ndarray
        Original training_set where the luts were trained on.
    """
    training_set = np.zeros((orig_training_set.shape[0], len(luts) + 1), dtype=bool)
    training_set[:, -1] = orig_training_set[:, -1]
    for i in range(len(orig_training_set)):
        for j in range(len(luts)):
            cols = copy.deepcopy(luts[j].cols)
            x = list(orig_training_set[i])
            training_set[i, j] = classify_single_training_example(luts[j], cols, x)
            
    return training_set

In [None]:
new_X = training_set_from_luts([lut_0, lut_1], X)
new_X

In [None]:
lut_3 = Lut(2)
lut_3.train(new_X)
lut_3

## Iris Dataset (differentiating class 0 from 1)

In [None]:
from sklearn import datasets
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

In [None]:
iris = datasets.load_iris()

In [None]:
scaler = MinMaxScaler(feature_range=(0, 1))
X_iris_scaled = scaler.fit_transform(iris.data[iris.target != 0])

target = np.array(iris.target[iris.target != 0] == 2, dtype=bool)[:, None]
X_iris = np.hstack((np.array(X_iris_scaled > 0.5, dtype=bool), target))
X_iris.shape

In [None]:
X_train, X_test = train_test_split(X_iris, test_size=0.33, random_state=42)

In [None]:
lut = Lut(4)
lut.train(X_train)
lut

In [None]:
preds = lut.predict(X_train)
accuracy_score(preds, X_train[:, -1])

In [None]:
preds = lut.predict(X_test)
accuracy_score(preds, X_test[:, -1])

## Custom dataset

In [None]:
import matplotlib.pyplot as plt

In [None]:
num_examples = 1000
dist = 4
ax1 = np.random.normal(loc=-dist/2, scale=1.0, size=num_examples)
ax2 = np.random.normal(loc=-dist/2, scale=1.0, size=num_examples)
ay = np.zeros((num_examples,), dtype=int)

a = np.hstack(((ax1[:, None] > 0.0).astype(int), (ax2[:, None] > 0.0).astype(int), ay[:, None]))

bx1 = np.random.normal(loc=dist/2, scale=1.0, size=num_examples)
bx2 = np.random.normal(loc=dist/2, scale=1.0, size=num_examples)
by = np.ones((num_examples,), dtype=int)

b = np.hstack(((bx1[:, None] > 0.0).astype(int), (bx2[:, None] > 0.0).astype(int), by[:, None]))

fig, ax = plt.subplots(1, 1)

ax.scatter(ax1, ax2, label="a")
ax.scatter(bx1, bx2, label="b", alpha=0.5)
ax.legend();

In [None]:
X_ab = np.vstack((a, b))

X_train, X_test = train_test_split(X_ab, test_size=0.33, random_state=42, shuffle=True)

lut = Lut(2)
lut.train(X_train)
lut

In [None]:
preds = lut.predict(X_train)
accuracy_score(preds, X_train[:, -1])

In [None]:
preds = lut.predict(X_test)
accuracy_score(preds, X_test[:, -1])

Here we have the two clusters symmetrically around 0 and the criterion for binarizing the dataset ($>0$) is well chosen. If we choose the criterion badly, the performance of the luts can drop significantly.

## MNIST

In [None]:
# from sklearn.datasets import fetch_openml

# X, y = fetch_openml("mnist_784", version=1, return_X_y=True, as_frame=False)
# y = np.array([int(x) for x in y])
# np.savez("MNIST.npz", X=X, y=y)

data = np.load("MNIST.npz", allow_pickle=True)
X = data["X"]
y = data["y"]

In [None]:
%%time
bits = 2

pca = PCA(n_components=bits)
X_pca = pca.fit_transform(X)

# fig, ax = plt.subplots(1, 1)
# X_back = pca.inverse_transform(X_pca)
# ax.imshow(X_back[10].reshape((28,28)), cmap="gray")

scaler = MinMaxScaler(feature_range=(0, 1))
X_tf = scaler.fit_transform(X_pca)

X_mnist = np.hstack(
    (
        (X_tf > 0.5).astype(bool),
        ((y == 0) | (y == 1) | (y == 2) | (y == 3) | (y == 4)).astype(bool)[:, None],
    )
)

X_train, X_test = train_test_split(X_mnist, test_size=0.33, random_state=42, shuffle=True)

lut = Lut(bits)
lut.train(X_train)

preds = lut.predict(X_train)
print(f"Accuracy on training set: {accuracy_score(preds, X_train[:, -1]):.2f}%")

preds = lut.predict(X_test)
print(f"Accuracy on test set: {accuracy_score(preds, X_test[:, -1]):.2f}%")

print(f"{lut.rnd.sum() / (~lut.rnd).sum():.2f}% of lut entries are random")

In [None]:
from tqdm.notebook import tqdm

In [None]:
bit_arr = list(range(2, 21))
train_arr = []
test_arr = []
rnd_arr = []

for bits in tqdm(bit_arr):
    pca = PCA(n_components=bits)
    X_pca = pca.fit_transform(X)
    scaler = MinMaxScaler(feature_range=(0, 1))
    X_tf = scaler.fit_transform(X_pca)
    X_mnist = np.hstack(
        (
            (X_tf > 0.5).astype(bool),
            ((y == 0) | (y == 1) | (y == 2) | (y == 3) | (y == 4)).astype(bool)[:, None],
        )
    )
    X_train, X_test = train_test_split(X_mnist, test_size=0.33, random_state=42, shuffle=True)
    lut = Lut(bits)
    lut.train(X_train)
    preds = lut.predict(X_train)
    train_arr.append(accuracy_score(preds, X_train[:, -1]))
    preds = lut.predict(X_test)
    test_arr.append(accuracy_score(preds, X_test[:, -1]))
    rnd_arr.append(lut.rnd.sum() / (~lut.rnd).sum())

In [None]:
color_list = [x["color"] for x in plt.rcParams["axes.prop_cycle"]]

In [None]:
from matplotlib.ticker import MaxNLocator

In [None]:
fig, ax = plt.subplots(1, 1)

ax.plot(bit_arr, train_arr, label="Train")
ax.plot(bit_arr, test_arr, "--", label="Test")
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
ax.set_xlabel("Number of Bits")
ax.set_ylabel("Accuracy")
ax.set_title("Performance of a single lut on 0-4 vs. 5-9 MNIST classification\n(PCA used to reduce dimensions to corresponding bit size)", pad=20)
ax.grid()
ax.legend()

ax2 = ax.twinx()
ax2.plot(bit_arr, rnd_arr, "-.", label="Percentage of\nlut entries\nrandom", c=color_list[2])
ax2.legend(bbox_to_anchor=(1.1,1), loc="upper left");

In [None]:
fig.savefig("single_lut_performance.jpg", dpi=100, bbox_inches="tight")