In [45]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms.functional as functional

In [2]:
def conv3x3(in_channels: int, out_channels: int):
    return nn.Conv2d(in_channels=in_channels, out_channels=out_channels,kernel_size=3, padding=0)


def max_pool_2d():
    return nn.MaxPool2d(kernel_size=2, stride=2)

class UnetEncodeLayer(nn.Module):
    # just a standard convolution layer.
    def __init__(self, in_channels: int, out_channels: int, activated=True,max_pool=False):
        super(UnetEncodeLayer, self).__init__()
        layers = [
            conv3x3(in_channels, out_channels),
            # nn.BatchNorm2d(out_channels),
        ]
        if activated:
            layers += [nn.ReLU()]

        if max_pool:
            layers += [max_pool_2d()]

        self.layer = nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)
    

In [69]:
class UNET(nn.Module):
	def __init__(self):
		super(UNET, self).__init__()
    	# encoding part of the Unet vanilla architecture
		self.encode1 = nn.Sequential(
			UnetEncodeLayer(3, 64),
			UnetEncodeLayer(64, 64),
		)
		self.encode2 = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			UnetEncodeLayer(64, 128),
			UnetEncodeLayer(128, 128),
		)
		self.encode3 = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			UnetEncodeLayer(128, 256),
			UnetEncodeLayer(256, 256),
		)
		self.encode4 = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			UnetEncodeLayer(256, 512),
			UnetEncodeLayer(512, 512),
		)
		self.encode5 = nn.Sequential(
			nn.MaxPool2d(kernel_size=2, stride=2),
			UnetEncodeLayer(512, 1024),
			UnetEncodeLayer(1024, 1024),
		)		
		self.decode1 = nn.Sequential(
			nn.Upsample(scale_factor=(2,2)),
			nn.Conv2d(1024,512,1, stride=1, padding=0)
		)
		self.decode_forward1 = nn.Sequential(
			conv3x3(1024, 512),
			nn.ReLU(),
			conv3x3(512, 512),
			nn.ReLU(),
		)
		self.decode2 = nn.Sequential(
			nn.Upsample(scale_factor=(2,2)),
			nn.Conv2d(512,256,1, stride=1, padding=0)
		)
		self.decode_forward2 = nn.Sequential(
			conv3x3(512, 256),
			nn.ReLU(),
			conv3x3(256, 256),
			nn.ReLU(),			
		)
		self.decode3 = nn.Sequential(
			nn.Upsample(scale_factor=(2,2)),
			nn.Conv2d(256,128,1, stride=1, padding=0)
		)
		self.decode_forward3 = nn.Sequential(
			conv3x3(256, 128),
			nn.ReLU(),
			conv3x3(128, 128),
			nn.ReLU(),				
		)
		self.decode4 = nn.Sequential(
			nn.Upsample(scale_factor=(2,2)),
			nn.Conv2d(128,64,1, stride=1, padding=0)
		)
		self.decode_forward4 = nn.Sequential(
			conv3x3(128, 64),
			nn.ReLU(),
			conv3x3(64, 64),
			nn.ReLU(),
			nn.Conv2d(64, 2, kernel_size=1) # final conv 1x1
		)	
	def forward(self, x: torch.Tensor):
		x1 = self.encode1(x)
		res1 = functional.center_crop(x1, 392)
		x2 = self.encode2(x1)
		res2 = functional.center_crop(x2,200)
		x3 = self.encode3(x2)
		res3 = functional.center_crop(x3,104)
		x4 = self.encode4(x3)
		res4 = functional.center_crop(x4, 56)
		x5 = self.encode5(x4)
		y1 = self.decode1(x5)
		c1 = torch.concat((res4, y1), 1)

		y2 = self.decode_forward1(c1)
		y2 = self.decode2(y2)
		c2 = torch.concat((res3, y2), 1)

		y3 = self.decode_forward2(c2)
		y3 = self.decode3(y3)
		c3 = torch.concat((res2, y3),1)

		y4 = self.decode_forward3(c3)
		y4 = self.decode4(y4)
		c4 = torch.concat((res1, y4), 1)

		seg_map = self.decode_forward4(c4)		

		return seg_map


In [70]:
net = UNET()  # instantiate your net
num_params = sum([np.prod(p.shape) for p in net.parameters()])
print(f"Number of parameters : {num_params}")
print('-' * 50)

X = torch.rand((1,3, 572, 572))
print('output shape for unet encoding', net(X).shape)


Number of parameters : 28942850
--------------------------------------------------
output shape for unet encoding torch.Size([1, 2, 388, 388])
