In [71]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt


In [72]:
def conv3x3(in_channels: int, out_channels: int):
    return nn.Conv2d(in_channels=in_channels, out_channels=out_channels,kernel_size=3, padding=0)


def max_pool_2d():
    return nn.MaxPool2d(kernel_size=2, stride=2)

class UnetEncodeLayer(nn.Module):
    # just a standard convolution layer.
    def __init__(self, in_channels: int, out_channels: int, activated=True,max_pool=False):
        super(UnetEncodeLayer, self).__init__()
        layers = [
            conv3x3(in_channels, out_channels),
            # nn.BatchNorm2d(out_channels),
        ]
        if activated:
            layers += [nn.ReLU()]

        if max_pool:
            layers += [max_pool_2d()]

        self.layer = nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)
    

In [84]:
class UNET(nn.Module):
	def __init__(self):
		super(UNET, self).__init__()
    	# encoding part of the Unet vanilla architecture
		self.encode = nn.Sequential(
			UnetEncodeLayer(3, 64),
			UnetEncodeLayer(64, 64, max_pool = True),

			UnetEncodeLayer(64, 128),
			UnetEncodeLayer(128, 128, max_pool = True),

			UnetEncodeLayer(128, 256),
			UnetEncodeLayer(256, 256, max_pool = True),

			UnetEncodeLayer(256, 512),
			UnetEncodeLayer(512, 512, max_pool = True),

			UnetEncodeLayer(512, 1024),
			UnetEncodeLayer(1024, 1024),
		)
		self.decode = nn.Sequential(
			nn.Upsample(scale_factor=(2,2)),
			nn.Conv2d(1024,512,2,padding=0)
		)
	def forward(self, x: torch.Tensor):
		x = self.encode(x)
		x = self.decode(x)
		return x


In [85]:
net = UNET()  # instantiate your net
num_params = sum([np.prod(p.shape) for p in net.parameters()])
print(f"Number of parameters : {num_params}")
print('-' * 50)

X = torch.rand((1,3, 572, 572))
print('output shape for unet encoding', net(X).shape)


Number of parameters : 20940864
--------------------------------------------------
output shape for unet encoding torch.Size([1, 512, 55, 55])
