# Super Resolution

> Neural net modules

In [None]:
#| default_exp models.superres

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

In [None]:
#| export
import torch.nn as nn
import torch

from nimrod.models.resnet import ResBlock
from nimrod.models.core import Regressor
from nimrod.utils import get_device, set_seed

from rich import print
from typing import Optional, Type, List, Callable, Any

from functools import partial
import logging

set_seed(42)
logger = logging.getLogger(__name__)
device = get_device()

Seed set to 42
Seed set to 42
Seed set to 42
[23:09:19] INFO - Using device: mps


## Autoencoder

In [None]:
#| export

class UpBlock(nn.Module):
    def __init__(
        self,
        in_channels:int, # Number of input channels
        out_channels:int, # Number of output channels
        kernel_size:int=3, # Kernel size
        activation:Optional[Type[nn.Module]]=nn.ReLU # Activation function
    ):
        super().__init__()
        layers = []
        # upsample receptive field
        layers.append(nn.UpsamplingNearest2d(scale_factor=2))
        # resnet block increase channels
        layers.append(ResBlock(in_channels, out_channels, kernel_size=kernel_size, activation=activation))
        self.nnet = nn.Sequential(*layers)

    def forward(self, x):
        return self.nnet(x)

In [None]:
m = UpBlock(3, 8)
x = torch.randn(1, 3, 64, 64)
y = m(x)
print(y.shape)

In [None]:
#| export 

class SuperResAutoencoder(nn.Module):
    def __init__(
        self,
        n_features:List[int]=[3, 8, 16, 32, 64, 128], # Number of features in each layer
    ):
        super().__init__()

        down = partial(ResBlock, kernel_size=3, activation=nn.ReLU, stride=2)
        enc  =  [down(n_features[0], n_features[1], stride=1)]
        for i in range(1, len(n_features) - 1):
            enc += [down(n_features[i], n_features[i+1])]

        up = partial(UpBlock, kernel_size=3, activation=nn.ReLU)
        dec = []
        for i in range(len(n_features) - 1, 1, -1):
            dec += [up(n_features[i], n_features[i-1])]
        dec += [up(n_features[1], n_features[0])]
        dec += [down(n_features[0], n_features[0])]

        self.autoencoder = nn.Sequential(*enc, *dec)

    def forward(self, x:torch.Tensor)->torch.Tensor:
        return self.autoencoder(x)
        

In [None]:
# RGB 
model = SuperResAutoencoder(n_features=[3, 8, 16, 32, 64, 128])
x = torch.randn(1, 3, 64, 64)
y = model(x)
print(y.shape)

# GRAY
model = SuperResAutoencoder(n_features=[1, 8, 16, 32, 64])
x = torch.randn(1, 1, 28, 28) # note dim is nearest power of 2
y = model(x)
print(y.shape)

    


In [None]:
#| export

class SuperResAutoencoderX(Regressor):
    def __init__(
        self,
        nnet:SuperResAutoencoder, # super res autoencoder neural net
        optimizer: Callable[...,torch.optim.Optimizer], # optimizer partial
        scheduler: Optional[Callable[...,Any]]=None, # scheduler partial
    ):
        logger.info("SuperResAutoencoderX: init")
        super().__init__(
            nnet=nnet,
            optimizer=optimizer,
            scheduler=scheduler
            )
        self.nnet = nnet
        self.register_module('nnet', self.nnet)

In [None]:

m = SuperResAutoencoderX(
    nnet=SuperResAutoencoder(),
    optimizer=partial(torch.optim.AdamW, lr=1e-4, weight_decay=1e-5),
    scheduler=partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode='min', factor=0.1, patience=10)
)

x = torch.randn(1,3,64,64)
y = m(x)
print(y.shape)

[22:08:49] INFO - SuperResAutoencoderX: init
[22:08:49] INFO - Regressor: init
/Users/slegroux/miniforge3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'nnet' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['nnet'])`.


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()