In [5]:
import torch
import torch.nn as nn


In [9]:
def conv3x3(in_channels: int, out_channels: int):
    return nn.Conv2d(in_channels=in_channels, out_channels=out_channels,kernel_size=3, padding=1)


def max_pool_2d():
    return nn.MaxPool2d(kernel_size=2, stride=2)

class UNETlayer(nn.Module):
    # just a standard convolution layer.
    def __init__(self, in_channels: int, out_channels: int, activated=True,max_pool=False):
        super(UNETlayer, self).__init__()
        layers = [
            conv3x3(in_channels, out_channels),
            # nn.BatchNorm2d(out_channels),
        ]
        if activated:
            layers += [nn.ReLU()]

        if max_pool:
            layers += [max_pool_2d()]

        self.layer = nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)        

In [10]:
class UNET(nn.Module):
    def __init__(self):
      super(UNET, self).__init__()
      # encoding part of the Unet vanilla architecture
      self.encode = nn.Sequential(
        UNETlayer(3, 64),
        UNETlayer(64, 64, max_pool = True),

        UNETlayer(64, 128),
        UNETlayer(128, 128, max_pool = True),

        UNETlayer(128, 256),
        UNETlayer(256, 256, max_pool = True),

        UNETlayer(256, 512),
        UNETlayer(512, 512, max_pool = True),

        UNETlayer(512, 1024),
        UNETlayer(1024, 1024),
    	)
    def forward(self, x: torch.Tensor):
        x = self.encode(x)
        return x


In [None]:
import numpy as np
import matplotlib.pyplot as plt

net = UNET()  # instantiate your net
num_params = sum([np.prod(p.shape) for p in net.parameters()])
print(f"Number of parameters : {num_params}")
print('-' * 50)

X = torch.rand((3, 300, 300))
print('output shape for unet encoding', net(X).shape)
