# ToyNet

ToyNet consists of a single linear layer along with the activation function
$$\mathrm{Sign}(x) = \begin{cases}
    +1, &\text{if } x > 0, \\
    -1, &\text{otherwise}.
  \end{cases}$$

This notebook is used for testing FHE inference with a simple matrix multiplication
along with bootstrapping.

In [1]:
%reset

## Network Implementation

### Clear Network

The following network is trained in the clear with standard PyTorch methods.

In [2]:
from torch import Tensor
from torch.nn import Module, Linear
import torch.nn.functional as F

from doren_bnn.xnorpp import Sign


class ToyNet(Module):
    def __init__(self, num_input: int = 1024, num_classes: int = 1000, **kwargs):
        super(ToyNet, self).__init__()

        self.fc = Linear(num_input, num_classes, bias=False)

    def forward(self, input: Tensor) -> Tensor:
        num_input = self.fc.weight.size(-1)

        input = input.view(-1, 3 * 224 * 224)[:, :num_input]
        return Sign.apply(F.linear(Sign.apply(input), Sign.apply(self.fc.weight)))

### FHE Network

The following network is a FHE version of the clear network.

In [3]:
from doren_bnn_concrete import toynet


class ToyNet_FHE(ToyNet):
    def __init__(self, **kwargs):
        super(ToyNet_FHE, self).__init__(**kwargs)

    def forward(self, input: Tensor) -> Tensor:
        assert not self.training

        num_input = self.fc.weight.size(-1)

        state_dict = self.state_dict()
        state_dict["fc.weight"] = [
            [w > 0 for w in row] for row in self.fc.weight.tolist()
        ]

        input = input.view(-1, 3 * 224 * 224)[:, :num_input].tolist()
        output = []
        for im in input:
            output.append(toynet(state_dict, im))
        return Tensor(output)

## Training

In [4]:
# define training settings
NUM_EPOCHS = 1
BATCH_SIZE = 2

In [5]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
from torchinfo import summary

NUM_INPUT = 10  # determines how many input neurons
model = ToyNet(num_input=NUM_INPUT, num_classes=10).to(device)

summary(model, input_size=(BATCH_SIZE, 3, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
ToyNet                                   [2, 10]                   --
├─Linear: 1-1                            --                        100
Total params: 100
Trainable params: 100
Non-trainable params: 0
Total mult-adds (M): 0.00
Input size (MB): 1.20
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 1.20

In [7]:
from torch.nn import CrossEntropyLoss
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

criterion = CrossEntropyLoss().to(device)
optimizer = AdamW(model.parameters(), lr=1e-2, weight_decay=5e-6)
scheduler = CosineAnnealingWarmRestarts(optimizer, 30)

### Experiment Setup

In [8]:
from doren_bnn.utils import Dataset, Experiment

EXPERIMENT_ID = "toynet"
experiment = Experiment(EXPERIMENT_ID, Dataset.CIFAR10, BATCH_SIZE)

Files already downloaded and verified
Files already downloaded and verified


In [9]:
experiment.train(
    device,
    model,
    criterion,
    optimizer,
    scheduler,
    NUM_EPOCHS,
    resume=False,
)

  0%|          | 0/1 [00:00<?, ?it/s]

### Clear Inference

In [10]:
experiment.test(device, model)

tensor([[ 1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,  1.,  1.]], device='cuda:0',
       grad_fn=<SliceBackward0>)


### FHE Inference

In [11]:
from doren_bnn_concrete import preload_keys

preload_keys()

Loading existing secret keys...
Existing secret keys loaded.


In [14]:
model_fhe = ToyNet_FHE(num_input=NUM_INPUT, num_classes=10)
cp = experiment.load_checkpoint(model_fhe, optimizer, scheduler)

If FHE inference is correct, the output should be exactly the same (after rounding) as the output of clear inference.

In [15]:
experiment.test_fhe(model_fhe)

[5.999999999999892, 3.999999999998245, 4.000000000000346, -5.999999999997441, 8.000000000001258, 1.9999999999986997, 3.9999999999961897, 7.9999999999987494, 7.999999999999435, 4.0000000000005205]
Encoder { o: -10.0, delta: 64.0, nb_bit_precision: 1, nb_bit_padding: 5, round: false }
Encoder { o: -1.0, delta: 4.0, nb_bit_precision: 1, nb_bit_padding: 9, round: false }
[5.999999999998707, 3.9999999999996163, 3.9999999999983515, -5.999999999999916, 8.000000000000789, 1.999999999999261, 3.9999999999991935, 7.999999999998703, 7.999999999999915, 4.000000000003164]
Encoder { o: -10.0, delta: 64.0, nb_bit_precision: 1, nb_bit_padding: 5, round: false }
Encoder { o: -1.0, delta: 4.0, nb_bit_precision: 1, nb_bit_padding: 9, round: false }
tensor([[-0.9979,  0.9779, -0.9427, -0.9685,  0.9843, -0.9731,  1.0000,  0.9989,
          0.9714, -0.9967],
        [ 1.0000, -1.0000,  1.0000,  1.0000,  1.0000, -1.0000, -0.9689,  1.0000,
          1.0000,  0.9619]])
