# Manual mobilenet implementation

# Imports

In [127]:
from models.mobilenet import ZkMobileNet
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pytorch_lightning as pl
import torchvision
import torchvision.transforms.v2 as transforms
import json
import os
from typing import List

# Utility functions

In [128]:
# Circom defines the range of positives are [0, p/2] and the range of negatives are [(p/2)+1, (p-1)].
# CIRCOM_PRIME = 21888242871839275222246405745257275088548364400416034343698204186575808495617 # bn254
p = CIRCOM_PRIME = 28948022309329048855892746252171976963363056481941647379679742748393362948097 # vesta
MAX_POSITIVE = CIRCOM_PRIME // 2
MAX_NEGATIVE = MAX_POSITIVE + 1 # The most positive number
CIRCOM_NEGATIVE_1 = CIRCOM_PRIME - 1
EXPONENT = 15

def from_circom(x):
    if type(x) != int:
        x = int(x)
    if x > MAX_POSITIVE: 
        return x - CIRCOM_PRIME
    return x
    
def to_circom(x):
    return x % CIRCOM_PRIME
    
def to_circom_input(array: np.array): 
    if type(array) != np.array:
        array = np.array(array)
    int_array = array.round().astype(int)
    int_array = to_circom(int_array)
    return int_array.astype(str).tolist()

# taken from https://github.com/socathie/circomlib-ml
def Conv2DInt(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):
    p = CIRCOM_PRIME
    Input = [[[str(int(input[i][j][k]) % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
    Weights = [[[[str(int(weights[i][j][k][l]) % p) for l in range(nFilters)] for k in range(nChannels)] for j in range(kernelSize)] for i in range(kernelSize)]
    Bias = [str(int(bias[i]) % p) for i in range(nFilters)]
    out = [[[0 for _ in range(nFilters)] for _ in range((nCols - kernelSize)//strides + 1)] for _ in range((nRows - kernelSize)//strides + 1)]
    remainder = [[[None for _ in range(nFilters)] for _ in range((nCols - kernelSize)//strides + 1)] for _ in range((nRows - kernelSize)//strides + 1)]
    for i in range((nRows - kernelSize)//strides + 1):
        for j in range((nCols - kernelSize)//strides + 1):
            for m in range(nFilters):
                for k in range(nChannels):
                    for x in range(kernelSize):
                        for y in range(kernelSize):
                            out[i][j][m] += int(input[i*strides+x][j*strides+y][k]) * int(weights[x][y][k][m])
                out[i][j][m] += int(bias[m])
                remainder[i][j][m] = str(int(out[i][j][m]) % n)
                out[i][j][m] = str(int(out[i][j][m]) // n % p)
    return Input, Weights, Bias, out, remainder
    
# taken from https://github.com/socathie/circomlib-ml
def BatchNormalizationInt(nRows, nCols, nChannels, n, X_in, a_in, b_in):
    p = CIRCOM_PRIME
    X = [[[str(int(X_in[i][j][k]) % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
    A = [str(int(a_in[k]) % p) for k in range(nChannels)]
    B = [str(int(b_in[k]) % p) for k in range(nChannels)]
    out = [[[int(0) for _ in range(nChannels)] for _ in range(nCols)] for _ in range(nRows)]
    out_str = [[[str(0) for _ in range(nChannels)] for _ in range(nCols)] for _ in range(nRows)]
    remainder = [[[str(0) for _ in range(nChannels)] for _ in range(nCols)] for _ in range(nRows)]
    for i in range(nRows):
        for j in range(nCols):
            for k in range(nChannels):
                out[i][j][k] = int(int(X_in[i][j][k]) * int(a_in[k]) + int(b_in[k]))
                remainder[i][j][k] = str(int(out[i][j][k]) % n)
                out[i][j][k] = int(out[i][j][k] // n)
                out_str[i][j][k] = str(out[i][j][k] % p)
    return X, A, B, out_str, out, remainder

def AveragePooling2DInt (nRows, nCols, nChannels, poolSize, strides, input):
    Input = [[[str(input[i][j][k] % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
    out = [[[0 for _ in range(nChannels)] for _ in range((nCols-poolSize)//strides + 1)] for _ in range((nRows-poolSize)//strides + 1)]
    out_str = [[[str(0) for _ in range(nChannels)] for _ in range((nCols-poolSize)//strides + 1)] for _ in range((nRows-poolSize)//strides + 1)]
    remainder = [[[None for _ in range(nChannels)] for _ in range((nCols-poolSize)//strides + 1)] for _ in range((nRows-poolSize)//strides + 1)]
    for i in range((nRows-poolSize)//strides + 1):
        for j in range((nCols-poolSize)//strides + 1):
            for k in range(nChannels):
                for x in range(poolSize):
                    for y in range(poolSize):
                        out[i][j][k] += input[i*strides+x][j*strides+y][k]
                remainder[i][j][k] = str(out[i][j][k] % poolSize**2 % p)
                out[i][j][k] = int(int(out[i][j][k]) // poolSize**2)
                out_str[i][j][k] = str(out[i][j][k] % p)
    return Input, out_str, out, remainder
    
def DenseInt(nInputs, nOutputs, n, input, weights, bias):
    Input = [str(int(input[i]) % p) for i in range(nInputs)]
    Weights = [[str(int(weights[i][j]) % p) for j in range(nOutputs)] for i in range(nInputs)]
    Bias = [str(int(bias[i]) % p) for i in range(nOutputs)]
    out = [0 for _ in range(nOutputs)]
    out_str = [str(0) for _ in range(nOutputs)]
    remainder = [None for _ in range(nOutputs)]
    for j in range(nOutputs):
        for i in range(nInputs):
            out[j] += int(input[i]) * int(weights[i][j])
        out[j] += int(bias[j])
        remainder[j] = str(int(out[j]) % n)
        out[j] = int(int(out[j]) // n % p)
        out_str[j] = str(int(out[j]) % p)
    return Input, Weights, Bias, out_str, out, remainder
        

def PaddedDepthwiseConv(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):
    """output is padded"""
    assert(nFilters % nChannels == 0)
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    
    # out = np.zeros((outRows, outCols, nFilters))
    Input = [[[str(int(input[i][j][k]) % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
    Weights = [[[str(int(weights[i][j][k]) % p) for k in range(weights.shape[2])] for j in range(weights.shape[1])] for i in range(weights.shape[0])]
    out = [[[0 for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    Bias = bias.round().astype(int).astype(str).tolist()
    out_str = [[[str(0) for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    remainder = [[[str(0) for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    
    for row in range(outRows):
        for col in range(outCols):
            for channel in range(nChannels):
                for x in range(kernelSize):
                    for y in range(kernelSize):
                        out[row+1][col+1][channel] += int(input[row*strides+x][col*strides+y][channel]) * int(weights[x, y, channel])
                
                out[row+1][col+1][channel] += int(bias[channel])
                remainder[row+1][col+1][channel] = str(int(out[row+1][col+1][channel] % n))
                out[row+1][col+1][channel] = int(out[row+1][col+1][channel] // n)
                out_str[row+1][col+1][channel] = str(out[row+1][col+1][channel] % p)
                            
    return Input, Weights, Bias, out_str, out, remainder

def PointwiseConv2d(nRows, nCols, nChannels, nFilters, strides, n, input, weights, bias):
    kernelSize = 1
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    
    Input = [[[str(int(input[i][j][k]) % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
    Weights = [[str(int(weights[i][j]) % p)for j in range(weights.shape[1])] for i in range(weights.shape[0])]
    out = [[[0 for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    Bias = bias.round().astype(int).astype(str).tolist()
    out_str = [[[str(0) for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    remainder = [[[str(0) for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    
    for row in range(outRows):
        for col in range(outCols):
            for filter in range(nFilters):
                for k in range(nChannels):
                    out[row][col][filter] += int(input[row*strides][col*strides][k]) * int(weights[k, filter])
                    
                out[row][col][filter] += int(bias[filter])
                remainder[row][col][filter] = str(int(out[row][col][filter] % n))
                out[row][col][filter] = int(out[row][col][filter] // n)
                out_str[row][col][filter] = str(out[row][col][filter] % p)
                            
    return Input, Weights, Bias, out_str, out, remainder

In [129]:
class DatasetWrapper(torch.utils.data.Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
        
    def __getitem__(self, index):
        x, y = self.subset[index]
        
        if self.transform:
            x = self.transform(x)
        return x, y
        
    def __len__(self):
        return len(self.subset)

In [130]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    # transforms.Resize((28, 28)),
])

testset = torchvision.datasets.CIFAR10(
    root='./CIFAR10', train=False, download=True, transform=transform)

trainset = torchvision.datasets.CIFAR10(
    root='./CIFAR10', train=True, download=True)

# split the train set into train/validation
train_set_size = int(len(trainset) * 0.8)
valid_set_size = len(trainset) - train_set_size

seed = torch.Generator().manual_seed(42)
trainset, validset = torch.utils.data.random_split(trainset, [train_set_size, valid_set_size], generator=seed)

trainset = DatasetWrapper(trainset, transform)
validset = DatasetWrapper(validset, transform)

# Create train dataloader
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=512, shuffle=True, num_workers=24)

# Create validation dataloader
validloader = torch.utils.data.DataLoader(
    validset, batch_size=512, shuffle=False, num_workers=24)

# Create test dataloader
testloader = torch.utils.data.DataLoader(
    testset, batch_size=512, shuffle=False, num_workers=24)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')



Files already downloaded and verified
Files already downloaded and verified


In [131]:
MODEL_WEIGHTS_PATH = "./checkpoints/no_padding_100epochs.pth", 
class Test(ZkMobileNet):
    def forward(self, x):
        # print("STARTING SHAPE: ", x.shape)
        x = self.relu(self.bn(self.conv(x)))
        # print("CONV1 SHAPE: ", x.shape)
        x = self.features(x)
        # x = self.relu(x)
        # print("BACKBONE SHAPE: ", x.shape)
        # print("PRE AVG-POOL SHAPE: ", x.shape)
        x = F.avg_pool2d(x, 6)
        # print("POST AVG-POOL SHAPE: ", x.shape)
        x = x.view(x.size()[0], -1)
        # print("PRE-CLASSIFIER SHAPE: ", x.shape)
        x = self.linear(x)
        # print("POST-CLASSIFIER SHAPE: ", x.shape)
        return x
        
# alpha = (0.25 * 0.5 * 0.75)
alpha = (0.25 * 0.125)
# alpha = 0.25
print(alpha)
model = Test(trainloader, num_classes=10, alpha=alpha, max_epochs=100)

saved = torch.load("./checkpoints/no_padding_100epochs.pth")
# model.load_state_dict(saved['state_dict'])
# model.load_state_dict(saved['net'])
model.eval()

image, label = validset[0]
print("IMAGE SHAPE: ", image.shape)
image = image.unsqueeze(0)
print("IMAGE SHAPE: ", image.shape)
logits = model(image)
pred_idx = logits.argmax()

print(f"Predicted {classes[pred_idx]} - idx: {pred_idx}")

0.03125
IMAGE SHAPE:  torch.Size([3, 32, 32])
IMAGE SHAPE:  torch.Size([1, 3, 32, 32])
Predicted deer - idx: 4


In [132]:
model.bn

BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

# Testing head layer

In [133]:
np.array(quantized_in).shape

(6, 6, 32)

In [163]:
# CONVOLUTION LAYER
weights = model.conv.weight.detach().numpy().transpose(2, 3, 1, 0)
image, label = testset[0]

expected = model.conv(image).detach().numpy()
bias = torch.zeros(weights.shape[3]).numpy()

# padded = pad(image, 1).transpose(1,2,0)
padded = F.pad(image, (1,1,1,1), "constant", 0).numpy()
padded = padded.transpose(1,2,0)

quantized_image = (padded * 10**EXPONENT).round()
quantized_weights = (weights * 10**EXPONENT).round() # .transpose(0, 3, 1, 0) # [nFilters, nChannels, H, W] -> 

circuit_in, circuit_conv_weights, circuit_conv_bias, circuit_conv_out, circuit_conv_remainder = Conv2DInt(34, 34, 3, 1, 3, 1, 10**EXPONENT, quantized_image, quantized_weights, bias)
# output, remainder = Conv2DInt(34, 34, 3, 8, 3, 1, 10**EXPONENT, quantized_image, quantized_weights, bias)
# test_output = output / 10**(EXPONENT)
test_output = [[[from_circom(int(out)) / 10**EXPONENT for out in asdf] for asdf in asdfasdf] for asdfasdf in circuit_conv_out]

expected = expected.transpose((1, 2, 0))

assert(np.allclose(test_output, expected, atol=1e-6))

# BATCH NORM CONSTANTS
gamma = model.bn.weight
beta = model.bn.bias
mean = model.bn.running_mean
var = model.bn.running_var
eps = model.bn.eps

a = (gamma/(var+eps)**.5).detach()
b = (beta-gamma*mean/(var+eps)**.5).detach().tolist()

quantized_a = (a * 10**EXPONENT).tolist()
quantized_b = [bi * 10**(2*EXPONENT) for bi in b]

# BATCH NORM USING PYTORCH OUTPUT
image, label = testset[0]
image = image.unsqueeze(0)
out = model.conv(image)
expected = model.bn(out)

expected = torch.permute(expected.squeeze(0), (1, 2, 0))

quantized_in = torch.permute(out.squeeze(0), (1, 2, 0)) * 10**EXPONENT

X, A, B, out_str, actual, remainder = BatchNormalizationInt(32, 32, 1, 10**EXPONENT, quantized_in, quantized_a, quantized_b)
test_output = [[[from_circom(int(out)) / 10**EXPONENT for out in vec] for vec in matrix] for matrix in actual]
assert(torch.allclose(torch.Tensor(test_output), expected, atol=1e-6))

# BATCH NORM USING CIRCUIT CONV OUTPUT
quantized_in = [[[from_circom(int(out)) for out in vec] for vec in matrix] for matrix in circuit_conv_out]

_, circuit_bn_a, circuit_bn_b, out_str, circuit_bn_out, circuit_bn_remainder = BatchNormalizationInt(32, 32, 1, 10**EXPONENT, quantized_in, quantized_a, quantized_b)

test_output = [[[int(from_circom(int(out))) / 10**EXPONENT for out in vec] for vec in matrix] for matrix in actual]

assert(torch.allclose(torch.Tensor(test_output), expected, atol=1e-6))

# RELU USING CIRCUIT OUTPUT
relu_in = [[[to_circom(bn_out) if bn_out < MAX_POSITIVE else 0 for bn_out in vec] for vec in matrix] for matrix in circuit_bn_out]
relu_out = [[[str(bn_out) if bn_out < MAX_POSITIVE else 0 for bn_out in vec] for vec in matrix] for matrix in relu_in]

input_json_path = "test_inputs/head_input.json"
with open(input_json_path, "w") as input_file:
    json.dump({"in": circuit_in,
               "conv2d_weights": circuit_conv_weights,
               "conv2d_bias": circuit_conv_bias,
               "conv2d_out": circuit_conv_out,
               "conv2d_remainder": circuit_conv_remainder,
               
               "bn_a": circuit_bn_a,
               "bn_b": circuit_bn_b,
               "bn_out": circuit_bn_out,
               "bn_remainder": circuit_bn_remainder,
               
               "relu_out": relu_out,
               }, input_file)

os.chdir("circuits")
!./head/head_cpp/head ../test_inputs/head_input.json head.wtns
# !npx snarkjs groth16 prove head/circuit_final.zkey head.wtns proof.json public_test.json
os.chdir("../")

print("done")

done


# Generating input for backbone layer

In [85]:
# CONVOLUTION LAYER
weights = model.conv.weight.detach().numpy().transpose(2, 3, 1, 0)
image, label = testset[0]

expected = model.conv(image).detach().numpy()
bias = torch.zeros(weights.shape[3]).numpy()

# padded = pad(image, 1).transpose(1,2,0)
padded = F.pad(image, (1,1,1,1), "constant", 0).numpy()
padded = padded.transpose(1,2,0)

quantized_image = (padded * 10**EXPONENT).round()
quantized_weights = (weights * 10**EXPONENT).round() # .transpose(0, 3, 1, 0) # [nFilters, nChannels, H, W] -> 

circuit_in, circuit_conv_weights, circuit_conv_bias, circuit_conv_out, circuit_conv_remainder = Conv2DInt(34, 34, 3, 1, 3, 1, 10**EXPONENT, quantized_image, quantized_weights, bias)
# output, remainder = Conv2DInt(34, 34, 3, 8, 3, 1, 10**EXPONENT, quantized_image, quantized_weights, bias)
# test_output = output / 10**(EXPONENT)
test_output = [[[from_circom(int(out)) / 10**EXPONENT for out in asdf] for asdf in asdfasdf] for asdfasdf in circuit_conv_out]

expected = expected.transpose((1, 2, 0))

assert(np.allclose(test_output, expected, atol=1e-6))

# BATCH NORM CONSTANTS
gamma = model.bn.weight
beta = model.bn.bias
mean = model.bn.running_mean
var = model.bn.running_var
eps = model.bn.eps

a = (gamma/(var+eps)**.5).detach()
b = (beta-gamma*mean/(var+eps)**.5).detach().tolist()

quantized_a = (a * 10**EXPONENT).tolist()
quantized_b = [bi * 10**(2*EXPONENT) for bi in b]

# BATCH NORM USING PYTORCH OUTPUT
image, label = testset[0]
image = image.unsqueeze(0)
out = model.conv(image)
expected = model.bn(out)

expected = torch.permute(expected.squeeze(0), (1, 2, 0))

quantized_in = torch.permute(out.squeeze(0), (1, 2, 0)) * 10**EXPONENT

X, A, B, _, actual, remainder = BatchNormalizationInt(32, 32, 1, 10**EXPONENT, quantized_in, quantized_a, quantized_b)
test_output = [[[from_circom(int(out)) / 10**EXPONENT for out in vec] for vec in matrix] for matrix in actual]
assert(torch.allclose(torch.Tensor(test_output), expected, atol=1e-6))

# BATCH NORM USING CIRCUIT CONV OUTPUT
quantized_in = [[[from_circom(int(out)) for out in vec] for vec in matrix] for matrix in circuit_conv_out]

_, circuit_bn_a, circuit_bn_b, _, circuit_bn_out, circuit_bn_remainder = BatchNormalizationInt(32, 32, 1, 10**EXPONENT, quantized_in, quantized_a, quantized_b)

test_output = [[[int(from_circom(int(out))) / 10**EXPONENT for out in vec] for vec in matrix] for matrix in circuit_bn_out]

assert(torch.allclose(torch.Tensor(test_output), expected, atol=1e-6))

# RELU USING CIRCUIT OUTPUT
relu_in = [[[to_circom(bn_out) if bn_out < MAX_POSITIVE else 0 for bn_out in vec] for vec in matrix] for matrix in circuit_bn_out]
relu_out = [[[str(bn_out) if bn_out < MAX_POSITIVE else 0 for bn_out in vec] for vec in matrix] for matrix in relu_in]

image, label = testset[0]
image = image.unsqueeze(0)
conv_output = model.conv(image)
bn_output = model.bn(conv_output)
relu_expected = model.relu(bn_output).squeeze(0).detach().numpy().transpose((1,2,0))

test_output = [[[out / 10**EXPONENT if out > 0 else 0 for out in vec] for vec in matrix] for matrix in circuit_bn_out]

assert(np.allclose(test_output, relu_expected, atol=1e-6))

# Auxiliary functions for converting between pytorch and circom input formats

In [135]:
from typing import List, Optional, Union

def circom2pytorch(circuit_output):
    formatted = torch.Tensor([[[int(out) / 10**EXPONENT if out > 0 else 0 for out in vec] for vec in matrix] for matrix in circuit_output])
    formatted = torch.permute(formatted, (2, 0, 1)).unsqueeze(0)
    return formatted
    
def pytorch2quantized(pytorch_output: torch.Tensor):
    return pytorch_output.squeeze().detach().numpy().transpose((1, 2, 0)) * 10**EXPONENT

def dequantize(input: List[List[List[int]]], padding: int=1, channel_padding: Optional[int]=None):
    test_output = np.array([[[int(value) / 10**EXPONENT for value in vec] for vec in matrix] for matrix in input])
    
    if channel_padding is None or channel_padding == 0:
        return test_output[padding:-padding, padding:-padding, :]
    
    return test_output[padding:-padding, padding:-padding, :channel_padding]
    
def check_quantized_input(quantized_input):
    """quantized_input should be quantized and should be (Height, Depth, Channels)"""
    assert(len(np.array(quantized_input).shape) == 3)

def check_pytorch_input(quantized_input):
    """pytorch_input should be quantized and should be (N=1, Channels, Height, Depth)"""
    assert(len(quantized_input.shape) == 4)
    assert(quantized_input.shape[0] == 1)
    assert(type(quantized_input) == torch.Tensor)
    
def pad(cube: List[List[List[int]]], square_pad: int, channel_pad: int):
    max_i = len(cube) + square_pad*2
    max_j = len(cube[0]) + square_pad*2
    max_k = len(cube[0][0]) + channel_pad
    result = [[[0 for k in range(max_k)] for j in range(max_j)] for i in range(max_i)]
    
    for i in range(len(cube) + square_pad*2):
        for j in range(len(cube[0]) + square_pad*2):
            for k in range(len(cube[0][0]) + channel_pad):
                if i >= square_pad and i < len(cube) + square_pad and j >= square_pad and j < len(cube[0]) + square_pad and k < len(cube[0][0]):
                    # print(f"{i-square_pad=}, {j-square_pad=}, {k=} {channel_pad=} {len(cube[0][0])=} {channel_pad - len(cube[0][0])=}")
                    result[i][j][k] = cube[i-square_pad][j-square_pad][k]

    # print(f"{max_i=}, {max_j=}, {max_k=}")
                

    return result

# Auxiliary dataclasses for circuit inputs

In [136]:
from pydantic import BaseModel

class ArbBaseModel(BaseModel):
    class Config:
        arbitrary_types_allowed = True

class CircuitConvInput(BaseModel):
    input: Optional[List[List[List[str]]]]
    weights: Union[List[List[List[str]]], List[List[str]]]
    bias: List[str]
    out_str: List[List[List[str]]]
    remainder: List[List[List[str]]]
    
    out: List[List[List[int]]]
    
    
class CircuitBatchNormInput(BaseModel):
    input: Optional[List[List[List[str]]]]
    a: List[str]
    b: List[str]
    out_str: List[List[List[str]]]
    out: List[List[List[int]]]
    remainder: List[List[List[str]]]


class ConvBN(BaseModel):
    conv: CircuitConvInput
    bn: CircuitBatchNormInput
    
    def to_dict(self, prefix: str):
        return {
            f"{prefix}conv_weights": self.conv.weights,
            f"{prefix}conv_bias": self.conv.bias,
            f"{prefix}conv_remainder": self.conv.remainder,
            f"{prefix}conv_out": self.conv.out_str,
            
            f"{prefix}bn_a": self.bn.a,
            f"{prefix}bn_b": self.bn.b,
            f"{prefix}bn_remainder": self.bn.remainder,
            f"{prefix}bn_out": self.bn.out_str,
        }

class CircuitLayerInput(BaseModel):
    depthwise: ConvBN
    pointwise: ConvBN

    def input(self):
        return self.depthwise.conv.input
        
    def out(self):
        return self.pointwise.bn.out
    
    def to_dict(self, prefix: str):
        return {
            **self.depthwise.to_dict(f"{prefix}dw_"),
            **self.pointwise.to_dict(f"{prefix}pw_")
        }
    def to_json(self, prefix: str, json_path: str):
        with open(json_path, "w") as input_file:
            json.dump(self.to_dict(prefix), input_file)
            
def TypedPaddedDepthwiseConv(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):
    Input, Weights, Bias, out_str, out, remainder = PaddedDepthwiseConv(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias)
    return CircuitConvInput(
        input=Input, 
        weights=Weights, 
        bias=Bias, 
        out_str=out_str, 
        out=out, 
        remainder=remainder, 
    )
    
def BatchNormalizationPadded(nRows, nCols, nChannels, n, X_in, a_in, b_in, padding):
    p = CIRCOM_PRIME
    X = [[[str(int(X_in[i][j][k]) % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
    A = [str(int(a_in[k]) % p) for k in range(nChannels)]
    B = [str(int(b_in[k]) % p) for k in range(nChannels)]
    out = [[[0 for _ in range(nChannels)] for _ in range(nCols)] for _ in range(nRows)]
    out_str = [[[str(0) for _ in range(nChannels)] for _ in range(nCols)] for _ in range(nRows)]
    remainder = [[[str(n) for _ in range(nChannels)] for _ in range(nCols)] for _ in range(nRows)]
    for i in range(padding, nRows-padding):
        for j in range(padding, nCols-padding):
            for k in range(nChannels):
                out[i][j][k] = int(int(X_in[i][j][k]) * int(a_in[k]) + int(b_in[k]))
                remainder[i][j][k] = str(int(out[i][j][k]) % n)
                out[i][j][k] = int(out[i][j][k] // n)
                out_str[i][j][k] = str(out[i][j][k] % p)
    return X, A, B, out_str, out, remainder
    
def TypedBatchNormalizationInt(nRows, nCols, nChannels, n, X_in, a_in, b_in, padding):
    X, A, B, out_str, out, remainder = BatchNormalizationPadded(nRows, nCols, nChannels, n, X_in, a_in, b_in, padding)
    return CircuitBatchNormInput(
        input=X, 
        a=A, 
        b=B, 
        out_str=out_str, 
        out=out, 
        remainder=remainder, 
    )
    
def TypedPointwiseConv2d(nRows, nCols, nChannels, nFilters, strides, n, input, weights, bias):
    Input, Weights, Bias, out_str, out, remainder = PointwiseConv2d(nRows, nCols, nChannels, nFilters, strides, n, input, weights, bias)
    return CircuitConvInput(
        input=Input, 
        weights=Weights, 
        bias=Bias, 
        out_str=out_str, 
        out=out, 
        remainder=remainder, 
    )
        

In [137]:
# Getting proper inputs

In [138]:
quantized_input = torch.Tensor([[[out if out > 0 else 0 for out in vec] for vec in matrix] for matrix in circuit_bn_out])
pytorch_input = torch.Tensor([[[int(out) / 10**EXPONENT if out > 0 else 0 for out in vec] for vec in matrix] for matrix in circuit_bn_out])
pytorch_input = torch.permute(pytorch_input, (2, 0, 1)).unsqueeze(0)

In [140]:
class CircuitBackbone():
    def __init__(self, model: ZkMobileNet, max_dims: List[int]):
        self.model = model
        self.model.eval()
        self.max_rows = max_dims[0]
        self.max_cols = max_dims[1]
        self.max_channels = max_dims[2]
        self.max_depth_filters = self.max_point_filters = max_dims[2]
        self.dw_kernel_size = 3
        self.stride = 1
        self.scalar_factor = 10**EXPONENT

    def _forward_module(self, module: nn.Module, input: torch.Tensor):
        output = module(input)
        expected = output.squeeze(0).detach().numpy().transpose((1,2,0))
        return output, expected
        
    def _circuit_conv_bn(self, layer: int, pytorch_input: torch.Tensor, quantized_input: List[List[List[int]]], conv: nn.Sequential):
        conv_output, conv_expected = self._forward_module(conv[0], pytorch_input)
        
        circuit_conv_input = self._get_conv_circuit_input(layer, quantized_input, conv_expected, conv[0])
        
        bn_output, bn_expected = self._forward_module(conv[1], conv_output)
        
        circuit_bn_input = self._get_bn_circuit_input(layer, circuit_conv_input.out, bn_expected, conv[1])
        
        return ConvBN(conv=circuit_conv_input, bn=circuit_bn_input), bn_output

    def circuit_layer_inputs(self, layer: int, pytorch_input: torch.Tensor, quantized_input: List[List[List[int]]]):
        check_pytorch_input(pytorch_input)
        check_quantized_input(quantized_input)

        depthwise, dw_output = self._circuit_conv_bn(layer, pytorch_input, quantized_input, self.model.features[layer].dw_conv)
        pointwise, pw_output = self._circuit_conv_bn(layer, dw_output, depthwise.bn.out, self.model.features[layer].pw_conv)

        layer_input = CircuitLayerInput(depthwise=depthwise, pointwise=pointwise)

        layer_input.to_json(f"l{layer}", "help.json")
        return layer_input, pw_output

    def _get_conv_circuit_input(self, layer: int, quantized_input: List[List[List[int]]], expected: np.array, module: nn.Conv2d) -> CircuitConvInput:
        check_quantized_input(quantized_input)

        padding = layer+1
        weights = module.weight.detach()
        
        if module.kernel_size == (3, 3): 
            filter_padding = self.max_depth_filters - weights.shape[0]
            assert(filter_padding >= 0)
            
            bias = np.zeros(weights.shape[0] + filter_padding)

            weights = weights.squeeze(1)
            weights = torch.permute(weights, (1, 2, 0))
            assert(filter_padding >= 0)
        
            if len(quantized_input[0][0]) != self.max_depth_filters:
                quantized_input = pad(quantized_input, 0, filter_padding)
                # quantized_input = F.pad(quantized_input, (0, filter_padding), "constant", 0)
            
            padded_weights = F.pad(weights, (0, filter_padding), "constant", 0)
            quantized_weights = padded_weights * 10**EXPONENT
            
            conv_input = TypedPaddedDepthwiseConv(
                self.max_rows, 
                self.max_cols, 
                self.max_channels, 
                self.max_depth_filters, 
                self.dw_kernel_size, 
                self.stride, 
                self.scalar_factor,
                quantized_input, 
                quantized_weights.round(), 
                bias
            )
            test_output = dequantize(conv_input.out, padding, self.max_point_filters - filter_padding)
            
        elif module.kernel_size == (1, 1):
            channel_padding = self.max_depth_filters - weights.shape[1]
            filter_padding = self.max_point_filters - weights.shape[0]
            
            assert(filter_padding >= 0)
            bias = np.zeros(weights.shape[0] + filter_padding)

            weights = weights.squeeze(-1).squeeze(-1) # removing H x W
            weights = torch.permute(weights, (1, 0))
            assert(filter_padding >= 0)
        
            if len(quantized_input[0][0]) != self.max_point_filters:
                quantized_input = F.pad(quantized_input, (0, filter_padding), "constant", 0)
                
            assert(channel_padding >= 0)
            padded_weights = F.pad(weights, (0, filter_padding, 0, channel_padding), "constant", 0)
            quantized_weights = padded_weights * 10**EXPONENT
            conv_input = TypedPointwiseConv2d(
                self.max_rows, 
                self.max_cols, 
                self.max_depth_filters,
                self.max_point_filters,
                self.stride, 
                self.scalar_factor,
                quantized_input,
                quantized_weights.round(),
                bias
            )
            
            test_output = dequantize(conv_input.out, padding, self.max_point_filters - filter_padding)
        
        assert(np.allclose(expected, test_output, atol=1e-4))
        
        return conv_input
        
    def _get_bn_circuit_input(self, layer: int, quantized_input: List[List[List[int]]], expected: np.array, batch_norm: nn.modules.batchnorm.BatchNorm2d) -> CircuitBatchNormInput:
        check_quantized_input(quantized_input)
        
        padding = layer+1
        assert(len(expected.shape) == 3)
        channel_padding = self.max_point_filters - expected.shape[2]
       
        gamma = batch_norm.weight
        beta = batch_norm.bias
        mean = batch_norm.running_mean
        var = batch_norm.running_var
        eps = batch_norm.eps
        
        a = (gamma/(var+eps)**.5).detach()
        b = (beta-gamma*mean/(var+eps)**.5).detach()
        
        channel_padding = self.max_point_filters - len(a)
        a = F.pad(a, (0, channel_padding), "constant", 0).tolist()
        b = F.pad(b, (0, channel_padding), "constant", 0).tolist()
        
        quantized_a = [ai * 10**(EXPONENT) for ai in a]
        quantized_b = [bi * 10**(2*EXPONENT) for bi in b]
        
        bn_input = TypedBatchNormalizationInt(
            self.max_rows, 
            self.max_cols, 
            self.max_point_filters, 
            self.scalar_factor,
            quantized_input,
            quantized_a,
            quantized_b,
            padding=(layer+1)
        )
        test_output = dequantize(bn_input.out, padding, self.max_point_filters - channel_padding)

        assert(np.allclose(test_output, expected, atol=1e-4))
        return bn_input
        
print("input shape: ", pytorch_input.shape)
circuit = CircuitBackbone(model, (32, 32, 32))
circuit_layer_input, pytorch_output = circuit.circuit_layer_inputs(0, pytorch_input, quantized_input)

image, label = testset[0]
image = image.unsqueeze(0)
conv_output = model.conv(image)
bn_output = model.bn(conv_output)
relu_expected = model.relu(bn_output)
layer0_expected = model.features[0](relu_expected)

assert(torch.allclose(pytorch_output, layer0_expected, atol=1e-5))

padding = (np.array(circuit_layer_input.out()).shape[0] - expected.shape[0]) // 2
channel_padding = np.array(circuit_layer_input.out()).shape[2] - expected.shape[2]
test_output = dequantize(circuit_layer_input.out(), padding, circuit.max_point_filters - channel_padding)
expected = layer0_expected.squeeze().detach().numpy().transpose((1,2,0))
circuit_layer_input, pytorch_output = circuit.circuit_layer_inputs(1, pytorch_output, circuit_layer_input.out())

input shape:  torch.Size([1, 1, 32, 32])


In [141]:
class BackboneCircuitInput():
    layers: List[CircuitLayerInput]
    pytorch_outputs: List[torch.Tensor]
    quantized_input: List[List[List[int]]]
    n_layers: int
    circuit: CircuitBackbone
    
    def __init__(self, model: ZkMobileNet, n_layers: int, pytorch_input: torch.Tensor, quantized_input: List[List[List[int]]]):
        self.n_layers = n_layers
        self.quantized_input = quantized_input
        self.circuit = CircuitBackbone(model, (32, 32, 32))
        # self.circuit = CircuitMobilenet(model, (32, 32, 64))
        # self.circuit = CircuitMobilenet(model, (32, 32, 96))
        circuit_layer_input, pytorch_output = self.circuit.circuit_layer_inputs(0, pytorch_input, quantized_input)
        self.layers = [circuit_layer_input]
        self.pytorch_outputs = [pytorch_output]
        for layer in range(1, n_layers):
            circuit_layer_input, pytorch_output = self.circuit.circuit_layer_inputs(layer, pytorch_output, circuit_layer_input.out())
            self.layers.append(circuit_layer_input)
            self.pytorch_outputs.append(pytorch_output)

    def to_dict(self):
        return {
                "inp": self.layers[0].depthwise.conv.input,
                "backbone": [self.layers[i].to_dict("") for i in range(len(self.layers))],
        }

In [142]:
backbone = BackboneCircuitInput(model, len(model.features), pytorch_input, quantized_input)

In [157]:
d = {
        "step_in": ["0", 
                    "13848531447176013426093659219117515409412026631794481225551280988741196757461"],
        "in": backbone.layers[0].depthwise.conv.input,
        **backbone.layers[0].to_dict("")
        # "in": backbone.layers[0].pointwise.bn.out_str,
}

json_path = "test_inputs/layer_test.json"
with open(json_path, "w") as f:
    json.dump(d, f)
    
os.chdir("circuits")
!./backbone/backbone_cpp/backbone ../test_inputs/layer_test.json layer_test.wtns
# !npx snarkjs groth16 prove ./backbone/circuit_final.zkey layer_test.wtns layer_proof.json layer_public_test.json
os.chdir("../")

print("OK")

BACKBONE STARTED
STEP_IN     RESULT 13848531447176013426093659219117515409412026631794481225551280988741196757461
HASH OUTPUT RESULT 13848531447176013426093659219117515409412026631794481225551280988741196757461
WEIGHTS HASH RESULT 28410143885270053506126428624450852609816959586875909354515247824569374022385
START
dw_conv done
depth batch norm done
pw_conv done
point batch norm done
END
LAYER DONE
step_in[0] 0
step_in[1] 13848531447176013426093659219117515409412026631794481225551280988741196757461
step_out[0] 27137464690529701616276065280124118395853528564618288887388359929961766464500
step_out[1] 6840624638541427477068564814949756440191684922118746481459477645605895269858
END
OK


# Testing Tail

In [124]:
quantized_input = torch.Tensor(backbone.layers[-1].pointwise.bn.out)
pytorch_input = torch.Tensor([[[int(out) / 10**EXPONENT if out > 0 else 0 for out in vec] for vec in matrix] for matrix in backbone.layers[-1].pointwise.bn.out])
pytorch_input = torch.permute(pytorch_input, (2, 0, 1)).unsqueeze(0)
pytorch_input.shape

torch.Size([1, 32, 32, 32])

In [125]:
quantized_input = torch.Tensor(backbone.layers[-1].pointwise.bn.out)
pytorch_input = torch.Tensor([[[int(out) / 10**EXPONENT if out > 0 else 0 for out in vec] for vec in matrix] for matrix in backbone.layers[-1].pointwise.bn.out])
pytorch_input = torch.permute(pytorch_input, (2, 0, 1)).unsqueeze(0)
pytorch_input.shape

test_output = dequantize(backbone.layers[-1].pointwise.bn.out, 13, 0)
backbone.pytorch_outputs[-1].shape

expected = torch.permute(backbone.pytorch_outputs[-1].squeeze(0), (1, 2, 0))
assert(np.allclose(expected.detach(), test_output, atol=1e-6))

test_output = torch.Tensor(test_output.transpose((2, 0, 1))).unsqueeze(0)
input_expected = F.avg_pool2d(test_output, 6)
output_expected = F.avg_pool2d(backbone.pytorch_outputs[-1], 6)
assert(np.allclose(output_expected.detach(), test_output, atol=1e-6))

quantized_in = np.array(backbone.layers[-1].pointwise.bn.out)[13:-13, 13:-13, :]
input, out_str, out, remainder = AveragePooling2DInt(6, 6, 32, 6, 1, quantized_in)
test_output = [[[int(o) / 10**EXPONENT for o in vec] for vec in matrix] for matrix in out]
expected = torch.permute(output_expected.squeeze(0), (1, 2, 0))
test_output = np.array([[[int(from_circom(int(value))) / 10**EXPONENT for value in vec] for vec in matrix] for matrix in out])
assert(np.allclose(test_output, expected.detach(), atol=1e-6))

In [155]:
quantized_in = np.array(backbone.layers[-1].pointwise.bn.out)[13:-13, 13:-13, :]
input, out_str, out, remainder = AveragePooling2DInt(6, 6, 32, 6, 1, quantized_in)

print("out shape: ", np.array(out).shape)
out = out[0][0]
print("out shape: ", np.array(out).shape)
print("remainder shape: ", np.array(remainder).shape)

weights = model.linear.weight.detach().numpy().transpose((1,0))
weights = weights * 10**EXPONENT
bias = torch.zeros(weights.shape[1]).tolist()
print("bias shape: ", np.array(bias).shape)
# weights = model.conv.weight.detach().numpy().transpose(2, 3, 1, 0)

print("WEIGHTS shape: ", weights.shape)
dense_input, dense_weights, dense_bias, dense_out_str, dense_out, dense_remainder = DenseInt(32, 10, 10**EXPONENT, out, weights.round(), bias)
# print("dense weights shape: ", dense_weights)
d = {
        "step_in": ["0", 
                    "28111771283980637647336518757611443787181888759533047726301239107230122265508"],
        "in": backbone.layers[-1].pointwise.bn.out,
        # "in": input,
        "avg_pool_out": out,
        "avg_pool_remainder": remainder,
    
        "dense_weights": dense_weights,
        "dense_bias": dense_bias,
        "dense_out": dense_out_str,
        "dense_remainder": dense_remainder,
        # "pw_bn_remainder": self.bn.remainder,
}

json_path = "test_inputs/tail_test.json"
with open(json_path, "w") as f:
    json.dump(d, f)
    
os.chdir("circuits")
!./tail/tail_cpp/tail ../test_inputs/tail_test.json tail.wtns
# !npx snarkjs groth16 prove ./tail/circuit_final.zkey tail.witns tail_proof.json tail_public_test.json
os.chdir("../")

print("OK")

out shape:  (1, 1, 32)
out shape:  (32,)
remainder shape:  (1, 1, 32)
bias shape:  (10,)
WEIGHTS shape:  (32, 10)
TAIL STARTED
end pooling
MIMC_INPUT HASH :  28111771283980637647336518757611443787181888759533047726301239107230122265508
end!!
OK


In [147]:
    
d = {
        "inp": backbone.layers[0].depthwise.conv.input,
        # "backbone": [backbone.layers[0].to_dict(""), backbone.layers[1].to_dict(""), backbone.layers[2].to_dict("")],
        "backbone": [backbone.layers[i].to_dict("") for i in range(len(backbone.layers))],
        # **backbone.layers[0].to_dict(""),
}

json_path = "test_inputs/nova_backbone_input.json"
with open(json_path, "w") as f:
    json.dump(d, f)
    
d = {
        "step_in": ["0", 
                    "13848531447176013426093659219117515409412026631794481225551280988741196757461"],
        "in": d["inp"],
        **d["backbone"][0], 
}

json_path = "test_inputs/layer_test.json"
with open(json_path, "w") as f:
    json.dump(d, f)


os.chdir("circuits")
!./backbone/backbone_cpp/backbone ../test_inputs/layer_test.json backbone.wtns
# # # !npx snarkjs groth16 prove ./backbone/circuit_final.zkey backbone.wtns proof.json public_test.json
os.chdir("../")

print("OK")

BACKBONE STARTED
STEP_IN     RESULT 13848531447176013426093659219117515409412026631794481225551280988741196757461
HASH OUTPUT RESULT 13848531447176013426093659219117515409412026631794481225551280988741196757461
WEIGHTS HASH RESULT 28410143885270053506126428624450852609816959586875909354515247824569374022385
START
dw_conv done
depth batch norm done
pw_conv done
point batch norm done
END
LAYER DONE
step_in[0] 0
step_in[1] 13848531447176013426093659219117515409412026631794481225551280988741196757461
step_out[0] 27137464690529701616276065280124118395853528564618288887388359929961766464500
step_out[1] 6840624638541427477068564814949756440191684922118746481459477645605895269858
END
OK


# Testing two layers

In [145]:
d = {
        "step_in": ["0", 
                    "13848531447176013426093659219117515409412026631794481225551280988741196757461"],
        # "in": d["inp"],
        "in": backbone.layers[0].depthwise.conv.input,
        **backbone.layers[0].to_dict("l0_"),
        **backbone.layers[1].to_dict("l1_"),
}

json_path = "test_inputs/test.json"
with open(json_path, "w") as f:
    json.dump(d, f)


os.chdir("circuits")
!./model_test/model_test_cpp/model_test ../test_inputs/test.json layers.wtns
os.chdir("../")

print("OK")

MODEL TEST STARTED
STEP_IN     RESULT 13848531447176013426093659219117515409412026631794481225551280988741196757461
HASH OUTPUT RESULT 13848531447176013426093659219117515409412026631794481225551280988741196757461
WEIGHTS HASH RESULT 28410143885270053506126428624450852609816959586875909354515247824569374022385
PARAMS HASH RESULT 26641664397551561746797824861551809283236976665550582276760575734004319122274
POINTWISE WEIGHTS HASH RESULT 19773429561560358638202436376014129192346904520982647497029246391825639520892
OUTPUT HASH RESULT 6840624638541427477068564814949756440191684922118746481459477645605895269858
L0_STEP_OUT[0] RESULT 27137464690529701616276065280124118395853528564618288887388359929961766464500
L0_STEP_OUT[1] RESULT 6840624638541427477068564814949756440191684922118746481459477645605895269858
START
dw_conv done
depth batch norm done
pw_conv done
point batch norm done
END
LAYER 0 DONE
WEIGHTS HASH RESULT 1099969501423856252816105515498886098677356435046934325722109103972021778941