# Manual mobilenet implementation

In [24]:
from models.mobilenet import MyMobileNet
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pytorch_lightning as pl
import torchvision
import torchvision.transforms.v2 as transforms
import json
import os

In [157]:
# Circom defines the range of positives are [0, p/2] and the range of negatives are [(p/2)+1, (p-1)].
# CIRCOM_PRIME = 21888242871839275222246405745257275088548364400416034343698204186575808495617
CIRCOM_PRIME = 28948022309329048855892746252171976963363056481941647379679742748393362948097
MAX_POSITIVE = CIRCOM_PRIME // 2
MAX_NEGATIVE = MAX_POSITIVE + 1 # The most positive number
CIRCOM_NEGATIVE_1 = CIRCOM_PRIME - 1
EXPONENT = 15

def from_circom(x):
    if type(x) != int:
        x = int(x)
    if x > MAX_POSITIVE: 
        return x - CIRCOM_PRIME
    return x
    
def to_circom(x):
    return x % CIRCOM_PRIME
    
def to_circom_input(array: np.array): 
    if type(array) != np.array:
        array = np.array(array)
    int_array = array.round().astype(int)
    int_array = to_circom(int_array)
    return int_array.astype(str).tolist()


def DepthwiseConv(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):
    assert(nFilters % nChannels == 0)
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    
    # out = np.zeros((outRows, outCols, nFilters))
    out = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    remainder = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    # remainder = np.zeros((outRows, outCols, nFilters))
    
    for row in range(outRows):
        for col in range(outCols):
            for channel in range(nChannels):
                for x in range(kernelSize):
                    for y in range(kernelSize):
                        out[row][col][channel] += int(input[row*strides+x, col*strides+y, channel]) * int(weights[x, y, channel])
                
                out[row][col][channel] += int(bias[channel])
                remainder[row][col][channel] = str(int(out[row][col][channel] % n))
                out[row][col][channel] = int(out[row][col][channel] // n)
                            
    return out, remainder

def PointwiseConv2d(nRows, nCols, nChannels, nFilters, strides, n, input, weights, bias):
    kernelSize = 1
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    out = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    str_out = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    remainder = [[[None for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    for row in range(outRows):
        for col in range(outCols):
            for filter in range(nFilters):
                for k in range(nChannels):
                    out[row][col][filter] += int(input[row*strides, col*strides, k]) * int(weights[k, filter])
                    
                out[row][col][filter] += int(bias[filter])
                remainder[row][col][filter] = str(int(out[row][col][filter] % n))
                out[row][col][filter] = int(out[row][col][filter] // n)
                str_out[row][col][filter] = str(out[row][col][filter] % p)
                            
    return out, str_out, remainder

def SeparableConvImpl(nRows, nCols, nChannels, nDepthFilters, nPointFilters, kernelSize, strides, n, input, depthWeights, pointWeights, depthBias, pointBias):
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1

    depth_out, depth_remainder = DepthwiseConv(nRows, nCols, nChannels, nDepthFilters, kernelSize, strides, n, input, depthWeights, depthBias)
    point_out, point_str_out, point_remainder = PointwiseConv2d(outRows, outCols, nChannels, nPointFilters, strides, n, depth_out, pointWeights, pointBias)
    return depth_out, depth_remainder, point_out, point_str_out, point_remainder

In [26]:
class DatasetWrapper(torch.utils.data.Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
        
    def __getitem__(self, index):
        x, y = self.subset[index]
        
        if self.transform:
            x = self.transform(x)
        return x, y
        
    def __len__(self):
        return len(self.subset)

In [27]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

testset = torchvision.datasets.CIFAR10(
    root='./CIFAR10', train=False, download=True, transform=transform)

trainset = torchvision.datasets.CIFAR10(
    root='./CIFAR10', train=True, download=True)

# split the train set into train/validation
train_set_size = int(len(trainset) * 0.8)
valid_set_size = len(trainset) - train_set_size

seed = torch.Generator().manual_seed(42)
trainset, validset = torch.utils.data.random_split(trainset, [train_set_size, valid_set_size], generator=seed)

trainset = DatasetWrapper(trainset, transform)
validset = DatasetWrapper(validset, transform)

# Create train dataloader
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=512, shuffle=True, num_workers=24)

# Create validation dataloader
validloader = torch.utils.data.DataLoader(
    validset, batch_size=512, shuffle=False, num_workers=24)

# Create test dataloader
testloader = torch.utils.data.DataLoader(
    testset, batch_size=512, shuffle=False, num_workers=24)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [28]:
MODEL_WEIGHTS_PATH = './checkpoints/model_small_100epochs.pth'

model = MyMobileNet(trainloader, num_classes=10, alpha=0.25, max_epochs=100)
checkpoint = torch.load(MODEL_WEIGHTS_PATH)
# model.load_state_dict(checkpoint['state_dict'])
model.load_state_dict(checkpoint['net'])
model.eval()

image, label = validset[0]
image = image.unsqueeze(0)
logits = model(image)
pred_idx = logits.argmax()

print(f"Predicted {classes[pred_idx]} - idx: {pred_idx}")

Predicted horse - idx: 7


In [6]:
def Conv2DInt(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):
    p = CIRCOM_PRIME
    Input = [[[str(int(input[i][j][k]) % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
    Weights = [[[[str(int(weights[i][j][k][l]) % p) for l in range(nFilters)] for k in range(nChannels)] for j in range(kernelSize)] for i in range(kernelSize)]
    Bias = [str(int(bias[i]) % p) for i in range(nFilters)]
    out = [[[0 for _ in range(nFilters)] for _ in range((nCols - kernelSize)//strides + 1)] for _ in range((nRows - kernelSize)//strides + 1)]
    remainder = [[[None for _ in range(nFilters)] for _ in range((nCols - kernelSize)//strides + 1)] for _ in range((nRows - kernelSize)//strides + 1)]
    for i in range((nRows - kernelSize)//strides + 1):
        for j in range((nCols - kernelSize)//strides + 1):
            for m in range(nFilters):
                for k in range(nChannels):
                    for x in range(kernelSize):
                        for y in range(kernelSize):
                            out[i][j][m] += int(input[i*strides+x][j*strides+y][k]) * int(weights[x][y][k][m])
                out[i][j][m] += int(bias[m])
                remainder[i][j][m] = str(int(out[i][j][m]) % n)
                out[i][j][m] = str(int(out[i][j][m]) // n % p)
    return Input, Weights, Bias, out, remainder

In [7]:
def BatchNormalizationInt(nRows, nCols, nChannels, n, X_in, a_in, b_in):
    p = CIRCOM_PRIME
    X = [[[str(int(X_in[i][j][k]) % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
    A = [str(int(a_in[k]) % p) for k in range(nChannels)]
    B = [str(int(b_in[k]) % p) for k in range(nChannels)]
    out = [[[0 for _ in range(nChannels)] for _ in range(nCols)] for _ in range(nRows)]
    remainder = [[[None for _ in range(nChannels)] for _ in range(nCols)] for _ in range(nRows)]
    for i in range(nRows):
        for j in range(nCols):
            for k in range(nChannels):
                out[i][j][k] = int(int(X_in[i][j][k]) * int(a_in[k]) + int(b_in[k]))
                remainder[i][j][k] = str(int(out[i][j][k]) % n)
                out[i][j][k] = int(out[i][j][k] // n)
    return X, A, B, out, remainder

In [8]:
gamma = model.bn.weight
beta = model.bn.bias
mean = model.bn.running_mean
var = model.bn.running_var
eps = model.bn.eps

a = (gamma/(var+eps)**.5).detach()
b = (beta-gamma*mean/(var+eps)**.5).detach().tolist()

image, label = testset[0]
image = image.unsqueeze(0)
out = model.conv(image)
expected = model.bn(out)
# out = out.squeeze()
print(f"{out.shape=}")

expected = torch.permute(expected.squeeze(), (1, 2, 0))

print(f"{expected.shape=}")

quantized_in = torch.permute(out.squeeze(), (1, 2, 0)) * 10**EXPONENT
quantized_a = (a * 10**EXPONENT).tolist()
quantized_b = [bi * 10**(2*EXPONENT) for bi in b]
print(out.shape)

X, A, B, actual, remainder = BatchNormalizationInt(32, 32, 8, 10**EXPONENT, quantized_in, quantized_a, quantized_b)
test_output = [[[from_circom(int(out)) / 10**EXPONENT for out in asdf] for asdf in asdfasdf] for asdfasdf in actual]

assert(torch.allclose(torch.Tensor(test_output), expected, atol=1e-6))

out.shape=torch.Size([1, 8, 32, 32])
expected.shape=torch.Size([32, 32, 8])
torch.Size([1, 8, 32, 32])


# Testing head layer

In [9]:
# CONVOLUTION LAYER
weights = model.conv.weight.detach().numpy().transpose(2, 3, 1, 0)
image, label = testset[0]

expected = model.conv(image).detach().numpy()
bias = torch.zeros(weights.shape[3]).numpy()

# padded = pad(image, 1).transpose(1,2,0)
padded = F.pad(image, (1,1,1,1), "constant", 0).numpy()
padded = padded.transpose(1,2,0)

quantized_image = (padded * 10**EXPONENT).round()
quantized_weights = (weights * 10**EXPONENT).round() # .transpose(0, 3, 1, 0) # [nFilters, nChannels, H, W] -> 

circuit_in, circuit_conv_weights, circuit_conv_bias, circuit_conv_out, circuit_conv_remainder = Conv2DInt(34, 34, 3, 8, 3, 1, 10**EXPONENT, quantized_image, quantized_weights, bias)
# output, remainder = Conv2DInt(34, 34, 3, 8, 3, 1, 10**EXPONENT, quantized_image, quantized_weights, bias)
# test_output = output / 10**(EXPONENT)
test_output = [[[from_circom(int(out)) / 10**EXPONENT for out in asdf] for asdf in asdfasdf] for asdfasdf in circuit_conv_out]

expected = expected.transpose((1, 2, 0))

assert(np.allclose(test_output, expected, atol=1e-6))

# BATCH NORM CONSTANTS
gamma = model.bn.weight
beta = model.bn.bias
mean = model.bn.running_mean
var = model.bn.running_var
eps = model.bn.eps

a = (gamma/(var+eps)**.5).detach()
b = (beta-gamma*mean/(var+eps)**.5).detach().tolist()

quantized_a = (a * 10**EXPONENT).tolist()
quantized_b = [bi * 10**(2*EXPONENT) for bi in b]

# BATCH NORM USING PYTORCH OUTPUT
image, label = testset[0]
image = image.unsqueeze(0)
out = model.conv(image)
expected = model.bn(out)

expected = torch.permute(expected.squeeze(), (1, 2, 0))

quantized_in = torch.permute(out.squeeze(), (1, 2, 0)) * 10**EXPONENT

X, A, B, actual, remainder = BatchNormalizationInt(32, 32, 8, 10**EXPONENT, quantized_in, quantized_a, quantized_b)
test_output = [[[from_circom(int(out)) / 10**EXPONENT for out in vec] for vec in matrix] for matrix in actual]
assert(torch.allclose(torch.Tensor(test_output), expected, atol=1e-6))

# BATCH NORM USING CIRCUIT CONV OUTPUT
quantized_in = [[[from_circom(int(out)) for out in vec] for vec in matrix] for matrix in circuit_conv_out]

_, circuit_bn_a, circuit_bn_b, circuit_bn_out, circuit_bn_remainder = BatchNormalizationInt(32, 32, 8, 10**EXPONENT, quantized_in, quantized_a, quantized_b)

test_output = [[[int(from_circom(int(out))) / 10**EXPONENT for out in vec] for vec in matrix] for matrix in actual]

assert(torch.allclose(torch.Tensor(test_output), expected, atol=1e-6))

# RELU USING CIRCUIT OUTPUT
relu_in = [[[to_circom(bn_out) if bn_out < MAX_POSITIVE else 0 for bn_out in vec] for vec in matrix] for matrix in circuit_bn_out]
relu_out = [[[str(bn_out) if bn_out < MAX_POSITIVE else 0 for bn_out in vec] for vec in matrix] for matrix in relu_in]

input_json_path = "head_input1.json"
with open(input_json_path, "w") as input_file:
    json.dump({"in": circuit_in,
               "conv2d_weights": circuit_conv_weights,
               "conv2d_bias": circuit_conv_bias,
               "conv2d_out": circuit_conv_out,
               "conv2d_remainder": circuit_conv_remainder,
               
               "bn_a": circuit_bn_a,
               "bn_b": circuit_bn_b,
               "bn_out": circuit_bn_out,
               "bn_remainder": circuit_bn_remainder,
               
               "relu_out": relu_out,
               }, input_file)

os.chdir("circuits")
!./head/head_cpp/head ../head_input1.json head.wtns
!npx snarkjs groth16 prove head/circuit_final.zkey head.wtns proof.json public_test.json
os.chdir("../")

print("ok")

something
before conv2d
after conv2d
after bn
after relu
end
ok


# Testing padding over highly padded input (to try and fold the circuit using nova)

In [186]:
class SeparableConv2d(nn.Module):
    '''Separable convolution'''
    def __init__(self, in_channels, out_channels, stride=1):
        super(SeparableConv2d, self).__init__()
        self.dw_conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=0, groups=in_channels, bias=False),
            # nn.BatchNorm2d(in_channels),
            # nn.ReLU(inplace=False),
        )
        self.pw_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            # nn.BatchNorm2d(out_channels),
            # nn.ReLU(inplace=False),
        )
        
input = torch.randn((1, 3, 5, 5))
test_model = SeparableConv2d(3, 6)

In [48]:
weights = test_model.dw_conv[0].weight.squeeze().detach().numpy()
bias = torch.zeros(weights.shape[0]).numpy()

expected = test_model.dw_conv(input).detach().numpy()
print(expected.shape)

padded = F.pad(input, (1,1,1,1), "constant", 0)
padded = padded.squeeze().numpy().transpose((1, 2, 0))
weights = weights.transpose((1, 2, 0))

quantized_image = padded * 10**EXPONENT
quantized_weights = weights * 10**EXPONENT

actual, rem = DepthwiseConv(7, 7, 3, 3, 3, 1, 10**EXPONENT, quantized_image.round(), quantized_weights.round(), bias)

expected = expected.squeeze().transpose((1, 2, 0))
expected = expected * 10**EXPONENT

assert(np.allclose(expected, actual, atol=0.00001))

(1, 3, 5, 5)


In [188]:
input = torch.randn((1, 3, 5, 5))
model = SeparableConv2d(3, 6)

In [69]:
actual, rem = DepthwiseConv(7, 7, 3, 3, 3, 1, 10**EXPONENT, quantized_image.round(), quantized_weights.round(), bias)

[[[-84775161859527321744602510080,
   -280863115324426945947120868368,
   -316509859287301671233788074460],
  [-412066814475013670629964755790,
   -307907089920012261542884320079,
   288733480109590434131766261999],
  [88725553719746761339810928348,
   57251477502350644469532169751,
   167073417747014310631386775640],
  [-451547140994921540780591667954,
   1001315881174683093580772914720,
   -234907986425664705310764373280],
  [-368065005502237197896044401205,
   -49827584055244381693338364130,
   57013630737515115414246974069],
  [0, 0, 0],
  [0, 0, 0]],
 [[225228896280826144742018725083,
   937129546496503660185568529414,
   -424060379269759190852526866365],
  [1000974140683911103665188311425,
   436335591241852765388689964311,
   -210043991425453207740685880708],
  [-1325614803459234583698167899456,
   -66160312754055519347597073647,
   497778292584125615331388252343],
  [-2810563469212315852893212650,
   -203075372530997502527143471525,
   -51156603705155919392846418305],
  [730813

In [171]:
def DepthwiseConv(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):
    assert(nFilters % nChannels == 0)
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    
    # out = np.zeros((outRows, outCols, nFilters))
    out = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    remainder = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    # remainder = np.zeros((outRows, outCols, nFilters))
    
    for row in range(outRows):
        for col in range(outCols):
            for channel in range(nChannels):
                for x in range(kernelSize):
                    for y in range(kernelSize):
                        out[row][col][channel] += int(input[row*strides+x, col*strides+y, channel]) * int(weights[x, y, channel])
                
                out[row][col][channel] += int(bias[channel])
                remainder[row][col][channel] = str(int(out[row][col][channel] % n))
                out[row][col][channel] = int(out[row][col][channel] // n)
                            
    return out, remainder
    
weights = test_model.dw_conv[0].weight.squeeze().detach().numpy()
bias = torch.zeros(weights.shape[0]).numpy()

expected = test_model.dw_conv(input).detach().numpy()
print(expected.shape)

padded = F.pad(input, (1,1,1,1), "constant", 0)
padded = padded.squeeze().numpy().transpose((1, 2, 0))
weights = weights.transpose((1, 2, 0))

quantized_image = padded * 10**EXPONENT
quantized_weights = weights * 10**EXPONENT

actual, rem = DepthwiseConv(7, 7, 3, 3, 3, 1, 10**EXPONENT, quantized_image.round(), quantized_weights.round(), bias)

test_output = np.array([[[int(out) / 10**EXPONENT for out in asdf] for asdf in asdfasdf] for asdfasdf in actual])
# test_output = test_output[1:6, 1:6, :]

expected = expected.squeeze().transpose((1, 2, 0))
# expected = expected * 10**EXPONENT
assert(np.allclose(expected, test_output, atol=0.00001))

circuit_in = quantized_image.round().astype(int).astype(str).tolist()
# circuit_weights = quantized_weights.round().astype(int).astype(str).tolist()
circuit_bias = bias.round().astype(int).astype(str).tolist()

p = CIRCOM_PRIME
Input = [[[str(int(quantized_image[i][j][k].round()) % p) for k in range(quantized_image.shape[2])] for j in range(quantized_image.shape[1])] for i in range(quantized_image.shape[0])]
Weights = [[[str(int(quantized_weights[i][j][k].round()) % p) for k in range(quantized_weights.shape[2])] for j in range(quantized_weights.shape[1])] for i in range(quantized_weights.shape[0])]
# Out = [[[str(int(actual[i][j][k]) % p) for k in range(3)] for j in range(7)] for i in range(7)]

test_output = np.array([[[int(out) for out in asdf] for asdf in asdfasdf] for asdfasdf in actual])
# test_output = test_output[1:6, 1:6, :]
Out = [[[str(int(test_output[i][j][k]) % p) for k in range(3)] for j in range(5)] for i in range(5)]
Rem = [[[str(rem[i][j][k]) for k in range(3)] for j in range(5)] for i in range(5)]

input_json_path = "backbone_input.json"
with open(input_json_path, "w") as input_file:
    json.dump({
               "in": Input,
               "weights": Weights,
               "bias": circuit_bias,
               "remainder": Rem,
               "out": Out,
              },
              input_file)


# os.chdir("circuits")
# !./origDepthwiseConv2d/origDepthwiseConv2d_cpp/origDepthwiseConv2d ../backbone_input.json head.wtns
# !npx snarkjs groth16 prove ./origDepthwiseConv2d/circuit_final.zkey head.wtns proof.json public_test.json
# os.chdir("../")

print("OK")

(1, 3, 5, 5)
Signal not found
origDepthwiseConv2d: calcwit.cpp:60: uint Circom_CalcWit::getInputSignalHashPosition(u64): Assertion `false' failed.
OK


# Padded Convolution test

# 

In [198]:
input = torch.randn((1, 3, 7, 7))
model = SeparableConv2d(3, 6)

In [284]:
def PointwiseConv2d(nRows, nCols, nChannels, nFilters, strides, n, input, weights, bias):
    kernelSize = 1
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    out = [[[0 for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    remainder = [[[None for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    for row in range(outRows):
        for col in range(outCols):
            for filter in range(nFilters):
                for k in range(nChannels):
                    out[row][col][filter] += int(input[row*strides][col*strides][k]) * int(weights[k, filter])
                    
                out[row][col][filter] += int(bias[filter])
                remainder[row][col][filter] = str(int(out[row][col][filter] % n))
                out[row][col][filter] = int(out[row][col][filter] // n)
                            
    return out, remainder

In [384]:
def PaddedDepthwiseConv(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):
    assert(nFilters % nChannels == 0)
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    
    # out = np.zeros((outRows, outCols, nFilters))
    Input = [[[str(int(input[i][j][k]) % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
    Weights = [[[str(int(weights[i][j][k].round()) % p) for k in range(weights.shape[2])] for j in range(weights.shape[1])] for i in range(weights.shape[0])]
    out = [[[0 for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    Bias = bias.round().astype(int).astype(str).tolist()
    out_str = [[[str(0) for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    remainder = [[[str(0) for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    
    for row in range(outRows):
        for col in range(outCols):
            for channel in range(nChannels):
                for x in range(kernelSize):
                    for y in range(kernelSize):
                        out[row+1][col+1][channel] += int(input[row*strides+x][col*strides+y][channel]) * int(weights[x, y, channel])
                
                out[row+1][col+1][channel] += int(bias[channel])
                remainder[row+1][col+1][channel] = str(int(out[row+1][col+1][channel] % n))
                out[row+1][col+1][channel] = int(out[row+1][col+1][channel] // n)
                out_str[row+1][col+1][channel] = str(out[row+1][col+1][channel] % p)
                            
    return out, remainder
    
weights = model.dw_conv[0].weight.squeeze().detach().numpy()
bias = torch.zeros(weights.shape[0]).numpy()

print("Input shape: ", input.shape)
expected = model.dw_conv(input).detach()
print("Expected shape: ", expected.shape)

padded = F.pad(input, (1,1,1,1), "constant", 0)
padded = padded.squeeze().numpy().transpose((1, 2, 0))
weights = weights.transpose((1, 2, 0))

quantized_image = input.squeeze().numpy().transpose((1,2,0)) * 10**EXPONENT
# quantized_image = padded * 10**EXPONENT
quantized_weights = weights * 10**EXPONENT

actual, rem = PaddedDepthwiseConv(7, 7, 3, 3, 3, 1, 10**EXPONENT, quantized_image.round(), quantized_weights.round(), bias)

test_output = np.array([[[int(out) / 10**EXPONENT for out in asdf] for asdf in asdfasdf] for asdfasdf in actual])
test_output = test_output[1:6, 1:6, :]

# expected = expected.squeeze().transpose((1, 2, 0))
# expected = F.pad(expected.squeeze(), (1,1,1,1)).numpy().transpose((1,2,0))
expected = expected.squeeze().numpy().transpose((1,2,0))
print("Expected shape: ", expected.shape)
print("test_output shape: ", test_output.shape)

# expected = expected * 10**EXPONENT
assert(np.allclose(expected, test_output, atol=0.00001))
print("--------------------------------")

print("Input shape: ", input.shape)
expected = model.dw_conv(input).detach()
print("expected shape: ", expected.shape)
expected = model.dw_conv(expected).detach()
expected = expected.squeeze().numpy().transpose((1,2,0))
print("expected shape: ", expected.shape)
dw_input, rem = PaddedDepthwiseConv(7, 7, 3, 3, 3, 1, 10**EXPONENT, quantized_image.round(), quantized_weights.round(), bias)
print("1st shape: ", np.array(dw_input).shape)
actual, rem = PaddedDepthwiseConv(7, 7, 3, 3, 3, 1, 10**EXPONENT, dw_input, quantized_weights.round(), bias)
print("2nd shape: ", np.array(actual).shape)
test_output = np.array([[[int(out) / 10**EXPONENT for out in asdf] for asdf in asdfasdf] for asdfasdf in actual])
test_output = test_output[2:5, 2:5, :]


# expected = F.pad(expected.squeeze(), (2,2,2,2)).numpy().transpose((1,2,0))
print("actual shape: ", np.array(actual).shape)
print("expected shape: ", np.array(expected).shape)
assert(np.allclose(expected, test_output, atol=0.00001))

p = CIRCOM_PRIME
quantized_image = np.array(dw_input)
Input = [[[str(int(dw_input[i][j][k]) % p) for k in range(quantized_image.shape[2])] for j in range(quantized_image.shape[1])] for i in range(quantized_image.shape[0])]
Weights = [[[str(int(quantized_weights[i][j][k].round()) % p) for k in range(quantized_weights.shape[2])] for j in range(quantized_weights.shape[1])] for i in range(quantized_weights.shape[0])]
# Out = [[[str(int(actual[i][j][k]) % p) for k in range(3)] for j in range(7)] for i in range(7)]

test_output = np.array([[[int(out) for out in asdf] for asdf in asdfasdf] for asdfasdf in actual])
print("test_output shape: ", np.array(test_output).shape)
# # test_output = test_output[1:6, 1:6, :]
Out = [[[str(int(test_output[i][j][k]) % p) for k in range(test_output.shape[2])] for j in range(test_output.shape[1])] for i in range(test_output.shape[0])]
Rem = [[[str(rem[i][j][k]) for k in range(test_output.shape[2])] for j in range(test_output.shape[1])] for i in range(test_output.shape[0])]
print("OUT SHAPE: ", np.array(Out).shape)
print("REM SHAPE: ", np.array(Rem).shape)

input_json_path = "padded_input.json"
with open(input_json_path, "w") as input_file:
    json.dump({
               "in": Input,
               "dw_conv_weights": Weights,
               "dw_conv_bias": circuit_bias,
               "dw_conv_remainder": Rem,
               "dw_conv_out": Out,
              },
              input_file)


os.chdir("circuits")
!./padded/padded_cpp/padded ../padded_input.json head.wtns
# # !npx snarkjs groth16 prove ./origDepthwiseConv2d/circuit_final.zkey head.wtns proof.json public_test.json
os.chdir("../")

# print("OK")

Input shape:  torch.Size([1, 3, 7, 7])
Expected shape:  torch.Size([1, 3, 5, 5])
Expected shape:  (5, 5, 3)
test_output shape:  (5, 5, 3)
--------------------------------
Input shape:  torch.Size([1, 3, 7, 7])
expected shape:  torch.Size([1, 3, 5, 5])
expected shape:  (3, 3, 3)
1st shape:  (7, 7, 3)
2nd shape:  (7, 7, 3)
actual shape:  (7, 7, 3)
expected shape:  (3, 3, 3)
test_output shape:  (7, 7, 3)
OUT SHAPE:  (7, 7, 3)
REM SHAPE:  (7, 7, 3)
START
PRIME - 1 28948022309329048855892746252171976963363056481941647379679742748393362948096
------------------------------------------------
elemSum.out 52783729333862158060632417821
out[ 1 ][ 1 ][ 0 ]:  52783729333862
remainder[ 1 ][ 1 ][ 0 ]:  158060632417821
result:  52783729333862158060632417821
elemSum.out 381816001850244179695128192072
out[ 1 ][ 1 ][ 1 ]:  381816001850244
remainder[ 1 ][ 1 ][ 1 ]:  179695128192072
result:  381816001850244179695128192072
elemSum.out 2894802230932904885589274625217197696336305648178338365167966158605864163

In [385]:
def PaddedDepthwiseConv(nRows, nCols, nChannels, nFilters, kernelSize, strides, n, input, weights, bias):
    assert(nFilters % nChannels == 0)
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    
    # out = np.zeros((outRows, outCols, nFilters))
    Input = [[[str(int(input[i][j][k]) % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
    Weights = [[[str(int(weights[i][j][k].round()) % p) for k in range(weights.shape[2])] for j in range(weights.shape[1])] for i in range(weights.shape[0])]
    out = [[[0 for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    Bias = bias.round().astype(int).astype(str).tolist()
    out_str = [[[str(0) for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    remainder = [[[str(0) for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    
    for row in range(outRows):
        for col in range(outCols):
            for channel in range(nChannels):
                for x in range(kernelSize):
                    for y in range(kernelSize):
                        out[row+1][col+1][channel] += int(input[row*strides+x][col*strides+y][channel]) * int(weights[x, y, channel])
                
                out[row+1][col+1][channel] += int(bias[channel])
                remainder[row+1][col+1][channel] = str(int(out[row+1][col+1][channel] % n))
                out[row+1][col+1][channel] = int(out[row+1][col+1][channel] // n)
                out_str[row+1][col+1][channel] = str(out[row+1][col+1][channel] % p)
                            
    return Input, Weights, Bias, out_str, out, remainder

In [397]:
def PointwiseConv2d(nRows, nCols, nChannels, nFilters, strides, n, input, weights, bias):
    kernelSize = 1
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    # out = [[[0 for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    # remainder = [[[None for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    
    Input = [[[str(int(input[i][j][k]) % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
    # Weights = [[[str(int(weights[i][j][k].round()) % p) for k in range(weights.shape[2])] for j in range(weights.shape[1])] for i in range(weights.shape[0])]
    Weights = [[str(int(weights[i][j].round()) % p)for j in range(weights.shape[1])] for i in range(weights.shape[0])]
    out = [[[0 for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    Bias = bias.round().astype(int).astype(str).tolist()
    out_str = [[[str(0) for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    remainder = [[[str(0) for _ in range(nFilters)] for _ in range(nCols)] for _ in range(nRows)]
    
    for row in range(outRows):
        for col in range(outCols):
            for filter in range(nFilters):
                for k in range(nChannels):
                    out[row][col][filter] += int(input[row*strides][col*strides][k]) * int(weights[k, filter])
                    
                out[row][col][filter] += int(bias[filter])
                remainder[row][col][filter] = str(int(out[row][col][filter] % n))
                out[row][col][filter] = int(out[row][col][filter] // n)
                out_str[row][col][filter] = str(out[row][col][filter] % p)
                            
    # return out, remainder
    return Input, Weights, Bias, out_str, out, remainder

In [399]:
depth_weights = model.dw_conv[0].weight.squeeze().detach().numpy()
depth_bias = torch.zeros(depth_weights.shape[0]).numpy()

print("Input shape: ", input.shape)
expected = model.dw_conv(input).detach()
print("Expected shape: ", expected.shape)

depth_weights = depth_weights.transpose((1, 2, 0))

quantized_image = input.squeeze().numpy().transpose((1,2,0)) * 10**EXPONENT
# quantized_image = padded * 10**EXPONENT
quantized_depth_weights = depth_weights * 10**EXPONENT

circuit_in, circuit_depth_weights, circuit_depth_bias, circuit_depth_out, depth_out, circuit_depth_remainder = PaddedDepthwiseConv(7, 7, 3, 3, 3, 1, 10**EXPONENT, quantized_image.round(), quantized_depth_weights.round(), depth_bias)

test_output = np.array([[[int(out) / 10**EXPONENT for out in asdf] for asdf in asdfasdf] for asdfasdf in depth_out])
test_output = test_output[1:6, 1:6, :]

expected = expected.squeeze().numpy().transpose((1,2,0))

assert(np.allclose(expected, test_output, atol=0.00001))

# point_weights = model.pw_conv[0].weight.squeeze().detach().numpy()
point_weights = model.pw_conv[0].weight.detach().numpy()
print("point weights shape: ", point_weights.shape)
point_bias = torch.zeros(point_weights.shape[0]).numpy()

print("Input shape: ", input.shape)
depth_expected = model.dw_conv(input)
point_expected = model.pw_conv(depth_expected)
print("Depth Expected shape: ", depth_expected.shape)
# print("Point Expected shape: ", point_expected.shape)

point_expected = point_expected.squeeze().detach().numpy().transpose((1,2,0))
print("Point Expected shape: ", point_expected.shape)

point_weights = point_weights.transpose((2, 3, 1, 0)).squeeze()
quantized_point_weights = point_weights * 10**EXPONENT
print("point weights shape: ", quantized_point_weights.shape)

# circuit_in, circuit_depth_weights, circuit_depth_bias, circuit_depth_out, depth_out, circuit_depth_remainder = PaddedDepthwiseConv(7, 7, 3, 3, 3, 1, 10**EXPONENT, quantized_image.round(), quantized_depth_weights.round(), depth_bias)
point_input, circuit_point_weights, circuit_point_bias, circuit_point_out, point_out, circuit_point_remainder = PointwiseConv2d(7, 7, 3, 6, 1, 10**EXPONENT, depth_out, quantized_point_weights.round(), point_bias)

test_output = np.array([[[int(out) / 10**EXPONENT for out in asdf] for asdf in asdfasdf] for asdfasdf in point_out])
test_output = test_output[1:6, 1:6, :]

assert(np.allclose(point_expected, test_output, atol=0.00001))

Input shape:  torch.Size([1, 3, 7, 7])
Expected shape:  torch.Size([1, 3, 5, 5])
point weights shape:  (6, 3, 1, 1)
Input shape:  torch.Size([1, 3, 7, 7])
Depth Expected shape:  torch.Size([1, 3, 5, 5])
Point Expected shape:  (5, 5, 6)
point weights shape:  (3, 6)


In [388]:
p = CIRCOM_PRIME

input_json_path = "padded_input.json"
with open(input_json_path, "w") as input_file:
    json.dump({
               "in": circuit_in,
               "dw_conv_weights": circuit_depth_weights,
               "dw_conv_bias": circuit_depth_bias,
               "dw_conv_remainder": circuit_depth_remainder,
               "dw_conv_out": circuit_depth_out,
        
               "pw_conv_weights": circuit_depth_weights,
               "pw_conv_bias": circuit_depth_bias,
               "dw_conv_remainder": circuit_depth_remainder,
               "dw_conv_out": circuit_depth_out,
              },
              input_file)


os.chdir("circuits")
!./padded/padded_cpp/padded ../padded_input.json head.wtns
# # !npx snarkjs groth16 prove ./origDepthwiseConv2d/circuit_final.zkey head.wtns proof.json public_test.json
os.chdir("../")

# print("OK")

START
PRIME - 1 28948022309329048855892746252171976963363056481941647379679742748393362948096
------------------------------------------------
------------------------------------------------
dw_conv done
END


In [383]:
test_output = np.array([[[int(out) / 10**EXPONENT for out in asdf] for asdf in asdfasdf] for asdfasdf in depth_out])
# test_output = test_output[1:6, 1:6, :]
depth_out[1][1][0]

-614123337976230

# Padded convolution 2 iterations test

In [362]:
depth_expected.shape, point_expected.shape, input.shape, np.array(depth_out).shape, np.array(actual).shape

(torch.Size([1, 3, 5, 5]),
 (5, 5, 6),
 torch.Size([1, 3, 7, 7]),
 (7, 7, 3),
 (7, 7, 6))

In [381]:
Out[1][1][0]

'28948022309329048855892746252171976963363056481941647379679742134270024971867'

In [365]:
point_expected[0][0]

array([-0.23325375,  0.42944416,  0.88247865,  0.3639508 , -0.36691314,
       -0.5982673 ], dtype=float32)