<a href="https://colab.research.google.com/github/shahroz1610/cowin-noti/blob/master/Unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn

In [23]:
class Unet(nn.Module):
    def __init__(self,in_channel,out_channel):
        super(Unet, self).__init__()

        self.in_channel = in_channel
        self.out_channel = out_channel
        self.mp1 = nn.MaxPool2d(kernel_size=2)
        self.gd1 = self.going_down(in_channel,64)
        self.gd2 = self.going_down(64,128)
        self.gd3 = self.going_down(128,256)
        self.gd4 = self.going_down(256,512)
        self.btm = self.bottom(512,1024,512)
        self.gp4 = self.going_up(1024,512,256)
        self.gp3 = self.going_up(512,256,128)
        self.gp2 = self.going_up(256,128,64)
        self.ll = self.last(128,64,out_channel)

    def going_down(self,in_channels,out_channels,kernel_size=3):
        block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size),
            nn.ReLU(True),
            nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=kernel_size),
            nn.ReLU(True)
        )
        return block
    def going_up(self,in_channels,mid_channels,out_channels,kernel_size=3):
        block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,out_channels=mid_channels,kernel_size=3),
            nn.ReLU(True),
            nn.Conv2d(in_channels=mid_channels,out_channels=mid_channels,kernel_size=3),
            nn.ReLU(True),
            nn.ConvTranspose2d(in_channels=mid_channels,out_channels=out_channels,kernel_size=kernel_size)
        )
        return block
    def bottom(self,in_channels,mid_channels,out_channels):
        block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,out_channels=mid_channels,kernel_size=3),
            nn.ReLU(True),
            nn.Conv2d(in_channels=mid_channels,out_channels=mid_channels,kernel_size=3),
            nn.ReLU(True),
            nn.ConvTranspose2d(in_channels=mid_channels,out_channels=out_channels,kernel_size=2)
        )
        return block
    def last(self,in_channels,mid_channels,out_channels):
        block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,out_channels=mid_channels,kernel_size=3),
            nn.ReLU(True),
            nn.Conv2d(in_channels=mid_channels,out_channels=mid_channels,kernel_size=3),
            nn.ReLU(True),
            nn.Conv2d(in_channels=mid_channels,out_channels=out_channels,kernel_size=1)
        )
        return block
    def concat(self,up,down):
        torch.cat((up,down),1)
    
    def forward(self,x):
        eb1 = self.gd1(x)
        em1 = self.mp1(eb1)
        eb2 = self.gd2(em1)
        em2 = self.mp1(eb2)
        eb3 = self.gd3(em2)
        em3 = self.mp1(eb3)
        eb4 = self.gd4(em3)
        em4 = self.mp1(eb4)
        btm = self.btm(em4)
        # 
        c4 = self.concat(btm,eb4)
        db4 = self.gp4(c4)
        
        c3 = self.concat(db4,eb3)
        db3 = self.gp3(c3)

        c2 = self.concat(db3,eb2)
        db2 = self.gp2(c2)

        c1 = self.concat(db2,eb1)
        res = self.ll(c1)
        
        return res


In [24]:
unet = Unet(1,2)

In [25]:
unet

Unet(
  (mp1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (gd1): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU(inplace=True)
  )
  (gd2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU(inplace=True)
  )
  (gd3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU(inplace=True)
  )
  (gd4): Sequential(
    (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU(inplace=True)
  )
  (btm): Sequential(
    (0): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace