In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F

In [2]:
training = pd.read_csv('fashion-mnist_train.csv')
testing = pd.read_csv('fashion-mnist_test.csv')

In [3]:
X_train, y_train = training.loc[:, training.columns != 'label'], training.loc[:, training.columns == 'label']
X_test, y_test = testing.loc[:, testing.columns != 'label'], testing.loc[:, testing.columns == 'label']

In [4]:
X_train = X_train.to_numpy()
X_test = X_test.to_numpy()
y_train = y_train.to_numpy()
y_test = y_test.to_numpy()

In [5]:
X_train, X_test = X_train/255, X_test/255

In [6]:
from sklearn.preprocessing import OneHotEncoder

onehot_encoder = OneHotEncoder()
y_train = onehot_encoder.fit_transform(y_train).toarray()
y_test = onehot_encoder.fit_transform(y_test).toarray()

In [7]:
# y_train = y_train.toarray()
# y_test = y_test.toarray()

In [8]:
class FCNN(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.linear1 = nn.Linear(input_size, 128)
        self.linear2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

In [9]:
model = FCNN(X_train.shape[1])

learning_rate = 0.01
optimiser = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

n_epochs = 5
n_iters = 5001
batch_size = 50

for epoch in range(n_epochs):
    for i in range(n_iters):
        idx = torch.randperm(X_train.shape[0])[:batch_size].numpy()
        X_batch = torch.tensor(X_train[idx][:batch_size]).type(torch.float32)
        y_batch = torch.tensor(y_train[idx][:batch_size]).type(torch.float32)
        y_pred = model(X_batch)

        loss = criterion(y_pred, y_batch)
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

        if i % 1000 == 0 and i != 0:
            _, predictions = torch.max(y_pred, 1)
            _, labels = torch.max(y_batch, 1)
            accuracy = predictions.eq(labels).sum() / batch_size * 100
            print(f'Iteration: {i} | loss: {loss} |Accuracy: {accuracy}%')
            
    print(f'Epoch {epoch+1} completed.')

Iteration: 1000 | loss: 0.2650490403175354 |Accuracy: 90.0%
Iteration: 2000 | loss: 0.27972570061683655 |Accuracy: 90.0%
Iteration: 3000 | loss: 0.3511337637901306 |Accuracy: 86.0%
Iteration: 4000 | loss: 0.40125328302383423 |Accuracy: 82.0%
Iteration: 5000 | loss: 0.5328548550605774 |Accuracy: 86.0%
Epoch 1 completed.
Iteration: 1000 | loss: 0.23673681914806366 |Accuracy: 94.0%
Iteration: 2000 | loss: 0.2324506938457489 |Accuracy: 90.0%
Iteration: 3000 | loss: 0.2109733521938324 |Accuracy: 90.0%
Iteration: 4000 | loss: 0.30207526683807373 |Accuracy: 92.0%
Iteration: 5000 | loss: 0.22341850399971008 |Accuracy: 92.0%
Epoch 2 completed.
Iteration: 1000 | loss: 0.34380996227264404 |Accuracy: 84.0%
Iteration: 2000 | loss: 0.6069839000701904 |Accuracy: 78.0%
Iteration: 3000 | loss: 0.19872169196605682 |Accuracy: 94.0%
Iteration: 4000 | loss: 0.5647631883621216 |Accuracy: 84.0%
Iteration: 5000 | loss: 0.34545010328292847 |Accuracy: 88.0%
Epoch 3 completed.
Iteration: 1000 | loss: 0.312299966

In [10]:
_, test_pred = torch.max(model(torch.tensor(X_test).type(torch.float32)),1)
_, test_labels = torch.max(torch.tensor(y_test), 1)
acc = test_pred.eq(test_labels).sum() / y_test.shape[0] * 100
print(f'Test Accuracy: {acc}%')

Test Accuracy: 86.93000030517578%


In [None]:
from concrete.ml.torch.compile import compile_torch_model

print('Compiling the model to FHE.')

try:
    quantised_compiled_module = compile_torch_model(
        model,
        X_train,
        n_bits=3
    )
    print('The network is trained and FHE friendly.')
except RuntimeError as e:
    if str(e).startswith("max_bit_width of some nodes is too high"):
        print("The network is not fully FHE friendly, retraining.")
    raise e
else:
    raise RuntimeError(
        "Could not compile the model to FHE."
        "You may need to decrease the n_bits parameter to avoid potential overflows."
    )

Compiling the model to FHE.


In [None]:
from concrete.ml.torch.compile import compile_torch_model

In [None]:
from concrete.common.compilation import CompilationConfiguration

cfg = CompilationConfiguration(
    dump_artifacts_on_unexpected_failures=False,
    enable_unsafe_features=True,  # This is for our tests only, never use that in prod
    treat_warnings_as_errors=True,
    use_insecure_key_cache=False,
)

In [None]:
from concrete.ml.torch.compile import compile_torch_model

def test_with_concrete_virtual_lib(quantised_module, use_fhe, use_vl):
    dtype_inputs = np.uint8 if use_fhe else np.int32
    all_y_pred = np.zeros((len(X_test)), dtype=np.int32)
    all_targets = np.zeros((len(X_test)), dtype=np.int32)

    idx = 0
    for i in range(len(X_test)):
        sample_q = quantised_module.quantize_input(X_test[i]).astype(dtype_inputs)
        
        endidx = idx + y_test.shape[1]
        all_targets[idx:endidx] = y_test[i]
        
        for j in range(len(X_test)):
            x_q = np.expand_dims(sample_q[j,:], 0)
            if use_fhe or use_vl:
                out_fhe = quantised_module.forward_fhe.encrypt_run_decrypt(x_q)
                output = quantised_module.dequantize_output(out_fhe)
            else:
                output = quantised_module.forward_and_dequant(x_q)
            
            y_pred = np.argmax(output, 1)
            all_y_pred[idx] = y_pred
            idx += 1
    
    n_correct = np.sum(all_targets == all_y_pred)
    return n_correct / len(X_test)


accs = []
accum_bits = []
for n_bits in range(2,3):
    q_module_vl = compile_torch_model(
        model,
        X_train,
        n_bits=n_bits,
        use_virtual_lib = True,
        compilation_configuration=cfg
    )
    accum_bits.append(q_module_vl.forward_fhe.get_max_bit_width())
    accs.append(
        test_with_concrete_virtual_lib(
            q_module_vl,
            use_fhe=False,
            use_vl=True,
        )
    )
        

In [None]:
fig = plt.figure(figsize=(12, 8))
plt.rcParams["font.size"] = 14
plt.plot(range(2, 9), accs, "-x")
for bits, acc, accum in zip(range(2, 9), accs, accum_bits):
    plt.gca().annotate(str(accum), (bits - 0.1, acc + 0.025))
plt.ylabel("Accuracy on test set")
plt.xlabel("Weight & activation quantization")
plt.grid(True)
plt.title("Accuracy for varying quantization bit width")
plt.show()


In [None]:
import numpy as np

a = np.expand_dims(X_test_q[j,], 0)
a.shape