In [74]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json

In [75]:
p = 21888242871839275222246405745257275088548364400416034343698204186575808495617

class SeparableConv2D(nn.Module):
    '''Separable convolution'''
    def __init__(self, in_channels, out_channels, stride=1):
        super(SeparableConv2D, self).__init__()
        self.dw_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=False)
        self.pw_conv =  nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False)

    def forward(self, x):
        x = self.dw_conv(x)
        x = self.pw_conv(x)
        return x

In [76]:
input = torch.randn((1, 3, 5, 5))
model = SeparableConv2D(3, 6)

In [77]:
def PointwiseConv2dInt(nRows, nCols, nChannels, nFilters, strides, n, input, weights, bias):
    kernelSize = 1
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    
    Input = [[[str(input[i][j][k] % p) for k in range(nChannels)] for j in range(nCols)] for i in range(nRows)]
    Weights = [[str(weights[k][l] % p) for l in range(nFilters)] for k in range(nChannels)]
    Bias = [str(bias[i] % p) for i in range(nFilters)]
    out = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    str_out = [[[0 for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    remainder = [[[None for _ in range(nFilters)] for _ in range(outCols)] for _ in range(outRows)]
    
    for row in range(outRows):
        for col in range(outCols):
            for filter in range(nFilters):
                for channel in range(nChannels):
                    out[row][col][filter] += int(input[row*strides][col*strides][channel]) * int(weights[channel][filter])
                            
                out[row][col][filter] += int(bias[filter])
                remainder[row][col][filter] = str(int(out[row][col][filter] % n))
                out[row][col][filter] = str(out[row][col][filter] // n % p)
    return Input, Weights, Bias, out, remainder
    
def PointwiseConv2d(nRows, nCols, nChannels, nFilters, strides, n, input, weights, bias):
    kernelSize = 1
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    out = np.zeros((outRows, outCols, nFilters))
    for row in range(outRows):
        for col in range(outCols):
            for filter in range(nFilters):
                for k in range(nChannels):
                    out[row, col, filter] += input[row*strides, col*strides, k] * weights[k, filter]
                    
                out[row][col][filter] += bias[filter]
                out[row][col][filter] = out[row][col][filter] / n
                            
    return out

EXPONENT = 8
weights = model.pw_conv.weight.detach().numpy()
print(f"{weights.shape=}")
bias = torch.zeros(weights.shape[0]).numpy()

expected = model.pw_conv(input).detach().numpy()

padded = input.squeeze().numpy().transpose((1, 2, 0))
print(padded.shape)
weights = weights.transpose((2, 3, 1, 0)).squeeze()

actual = PointwiseConv2d(5, 5, 3, 6, 1, 1, padded, weights, bias)

expected = expected.squeeze().transpose((1, 2, 0))

assert(np.allclose(expected, actual, atol=0.00001))

weights.shape=(6, 3, 1, 1)
(5, 5, 3)


In [78]:
EXPONENT = 15

weights = model.pw_conv.weight.detach().numpy()
bias = torch.zeros(weights.shape[0]).numpy()

expected = model.pw_conv(input).detach().numpy()

weights = weights.transpose((2, 3, 1, 0)).squeeze()

quantized_image = input.squeeze().numpy().transpose((1, 2, 0)) * 10**EXPONENT
quantized_weights = weights * 10**EXPONENT
print(f"{quantized_image.shape=}")
print(f"{quantized_weights.shape=}")

actual = PointwiseConv2d(5, 5, 3, 6, 1, 10**EXPONENT, quantized_image.round(), quantized_weights.round(), bias)

actual = actual / 10**(EXPONENT)
print(f"{actual.shape=}")
print(f"{expected.shape=}")

expected = expected.squeeze().transpose((1, 2, 0))
print(f"{expected.shape=}")

assert(np.allclose(expected, actual, atol=0.00001))

quantized_image.shape=(5, 5, 3)
quantized_weights.shape=(3, 6)
actual.shape=(5, 5, 6)
expected.shape=(1, 6, 5, 5)
expected.shape=(5, 5, 6)


In [79]:
q_input, q_weights, str_bias, str_actual, rem  = PointwiseConv2dInt(5, 5, 3, 3, 1, 10**EXPONENT, quantized_image.round().astype(int), quantized_weights.round().astype(int), bias.astype(int))

input_json_path = "pointwiseConv2D_input.json"
with open(input_json_path, "w") as input_file:
    json.dump({"in": q_input,
               "weights": q_weights,
               "remainder": rem,
               "out": str_actual,
               "bias": str_bias,
              },
              input_file)