<a href="https://colab.research.google.com/github/winchelo/Projet-GIF4001-GIF7005/blob/main/RCMUNETPytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
def double_conv(in_chnl, out_chnl, use_batch_norm=True):
    layers = [
        nn.Conv2d(in_chnl, out_chnl, kernel_size=3, padding='same'),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_chnl, out_chnl, kernel_size=3, padding='same'),
        nn.ReLU(inplace=True)
    ]

    if use_batch_norm:
        layers.insert(1, nn.BatchNorm2d(out_chnl))
        layers.insert(4, nn.BatchNorm2d(out_chnl))

    return nn.Sequential(*layers)

In [None]:
class Dense4NN(nn.Module):
    def __init__(self, nb_chnl = 41):
        super(Dense4NN, self).__init__()

        self.fc1 = nn.Linear(nb_chnl, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, 256)
        self.fc4 = nn.Linear(256, 512)

    def forward(self, x):
        # Flatten the input
        x = x.view(x.size(0), -1)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)

        x = x.view(-1, 512, 1, 1)
        return x

In [None]:
def crop_tensor(tensor, target_tensor):
  target_size = target_tensor.size()[2]
  tensor_size = tensor.size()[2]
  delta = tensor_size - target_size
  delta = delta // 2
  return tensor[:,:, delta:tensor_size - delta, delta:tensor_size - delta]


In [None]:
class RCMUNet(nn.Module):
  def __init__(self, nb_chnl = 20, nb_class = 1):
    super(RCMUNet, self).__init__()
    self.nb_chnl = nb_chnl
    self.nb_class = nb_class
    self.max_pool_2x2 = nn.MaxPool2d(kernel_size= 2, stride = 2)
    self.down_conv_1 = double_conv(nb_chnl, 64)
    self.down_conv_2 = double_conv(64, 128)
    self.down_conv_3 = double_conv(128, 256)
    self.down_conv_4 = double_conv(256, 512)
    self.down_conv_5 = double_conv(512, 512, False)

    self.dense_layer = Dense4NN()

    self.up_transp_1 = nn.ConvTranspose2d(
        in_channels= 1024,
        out_channels= 512,
        kernel_size = 2,
        stride = 2)
    self.up_conv_1 = double_conv(1024, 512)

    self.up_transp_2 = nn.ConvTranspose2d(
        in_channels= 512,
        out_channels= 256,
        kernel_size = 2,
        stride = 2)
    self.up_conv_2 = double_conv(512, 256)

    self.up_transp_3 = nn.ConvTranspose2d(
        in_channels= 256,
        out_channels= 128,
        kernel_size = 2,
        stride = 2)
    self.up_conv_3 = double_conv(256, 128)

    self.up_transp_4 = nn.ConvTranspose2d(
        in_channels= 128,
        out_channels= 64,
        kernel_size = 2,
        stride = 2)
    self.up_conv_4 = double_conv(128, 64)
    # new layer to the UNet Architecture
    self.new_transp_1 = nn.ConvTranspose2d(
        in_channels= 64,
        out_channels= 64,
        kernel_size = 2,
        stride = 2)
    self.new_conv_1 = double_conv(64, 64)

    self.new_transp_2 = nn.ConvTranspose2d(
        in_channels= 64,
        out_channels= 64,
        kernel_size = 2,
        stride = 2)
    self.new_conv_2 = double_conv(64, 64)
    # the last layer
    self.out = nn.Conv2d(in_channels=64,
                         out_channels=1,
                         kernel_size=nb_class)

  def forward(self, image2D, input1D):
    #encoder
    x1 = self.down_conv_1(image2D)#
    x2 = self.max_pool_2x2(x1)
    x3 = self.down_conv_2(x2)#
    x4 = self.max_pool_2x2(x3)
    x5 = self.down_conv_3(x4)#
    x6 = self.max_pool_2x2(x5)
    x7 = self.down_conv_4(x6)#
    x8 = self.max_pool_2x2(x7)
    x9 = self.down_conv_5(x8)


    # Dense 4 nn
    x1d_encoded = self.dense_layer(input1D)
    x_combined = torch.cat([x9, x1d_encoded], dim=1)

    #decoder
    x = self.up_transp_1(x_combined)
    y = crop_tensor(x7, x)
    x = self.up_conv_1(torch.cat([x, y], 1))

    x = self.up_transp_2(x)
    y = crop_tensor(x5, x)
    x = self.up_conv_2(torch.cat([x, y], 1))

    x = self.up_transp_3(x)
    y = crop_tensor(x3, x)
    x = self.up_conv_3(torch.cat([x, y], 1))

    x = self.up_transp_4(x)
    y = crop_tensor(x1, x)
    x = self.up_conv_4(torch.cat([x, y], 1))

    # added convolution and convTranspose
    x = self.new_transp_1(x)
    x = self.new_conv_1(x)

    x = self.new_transp_2(x)
    x = self.new_conv_2(x)

    x = self.out(x)
    return x




In [None]:
if __name__ == "__main__" :
  image2D = torch.rand(1, 20, 16, 16)
  input1D = torch.rand(1, 41, 1, 1)
  model = RCMUNet(20, 1)
  x = model(image2D, input1D)
  print(f'output shape: {x.size()}')
  print(f'output value {x}')
  print(model)


output shape: torch.Size([1, 1, 64, 64])
output value tensor([[[[ 0.1522, -0.4444,  0.4682,  ...,  0.0316,  0.2154,  0.1440],
          [ 0.1267, -0.1600,  0.0871,  ..., -0.0169,  0.2047, -0.2992],
          [ 0.0848,  0.1891,  0.0150,  ..., -0.4224,  0.2091, -0.3283],
          ...,
          [-0.2404,  0.0228,  0.1069,  ..., -0.1632,  0.2730, -0.0371],
          [-0.0544,  0.1043, -0.0670,  ..., -0.3485,  0.0174, -0.0639],
          [ 0.0051, -0.1846, -0.2083,  ...,  0.0143, -0.0821,  0.2002]]]],
       grad_fn=<ConvolutionBackward0>)
RCMUNet(
  (max_pool_2x2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (down_conv_1): Sequential(
    (0): Conv2d(20, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, 