In [2]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Optional

import torch 
from torch import nn
from torchvision.datasets import CIFAR10

In [3]:
def get_activation(name: str, **kwargs):
    activations = {
        'relu': nn.ReLU,
        'tanh': nn.Tanh,
        'sigmoid': nn.Sigmoid,
        'silu': nn.SiLU,
        'softplus': nn.Softplus,
        'leakyrelu': nn.LeakyReLU
    }
    if name in activations.keys():
        return activations[name.lower()](**kwargs)
    else:
        raise KeyError('No such activation')

In [4]:
class ResidualBlock(nn.Module):
    def __init__(self, type: str, in_channels: int, out_channels: Optional[int] = None, mid_channels: Optional[int] = None, activations: List[str] | str = 'relu', kernel_size: int = 3, stride: int = 1, padding: int = 1) -> None:
        super(ResidualBlock, self).__init__()
        
        n_layers = 2
        
        if not isinstance(activations, List):
            activations = [activations] * n_layers
        if len(activations) != n_layers:
            raise Exception('Not enough activations')
        
        match type:
            case 'Standard':
                mid_channels = mid_channels or in_channels
                out_channels = out_channels or in_channels
                
                self.layers = nn.Sequential(
                    nn.Conv2d(in_channels, mid_channels, kernel_size=kernel_size, stride=stride, padding=padding),
                    nn.BatchNorm2d(mid_channels),
                    get_activation(activations[0]),
                    
                    nn.Conv2d(mid_channels, out_channels, kernel_size=kernel_size, padding=padding),
                    nn.BatchNorm2d(out_channels)
                )
            case 'Bottleneck':
                mid_channels = mid_channels or in_channels // 2
                out_channels = out_channels or in_channels
                
                self.layers = nn.Sequential(
                    nn.Conv2d(in_channels, mid_channels, kernel_size=1),
                    nn.BatchNorm2d(mid_channels),
                    get_activation(activations[0]),
                    
                    nn.Conv2d(mid_channels, mid_channels, kernel_size=kernel_size, stride=stride, padding=padding),
                    nn.BatchNorm2d(mid_channels),
                    get_activation(activations[1]),
                    
                    nn.Conv2d(mid_channels, out_channels, kernel_size=1),
                    nn.BatchNorm2d(out_channels)
                )
            case _:
                raise Exception('Unsupported block type')
                
        self.activation = get_activation(activations[-1])
        
        self.sc_pool = None
        self.sc_scale = None
        
        if stride > 1:
            self.sc_pool = nn.AvgPool2d(kernel_size=stride, stride=stride)
        if in_channels != out_channels:
            self.sc_scale = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        
    def forward(self, x):
        shortcut = x
        
        x = self.layers(x)
                
        if self.sc_pool:
            shortcut = self.sc_pool(shortcut)
        if self.sc_scale:
            shortcut = self.sc_scale(shortcut)
        
        x = self.activation(shortcut + x)
        return x
print(ResidualBlock('Standard', 2, 4, stride=2)(torch.rand(1, 2, 10, 10)).shape)
print(ResidualBlock('Bottleneck', 2, 4, stride=2)(torch.rand(1, 2, 10, 10)).shape)
print(ResidualBlock('Standard', 2)(torch.rand(1, 2, 10, 10)).shape)
print(ResidualBlock('Bottleneck', 2)(torch.rand(1, 2, 10, 10)).shape)

torch.Size([1, 4, 5, 5])
torch.Size([1, 4, 5, 5])
torch.Size([1, 2, 10, 10])
torch.Size([1, 2, 10, 10])


In [20]:
class ResidualLayer(nn.Module):
    def __init__(self, block_type: str, n_blocks, in_channels: int, out_channels: Optional[int] = None, activations: List[str] | str = 'relu'):
        super(ResidualLayer, self).__init__()
        
        if not isinstance(activations, List):
            activations = [activations] * n_blocks
        if len(activations) != n_blocks:
            raise Exception('Not enough activations')
        
        out_channels = out_channels or in_channels * 2
        
        self.block1 = ResidualBlock(block_type, in_channels, out_channels, stride=2, activations=activations[0])
        self.same_blocks = nn.Sequential(*[ResidualBlock(block_type, out_channels, activations=activations[i]) for i in range(1, n_blocks)])
    
    def forward(self, x):
        x = self.block1(x)
        x = self.same_blocks(x)
        return x
ResidualLayer('Standard', 2, 2, activations=['relu', ['tanh', 'relu']])

ResidualLayer(
  (block1): ResidualBlock(
    (layers): Sequential(
      (0): Conv2d(2, 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(2, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (activation): ReLU()
    (sc_pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (sc_scale): Conv2d(2, 4, kernel_size=(1, 1), stride=(1, 1))
  )
  (same_blocks): Sequential(
    (0): ResidualBlock(
      (layers): Sequential(
        (0): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Tanh()
        (3): Conv2d(4, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_runni