In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import random
import sys
from typing import List, Union

import matplotlib.pyplot as plt
import numpy as np
import tqdm
from sklearn.metrics import accuracy_score, classification_report

from wisard.encoders import ThermometerEncoder, encode_dataset
from wisard.wisard import WiSARD
from wisard.utils import untie
from wisard.optimize import find_best_bleach_bayesian, find_best_bleach_bin_search
from wisard.data import IrisDataset

from keras.datasets import mnist, fashion_mnist

%matplotlib inline

2022-08-05 01:28:33.751663: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-08-05 01:28:33.751721: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [3]:
def do_train_and_evaluate(x_train,
                          y_train,
                          x_test,
                          y_test,
                          tuple_size: int,
                          bleach: Union[int, str] = "auto",
                          **kwargs):
    num_classes = len(np.unique(y_train))

    print(" ----- Training model ----- ")

    model = WiSARD(num_inputs=x_train[0].size,
                   num_classes=num_classes,
                   unit_inputs=tuple_size,
                   unit_entries=1,
                   unit_hashes=1,
                   input_idxs=np.random.shuffle(np.arange(x_train[0].size)),
                   shared_rand_vals=True,
                   randomize=False)

    model.fit(x_train, y_train)
    max_bleach = model.max_bleach()
    print(f"Max bleach is: {max_bleach}\n")

    print(" ----- Evaluating model ----- ")

    if isinstance(bleach, int):
        y_pred = model.predict(x_test, y_test, bleach=bleach, use_tqdm=True)
        y_pred, ties = untie(y_pred, use_tqdm=False)
        accuracy = accuracy_score(y_test, y_pred)
        print(f"Accuracy: {accuracy:.3f}")
    elif bleach == "auto":
        bleach = find_best_bleach_bin_search(model,
                                           X=x_test,
                                           y=y_test,
                                           min_bleach=1,
                                           max_bleach=max_bleach,
                                           **kwargs)
    else:
        raise ValueError(f"Invalid value for bleach: '{bleach}'")

    return model, bleach

# IRIS

Encoder:
- Type: Thermometer
- Resolution: 16
- Min, max: 0, 8

In [4]:
thermometer = ThermometerEncoder(minimum=0, maximum=8, resolution=16)
(x_train, y_train), (x_test, y_test) = IrisDataset(test_size=0.2).load_data()
x_train = encode_dataset(thermometer, x_train)
x_test = encode_dataset(thermometer, x_test)

Encoding dataset: 100%|██████████████████████████████████████████| 120/120 [00:00<00:00, 23792.97it/s]
Encoding dataset: 100%|████████████████████████████████████████████| 30/30 [00:00<00:00, 28906.30it/s]


In [5]:
model, bleach = do_train_and_evaluate(x_train, y_train, x_test, y_test, tuple_size=8)
print(f"Best bleach: {bleach}")

 ----- Training model ----- 


Training model: 100%|██████████████████████████████████████████████| 120/120 [00:00<00:00, 126.63it/s]


Max bleach is: 40

 ----- Evaluating model ----- 
Testing with bleach=10
[b=10] Accuracy=0.933, ties=1
Testing with bleach=20
[b=20] Accuracy=0.833, ties=3
Testing with bleach=30
[b=30] Accuracy=0.800, ties=7
Testing with bleach=5
[b=5] Accuracy=0.867, ties=4
Testing with bleach=10
[b=10] Accuracy=0.933
Testing with bleach=15
[b=15] Accuracy=0.933, ties=1
Testing with bleach=8
[b=8] Accuracy=0.833, ties=4
Testing with bleach=10
[b=10] Accuracy=0.933
Testing with bleach=12
[b=12] Accuracy=0.933, ties=1
Testing with bleach=9
[b=9] Accuracy=0.833, ties=4
Testing with bleach=10
[b=10] Accuracy=0.933
Testing with bleach=11
[b=11] Accuracy=0.933, ties=1
Best bleach: 10....
Best bleach: 10


In [6]:
model, bleach = do_train_and_evaluate(x_train, y_train, x_test, y_test, tuple_size=12)
print(f"Best bleach: {bleach}")

 ----- Training model ----- 


Training model: 100%|████████████████████████████████████████████| 120/120 [00:00<00:00, 26459.70it/s]

Max bleach is: 40

 ----- Evaluating model ----- 
Testing with bleach=10
[b=10] Accuracy=0.933, ties=2
Testing with bleach=20
[b=20] Accuracy=0.867, ties=7
Testing with bleach=30
[b=30] Accuracy=0.533, ties=20
Testing with bleach=5
[b=5] Accuracy=0.867, ties=4
Testing with bleach=10
[b=10] Accuracy=0.933
Testing with bleach=15
[b=15] Accuracy=0.867, ties=3
Testing with bleach=8
[b=8] Accuracy=0.833, ties=5
Testing with bleach=10
[b=10] Accuracy=0.933
Testing with bleach=12
[b=12] Accuracy=0.933, ties=3
Testing with bleach=9
[b=9] Accuracy=0.833, ties=5
Testing with bleach=10
[b=10] Accuracy=0.933
Testing with bleach=11
[b=11] Accuracy=0.933, ties=1
Best bleach: 10....
Best bleach: 10





In [7]:
model, bleach = do_train_and_evaluate(x_train, y_train, x_test, y_test, tuple_size=16)
print(f"Best bleach: {bleach}")

 ----- Training model ----- 


Training model: 100%|████████████████████████████████████████████| 120/120 [00:00<00:00, 27835.22it/s]

Max bleach is: 40

 ----- Evaluating model ----- 
Testing with bleach=10
[b=10] Accuracy=0.933, ties=0
Testing with bleach=20
[b=20] Accuracy=0.767, ties=4
Testing with bleach=30
[b=30] Accuracy=0.500, ties=11
Testing with bleach=5
[b=5] Accuracy=0.867, ties=3
Testing with bleach=10
[b=10] Accuracy=0.933
Testing with bleach=15
[b=15] Accuracy=0.800, ties=5
Testing with bleach=8
[b=8] Accuracy=0.833, ties=3
Testing with bleach=10
[b=10] Accuracy=0.933
Testing with bleach=12
[b=12] Accuracy=0.900, ties=5
Testing with bleach=9
[b=9] Accuracy=0.833, ties=3
Testing with bleach=10
[b=10] Accuracy=0.933
Testing with bleach=11
[b=11] Accuracy=0.933, ties=0
Best bleach: 10....
Best bleach: 10





# Fashion MNIST

Encoder:
- Type: Thermometer
- Resolution: 16
- Min, max: 0, 255

In [8]:
thermometer = ThermometerEncoder(minimum=0, maximum=255, resolution=16)
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_train = encode_dataset(thermometer, x_train)
x_test = encode_dataset(thermometer, x_test)

Encoding dataset: 100%|███████████████████████████████████████| 60000/60000 [00:09<00:00, 6391.18it/s]
Encoding dataset: 100%|███████████████████████████████████████| 10000/10000 [00:01<00:00, 7619.20it/s]


In [9]:
model, bleach = do_train_and_evaluate(x_train, y_train, x_test, y_test, tuple_size=8)
print(f"Best bleach: {bleach}")

 ----- Training model ----- 


Training model: 100%|██████████████████████████████████████████| 60000/60000 [01:51<00:00, 535.93it/s]


Max bleach is: 6000

 ----- Evaluating model ----- 
Testing with bleach=1500
[b=1500] Accuracy=0.676, ties=69
Testing with bleach=3000
[b=3000] Accuracy=0.478, ties=101
Testing with bleach=4500
[b=4500] Accuracy=0.269, ties=55
Testing with bleach=750
[b=750] Accuracy=0.633, ties=78
Testing with bleach=1500
[b=1500] Accuracy=0.676
Testing with bleach=2250
[b=2250] Accuracy=0.604, ties=92
Testing with bleach=1125
[b=1125] Accuracy=0.666, ties=64
Testing with bleach=1500
[b=1500] Accuracy=0.676
Testing with bleach=1875
[b=1875] Accuracy=0.661, ties=71
Testing with bleach=1313
[b=1313] Accuracy=0.675, ties=53
Testing with bleach=1500
[b=1500] Accuracy=0.676
Testing with bleach=1687
[b=1687] Accuracy=0.668, ties=65
Testing with bleach=1407
[b=1407] Accuracy=0.676, ties=77
Testing with bleach=1500
[b=1500] Accuracy=0.676
Testing with bleach=1593
[b=1593] Accuracy=0.674, ties=71
Testing with bleach=1454
[b=1454] Accuracy=0.679, ties=71
Testing with bleach=1500
[b=1500] Accuracy=0.676
Testing 

In [10]:
model, bleach = do_train_and_evaluate(x_train, y_train, x_test, y_test, tuple_size=12)
print(f"Best bleach: {bleach}")

 ----- Training model ----- 


Training model: 100%|██████████████████████████████████████████| 60000/60000 [01:18<00:00, 764.64it/s]


Max bleach is: 6000

 ----- Evaluating model ----- 
Testing with bleach=1500
[b=1500] Accuracy=0.613, ties=146
Testing with bleach=3000
[b=3000] Accuracy=0.406, ties=207
Testing with bleach=4500
[b=4500] Accuracy=0.204, ties=121
Testing with bleach=750
[b=750] Accuracy=0.622, ties=106
Testing with bleach=1500
[b=1500] Accuracy=0.613
Testing with bleach=2250
[b=2250] Accuracy=0.504, ties=215
Testing with bleach=375
[b=375] Accuracy=0.689, ties=86
Testing with bleach=750
[b=750] Accuracy=0.622
Testing with bleach=1125
[b=1125] Accuracy=0.632, ties=134
Testing with bleach=188
[b=188] Accuracy=0.725, ties=87
Testing with bleach=375
[b=375] Accuracy=0.689
Testing with bleach=562
[b=562] Accuracy=0.642, ties=111
Testing with bleach=95
[b=95] Accuracy=0.760, ties=93
Testing with bleach=188
[b=188] Accuracy=0.725
Testing with bleach=281
[b=281] Accuracy=0.706, ties=94
Testing with bleach=49
[b=49] Accuracy=0.775, ties=106
Testing with bleach=95
[b=95] Accuracy=0.760
Testing with bleach=141
[b=

In [11]:
model, bleach = do_train_and_evaluate(x_train, y_train, x_test, y_test, tuple_size=16)
print(f"Best bleach: {bleach}")

 ----- Training model ----- 


Training model: 100%|██████████████████████████████████████████| 60000/60000 [01:02<00:00, 961.03it/s]


Max bleach is: 6000

 ----- Evaluating model ----- 
Testing with bleach=1500
[b=1500] Accuracy=0.528, ties=462
Testing with bleach=3000
[b=3000] Accuracy=0.335, ties=392
Testing with bleach=4500
[b=4500] Accuracy=0.130, ties=155
Testing with bleach=750
[b=750] Accuracy=0.607, ties=164
Testing with bleach=1500
[b=1500] Accuracy=0.528
Testing with bleach=2250
[b=2250] Accuracy=0.418, ties=468
Testing with bleach=375
[b=375] Accuracy=0.679, ties=135
Testing with bleach=750
[b=750] Accuracy=0.607
Testing with bleach=1125
[b=1125] Accuracy=0.567, ties=326
Testing with bleach=188
[b=188] Accuracy=0.719, ties=120
Testing with bleach=375
[b=375] Accuracy=0.679
Testing with bleach=562
[b=562] Accuracy=0.627, ties=151
Testing with bleach=95
[b=95] Accuracy=0.756, ties=105
Testing with bleach=188
[b=188] Accuracy=0.719
Testing with bleach=281
[b=281] Accuracy=0.696, ties=136
Testing with bleach=49
[b=49] Accuracy=0.772, ties=123
Testing with bleach=95
[b=95] Accuracy=0.756
Testing with bleach=141

# MNIST

Encoder:
- Type: Thermometer
- Resolution: 16
- Min, max: 0, 255

In [12]:
thermometer = ThermometerEncoder(minimum=0, maximum=255, resolution=16)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = encode_dataset(thermometer, x_train)
x_test = encode_dataset(thermometer, x_test)

Encoding dataset: 100%|███████████████████████████████████████| 60000/60000 [00:08<00:00, 6872.08it/s]
Encoding dataset: 100%|███████████████████████████████████████| 10000/10000 [00:01<00:00, 7592.19it/s]


In [13]:
model, bleach = do_train_and_evaluate(x_train, y_train, x_test, y_test, tuple_size=8)
print(f"Best bleach: {bleach}")

 ----- Training model ----- 


Training model: 100%|██████████████████████████████████████████| 60000/60000 [01:50<00:00, 542.20it/s]


Max bleach is: 6742

 ----- Evaluating model ----- 
Testing with bleach=1686
[b=1686] Accuracy=0.425, ties=169
Testing with bleach=3371
[b=3371] Accuracy=0.134, ties=150
Testing with bleach=5056
[b=5056] Accuracy=0.114, ties=0
Testing with bleach=844
[b=844] Accuracy=0.701, ties=93
Testing with bleach=1686
[b=1686] Accuracy=0.425
Testing with bleach=2528
[b=2528] Accuracy=0.215, ties=219
Testing with bleach=423
[b=423] Accuracy=0.792, ties=59
Testing with bleach=844
[b=844] Accuracy=0.701
Testing with bleach=1265
[b=1265] Accuracy=0.566, ties=159
Testing with bleach=213
[b=213] Accuracy=0.857, ties=35
Testing with bleach=423
[b=423] Accuracy=0.792
Testing with bleach=633
[b=633] Accuracy=0.741, ties=74
Testing with bleach=108
[b=108] Accuracy=0.889, ties=21
Testing with bleach=213
[b=213] Accuracy=0.857
Testing with bleach=318
[b=318] Accuracy=0.826, ties=41
Testing with bleach=56
[b=56] Accuracy=0.897, ties=48
Testing with bleach=108
[b=108] Accuracy=0.889
Testing with bleach=160
[b=1

In [14]:
model, bleach = do_train_and_evaluate(x_train, y_train, x_test, y_test, tuple_size=12)
print(f"Best bleach: {bleach}")

 ----- Training model ----- 


Training model: 100%|██████████████████████████████████████████| 60000/60000 [01:17<00:00, 770.32it/s]


Max bleach is: 6742

 ----- Evaluating model ----- 
Testing with bleach=1686
[b=1686] Accuracy=0.248, ties=537
Testing with bleach=3371
[b=3371] Accuracy=0.125, ties=389
Testing with bleach=5056
[b=5056] Accuracy=0.114, ties=31
Testing with bleach=844
[b=844] Accuracy=0.616, ties=221
Testing with bleach=1686
[b=1686] Accuracy=0.248
Testing with bleach=2528
[b=2528] Accuracy=0.137, ties=926
Testing with bleach=423
[b=423] Accuracy=0.761, ties=93
Testing with bleach=844
[b=844] Accuracy=0.616
Testing with bleach=1265
[b=1265] Accuracy=0.421, ties=377
Testing with bleach=213
[b=213] Accuracy=0.825, ties=62
Testing with bleach=423
[b=423] Accuracy=0.761
Testing with bleach=633
[b=633] Accuracy=0.692, ties=139
Testing with bleach=108
[b=108] Accuracy=0.876, ties=36
Testing with bleach=213
[b=213] Accuracy=0.825
Testing with bleach=318
[b=318] Accuracy=0.787, ties=102
Testing with bleach=56
[b=56] Accuracy=0.905, ties=40
Testing with bleach=108
[b=108] Accuracy=0.876
Testing with bleach=160


In [15]:
model, bleach = do_train_and_evaluate(x_train, y_train, x_test, y_test, tuple_size=16)
print(f"Best bleach: {bleach}")

 ----- Training model ----- 


Training model: 100%|██████████████████████████████████████████| 60000/60000 [01:02<00:00, 967.62it/s]


Max bleach is: 6742

 ----- Evaluating model ----- 
Testing with bleach=1686
[b=1686] Accuracy=0.203, ties=2563
Testing with bleach=3371
[b=3371] Accuracy=0.126, ties=1141
Testing with bleach=5056
[b=5056] Accuracy=0.114, ties=273
Testing with bleach=844
[b=844] Accuracy=0.440, ties=780
Testing with bleach=1686
[b=1686] Accuracy=0.203
Testing with bleach=2528
[b=2528] Accuracy=0.155, ties=1969
Testing with bleach=423
[b=423] Accuracy=0.666, ties=220
Testing with bleach=844
[b=844] Accuracy=0.440
Testing with bleach=1265
[b=1265] Accuracy=0.284, ties=2148
Testing with bleach=213
[b=213] Accuracy=0.769, ties=105
Testing with bleach=423
[b=423] Accuracy=0.666
Testing with bleach=633
[b=633] Accuracy=0.549, ties=401
Testing with bleach=108
[b=108] Accuracy=0.831, ties=69
Testing with bleach=213
[b=213] Accuracy=0.769
Testing with bleach=318
[b=318] Accuracy=0.726, ties=171
Testing with bleach=56
[b=56] Accuracy=0.874, ties=37
Testing with bleach=108
[b=108] Accuracy=0.831
Testing with blea