In [None]:
#default_exp unet

In [None]:
#hide
#run to export library
from nbdev.export import notebook2script; notebook2script()

Converted 00_baseline.ipynb.
Converted 00_helpers.ipynb.
Converted 00_unet.ipynb.
Converted index.ipynb.


In [None]:
#export
import torch
from noise2noise.helpers import *
import torch.nn as nn

# Unet

> Unet models

In [None]:
#export
#hide
class UnetBlockDown(nn.Module):
    
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.stack = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_channel, out_channel, 3, padding=1),
            nn.ReLU(),
        )
    def forward(self, x):
        return self.stack(x)
    
class UnetBlockUp(nn.Module):
    
    def __init__(self, in_channel, out_channel):
        super().__init__()
     
        self.stack = nn.Sequential(
            nn.ConvTranspose2d(in_channel, in_channel, 3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channel, out_channel, 3, padding=1),
            nn.ReLU()
        )
    def forward(self, x):
        return self.stack(x)

In [None]:
noisy_imgs_1 , noisy_imgs_2 = load_images()
img = to_float_image(noisy_imgs_1[:1])
img.shape

torch.Size([1, 3, 32, 32])

In [None]:
block_down = UnetBlockDown(3,64)
block_up = UnetBlockUp(64,3)
down = block_down(img)
down.shape, block_up(down).shape

(torch.Size([1, 64, 32, 32]), torch.Size([1, 3, 32, 32]))

In [None]:
#export
#hide
class Unet(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.block_down_1 = UnetBlockDown(3,64)
        self.block_down_2 = UnetBlockDown(64,128)
        
        self.pool = nn.MaxPool2d(2, stride=2, return_indices=True)
        self.unpool = nn.MaxUnpool2d(2, stride=2)
        
        self.bottom = nn.Sequential(
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU()
        )
        
        self.block_up_1 = UnetBlockUp(128,64)
        self.block_up_2 = UnetBlockUp(64,3)
    
    def forward(self, x):
        x_1 = self.block_down_1(x)
        x, indices_1 = self.pool(x_1)
        
        x_2 = self.block_down_2(x)
        x, indices_2 = self.pool(x_2)
        
        x = self.bottom(x)
        x = self.unpool(x,indices_2)+x_2
        x = self.block_up_1(x)
        
        x = self.unpool(x,indices_1)+x_1
        x = self.block_up_2(x)
        
        return x

In [None]:
unet = Unet()
unet(img).shape

torch.Size([1, 3, 32, 32])