# ToyNet

ToyNet consists of a single linear layer along with the activation function
$$\mathrm{Sign}(x) = \begin{cases}
    +1, &\text{if } x > 0, \\
    -1, &\text{otherwise}.
  \end{cases}$$

This notebook is used for testing FHE inference with a simple matrix multiplication
along with batch normalisation and ReLU.

In [None]:
%load_ext autoreload
%autoreload 2

## Network Implementation

### Clear Network

The following network is trained in the clear with standard PyTorch methods.

In [None]:
from torch import Tensor
from torch.nn import BatchNorm1d, Module, Linear, ReLU, Sequential
import torch.nn.functional as F

from doren_bnn.xnorpp import Sign


class ToyNet(Module):
    def __init__(self, num_input: int = 10, num_output: int = 10, **kwargs):
        super(ToyNet, self).__init__()

        self.block = Sequential(
            Linear(num_input, num_output, bias=False),
            BatchNorm1d(num_output),
            ReLU(inplace=True),
        )

    def forward(self, input: Tensor) -> Tensor:
        num_input = self.block[0].weight.size(-1)

        input = input.view(-1, 3 * 32 * 32)[:, :num_input]
        (output_lin,) = (F.linear(Sign.apply(input), Sign.apply(self.block[0].weight)),)
        output_bn = self.block[1](output_lin)
        output = self.block[2](output_bn)
        if not self.training:
            print(output[:, :10])
        return output

### FHE Network

The following network is a FHE version of the clear network.

In [None]:
from doren_bnn_concrete import toynet


class ToyNet_FHE(ToyNet):
    def __init__(self, **kwargs):
        super(ToyNet_FHE, self).__init__(**kwargs)

    def forward(self, input: Tensor) -> Tensor:
        assert not self.training

        num_input = self.block[0].weight.size(-1)

        state_dict = self.state_dict()
        state_dict["block.0.weight"] = Sign.apply(self.block[0].weight).long().tolist()
        print(state_dict)

        input = input.view(-1, 3 * 32 * 32)[:, :num_input].tolist()
        output = []
        for im in input:
            output_tn = toynet(state_dict)
            output.append(output_tn)
            print(output_tn[:10])
        return Tensor(output)

## Training

In [None]:
# define training settings
NUM_EPOCHS = 1
BATCH_SIZE = 2

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from torchinfo import summary

NUM_INPUT = 10  # determines how many input neurons
model = ToyNet(num_input=NUM_INPUT, num_output=10).to(device)

summary(model, input_size=(BATCH_SIZE, 3, 32, 32))

In [None]:
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

criterion = CrossEntropyLoss().to(device)
optimizer = AdamW(model.parameters(), lr=1e-2, weight_decay=5e-6)
scheduler = CosineAnnealingWarmRestarts(optimizer, 30)

### Experiment Setup

In [None]:
from doren_bnn.utils import Dataset, Experiment

EXPERIMENT_ID = "toynet"
experiment = Experiment(EXPERIMENT_ID, Dataset.CIFAR10, BATCH_SIZE, multiplier=0.001)

In [None]:
experiment.train(
    model,
    criterion,
    optimizer,
    scheduler,
    NUM_EPOCHS,
    device=device,
)

### Clear Inference

In [None]:
experiment.test(model, device=device)

### FHE Inference

In [None]:
from doren_bnn_concrete import preload_keys

preload_keys()

In [None]:
model_fhe = ToyNet_FHE(num_input=NUM_INPUT, num_output=10)
cp = experiment.load_checkpoint(model_fhe, optimizer, scheduler)

If FHE inference is correct, the output should be exactly the same (after rounding) as the output of clear inference.

In [None]:
experiment.test_fhe(model_fhe)