# Super Resolution

In [22]:
%load_ext autoreload
%autoreload 2

Seed set to 42


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
import torch.nn as nn
import torch

from nimrod.image.datasets import ImageDataset, ImageDataModule, ImageSuperResDataModule, ImageSuperResDataset
from nimrod.models.core import lr_finder, train_one_cycle
from nimrod.models.resnet import ResBlock
from nimrod.utils import get_device, set_seed

from hydra.utils import instantiate
from omegaconf import OmegaConf
from rich import print
from typing import Optional, Type, List
from functools import partial

set_seed(42)
device = get_device()

Seed set to 42
[16:06:28] INFO - Using device: mps


## tiny imagenet

In [28]:
dm = ImageSuperResDataModule(
    "slegroux/tiny-imagenet-200-clean",
    data_dir = "../data/image",
    batch_size = 512
)

[16:06:32] INFO - Init ImageSuperResDataModule for slegroux/tiny-imagenet-200-clean
/Users/slegroux/miniforge3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'transform_x' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['transform_x'])`.
/Users/slegroux/miniforge3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'transform_y' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['transform_y'])`.
[16:06:32] INFO - Init ImageDataModule for slegroux/tiny-imagenet-200-clean
/Users/slegroux/miniforge3/envs/nimrod/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'transforms' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore th

In [4]:
dm.prepare_data()
dm.setup()

[15:15:08] INFO - loading dataset slegroux/tiny-imagenet-200-clean with args () from split train
[15:15:08] INFO - loading dataset slegroux/tiny-imagenet-200-clean from split train
Overwrite dataset info from restored data version if exists.
[15:15:11] INFO - Overwrite dataset info from restored data version if exists.
Loading Dataset info from ../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2
[15:15:11] INFO - Loading Dataset info from ../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2
Found cached dataset tiny-imagenet-200-clean (/Users/slegroux/Projects/nimrod/tutorials/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2)
[15:15:11] INFO - Found cached dataset tiny-imagenet-200-clean (/Users/slegroux/Projects/nimrod/tutorials/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f

In [5]:
print(dm.dim)

## Model

In [6]:
#| export
class UpBlock(nn.Module):
    def __init__(
        self,
        in_channels:int,
        out_channels:int,
        kernel_size:int=3,
        activation:Optional[Type[nn.Module]]=nn.ReLU
    ):
        super().__init__()
        layers = []
        # upsample receptive field
        layers.append(nn.UpsamplingNearest2d(scale_factor=2))
        # resnet block increase channels
        layers.append(ResBlock(in_channels, out_channels, kernel_size=kernel_size, activation=activation))
        self.nnet = nn.Sequential(*layers)

    def forward(self, x):
        return self.nnet(x)

In [7]:
m = UpBlock(3, 8)
x = torch.randn(1, 3, 64, 64)
y = m(x)
print(y.shape)

In [8]:
m.nnet[0](x).shape

torch.Size([1, 3, 128, 128])

In [14]:
class SuperResAutoencoder(nn.Module):
    def __init__(
        self,
        n_features:List[int]=[3, 8, 16, 32, 64, 128],
    ):
        super().__init__()

        down = partial(ResBlock, kernel_size=3, activation=nn.ReLU, stride=2)
        enc  =  [down(n_features[0], n_features[1], stride=1)]
        for i in range(1, len(n_features) - 1):
            enc += [down(n_features[i], n_features[i+1])]

        up = partial(UpBlock, kernel_size=3, activation=nn.ReLU)
        dec = []
        for i in range(len(n_features) - 1, 1, -1):
            dec += [up(n_features[i], n_features[i-1])]
        dec += [up(n_features[1], n_features[0])]
        dec += [down(n_features[0], n_features[0])]

        self.autoencoder = nn.Sequential(*enc, *dec)

    def forward(self, x:torch.Tensor)->torch.Tensor:
        return self.autoencoder(x)
        

In [10]:
model = SuperResAutoencoder(n_features=[3, 8, 16, 32, 64, 128, 256])
x = torch.randn(1, 3, 64, 64)
y = model(x)
print(y.shape)

## Training

In [None]:
#| notest

device = get_device()

%time
# data
cfg = OmegaConf.load('../config/data/image/tiny_imagenet_superres.yaml')
cfg.batch_size = 512
dm = instantiate(cfg)
dm.prepare_data()
dm.setup()

# model
model = SuperResAutoencoder(n_features=[3, 8, 16, 32, 64, 128, 256]).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)


Seed set to 42
[17:06:32] INFO - Using device: mps
[17:06:32] INFO - Init ImageSuperResDataModule for slegroux/tiny-imagenet-200-clean
[17:06:32] INFO - Init ImageDataModule for slegroux/tiny-imagenet-200-clean


CPU times: user 1 μs, sys: 0 ns, total: 1 μs
Wall time: 3.1 μs


[17:06:35] INFO - loading dataset slegroux/tiny-imagenet-200-clean with args () from split train
[17:06:35] INFO - loading dataset slegroux/tiny-imagenet-200-clean from split train
Overwrite dataset info from restored data version if exists.
[17:06:37] INFO - Overwrite dataset info from restored data version if exists.
Loading Dataset info from ../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2
[17:06:37] INFO - Loading Dataset info from ../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2
Found cached dataset tiny-imagenet-200-clean (/Users/slegroux/Projects/nimrod/tutorials/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f2)
[17:06:37] INFO - Found cached dataset tiny-imagenet-200-clean (/Users/slegroux/Projects/nimrod/tutorials/../data/image/slegroux___tiny-imagenet-200-clean/default/0.0.0/4b908d89fab3eb36aa8ebcd41c1996b28da7d6f

In [18]:
dm.dim

[64, 64]

In [20]:
n_epochs = 1
losses = []
lrs = []
current_step = 0
total_steps = len(dm.train_dataloader()) * n_epochs
print(f"total_steps: {total_steps}")

for epoch in range(n_epochs):
    model.train()
    for low_res, high_res in dm.train_dataloader():
        
        optimizer.zero_grad()
        low_res = low_res.to(device)
        high_res = high_res.to(device)
        outputs = model(low_res)
        loss = criterion(outputs, high_res)        
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        current_lr = optimizer.param_groups[0]['lr']
        lrs.append(current_lr)
        if not (current_step % 100):
            print(f"Loss {loss.item():.4f}, Current LR: {current_lr:.10f}, Step: {current_step}/{total_steps}")
        current_step += 1

    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for low_res, high_res in dm.test_dataloader():
            # model expects input (B,H*W)
            low_res = low_res.to(device)
            low_res = low_res.to(device)
            high_res = high_res.to(device)
            # Pass the input through the model
            outputs = model(low_res)
            val_loss = criterion(outputs, high_res)
            print(f"val Loss {val_loss.item():.4f}")
            # Get the predicted high_res

        # Print the accuracy
        print(f"Epoch {epoch + 1}: Accuracy = {100 * correct / total:.2f}%")


  return F.mse_loss(input, target, reduction=self.reduction)


RuntimeError: The size of tensor a (64) must match the size of tensor b (512) at non-singleton dimension 3