In [1]:
import torch.nn as nn
import torch.nn.functional as F

In [2]:
import torch

In [3]:
type(torch.rand([2,2]))

torch.Tensor

In [11]:
x = torch.rand([1, 2, 4, 4])
x

tensor([[[[0.5105, 0.7162, 0.6452, 0.7423],
          [0.6244, 0.7486, 0.3236, 0.6378],
          [0.8262, 0.6350, 0.8745, 0.0091],
          [0.6302, 0.1055, 0.9326, 0.8065]],

         [[0.2670, 0.3333, 0.4485, 0.5862],
          [0.1464, 0.8743, 0.3858, 0.2479],
          [0.5450, 0.2510, 0.4583, 0.8675],
          [0.4275, 0.6116, 0.8135, 0.3809]]]])

In [12]:
x = x.view(1,2,4//2,2,4//2,2)
x

tensor([[[[[[0.5105, 0.7162],
            [0.6452, 0.7423]],

           [[0.6244, 0.7486],
            [0.3236, 0.6378]]],


          [[[0.8262, 0.6350],
            [0.8745, 0.0091]],

           [[0.6302, 0.1055],
            [0.9326, 0.8065]]]],



         [[[[0.2670, 0.3333],
            [0.4485, 0.5862]],

           [[0.1464, 0.8743],
            [0.3858, 0.2479]]],


          [[[0.5450, 0.2510],
            [0.4583, 0.8675]],

           [[0.4275, 0.6116],
            [0.8135, 0.3809]]]]]])

In [13]:
x = torch.sum(x,dim=(3,5))
x

tensor([[[[2.5997, 2.3489],
          [2.1969, 2.6227]],

         [[1.6211, 1.6684],
          [1.8351, 2.5202]]]])

In [10]:
x.shape

torch.Size([1, 2, 2, 2])

In [32]:
class CUSTOM_POOLING(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.weights = nn.Parameter(torch.randn(out_channels))
        self.bias = nn.Parameter(torch.randn(out_channels))
    
    def forward(self, x):
        B, C, H, W = x.size()
        assert H%2==0 and W%2==0

        x = x.view(B, C, H//2, 2, W//2, 2)
        x = torch.sum(x, dim=(3,5))
        x = x * self.weights.view(1,-1,1,1) + self.bias.view(1,-1,1,1)

        return x

In [33]:
x = torch.randn([1, 3, 4, 4])
y = CUSTOM_POOLING(1,3)
y.forward(x)

tensor([[[[ 0.8743,  2.1207],
          [ 1.7699, -2.8809]],

         [[ 7.5807, -4.9952],
          [ 3.1645,  2.0584]],

         [[ 1.2447,  1.1393],
          [ 1.3759,  1.0937]]]], grad_fn=<AddBackward0>)

In [26]:
table1 = [
    [0,1,2],
    [1,2,3],
    [2,3,4],
    [3,4,5],
    [4,5,0],
    [5,0,1],
    [0,1,2,3],
    [1,2,3,4],
    [2,3,4,5],
    [3,4,5,0],
    [4,5,0,1],
    [5,0,1,2],
    [0,1,3,4],
    [1,2,4,5],
    [0,2,3,5],
    [0,1,2,3,4,5]
]

In [28]:
class SPARSE_CONV(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, connection_scheme):
        super().__init__()
        self.connection_scheme = connection_scheme
        self.convs = nn.ModuleList(
            [nn.Conv2d(len(S2_channels), 1, kernel_size) for S2_channels in connection_scheme]
        )
    
    def forward(self, x):
        output = []
        for i, S2_channels in enumerate(self.connection_scheme):
            output.append(
                self.convs[i](x[:, S2_channels, :, :])
            )
        
        return torch.cat(output, dim=1)

In [29]:
x = torch.randn(1, 6, 14, 14)
m = SPARSE_CONV(6, 16, 5, connection_scheme=table1)
y = m(x)
y.shape

torch.Size([1, 16, 10, 10])

In [36]:
temp = torch.randn(2, 3, 1, 1)
temp.squeeze()

tensor([[ 0.5441, -0.2279,  0.4016],
        [ 0.2634,  0.3783, -0.7449]])

In [None]:
class RBF(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__(*args, **kwargs)

In [None]:
class MOCK_LENET5(nn.Module):
    def __init__(self):
        super().__init__()
        self.C1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1)
        self.S2 = CUSTOM_POOLING(int_channels=6, out_channels=6)
        self.C3 = SPARSE_CONV(in_channels=6, out_channels=16, kernel_size=5)
        self.S4 = CUSTOM_POOLING(in_channels=16, out_channels=16)
        self.C5 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5)
        
        