In [None]:
import torch
import torchvision
import torch.nn as nn

In [None]:
class residual_block(nn.Module):
    def __init__(self, width):
      super(residual_block, self).__init__()

      self.width = width
      self.block = nn.Sequential(
          nn.Conv2d(width, width, 3, padding = 1),
          nn.ReLU(),
          nn.Conv2d(width, width, 3, padding = 1),
          nn.ReLU(),
          nn.Conv2d(width, width, 3, padding = 1)
      )
    def forward(self, x):
      return x + self.block(x)

In [None]:
class convnet(nn.Module):
    def __init__(self):
      super(convnet, self).__init__()

      # input channels, output channels, conv size, padding, (etc.)
      self.layer1 = nn.Conv2d(1, 8, 3, padding = 1)
      self.layer2 = nn.Conv2d(8, 16, 3, padding = 1)
      self.layer3 = residual_block(16)
      self.layer4 = nn.Linear(16 * 2 * 2, 10)

      self.pool = nn.MaxPool2d(2)
      self.relu = nn.ReLU()
      self.softmax = nn.Softmax(dim = 1)

      self.batchnorm1 = nn.BatchNorm2d(8)
      self.batchnorm2 = nn.BatchNorm2d(16)

    def forward(self, x):
      # layer 1
      x = self.layer1(x)
      x = self.batchnorm1(x)
      x = self.pool(x)
      x = self.relu(x)

      # layer 2
      x = self.layer2(x)
      x = self.pool(x)
      x = self.relu(x)

      # layer 2
      x = self.layer3(x)
      x = self.batchnorm2(x)
      x = self.relu(x)

      # layer 4
      x = x.reshape(-1, 16 * 2 * 2)
      x = self.layer4(x)
      x = self.softmax(x)
      return x