**IMPORTING REQUIRED LIBRARIES**

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets
import torchvision.transforms as transforms
import numpy as np
from Models.PolyKervNet import Kerv2d
from Models.Poly1Net import Poly1Net
import tenseal as ts
from tqdm import tqdm
import os
import torch.optim as optim

**LOADING THE MNIST DATASET**

In [None]:
train_transform=transforms.Compose([transforms.ToTensor()])
test_transform=transforms.Compose([transforms.ToTensor()])

train_data = datasets.MNIST('data', train=True, download=True, transform=train_transform)
test_data = datasets.MNIST('data', train=False, download=True, transform=test_transform)

batch_size = 128

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

**DEFINING THE TRAIN & TEST FUNCTIONS FOR THE PLAINTEXT DOMAIN**

In [None]:
def train(model, train_loader, criterion, optimizer, scheduler, n_epochs=100):
    # model in training mode
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.train()
    for epoch in range(1, n_epochs+1):

        train_loss = 0.0
        for data, targets in train_loader:
            data = data.to(device=device)
            targets = targets.to(device=device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # calculate average losses
        train_loss = train_loss / len(train_loader)

        print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch, train_loss))
        scheduler.step()
    # model in evaluation mode
    model.eval()
    return model

def test(model, test_loader, criterion, classes):
    # initialize lists to monitor test loss and accuracy
    test_loss = 0.0
    class_correct = list(0. for i in range(classes))
    class_total = list(0. for i in range(classes))

    # model in evaluation mode
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.eval()

    for data, targets in test_loader:
        data = data.to(device=device)
        targets = targets.to(device=device)
        output = model(data)
        loss = criterion(output, targets)
        test_loss += loss.item()
        # convert output probabilities to predicted class
        _, pred = torch.max(output, 1)
        # compare predictions to true label
        correct = np.squeeze(pred.eq(targets.data.view_as(pred)))
        # calculate test accuracy for each object class
        for i in range(len(targets)):
            label = targets.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

    # calculate and print avg test loss
    test_loss = test_loss/len(test_loader)
    print(f'Test Loss: {test_loss:.6f}\n')

    for label in range(classes):
        print(
            f'Test Accuracy of {label}: {int(100 * class_correct[label] / class_total[label])}% '
            f'({int(np.sum(class_correct[label]))}/{int(np.sum(class_total[label]))})'
        )

    print(
        f'\nTest Accuracy (Overall): {int(100 * np.sum(class_correct) / np.sum(class_total))}% ' 
        f'({int(np.sum(class_correct))}/{int(np.sum(class_total))})'
    )

**DEFINING THE TRAINING CONFIGURATIONS/PARAMETERS**

In [None]:
#device as cuda if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#calling the Poly1Net
model = Poly1Net(10).to(device)

#loss function and optimizer
criterion=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(),lr=1e-3)
scheduler = optim.lr_scheduler.StepLR(optimizer,step_size=15, gamma=0.1)

**TRAINING**

In [None]:
model = train(model, train_loader, criterion, optimizer, scheduler, n_epochs=20)

**TESTING**

In [None]:
test(model, test_loader, criterion, classes=10)

**DEFINING THE ENCRYPTION AND DECRYPTION CLASS FOR THE POLY1NET CKKS ENCRYPTION EVALUATION**

In [None]:
class EncPoly1Net:
    def __init__(self, torch_nn):
        self.kerv1_weight = torch_nn.kerv1.weight.data.view(
            torch_nn.kerv1.out_channels, torch_nn.kerv1.kernel_size[0],
            torch_nn.kerv1.kernel_size[1]).tolist()
        self.kerv1_bias = torch_nn.kerv1.bias.data.tolist()
        
        self.fc1_weight = torch_nn.fc1.weight.T.data.tolist()
        self.fc1_bias = torch_nn.fc1.bias.data.tolist()
        
        self.fc2_weight = torch_nn.fc2.weight.T.data.tolist()
        self.fc2_bias = torch_nn.fc2.bias.data.tolist()
        
        
    def forward(self, enc_x, windows_nb):
        # conv layer
        enc_channels = []
        for kernel, bias in zip(self.kerv1_weight, self.kerv1_bias):
            y = (enc_x.conv2d_im2col(kernel, windows_nb))^2 + bias
            enc_channels.append(y)
        # pack all channels into a single flattened vector
        enc_x = ts.CKKSVector.pack_vectors(enc_channels)
        # fc1 layer
        enc_x = enc_x.mm(self.fc1_weight) + self.fc1_bias
        # fc2 layer
        enc_x = enc_x.mm(self.fc2_weight) + self.fc2_bias
        return enc_x
    
    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)

    
def enc_test(context, model, test_loader, criterion, kernel_shape, stride):
    # initialize lists to monitor test loss and accuracy
    test_loss = 0.0
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))

    for data, target in test_loader:
        # Encoding and encryption
        x_enc, windows_nb = ts.im2col_encoding(
            context, data.view(28, 28).tolist(), kernel_shape[0],
            kernel_shape[1], stride
        )
        # Encrypted evaluation
        enc_output = enc_model(x_enc, windows_nb)
        # Decryption of result
        output = enc_output.decrypt()
        output = torch.tensor(output).view(1, -1)

        # compute loss
        loss = criterion(output, target)
        test_loss += loss.item()
        
        # convert output probabilities to predicted class
        _, pred = torch.max(output, 1)
        # compare predictions to true label
        correct = np.squeeze(pred.eq(target.data.view_as(pred)))
        # calculate test accuracy for each object class
        label = target.data[0]
        class_correct[label] += correct.item()
        class_total[label] += 1


    # calculate and print avg test loss
    test_loss = test_loss / sum(class_total)
    print(f'Test Loss: {test_loss:.6f}\n')

    for label in range(10):
        print(
            f'Test Accuracy of {label}: {int(100 * class_correct[label] / class_total[label])}% '
            f'({int(np.sum(class_correct[label]))}/{int(np.sum(class_total[label]))})'
        )

    print(
        f'\nTest Accuracy (Overall): {int(100 * np.sum(class_correct) / np.sum(class_total))}% ' 
        f'({int(np.sum(class_correct))}/{int(np.sum(class_total))})'
    )


# Load one element at a time
enc_testdl = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=True)
# required for encoding
kernel_shape = model.kerv1.kernel_size
stride = model.kerv1.stride[0]


## Encryption Parameters

# controls precision of the fractional part
bits_scale = 26

# Create TenSEAL context
context = ts.context(
    ts.SCHEME_TYPE.CKKS,
    poly_modulus_degree=8192,
    coeff_mod_bit_sizes=[31, bits_scale, bits_scale, bits_scale, bits_scale, bits_scale, bits_scale, 31]
)

# set the scale
context.global_scale = pow(2, bits_scale)

# galois keys are required to do ciphertext rotations
context.generate_galois_keys()

enc_model = EncPoly1Net(model)
enc_test(context, enc_model, enc_testdl, criterion, kernel_shape, stride)