# U-Net

> Neural net model

In [None]:
#| default_exp models.unet

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

In [2]:
#| export
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch_lr_finder import LRFinder

from omegaconf import OmegaConf
from hydra.utils import instantiate

from matplotlib import pyplot as plt

from nimrod.models.conv import ConvLayer
from nimrod.models.resnet import ResBlock
from nimrod.models.superres import UpBlock
from nimrod.utils import get_device, set_seed

from functools import partial

from typing import List
import logging

In [None]:
#| export
logger = logging.getLogger(__name__)
set_seed()

Seed set to 42


## Tiny Unet

In [None]:
def up_block(ni, nf, kernel_size=3, norm=None):
    return nn.Sequential(
        nn.UpsamplingNearest2d(scale_factor=2),
        ResBlock(ni, nf, kernel_size=kernel_size)
    )

In [None]:
x = torch.randn(1, 3, 128, 128)
up_block(3, 6)(x).shape



torch.Size([1, 6, 256, 256])

In [6]:
#| export

def init_weights(m, leaky=0.):
    if isinstance(m, (nn.Conv1d,nn.Conv2d,nn.Conv3d)): nn.init.kaiming_normal_(m.weight, a=leaky)

def zero_weights(layer):
    with torch.no_grad():
        layer.weight.zero_()
        if hasattr(layer, 'bias') and hasattr(layer.bias, 'zero_'): layer.bias.zero_()

class TinyUnet(nn.Module):
    def __init__(
        self,
        n_features:List[int]=[3, 32, 64, 128, 256, 512, 1024], # Number of features in each layer
        activation=partial(nn.LeakyReLU, negative_slope=0.1), # Activation function
        leaky:float=0.1,# Leaky ReLU negative slope
        normalization=nn.BatchNorm2d # Normalization function
    ):
        super().__init__()

        # first layer
        self.start = ResBlock(n_features[0], n_features[1], kernel_size=3, stride=1, activation=activation, normalization=normalization)
        self.encoder = nn.ModuleList()
        # encoder downsample receptive field
        down = partial(ResBlock, kernel_size=3,  stride=2, activation=activation, normalization=normalization)
        for i in range(1, len(n_features) - 1):
            self.encoder += [down(n_features[i], n_features[i+1])]

        # decoder upsampling receptive field
        up = partial(UpBlock, kernel_size=3, activation=activation, normalization=normalization)
        self.decoder = nn.ModuleList()
        for i in range(len(n_features) - 1, 1, -1):
            self.decoder += [up(n_features[i], n_features[i-1])]
        self.decoder += [up(n_features[1], n_features[0])]
        self.end = ResBlock(n_features[0], n_features[0], kernel_size=3, stride=2, activation=nn.Identity, normalization=normalization)

    def forward(self, x:torch.Tensor)->torch.Tensor:
        layers = [] # store the output of each layer
        x = self.start(x)
        for layer in self.encoder:
            layers.append(x)
            x = layer(x)
        n = len(layers)
        for i, layer in enumerate(self.decoder):
            if i != 0:
                x += layers[n-i]
            x = layer(x)
        return self.end(x+layers[0])
        

In [13]:
model = TinyUnet(n_features=[3, 16, 32])
x = torch.randn(1, 3, 64, 64)
model(x).shape



RuntimeError: The size of tensor a (128) must match the size of tensor b (64) at non-singleton dimension 3

In [24]:
class TinyUnet(nn.Module):
    def __init__(self, act=nn.ReLU, nfs=(32,64,128,256,512,1024), norm=nn.BatchNorm2d):
        super().__init__()
        self.start = ResBlock(3, nfs[0], stride=1, activation=act, normalization=norm)
        self.dn = nn.ModuleList([ResBlock(nfs[i], nfs[i+1], activation=act, normalization=norm, stride=2)
                                 for i in range(len(nfs)-1)])
        self.up = nn.ModuleList([UpBlock(nfs[i], nfs[i-1], activation=act, normalization=norm)
                                 for i in range(len(nfs)-1,0,-1)])
        self.up += [ResBlock(nfs[0], 3, activation=act, normalization=norm)]
        self.end = ResBlock(3, 3, activation=nn.Identity, normalization=norm)

    def forward(self, x):
        layers = []
        layers.append(x)
        x = self.start(x)
        for l in self.dn:
            layers.append(x)
            x = l(x)
        n = len(layers)
        for i,l in enumerate(self.up):
            if i!=0: x += layers[n-i]
            x = l(x)
        return self.end(x+layers[0])

In [25]:
model = TinyUnet()
x = torch.randn(1, 3, 64, 64)
model(x).shape



torch.Size([1, 3, 64, 64])

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()