In [None]:
#default_exp unet_resnet
#all_slow

In [1]:
#hide
from google.colab import drive
drive.mount('/content/drive')
%cd drive/MyDrive/noise2noise

Mounted at /content/drive


In [3]:
#export
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from noise2noise.helpers import *

# Unet resnet

In [4]:
#export
#maybe implement basic block

class ResNetBlock(nn.Module):
    
    def __init__(self, channels):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=(3, 3), padding=(1, 1), bias=False),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=(3, 3), padding=(1, 1), bias=False),
            nn.BatchNorm2d(channels))
    
    def forward(self, x):
        return F.relu(self.layers(x)+x)
    
class ResizeBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, size):
        super().__init__()
        self.resize_channel = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1), bias=False)
        self.size = size
    
    def forward(self, x):
        return self.resize_channel(F.interpolate(x, size=self.size))
    

In [5]:
noisy_imgs_1 , noisy_imgs_2 = load_images()
img = to_float_image(noisy_imgs_1[:2])
img.shape

torch.Size([2, 3, 32, 32])

In [6]:
block = ResizeBlock(3,32,(40,40))
block(img).shape

torch.Size([2, 32, 40, 40])

In [7]:
block = ResNetBlock(3)
block(img).shape

torch.Size([2, 3, 32, 32])

In [8]:
resnet = torchvision.models.resnet18(pretrained = True)
resnet

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [21]:
#export

class ResNetUnet(nn.Module):
    
    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet18(pretrained = True)
        resnet_layers = list(resnet.children())
        self.encoder = nn.ModuleList([nn.Sequential(*resnet_layers[:4]),*resnet_layers[4:-2]])
        
       
        self.decoder_layers = []
        x = torch.zeros((2, 3, 32, 32))
        for i, l in enumerate(self.encoder):
            x_next = l(x)
            in_channels = x_next.size(1)
            out_channels = x.size(1)
            size = (x.size(2),x.size(3))
            
            if i>0:
              self.decoder_layers.append(nn.Sequential(ResNetBlock(in_channels),
                                                     ResizeBlock(in_channels, out_channels, size)))    
           
            x = x_next
        self.decoder_layers.reverse()
        self.decoder = nn.ModuleList(self.decoder_layers)
        self.middle = ResNetBlock(in_channels)
        self.to_32 = nn.Sequential(ResizeBlock(64,32,(32,32)), ResNetBlock(32))
        self.to_3 = nn.Conv2d((32+3), 3, kernel_size=3, padding=1)
        
    def forward(self, x):
        img = x
        intermediary_x = []
        for l in self.encoder:
            x = l(x)
            intermediary_x.append(x)
        
        
        intermediary_x.reverse()
        x = self.middle(x)
        
        for l, x_other in zip(self.decoder,intermediary_x):
            x = l(x+x_other)

        x = self.to_32(x+intermediary_x[-1]) 
        return self.to_3(torch.cat([x,img],dim=1))

In [22]:
model = ResNetUnet()
model(img).shape

torch.Size([2, 3, 32, 32])

In [None]:
#export

class ResNetUnetNotPretrained(nn.Module):
  def __init__(self):
    super().__init__()