# Neural network models

In [23]:
import torch
import torch.nn as nn
import math
from torchvision import models
from collections import OrderedDict
from copy import deepcopy

## U-Net Network
num_classes = Number of classes expecteed from output

In [24]:
def conv2d_layer(channels_in, channels_out, kernel_size, padding, stride, dilation):
    layer = nn.ModuleList()

    conv2d = nn.Conv2d(
        in_channels = channels_in, 
        out_channels = channels_out, 
        kernel_size = kernel_size, 
        stride = stride, 
        padding = padding, 
        dilation = dilation
    )
    torch.nn.init.kaiming_normal_(conv2d.weight)
    
    layer.append(conv2d)
    layer.append(nn.ReLU())

    return layer

In [25]:
class Encoder_Block(nn.Module):
    def __init__(self, channels_in, channels_out, params: dict):
        super().__init__()

        # Config parameters
        try:
            block_width = params['block_width']
            kernel_size = params['kernel_size']
            padding = params['padding']
            stride = params['stride']
            dilation = params['dilation']
            pool_type = params['pool_type']
            pool_size = params['pool_kernel_size']
        except KeyError as e:
            raise Exception(f'Parameter "{e.args[0]}" NOT found!')
        
        # Add first Conv2D layer
        self.layers = nn.ModuleList()
        self.layers.extend(conv2d_layer(channels_in, channels_out, kernel_size, padding, stride, dilation))

        # Add other Conv2D layers
        for _ in range(block_width - 1):
            self.layers.extend(conv2d_layer(channels_out, channels_out, kernel_size, padding, stride, dilation))

        # Add pooling layer
        match pool_type:
            case 'max':
                self.pooling = nn.MaxPool2d(kernel_size=pool_size, stride=pool_size)
            case 'avg':
                self.pooling = nn.AvgPool2d(kernel_size=pool_size, stride=pool_size)


    def forward_noSkip(self, x):
        # Convolution
        for layer in self.layers:
            x = layer(x)

        # Pooling
        pooled = self.pooling(x)
        
        return pooled
    
    def forward_skip(self, x):
        # Convolution
        for layer in self.layers:
            x = layer(x)

        # Pooling
        pooled = self.pooling(x)

        return (pooled, x)

In [26]:
class Decoder_Block(nn.Module):
    def __init__(self, channels_in, channels_out, padding_t, params: dict):
        super().__init__()

        # Config parameters
        try:
            block_width = params['block_width']
            kernel_size = params['kernel_size']
            padding = params['padding']
            stride = params['stride']
            dilation = params['dilation']
            pool_size = params['pool_kernel_size']
            skip_features = params['skip_features']
        except KeyError as e:
            raise Exception(f'Parameter "{e.args[0]}" NOT found!')

        # Add transpose convolution
        channel_out_convT = math.floor(channels_in / 2)

        self.convT = nn.ConvTranspose2d (
            in_channels = channels_in, 
            out_channels = channel_out_convT, 
            kernel_size = pool_size, 
            stride = pool_size, 
            output_padding = padding_t
        )
        torch.nn.init.kaiming_normal_(self.convT)
        
        match skip_features:
            case "none":
                channels_in = channel_out_convT
                pass
            case "concat":
                # Half of channels are from convT, other half from enc features
                # channels_in = 2 * channel_out_convT
                pass
            case "sdi":
                pass

        # Add first Conv2D layer
        self.layers = nn.ModuleList()
        self.layers.extend(conv2d_layer(channels_in, channels_out, kernel_size, padding, stride, dilation))

        # Add other Conv2D layers
        for _ in range(block_width - 1):
            self.layers.extend(conv2d_layer(channels_out, channels_out, kernel_size, padding, stride, dilation))

    def forward_noSkip(self, x):
        # Transposed convolution
        x = self.convT(x)

        # Convolution
        for layer in self.layers:
            x = layer(x)

        return x

    def forward_skip(self, x, enc_feature):
        # Transposed convolution
        x = self.convT(x)

        # Crop feature image 
        # w = x.size(-1)
        # diff = math.floor((enc_feature.size(-1) - w) / 2)
        # enc_features = enc_features[:, :, diff:diff+w , diff:diff+w]

        print(x.size())   
        print(enc_feature.size())
        
        # Copy features
        x = torch.cat([enc_feature, x], dim = 1)

        # Convolution
        for layer in self.layers:
            x = layer(x)

        return x

In [27]:
class U_Net(nn.Module):
    def __init__(self, channels_in, num_classes, params: dict):
        super().__init__()

        torch.set_default_dtype(torch.float32)

        # Config parameters
        try:
            channel_mul = params['channel_mul']
            depth = params['network_depth']
            channels_out = params['channels_out_init']
            padding_convT = deepcopy(params['padding_convT'])
            self.skip_features = params['skip_features']
            if (len(padding_convT) != depth):
                raise Exception(f"Padding of Transposed convolution has length {len(padding_convT)}, but should be {depth}!")
            
        except KeyError as e:
            raise Exception(f'Parameter "{e.args[0]}" NOT found!')
        
        # Create encoder
        self.encoders = nn.ModuleList()

        for i in range(depth):
            self.encoders.append(Encoder_Block(channels_in, channels_out, params))
            channels_in = channels_out
            channels_out = math.floor(channels_out * channel_mul)

        # Create bridge as Conv2D layer
        last_encoder = Encoder_Block(channels_in, channels_out, params)
        self.bridge = last_encoder.layers[0]

        channels_in = channels_out
        channels_out = math.floor(channels_in / channel_mul)

        # Create decoder
        self.decoders = nn.ModuleList()
        
        for _ in range(depth) :
            paddingT = padding_convT.pop(0)
            self.decoders.append(Decoder_Block(channels_in, channels_out, paddingT, params))
            channels_in = channels_out
            channels_out = math.floor(channels_out / channel_mul)

        # Create output layer (1x1 convolution, return logits)
        self.output_layer = nn.Conv2d(
            in_channels = channels_in, 
            out_channels = num_classes, 
            kernel_size = 1
        )
        
    def forward(self, x):
        match self.skip_features:
            # No skipping
            case "none":
                print("This U-Net does NOT skip any connections!")
                # Encoding
                for enc in self.encoders:
                    x = enc.forward_noSkip(x)

                # Bridge between encoder and decoder
                x = self.bridge(x)

                # Decoding
                for dec in self.decoders:
                    x = dec.forward_noSkip(x)

            # Cat tensors => features + pooled output
            case "concat":
                print("This U-Net skips connections using concatenation!")
                # Features created during enconding
                features = []

                # Encoding
                for enc in self.encoders:
                    x, enc_feature = enc.forward_skip(x)
                    features.append(enc_feature)

                # Bridge between encoder and decoder
                x = self.bridge(x)

                # Decoding
                for dec in self.decoders:
                    enc_output = features.pop()
                    x = dec.forward_skip(x, enc_output)

            case "sdi":
                pass
    

        # Output as logits
        x = self.output_layer(x)

        return x

In [28]:
def conv2d_output_shape(height, width, conv2d: nn.Module):
    height = ((height + (2 * conv2d.padding[0]) - (conv2d.dilation[0] * (conv2d.kernel_size[0] - 1)) - 1) / conv2d.stride[0]) + 1
    height = math.floor(height)
    
    width = ((width + (2 * conv2d.padding[1]) - (conv2d.dilation[1] * (conv2d.kernel_size[1] - 1)) - 1) / conv2d.stride[1]) + 1
    width = math.floor(width)

    return (height, width, conv2d.out_channels)

def pool_output_shape(height, width, conv2d: nn.Module):
    height = ((height + (2 * conv2d.padding) - (conv2d.dilation * (conv2d.kernel_size - 1)) - 1) / conv2d.stride) + 1
    height = math.floor(height)
    
    width = ((width + (2 * conv2d.padding) - (conv2d.dilation * (conv2d.kernel_size - 1)) - 1) / conv2d.stride) + 1
    width = math.floor(width)

    return (height, width)

def convT_output_shape(height, width, convT: nn.Module):
    height = (height - 1) * convT.stride[0] - (2 * convT.padding[0]) + (convT.dilation[0] * (convT.kernel_size[0] - 1)) + convT.output_padding[0] + 1
    
    width = (width - 1) * convT.stride[1] - (2 * convT.padding[1]) + (convT.dilation[1] * (convT.kernel_size[1] - 1)) + convT.output_padding[1] + 1
    
    return (height, width, convT.out_channels)


def output_shapes(net : U_Net, height, width):
    h = height
    w = width
    depth = len(net.encoders) + 1

    print("----------------------------------------------")
    print(f"CNN with depth = {depth}")
    print("----------------------------------------------")

    i = 1
    for enc in net.encoders:
        print(f"||Layer {i} - encoders||")

        for l in enc.layers:
            if not isinstance(l, nn.ReLU):
                (h,w,c) = conv2d_output_shape(h, w, l)
                print (f"Conv2D -- H = {h} | W = {w} | C = {c}")

        (h,w) = pool_output_shape(h, w, enc.pooling)
        print (f"Pooling -- H = {h} | W = {w} | C = {c}") 

        print()
        i += 1
            
    print(f"||Layer {i} - bridge||")
    (h,w,c) = conv2d_output_shape(h, w, net.bridge)
    print (f"H = {h} | W = {w} | C = {c}")
    print("----------------------------------------------")

    for dec in net.decoders:
        print(f"||Layer {i} - decoders||")

        (h,w,c) = convT_output_shape(h, w, dec.convT)
        print (f"Transpose -- H = {h} | W = {w} | C = {c}")   

        for l in dec.layers:
            if not isinstance(l, nn.ReLU):
                (h,w,c) = conv2d_output_shape(h, w, l)
                print (f"Conv2D -- H = {h} | W = {w} | C = {c}")   

        print()
        i -= 1


    print(f"||Layer {i} - output||")
    (h,w,c) = conv2d_output_shape(h, w, net.output_layer)
    print (f"H = {h} | W = {w} | C = {c}")
    print("----------------------------------------------")