# Implementation of the UNET Architecture from the paper 

https://arxiv.org/pdf/1505.04597.pdf

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms.functional as TF


In [8]:
class doubleconvlayer(nn.Module):

  def __init__(self,input_channels,output_channels):                                  # input image has size of 3x512x512
    super(doubleconvlayer,self).__init__()  

    self.input_channels=input_channels
    self.output_channels=output_channels

    self.conv=nn.Sequential(
        nn.Conv2d(in_channels  =input_channels,out_channels=output_channels,kernel_size=3,stride=1,padding=1),
        nn.BatchNorm2d(output_channels),
        nn.ReLU(),                             # output --> (n-f+1)/s --> 512-3+1=510 ---> Nx64x510x510
        nn.Conv2d(in_channels=output_channels,out_channels=output_channels,kernel_size=3,stride=1,padding=1),
        nn.BatchNorm2d(output_channels),
        nn.ReLU()                              # output --> (n-f+1)/s --> 510-3+1=508 ---> Nx64x508x508
    ) 

  def forward(self,x):
    return self.conv(x)


# returns batch_sizex64x508x508
m=doubleconvlayer(3,64)
x=torch.randn(5,3,512,512)
y=m(x)
print(x.shape)
print(y.shape)



torch.Size([5, 3, 512, 512])
torch.Size([5, 64, 512, 512])


In [9]:

class UNET(nn.Module):
  def __init__(self,input_channels=3,output_channels=1,features=[64,128,256,512]):
    super(UNET,self).__init__()

    self.ups=nn.ModuleList()
    self.downs=nn.ModuleList()
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)


    # downs

    for feature in features:

      self.downs.append(doubleconvlayer(input_channels,feature))       # i have added each double conv in as a element of list
                                                                     # modulelist() can be used to perform operation in this case
      input_channels=feature

   # print(self.downs)


   # uppart of the UNET , skip for a second for the bottleneck

    for feature in reversed(features):

      self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2,))

      self.ups.append(doubleconvlayer(feature*2,feature))


    self.bottleneck=doubleconvlayer(features[-1],features[-1]*2)

    self.final_layer= nn.Conv2d(features[0], output_channels, kernel_size=1)   

    


  def forward(self,x):

    skip_connections=[]

    for down in self.downs:                          # 1st --> n+2p-f+1/s --> Nx 64*512*512 --> pool , N*64*256*256

      x=down(x)
      skip_connections.append(x)
      x=self.pool(x)
     # print(x.shape)

    x=self.bottleneck(x)
    skip_connections=skip_connections[::-1]

    
    for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

    return self.final_layer(x)




    





m=UNET()
x=torch.randn(5,3,512,512)
y=m(x)
print(x.shape)
print(y.shape)



torch.Size([5, 3, 512, 512])
torch.Size([5, 1, 512, 512])
