In [1]:
import torch
import torch.nn as nn
import import_ipynb
from utils import generate_parameter_vector, scaled_tanh

In [2]:
class CUSTOM_POOLING(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.weights = nn.Parameter(torch.randn(out_channels))
        self.bias = nn.Parameter(torch.randn(out_channels))
    
    def forward(self, x):
        B, C, H, W = x.size()
        assert H%2==0 and W%2==0

        x = x.view(B, C, H//2, 2, W//2, 2)
        x = torch.sum(x, dim=(3,5))
        x = x * self.weights.view(1,-1,1,1) + self.bias.view(1,-1,1,1)

        return x

In [3]:
table1 = [
    [0,1,2],
    [1,2,3],
    [2,3,4],
    [3,4,5],
    [4,5,0],
    [5,0,1],
    [0,1,2,3],
    [1,2,3,4],
    [2,3,4,5],
    [3,4,5,0],
    [4,5,0,1],
    [5,0,1,2],
    [0,1,3,4],
    [1,2,4,5],
    [0,2,3,5],
    [0,1,2,3,4,5]
]

In [4]:
class SPARSE_CONV(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, connection_scheme):
        super().__init__()
        self.connection_scheme = connection_scheme
        self.convs = nn.ModuleList(
            [nn.Conv2d(len(S2_channels), 1, kernel_size) for S2_channels in connection_scheme]
        )
    
    def forward(self, x):
        output = []
        for i, S2_channels in enumerate(self.connection_scheme):
            output.append(
                self.convs[i](x[:, S2_channels, :, :])
            )
        
        return torch.cat(output, dim=1)

In [5]:
class RBF(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.parameter_vector = nn.Parameter(
            torch.stack(
                [generate_parameter_vector(0.2, length=in_channels) for _ in range(out_channels)]
            ),
            requires_grad=False
        )
    
    def forward(self, x):
        batch_size = x.shape[0]

        expanded_param = self.parameter_vector.unsqueeze(0).expand(batch_size, -1, -1)
        x = x.unsqueeze(dim=1)

        output = torch.sum(torch.square(x-expanded_param), dim=2, keepdim=True).squeeze(dim=2)
        return output

In [6]:
class MOCK_LENET5(nn.Module):
    def __init__(self):
        super(MOCK_LENET5, self).__init__()
        self.C1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1)
        self.S2 = CUSTOM_POOLING(in_channels=6, out_channels=6)
        self.C3 = SPARSE_CONV(in_channels=6, out_channels=16, kernel_size=5, connection_scheme=table1)
        self.S4 = CUSTOM_POOLING(in_channels=16, out_channels=16)
        self.C5 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5)
        self.F6 = nn.Linear(in_features=120, out_features=84)
        self.RBF = RBF(in_channels=84, out_channels=10)
        self.dense = nn.Linear(in_features=84, out_features=10)
    
    def forward(self, x):
        x = scaled_tanh(input=self.C1(x))
        x = scaled_tanh(input=self.S2(x))
        x = scaled_tanh(input=self.C3(x))
        x = scaled_tanh(input=self.S4(x))
        x = scaled_tanh(input=self.C5(x))
        x = x.squeeze()
        x = scaled_tanh(input=self.F6(x))
        output = self.RBF(x)

        return output

        
        