# Super Resolution

In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
import torch.nn as nn
import torch

from nimrod.image.datasets import ImageDataset, ImageDataModule
from nimrod.models.core import lr_finder, train_one_cycle
from nimrod.models.resnet import ResBlock

from hydra.utils import instantiate
from omegaconf import OmegaConf
from rich import print
from typing import Optional, Type

## tiny imagenet

In [4]:
dm = ImageDataModule(
    "slegroux/tiny-imagenet-200-clean",
    data_dir = "../data/image",
    batch_size = 512
)

[23:15:42] INFO - Init ImageDataModule for slegroux/tiny-imagenet-200-clean
/Users/slegroux/miniforge3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'transforms' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['transforms'])`.


In [5]:
dm.prepare_data()
dm.setup()

[23:15:46] INFO - loading dataset slegroux/tiny-imagenet-200-clean with args () from split train
[23:15:46] INFO - loading dataset slegroux/tiny-imagenet-200-clean from split train
Overwrite dataset info from restored data version if exists.
[23:15:48] INFO - Overwrite dataset info from restored data version if exists.
Loading Dataset info from ../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2
[23:15:48] INFO - Loading Dataset info from ../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2
Found cached dataset tiny-imagenet-200-clean (/Users/slegroux/Projects/nimrod/tutorials/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2)
[23:15:48] INFO - Found cached dataset tiny-imagenet-200-clean (/Users/slegroux/Projects/nimrod/tutorials/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f

In [7]:
print(dm.dim)

In [32]:
#| export
class UpBlock(nn.Module):
    def __init__(
        self,
        in_channels:int,
        out_channels:int,
        kernel_size:int=3,
        activation:Optional[Type[nn.Module]]=nn.ReLU
    ):
        super().__init__()
        layers = []
        # upsample receptive field
        layers.append(nn.UpsamplingNearest2d(scale_factor=2))
        # resnet block increase channels
        layers.append(ResBlock(in_channels, out_channels, kernel_size=kernel_size, activation=activation))
        self.nnet = nn.Sequential(*layers)

    def forward(self, x):
        return self.nnet(x)

In [31]:
m = UpBlock(3, 8)
x = torch.randn(1, 3, 64, 64)
y = m(x)
print(y.shape)

In [29]:
m.nnet[0](x).shape

torch.Size([1, 3, 128, 128])