In [None]:
#default_exp unet_resnet
#all_slow

In [None]:
#hide
from google.colab import drive
drive.mount('/content/drive')
%cd drive/MyDrive/noise2noise

Mounted at /content/drive


In [None]:
#export
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from noise2noise.helpers import *

# Unet resnet

In [None]:
#export
#maybe implement basic block

class ResNetBlock(nn.Module):
    
    def __init__(self, channels):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=(3, 3), padding=(1, 1), bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=(3, 3), padding=(1, 1), bias=False),
            nn.BatchNorm2d(channels))
    
    def forward(self, x):
        return F.relu(self.layers(x)+x)
    
class ResizeBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, size):
        super().__init__()
        self.resize_channel = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1), bias=False),
            nn.BatchNorm2d(out_channels))
        self.size = size
    
    def forward(self, x):
        return self.resize_channel(F.interpolate(x, size=self.size))

class UpBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.resize_channel = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1), bias=False),
            nn.BatchNorm2d(out_channels)) 
    
       
    
    def forward(self, x):
        return self.resize_channel(F.interpolate(x, scale_factor=2))

class DownBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.resize_channel = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1), bias=False),
            nn.BatchNorm2d(out_channels)) 
    
    
    def forward(self, x):
        return self.resize_channel(F.max_pool2d(x, 2,stride=2))
    

In [None]:
noisy_imgs_1 , noisy_imgs_2 = load_images()
img = to_float_image(noisy_imgs_1[:2])
img.shape

torch.Size([2, 3, 32, 32])

In [None]:
block = ResizeBlock(3,32,(40,40))
block(img).shape

torch.Size([2, 32, 40, 40])

In [None]:
block = ResNetBlock(3)
block(img).shape

torch.Size([2, 3, 32, 32])

In [None]:
#export

class ResNetUnet(nn.Module):
    
    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet18(pretrained = True)
        resnet_layers = list(resnet.children())
        self.encoder = nn.ModuleList([nn.Sequential(*resnet_layers[:4]),*resnet_layers[4:-2]])
        
       
        self.decoder_layers = []
        x = torch.zeros((2, 3, 32, 32))
        for i, l in enumerate(self.encoder):
            x_next = l(x)
            in_channels = x_next.size(1)
            out_channels = x.size(1)
            size = (x.size(2),x.size(3))
            
            if i>0:
              self.decoder_layers.append(nn.Sequential(ResNetBlock(in_channels),
                                                     ResizeBlock(in_channels, out_channels, size)))    
           
            x = x_next
        self.decoder_layers.reverse()
        self.decoder = nn.ModuleList(self.decoder_layers)
        self.middle = ResNetBlock(in_channels)
        self.to_32 = nn.Sequential(ResizeBlock(64,32,(32,32)), ResNetBlock(32))
        self.to_3 = nn.Conv2d((32+3), 3, kernel_size=3, padding=1)
        
    def forward(self, x):
        img = x
        intermediary_x = []
        for l in self.encoder:
            x = l(x)
            intermediary_x.append(x)
        
        
        intermediary_x.reverse()
        x = self.middle(x)
        
        for l, x_other in zip(self.decoder,intermediary_x):
            x = l(x+x_other)

        x = self.to_32(x+intermediary_x[-1]) 
        return self.to_3(torch.cat([x,img],dim=1))

In [None]:
model = ResNetUnet()
model(img).shape

torch.Size([2, 3, 32, 32])

In [None]:
#export

class ResNetUnetNotPretrained(nn.Module):
    def __init__(self):
        super().__init__()
        channels = [64, 128, 256]
        self.encoder = nn.ModuleList([ResNetBlock(channel) for channel in channels])
        self.down_blocks = nn.ModuleList([DownBlock(64,128), DownBlock(128,256), nn.Identity()])
        
        channels.reverse()
        self.decoder = nn.ModuleList([ResNetBlock(channel) for channel in channels])
        self.up_blocks = nn.ModuleList([nn.Identity(), UpBlock(256,128), UpBlock(128,64)])
        
        self.middle = ResNetBlock(256)
        
        self.to_64 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=(3, 3), padding=(1, 1), bias=False), nn.BatchNorm2d(64))
        self.to_3 = nn.Sequential(ResNetBlock(64+3), nn.Conv2d(64+3, 3, kernel_size=3, padding=1))
        
        self.pool = nn.MaxPool2d(2, stride=2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear')
        
    
    def forward(self, x):
        intermediary = []
        
        img = x
        x = self.to_64(x)
        
        for l, down in zip(self.encoder, self.down_blocks):
            x = l(x)
            intermediary.append(x)
            x = down(x)
            
        
        x = self.middle(x)
        
        intermediary.reverse()
        for l, up, z in zip(self.decoder, self.up_blocks, intermediary):
            x = up(x)
            x = l(x+z)
          
            
        return self.to_3(torch.cat([x,img],dim=1))

    

In [None]:
model = ResNetUnetNotPretrained()
model(img).shape

torch.Size([2, 3, 32, 32])

In [None]:
#export
class ResNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.to_64 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=(3, 3), padding=(1, 1), bias=False),
            nn.BatchNorm2d(64),
            ResNetBlock(64)
        )
        
        self.to_128 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=(3, 3), padding=(1, 1), bias=False),
            nn.BatchNorm2d(128),
            ResNetBlock(128)
        )
        
        self.resblocks = nn.Sequential(*[ResNetBlock(128) for loop in range(8)])
        
        self.to_3 = nn.Sequential(
            nn.Conv2d(128, 3, kernel_size=(3, 3), padding=(1, 1), bias=False),
            nn.BatchNorm2d(3),
            ResNetBlock(3)
        )
        
    def forward(self, x):
        x = self.to_64(x)
        x = self.to_128(x)
        x = self.resblocks(x)
        return self.to_3(x)

In [None]:
model = ResNet()
model(img).shape

torch.Size([2, 3, 32, 32])