In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json

In [2]:
class SeparableConv2D(nn.Module):
    '''Separable convolution'''
    def __init__(self, in_channels, out_channels, stride=1):
        super(SeparableConv2D, self).__init__()
        self.dw_conv = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=False),
        )
        self.pw_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=False),
        )

    def forward(self, x):
        x = self.dw_conv(x)
        x = self.pw_conv(x)
        return x

In [4]:
input = torch.randn((1, 8, 32, 32))
# print(input)

model = SeparableConv2D(8, 8)
expected = model(input)

In [15]:
def PointwiseConv2d(nRows, nCols, nChannels, nFilters, strides, input, weights, bias):
    kernelSize = 1
    outRows = (nRows - kernelSize)//strides + 1
    outCols = (nCols - kernelSize)//strides + 1
    out = np.zeros((outRows, outCols, nFilters))
    for filter in range(nFilters):
        for row in range(outRows):
            for col in range(outCols):
                for k in range(nChannels):
                    out[row, col, filter] += input[row*strides, col*strides, k] * weights[k, filter]
                            
    return out

weights = model.pw_conv[0].weight.detach().numpy()
print(f"{weights.shape=}")
bias = torch.zeros(weights.shape[0]).numpy()
input = torch.randn((1, 8, 32, 32))

expected = model.pw_conv[0](input).detach().numpy()

# # Converting to H x W x C
# padded = F.pad(input, (1,1,1,1), "constant", 0)
padded = input.squeeze().numpy().transpose((1, 2, 0))
weights = weights.transpose((2, 3, 1, 0)).squeeze()

actual = PointwiseConv2d(32, 32, 8, 8, 1, padded, weights, bias)

expected = expected.squeeze().transpose((1, 2, 0))

assert(np.allclose(expected, actual, atol=0.00001))

weights.shape=(8, 8, 1, 1)
