# Manual mobilenet implementation

In [1]:
from models.mobilenet import ZkMobileNet
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pytorch_lightning as pl
import torchvision
import torchvision.transforms.v2 as transforms
import json
import os

In [2]:
# Circom defines the range of positives are [0, p/2] and the range of negatives are [(p/2)+1, (p-1)].
# CIRCOM_PRIME = 21888242871839275222246405745257275088548364400416034343698204186575808495617
p = CIRCOM_PRIME = 28948022309329048855892746252171976963363056481941647379679742748393362948097
MAX_POSITIVE = CIRCOM_PRIME // 2
MAX_NEGATIVE = MAX_POSITIVE + 1 # The most positive number
CIRCOM_NEGATIVE_1 = CIRCOM_PRIME - 1
EXPONENT = 15

def from_circom(x):
    if type(x) != int:
        x = int(x)
    if x > MAX_POSITIVE: 
        return x - CIRCOM_PRIME
    return x
    
def to_circom(x):
    return x % CIRCOM_PRIME
    
def to_circom_input(array: np.array): 
    if type(array) != np.array:
        array = np.array(array)
    int_array = array.round().astype(int)
    int_array = to_circom(int_array)
    return int_array.astype(str).tolist()


def DepthwiseConv(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):
    assert(nFilters % nChannels == 0)
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    
    # out = np.zeros((outRows, outCols, nFilters))
    out = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    remainder = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    # remainder = np.zeros((outRows, outCols, nFilters))
    
    for row in range(outRows):
        for col in range(outCols):
            for channel in range(nChannels):
                for x in range(kernelSize):
                    for y in range(kernelSize):
                        out[row][col][channel] += int(input[row*strides+x, col*strides+y, channel]) * int(weights[x, y, channel])
                
                out[row][col][channel] += int(bias[channel])
                remainder[row][col][channel] = str(int(out[row][col][channel] % n))
                out[row][col][channel] = int(out[row][col][channel] // n)
                            
    return out, remainder

def PointwiseConv2d(nRows, nCols, nChannels, nFilters, strides, n, input, weights, bias):
    kernelSize = 1
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    out = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    str_out = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    remainder = [[[None for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    for row in range(outRows):
        for col in range(outCols):
            for filter in range(nFilters):
                for k in range(nChannels):
                    out[row][col][filter] += int(input[row*strides, col*strides, k]) * int(weights[k, filter])
                    
                out[row][col][filter] += int(bias[filter])
                remainder[row][col][filter] = str(int(out[row][col][filter] % n))
                out[row][col][filter] = int(out[row][col][filter] // n)
                str_out[row][col][filter] = str(out[row][col][filter] % p)
                            
    return out, str_out, remainder

def SeparableConvImpl(nRows, nCols, nChannels, nDepthFilters, nPointFilters, kernelSize, strides, n, input, depthWeights, pointWeights, depthBias, pointBias):
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1

    depth_out, depth_remainder = DepthwiseConv(nRows, nCols, nChannels, nDepthFilters, kernelSize, strides, n, input, depthWeights, depthBias)
    point_out, point_str_out, point_remainder = PointwiseConv2d(outRows, outCols, nChannels, nPointFilters, strides, n, depth_out, pointWeights, pointBias)
    return depth_out, depth_remainder, point_out, point_str_out, point_remainder

In [3]:
class DatasetWrapper(torch.utils.data.Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
        
    def __getitem__(self, index):
        x, y = self.subset[index]
        
        if self.transform:
            x = self.transform(x)
        return x, y
        
    def __len__(self):
        return len(self.subset)

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

testset = torchvision.datasets.CIFAR10(
    root='./CIFAR10', train=False, download=True, transform=transform)

trainset = torchvision.datasets.CIFAR10(
    root='./CIFAR10', train=True, download=True)

# split the train set into train/validation
train_set_size = int(len(trainset) * 0.8)
valid_set_size = len(trainset) - train_set_size

seed = torch.Generator().manual_seed(42)
trainset, validset = torch.utils.data.random_split(trainset, [train_set_size, valid_set_size], generator=seed)

trainset = DatasetWrapper(trainset, transform)
validset = DatasetWrapper(validset, transform)

# Create train dataloader
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=512, shuffle=True, num_workers=24)

# Create validation dataloader
validloader = torch.utils.data.DataLoader(
    validset, batch_size=512, shuffle=False, num_workers=24)

# Create test dataloader
testloader = torch.utils.data.DataLoader(
    testset, batch_size=512, shuffle=False, num_workers=24)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')



Files already downloaded and verified
Files already downloaded and verified


In [5]:
MODEL_WEIGHTS_PATH = "./checkpoints/no_padding_100epochs.pth", 

model = ZkMobileNet(trainloader, num_classes=10, alpha=0.25, max_epochs=100)

saved = torch.load("./checkpoints/no_padding_100epochs.pth")
# model.load_state_dict(saved['state_dict'])
model.load_state_dict(saved['net'])
model.eval()

image, label = validset[0]
image = image.unsqueeze(0)
logits = model(image)
pred_idx = logits.argmax()

print(f"Predicted {classes[pred_idx]} - idx: {pred_idx}")

Predicted frog - idx: 6


In [12]:
def Conv2DInt(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):
    p = CIRCOM_PRIME
    Input = [[[str(int(input[i][j][k]) % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
    Weights = [[[[str(int(weights[i][j][k][l]) % p) for l in range(nFilters)] for k in range(nChannels)] for j in range(kernelSize)] for i in range(kernelSize)]
    Bias = [str(int(bias[i]) % p) for i in range(nFilters)]
    out = [[[0 for _ in range(nFilters)] for _ in range((nCols - kernelSize)//strides + 1)] for _ in range((nRows - kernelSize)//strides + 1)]
    remainder = [[[None for _ in range(nFilters)] for _ in range((nCols - kernelSize)//strides + 1)] for _ in range((nRows - kernelSize)//strides + 1)]
    for i in range((nRows - kernelSize)//strides + 1):
        for j in range((nCols - kernelSize)//strides + 1):
            for m in range(nFilters):
                for k in range(nChannels):
                    for x in range(kernelSize):
                        for y in range(kernelSize):
                            out[i][j][m] += int(input[i*strides+x][j*strides+y][k]) * int(weights[x][y][k][m])
                out[i][j][m] += int(bias[m])
                remainder[i][j][m] = str(int(out[i][j][m]) % n)
                out[i][j][m] = str(int(out[i][j][m]) // n % p)
    return Input, Weights, Bias, out, remainder

In [13]:
def BatchNormalizationInt(nRows, nCols, nChannels, n, X_in, a_in, b_in):
    p = CIRCOM_PRIME
    X = [[[str(int(X_in[i][j][k]) % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
    A = [str(int(a_in[k]) % p) for k in range(nChannels)]
    B = [str(int(b_in[k]) % p) for k in range(nChannels)]
    out = [[[0 for _ in range(nChannels)] for _ in range(nCols)] for _ in range(nRows)]
    remainder = [[[None for _ in range(nChannels)] for _ in range(nCols)] for _ in range(nRows)]
    for i in range(nRows):
        for j in range(nCols):
            for k in range(nChannels):
                out[i][j][k] = int(int(X_in[i][j][k]) * int(a_in[k]) + int(b_in[k]))
                remainder[i][j][k] = str(int(out[i][j][k]) % n)
                out[i][j][k] = int(out[i][j][k] // n)
    return X, A, B, out, remainder

In [17]:
gamma = model.bn.weight
beta = model.bn.bias
mean = model.bn.running_mean
var = model.bn.running_var
eps = model.bn.eps

a = (gamma/(var+eps)**.5).detach()
b = (beta-gamma*mean/(var+eps)**.5).detach().tolist()

image, label = testset[0]
image = image.unsqueeze(0)
out = model.conv(image)
expected = model.bn(out)
# out = out.squeeze()
print(f"{out.shape=}")

expected = torch.permute(expected.squeeze(), (1, 2, 0))

print(f"{expected.shape=}")

quantized_in = torch.permute(out.squeeze(), (1, 2, 0)) * 10**EXPONENT
quantized_a = (a * 10**EXPONENT).tolist()
quantized_b = [bi * 10**(2*EXPONENT) for bi in b]
print(out.shape)

X, A, B, actual, remainder = BatchNormalizationInt(32, 32, 8, 10**EXPONENT, quantized_in, quantized_a, quantized_b)
test_output = [[[from_circom(int(out)) / 10**EXPONENT for out in asdf] for asdf in asdfasdf] for asdfasdf in actual]

assert(torch.allclose(torch.Tensor(test_output), expected, atol=1e-6))

out.shape=torch.Size([1, 8, 32, 32])
expected.shape=torch.Size([32, 32, 8])
torch.Size([1, 8, 32, 32])


# Testing head layer

In [18]:
# CONVOLUTION LAYER
weights = model.conv.weight.detach().numpy().transpose(2, 3, 1, 0)
image, label = testset[0]

expected = model.conv(image).detach().numpy()
bias = torch.zeros(weights.shape[3]).numpy()

# padded = pad(image, 1).transpose(1,2,0)
padded = F.pad(image, (1,1,1,1), "constant", 0).numpy()
padded = padded.transpose(1,2,0)

quantized_image = (padded * 10**EXPONENT).round()
quantized_weights = (weights * 10**EXPONENT).round() # .transpose(0, 3, 1, 0) # [nFilters, nChannels, H, W] -> 

circuit_in, circuit_conv_weights, circuit_conv_bias, circuit_conv_out, circuit_conv_remainder = Conv2DInt(34, 34, 3, 8, 3, 1, 10**EXPONENT, quantized_image, quantized_weights, bias)
# output, remainder = Conv2DInt(34, 34, 3, 8, 3, 1, 10**EXPONENT, quantized_image, quantized_weights, bias)
# test_output = output / 10**(EXPONENT)
test_output = [[[from_circom(int(out)) / 10**EXPONENT for out in asdf] for asdf in asdfasdf] for asdfasdf in circuit_conv_out]

expected = expected.transpose((1, 2, 0))

assert(np.allclose(test_output, expected, atol=1e-6))

# BATCH NORM CONSTANTS
gamma = model.bn.weight
beta = model.bn.bias
mean = model.bn.running_mean
var = model.bn.running_var
eps = model.bn.eps

a = (gamma/(var+eps)**.5).detach()
b = (beta-gamma*mean/(var+eps)**.5).detach().tolist()

quantized_a = (a * 10**EXPONENT).tolist()
quantized_b = [bi * 10**(2*EXPONENT) for bi in b]

# BATCH NORM USING PYTORCH OUTPUT
image, label = testset[0]
image = image.unsqueeze(0)
out = model.conv(image)
expected = model.bn(out)

expected = torch.permute(expected.squeeze(), (1, 2, 0))

quantized_in = torch.permute(out.squeeze(), (1, 2, 0)) * 10**EXPONENT

X, A, B, actual, remainder = BatchNormalizationInt(32, 32, 8, 10**EXPONENT, quantized_in, quantized_a, quantized_b)
test_output = [[[from_circom(int(out)) / 10**EXPONENT for out in vec] for vec in matrix] for matrix in actual]
assert(torch.allclose(torch.Tensor(test_output), expected, atol=1e-6))

# BATCH NORM USING CIRCUIT CONV OUTPUT
quantized_in = [[[from_circom(int(out)) for out in vec] for vec in matrix] for matrix in circuit_conv_out]

_, circuit_bn_a, circuit_bn_b, circuit_bn_out, circuit_bn_remainder = BatchNormalizationInt(32, 32, 8, 10**EXPONENT, quantized_in, quantized_a, quantized_b)

test_output = [[[int(from_circom(int(out))) / 10**EXPONENT for out in vec] for vec in matrix] for matrix in actual]

assert(torch.allclose(torch.Tensor(test_output), expected, atol=1e-6))

# RELU USING CIRCUIT OUTPUT
relu_in = [[[to_circom(bn_out) if bn_out < MAX_POSITIVE else 0 for bn_out in vec] for vec in matrix] for matrix in circuit_bn_out]
relu_out = [[[str(bn_out) if bn_out < MAX_POSITIVE else 0 for bn_out in vec] for vec in matrix] for matrix in relu_in]

input_json_path = "head_input1.json"
with open(input_json_path, "w") as input_file:
    json.dump({"in": circuit_in,
               "conv2d_weights": circuit_conv_weights,
               "conv2d_bias": circuit_conv_bias,
               "conv2d_out": circuit_conv_out,
               "conv2d_remainder": circuit_conv_remainder,
               
               "bn_a": circuit_bn_a,
               "bn_b": circuit_bn_b,
               "bn_out": circuit_bn_out,
               "bn_remainder": circuit_bn_remainder,
               
               "relu_out": relu_out,
               }, input_file)

os.chdir("circuits")
!./head/head_cpp/head ../head_input1.json head.wtns
# !npx snarkjs groth16 prove head/circuit_final.zkey head.wtns proof.json public_test.json
os.chdir("../")

print("ok")

something
before conv2d
after conv2d
after bn
at:  0 1 6
in:  1745789835369983
after relu
end
ok


# Testing padding over highly padded input (to try and fold the circuit using nova)

In [7]:
class SeparableConv2d(nn.Module):
    '''Separable convolution'''
    def __init__(self, in_channels, out_channels, stride=1):
        super(SeparableConv2d, self).__init__()
        self.dw_conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=0, groups=in_channels, bias=False),
            nn.BatchNorm2d(in_channels),
            # nn.ReLU(inplace=False),
        )
        self.pw_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            # nn.ReLU(inplace=False),
        )
        
input = torch.randn((1, 3, 5, 5))
test_model = SeparableConv2d(3, 6)

# Padded Convolution test

# 

In [8]:
input = torch.randn((1, 3, 7, 7))
# model = SeparableConv2d(3, 6)
# model.eval()

In [9]:
def PaddedDepthwiseConv(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):
    assert(nFilters % nChannels == 0)
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    
    # out = np.zeros((outRows, outCols, nFilters))
    Input = [[[str(int(input[i][j][k]) % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
    Weights = [[[str(int(weights[i][j][k].round()) % p) for k in range(weights.shape[2])] for j in range(weights.shape[1])] for i in range(weights.shape[0])]
    out = [[[0 for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    Bias = bias.round().astype(int).astype(str).tolist()
    out_str = [[[str(0) for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    remainder = [[[str(0) for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    
    for row in range(outRows):
        for col in range(outCols):
            for channel in range(nChannels):
                for x in range(kernelSize):
                    for y in range(kernelSize):
                        out[row+1][col+1][channel] += int(input[row*strides+x][col*strides+y][channel]) * int(weights[x, y, channel])
                
                out[row+1][col+1][channel] += int(bias[channel])
                remainder[row+1][col+1][channel] = str(int(out[row+1][col+1][channel] % n))
                out[row+1][col+1][channel] = int(out[row+1][col+1][channel] // n)
                out_str[row+1][col+1][channel] = str(out[row+1][col+1][channel] % p)
                            
    return Input, Weights, Bias, out_str, out, remainder

In [10]:
def PointwiseConv2d(nRows, nCols, nChannels, nFilters, strides, n, input, weights, bias):
    kernelSize = 1
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    # out = [[[0 for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    # remainder = [[[None for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    
    Input = [[[str(int(input[i][j][k]) % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
    # Weights = [[[str(int(weights[i][j][k].round()) % p) for k in range(weights.shape[2])] for j in range(weights.shape[1])] for i in range(weights.shape[0])]
    Weights = [[str(int(weights[i][j].round()) % p)for j in range(weights.shape[1])] for i in range(weights.shape[0])]
    out = [[[0 for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    Bias = bias.round().astype(int).astype(str).tolist()
    out_str = [[[str(0) for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    remainder = [[[str(0) for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    
    for row in range(outRows):
        for col in range(outCols):
            for filter in range(nFilters):
                for k in range(nChannels):
                    out[row][col][filter] += int(input[row*strides][col*strides][k]) * int(weights[k, filter])
                    
                out[row][col][filter] += int(bias[filter])
                remainder[row][col][filter] = str(int(out[row][col][filter] % n))
                out[row][col][filter] = int(out[row][col][filter] // n)
                out_str[row][col][filter] = str(out[row][col][filter] % p)
                            
    # return out, remainder
    return Input, Weights, Bias, out_str, out, remainder

In [14]:
# Depthwise convolution
input = torch.randn((1, 3, 7, 7))

depth_weights = test_model.dw_conv[0].weight.squeeze().detach().numpy()
depth_bias = torch.zeros(depth_weights.shape[0]).numpy()

print("Input shape: ", input.shape)
expected = test_model.dw_conv[0](input).detach()
print("Expected shape: ", expected.shape)

depth_weights = depth_weights.transpose((1, 2, 0))

quantized_image = input.squeeze().numpy().transpose((1,2,0)) * 10**EXPONENT
# quantized_image = padded * 10**EXPONENT
quantized_depth_weights = depth_weights * 10**EXPONENT

circuit_in, circuit_depth_weights, circuit_depth_bias, circuit_depth_out, depth_out, circuit_depth_remainder = PaddedDepthwiseConv(7, 7, 3, 3, 3, 1, 10**EXPONENT, quantized_image.round(), quantized_depth_weights.round(), depth_bias)

test_output = np.array([[[int(out) / 10**EXPONENT for out in asdf] for asdf in asdfasdf] for asdfasdf in depth_out])
test_output = test_output[1:-1, 1:-1, :]

expected = expected.squeeze().numpy().transpose((1,2,0))

assert(np.allclose(expected, test_output, atol=0.00001))

# Batch normalization step
# BATCH NORM CONSTANTS
test_model.eval()
gamma = test_model.dw_conv[1].weight
beta = test_model.dw_conv[1].bias
mean = test_model.dw_conv[1].running_mean
var = test_model.dw_conv[1].running_var
eps = test_model.dw_conv[1].eps

a = (gamma/(var+eps)**.5).detach()
print('a shape: ', a.shape)
b = (beta-gamma*mean/(var+eps)**.5).detach()
print('b shape: ', b.shape)

b = b.tolist()

quantized_a = (a * 10**EXPONENT).tolist()
quantized_b = [bi * 10**(2*EXPONENT) for bi in b]

out = test_model.dw_conv[0](input)
expected = test_model.dw_conv[1](out)

expected = torch.permute(expected.squeeze(), (1, 2, 0))

quantized_in = torch.permute(out.squeeze(), (1, 2, 0)) * 10**EXPONENT

X, A, B, actual, remainder = BatchNormalizationInt(5, 5, 3, 10**EXPONENT, quantized_in, quantized_a, quantized_b)
test_output = [[[from_circom(int(out)) / 10**EXPONENT for out in vec] for vec in matrix] for matrix in actual]

assert(torch.allclose(torch.Tensor(test_output), expected, atol=1e-6))

# quantized_in = [[[from_circom(int(out)) for out in vec] for vec in matrix] for matrix in circuit_conv_out]

_, circuit_depth_bn_a, circuit_depth_bn_b, circuit_depth_bn_out, circuit_depth_bn_remainder = BatchNormalizationInt(7, 7, 3, 10**EXPONENT, depth_out, quantized_a, quantized_b)

test_output = np.array([[[int(out) / 10**EXPONENT for out in asdf] for asdf in asdfasdf] for asdfasdf in circuit_depth_bn_out])
test_output = test_output[1:-1, 1:-1, :]

assert(torch.allclose(torch.Tensor(test_output), expected, atol=1e-6))

# Pointwise convolution
# point_weights = test_model.pw_conv[0].weight.squeeze().detach().numpy()
point_weights = test_model.pw_conv[0].weight.detach().numpy()
print("point weights shape: ", point_weights.shape)
point_bias = torch.zeros(point_weights.shape[0]).numpy()

print("Input shape: ", input.shape)
depth_expected = test_model.dw_conv[0](input)
bn_expected = test_model.dw_conv[1](depth_expected)
point_expected = test_model.pw_conv[0](bn_expected)
print("Depth Expected shape: ", depth_expected.shape)
# print("Point Expected shape: ", point_expected.shape)

point_expected = point_expected.squeeze().detach().numpy().transpose((1,2,0))
print("Point Expected shape: ", point_expected.shape)

point_weights = point_weights.transpose((2, 3, 1, 0)).squeeze()
quantized_point_weights = point_weights * 10**EXPONENT
print("point weights shape: ", quantized_point_weights.shape)

# circuit_in, circuit_depth_weights, circuit_depth_bias, circuit_depth_out, depth_out, circuit_depth_remainder = PaddedDepthwiseConv(7, 7, 3, 3, 3, 1, 10**EXPONENT, quantized_image.round(), quantized_depth_weights.round(), depth_bias)
point_input, circuit_point_weights, circuit_point_bias, circuit_point_out, point_out, circuit_point_remainder = PointwiseConv2d(7, 7, 3, 6, 1, 10**EXPONENT, circuit_depth_bn_out, quantized_point_weights.round(), point_bias)

test_output = np.array([[[int(out) / 10**EXPONENT for out in asdf] for asdf in asdfasdf] for asdfasdf in point_out])
test_output = test_output[1:-1, 1:-1, :]

assert(np.allclose(point_expected, test_output, atol=1e-6))

                   
# Batch normalization step
# BATCH NORM CONSTANTS
test_model.eval()
gamma = test_model.pw_conv[1].weight
beta = test_model.pw_conv[1].bias
mean = test_model.pw_conv[1].running_mean
var = test_model.pw_conv[1].running_var
eps = test_model.pw_conv[1].eps

a = (gamma/(var+eps)**.5).detach()
print('a shape: ', a.shape)
b = (beta-gamma*mean/(var+eps)**.5).detach()
print('b shape: ', b.shape)

b = b.tolist()

quantized_a = (a * 10**EXPONENT).tolist()
quantized_b = [bi * 10**(2*EXPONENT) for bi in b]

# out = test_model.dw_conv[0](input)
# expected = test_model.dw_conv[1](out)
        
print("Input shape: ", input.shape)
depth_expected = test_model.dw_conv[0](input)
bn_expected = test_model.dw_conv[1](depth_expected)
point_expected = test_model.pw_conv[0](bn_expected)
expected = test_model.pw_conv[1](point_expected)
print("Depth Expected shape: ", expected.shape)

              
expected = torch.permute(expected.squeeze(), (1, 2, 0))

quantized_in = torch.permute(point_expected.squeeze(), (1, 2, 0)) * 10**EXPONENT

X, A, B, actual, remainder = BatchNormalizationInt(5, 5, 6, 10**EXPONENT, quantized_in, quantized_a, quantized_b)
test_output = [[[from_circom(int(out)) / 10**EXPONENT for out in vec] for vec in matrix] for matrix in actual]

assert(torch.allclose(torch.Tensor(test_output), expected, atol=1e-6))

_, circuit_point_bn_a, circuit_point_bn_b, circuit_point_bn_out, circuit_point_bn_remainder = BatchNormalizationInt(7, 7, 6, 10**EXPONENT, point_out, quantized_a, quantized_b)

test_output = np.array([[[int(out) / 10**EXPONENT for out in asdf] for asdf in asdfasdf] for asdfasdf in circuit_point_bn_out])
test_output = test_output[1:-1, 1:-1, :]

print("Expected shape: ", expected.shape)
print("test shape: ", test_output.shape)
assert(torch.allclose(torch.Tensor(test_output), expected, atol=1e-6))

Input shape:  torch.Size([1, 3, 7, 7])
Expected shape:  torch.Size([1, 3, 5, 5])
a shape:  torch.Size([3])
b shape:  torch.Size([3])
point weights shape:  (6, 3, 1, 1)
Input shape:  torch.Size([1, 3, 7, 7])
Depth Expected shape:  torch.Size([1, 3, 5, 5])
Point Expected shape:  (5, 5, 6)
point weights shape:  (3, 6)
a shape:  torch.Size([6])
b shape:  torch.Size([6])
Input shape:  torch.Size([1, 3, 7, 7])
Depth Expected shape:  torch.Size([1, 6, 5, 5])
Expected shape:  torch.Size([5, 5, 6])
test shape:  (5, 5, 6)
terminate called after throwing an instance of 'std::runtime_error'
  what():  Error loading signal dw_bn_a: Not enough values

OK


# Padded convolution 2 iterations test with true input

In [16]:
# MODEL_WEIGHTS_PATH = './checkpoints/model_small_100epochs.pth'

# model = ZkMobileNet(trainloader, num_classes=10, alpha=0.25, max_epochs=100)
# checkpoint = torch.load(MODEL_WEIGHTS_PATH)
# # model.load_state_dict(checkpoint['state_dict'])
# model.load_state_dict(checkpoint['net'])
model.eval()

image, label = testset[0]
image = image.unsqueeze(0)
logits = model(image)
pred_idx = logits.argmax()

print(f"Predicted {classes[pred_idx]} - idx: {pred_idx}")

Predicted frog - idx: 6


In [17]:
# CONVOLUTION LAYER
weights = model.conv.weight.detach().numpy().transpose(2, 3, 1, 0)
image, label = testset[0]

expected = model.conv(image).detach().numpy()
bias = torch.zeros(weights.shape[3]).numpy()

# padded = pad(image, 1).transpose(1,2,0)
padded = F.pad(image, (1,1,1,1), "constant", 0).numpy()
padded = padded.transpose(1,2,0)

quantized_image = (padded * 10**EXPONENT).round()
quantized_weights = (weights * 10**EXPONENT).round() # .transpose(0, 3, 1, 0) # [nFilters, nChannels, H, W] -> 

circuit_in, circuit_conv_weights, circuit_conv_bias, circuit_conv_out, circuit_conv_remainder = Conv2DInt(34, 34, 3, 8, 3, 1, 10**EXPONENT, quantized_image, quantized_weights, bias)
# output, remainder = Conv2DInt(34, 34, 3, 8, 3, 1, 10**EXPONENT, quantized_image, quantized_weights, bias)
# test_output = output / 10**(EXPONENT)
test_output = [[[from_circom(int(out)) / 10**EXPONENT for out in asdf] for asdf in asdfasdf] for asdfasdf in circuit_conv_out]

expected = expected.transpose((1, 2, 0))

assert(np.allclose(test_output, expected, atol=1e-6))

# BATCH NORM CONSTANTS
gamma = model.bn.weight
beta = model.bn.bias
mean = model.bn.running_mean
var = model.bn.running_var
eps = model.bn.eps

a = (gamma/(var+eps)**.5).detach()
b = (beta-gamma*mean/(var+eps)**.5).detach().tolist()

quantized_a = (a * 10**EXPONENT).tolist()
quantized_b = [bi * 10**(2*EXPONENT) for bi in b]

# BATCH NORM USING PYTORCH OUTPUT
image, label = testset[0]
image = image.unsqueeze(0)
out = model.conv(image)
expected = model.bn(out)

expected = torch.permute(expected.squeeze(), (1, 2, 0))

quantized_in = torch.permute(out.squeeze(), (1, 2, 0)) * 10**EXPONENT

X, A, B, actual, remainder = BatchNormalizationInt(32, 32, 8, 10**EXPONENT, quantized_in, quantized_a, quantized_b)
test_output = [[[from_circom(int(out)) / 10**EXPONENT for out in vec] for vec in matrix] for matrix in actual]
assert(torch.allclose(torch.Tensor(test_output), expected, atol=1e-6))

# BATCH NORM USING CIRCUIT CONV OUTPUT
quantized_in = [[[from_circom(int(out)) for out in vec] for vec in matrix] for matrix in circuit_conv_out]

_, circuit_bn_a, circuit_bn_b, circuit_bn_out, circuit_bn_remainder = BatchNormalizationInt(32, 32, 8, 10**EXPONENT, quantized_in, quantized_a, quantized_b)

test_output = [[[int(from_circom(int(out))) / 10**EXPONENT for out in vec] for vec in matrix] for matrix in circuit_bn_out]

assert(torch.allclose(torch.Tensor(test_output), expected, atol=1e-6))

# RELU USING CIRCUIT OUTPUT
relu_in = [[[to_circom(bn_out) if bn_out < MAX_POSITIVE else 0 for bn_out in vec] for vec in matrix] for matrix in circuit_bn_out]
relu_out = [[[str(bn_out) if bn_out < MAX_POSITIVE else 0 for bn_out in vec] for vec in matrix] for matrix in relu_in]

image, label = testset[0]
image = image.unsqueeze(0)
conv_output = model.conv(image)
bn_output = model.bn(conv_output)
relu_expected = model.relu(bn_output).squeeze().detach().numpy().transpose((1,2,0))

test_output = [[[out / 10**EXPONENT if out > 0 else 0 for out in vec] for vec in matrix] for matrix in circuit_bn_out]

assert(np.allclose(test_output, relu_expected, atol=1e-6))

# DW INPUT WITH ACTUAL IMAGE

In [23]:
quantized_input = torch.Tensor([[[out if out > 0 else 0 for out in vec] for vec in matrix] for matrix in circuit_bn_out])
pytorch_input = torch.Tensor([[[int(out) / 10**EXPONENT if out > 0 else 0 for out in vec] for vec in matrix] for matrix in circuit_bn_out])
pytorch_input = torch.permute(pytorch_input, (2, 0, 1)).unsqueeze(0)

In [18]:
def BatchNormalizationInt(nRows, nCols, nChannels, n, X_in, a_in, b_in):
    p = CIRCOM_PRIME
    X = [[[str(int(X_in[i][j][k]) % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
    A = [str(int(a_in[k]) % p) for k in range(nChannels)]
    B = [str(int(b_in[k]) % p) for k in range(nChannels)]
    out = [[[0 for _ in range(nChannels)] for _ in range(nCols)] for _ in range(nRows)]
    out_str = [[[str(0) for _ in range(nChannels)] for _ in range(nCols)] for _ in range(nRows)]
    remainder = [[[str(0) for _ in range(nChannels)] for _ in range(nCols)] for _ in range(nRows)]
    for i in range(nRows):
        for j in range(nCols):
            for k in range(nChannels):
                out[i][j][k] = int(int(X_in[i][j][k]) * int(a_in[k]) + int(b_in[k]))
                remainder[i][j][k] = str(int(out[i][j][k]) % n)
                out[i][j][k] = int(out[i][j][k] // n)
                out_str[i][j][k] = str(out[i][j][k] % p)
    return X, A, B, out_str, out, remainder

In [209]:
def circom2pytorch(circuit_output):
    formatted = torch.Tensor([[[int(out) / 10**EXPONENT if out > 0 else 0 for out in vec] for vec in matrix] for matrix in circuit_output])
    formatted = torch.permute(formatted, (2, 0, 1)).unsqueeze(0)
    return formatted
    
def pytorch2quantized(pytorch_output: torch.Tensor):
    return pytorch_output.squeeze().detach().numpy().transpose((1, 2, 0)) * 10**EXPONENT

def dequantize(input: List[List[List[int]]], padding: int=1, channel_padding: Optional[int]=None):
    test_output = np.array([[[int(value) / 10**EXPONENT for value in vec] for vec in matrix] for matrix in input])
    
    if channel_padding is None:
        return test_output[padding:-padding, padding:-padding, :]
    
    return test_output[padding:-padding, padding:-padding, :channel_padding]
    
def check_quantized_input(quantized_input):
    """quantized_input should be quantized and should be (Height, Depth, Channels)"""
    assert(len(np.array(quantized_input).shape) == 3)

def check_pytorch_input(quantized_input):
    """pytorch_input should be quantized and should be (N=1, Channels, Height, Depth)"""
    assert(len(quantized_input.shape) == 4)
    assert(quantized_input.shape[0] == 1)
    assert(type(quantized_input) == torch.Tensor)

In [20]:
from pydantic import BaseModel
from typing import List, Optional, Union

class CircuitConvInput(BaseModel):
    input: Optional[List[List[List[str]]]]
    weights: Union[List[List[List[str]]], List[List[str]]]
    bias: List[str]
    out_str: List[List[List[str]]]
    remainder: List[List[List[str]]]
    
    out: List[List[List[int]]]
    
    
class CircuitBatchNormInput(BaseModel):
    input: Optional[List[List[List[str]]]]
    a: List[str]
    b: List[str]
    out_str: List[List[List[str]]]
    out: List[List[List[int]]]
    remainder: List[List[List[str]]]


class ConvBN(BaseModel):
    conv: CircuitConvInput
    bn: CircuitBatchNormInput

class CircuitLayerInput(BaseModel):
    depthwise: ConvBN
    pointwise: ConvBN

    def input(self):
        return self.depthwise.conv.input
        
    def out(self):
        return self.pointwise.bn.out
    
    def to_json(self, json_path: str):
        with open(json_path, "w") as input_file:
            json.dump({
                       "in": self.depthwise.conv.input,
                       "dw_conv_weights": self.depthwise.conv.weights,
                       "dw_conv_bias": self.depthwise.conv.bias,
                       "dw_conv_remainder": self.depthwise.conv.remainder,
                       "dw_conv_out": self.depthwise.conv.out_str,
                
                       "dw_bn_a": self.depthwise.bn.a,
                       "dw_bn_b": self.depthwise.bn.b,
                       "dw_bn_remainder": self.depthwise.bn.remainder,
                       "dw_bn_out": self.depthwise.bn.out,
                
                       "pw_conv_weights": self.pointwise.conv.weights,
                       "pw_conv_bias": self.pointwise.conv.bias,
                       "pw_conv_remainder": self.pointwise.conv.remainder,
                       "pw_conv_out": self.pointwise.conv.out_str,
                
                       "pw_bn_a": self.pointwise.bn.a,
                       "pw_bn_b": self.pointwise.bn.b,
                       "pw_bn_remainder": self.pointwise.bn.remainder,
                       "pw_bn_out": self.pointwise.bn.out,
                      },
                      input_file)

In [205]:
def TypedPaddedDepthwiseConv(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):
    Input, Weights, Bias, out_str, out, remainder = PaddedDepthwiseConv(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias)
    return CircuitConvInput(
        input=Input, 
        weights=Weights, 
        bias=Bias, 
        out_str=out_str, 
        out=out, 
        remainder=remainder, 
    )
    
def BatchNormalizationPadded(nRows, nCols, nChannels, n, X_in, a_in, b_in, padding):
    p = CIRCOM_PRIME
    X = [[[str(int(X_in[i][j][k]) % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
    A = [str(int(a_in[k]) % p) for k in range(nChannels)]
    B = [str(int(b_in[k]) % p) for k in range(nChannels)]
    out = [[[0 for _ in range(nChannels)] for _ in range(nCols)] for _ in range(nRows)]
    out_str = [[[str(0) for _ in range(nChannels)] for _ in range(nCols)] for _ in range(nRows)]
    remainder = [[[str(0) for _ in range(nChannels)] for _ in range(nCols)] for _ in range(nRows)]
    for i in range(padding, nRows-padding):
        for j in range(padding, nCols-padding):
            for k in range(nChannels):
                out[i][j][k] = int(int(X_in[i][j][k]) * int(a_in[k]) + int(b_in[k]))
                remainder[i][j][k] = str(int(out[i][j][k]) % n)
                out[i][j][k] = int(out[i][j][k] // n)
                out_str[i][j][k] = str(out[i][j][k] % p)
    return X, A, B, out_str, out, remainder
    
def TypedBatchNormalizationInt(nRows, nCols, nChannels, n, X_in, a_in, b_in, padding):
    X, A, B, out_str, out, remainder = BatchNormalizationPadded(nRows, nCols, nChannels, n, X_in, a_in, b_in, padding)
    return CircuitBatchNormInput(
        input=X, 
        a=A, 
        b=B, 
        out_str=out_str, 
        out=out, 
        remainder=remainder, 
    )
    
def TypedPointwiseConv2d(nRows, nCols, nChannels, nFilters, strides, n, input, weights, bias):
    Input, Weights, Bias, out_str, out, remainder = PointwiseConv2d(nRows, nCols, nChannels, nFilters, strides, n, input, weights, bias)
    return CircuitConvInput(
        input=Input, 
        weights=Weights, 
        bias=Bias, 
        out_str=out_str, 
        out=out, 
        remainder=remainder, 
    )
        

In [197]:
model.features[1].dw_conv

Sequential(
  (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), groups=16, bias=False)
  (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [206]:
class CircuitMobilenet():
    def __init__(self, model: ZkMobileNet):
        self.model = model
        self.model.eval()

    def _get_conv_circuit_input(self, layer: int, quantized_input: List[List[List[int]]], expected: np.array, module: nn.Conv2d) -> CircuitConvInput:
        print("_get_conv_circuit_input")
        check_quantized_input(quantized_input)
        
        print("_pointwise circuit_input")
        weights = module.weight.detach().numpy()
        bias = torch.zeros(weights.shape[0]).numpy()
        
        weights = weights.transpose((2, 3, 1, 0)).squeeze()
        quantized_weights = weights * 10**EXPONENT

        if module.kernel_size == (3, 3): 
            print("_depthwise circuit_input")
            conv_input = TypedPaddedDepthwiseConv(32, 32, 8, 8, 3, 1, 10**EXPONENT, quantized_input, quantized_weights.round(), bias)
            # -------------- Need to use (MAX_H, MAX_W, MAX_N_CHANNELS, MAX_N_FILTERS)
            # conv_input = TypedPaddedDepthwiseConv(32, 32, 16, 16, 3, 1, 10**EXPONENT, circuit_layer_input.out(), quantized_weights.round(), bias)
        elif module.kernel_size == (1, 1):
            print("_pointwise circuit_input")
            conv_input = TypedPointwiseConv2d(32, 32, 8, 16, 1, 10**EXPONENT, quantized_input, quantized_weights.round(), bias)
        
        test_output = dequantize(conv_input.out)
        
        assert(np.allclose(expected, test_output, atol=1e-5))
        return conv_input
        
    def _get_bn_circuit_input(self, layer: int, quantized_input: List[List[List[int]]], expected: np.array, batch_norm: nn.modules.batchnorm.BatchNorm2d) -> CircuitBatchNormInput:
        print("_get_bn_circuit_input")
        check_quantized_input(quantized_input)
        
        gamma = batch_norm.weight
        beta = batch_norm.bias
        mean = batch_norm.running_mean
        var = batch_norm.running_var
        eps = batch_norm.eps
        
        a = (gamma/(var+eps)**.5).detach().tolist()
        b = (beta-gamma*mean/(var+eps)**.5).detach().tolist()
        
        quantized_a = [ai * 10**(EXPONENT) for ai in a]
        quantized_b = [bi * 10**(2*EXPONENT) for bi in b]
        
        bn_input = TypedBatchNormalizationInt(32, 32, len(quantized_a), 10**EXPONENT, quantized_input, quantized_a, quantized_b, padding=(layer+1))
        
        test_output = dequantize(bn_input.out)
        
        assert(np.allclose(test_output, expected, atol=1e-5))
        return bn_input

    def _forward_module(self, module: nn.Module, input: torch.Tensor):
        print("_forward_module")
        output = module(input)
        expected = output.squeeze().detach().numpy().transpose((1,2,0))
        return output, expected
        
    def _circuit_conv_bn(self, layer: int, pytorch_input: torch.Tensor, quantized_input: List[List[List[int]]], conv: nn.Sequential):
        print("_circuit_conv_bn")
        conv_output, conv_expected = self._forward_module(conv[0], pytorch_input)
        
        circuit_conv_input = self._get_conv_circuit_input(layer, quantized_input, conv_expected, conv[0])
        
        bn_output, bn_expected = self._forward_module(conv[1], conv_output)
        
        circuit_bn_input = self._get_bn_circuit_input(layer, circuit_conv_input.out, bn_expected, conv[1])
        
        return ConvBN(conv=circuit_conv_input, bn=circuit_bn_input), bn_output

    def circuit_layer_inputs(self, layer: int, pytorch_input: torch.Tensor, quantized_input: List[List[List[int]]]):
        print("circuit_layer_inputs")
        check_pytorch_input(pytorch_input)
        check_quantized_input(quantized_input)

        depthwise, dw_output = self._circuit_conv_bn(layer, pytorch_input, quantized_input, self.model.features[layer].dw_conv)
        pointwise, pw_output = self._circuit_conv_bn(layer, dw_output, depthwise.bn.out, self.model.features[layer].pw_conv)

        layer_input = CircuitLayerInput(depthwise=depthwise, pointwise=pointwise)

        layer_input.to_json("it_worked.json")
        return layer_input, pw_output
        
circuit = CircuitMobilenet(model)
quantized_input = torch.Tensor([[[out if out > 0 else 0 for out in vec] for vec in matrix] for matrix in circuit_bn_out])
pytorch_input = torch.Tensor([[[int(out) / 10**EXPONENT if out > 0 else 0 for out in vec] for vec in matrix] for matrix in circuit_bn_out])
pytorch_input = torch.permute(pytorch_input, (2, 0, 1)).unsqueeze(0)


# print("quantized shape: ", np.array(circuit_layer_input.pointwise.bn.out).shape)
print("input shape: ", pytorch_input.shape)
# print("Pytorch shape: ", layer_pytorch_input.shape)
circuit_layer_input, pytorch_output = circuit.circuit_layer_inputs(0, pytorch_input, quantized_input)

image, label = testset[0]
image = image.unsqueeze(0)
conv_output = model.conv(image)
bn_output = model.bn(conv_output)
relu_expected = model.relu(bn_output)
# layer0_expected = model.features[0](bn_output).squeeze().detach().numpy().transpose((1,2,0))
layer0_expected = model.features[0](relu_expected)



assert(torch.allclose(pytorch_output, layer0_expected, atol=1e-5))

test_output = dequantize(circuit_layer_input.out())
expected = layer0_expected.squeeze().detach().numpy().transpose((1,2,0))
assert(np.allclose(test_output, expected, atol=1e-5))

input shape:  torch.Size([1, 8, 32, 32])
circuit_layer_inputs
_circuit_conv_bn
_forward_module
_get_conv_circuit_input
_pointwise circuit_input
_depthwise circuit_input
_forward_module
_get_bn_circuit_input
X type: <class 'str'> -> shape: (32, 32, 8)
A type: <class 'str'> -> shape: (8,)
B type: <class 'str'> -> shape: (8,)
_circuit_conv_bn
_forward_module
_get_conv_circuit_input
_pointwise circuit_input
_pointwise circuit_input
_forward_module
_get_bn_circuit_input
X type: <class 'str'> -> shape: (32, 32, 16)
A type: <class 'str'> -> shape: (16,)
B type: <class 'str'> -> shape: (16,)


In [43]:
pytorch_output[0][0][0], layer0_expected[0][0][0]

(tensor([0.5832, 0.7942, 0.9404, 0.8328, 0.8619, 0.9082, 0.8824, 0.9248, 1.0196,
         1.2147, 1.3648, 1.0355, 1.0215, 1.3017, 1.4859, 1.7842, 2.0280, 1.9749,
         2.2359, 2.3105, 2.3574, 2.3273, 2.2716, 2.1797, 1.8374, 1.5675, 1.2039,
         1.0939, 1.0149, 1.1237], grad_fn=<SelectBackward0>),
 tensor([0.5832, 0.7942, 0.9404, 0.8328, 0.8619, 0.9082, 0.8824, 0.9248, 1.0196,
         1.2147, 1.3648, 1.0355, 1.0215, 1.3017, 1.4859, 1.7842, 2.0280, 1.9749,
         2.2359, 2.3105, 2.3574, 2.3273, 2.2716, 2.1797, 1.8374, 1.5675, 1.2039,
         1.0939, 1.0149, 1.1237], grad_fn=<SelectBackward0>))

# REWRITING FOR SECOND LAYER!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

In [75]:
layer1_expected = model.features[1](layer0_expected)
print("Image shape: ", image.shape)
print("Conv  shape: ", conv_output.shape)
print("BN    shape: ", bn_output.shape)
print("ReLU  shape: ", relu_expected.shape)

layer0_dw_expected = model.features[0].dw_conv(relu_expected)
print("after features[0].depthwise shape: ", layer0_expected.shape)
print("----------> layer0 depthwise  shape: ", np.array(circuit_layer_input.depthwise.conv.out).shape)
layer0_expected = model.features[0].pw_conv(layer0_dw_expected)
print("after features[1].pointwise shape: ", layer0_expected.shape)
print("----------> layer0 pointwise  shape: ", np.array(circuit_layer_input.pointwise.conv.out).shape)
expected = model.features[1].dw_conv[0](layer0_expected).detach()
print("expected shape: ", pytorch_output.shape)

# print("depthwise circuit_input")
weights = model.features[1].dw_conv[0].weight.detach().numpy()
bias = torch.zeros(weights.shape[0]).numpy()

weights = weights.transpose((2, 3, 1, 0)).squeeze()
quantized_weights = weights * 10**EXPONENT
print("----------> layer1 --->> quantized_weights  shape: ", np.array(quantized_weights).shape)

conv_input = TypedPaddedDepthwiseConv(32, 32, 16, 16, 3, 1, 10**EXPONENT, circuit_layer_input.out(), quantized_weights.round(), bias)
            # conv_input = TypedPaddedDepthwiseConv(32, 32, 8, 8, 3, 1, 10**EXPONENT, quantized_input, quantized_weights.round(), bias)
    
print("----------> layer1 depthwise  shape: ", np.array(conv_input.out).shape)
test_output = dequantize(conv_input.out, 2)
print("conv_input.out shape: ", np.array(conv_input.out).shape)
print("test_output shape: ", np.array(test_output).shape)
print("expected shape: ", np.array(expected).shape)
expected = expected.squeeze().detach().numpy().transpose((1, 2, 0))

assert(np.allclose(expected, test_output, atol=1e-5))
# circuit_depth_input = circuit.circuit_layer_inputs(1, layer0_expected, circuit_layer_input.out())

Image shape:  torch.Size([1, 3, 32, 32])
Conv  shape:  torch.Size([1, 8, 32, 32])
BN    shape:  torch.Size([1, 8, 32, 32])
ReLU  shape:  torch.Size([1, 8, 32, 32])
after features[0].depthwise shape:  torch.Size([1, 16, 30, 30])
----------> layer0 depthwise  shape:  (32, 32, 8)
after features[1].pointwise shape:  torch.Size([1, 16, 30, 30])
----------> layer0 pointwise  shape:  (32, 32, 16)
expected shape:  torch.Size([1, 32, 28, 28])
----------> layer1 --->> quantized_weights  shape:  (3, 3, 16)
----------> layer1 depthwise  shape:  (32, 32, 16)
conv_input.out shape:  (32, 32, 16)
test_output shape:  (28, 28, 16)
expected shape:  (1, 16, 28, 28)


In [87]:
image, label = testset[0]
image = image.unsqueeze(0)
conv_output = model.conv(image)
bn_output = model.bn(conv_output)
relu_expected = model.relu(bn_output)
# layer0_expected = model.features[0](bn_output).squeeze().detach().numpy().transpose((1,2,0))
print("Image shape: ", image.shape)
print("Conv  shape: ", conv_output.shape)
print("BN    shape: ", bn_output.shape)
print("ReLU  shape: ", relu_expected.shape)

layer0_dw_expected = model.features[0].dw_conv(relu_expected)

# print("depthwise circuit_input")
weights = model.features[1].dw_conv[0].weight.detach().numpy()
bias = torch.zeros(weights.shape[0]).numpy()

weights = weights.transpose((2, 3, 1, 0)).squeeze()
quantized_weights = weights * 10**EXPONENT
print("----------> layer1 --->> quantized_weights  shape: ", np.array(quantized_weights).shape)

Image shape:  torch.Size([1, 3, 32, 32])
Conv  shape:  torch.Size([1, 8, 32, 32])
BN    shape:  torch.Size([1, 8, 32, 32])
ReLU  shape:  torch.Size([1, 8, 32, 32])
----------> layer1 --->> quantized_weights  shape:  (3, 3, 16)


In [96]:
model.features[-1]

ZkSeparableConv2d(
  (dw_conv): Sequential(
    (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), groups=256, bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pw_conv): Sequential(
    (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [127]:
# initial    SHAPE: 32x32x3
# conv       SHAPE: 32x32x8
      # (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), groups=8, bias=False)
      # (0): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
# layer0_dw  SHAPE: 30x30x8
# layer0_pw  SHAPE: 30x30x16

      # (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), groups=16, bias=False)
      # (0): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
# layer1_dw  SHAPE: 28x28x16
# layer1_pw  SHAPE: 28x28x32

      # (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
      # (0): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
# layer2_dw  SHAPE: 26x26x32
# layer2_pw  SHAPE: 26x26x32

      # (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
      # (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
# layer3_dw  SHAPE: 24x24x32
# layer3_pw  SHAPE: 24x24x64

      # (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), groups=64, bias=False)
      # (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
# layer4_dw  SHAPE: 22x22x64
# layer4_pw  SHAPE: 22x22x64

      # (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), groups=64, bias=False)
      # (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
# layer5_dw  SHAPE: 20x20x64
# layer5_pw  SHAPE: 20x20x128

      # (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), groups=128, bias=False)
      # (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
# layer6_dw  SHAPE: 18x18x128
# layer6_pw  SHAPE: 18x18x128

      # (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), groups=128, bias=False)
      # (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
# layer7_dw  SHAPE: 16x16x128
# layer7_pw  SHAPE: 16x16x128

      # (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), groups=128, bias=False)
      # (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
# layer8_dw  SHAPE: 14x14x128
# layer8_pw  SHAPE: 14x14x128

      # (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), groups=128, bias=False)
      # (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
# layer9_dw  SHAPE: 12x12x128
# layer9_pw  SHAPE: 12x12x128

      # (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), groups=128, bias=False)
      # (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
# layer10_dw  SHAPE: 10x10x128
# layer10_pw  SHAPE: 10x10x128

      # (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), groups=128, bias=False)
      # (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
# layer11_dw  SHAPE: 8x8x128
# layer11_pw  SHAPE: 8x8x256

      # (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), groups=256, bias=False)
      # (0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
# layer12_dw  SHAPE: 6x6x256
# layer12_pw  SHAPE: 6x6x256

In [185]:
# def format_circuit_input(input: List[List[List[int]]], dims: List[int]):
#     # [[[str(int(input[i][j][k]) % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
#     assert(len(dims) == 3)
#     input_shape = np.array(input).shape
#     assert(len(input_shape) == 3)
#     print(f"{dims = }")
#     print(f"{input_shape = }")
#     formatted = [[[str(0) for k in range(dims[2])] for j in range(dims[1])] for i in range(dims[0])]
    

def CircuitDepthwiseConv(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):
    assert(nFilters % nChannels == 0)
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    
    # out = np.zeros((outRows, outCols, nFilters))
    Input = [[[str(int(input[i][j][k]) % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
    # Input = format_circuit_input(input, (nRows, nCols, nChannels))
    Weights = [[[str(int(weights[i][j][k].round()) % p) for k in range(weights.shape[2])] for j in range(weights.shape[1])] for i in range(weights.shape[0])]
    Bias = bias.round().astype(int).astype(str).tolist()
    
    
   
    out = [[[0 for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    out_str = [[[str(0) for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    remainder = [[[str(0) for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    
    for row in range(outRows):
        for col in range(outCols):
            for channel in range(nFilters):
                for x in range(kernelSize):
                    for y in range(kernelSize):
                        out[row+1][col+1][channel] += int(input[row*strides+x][col*strides+y][channel]) * int(weights[x, y, channel])
                
                out[row+1][col+1][channel] += int(bias[channel])
                remainder[row+1][col+1][channel] = str(int(out[row+1][col+1][channel] % n))
                out[row+1][col+1][channel] = int(out[row+1][col+1][channel] // n)
                out_str[row+1][col+1][channel] = str(out[row+1][col+1][channel] % p)
                            
    return Input, Weights, Bias, out_str, out, remainder

def TypedPaddedDepthwiseConv(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):
    Input, Weights, Bias, out_str, out, remainder = CircuitDepthwiseConv(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias)
    return CircuitConvInput(
        input=Input, 
        weights=Weights, 
        bias=Bias, 
        out_str=out_str, 
        out=out, 
        remainder=remainder, 
    )
    

# # Need 32 x 32 x 256 input/output shapes


# Have 32 x 32 x 8 input shape,
# Need 32 x 32 x 16 input/output shapes
image, label = testset[0]
image = image.unsqueeze(0)
conv_output = model.conv(image)
bn_output = model.bn(conv_output)
relu_expected = model.relu(bn_output)
# layer0_expected = model.features[0](bn_output).squeeze().detach().numpy().transpose((1,2,0))
print("Image shape: ", image.shape)
print("Conv  shape: ", conv_output.shape)
print("BN    shape: ", bn_output.shape)
print("ReLU  shape: ", relu_expected.shape)

layer0_dw_expected = model.features[0].dw_conv(relu_expected)
print("layer0_dw_expected shape: ", layer0_dw_expected.shape)

# print("depthwise circuit_input")
# weights = model.features[1].dw_conv[0].weight.detach().numpy()
# bias = torch.zeros(weights.shape[0]).numpy()

# weights = weights.transpose((2, 3, 1, 0)).squeeze()
# quantized_weights = weights * 10**EXPONENT
# print("----------> layer1 --->> quantized_weights  shape: ", np.array(quantized_weights).shape)

expected = layer0_dw_expected.squeeze().detach().numpy().transpose((1, 2, 0))
print("transposed layer0_dw_expected shape: ", expected.shape)

# conv_input = TypedPaddedDepthwiseConv(32, 32, 16, 16, 3, 1, 10**EXPONENT, circuit_layer_input.out(), quantized_weights.round(), bias)
# weights = model.features[0].dw_conv[0].weight.detach().numpy()
weights = model.features[0].dw_conv[0].weight
print("WEIGHTS SHAPE: ", weights.shape)
bias = torch.zeros(16).numpy()

# weights = weights.transpose((2, 3, 1, 0)).squeeze()
weights = torch.permute(weights, (2, 3, 1, 0)).squeeze()
print("TRANSPOSED WEIGHTS SHAPE: ", weights.shape)

# (3 x 3 x 8) -> (3 x 3 x 16)
padded_weights = F.pad(weights, (0, 8), "constant", 0)
print("PADDED WEIGHTS SHAPE: ", padded_weights.shape)
quantized_weights = padded_weights * 10**EXPONENT

dw_input = circuit_layer_input.out()
print("DW INPUT SHAPE: ", np.array(dw_input).shape)
conv_input = TypedPaddedDepthwiseConv(32, 32, 16, 16, 3, 1, 10**EXPONENT, dw_input, quantized_weights.round(), bias)
# conv_input = TypedPaddedDepthwiseConv(32, 32, 8, 8, 3, 1, 10**EXPONENT, quantized_input, quantized_weights.round(), bias)
    
test_output = dequantize(conv_input.out, 1, 8)
print("Dequantized shape: ", np.array(test_output).shape)
print("conv_input.out shape: ", np.array(conv_input.out).shape)


# circuit_depth_input = circuit.circuit_layer_inputs(1, layer0_expected, circuit_layer_input.out())

# layer0_pw_expected = model.features[0].pw_conv(layer0_dw_expected)
# print("layer0_pw_expected shape: ", layer0_pw_expected.shape)
assert(np.allclose(expected, test_output, atol=1e-5))

Image shape:  torch.Size([1, 3, 32, 32])
Conv  shape:  torch.Size([1, 8, 32, 32])
BN    shape:  torch.Size([1, 8, 32, 32])
ReLU  shape:  torch.Size([1, 8, 32, 32])
layer0_dw_expected shape:  torch.Size([1, 8, 30, 30])
transposed layer0_dw_expected shape:  (30, 30, 8)
WEIGHTS SHAPE:  torch.Size([8, 1, 3, 3])
TRANSPOSED WEIGHTS SHAPE:  torch.Size([3, 3, 8])
PADDED WEIGHTS SHAPE:  torch.Size([3, 3, 16])
DW INPUT SHAPE:  (32, 32, 16)
Dequantized shape:  (30, 30, 8)
conv_input.out shape:  (32, 32, 16)


AssertionError: 

In [172]:
# np.array(circuit_layer_input.out())[0][0][0]

3146114871436651

In [167]:
expected[0][0]

array([-0.18779713,  0.8504402 , -1.3140223 ,  3.357099  , -1.304945  ,
        4.7155004 ,  3.8864698 ,  0.88151836], dtype=float32)

In [152]:
os.chdir("circuits")
!./padded/padded_cpp/padded ../it_worked.json head.wtns
# # !npx snarkjs groth16 prove ./origDepthwiseConv2d/circuit_final.zkey head.wtns proof.json public_test.json
os.chdir("../")

print("OK")

START
PRIME - 1 28948022309329048855892746252171976963363056481941647379679742748393362948096
------------------------------------------------
------------------------------------------------
dw_conv done
depth batch norm done
pw_conv done
point batch norm done
END
OK


In [None]:
circuit_depth_input = circuit.circuit_layer_inputs(0, pytorch_input, quantized_input)

In [6]:
model.features[0]

ZkSeparableConv2d(
  (dw_conv): Sequential(
    (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), groups=8, bias=False)
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pw_conv): Sequential(
    (0): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)