# CIFAR-10 FHE classification with 8-bit split VGG

As mentioned in the [README](./README.md) we present in this notebook how to compile to FHE a split torch model.
The model we will be considering is a CIFAR-10 classifier based on the VGG architecture. It was trained with pruning and accumulator bit-width monitoring so that the classifier does not exceed the 8 bit-width accumulator constraint.

The first layers of the models should be run on the clear data on the client's side and the rest of the model in FHE on the server's side.

In [1]:
import time

import pandas as pd
import torch
import torchvision
from model import CNV  # pylint: disable=no-name-in-module
from sklearn.metrics import top_k_accuracy_score
from torchvision import transforms

from concrete.ml.torch.compile import compile_brevitas_qat_model

No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'


In `model.py` we define our model architecture.

As one can see we split the main model `CNV` into two sub-models `ClearModule` and `EncryptedModule`.

- `ClearModule` will be used to run on clear data on the client's side. It can do any float operations and does not require quantization.
- `EncryptedModule` will run on the server side. This part of the model running in FHE we need to quantize it, thus why we leverage Brevitas for Quantization Aware Training.

In [2]:
model = CNV(num_classes=10, weight_bit_width=2, act_bit_width=2, in_bit_width=3, in_ch=3)

We won't be training the model is this notebook as it would be quite computationnaly intensive but we provide an already trained model that satisfies the 8-bit accumulator size constraint and that performs better than random on CIFAR-10.

In [3]:
loaded = torch.load("./8_bit_model.pt")

In [4]:
model.load_state_dict(loaded["model_state_dict"])
model = model.eval()

In [5]:
IMAGE_TRANSFORM = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

try:
    train_set = torchvision.datasets.CIFAR10(
        root=".data/",
        train=True,
        download=False,
        transform=IMAGE_TRANSFORM,
        target_transform=None,
    )
except RuntimeError:
    train_set = torchvision.datasets.CIFAR10(
        root=".data/",
        train=True,
        download=True,
        transform=IMAGE_TRANSFORM,
        target_transform=None,
    )
test_set = torchvision.datasets.CIFAR10(
    root=".data/",
    train=False,
    download=False,
    transform=IMAGE_TRANSFORM,
    target_transform=None,
)

print((train_set, test_set))

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to .data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting .data/cifar-10-python.tar.gz to .data/
(Dataset CIFAR10
    Number of datapoints: 50000
    Root location: .data/
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
           ), Dataset CIFAR10
    Number of datapoints: 10000
    Root location: .data/
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
           ))


We use a sub-sample of the training set for the FHE compilation to maintain acceptable compilation times and avoid out-of-memory errors.

In [6]:
num_samples = 1000
train_sub_set = torch.stack(
    [train_set[index][0] for index in range(min(num_samples, len(train_set)))]
)

Since we will be compiling only a part of the network we need to give it representative inputs, in our case the first feature map of the network.

In [7]:
# Pre-processing -> images -> feature maps
with torch.no_grad():
    train_features_sub_set = model.clear_module(train_sub_set)

# FHE Simulation

In a first time we can make sure that our FHE constraints are respected.

In [8]:
optional_kwargs = {}

# Compile the model
compilation_onnx_path = "compilation_model.onnx"
print("Compiling the model")
start_compile = time.time()
quantized_numpy_module = compile_brevitas_qat_model(
    # our encrypted model
    torch_model=model.encrypted_module,
    # a representative input-set to be used for compilation
    torch_inputset=train_features_sub_set,
    **optional_kwargs,
    output_onnx_file=compilation_onnx_path,
)

end_compile = time.time()
print(f"Compilation finished in {end_compile - start_compile:.2f} seconds")

# Check that the network is compatible with FHE constraints
assert quantized_numpy_module.fhe_circuit is not None
bitwidth = quantized_numpy_module.fhe_circuit.graph.maximum_integer_bit_width()
print(f"Max bit-width: {bitwidth} bits!")

Compiling the model
Compilation finished in 93.09 seconds
Max bitwidth: 8 bits!


In [9]:
img, _ = train_set[0]
with torch.no_grad():
    feature_maps = model.clear_module(img[None, :])

In [10]:
output_simulated = quantized_numpy_module.forward(feature_maps.numpy(), fhe="simulate")

In [11]:
with torch.no_grad():
    torch_output = model(img[None, :])

In [12]:
print(torch_output - output_simulated)

tensor([[ 0.0171,  0.0171, -0.0215,  0.0122,  0.0232, -0.0144,  0.0042, -0.0115,
          0.0180,  0.0065]], dtype=torch.float64)


We see that we have some differences between the output of the torch model output and the FHE simulation.

This is expected but as we can see in the following code blocks we have no difference in top-k accuracies between Pytorch and the FHE simulation mode.

It appears that there are some differences between the output of the Torch model and the FHE simulation. While this outcome was expected, it is important to note that, as demonstrated in the following code blocks, there are no differences in the top-k accuracies between PyTorch and the FHE simulation mode.

In [13]:
def evaluate(file_path: str, k=3):
    predictions = pd.read_csv(file_path)
    prob_columns = [elt for elt in predictions.columns if elt.endswith("_prob")]
    predictions["pred_label"] = predictions[prob_columns].values.argmax(axis=1)

    # Equivalent to top-1-accuracy
    for k_ in range(1, k + 1):
        print(
            f"top-{k}-accuracy: ",
            top_k_accuracy_score(
                y_true=predictions["label"], y_score=predictions[prob_columns], k=k_
            ),
        )

We can use the `infer_fhe_simulation.py` script to generate the predictions of the model using Pytorch for the first layer and FHE simulation for the rest of the network.

In [14]:
%run infer_fhe_simulation.py
evaluate("./fhe_simulated_predictions.csv")

Compiling the model
Compilation finished in 86.79 seconds


                                                              

Finished inference
top-3-accuracy:  0.6231
top-3-accuracy:  0.8072
top-3-accuracy:  0.8906


And the `infer.py` script to generate the pure Pytorch predictions.

In [15]:
%run infer_torch.py
evaluate("./predictions.csv")

                                                              

Finished inference
top-3-accuracy:  0.6231
top-3-accuracy:  0.8072
top-3-accuracy:  0.8906


# FHE execution results

In this notebook we showed how to compile a split-VGG model trained to classify CIFAR-10 images in FHE.

While satisfying the FHE constraints the model achieves the following performances:

- top-1-accuracy: 0.6234
- top-2-accuracy: 0.8075
- top-3-accuracy: 0.8905

*We don't launch the inference in FHE in this notebook as it takes quite some time just to infer on one image.*

For reference we ran the inference of one image on an AWS c6i.metal compute machine, using the `fhe_inference.py` script, and got the following timings:

- Time to compile: 103 seconds
- Time to keygen: 639 seconds
- Time to infer: ~1800 seconds