In [None]:
#default_exp unet_resnet
#all_slow

In [None]:
#export
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from noise2noise.helpers import *

# Unet resnet

In [None]:
#export
#maybe implement basic block

class ResNetBlock(nn.Module):
    
    def __init__(self, channels):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=(3, 3), padding=(1, 1), bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=(3, 3), padding=(1, 1), bias=False),
            nn.BatchNorm2d(channels))
    
    def forward(self, x):
        return F.relu(self.layers(x)+x)
    
class ResizeBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, size):
        super().__init__()
        self.resize_channel = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1), bias=False)
        self.size = size
    
    def forward(self, x):
        return F.interpolate(self.resize_channel(x), size=self.size)
    

In [None]:
noisy_imgs_1 , noisy_imgs_2 = load_images()
img = to_float_image(noisy_imgs_1[:2])
img.shape

torch.Size([2, 3, 32, 32])

In [None]:
block = ResizeBlock(3,32,(40,40))
block(img).shape

torch.Size([2, 32, 40, 40])

In [None]:
block = ResNetBlock(3)
block(img).shape

torch.Size([1, 3, 32, 32])

In [None]:
#export

class ResNetUnet(nn.Module):
    
    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet18(pretrained = True)
        resnet_layers = list(resnet.children())
        self.encoder = nn.ModuleList([nn.Sequential(*resnet_layers[:4]),*resnet_layers[4:-2]])
        
       
        self.decoder_layers = []
        x = torch.zeros((2, 3, 32, 32))
        for l in self.encoder:
            x_next = l(x)
            in_channels = x_next.size(1)
            out_channels = x.size(1)
            size = (x.size(2),x.size(3))
            
            self.decoder_layers.append(nn.Sequential(ResNetBlock(in_channels),
                                                     ResizeBlock(in_channels, out_channels, size)))    
           
            x = x_next
        self.decoder_layers.reverse()
        self.decoder = nn.ModuleList(self.decoder_layers)
        self.middle = ResNetBlock(in_channels)
        
    def forward(self, x):
        intermediary_x = []
        for l in self.encoder:
            x = l(x)
            intermediary_x.append(x)
        
        
        intermediary_x.reverse()
        x = self.middle(x)
        
        for l, x_other in zip(self.decoder,intermediary_x):
            x = l(x+x_other)
       
       
        return x

In [None]:
model = ResNetUnet()
model(img).shape

torch.Size([2, 3, 32, 32])