In [1]:
import torch
import torch.nn as nn

In [2]:
torch.manual_seed(0)

<torch._C.Generator at 0x76ce4ffccc50>

In [3]:
channels = 1
height = 28
width = 28
batch = 1

input_shape = (batch, channels, height, width)

In [4]:
x = torch.randn(input_shape)

## Network Module

In [5]:
class Diverge(nn.Module):
    def __init__(self, pre_pooling_block, pooling) -> None:
        super().__init__()

        self.pre_pooling_block = pre_pooling_block
        self.pooling = pooling

    def forward(self, x):
        x = self.pre_pooling_block(x)
        x = self.pooling(x)
        return x

In [6]:
class CNNLayerBlock(nn.Module):
    def __init__(self, pre_activation, activation, pooling = None) -> None:
        super().__init__()

        self.pre_activation = pre_activation
        self.activation = activation
        self.pooling = pooling

    def forward(self, x):
        x = self.pre_activation(x)
        x = self.activation(x)
        if self.pooling:
            x = self.pooling(x)
        return x

In [7]:
class CNN(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        # CNNLayerBlock 1 - this block wants to branch out to target two different layers.

        self.cnnlblk_1 = CNNLayerBlock(
            pre_activation              = nn.Conv2d(1, 30, 2, 1, bias=False), 
            activation                  = nn.ReLU()
            )
        
        self.diverge_1a = Diverge(      # pooling 1a.
            pre_pooling_block           = self.cnnlblk_1,
            pooling                     = nn.AvgPool2d(2,2)
            )
        
        self.diverge_1b = Diverge(      # pooling 1b.
            pre_pooling_block           = self.cnnlblk_1, 
            pooling                     = nn.AvgPool2d(4,4)
            )

        # CNNLayerBlock 2.

        self.cnnlblk_2 = CNNLayerBlock(
            pre_activation              = nn.Conv2d(30, 30, 2, 1, bias=False),
            activation                  = nn.ReLU(),
            pooling                     = nn.AvgPool2d(2,2)
            )

        # CNNLayerBlock 3.

        self.cnnlblk_3 = CNNLayerBlock(
            pre_activation              = nn.Conv2d(30, 1, 3, 1, bias=False), 
            activation                  = nn.ReLU()
            )

        # DynapcnnLayer 4.  

        self.cnnlblk_4 = CNNLayerBlock(
            pre_activation              = nn.Linear(16, 500, bias=False), 
            activation                  = nn.ReLU()
            )
        
        # DynapcnnLayer 5.

        self.cnnlblk_5 = CNNLayerBlock(
            pre_activation              = nn.Linear(500, 10, bias=False), 
            activation                  = nn.ReLU()
            )

        # 'support' layers

        self.flat = nn.Flatten()


    def forward(self, x):
        # CNNLayerBlock 1.
        div1a_out = self.diverge_1a(x)
        div1b_out = self.diverge_1b(x)
        
        # CNNLayerBlock 2.
        blk2_out = self.cnnlblk_2(div1a_out)

        # CNNLayerBlock 3.
        blk3_out = self.cnnlblk_3(blk2_out + div1b_out)

        # CNNLayerBlock 4.
        blk4_out = self.cnnlblk_4(self.flat(blk3_out))

        # CNNLayerBlock 5.
        blk5_out = self.cnnlblk_5(blk4_out)

        return blk5_out

In [8]:
cnn = CNN()

In [9]:
# CNNLayerBlock 1.
div1a_out = cnn.diverge_1a(x)
div1b_out = cnn.diverge_1b(x)

# CNNLayerBlock 2.
blk2_out = cnn.cnnlblk_2(div1a_out)

# CNNLayerBlock 3.
print(blk2_out.shape, div1b_out.shape)
blk3_out = cnn.cnnlblk_3(blk2_out + div1b_out)

# CNNLayerBlock 4.
blk4_out = cnn.cnnlblk_4(cnn.flat(blk3_out))

# CNNLayerBlock 5.
blk5_out = cnn.cnnlblk_5(blk4_out)

torch.Size([1, 30, 6, 6]) torch.Size([1, 30, 6, 6])
