<a href="https://www.kaggle.com/code/shobhiii/super-resolution-using-pretrained-mae?scriptVersionId=230384264" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

### Imports Required Libraries  
This code imports essential libraries for deep learning (PyTorch, torchvision), image processing (skimage), and utility tools (tqdm, matplotlib) for training and evaluating models.

In [1]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from skimage.metrics import structural_similarity as ssim
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision.models import vgg16
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import time
from tqdm import tqdm
import matplotlib.pyplot as plt

### Custom Dataset and DataLoader Creation  
Defines `ImageDataset` to load low-resolution and high-resolution image pairs from `.npy` files.  
The `create_dataloaders` function splits the dataset into training and validation sets and returns corresponding DataLoaders.


In [2]:
class ImageDataset(Dataset):
    def __init__(self, low_res_path, high_res_path):
        self.low_res_path = low_res_path
        self.high_res_path = high_res_path
        self.file_names = sorted(os.listdir(low_res_path))  
    
    def __len__(self):
        return len(self.file_names)
    
    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        low_res = np.load(os.path.join(self.low_res_path, file_name))
        high_res = np.load(os.path.join(self.high_res_path, file_name))
        low_res = np.repeat(low_res, 3, axis=0)  # (3, 75, 75)
        high_res = np.repeat(high_res, 3, axis=0)
        
        return torch.tensor(low_res, dtype=torch.float32), torch.tensor(high_res, dtype=torch.float32)

def create_dataloaders(base_path, batch_size=32):
    """
    Returns:
        tuple: (train_loader, val_loader)
    """
    low_res_path = os.path.join(base_path, "LR")
    high_res_path = os.path.join(base_path, "HR")
    dataset = ImageDataset(low_res_path, high_res_path)
    
    val_size = int(0.1 * len(dataset))
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, val_loader

### Patch Embedding Layer  
Defines `PatchEmbed`, a module that converts an image into a sequence of patch embeddings using a convolutional layer.  
It splits the input image into non-overlapping patches and projects them into an embedding space.


In [3]:
class PatchEmbed(nn.Module):
    """ Image to Patch Embedding """
    def __init__(self, img_size=75, patch_size=16, in_chans=3, embed_dim=1024):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size//patch_size)**2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x)
        return x

### Multi-Head Self-Attention Module  
Defines an `Attention` module that implements multi-head self-attention.  
It computes query, key, and value projections, applies scaled dot-product attention, and processes the output through linear layers with dropout.


In [4]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads=16, qkv_bias=True, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


### Multi-Layer Perceptron (MLP) Block  
Defines an `Mlp` module with two fully connected layers, an activation function (GELU), and dropout.  
It expands the input features and applies non-linearity, commonly used in transformer architectures.


In [5]:
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features * 4
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

### Transformer Block  
Defines a `TransformerBlock` consisting of multi-head self-attention, layer normalization, and a feed-forward MLP.  
It follows the standard Transformer architecture with residual connections and normalization layers.


In [6]:
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path = 0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

### Transformer Encoder  
Defines a `TransformerEncoder` that processes images by embedding patches and applying multiple Transformer blocks.  
It includes positional embeddings, normalization, and a stack of self-attention layers for feature extraction.


In [7]:
class TransformerEncoder(nn.Module):
    def __init__(self, img_size=75, patch_size=16, in_chans=3, embed_dim=1024, depth=24, 
                 num_heads=16, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = (img_size // patch_size) ** 2
        
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                drop=drop_rate, attn_drop=attn_drop_rate)
            for i in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # Extract patches
        x = self.patch_embed(x)
        x = rearrange(x, 'b c h w -> b (h w) c')  # Flatten spatial dims into sequence
        
        # Add positional embedding
        x = x + self.pos_embed
        
        # Apply transformer blocks
        for blk in self.blocks:
            x = blk(x)
        
        x = self.norm(x)
        return x

### Transformer Decoder  
Defines a `TransformerDecoder` that reconstructs the full image from encoded patches and mask tokens.  
It applies a series of Transformer blocks and normalization to refine the decoded representation.


In [8]:
class TransformerDecoder(nn.Module):
    """
    Decoder will be used to reconstruct the full image from the encoded visible patches and mask tokens.
    """
    def __init__(self,
                embed_dim = 768,
                depth = 8,
                num_heads = 16,
                mlp_ratio = 4.,
                qkv_bias = True,
                drop_rate = 0.,
                attn_drop_rate = 0.,
                drop_path_rate = 0.):
        super().__init__()
        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim = embed_dim,
                num_heads = num_heads,
                mlp_ratio = mlp_ratio,
                qkv_bias = qkv_bias,
                drop = drop_rate,
                attn_drop = attn_drop_rate,
                drop_path = drop_path_rate * i / depth
            ) for i in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        
    def forward(self, x):
        for block in self.blocks:
            x  = block(x)
        x = self.norm(x)
        return x

### Super-Resolution Transformer Decoder  
Defines `SRTransformerDecoder`, a Transformer-based decoder for super-resolution tasks.  
It processes encoded features using multiple Transformer blocks and applies layer normalization to refine the output.


In [9]:
class SRTransformerDecoder(nn.Module):
    def __init__(self, embed_dim=512, depth=8, num_heads=8, mlp_ratio=4., 
                 qkv_bias=True, drop_rate=0., attn_drop_rate=0.):
        super().__init__()
        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                drop=drop_rate, attn_drop=attn_drop_rate)
            for i in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x


### Pixel Shuffle Upsampling  
Defines `PixelShuffleUpsample`, which upsamples feature maps using a convolutional layer followed by pixel shuffle.  
It increases the spatial resolution while maintaining feature integrity and applies GELU activation.


In [10]:
class PixelShuffleUpsample(nn.Module):
    def __init__(self, in_features, scale_factor=2):
        super().__init__()
        self.conv = nn.Conv2d(in_features, in_features * scale_factor**2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)
        self.act = nn.GELU()
        
    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.act(x)
        return x

## Masked Autoencoder (MAE) for Image Reconstruction

This model implements a **Masked Autoencoder (MAE)** using a Transformer-based architecture.  
It randomly masks image patches, encodes visible ones, and reconstructs the full image using a decoder.  
The MAE learns image representations by predicting the missing patches, optimizing reconstruction loss.  


In [11]:
class MaskedAutoEncoder(nn.Module):
    def __init__(self,
                img_size = 224,
                patch_size = 16,
                in_chans = 3,
                embed_dim = 1024,
                depth = 24,
                num_heads = 16,
                decoder_embed_dim = 512,
                decoder_depth = 8,
                decoder_num_heads = 16,
                mlp_ratio = 4.,
                norm_layer = nn.LayerNorm):
        super().__init__()
        # Encoder Components
        # PatchEmbed splits the image into patches and embeds them
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        self.num_patches = self.patch_embed.num_patches

        # Class token and Positional Encoding for encoder
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))

        # Encoder for the visible Patches
        self.encoder = TransformerEncoder(
            embed_dim = embed_dim,
            depth = depth,
            num_heads = num_heads,
            mlp_ratio = mlp_ratio
        )

        # Decoder components
        # Convert the encoder output to decoder dimension
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias = True)
        # Learnable mask token that is used for masked patches
        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
        # Positional encoding ffor the decoder
        self.decoder_pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, decoder_embed_dim))
        
        # Decoder to reconstruct the full image
        self.decoder = TransformerDecoder(
            embed_dim = decoder_embed_dim,
            depth = decoder_depth,
            num_heads = decoder_num_heads,
            mlp_ratio = mlp_ratio
            )
        # Final prediction Layer: predict pixel values for eeach patch
        self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias = True)
        
        # Initialize weights for all components
        self.initialize_weights()

        # Store model parameters for later use
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.img_size = img_size

    def initialize_weights(self):
        pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, self.pos_embed.shape[2]))
        decoder_pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, self.decoder_pos_embed.shape[2]))
        
        
        # Use truncated normal distributions 
        nn.init.trunc_normal_(self.pos_embed, std = 0.02)
        nn.init.trunc_normal_(self.decoder_pos_embed, std = 0.02)

        nn.init.trunc_normal_(self.cls_token, std = 0.02)
        nn.init.trunc_normal_(self.mask_token, std = 0.02)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std = 0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def random_masking(self, x, mask_ratio):
        N, L, D = x.shape # Batch, length, dimension
        len_keep = int(L * (1 - mask_ratio)) # Number of patches to keep

        # Generate uniform random noise for each patch in each sample
        noise = torch.rand(N, L, device = x.device) 

        # Sort noise to determine which patches to keep/remove
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)

        # Keep the first len_keep patches (lowest noise values)
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index = ids_keep.unsqueeze(-1).repeat(1, 1, D))

        mask = torch.ones([N, L], device = x.device)
        mask[:, :len_keep] = 0
        # Unshuffle to get the binary mask for original sequence
        mask = torch.gather(mask, dim = 1, index = ids_restore)

        return x_masked, mask, ids_restore

    def forward_encoder(self, x, mask_ratio):
        # Convert img to patches
        x = self.patch_embed(x)
    
        # Add positional embeddings
        cls_token = self.cls_token + self.pos_embed[:, :1, :]  # [1, 1, D]
        x = x + self.pos_embed[:, 1:, :]  # [B, N, D]
    
        # Apply random masking
        x, mask, ids_restore = self.random_masking(x, mask_ratio)  # [B, N', D]

        # Expand class token to match batch size
        cls_tokens = cls_token.expand(x.shape[0], -1, -1)  # [B, 1, D]

        # Concatenate cls_token and image tokens
        x = torch.cat((cls_tokens, x), dim=1)  # [B, N'+1, D]

        # Process through transformer encoder
        x = self.encoder(x)

        return x, mask, ids_restore


    def forward_decoder(self, x, ids_restore):
        # embed the encoder output
        x = self.decoder_embed(x)

        # add mask tokens
        mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)

        # exclude class token x[:, 1:] and append mask tokens
        x_ = torch.cat([x[:, 1:, :], mask_tokens], dim = 1)

        # unshuffle: restore the original sequence order
        x_ = torch.gather(x_, dim = 1, index = ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))

        # append class token
        x = torch.cat([x[:, :1, :], x_], dim = 1)

        # apply positional embedding 
        x = x + self.decoder_pos_embed

        # apply transformer decoder 
        x = self.decoder(x)

        # predict pixel values for each patch
        x = self.decoder_pred(x)

        # remove class token from prediictions
        x = x[:, 1:, :]

        return x

    def forward(self, imgs, mask_ratio = 0.75):
        # Forward Pass the entire MAE model

        # run encoder on the images with masking
        latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)

        # run decoder to predict all patches
        pred = self.forward_decoder(latent, ids_restore)

        # convert input images to patches for loss calculation
        target = self.patchify(imgs)

        # calculate mse loss only for masked patches
        loss = self.calculate_loss(pred, target, mask)

        return loss, pred, mask

    def patchify(self, imgs):
        # convert imgs to patches for calculating loss
        p = self.patch_size
        h = w = self.img_size//p

        x = rearrange(imgs, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
        return x

    def unpatchify(self, x):
        p = self.patch_size
        h = w = int(x.shape[1] ** 0.5)

        imgs = rearrange(x, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', h = h, w = w, p1 = p, p2 = p)
        return imgs

    def calculate_loss(self, pred, target, mask):
        # calculate mse loss for masked patches only

        loss = (pred - target)**2
        loss = loss.mean(dim = -1)

        loss = (loss*mask).sum()/mask.sum() 

        return loss

## Super-Resolution Model using Transformer-based MAE

This model builds upon a **Masked Autoencoder (MAE)** to perform **image super-resolution**.  
It encodes low-resolution images, processes them with a Transformer-based decoder, and progressively upsamples them to a higher resolution.  
The model also supports loading pretrained MAE weights for better initialization and improved performance.  


In [12]:
class SuperResolutionModel(nn.Module):
    def __init__(self, img_size=75, patch_size=15, in_chans=3, encoder_embed_dim=1024, 
                 decoder_embed_dim=512, encoder_depth=24, decoder_depth=8, encoder_num_heads=16, 
                 decoder_num_heads=8, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.):
        super().__init__()
        
        # Encoder from MAE (pretrained)
        self.encoder = TransformerEncoder(
            img_size=img_size, 
            patch_size=patch_size, 
            in_chans=in_chans,
            embed_dim=encoder_embed_dim, 
            depth=encoder_depth,
            num_heads=encoder_num_heads, 
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias, 
            drop_rate=drop_rate, 
            attn_drop_rate=attn_drop_rate
        )
        
        # Bridge from encoder to SR decoder
        self.sr_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim)
        
        # SR-specific decoder
        self.sr_decoder = SRTransformerDecoder(
            embed_dim=decoder_embed_dim,
            depth=decoder_depth,
            num_heads=decoder_num_heads,
            mlp_ratio=mlp_ratio,
            qkv_bias=qkv_bias,
            drop_rate=drop_rate,
            attn_drop_rate=attn_drop_rate
        )
        
        # Calculate output sequence length (patches)
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_size = patch_size
        
        # Adjusted spatial dimensions after reshaping
        self.patches_h = self.patches_w = img_size // patch_size
        
        # Progressive upsampling layers
        self.upsampling = nn.Sequential(
            nn.Conv2d(decoder_embed_dim, decoder_embed_dim * 4, kernel_size=3, padding=1),
            nn.PixelShuffle(upscale_factor=2),  # 5x5 → 10x10
            nn.GELU(),
    
            nn.Conv2d(decoder_embed_dim, decoder_embed_dim * 4, kernel_size=3, padding=1),
            nn.PixelShuffle(upscale_factor=2),  # 10x10 → 20x20
            nn.GELU(),
    
            nn.Conv2d(decoder_embed_dim, decoder_embed_dim * 4, kernel_size=3, padding=1),
            nn.PixelShuffle(upscale_factor=2),  # 20x20 → 40x40
            nn.GELU(),

            nn.Conv2d(decoder_embed_dim, decoder_embed_dim * 4, kernel_size=3, padding=1),
            nn.PixelShuffle(upscale_factor=2),  # 40x40 → 80x80
            nn.GELU(),

            nn.Upsample(size=(150, 150), mode="bilinear", align_corners=True)  # Final resize
        )
        
        # Final projection to RGB
        self.final_conv = nn.Conv2d(decoder_embed_dim, in_chans, kernel_size=3, padding=1)
        
    def forward(self, x):
        # Encode
        features = self.encoder(x)
        
        # Bridge to decoder embedding dimension
        features = self.sr_embed(features)
        
        # Decode
        features = self.sr_decoder(features)
        
        # Reshape to spatial format
        features = rearrange(features, 'b (h w) c -> b c h w', h=self.patches_h, w=self.patches_w)
        
        # Upsample
        features = self.upsampling(features)
        
        # Final projection to output image
        output = self.final_conv(features)
        
        return output


    def load_from_mae(self, mae_model):
        # Load encoder
        self.encoder.patch_embed.proj.weight.data = mae_model.patch_embed.proj.weight.data
        self.encoder.patch_embed.proj.bias.data = mae_model.patch_embed.proj.bias.data
        
        # Adjust position embeddings for potentially different input size
        if self.encoder.pos_embed.shape != mae_model.encoder.pos_embed.shape:
            # Interpolate position embeddings to new size
            pos_embed = mae_model.encoder.pos_embed
            src_size = int(mae_model.encoder.pos_embed.shape[1] ** 0.5)
            tgt_size = int(self.encoder.pos_embed.shape[1] ** 0.5)
            
            pos_embed = pos_embed.reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2)
            pos_embed = F.interpolate(pos_embed, size=(tgt_size, tgt_size), mode='bicubic')
            pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(1, tgt_size*tgt_size, -1)
            self.encoder.pos_embed.data = pos_embed
        else:
            self.encoder.pos_embed.data = mae_model.encoder.pos_embed.data
            
        # Load transformer blocks
        for i, blk in enumerate(self.encoder.blocks):
            blk.load_state_dict(mae_model.encoder.blocks[i].state_dict())
            
        # Load encoder norm
        self.encoder.norm.load_state_dict(mae_model.encoder.norm.state_dict())
        
        # Bridge embed (initialize from MAE decoder_embed)
        self.sr_embed.weight.data = mae_model.decoder_embed.weight.data
        self.sr_embed.bias.data = mae_model.decoder_embed.bias.data
        
        # Initialize decoder from MAE decoder (if dimensions match)
        for i, blk in enumerate(self.sr_decoder.blocks):
            if i < len(mae_model.decoder.blocks):
                # Only load parameters with matching dimensions
                mae_blk = mae_model.decoder.blocks[i]
                
                # Check and load attention
                if blk.attn.qkv.weight.shape == mae_blk.attn.qkv.weight.shape:
                    blk.attn.qkv.load_state_dict(mae_blk.attn.qkv.state_dict())
                if blk.attn.proj.weight.shape == mae_blk.attn.proj.weight.shape:
                    blk.attn.proj.load_state_dict(mae_blk.attn.proj.state_dict())
                
                # Check and load MLP
                if blk.mlp.fc1.weight.shape == mae_blk.mlp.fc1.weight.shape:
                    blk.mlp.fc1.load_state_dict(mae_blk.mlp.fc1.state_dict())
                if blk.mlp.fc2.weight.shape == mae_blk.mlp.fc2.weight.shape:
                    blk.mlp.fc2.load_state_dict(mae_blk.mlp.fc2.state_dict())
                
                # Load norms (if dimensions match)
                if blk.norm1.weight.shape == mae_blk.norm1.weight.shape:
                    blk.norm1.load_state_dict(mae_blk.norm1.state_dict())
                if blk.norm2.weight.shape == mae_blk.norm2.weight.shape:
                    blk.norm2.load_state_dict(mae_blk.norm2.state_dict())
        
        # Load decoder norm
        if self.sr_decoder.norm.weight.shape == mae_model.decoder.norm.weight.shape:
            self.sr_decoder.norm.load_state_dict(mae_model.decoder.norm.state_dict())

### VGG Perceptual Loss
This class implements a perceptual loss function using a pre-trained VGG16 network. It extracts hierarchical features from input images and computes an L1 loss at multiple layers to measure perceptual similarity. The loss is commonly used for tasks like image super-resolution and style transfer.


In [13]:
class VGGPerceptualLoss(nn.Module):
    def __init__(self, resize=True):
        super(VGGPerceptualLoss, self).__init__()
        vgg_pretrained = vgg16(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        
        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg_pretrained[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_pretrained[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_pretrained[x])
            
        for param in self.parameters():
            param.requires_grad = False
        
        self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4]
        self.resize = resize
        
    def forward(self, x, y):
        if self.resize:
            x = nn.functional.interpolate(x, mode='bilinear', size=(224, 224), align_corners=False)
            y = nn.functional.interpolate(y, mode='bilinear', size=(224, 224), align_corners=False)
        
        loss = 0.0
        x_vgg, y_vgg = self.preprocess(x), self.preprocess(y)
        
        x_feat1 = self.slice1(x_vgg)
        y_feat1 = self.slice1(y_vgg)
        loss += self.weights[0] * nn.functional.l1_loss(x_feat1, y_feat1)
        
        x_feat2 = self.slice2(x_feat1)
        y_feat2 = self.slice2(y_feat1)
        loss += self.weights[1] * nn.functional.l1_loss(x_feat2, y_feat2)
        
        x_feat3 = self.slice3(x_feat2)
        y_feat3 = self.slice3(y_feat2)
        loss += self.weights[2] * nn.functional.l1_loss(x_feat3, y_feat3)
        
        x_feat4 = self.slice4(x_feat3)
        y_feat4 = self.slice4(y_feat3)
        loss += self.weights[3] * nn.functional.l1_loss(x_feat4, y_feat4)
        
        return loss
    
    def preprocess(self, x):
        # Normalize to match VGG input
        mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
        std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
        return (x - mean) / std

### Image Quality Metrics Calculation  
This function computes MSE, SSIM, and PSNR for a batch of images by comparing predicted and target images. It processes images in the [0,1] range and returns the average metrics, commonly used for image restoration and super-resolution evaluation.


In [14]:
def calculate_metrics(pred, target):
    # Convert tensors to numpy arrays
    pred = pred.detach().cpu().numpy().transpose(0, 2, 3, 1)  # B, H, W, C
    target = target.detach().cpu().numpy().transpose(0, 2, 3, 1)  # B, H, W, C
    
    # Initialize metrics
    batch_mse = 0
    batch_ssim = 0
    batch_psnr = 0
    batch_size = pred.shape[0]
    
    # Calculate metrics for each image in batch
    for i in range(batch_size):
        # Clip values to valid image range [0, 1]
        p = np.clip(pred[i], 0, 1)
        t = np.clip(target[i], 0, 1)
        
        # MSE
        mse = np.mean((p - t) ** 2)
        batch_mse += mse
        
        # SSIM (multichannel for RGB)
        data_range = t.max() - t.min()
        batch_ssim += ssim(p, t, multichannel=True, channel_axis=2, data_range = data_range)
        
        # PSNR
        batch_psnr += psnr(t, p, data_range=1.0)
    
    # Return average metrics
    return {
        'mse': batch_mse / batch_size,
        'ssim': batch_ssim / batch_size,
        'psnr': batch_psnr / batch_size
    }


In [15]:
def plot_training_curves(history, save_dir):
    """Plot and save training curves"""
    # Create figure with subplots
    fig, axs = plt.subplots(2, 2, figsize=(15, 10))
    
    # Plot loss curves
    axs[0, 0].plot(history['train_loss'], label='Train Loss')
    axs[0, 0].plot(history['val_loss'], label='Validation Loss')
    axs[0, 0].set_title('Loss')
    axs[0, 0].set_xlabel('Epoch')
    axs[0, 0].set_ylabel('Loss')
    axs[0, 0].legend()
    
    # Plot MSE
    axs[0, 1].plot(history['val_mse'], label='Validation MSE')
    axs[0, 1].set_title('Mean Squared Error (MSE)')
    axs[0, 1].set_xlabel('Epoch')
    axs[0, 1].set_ylabel('MSE')
    axs[0, 1].legend()
    
    # Plot SSIM
    axs[1, 0].plot(history['val_ssim'], label='Validation SSIM')
    axs[1, 0].set_title('Structural Similarity Index (SSIM)')
    axs[1, 0].set_xlabel('Epoch')
    axs[1, 0].set_ylabel('SSIM')
    axs[1, 0].legend()
    
    # Plot PSNR
    axs[1, 1].plot(history['val_psnr'], label='Validation PSNR')
    axs[1, 1].set_title('Peak Signal-to-Noise Ratio (PSNR)')
    axs[1, 1].set_xlabel('Epoch')
    axs[1, 1].set_ylabel('PSNR (dB)')
    axs[1, 1].legend()
    
    plt.tight_layout()
    
    # Save plot
    plt.savefig(os.path.join(save_dir, 'training_curves.png'))
    plt.close()

### Super-Resolution Model Training  
This function trains a super-resolution model using both pixel-wise (L1) and perceptual (VGG-based) loss. It leverages an MAE-pretrained model, tracks validation metrics (MSE, SSIM, PSNR), and saves the best model based on PSNR. Training progress is logged, and model checkpoints are saved periodically.


In [16]:
def train_model(model, mae_model, train_dataloader, val_dataloader, device, 
               num_epochs=50, save_dir='./checkpoints'):
    
    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)
    
    # Load weights from MAE model
    model.load_from_mae(mae_model)
    model = model.to(device)
    
    # Initialize losses
    criterion_pixel = nn.L1Loss()
    criterion_perceptual = VGGPerceptualLoss().to(device)
    
    # Initialize optimizer and scheduler
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-6)
    
    # Training statistics
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_mse': [],
        'val_ssim': [],
        'val_psnr': []
    }
    
    best_psnr = 0
    
    # Start training
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        
        # Training phase
        model.train()
        train_loss = 0
        
        progress_bar = tqdm(train_dataloader, desc=f"Training")
        for batch_idx, (low_res, high_res) in enumerate(progress_bar):
            low_res, high_res = low_res.to(device), high_res.to(device)
            
            # Forward pass
            optimizer.zero_grad()
            # low_res = low_res.repeat(1, 3, 1, 1)
            outputs = model(low_res)
            
            # Calculate losses
            pixel_loss = criterion_pixel(outputs, high_res)
            perceptual_loss = criterion_perceptual(outputs, high_res)
            loss = pixel_loss + 0.1 * perceptual_loss
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Update statistics
            train_loss += loss.item()
            progress_bar.set_postfix({"loss": loss.item()})
        
        # Calculate average training loss
        train_loss /= len(train_dataloader)
        history['train_loss'].append(train_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0
        val_metrics = {'mse': 0, 'ssim': 0, 'psnr': 0}
        
        with torch.no_grad():
            progress_bar = tqdm(val_dataloader, desc=f"Validation")
            for batch_idx, (low_res, high_res) in enumerate(progress_bar):
                low_res, high_res = low_res.to(device), high_res.to(device)
                
                # Forward pass
                outputs = model(low_res)
                
                # Calculate losses
                pixel_loss = criterion_pixel(outputs, high_res)
                perceptual_loss = criterion_perceptual(outputs, high_res)
                loss = pixel_loss + 0.1 * perceptual_loss
                
                # Update statistics
                val_loss += loss.item()
                
                # Calculate metrics
                metrics = calculate_metrics(outputs, high_res)
                val_metrics['mse'] += metrics['mse']
                val_metrics['ssim'] += metrics['ssim']
                val_metrics['psnr'] += metrics['psnr']
                
                progress_bar.set_postfix({"val_loss": loss.item()})
        
        # Calculate average validation loss and metrics
        val_loss /= len(val_dataloader)
        val_metrics['mse'] /= len(val_dataloader)
        val_metrics['ssim'] /= len(val_dataloader)
        val_metrics['psnr'] /= len(val_dataloader)
        
        # Update history
        history['val_loss'].append(val_loss)
        history['val_mse'].append(val_metrics['mse'])
        history['val_ssim'].append(val_metrics['ssim'])
        history['val_psnr'].append(val_metrics['psnr'])
        
        # Update scheduler
        scheduler.step()
        
        # Print epoch statistics
        print(f"Train Loss: {train_loss:.6f}")
        print(f"Val Loss: {val_loss:.6f}, MSE: {val_metrics['mse']:.6f}, "
              f"SSIM: {val_metrics['ssim']:.6f}, PSNR: {val_metrics['psnr']:.6f}")
        
        # Save best model based on PSNR
        if val_metrics['psnr'] > best_psnr:
            best_psnr = val_metrics['psnr']
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': val_loss,
                'psnr': val_metrics['psnr'],
                'ssim': val_metrics['ssim'],
                'mse': val_metrics['mse'],
            }, os.path.join(save_dir, 'best_model.pth'))
            print(f"Saved best model with PSNR: {best_psnr:.6f}")
        
        # Save checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': val_loss,
                'history': history,
            }, os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pth'))
    
    # Plot training curves
    plot_training_curves(history, save_dir)
    
    return model, history

In [17]:
base_path = "/kaggle/input/dataset-6-for-sr/Dataset"
train_dataloader, val_dataloader = create_dataloaders(base_path, batch_size=8)

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

In [19]:
mae_model = MaskedAutoEncoder().to(device)
mae_checkpoint = torch.load('/kaggle/input/pretrained-masked-auto-encoder/pytorch/default/1/model.pth')
mae_model.load_state_dict(mae_checkpoint, strict = False)

  mae_checkpoint = torch.load('/kaggle/input/pretrained-masked-auto-encoder/pytorch/default/1/model.pth')


_IncompatibleKeys(missing_keys=['encoder.pos_embed', 'encoder.patch_embed.proj.weight', 'encoder.patch_embed.proj.bias'], unexpected_keys=[])

In [20]:
sr_model = SuperResolutionModel(
    img_size=75,  # Low-res input size
    patch_size=16,
    in_chans=3,
    encoder_embed_dim=1024,  
    decoder_embed_dim=512,   
    encoder_depth=24,        
    decoder_depth=8          
)

In [21]:
trained_model, history = train_model(
    model=sr_model,
    mae_model=mae_model,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    device=device,
    num_epochs=10,
    save_dir='./sr_checkpoints'
)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:02<00:00, 224MB/s]


Epoch 1/10


Training: 100%|██████████| 1125/1125 [07:57<00:00,  2.36it/s, loss=0.0125]
Validation: 100%|██████████| 125/125 [00:35<00:00,  3.53it/s, val_loss=0.0129]


Train Loss: 0.016772
Val Loss: 0.012808, MSE: 0.000186, SSIM: 0.960664, PSNR: 37.466745
Saved best model with PSNR: 37.466745
Epoch 2/10


Training: 100%|██████████| 1125/1125 [05:26<00:00,  3.44it/s, loss=0.0118]
Validation: 100%|██████████| 125/125 [00:18<00:00,  6.64it/s, val_loss=0.0121]


Train Loss: 0.012518
Val Loss: 0.012079, MSE: 0.000144, SSIM: 0.962628, PSNR: 38.535957
Saved best model with PSNR: 38.535957
Epoch 3/10


Training: 100%|██████████| 1125/1125 [05:23<00:00,  3.48it/s, loss=0.0114]
Validation: 100%|██████████| 125/125 [00:18<00:00,  6.70it/s, val_loss=0.0118]


Train Loss: 0.012064
Val Loss: 0.011818, MSE: 0.000131, SSIM: 0.963415, PSNR: 38.933978
Saved best model with PSNR: 38.933978
Epoch 4/10


Training: 100%|██████████| 1125/1125 [05:23<00:00,  3.48it/s, loss=0.0117]
Validation: 100%|██████████| 125/125 [00:18<00:00,  6.65it/s, val_loss=0.0118]


Train Loss: 0.011764
Val Loss: 0.011736, MSE: 0.000127, SSIM: 0.963533, PSNR: 39.038091
Saved best model with PSNR: 39.038091
Epoch 5/10


Training: 100%|██████████| 1125/1125 [05:23<00:00,  3.47it/s, loss=0.0117]
Validation: 100%|██████████| 125/125 [00:18<00:00,  6.60it/s, val_loss=0.0115]


Train Loss: 0.011568
Val Loss: 0.011472, MSE: 0.000111, SSIM: 0.964075, PSNR: 39.613521
Saved best model with PSNR: 39.613521
Epoch 6/10


Training: 100%|██████████| 1125/1125 [05:23<00:00,  3.48it/s, loss=0.0113]
Validation: 100%|██████████| 125/125 [00:18<00:00,  6.62it/s, val_loss=0.0115]


Train Loss: 0.011420
Val Loss: 0.011509, MSE: 0.000114, SSIM: 0.964036, PSNR: 39.479521
Epoch 7/10


Training: 100%|██████████| 1125/1125 [05:22<00:00,  3.49it/s, loss=0.0116]
Validation: 100%|██████████| 125/125 [00:18<00:00,  6.58it/s, val_loss=0.0113]


Train Loss: 0.011271
Val Loss: 0.011291, MSE: 0.000103, SSIM: 0.964433, PSNR: 39.935734
Saved best model with PSNR: 39.935734
Epoch 8/10


Training: 100%|██████████| 1125/1125 [05:23<00:00,  3.48it/s, loss=0.0113]
Validation: 100%|██████████| 125/125 [00:18<00:00,  6.66it/s, val_loss=0.0112]


Train Loss: 0.011191
Val Loss: 0.011228, MSE: 0.000100, SSIM: 0.964523, PSNR: 40.039625
Saved best model with PSNR: 40.039625
Epoch 9/10


Training: 100%|██████████| 1125/1125 [05:22<00:00,  3.48it/s, loss=0.0112]
Validation: 100%|██████████| 125/125 [00:18<00:00,  6.82it/s, val_loss=0.0112]


Train Loss: 0.011130
Val Loss: 0.011165, MSE: 0.000097, SSIM: 0.964601, PSNR: 40.160305
Saved best model with PSNR: 40.160305
Epoch 10/10


Training: 100%|██████████| 1125/1125 [05:23<00:00,  3.48it/s, loss=0.0112]
Validation: 100%|██████████| 125/125 [00:18<00:00,  6.66it/s, val_loss=0.0112]


Train Loss: 0.011095
Val Loss: 0.011139, MSE: 0.000096, SSIM: 0.964683, PSNR: 40.209333
Saved best model with PSNR: 40.209333
