In [2]:
import torch
from torch import nn

In [None]:
class IntermediateBlock(nn.Module):

  """
  Class for the intermediate convolutional layer with architecture specifed by
  assignment.
  """

  def __init__(
      self,
      input_channels = 3,
      output_channels = 256,
      output_volume = 32,
      units = 4,
      groups = True
      ):
    """
    input_channels:
      The number of channels for incoming image
    output_channels:
      The number of channels for outgoing image
    output_volume:
      The size of the output image
    units:
      The number of convolutional units within the block
    groups:
      Whether to split inputs to groups for convolving
    """

    super(IntermediateBlock, self).__init__()

    self.units = units

    if not groups:
      self.groups = 1
    else:
      self.groups = input_channels

    # Set up the feed-forward network with the same number of outputs as units
    # for weighting each convolutional unit within block
    self.fc = nn.Sequential(
        nn.LayerNorm(input_channels),
        nn.Linear(input_channels, 256),
        nn.GELU(),
        nn.Dropout(0.2),
        nn.Linear(256, self.units),
        nn.Softmax(dim = 1)
    )

    # Each convolutional unit applies GELU activation function to outputs
    # Adaptive Average Pooling is used to ensure output dimensionality is
    # consistent
    self.conv_layer = nn.ModuleList([
        nn.Sequential(
          nn.Conv2d(
              input_channels,
              output_channels,
              groups = self.groups,
              kernel_size = 8,
              stride = 1,
              padding = "same"
              ),
          nn.GELU(),
          nn.AdaptiveAvgPool2d(output_volume)
        )
    ] * units)

  def forward(self, x):

    # Calculate the mean of input x and pass it through a feed forward network
    # This returns weights specifying which convolutional units to weigh more
    m = torch.mean(x, dim = [2, 3])
    a = self.fc(m)

    # Calculate the output of each convolutional layer and multiply it by
    # the corresponding coefficient generated by the feed forward network above
    # Stack and sum outputs of each convolutional layer
    conv_layer_out = []

    for coef, unit in zip(a.T, self.conv_layer):
      conv_out = unit(x)
      coef = coef.reshape(-1,1,1,1)
      conv_layer_out.append(torch.mul(coef, conv_out))

    return torch.sum(torch.stack(conv_layer_out,dim = 0), dim = 0)

class OutputBlock(nn.Module):

  """
  Class for the output bloc with architecture specifed by assignment.
  """

  def __init__(self, input_channels, input_volume):

    super(OutputBlock, self).__init__()

    self.input_channels = input_channels
    self.input_volume = input_volume

    # Calculate size of flattened image for input layer to feed forward network
    self.first_layer_inputs = (input_volume * input_volume * input_channels)

    # Set up feed forward network
    self.fc = nn.Sequential(
      nn.Flatten(),
      nn.LayerNorm(self.first_layer_inputs),
      nn.Linear(self.first_layer_inputs, 10)
  )

  def forward(self, x):

    # Calculate logits from feed forward network
    logits = self.fc(x)

    return logits

class Residual(nn.Module):
  """
  Wrapper class which pools the original input and adds it to the output
  of the wrapped module.

  https://github.com/kentaroy47/vision-transformers-cifar10/blob/main/models/convmixer.py
  """
  def __init__(self, fn):
    """
    fn:
      The module being wrapped
    """
    super().__init__()
    self.fn = fn

  def forward(self, x):

    # Calculate the output from the forward pass of the wrapped module
    output = self.fn(x)

    # Change the dimensionality of x to match that of the output
    residual = nn.AdaptiveMaxPool2d(output.shape[2])(x)

    # Sum the output with the residual
    return output + residual

class SelfAttention(nn.Module):
    """
    Self attention Layer

    https://github.com/heykeetae/Self-Attention-GAN/blob/master/sagan_models.py
    """
    def __init__(self,input_channels):

        super(SelfAttention, self).__init__()
        self.input_channels = input_channels

        self.query_conv = nn.Sequential(
            nn.Conv2d(
              in_channels = input_channels,
              out_channels = input_channels//8,
              kernel_size= 1
              ),
            nn.GELU()
        )
        self.key_conv = nn.Sequential(
            nn.Conv2d(
              in_channels = input_channels,
              out_channels = input_channels//8,
              kernel_size= 1
              ),
            nn.GELU()
        )
        self.value_conv = nn.Sequential(
            nn.Conv2d(
              in_channels = input_channels,
              out_channels = input_channels,
              kernel_size= 1
              ),
            nn.GELU()
        )

        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax  = nn.Softmax(dim=-1)

    def forward(self,x):
        """
            inputs :
                x : input feature maps (B x C x W x H)
            returns :
                out : self attention value + input feature
                attention: B x N x N (N is Width*Height)
        """

        batchsize, channels, width, height = x.size()

        proj_query  = self.query_conv(x).view(
            batchsize,
            -1,
            width * height
            ).permute(0, 2, 1)

        proj_key =  self.key_conv(x).view(
            batchsize,
            -1,
            width * height
            )

        energy = torch.bmm(proj_query, proj_key)

        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(batchsize, -1, width * height) # B X C X N

        out = torch.bmm(proj_value, attention.permute(0,2,1) )
        out = out.view(batchsize, channels, width, height)

        out = self.gamma * out + x
        return out

In [None]:
class AdvancedCNN(nn.Module):

  """
  Class defining advanced convolutional architecture based on that specified in
  assignment
  """

  def __init__(self, block_config, self_attention = False):

      """
      block_config:
        Dictionary containing the configuration for each intermediate block
      """

      super(AdvancedCNN, self).__init__()

      # Set up empty sequential container to iteratively add blocks based on
      # config dictionary
      self.spine = nn.Sequential()

      # Add first intermediate block to expand channels from 3 to 256
      self.spine.append(nn.Sequential(
        IntermediateBlock(
            input_channels = 3,
            output_channels = 256,
            output_volume = 32,
            units = 4,
            groups = False
        ),
        nn.BatchNorm2d(256)
        ))

      # Iterate through config dictionary and add each block to spine
      for i, block in enumerate(block_config.values()):

        # Add intermediate block wrapped with the residual class
        self.spine.append(
            nn.Sequential(
              Residual(
                  IntermediateBlock(**block)
                  ),
              nn.BatchNorm2d(256)
              )
            )

        # Add convolutional layer wrapped with residual layer
        self.spine.append(
            nn.Sequential(
                nn.Conv2d(256, 256, kernel_size=1),
                nn.GELU(),
                nn.BatchNorm2d(256)
                )
            )

        if i in (0, 1, 2) and self_attention:

          self.spine.append(
              nn.Sequential(
                  SelfAttention(256),
                  nn.BatchNorm2d(256)
              )
          )

      # Add adaptive pooling layer to reduce each channel to a number
      # representing one feature
      self.spine.append(nn.AdaptiveAvgPool2d((1,1)))

      # Add output block
      self.spine.append(OutputBlock(256, 1))

  def forward(self, x):

    return self.spine(x)

In [None]:
class AdvancedCNN(nn.Module):

    def __init__(self):
        super(AdvancedCNN, self).__init__()

        self.conv1 = nn.Sequential(
            IntermediateBlock(
                input_channels = 3,
                output_channels = 256,
                output_volume = 32,
                units = 4,
                groups = False
            ),
            nn.BatchNorm2d(256)
            )
        
        self.conv2 = nn.Sequential(
            IntermediateBlock(
                input_channels = 256,
                output_channels = 256,
                output_volume = 32,
                units = 4
            ),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, 1),
            nn.GELU(),
            nn.BatchNorm2d(256)
            )
        
        self.conv3 = nn.Sequential(
            IntermediateBlock(
                input_channels = 256,
                output_channels = 256,
                output_volume = 26,
                units = 4
            ),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, 1),
            nn.GELU(),
            nn.BatchNorm2d(256)
            )
        
        self.conv4 = nn.Sequential(
            IntermediateBlock(
                input_channels = 256,
                output_channels = 256,
                output_volume = 20,
                units = 4
            ),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, 1),
            nn.GELU(),
            nn.BatchNorm2d(256)
            )
        
        self.conv5 = nn.Sequential(
            IntermediateBlock(
                input_channels = 256,
                output_channels = 256,
                output_volume = 20,
                units = 4
            ),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, 1),
            nn.GELU(),
            nn.BatchNorm2d(256)
            )
        
        self.conv6 = nn.Sequential(
            IntermediateBlock(
                input_channels = 256,
                output_channels = 256,
                output_volume = 18,
                units = 4
            ),
            nn.BatchNorm2d(256),
            nn.Conv2d(256, 256, 1),
            nn.GELU(),
            nn.BatchNorm2d(256)
            )
        
        self.attn1 = nn.Sequential(
                SelfAttention(256),
                nn.BatchNorm2d(256)
            )
        
        self.attn2 = nn.Sequential(
                SelfAttention(256),
                nn.BatchNorm2d(256)
            )
        
        self.pool1 = nn.AdaptiveAvgPool2d((26, 26))
        self.pool2 = nn.AdaptiveAvgPool2d((20, 20))
        self.pool3 = nn.AdaptiveAvgPool2d((18, 18))

        
        self.outpool = nn.AdaptiveAvgPool2d((1, 1))

        self.output = OutputBlock(256, 1)

    def forward(self, x):

        c1 = self.conv1(x) + x
        c2 = self.conv2(c1) + c1 + x
        c3 = self.conv3(c2) + self.pool1(c2) + self.pool1(c1) + self.pool(x)
        c4 = self.conv4(c3) + self.pool2(c3) + self.pool2(c2) + self.pool2(c1) + self.pool2(x)
        a1 = self.attn1(c4) + c4 + self.pool2(c3) + self.pool2(c2) + self.pool2(c1) + self.pool2(x)
        c5 = self.conv5(a1) + c4 + self.pool2(c3) + self.pool2(c2) + self.pool2(c1) + self.pool2(x) 
        a2 = self.attn2(c5) + c5 + c4 + self.pool2(c3) + self.pool2(c2) + self.pool2(c1) + self.pool2(x)
        c6 = self.conv6(a2) + self.pool3(a2) + self.pool3(c5) + self.pool3(c4) + self.pool3(c3) + self.pool3(c2) + self.pool3(c1) + self.pool3(x)

        out = self.outpool(c6)
        out = self.output(out)

        return out

        
        

In [None]:
advanced_block_config = {
    "block_1": {
        "input_channels": 256,
        "output_channels": 256,
        "output_volume": 32,
        "units": 4,
    },
    "block_2": {
        "input_channels": 256,
        "output_channels": 256,
        "output_volume": 26,
        "units": 4,
      },
    "block_3": {
        "input_channels": 256,
        "output_channels": 256,
        "output_volume": 20,
        "units": 4,
    },
    "block_4": {
        "input_channels": 256,
        "output_channels": 256,
        "output_volume": 20,
        "units": 4,
    },
    "block_5": {
        "input_channels": 256,
        "output_channels": 256,
        "output_volume": 18,
        "units": 4,
    }
}