In [None]:
#default_exp architectures
#all_slow

In [None]:
#hide
#run to export library
from nbdev.export import notebook2script; notebook2script()

Converted 00_architecture_to_test.ipynb.
Converted 00_augmentation.ipynb.
Converted 00_baseline.ipynb.
Converted 00_helpers.ipynb.
Converted 00_training.ipynb.
Converted 00_unet.ipynb.
Converted 00_unet_resnet.ipynb.
Converted index.ipynb.


In [None]:
#export
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from noise2noise.helpers import *

# Architectures

> Architectures we compare in the first part

In [None]:
#export
class ResNetBlock(nn.Module):
    
    def __init__(self, channels):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=(3, 3), padding=(1, 1), bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=(3, 3), padding=(1, 1), bias=False),
            nn.BatchNorm2d(channels))
    
    def forward(self, x):
        return F.relu(self.layers(x)+x)
    

In [None]:
noisy_imgs_1 , noisy_imgs_2 = load_images()
img = to_float_image(noisy_imgs_1[:2])
img.shape

torch.Size([2, 3, 32, 32])

In [None]:
block = ResNetBlock(3)
block(img).shape

torch.Size([2, 3, 32, 32])

In [None]:
#export
class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.to_64 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=(3, 3), padding=(1, 1), bias=False),
            nn.BatchNorm2d(64),
            ResNetBlock(64)
        )
        
        self.to_128 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3, 3), padding=(1, 1), bias=False),
            nn.BatchNorm2d(128),
            ResNetBlock(128)
        )
        
        self.resblocks = nn.Sequential(*[ResNetBlock(128) for loop in range(4)])
        
        self.to_3 = nn.Sequential(
            nn.Conv2d(128, 3, kernel_size=(3, 3), padding=(1, 1), bias=False),
            nn.BatchNorm2d(3),
            ResNetBlock(3)
        )
        
    def forward(self, x):
        x = self.to_64(x)
        x = self.to_128(x)
        x = self.resblocks(x)
        return self.to_3(x)

In [None]:
model = ResNet()
model(img).shape

torch.Size([2, 3, 32, 32])

In [None]:
#export

class Unet(nn.Module):
    def __init__(self):
        super().__init__()
        
        
        
        channels = [(3,64), (64,128), (128,256)]
        self.encoder = nn.ModuleList([nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1),
                                                    nn.BatchNorm2d(ch_out),
                                                    ResNetBlock(ch_out)) for ch_in, ch_out in channels])
        self.down = nn.MaxPool2d(2, stride=2)
        self.decoder = self.make_decoder_from_encoder(self.encoder)
        
        
        
        
        self.middle = ResNetBlock(256)
        self.last_layer = nn.Conv2d(3, 3, kernel_size=3, padding=1)
        
        
    def make_decoder_from_encoder(self, encoder):
        decoder = []
        
        x = torch.zeros((2,3,32,32))
        
        for l in encoder:
            last_channels = x.size(1)
            x = l(x)
            channels = x.size(1)
            decoder.append(nn.Sequential(ResNetBlock(channels),
                                             nn.Conv2d(channels, last_channels, kernel_size=3, padding=1),
                                             nn.BatchNorm2d(last_channels)))
        
        decoder.reverse()
        return nn.ModuleList(decoder)
        
        
    
    def forward(self, x):
        intermediary_x = []
        
        
        for l in self.encoder:
            x = l(x)
            intermediary_x.append(x)
            x = self.down(x)
            
        
        x = self.middle(x)
        intermediary_x.reverse()
        
        for l, z in zip(self.decoder, intermediary_x):
            x = F.interpolate(x,z.shape[2:])
            x = l(x+z)
          
            
        return self.last_layer(x)

In [None]:
model = Unet()
model(img).shape

torch.Size([2, 3, 32, 32])