In [1]:
import torch
import torch.nn as nn


In [10]:
class UNet(nn.Module):
    def __init__(self, input, output):
        super().__init__()
        # this is the first half part

        #   encoder part

        self.conv1 = nn.Conv2d(input, 64, kernel_size= 3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, kernel_size= 3,  padding=1)

        self.conv3 = nn.Conv2d(64, 128, kernel_size= 3,  padding=1)
        self.conv4 = nn.Conv2d(128, 128, kernel_size= 3,  padding=1)


        self.conv5 = nn.Conv2d(128, 256, kernel_size= 3,  padding=1)
        self.conv6 = nn.Conv2d(256, 256, kernel_size= 3,  padding=1)


        self.conv7 = nn.Conv2d(256, 512, kernel_size= 3, padding=1)
        self.conv8 = nn.Conv2d(512, 512, kernel_size= 3,  padding=1)


        #bottleneck

        self.conv9 = nn.Conv2d(512, 1024, kernel_size= 3,  padding=1)
        self.conv10 = nn.Conv2d(1024, 1024, kernel_size= 3,  padding=1)



        #   decoder part

        self.conv11 = nn.Conv2d(1024, 512, kernel_size= 3, padding=1)
        self.conv12 = nn.Conv2d(512, 512, kernel_size= 3, padding=1)


        self.conv13 = nn.Conv2d(512, 256, kernel_size= 3, padding=1)
        self.conv14 = nn.Conv2d(256, 256, kernel_size= 3, padding=1)



        self.conv15 = nn.Conv2d(256, 128, kernel_size= 3, padding=1)
        self.conv16 = nn.Conv2d(128, 128, kernel_size= 3, padding=1)

        self.conv17 = nn.Conv2d(128, 64, kernel_size= 3, padding=1)
        self.conv18 = nn.Conv2d(64, 64, kernel_size= 3, padding=1)

        # output part
        self.conv19 = nn.Conv2d(64, output, kernel_size= 1)


        # upsamplingtranspose part
        self.upsam1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.upsam2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.upsam3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.upsam4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)


        # we will use ReLU for our activation function
        self.activation = nn.ReLU()
        # for downsampling
        self.maxpooling = nn.MaxPool2d(kernel_size=2, stride =2)



    def forward(self, x):

        #downsampling part
        x1 = self.activation(self.conv2(self.activation(self.conv1(x))))
        x2= self.maxpooling(x1)


        x3= self.activation(self.conv4(self.activation(self.conv3(x2))))
        x4= self.maxpooling(x3)

        x5= self.activation(self.conv6(self.activation(self.conv5(x4))))
        x6= self.maxpooling(x5)

        x7= self.activation(self.conv8(self.activation(self.conv7(x6))))
        x8= self.maxpooling(x7)

        # bottleneck part


        x9 =  self.activation(self.conv10(self.activation(self.conv9(x8))))



        #upsampling part
        x = self.upsam1(x9)
        x = torch.cat([x, x7], dim=1)
        x = self.activation(self.conv12(self.activation(self.conv11(x))))

        x = self.upsam2(x)
        x = torch.cat([x, x5], dim=1)
        x = self.activation(self.conv14(self.activation(self.conv13(x))))

        x = self.upsam3(x)
        x = torch.cat([x, x3], dim=1)
        x = self.activation(self.conv16(self.activation(self.conv15(x))))

        x = self.upsam4(x)
        x = torch.cat([x, x1], dim=1)
        x = self.activation(self.conv18(self.activation(self.conv17(x))))

        return self.conv19(x)






model = UNet(input=3, output=1)
print(model)



UNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv7): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv8): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv9): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv10): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv11): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv12): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv13): Conv2d(512, 256, kernel_size=(3, 

In [3]:
!pip install torchinfo





Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [11]:
from torchsummary import summary
import torch

# Create an instance of your model
model = UNet(input=3, output=1)  # 3 input channels for RGB, 1 output channel for binary segmentation

# Display the summary of the model
summary(model, (3, 128, 128))  # Corrected input shape (3 channels for RGB, 64x64 image size)


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           1,792
              ReLU-2         [-1, 64, 128, 128]               0
            Conv2d-3         [-1, 64, 128, 128]          36,928
              ReLU-4         [-1, 64, 128, 128]               0
         MaxPool2d-5           [-1, 64, 64, 64]               0
            Conv2d-6          [-1, 128, 64, 64]          73,856
              ReLU-7          [-1, 128, 64, 64]               0
            Conv2d-8          [-1, 128, 64, 64]         147,584
              ReLU-9          [-1, 128, 64, 64]               0
        MaxPool2d-10          [-1, 128, 32, 32]               0
           Conv2d-11          [-1, 256, 32, 32]         295,168
             ReLU-12          [-1, 256, 32, 32]               0
           Conv2d-13          [-1, 256, 32, 32]         590,080
             ReLU-14          [-1, 256,