In [1]:
import torch
import torch.nn as nn
import import_ipynb
from utils import generate_parameter_vector, scaled_tanh

In [2]:
x = torch.rand([1, 2, 4, 4])
x

tensor([[[[0.3769, 0.4346, 0.1880, 0.9186],
          [0.7414, 0.3024, 0.8848, 0.4306],
          [0.3657, 0.5118, 0.7352, 0.2907],
          [0.8290, 0.3280, 0.5288, 0.2291]],

         [[0.1441, 0.9879, 0.5977, 0.2393],
          [0.4624, 0.5259, 0.4718, 0.6364],
          [0.0130, 0.8579, 0.7141, 0.9142],
          [0.8387, 0.5812, 0.9526, 0.0079]]]])

In [3]:
x = x.view(1,2,4//2,2,4//2,2)
x

tensor([[[[[[0.3769, 0.4346],
            [0.1880, 0.9186]],

           [[0.7414, 0.3024],
            [0.8848, 0.4306]]],


          [[[0.3657, 0.5118],
            [0.7352, 0.2907]],

           [[0.8290, 0.3280],
            [0.5288, 0.2291]]]],



         [[[[0.1441, 0.9879],
            [0.5977, 0.2393]],

           [[0.4624, 0.5259],
            [0.4718, 0.6364]]],


          [[[0.0130, 0.8579],
            [0.7141, 0.9142]],

           [[0.8387, 0.5812],
            [0.9526, 0.0079]]]]]])

In [4]:
x = torch.sum(x,dim=(3,5))
x

tensor([[[[1.8552, 2.4220],
          [2.0345, 1.7838]],

         [[2.1203, 1.9452],
          [2.2908, 2.5887]]]])

In [5]:
x.shape

torch.Size([1, 2, 2, 2])

In [6]:
class CUSTOM_POOLING(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.weights = nn.Parameter(torch.randn(out_channels))
        self.bias = nn.Parameter(torch.randn(out_channels))
    
    def forward(self, x):
        B, C, H, W = x.size()
        assert H%2==0 and W%2==0

        x = x.view(B, C, H//2, 2, W//2, 2)
        x = torch.sum(x, dim=(3,5))
        x = x * self.weights.view(1,-1,1,1) + self.bias.view(1,-1,1,1)

        return x

In [7]:
x = torch.randn([1, 3, 4, 4])
y = CUSTOM_POOLING(1,3)
y.forward(x)

tensor([[[[ 0.2035,  2.1864],
          [ 0.5668,  1.7187]],

         [[ 0.9794,  1.4307],
          [ 2.9042,  2.5593]],

         [[ 0.7545,  4.1771],
          [-1.8419, -1.2931]]]], grad_fn=<AddBackward0>)

In [8]:
table1 = [
    [0,1,2],
    [1,2,3],
    [2,3,4],
    [3,4,5],
    [4,5,0],
    [5,0,1],
    [0,1,2,3],
    [1,2,3,4],
    [2,3,4,5],
    [3,4,5,0],
    [4,5,0,1],
    [5,0,1,2],
    [0,1,3,4],
    [1,2,4,5],
    [0,2,3,5],
    [0,1,2,3,4,5]
]

In [9]:
class SPARSE_CONV(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, connection_scheme):
        super().__init__()
        self.connection_scheme = connection_scheme
        self.convs = nn.ModuleList(
            [nn.Conv2d(len(S2_channels), 1, kernel_size) for S2_channels in connection_scheme]
        )
    
    def forward(self, x):
        output = []
        for i, S2_channels in enumerate(self.connection_scheme):
            output.append(
                self.convs[i](x[:, S2_channels, :, :])
            )
        
        return torch.cat(output, dim=1)

In [10]:
x = torch.randn(1, 6, 14, 14)
m = SPARSE_CONV(6, 16, 5, connection_scheme=table1)
y = m(x)
y.shape

torch.Size([1, 16, 10, 10])

In [11]:
temp = torch.randn(2, 3, 1, 1)
temp.squeeze()

tensor([[ 0.0328,  0.2197, -1.2524],
        [ 0.5077,  0.5836,  1.6626]])

In [12]:
torch.stack([generate_parameter_vector(0.5, 4) for _ in range(4)])

tensor([[ 1., -1., -1.,  1.],
        [-1.,  1., -1.,  1.],
        [ 1., -1.,  1., -1.],
        [-1.,  1.,  1., -1.]])

In [37]:
class RBF(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.parameter_vector = nn.Parameter(
            torch.stack(
                [generate_parameter_vector(0.2, length=in_channels) for _ in range(out_channels)]
            )
        )
    
    def forward(self, x):
        # print('x ', x.shape)
        # print('param: ', self.parameter_vector.shape)
        # print('x-param: ', (x-self.parameter_vector).shape)
        # print('sum(x-param:) ', torch.sum(x-self.parameter_vector, dim=1), torch.sum((x-self.parameter_vector),dim=1, keepdim=True).shape)
        output = torch.sum(torch.square(x-self.parameter_vector), dim=1, keepdim=True)
        return output

In [38]:
x = torch.rand(4)
m = RBF(4, 3)
y = m(x)
print(y, y.shape)



start


x  torch.Size([4])
param:  torch.Size([3, 4])
x-param:  torch.Size([3, 4])


end


sum(x-param:)  tensor([-0.1607, -2.1607, -0.1607], grad_fn=<SumBackward1>) torch.Size([3, 1])
tensor([[2.2292],
        [1.5074],
        [5.2928]], grad_fn=<SumBackward1>) torch.Size([3, 1])


In [39]:
class MOCK_LENET5(nn.Module):
    def __init__(self):
        super().__init__()
        self.C1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1)
        self.S2 = CUSTOM_POOLING(in_channels=6, out_channels=6)
        self.C3 = SPARSE_CONV(in_channels=6, out_channels=16, kernel_size=5, connection_scheme=table1)
        self.S4 = CUSTOM_POOLING(in_channels=16, out_channels=16)
        self.C5 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5)
        self.F6 = nn.Linear(in_features=120, out_features=84)
        self.RBF = RBF(in_channels=84, out_channels=10)
    
    def forward(self, x):
        x = scaled_tanh(input=self.C1(x))
        x = scaled_tanh(input=self.S2(x))
        x = scaled_tanh(input=self.C3(x))
        x = scaled_tanh(input=self.S4(x))
        x = scaled_tanh(input=self.C5(x))
        x = x.squeeze()
        print(x.shape)
        x = scaled_tanh(input=self.F6(x))
        print(x.shape)
        output = self.RBF(x)

        return output

        
        

In [40]:
input = torch.rand(1,1,32,32)
model = MOCK_LENET5()
output = model(input)

torch.Size([1, 120, 1, 1])
torch.Size([120])
torch.Size([84])


start


x  torch.Size([84])
param:  torch.Size([10, 84])
x-param:  torch.Size([10, 84])


end


sum(x-param:)  tensor([-37.9054, -57.9054, -51.9054, -55.9054, -47.9054, -41.9054, -45.9054,
        -45.9054, -55.9054, -61.9054], grad_fn=<SumBackward1>) torch.Size([10, 1])


In [42]:
output.shape

torch.Size([10, 1])

In [30]:
output

tensor([[177.2170],
        [181.4866],
        [149.2645],
        [164.0868],
        [164.3589],
        [159.2939],
        [157.3618],
        [171.2735],
        [178.8741],
        [164.8284]], grad_fn=<SumBackward1>)