In [None]:
import numpy as np
import copy

from sklearn import datasets
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle

import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

color_list = [x["color"] for x in plt.rcParams["axes.prop_cycle"]]

from matplotlib.ticker import MaxNLocator

import multiprocessing
import os

## Toy example from paper

In [None]:
X = np.array(
    [
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 1],
        [1, 0, 0],
        [1, 1, 0],
        [1, 1, 0],
    ],
    dtype=bool,
)

y = np.array([0, 1, 1, 1, 0, 0, 1], dtype=bool)

## Vectorized implementation

In [None]:
def get_idxs(X, bit_pattern_tiled, N, bits):
    """
    Get indexes of bit pattern. Each row of X corresponds to one bit pattern, e.g.
    `[0, 0, 1]`. The first entry of bit_pattern is `[0, 0, 0]`. So this function
    would return index 1 for `[0, 0, 1]`.
    
    Parameters
    ==========
    X: np.ndarray
        Dataset of shape (N, bits) and dtype bool.
    bit_pattern_tiled: np.ndarray
        Tiled bit pattern: `np.tile(bit_pattern, (N, 1))`.
    N: int
        Number of examples, i.e. `X.shape[0]`.
    bits:
        Number of bits, i.e. `X.shape[1]`.
    """
    return np.where(
        np.all(bit_pattern_tiled == np.repeat(X, 2 ** bits, axis=0), axis=1,).reshape(
            (N, 2 ** bits)
        )
        == True
    )[1]

def get_lut(indexes, labels, bits):
    f = np.bincount(indexes, weights=labels, minlength=2 ** bits)
    np.put(f, np.where(f == 0)[0], np.random.choice([0, 1], size=(f == 0).sum()))
    np.put(f, np.where(f < 0)[0], 0)
    np.put(f, np.where(f > 0)[0], 1)
    return f.astype(bool)

def get_bit_pattern(bits):
    bit_pattern = np.empty((2 ** bits, bits), dtype=bool)
    for i in range(2 ** bits):
        bit_string = "0" * (bits - len(f"{i:b}")) + f"{i:b}"
        for j, bit in enumerate(bit_string):
            bit_pattern[i, j] = int(bit)
    return bit_pattern

#### Network of luts

In [None]:
data = np.load("MNIST.npz", allow_pickle=True)
X_ = data["X"]
y_ = data["y"]

scaler = MinMaxScaler(feature_range=(0, 1))
X_tf = scaler.fit_transform(X_)

In [None]:
%%time
num_examples = 2000
X = (X_tf > 0.5).astype(bool)[:num_examples]
y = ((y_ == 0) | (y_ == 1) | (y_ == 2) | (y_ == 3) | (y_ == 4)).astype(int)[:num_examples]
y[y == 0] = -1
y[y == 1] = 1
y_bool = ((y_ == 0) | (y_ == 1) | (y_ == 2) | (y_ == 3) | (y_ == 4)).astype(bool)[:num_examples]

bits = 8
N = X.shape[0]
bit_pattern = get_bit_pattern(bits)
bit_pattern_tiled = np.tile(bit_pattern, (N, 1))

num_luts_1 = 1024
cols_arr_1 = np.zeros((num_luts_1, bits), dtype=int)
lut_arr_1 = np.zeros((num_luts_1, 2 ** bits), dtype=bool)
idxs_arr_1 = np.zeros((num_luts_1, N), dtype=int)
X_1 = np.zeros((N, num_luts_1), dtype=bool)

for i in tqdm(range(num_luts_1)):
    cols_arr_1[i] = np.random.choice(range(X.shape[1]), size=bits)
    idxs_arr_1[i] = get_idxs(X[:, cols_arr_1[i]], bit_pattern_tiled, N, bits)
    lut_arr_1[i] = get_lut(idxs_arr_1[i], y, bits)
    X_1[:, i] = lut_arr_1[i][idxs_arr_1[i]]

num_luts_2 = 1024
cols_arr_2 = np.zeros((num_luts_2, bits), dtype=int)
lut_arr_2 = np.zeros((num_luts_2, 2 ** bits), dtype=bool)
idxs_arr_2 = np.zeros((num_luts_2, N), dtype=int)
X_2 = np.zeros((N, num_luts_2), dtype=bool)

for i in tqdm(range(num_luts_2)):
    cols_arr_2[i] = np.random.choice(range(num_luts_1), size=bits)
    idxs_arr_2[i] = get_idxs(X_1[:, cols_arr_2[i]], bit_pattern_tiled, N, bits)
    lut_arr_2[i] = get_lut(idxs_arr_2[i], y, bits)
    X_2[:, i] = lut_arr_2[i][idxs_arr_2[i]]

cols_3 = np.random.choice(range(num_luts_2), size=bits)
idxs_3 = get_idxs(X_2[:, cols_3], bit_pattern_tiled, N, bits)
lut_3 = get_lut(idxs_3, y, bits)

preds = lut_3[idxs_3]
print(f"Accuracy on training set: {accuracy_score(preds, y_bool):.2f}%")

## Multiprocessing

In [None]:
data = np.load("MNIST.npz", allow_pickle=True)
X_ = data["X"]
y_ = data["y"]

scaler = MinMaxScaler(feature_range=(0, 1))
X_tf = scaler.fit_transform(X_)

In [None]:
%%time
num_examples = 2000
X = (X_tf > 0.5).astype(bool)[:num_examples]
y = ((y_ == 0) | (y_ == 1) | (y_ == 2) | (y_ == 3) | (y_ == 4)).astype(int)[
    :num_examples
]
y[y == 0] = -1
y[y == 1] = 1
y_bool = ((y_ == 0) | (y_ == 1) | (y_ == 2) | (y_ == 3) | (y_ == 4)).astype(bool)[
    :num_examples
]

bits = 8
N = X.shape[0]
bit_pattern = get_bit_pattern(bits)
bit_pattern_tiled = np.tile(bit_pattern, (N, 1))


def get_cols(inp_len, bits):
    np.random.seed(int.from_bytes(os.urandom(4), byteorder="little"))
    return np.random.choice(range(inp_len), size=bits)


num_luts_1 = 1024
num_luts_2 = 1024

pool = multiprocessing.Pool()

if __name__ == "__main__":
    cols_arr_1 = np.array(pool.starmap(get_cols, [[X.shape[1], bits]] * num_luts_1))
    idxs_arr_1 = np.array(
        pool.starmap(
            get_idxs,
            [
                [X[:, cols_arr_1[i]], bit_pattern_tiled, N, bits]
                for i in range(num_luts_1)
            ],
        )
    )
    lut_arr_1 = np.array(
        pool.starmap(get_lut, [[idxs_arr_1[i], y, bits] for i in range(num_luts_1)])
    )
    X_1 = np.array([lut_arr_1[i][idxs_arr_1[i]] for i in range(num_luts_1)]).T

    #######################
    cols_arr_2 = np.array(pool.starmap(get_cols, [[num_luts_1, bits]] * num_luts_2))
    idxs_arr_2 = np.array(
        pool.starmap(
            get_idxs,
            [
                [X_1[:, cols_arr_2[i]], bit_pattern_tiled, N, bits]
                for i in range(num_luts_2)
            ],
        )
    )
    lut_arr_2 = np.array(
        pool.starmap(get_lut, [[idxs_arr_2[i], y, bits] for i in range(num_luts_2)])
    )
    X_2 = np.array([lut_arr_2[i][idxs_arr_2[i]] for i in range(num_luts_2)]).T


cols_3 = np.random.choice(range(num_luts_2), size=bits)
idxs_3 = get_idxs(X_2[:, cols_3], bit_pattern_tiled, N, bits)
lut_3 = get_lut(idxs_3, y, bits)

preds = lut_3[idxs_3]
print(f"Accuracy on training set: {accuracy_score(preds, y_bool):.2f}%")

## Lut class

In [None]:
def get_cols(inp_len, bits):
    np.random.seed(int.from_bytes(os.urandom(4), byteorder="little"))
    return np.random.choice(range(inp_len), size=bits)


class Lut:
    def __init__(self, bits, hidden_layers):
        self.bits = bits
        self.hidden_layers = hidden_layers
        self.bit_pattern = get_bit_pattern(bits)

        self.cols_arr = []
        self.idxs_arr = []
        self.lut_arr = []
        self.X_arr = []

    def train(self, X, y):
        N = X.shape[0]
        bit_pattern_tiled = np.tile(self.bit_pattern, (N, 1))

        pool = multiprocessing.Pool()

        if __name__ == "__main__":
            for j, num_luts in enumerate(self.hidden_layers):
                self.cols_arr.append(
                    np.array(
                        pool.starmap(
                            get_cols,
                            [
                                [
                                    X.shape[1] if j == 0 else self.hidden_layers[j - 1],
                                    self.bits,
                                ]
                            ]
                            * num_luts,
                        )
                    )
                )
                self.idxs_arr.append(
                    np.array(
                        pool.starmap(
                            get_idxs,
                            [
                                [
                                    X[:, self.cols_arr[-1][i]]
                                    if j == 0
                                    else self.X_arr[-1][:, self.cols_arr[-1][i]],
                                    bit_pattern_tiled,
                                    N,
                                    self.bits,
                                ]
                                for i in range(num_luts)
                            ],
                        )
                    )
                )
                self.lut_arr.append(
                    np.array(
                        pool.starmap(
                            get_lut,
                            [
                                [self.idxs_arr[-1][i], y, self.bits]
                                for i in range(num_luts)
                            ],
                        )
                    )
                )
                self.X_arr.append(
                    np.array(
                        [
                            self.lut_arr[-1][i][self.idxs_arr[-1][i]]
                            for i in range(num_luts)
                        ]
                    ).T
                )

        self.cols_arr.append(
            np.random.choice(range(self.hidden_layers[-1]), size=self.bits)
        )
        self.idxs_arr.append(
            get_idxs(
                self.X_arr[-1][:, self.cols_arr[-1]], bit_pattern_tiled, N, self.bits
            )
        )
        self.lut_arr.append(get_lut(self.idxs_arr[-1], y, self.bits))
        preds = self.lut_arr[-1][self.idxs_arr[-1]]
        return preds

    def predict(self, X):
        N = X.shape[0]
        bit_pattern_tiled = np.tile(self.bit_pattern, (N, 1))

        pool = multiprocessing.Pool()

        if __name__ == "__main__":
            for j, num_luts in enumerate(self.hidden_layers):
                idxs = np.array(
                    pool.starmap(
                        get_idxs,
                        [
                            [
                                X[:, self.cols_arr[0][i]]
                                if j == 0
                                else X_[:, self.cols_arr[j][i]],
                                bit_pattern_tiled,
                                N,
                                self.bits,
                            ]
                            for i in range(num_luts)
                        ],
                    )
                )
                X_ = np.array([self.lut_arr[j][i][idxs[i]] for i in range(num_luts)]).T

        idxs = get_idxs(X_[:, self.cols_arr[-1]], bit_pattern_tiled, N, self.bits)
        preds = self.lut_arr[-1][idxs]
        return preds

## Experiments

In [None]:
data = np.load("MNIST.npz", allow_pickle=True)
X_ = data["X"]
y_ = data["y"]

scaler = MinMaxScaler(feature_range=(0, 1))
X_tf = scaler.fit_transform(X_)

In [None]:
X = (X_tf > 0.5).astype(bool)
y = ((y_ == 0) | (y_ == 1) | (y_ == 2) | (y_ == 3) | (y_ == 4)).astype(int)
y[y == 0] = -1
y[y == 1] = 1

y_bool = ((y_ == 0) | (y_ == 1) | (y_ == 2) | (y_ == 3) | (y_ == 4)).astype(bool)

X, y, y_bool = shuffle(X, y, y_bool, n_samples=10_000, random_state=0)

X_train, X_test, y_train, y_test, y_bool_train, y_bool_test = train_test_split(
    X, y, y_bool, test_size=0.33, random_state=42, shuffle=False
)

In [None]:
%%time
lut = Lut(9, [500, 500, 500])
preds_train = lut.train(X_train, y_train)
preds_test = lut.predict(X_test)

print(f"Accuracy on training set: {accuracy_score(preds_train, y_bool_train):.2f}%")
print(f"Accuracy on test set: {accuracy_score(preds_test, y_bool_test):.2f}%")

In [None]:
bit_arr = list(range(4, 13))
arc_arr = [[100] * i for i in range(4, 10)]

train_mesh = np.zeros((len(bit_arr), len(arc_arr)), dtype=np.float32)
test_mesh = np.zeros((len(bit_arr), len(arc_arr)), dtype=np.float32)
for i, bits in enumerate(tqdm(bit_arr)):
    for j, arc in enumerate(arc_arr):
        lut = Lut(bits, arc)
        preds_train = lut.train(X_train, y_train)
        preds_test = lut.predict(X_test)
        train_mesh[i, j] = accuracy_score(preds_train, y_bool_train)
        test_mesh[i, j] = accuracy_score(preds_test, y_bool_test)

In [None]:
train_mesh

In [None]:
x = np.arange(len(bit_arr) + 1)
y = np.arange(len(arc_arr) + 1)

fig, axs = plt.subplots(1, 2, figsize=(9, 4))

ax = axs[0]
pm = ax.pcolormesh(x, y, train_mesh.T, shading="auto")
cbar = plt.colorbar(pm, ax=ax, label="Train accuracy")
cbar.ax.get_yaxis().labelpad = 13
ax.set_xlabel("Number of bits per lut")
ax.set_ylabel("Number of hidden layers\n(Each hidden layer has 100 luts)")
ax.set_xticks(x[:-1] + 0.5)
ax.set_xticklabels(bit_arr)
ax.set_yticks(y[:-1] + 0.5)
ax.set_yticklabels(list(range(4, 10)))
ax.set_title("Train")

ax = axs[1]
pm = ax.pcolormesh(x, y, test_mesh.T, shading="auto")
cbar = plt.colorbar(pm, ax=ax, label="Test accuracy")
cbar.ax.get_yaxis().labelpad = 13
ax.set_xlabel("Number of bits per lut")
ax.set_ylabel("Number of hidden layers\n(Each hidden layer has 100 luts)")
ax.set_xticks(x[:-1] + 0.5)
ax.set_xticklabels(bit_arr)
ax.set_yticks(y[:-1] + 0.5)
ax.set_yticklabels(list(range(4, 10)))
ax.set_title("Test")

fig.suptitle("Train and test accuracies for 0-4 vs. 5-9 MNIST classification", fontsize=14)

plt.tight_layout();

In [None]:
fig.savefig(f"layers_bits_acc.jpg", bbox_inches="tight", dpi=100)