<a href="https://colab.research.google.com/github/taravatp/Future_MRI_Image_Generation/blob/main/models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
from torchsummary import summary

In [3]:
class Downconv(nn.Module):
  def __init__(self,in_channels,out_channels,kernel_size=3,stride=1,padding=1): #stride=1 padding=1 kernel_size=3 this will be a same conv
    super(Downconv,self).__init__() #mitooni bias ro false bezari va az batch nomr ham estefade koni
    self.layers = nn.Sequential(
        nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
        nn.BatchNorm3d(out_channels),
        nn.ReLU(),
        nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
        nn.BatchNorm3d(out_channels),
        nn.ReLU()
    )
  def forward(self,x):
    return self.layers(x)

In [29]:
class generator(nn.Module):
  def __init__(self, in_channels=3,out_channels=3):
    super(generator,self).__init__()

    self.features = [32, 64, 128, 256]

    self.downs = nn.ModuleList()
    self.ups = nn.ModuleList()
    self.pool = nn.MaxPool3d(kernel_size=2,stride=2)

    for feature in self.features: 
      self.downs.append(Downconv(in_channels=in_channels,out_channels=feature)) #(3,32) - (32,64) - (64,128) - (128,256)
      in_channels = feature

    self.bottleneck = Downconv(self.features[-1], self.features[-1]*2) #(256,512)

    for feature in reversed(self.features):
      self.ups.append(
          nn.ConvTranspose3d(in_channels=feature*2, out_channels=feature, kernel_size=2, stride=2) #double the heigh and width of the image
      )
      self.ups.append(Downconv(in_channels=feature*2,out_channels=feature))
    
    self.final_conv = nn.Conv3d(self.features[0],out_channels,kernel_size=1)

  def forward(self,x):
    
    skip_connections = []
    for down in self.downs:
      x = down(x)
      skip_connections.append(x)
      x = self.pool(x)

    x = self.bottleneck(x)

    skip_connections = skip_connections[::-1]
    for index in range(0,len(self.ups),2):
      x = self.ups[index](x) #0,2,4,.. #upsampling
      skip_connection = skip_connections[index//2] #0,1,2,///
      concat_skip = torch.cat((skip_connection,x),dim=1) #concatanating
      x = self.ups[index+1](concat_skip) #applying double convolutions

    x = self.final_conv(x)
    return x

In [30]:
if __name__ == "__main__":
  model = generator()
  input = torch.zeros((1,3,64,64,64)) #(N, c, D , H, W)
  print('input shape:',input.shape)
  out = model(input)
  print('output shape:',out.shape)

input shape: torch.Size([1, 3, 64, 64, 64])
output shape: torch.Size([1, 3, 64, 64, 64])
