# ToyNet

ToyNet consists of a single linear layer along with the activation function
$$\mathrm{Sign}(x) = \begin{cases}
    +1, &\text{if } x > 0, \\
    -1, &\text{otherwise}.
  \end{cases}$$

This notebook is used for testing FHE inference with a simple matrix multiplication
along with batch normalisation and ReLU.

In [1]:
%load_ext autoreload
%autoreload 2

## Network Implementation

### Clear Network

The following network is trained in the clear with standard PyTorch methods.

In [2]:
from torch import Tensor
from torch.nn import BatchNorm1d, Module, Linear, ReLU, Sequential
import torch.nn.functional as F

from doren_bnn.xnorpp import Sign


class ToyNet(Module):
    def __init__(self, num_input: int = 10, num_output: int = 10, **kwargs):
        super(ToyNet, self).__init__()

        self.block = Sequential(
            Linear(num_input, num_output, bias=False),
            BatchNorm1d(num_output),
            ReLU(inplace=True),
        )

    def forward(self, input: Tensor) -> Tensor:
        num_input = self.block[0].weight.size(-1)

        input = input.view(-1, 3 * 32 * 32)[:, :num_input]
        (output_lin,) = (F.linear(Sign.apply(input), Sign.apply(self.block[0].weight)),)
        output_bn = self.block[1](output_lin)
        return self.block[2](output_bn)

### FHE Network

The following network is a FHE version of the clear network.

In [3]:
from doren_bnn_concrete import toynet


class ToyNet_FHE(ToyNet):
    def __init__(self, **kwargs):
        super(ToyNet_FHE, self).__init__(**kwargs)

    def forward(self, input: Tensor) -> Tensor:
        assert not self.training

        num_input = self.block[0].weight.size(-1)

        state_dict = self.state_dict()
        state_dict["block.0.weight"] = Sign.apply(self.block[0].weight).long().tolist()
        print(state_dict)

        input = input.view(-1, 3 * 32 * 32)[:, :num_input].tolist()
        output = []
        for im in input:
            output.append(toynet(state_dict, im))
        return Tensor(output)

## Training

In [4]:
# define training settings
NUM_EPOCHS = 1
BATCH_SIZE = 2

In [5]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
from torchinfo import summary

NUM_INPUT = 10  # determines how many input neurons
model = ToyNet(num_input=NUM_INPUT, num_output=10).to(device)

summary(model, input_size=(BATCH_SIZE, 3, 32, 32))

Layer (type:depth-idx)                   Output Shape              Param #
ToyNet                                   [2, 10]                   --
├─Sequential: 1-1                        --                        --
│    └─Linear: 2-1                       --                        100
│    └─BatchNorm1d: 2-2                  [2, 10]                   20
│    └─ReLU: 2-3                         [2, 10]                   --
Total params: 120
Trainable params: 120
Non-trainable params: 0
Total mult-adds (M): 0.00
Input size (MB): 0.02
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.03

In [7]:
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

criterion = CrossEntropyLoss().to(device)
optimizer = AdamW(model.parameters(), lr=1e-2, weight_decay=5e-6)
scheduler = CosineAnnealingWarmRestarts(optimizer, 30)

### Experiment Setup

In [8]:
from doren_bnn.utils import Dataset, Experiment

EXPERIMENT_ID = "toynet"
experiment = Experiment(EXPERIMENT_ID, Dataset.CIFAR10, BATCH_SIZE, multiplier=0.001)

Files already downloaded and verified
Files already downloaded and verified


In [9]:
experiment.train(
    model,
    criterion,
    optimizer,
    scheduler,
    NUM_EPOCHS,
    device=device,
)

  0%|          | 0/1 [00:00<?, ?it/s]

### Clear Inference

In [10]:
experiment.test(model, device=device)

### FHE Inference

In [11]:
from doren_bnn_concrete import preload_keys

# preload_keys()

In [12]:
model_fhe = ToyNet_FHE(num_input=NUM_INPUT, num_output=10)
cp = experiment.load_checkpoint(model_fhe, optimizer, scheduler)

If FHE inference is correct, the output should be exactly the same (after rounding) as the output of clear inference.

In [13]:
experiment.test_fhe(model_fhe)

OrderedDict([('block.0.weight', [[1, 1, 1, -1, -1, -1, -1, 1, 1, 1], [1, 1, -1, 1, -1, 1, -1, 1, -1, 1], [1, 1, -1, -1, -1, 1, 1, -1, 1, 1], [-1, -1, -1, 1, -1, 1, -1, -1, -1, 1], [-1, -1, 1, -1, 1, 1, -1, 1, -1, -1], [1, -1, -1, 1, -1, 1, 1, 1, 1, -1], [-1, -1, -1, -1, -1, 1, 1, -1, 1, -1], [-1, 1, -1, 1, 1, -1, -1, 1, 1, -1], [1, -1, 1, -1, -1, 1, 1, -1, 1, 1], [1, 1, 1, 1, -1, -1, 1, 1, 1, 1]]), ('block.1.weight', tensor([0.9526, 0.9739, 0.9689, 0.9690, 0.9786, 0.8366, 0.9200, 0.8543, 0.8449,
        0.9867])), ('block.1.bias', tensor([-0.0474, -0.0261, -0.0311,  0.0123, -0.0214, -0.1634, -0.0800, -0.1457,
        -0.1551, -0.0315])), ('block.1.running_mean', tensor([ 0.7556,  0.1569,  0.2595, -0.3677,  0.0578,  0.1910, -0.0127,  0.0292,
         0.2432,  0.7833])), ('block.1.running_var', tensor([ 2.9861,  2.5818,  3.9839,  9.8438,  1.9769,  6.1268, 17.2033,  1.7501,
         3.0241, 13.2730])), ('block.1.num_batches_tracked', tensor(25))])
Loading existing client & server keys...


RuntimeError: fc.weight not found