# Neural network models

In [264]:
import numpy as np
import torch
import torch.nn as nn
import math
from torchvision import models
from collections import OrderedDict

## U-Net Network
num_classes = Number of classes expecteed from output

In [265]:
def conv2d_layer(channels_in, channels_out, kernel_size, padding, stride, dilation):
    layer = nn.ModuleList()
  
    layer.append(nn.Conv2d(
        in_channels = channels_in, 
        out_channels = channels_out, 
        kernel_size = kernel_size, 
        stride = stride, 
        padding = padding, 
        dilation = dilation
        )
    )
    layer.append(nn.ReLU())

    return layer

In [266]:
class Encoder_Block(nn.Module):
    def __init__(self, channels_in, channels_out, params: dict):
        super().__init__()

        # Config parameters
        try:
            block_width = params['block_width']
            kernel_size = params['kernel_size']
            padding = params['padding']
            stride = params['stride']
            dilation = params['dilation']
            pool_type = params['pool_type']
            pool_size = params['pool_kernel_size']
        except KeyError as e:
            raise Exception(f'Parameter "{e.args[0]}" NOT found!')
        
        # Add first Conv2D layer
        self.layers = nn.ModuleList()
        self.layers.extend(conv2d_layer(channels_in, channels_out, kernel_size, padding, stride, dilation))

        # Add other Conv2D layers
        for _ in range(block_width - 1):
            self.layers.extend(conv2d_layer(channels_out, channels_out, kernel_size, padding, stride, dilation))

        # Add pooling layer
        match pool_type:
            case 'max':
                self.pooling = nn.MaxPool2d(kernel_size=pool_size, stride=pool_size)
            case 'avg':
                self.pooling = nn.AvgPool2d(kernel_size=pool_size, stride=pool_size)


    def forward(self, x):
        # Convolution
        for layer in self.layers:
            x = layer(x)

        # Pooling
        return (self.pooling(x), x)

In [267]:
class Decoder_Block(nn.Module):
    def __init__(self, channels_in, channels_out, padding_t, params: dict):
        super().__init__()

        # Config parameters
        try:
            block_width = params['block_width']
            kernel_size = params['kernel_size']
            padding = params['padding']
            stride = params['stride']
            dilation = params['dilation']
            pool_size = params['pool_kernel_size']
        except KeyError as e:
            raise Exception(f'Parameter "{e.args[0]}" NOT found!')

        # Add transpose convolution
        self.convT = nn.ConvTranspose2d (
            in_channels = channels_in, 
            out_channels = math.floor(channels_in / 2), 
            kernel_size = pool_size, 
            stride = pool_size, 
            output_padding = padding_t
        )

        # Add first Conv2D layer
        self.layers = nn.ModuleList()
        self.layers.extend(conv2d_layer(channels_in, channels_out, kernel_size, padding, stride, dilation))

        # Add other Conv2D layers
        for _ in range(block_width - 1):
            self.layers.extend(conv2d_layer(channels_out, channels_out, kernel_size, padding, stride, dilation))


    def forward(self, x, enc_features):
        # Transposed convolution
        x = self.convT(x)

        # Copy and crop features
        w = x.size(-1)
        c = math.floor((enc_features.size(-1) - w) / 2)

        print(x.size())
        print(x.size(dim=1))
        
        print(enc_features.size())
        print(enc_features.size(dim=1))
        #enc_features = enc_features[:, :, c:c+w , c:c+w]
        x = torch.cat([enc_features, x], dim = 1)

        # Convolution
        for layer in self.layers:
            x = layer(x)

        return x

In [268]:
class U_Net(nn.Module):
    def __init__(self, channels_in, num_classes, params: dict):
        super().__init__()

        # Config parameters
        try:
            channel_mul = params['channel_mul']
            depth = params['network_depth']
            channels_out = params['channels_out_init']
            padding_convT = params['padding_convT']
        except KeyError as e:
            raise Exception(f'Parameter "{e.args[0]}" NOT found!')
        
        # Create encoder
        self.encoders = nn.ModuleList()

        for _ in range(depth) :
            self.encoders.append(Encoder_Block(channels_in, channels_out, params))
            channels_in = channels_out
            channels_out = math.floor(channels_out * channel_mul)

        # Create bridge as Conv2D layer
        last_encoder = Encoder_Block(channels_in, channels_out, params)
        self.bridge = last_encoder.layers[0]

        channels_in = channels_out
        channels_out = math.floor(channels_in / channel_mul)

        # Create decoder
        self.decoders = nn.ModuleList()
        
        for _ in range(depth) :
            paddingT = padding_convT.pop(0)
            self.decoders.append(Decoder_Block(channels_in, channels_out, paddingT, params))
            channels_in = channels_out
            channels_out = math.floor(channels_out / channel_mul)

        # Create output layer (1x1 convolution, return logits)
        self.output_layer = nn.Conv2d(
            in_channels = channels_in, 
            out_channels = num_classes, 
            kernel_size = 1
        )
        

    def forward(self, x):
        # Features created during enconding
        features = []

        # Encoding
        for enc in self.encoders:
            x, conv_output = enc(x)
            features.append(conv_output)

        # Bridge between encoder and decoder
        x = self.bridge(x)

        # Decoding
        for dec in self.decoders:
            enc_output = features.pop()
            x = dec(x, enc_output)

        # Return logits
        return self.output_layer(x)

In [269]:
def conv2d_output_shape(height, width, conv2d: nn.Module):
    height = ((height + (2 * conv2d.padding[0]) - (conv2d.dilation[0] * (conv2d.kernel_size[0] - 1)) - 1) / conv2d.stride[0]) + 1
    height = math.floor(height)
    
    width = ((width + (2 * conv2d.padding[1]) - (conv2d.dilation[1] * (conv2d.kernel_size[1] - 1)) - 1) / conv2d.stride[1]) + 1
    width = math.floor(width)

    return (height, width, conv2d.out_channels)

def pool_output_shape(height, width, conv2d: nn.Module):
    height = ((height + (2 * conv2d.padding) - (conv2d.dilation * (conv2d.kernel_size - 1)) - 1) / conv2d.stride) + 1
    height = math.floor(height)
    
    width = ((width + (2 * conv2d.padding) - (conv2d.dilation * (conv2d.kernel_size - 1)) - 1) / conv2d.stride) + 1
    width = math.floor(width)

    return (height, width)

def convT_output_shape(height, width, convT: nn.Module):
    height = (height - 1) * convT.stride[0] - (2 * convT.padding[0]) + (convT.dilation[0] * (convT.kernel_size[0] - 1)) + convT.output_padding[0] + 1
    
    width = (width - 1) * convT.stride[1] - (2 * convT.padding[1]) + (convT.dilation[1] * (convT.kernel_size[1] - 1)) + convT.output_padding[1] + 1
    
    return (height, width, convT.out_channels)


def output_shapes(net : U_Net, image):
    h = image.size(-2)  # height
    w = image.size(-1)  # width
    depth = len(net.encoders) + 1

    print("----------------------------------------------")
    print(f"CNN with depth = {depth}")
    print("----------------------------------------------")

    i = 1
    for enc in net.encoders:
        print(f"||Layer {i} - encoders||")

        for l in enc.layers:
            if not isinstance(l, nn.ReLU):
                (h,w,c) = conv2d_output_shape(h, w, l)
                print (f"Conv2D -- H = {h} | W = {w} | C = {c}")

        (h,w) = pool_output_shape(h, w, enc.pooling)
        print (f"Pooling -- H = {h} | W = {w} | C = {c}") 

        print()
        i += 1
            
    print(f"||Layer {i} - bridge||")
    (h,w,c) = conv2d_output_shape(h, w, net.bridge)
    print (f"H = {h} | W = {w} | C = {c}")
    print("----------------------------------------------")

    for dec in net.decoders:
        print(f"||Layer {i} - decoders||")

        (h,w,c) = convT_output_shape(h, w, dec.convT)
        print (f"Transpose -- H = {h} | W = {w} | C = {c}")   

        for l in dec.layers:
            if not isinstance(l, nn.ReLU):
                (h,w,c) = conv2d_output_shape(h, w, l)
                print (f"Conv2D -- H = {h} | W = {w} | C = {c}")   

        print()
        i -= 1


    print(f"||Layer {i} - output||")
    (h,w,c) = conv2d_output_shape(h, w, net.output_layer)
    print (f"H = {h} | W = {w} | C = {c}")
    print("----------------------------------------------")

In [274]:
from sklearn.model_selection import train_test_split

%run config.ipynb

# Load data from .npy files
x = np.load('images_array.npy')  # Assuming X.npy contains your images
y = np.load('inst_maps_replaced_resized.npy')  # Assuming y.npy contains your labels

# Split the dataset into train, validation, and test sets
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.25, random_state=42) 

print(X_train)

#net = U_Net(2, 10, config_to_dict(config_unet))

#output_shapes(net, x)

#out = net(x)

[[[[111  27  98]
   [114  29 108]
   [114  23 119]
   ...
   [213  98 161]
   [248 142 182]
   [251 170 199]]

  [[123  31 110]
   [127  34 118]
   [127  30 130]
   ...
   [239 147 176]
   [253 182 201]
   [254 197 241]]

  [[140  34 135]
   [138  32 125]
   [140  35 123]
   ...
   [251 173 189]
   [253 202 225]
   [253 210 247]]

  ...

  [[237 131 195]
   [247 151 199]
   [225 135 206]
   ...
   [189 104 207]
   [169  84 170]
   [170  76 176]]

  [[232 134 197]
   [251 165 209]
   [243 164 215]
   ...
   [179  91 203]
   [167  80 166]
   [170  78 157]]

  [[216 120 203]
   [238 154 213]
   [250 177 216]
   ...
   [171  84 197]
   [154  66 155]
   [177  83 166]]]


 [[[233 221 231]
   [233 221 231]
   [233 221 231]
   ...
   [238 226 236]
   [238 226 235]
   [237 226 234]]

  [[233 221 231]
   [233 221 231]
   [233 221 231]
   ...
   [237 225 235]
   [237 226 234]
   [237 226 234]]

  [[233 221 231]
   [233 221 231]
   [233 221 231]
   ...
   [240 229 238]
   [237 225 235]
   [235 224