# Manual mobilenet implementation

In [22]:
from models.mobilenet import MyMobileNet
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pytorch_lightning as pl
import torchvision
import torchvision.transforms.v2 as transforms
import json
import os

In [97]:
# Circom defines the range of positives are [0, p/2] and the range of negatives are [(p/2)+1, (p-1)].
CIRCOM_PRIME = 21888242871839275222246405745257275088548364400416034343698204186575808495617
MAX_POSITIVE = CIRCOM_PRIME // 2
MAX_NEGATIVE = MAX_POSITIVE + 1 # The most positive number
CIRCOM_NEGATIVE_1 = 21888242871839275222246405745257275088548364400416034343698204186575808495617 - 1
EXPONENT = 10

def from_circom(x):
    if type(x) != int:
        x = int(x)
    if x > MAX_POSITIVE: 
        return x - CIRCOM_PRIME
    return x
    
def to_circom(x):
    return x % CIRCOM_PRIME
    
def to_circom_input(array: np.array): 
    if type(array) != np.array:
        array = np.array(array)
    int_array = array.round().astype(int)
    int_array = to_circom(int_array)
    return int_array.astype(str).tolist()


def DepthwiseConv(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):
    assert(nFilters % nChannels == 0)
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    
    # out = np.zeros((outRows, outCols, nFilters))
    out = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    remainder = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    # remainder = np.zeros((outRows, outCols, nFilters))
    
    for row in range(outRows):
        for col in range(outCols):
            for channel in range(nChannels):
                for x in range(kernelSize):
                    for y in range(kernelSize):
                        out[row][col][channel] += int(input[row*strides+x, col*strides+y, channel]) * int(weights[x, y, channel])
                
                out[row][col][channel] += int(bias[channel])
                remainder[row][col][channel] = str(int(out[row][col][channel] % n))
                out[row][col][channel] = int(out[row][col][channel] // n)
                            
    return out, remainder

def PointwiseConv2d(nRows, nCols, nChannels, nFilters, strides, n, input, weights, bias):
    kernelSize = 1
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    out = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    str_out = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    remainder = [[[None for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    for row in range(outRows):
        for col in range(outCols):
            for filter in range(nFilters):
                for k in range(nChannels):
                    out[row][col][filter] += int(input[row*strides, col*strides, k]) * int(weights[k, filter])
                    
                out[row][col][filter] += int(bias[filter])
                remainder[row][col][filter] = str(int(out[row][col][filter] % n))
                out[row][col][filter] = int(out[row][col][filter] // n)
                str_out[row][col][filter] = str(out[row][col][filter] % p)
                            
    return out, str_out, remainder

def SeparableConvImpl(nRows, nCols, nChannels, nDepthFilters, nPointFilters, kernelSize, strides, n, input, depthWeights, pointWeights, depthBias, pointBias):
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1

    depth_out, depth_remainder = DepthwiseConv(nRows, nCols, nChannels, nDepthFilters, kernelSize, strides, n, input, depthWeights, depthBias)
    point_out, point_str_out, point_remainder = PointwiseConv2d(outRows, outCols, nChannels, nPointFilters, strides, n, depth_out, pointWeights, pointBias)
    return depth_out, depth_remainder, point_out, point_str_out, point_remainder

In [98]:
class DatasetWrapper(torch.utils.data.Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
        
    def __getitem__(self, index):
        x, y = self.subset[index]
        # print(x)
        if self.transform:
            x = self.transform(x)
        return x, y
        
    def __len__(self):
        return len(self.subset)

In [99]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

testset = torchvision.datasets.CIFAR10(
    root='./CIFAR10', train=False, download=True, transform=transform)

trainset = torchvision.datasets.CIFAR10(
    root='./CIFAR10', train=True, download=True)

# split the train set into train/validation
train_set_size = int(len(trainset) * 0.8)
valid_set_size = len(trainset) - train_set_size

seed = torch.Generator().manual_seed(42)
trainset, validset = torch.utils.data.random_split(trainset, [train_set_size, valid_set_size], generator=seed)

trainset = DatasetWrapper(trainset, transform)
validset = DatasetWrapper(validset, transform)

# Create train dataloader
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=512, shuffle=True, num_workers=24)

# Create validation dataloader
validloader = torch.utils.data.DataLoader(
    validset, batch_size=512, shuffle=False, num_workers=24)

# Create test dataloader
testloader = torch.utils.data.DataLoader(
    testset, batch_size=512, shuffle=False, num_workers=24)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')



Files already downloaded and verified
Files already downloaded and verified


In [100]:
MODEL_WEIGHTS_PATH = './checkpoints/model_small_100epochs.pth'

model = MyMobileNet(trainloader, num_classes=10, alpha=0.25, max_epochs=100)
checkpoint = torch.load(MODEL_WEIGHTS_PATH)
# model.load_state_dict(checkpoint['state_dict'])
model.load_state_dict(checkpoint['net'])
model.eval()

image, label = validset[0]
image = image.unsqueeze(0)
logits = model(image)
pred_idx = logits.argmax()

print(f"Predicted {classes[pred_idx]} - idx: {pred_idx}")

Predicted horse - idx: 7


In [101]:
def Conv2DInt(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    
    out = [[[0 for _ in range(nFilters)] for _ in range((nCols - kernelSize)//strides + 1)] for _ in range((nRows - kernelSize)//strides + 1)]
    remainder = [[[None for _ in range(nFilters)] for _ in range((nCols - kernelSize)//strides + 1)] for _ in range((nRows - kernelSize)//strides + 1)]
    
    for row in range(outRows):
        for col in range(outCols):
            for filter in range(nFilters):
                for channel in range(nChannels):
                    for x in range(kernelSize):
                        for y in range(kernelSize):
                            out[row][col][filter] += int(input[row*strides+x][col*strides+y][channel]) * int(weights[x][y][channel][filter])
                out[row][col][filter] += int(bias[filter])
                remainder[row][col][filter] = int(out[row][col][filter] % n)
                out[row][col][filter] = int(int(out[row][col][filter]) // n)
                
    return out, remainder

In [50]:
weights = model.conv.weight.detach().numpy().transpose(2, 3, 1, 0)
image, label = testset[0]
expected = model.conv(image).detach().numpy()
bias = torch.zeros(weights.shape[3]).numpy()

# padded = pad(image, 1).transpose(1,2,0)
padded = F.pad(image, (1,1,1,1), "constant", 0).numpy()
padded = padded.transpose(1,2,0)

quantized_image = (padded * 10**EXPONENT).round()
quantized_weights = (weights * 10**EXPONENT).round() # .transpose(0, 3, 1, 0) # [nFilters, nChannels, H, W] -> 

output, remainder = Conv2DInt(34, 34, 3, 8, 3, 1, 10**EXPONENT, quantized_image, quantized_weights, bias)
# test_output = output / 10**(EXPONENT)
test_output = [[[out / 10**EXPONENT for out in asdf] for asdf in asdfasdf] for asdfasdf in output]

expected = expected.transpose((1, 2, 0))

assert(np.allclose(test_output, expected, atol=0.00001))
    
circuit_in = to_circom_input(quantized_image)
circuit_weights = to_circom_input(quantized_weights)
circuit_bias = to_circom_input(bias)
circuit_out = to_circom_input(output)
circuit_remainder = to_circom_input(remainder)

input_json_path = "head_input.json"
with open(input_json_path, "w") as input_file:
    json.dump({"in": circuit_in,
               "conv2d_weights": circuit_weights,
               "conv2d_remainder": circuit_remainder,
               "conv2d_out": circuit_out,
               "conv2d_bias": circuit_bias,
               }, input_file)

os.chdir("circuits")
!./head/head_cpp/head ../head_input.json head.wtns
!npx snarkjs groth16 prove head/circuit_final.zkey head.wtns proof.json public_test.json
os.chdir("../")

print("TEST")

TEST


In [117]:
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
    # Use is_grad_enabled to determine whether we are in training mode
    X_hat = (X - moving_mean) #/ torch.sqrt(moving_var + eps)
    Y = gamma * X_hat + beta  # Scale and shift
    return Y

In [158]:
gamma = model.bn.weight
beta = model.bn.bias
mean = model.bn.running_mean
var = model.bn.running_var
eps = model.bn.eps

print(f"{gamma=}")
print(f"{beta=}")
print(f"{mean=}")
print(f"{var=}")
print(f"{eps=}")

a = gamma/(var+eps)**.5
b = beta-gamma*mean/(var+eps)**.5

image, label = testset[0]
image = image.unsqueeze(0)
out = model.conv(image)
print(out.shape)
expected = model.bn(out)
# expected = expected.squeeze().detach().numpy().transpose(1, 2, 0)
expected = torch.permute(expected.squeeze(), (1, 2, 0))

quantized_weights = (weights * 10**EXPONENT).round() # .transpose(0, 3, 1, 0) # [nFilters, nChannels, H, W] -> 
print(f"{out.shape=}")
# print(f"{out=}")

# X, A, B, out, remainder = BatchNormalizationInt(32, 32, 8, 10**EXPONENT, quantized_image, quantized_a, quantized_b)
actual = batch_norm(torch.permute(out, (2, 3, 1, 0)), gamma, beta, mean, var, eps, 0)

gamma=Parameter containing:
tensor([1.7132, 0.9582, 0.6173, 0.8834, 0.6231, 0.8451, 0.7998, 1.2492],
       requires_grad=True)
beta=Parameter containing:
tensor([ 0.2958,  1.0282,  1.3571,  1.1435,  1.2801,  0.9873,  1.1311, -0.0990],
       requires_grad=True)
mean=tensor([ 0.2411,  0.3613,  0.0998, -0.1060,  0.1290, -0.0216, -0.2255,  0.0701])
var=tensor([28.2820, 20.5319,  5.9350,  8.0902, 11.8494, 17.0723, 22.9259,  6.8973])
eps=1e-05
torch.Size([1, 8, 32, 32])
out.shape=torch.Size([1, 8, 32, 32])


In [159]:
actual = batch_norm(torch.permute(out.squeeze(), (1, 2, 0)), gamma, beta, mean, var, eps, 0)

In [160]:
actual.shape, expected.shape

(torch.Size([32, 32, 8]), torch.Size([32, 32, 8]))

In [161]:
actual[0][0]

tensor([2.6740, 2.1584, 1.8093, 2.0547, 1.9869, 0.7890, 0.1890, 0.4811],
       grad_fn=<SelectBackward0>)

In [162]:
expected[0][0]

tensor([2.6740, 2.1584, 1.8093, 2.0547, 1.9869, 0.7890, 0.1890, 0.4811],
       grad_fn=<SelectBackward0>)

In [115]:
mean

tensor([ 0.2411,  0.3613,  0.0998, -0.1060,  0.1290, -0.0216, -0.2255,  0.0701])