# CNN inference on encrypted MNIST

In this example we train a CNN on plain data using Pytorch, then use the weights of this model to do the same inference with a numpy function, which can be later compiled to its homomorphic equivalent, to do inference on encrypted inputs.

In [None]:
import torch
from torchvision import datasets
import torchvision.transforms as transforms
import numpy as np

In [None]:
torch.manual_seed(73)

train_data = datasets.MNIST(
    "data", train=True, download=True, transform=transforms.ToTensor()
)
test_data = datasets.MNIST(
    "data", train=False, download=True, transform=transforms.ToTensor()
)

batch_size = 64

train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=batch_size, shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=batch_size, shuffle=True
)

In [None]:
class Net(torch.nn.Module):
    def __init__(self, hiddens=[256, 64, 32], output=10):
        super(Net, self).__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=1, out_channels=4, kernel_size=7, stride=3
        )
        self.fc1 = torch.nn.Linear(hiddens[0], hiddens[1])
        self.fc2 = torch.nn.Linear(hiddens[1], hiddens[2])
        self.fc3 = torch.nn.Linear(hiddens[2], output)

    def forward(self, x):
        batch_size = x.shape[0]
        x = torch.sigmoid(self.conv(x))
        x = x.view(batch_size, -1)
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = self.fc3(x)
        return x


def train(model, train_loader, criterion, optimizer, n_epochs=10):
    # model in training mode
    model.train()
    for epoch in range(1, n_epochs + 1):

        train_loss = 0.0
        for data, target in train_loader:
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # calculate average losses
        train_loss = train_loss / len(train_loader)

        print("Epoch: {} \tTraining Loss: {:.6f}".format(epoch, train_loss))

    # model in evaluation mode
    model.eval()
    return model

In [None]:
model = Net()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=6e-4)
model = train(model, train_loader, criterion, optimizer, 10)

In [None]:
def test(model, test_loader, criterion):
    # initialize lists to monitor test loss and accuracy
    test_loss = 0.0
    class_correct = list(0.0 for i in range(10))
    class_total = list(0.0 for i in range(10))

    # model in evaluation mode
    model.eval()

    for data, target in test_loader:
        output = model(data)
        loss = criterion(output, target)
        test_loss += loss.item()
        # convert output probabilities to predicted class
        _, pred = torch.max(output, 1)
        # compare predictions to true label
        correct = np.squeeze(pred.eq(target.data.view_as(pred)))
        # calculate test accuracy for each object class
        for i in range(len(target)):
            label = target.data[i]
            class_correct[label] += correct[i].item()
            class_total[label] += 1

    # calculate and print avg test loss
    test_loss = test_loss / len(test_loader)
    print(f"Test Loss: {test_loss:.6f}\n")

    for label in range(10):
        print(
            f"Test Accuracy of {label}: {int(100 * class_correct[label] / class_total[label])}% "
            f"({int(np.sum(class_correct[label]))}/{int(np.sum(class_total[label]))})"
        )

    print(
        f"\nTest Accuracy (Overall): {int(100 * np.sum(class_correct) / np.sum(class_total))}% "
        f"({int(np.sum(class_correct))}/{int(np.sum(class_total))})"
    )


test(model, test_loader, criterion)

## Encrypted inference

In [None]:
import numpy as np
import hnumpy as hnp
# conv2d is not a native numpy operation, so we provide it in npx
import hnumpy.extended as npx
import time

DATA_RANGE = (0, 1.0)
INPUT_SIZE = (28, 28)

# extract weights
kernels = model.conv.weight.detach().numpy()
biases = model.conv.bias.detach().numpy()

fc = {
    1: (model.fc1.weight.T.detach().numpy(), model.fc1.bias.detach().numpy()),
    2: (model.fc2.weight.T.detach().numpy(), model.fc2.bias.detach().numpy()),
    3: (model.fc3.weight.T.detach().numpy(), model.fc3.bias.detach().numpy()),
}


# we implement the forward function of the Pytorch model as a numpy function so that we can compile it
def inference(x):
    # perform convolution
    x = npx.conv2d(x, kernels, biases, stride=3).flatten()
    x = 1 / (1 + np.exp(-x))
    # forward through the linear layers + activation
    for i in range(1, 4):
        x = np.dot(x, fc[i][0]) + fc[i][1]
        x = 1 / (1 + np.exp(-x))
    return x

In [None]:
# compile the function
config = hnp.config.CompilationConfig(parameter_optimizer="handselected")
h = hnp.compile_fhe(
    inference,
    {
        "x": hnp.encrypted_ndarray(bounds=DATA_RANGE, shape=INPUT_SIZE),
    },
    config=config,
)

In [None]:
# generate context and keys
ctx = h.create_context()
keys = ctx.keygen()

We will first run the inference on the entire testset using the simulation mode, to make sure everything is working correctly, then we will run the encrypted inference on a subset of the testset and compare results against the plain evaluation.

In [None]:
# run the plain/simulated-encrypted inference and store outputs/timing

# don't batch as we want one input at a time
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=True)

expected_results = []
results = []
targets = []
times = []

n_iter = len(test_loader)
for i, (data, target) in enumerate(test_loader):
    print(f"evaluation {i + 1}/{n_iter}\r", end="")
    x = data.detach().numpy().reshape(28, 28)
    expected_results.append(inference(x))
    tick = time.perf_counter()
    # simulate encrypted inference
    results.append(h.simulate(x))
    tock = time.perf_counter()
    times.append(tock - tick)
    targets.append(target)

# print results

diff_plain_enc = 0
diff_plain = 0
diff_enc = 0
for i in range(n_iter):
    p = results[i].argmax()
    ep = expected_results[i].argmax()
    t = targets[i]
    if p != ep:
        diff_plain_enc += 1
    if ep != t:
        diff_plain += 1
    if p != t:
        diff_enc += 1

print(f"diff between encrypted (simulated) and plain output is {diff_plain_enc} out of {n_iter} computation")
print(f"plain accuracy {(n_iter - diff_plain) / n_iter}")
print(f"simulated enc accuracy {(n_iter - diff_enc) / n_iter}")
# first time include compilation
print(f"average time {sum(times) / n_iter}")

In [None]:
# run the plain/encrypted inference and store outputs/timing

# don't batch as we want one input at a time
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=True)

expected_results = []
results = []
targets = []
times = []
# run 10 examples only
n_iter = 10
for i, (data, target) in enumerate(test_loader):
    print(f"evaluation {i + 1}/{n_iter}\r", end="")
    x = data.detach().numpy().reshape(28, 28)
    expected_results.append(inference(x))
    tick = time.perf_counter()
    # encrypted inference: comment the simulation and uncomment the encrypted inference
    results.append(h.simulate(x))
    # results.append(h.encrypt_and_run(keys, x))
    tock = time.perf_counter()
    times.append(tock - tick)
    targets.append(target)
    # you can remove the lines below if you want to test on the whole dataset
    if i == n_iter - 1:
        break
        
# print results

diff_plain_enc = 0
diff_plain = 0
diff_enc = 0
for i in range(n_iter):
    p = results[i].argmax()
    ep = expected_results[i].argmax()
    t = targets[i]
    if p != ep:
        diff_plain_enc += 1
    if ep != t:
        diff_plain += 1
    if p != t:
        diff_enc += 1

print(f"diff between encrypted and plain output is {diff_plain_enc} out of {n_iter} computation")
print(f"plain accuracy {(n_iter - diff_plain) / n_iter}")
print(f"enc accuracy {(n_iter - diff_enc) / n_iter}")
# first time include compilation
print(f"average time {sum(times) / n_iter}")