In [1]:
import torch

In [2]:
class ConvNet(torch.nn.Module):

  class block(torch.nn.Module):
    def __init__(self, in_channels, out_channels, stride):
      super().__init__()
      kernel_size = 3
      padding = (kernel_size-1)//2
      self.block_model = torch.nn.Sequential( # model
          torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), # only stride the first layer of the block.
          torch.nn.BatchNorm2d(out_channels), # normalization. We normalize in between Conv and Relu to deal with bias issues.If we apply before Conv we will have to mess with bias.
          torch.nn.ReLU(),
          torch.nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding),
          torch.nn.BatchNorm2d(out_channels),
          torch.nn.ReLU(),
          torch.nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding),
          torch.nn.BatchNorm2d(out_channels),
          torch.nn.ReLU(),
      )

      if in_channels != out_channels: # Residual Connections
        self.skip = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)
      else:
        self.skip = torch.nn.Identity()

    def forward(self, x):
      return self.block_model(x) + self.skip(x)


  def __init__(self, channels_l0 = 64, n_blocks = 3):
    super().__init__()
    cnn_layers = [
        torch.nn.Conv2d(in_channels=3, out_channels=channels_l0, kernel_size=11, stride=2, padding = (11-1)//2), # blow up the first layer of the network to maintain gradients and create deep networks.
        torch.nn.ReLU(),
    ]
    c1 = channels_l0 # input channels
    for _ in range(n_blocks):
      c2 = c1*2
      cnn_layers.append(self.block(c1,c2,stride =2))
      c1 = c2
    cnn_layers.append(torch.nn.Conv2d(c1,1,kernel_size=1)) # one cross on conv. classifier layer.
    #cnn_layers.append(torch.nn.AdaptiveAvgPool2d(1)) # this will average all outouts. and used for one single classification. if you want to keep all outputs you can not add this line.
    self.model = torch.nn.Sequential(*cnn_layers)

  def forward(self, x):
    return self.model(x)

x = torch.randn(1,3,64,64)
net = ConvNet()
print(net(x).shape)

torch.Size([1, 1, 4, 4])


In [3]:
print(net)

ConvNet(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(2, 2), padding=(5, 5))
    (1): ReLU()
    (2): block(
      (block_model): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU()
        (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (8): ReLU()
      )
      (skip): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
    )
    (3): block(
      (block_model): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): BatchNorm2d(256, eps=