## Imports

In [106]:
import wandb
from pathlib import Path
import torch
import os
import shutil
import torch.nn as nn
from sklearn.model_selection import train_test_split
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import numpy as np
from torch.nn import functional as F
from torchvision import transforms
from PIL import Image
import torchvision.models as models
import time
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader

from einops import repeat, rearrange
from einops.layers.torch import Rearrange

from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block

## Constants

In [107]:
DATA_A_PATH = "dataset3a/"
SPLIT_DATA_PATH_A = "data_a/"
PATH_TRAIN_LR = "../sr/data_a/train/LR"
PATH_TRAIN_HR = "../sr/data_a/train/HR"
PATH_VAL_LR = "../sr/data_a/val/LR"
PATH_VAL_HR = "../sr/data_a/val/HR"

DATA_B_PATH = "dataset3b/"
SPLIT_DATA_PATH_B = "data_b/"
PATH_TRAIN_LR_B = "data_b/train/LR"
PATH_TRAIN_HR_B = "data_b/train/HR"
PATH_VAL_LR_B = "data_b/val/LR"
PATH_VAL_HR_B = "data_b/val/HR"

SAVED_MODEL_PATH = "best_mae_model2.pth"

LR = "LR"
HR = "HR"
TRAIN = "train"
VAL = "val"

WANDB_USERNAME = "samkitshah1262-warner-bros-discovery"
WANDB_PROJECT = "deeplens-foundational"

## Utils

### Logger

In [108]:
class WandBLogger:
    def __init__(self, config, model):
        self.config = config
        self.run = wandb.init(
            project=config['wandb_project'],
            entity=config['wandb_entity'],
            config=config,
            tags=config.get('tags', ['super-resolution', 'lensing']),
            dir=str(Path.cwd())
        )

        wandb.watch(
            model,
            log='all',
            log_freq=config.get('log_interval', 50),
            log_graph=True
        )
        
    def log_metrics(self, metrics, step=None, commit=True):
        wandb.log(metrics, step=step, commit=commit)
        
    def log_images(self, lr, sr, hr, caption="LR/SR/HR Comparison"):
        # Resize LR to match HR/SR dimensions
        lr_upscaled = F.interpolate(lr, scale_factor=2, mode='bilinear')
        
        # Denormalize images
        lr_upscaled = (lr_upscaled * 0.5 + 0.5).clamp(0, 1)
        sr = (sr * 0.5 + 0.5).clamp(0, 1)
        hr = (hr * 0.5 + 0.5).clamp(0, 1)

        grid = torch.cat([lr_upscaled, sr, hr], dim=-1)  # Concatenate along width

        images = wandb.Image(grid, caption=caption)
        wandb.log({"Examples": images})
        
    def log_model(self, model_path, metadata=None):
        artifact = wandb.Artifact(
            name=f"model-{wandb.run.id}",
            type="model",
            description="Equiformer super-resolution model",
            metadata=metadata or {}
        )
        artifact.add_file(model_path)
        wandb.log_artifact(artifact)
        
    def finish(self):
        wandb.finish()


### Metrics

In [109]:
def calculate_psnr(img1, img2):
    """Calculate PSNR between two image tensors"""
    img1 = img1.detach().cpu().numpy().transpose(0,2,3,1)
    img2 = img2.detach().cpu().numpy().transpose(0,2,3,1)
    return np.mean([psnr(im1, im2, data_range=1.0) 
                   for im1, im2 in zip(img1, img2)])

def calculate_ssim(img1, img2, data_range=1.0, eps=1e-8):
    """Numerically stable SSIM calculation for grayscale images"""
    # Input validation and clamping
    img1 = torch.clamp(img1, -data_range, data_range).detach()
    img2 = torch.clamp(img2, -data_range, data_range).detach()
    
    # Convert to numpy with double precision
    img1_np = img1.cpu().numpy().squeeze(1).astype(np.float64)
    img2_np = img2.cpu().numpy().squeeze(1).astype(np.float64)
    
    ssim_values = []
    
    for i in range(img1_np.shape[0]):
        im1 = img1_np[i]
        im2 = img2_np[i]
        
        try:
            # Dynamic window size selection with safety checks
            min_dim = min(im1.shape)
            win_size = min(7, min_dim - 1 if min_dim % 2 == 0 else min_dim)
            win_size = max(3, win_size)
            
            # Calculate SSIM with stability parameters
            ssim_val = ssim(
                im1, im2,
                data_range=data_range,
                win_size=win_size,
                channel_axis=None,
                gaussian_weights=True,
                sigma=1.5,
                use_sample_covariance=False
            )
            
            # Handle potential NaN/Inf
            if np.isnan(ssim_val) or np.isinf(ssim_val):
                raise ValueError("Invalid SSIM value")
                
        except Exception as e:
            print(f"SSIM calculation failed for image {i}: {str(e)}")
            ssim_val = -1  # Sentinel value for failures
            
        ssim_values.append(ssim_val)
    
    # Filter out failed calculations
    valid_ssim = [v for v in ssim_values if v >= 0]
    
    return np.mean(valid_ssim) if valid_ssim else 0.0

## Configs and settings

In [110]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [111]:
config = {
    'dim': 64,
    'num_blocks': 8,
    'num_heads': 4,
    'upscale': 2,
    'lr': 1e-4,
    'weight_decay': 1e-4,
    'batch_size': 16,
    'epochs': 100,
    'patience': 10,
    'use_amp': True,
    'train_lr_dir': PATH_TRAIN_LR,
    'train_hr_dir': PATH_TRAIN_HR,
    'val_lr_dir': PATH_VAL_LR,
    'val_hr_dir': PATH_VAL_HR,
    'transform': transforms.Compose([
        transforms.Normalize(mean=[0.5], std=[0.5]) 
    ]),
    'wandb_project': WANDB_PROJECT,
    'wandb_entity': WANDB_USERNAME,
    'tags': ['gsoc2025', 'diffilens'],
    'log_interval': 50,
    'sample_interval': 200,
    'architecture': 'Equiformer'
}


## Data loading

In [112]:
class LensDataset(torch.utils.data.Dataset):
    """Strong lensing dataset handler for .npy files"""

    def __init__(self, lr_dir, hr_dir, transform=None):
        self.lr_files = sorted(os.listdir(lr_dir))  # Ensure order
        self.hr_files = sorted(os.listdir(hr_dir))
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.transform = transform

        assert len(self.lr_files) == len(self.hr_files), "Mismatch in LR and HR files!"

    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):

        lr = np.load(os.path.join(self.lr_dir, self.lr_files[idx]))
        hr = np.load(os.path.join(self.hr_dir, self.hr_files[idx]))

        if len(lr.shape) == 2:
            lr = lr[None, ...]
        if len(hr.shape) == 2:
            hr = hr[None, ...]

        lr = torch.tensor(lr, dtype=torch.float32)
        hr = torch.tensor(hr, dtype=torch.float32) 

        if self.transform:
            lr = self.transform(lr)
            hr = self.transform(hr)

        return lr, hr
    

In [113]:
class LensDataPreprocessor:
    def __init__(self, crop_size=75, scale_factor=2):
        self.train_transform = transforms.Compose([
            transforms.RandomChoice([
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip()
            ]),
            transforms.RandomRotation(15),
            transforms.RandomCrop(crop_size*scale_factor),
            transforms.Lambda(lambda x: self._degrade(x, scale_factor)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])
        
        self.val_transform = transforms.Compose([
            transforms.CenterCrop(crop_size*scale_factor),
            transforms.Lambda(lambda x: self._degrade(x, scale_factor)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])

    def _degrade(self, hr, scale_factor):
        lr_size = (hr.size[1]//scale_factor, hr.size[0]//scale_factor)
        return hr.resize(lr_size, Image.BICUBIC)

    def get_transforms(self):
        return {
            'train': PairedTransform(self.train_transform),
            'val': PairedTransform(self.val_transform)
        }

class PairedTransform:
    def __init__(self, transform):
        self.transform = transform
        
    def __call__(self, hr):
        lr = self.transform(hr)
        hr = self.transform(hr)
        return lr, hr

## Model

In [114]:
def random_indexes(size : int):
    forward_indexes = np.arange(size)
    np.random.shuffle(forward_indexes)
    backward_indexes = np.argsort(forward_indexes)
    return forward_indexes, backward_indexes

def take_indexes(sequences, indexes):
    return torch.gather(sequences, 0, repeat(indexes, 't b -> t b c', c=sequences.shape[-1]))

class PatchShuffle(torch.nn.Module):
    def __init__(self, ratio) -> None:
        super().__init__()
        self.ratio = ratio

    def forward(self, patches : torch.Tensor):
        T, B, C = patches.shape
        remain_T = int(T * (1 - self.ratio))

        indexes = [random_indexes(T) for _ in range(B)]
        forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
        backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)

        patches = take_indexes(patches, forward_indexes)
        patches = patches[:remain_T]

        return patches, forward_indexes, backward_indexes

class MAE_Encoder(torch.nn.Module):
    def __init__(self,
                 image_size=64,
                 patch_size=4,
                 emb_dim=192,
                 num_layer=12,
                 num_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()

        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))
        self.shuffle = PatchShuffle(mask_ratio)

        self.patchify = torch.nn.Conv2d(1, emb_dim, patch_size, patch_size)

        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

        self.layer_norm = torch.nn.LayerNorm(emb_dim)

        self.init_weight()

    def init_weight(self):
        trunc_normal_(self.cls_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, img):
        patches = self.patchify(img)
        patches = rearrange(patches, 'b c h w -> (h w) b c')
        patches = patches + self.pos_embedding

        patches, forward_indexes, backward_indexes = self.shuffle(patches)

        patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
        patches = rearrange(patches, 't b c -> b t c')
        features = self.layer_norm(self.transformer(patches))
        features = rearrange(features, 'b t c -> t b c')

        return features, backward_indexes

class MAE_Decoder(torch.nn.Module):
    def __init__(self,
                 image_size=64,
                 patch_size=4,
                 emb_dim=192,
                 num_layer=4,
                 num_head=3,
                 ) -> None:
        super().__init__()

        self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim))

        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

        self.head = torch.nn.Linear(emb_dim, 1 * patch_size ** 2)
        self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)

        self.init_weight()

    def init_weight(self):
        trunc_normal_(self.mask_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, features, backward_indexes):
        T = features.shape[0]
        backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim=0)
        features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim=0)
        features = take_indexes(features, backward_indexes)
        features = features + self.pos_embedding

        features = rearrange(features, 't b c -> b t c')
        features = self.transformer(features)
        features = rearrange(features, 'b t c -> t b c')
        features = features[1:]

        patches = self.head(features)
        mask = torch.zeros_like(patches)
        mask[T-1:] = 1
        mask = take_indexes(mask, backward_indexes[1:] - 1)
        img = self.patch2img(patches)
        mask = self.patch2img(mask)

        return img, mask

class MAE_ViT(torch.nn.Module):
    def __init__(self,
                 image_size=64,
                 patch_size=4,
                 emb_dim=192,
                 encoder_layer=12,
                 encoder_head=3,
                 decoder_layer=4,
                 decoder_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()

        self.encoder = MAE_Encoder(image_size, patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio)
        self.decoder = MAE_Decoder(image_size, patch_size, emb_dim, decoder_layer, decoder_head)

    def forward(self, img):
        features, backward_indexes = self.encoder(img)
        predicted_img, mask = self.decoder(features,  backward_indexes)
        return predicted_img, mask

In [117]:
from torchvision.models import vgg16_bn
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
            nn.PReLU(),
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels)
        )
    
    def forward(self, x):
        return x + self.conv(x)

class UpSampleBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.upsample = nn.Sequential(
            nn.Conv2d(in_ch, out_ch*4, 3, padding=1),  # Pixel shuffle
            nn.PixelShuffle(2),
            nn.PReLU()
        )
        self.res = ResidualBlock(out_ch)
    
    def forward(self, x):
        x = self.upsample(x)
        return self.res(x)

# class EnhancedDecoder(nn.Module):
#     def __init__(self, emb_dim=192):
#         super().__init__()
#         self.layers = nn.Sequential(
#             nn.Conv2d(emb_dim, 256, 3, padding=1),
#             ResidualBlock(256),
#             UpSampleBlock(256, 128),  # 16x16 → 32x32
#             UpSampleBlock(128, 64),   # 32x32 → 64x64
#             UpSampleBlock(64, 32),    # 64x64 → 128x128
#             nn.Conv2d(32, 3, 3, padding=1)
#         )
#         self.final_upscale = nn.Sequential(
#             nn.Upsample(size=150, mode='bicubic', align_corners=False),
#             # nn.Conv2d(3, 3, 3, padding=1)
#         )
    
#     def forward(self, x):
#         x = self.layers(x)        # Output: 128x128
#         return self.final_upscale(x)  # → 150x150
class EnhancedDecoder(nn.Module):
    def __init__(self, emb_dim=192):
        super().__init__()
        self.initial_conv = nn.Sequential(
            nn.Conv2d(emb_dim, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.PReLU()
        )
        
        self.layers = nn.Sequential(
            ResidualBlock(256),
            UpSampleBlock(256, 128),  # 16x16 → 32x32
            UpSampleBlock(128, 64),   # 32x32 → 64x64
            UpSampleBlock(64, 32),    # 64x64 → 128x128
            nn.Conv2d(32, 1, 3, padding=1)  # Final conv to get single channel
        )
        
        # Final upscale to 150x150
        self.final_upscale = nn.Sequential(
            nn.Upsample(size=150, mode='bicubic', align_corners=False),
            nn.Conv2d(1, 1, 3, padding=1)
        )
    
    def forward(self, x):
        x = self.initial_conv(x)
        x = self.layers(x)
        return self.final_upscale(x)

# ---------------------------
# 2. Loss Function Improvements
# ---------------------------
class CombinedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        # Remove VGG loss since we're working with single-channel images
        self.l1_loss = nn.L1Loss()
        self.mse_loss = nn.MSELoss()
        
    def forward(self, pred, target):
        l1 = self.l1_loss(pred, target)
        mse = self.mse_loss(pred, target)
        return 0.5 * l1 + 0.5 * mse

# ---------------------------
# 3. Input Adaptation
# ---------------------------
class SuperResolutionMAE(nn.Module):
    def __init__(self, pretrained_mae):
        super().__init__()
        self.encoder = pretrained_mae.encoder
        self.encoder.mask_ratio = 0
        
        # Adjust for 75x75 input via padding (75 → 96)
        self.transform = nn.Sequential(
            transforms.CenterCrop(64),  # First crop to 64x64
            # nn.ReflectionPad2d((0, 0, 0, 0))  # No padding needed after crop
        )  # 75→96
        
        self.decoder = EnhancedDecoder(emb_dim=192)
        
    def forward(self, lr_img):
        # Pad 75x75 → 96x96 (divisible by MAE's patch size)
        x = self.transform(lr_img)  # Now compatible with MAE's 4x4 patches
        
        # Extract features
        features, _ = self.encoder(x)
        # B = features.shape[1]
        features = features[1:]  # Remove CLS token
        
        # Reshape to 2D feature map
        num_patches = int(np.sqrt(features.shape[0]))
        features = rearrange(features, '(h w) b c -> b c h w', h=num_patches, w=num_patches)
        
        # Decode
        return self.decoder(features)

In [118]:
mae = MAE_ViT(
        image_size=64,
        patch_size=4,
        emb_dim=192,
        encoder_layer=12,
        encoder_head=3,
        decoder_layer=4,
        decoder_head=3,
        mask_ratio=0.75
).to(device)
mae.load_state_dict(torch.load('best_mae_model2.pth')['model_state_dict'])
model = SuperResolutionMAE(mae).to(device)
lr_img = torch.randn(1, 1, 75, 75).to(device)
hr_pred = model(lr_img)

In [119]:
with torch.no_grad():
    model.encoder.patchify.weight[:, :, :4, :4].copy_(mae.encoder.patchify.weight)

In [120]:
print(hr_pred.shape)

torch.Size([1, 1, 150, 150])


## Trainer

In [121]:
class Trainer:
    def __init__(self, model, train_loader, val_loader, 
                 optimizer, criterion, device, config,
                 scheduler=None, use_amp=True):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device
        self.scheduler = scheduler
        self.scaler = GradScaler(device, enabled=use_amp)
        self.best_val_loss = float('inf')
        self.logger = WandBLogger(config, model)
        self.log_interval = config.get('log_interval', 50)
        self.sample_interval = config.get('sample_interval', 200)
        self.best_val_loss = float('inf')
        self.best_epoch = 0 
        
    def train_epoch(self, epoch):
        self.model.train()
        total_loss = 0.0
        psnr_values = []
        ssim_values = []
        mse_values = []
        start_time = time.time()
        
        for batch_idx, (lr, hr) in enumerate(self.train_loader):
            lr = lr.to(self.device, non_blocking=True)
            hr = hr.to(self.device, non_blocking=True)
            # print("LR Shape: ",lr.shape)
            # print("HR Shape: ",hr.shape)
            self.optimizer.zero_grad(set_to_none=True)
            
            with autocast(enabled=self.scaler.is_enabled()):
                outputs = self.model(lr)
                # print("Output Shape: ",outputs.shape)
                loss = self.criterion(outputs, hr)
                
            self.scaler.scale(loss).backward()
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.scaler.step(self.optimizer)
            self.scaler.update()
            
            total_loss += loss.item()
            batch_mse = F.mse_loss(outputs, hr).item()
            batch_psnr = calculate_psnr(outputs, hr)
            batch_ssim = calculate_ssim(outputs, hr)
            psnr_values.append(batch_psnr)
            ssim_values.append(batch_ssim)
            mse_values.append(batch_mse)

            if batch_idx % self.log_interval == 0:
                self.logger.log_metrics({
                    "train/loss": loss.item(),
                    "train/batch_mse": batch_mse,
                    "train/batch_psnr": batch_psnr,
                    "train/batch_ssim": batch_ssim,
                    "lr": self.optimizer.param_groups[0]['lr']
                }, commit=False)
                
            if batch_idx % self.sample_interval == 0:
                with torch.no_grad():
                    self.logger.log_images(lr[:1], outputs[:1], hr[:1])
                
        avg_loss = total_loss / len(self.train_loader)
        avg_psnr = np.mean(psnr_values)
        avg_ssim = np.mean(ssim_values)
        avg_mse = np.mean(mse_values)
        epoch_time = time.time() - start_time
        
        self.logger.log_metrics({
            "epoch": epoch,
            "train/avg_loss": avg_loss,
            "train/epoch_mse": avg_mse,
            "train/avg_psnr": avg_psnr,
            "train/avg_ssim": avg_ssim,
            "epoch_time": epoch_time
        })
        return avg_loss
    
    @torch.no_grad()
    def validate(self, epoch):
        self.model.eval()
        total_loss = 0.0
        mse_values = []
        psnr_values = []
        ssim_values = []
        start_time = time.time()
        
        for lr, hr in self.val_loader:
            lr = lr.to(self.device, non_blocking=True)
            hr = hr.to(self.device, non_blocking=True)
            
            outputs = self.model(lr)
            loss = self.criterion(outputs, hr)
            
            total_loss += loss.item()
            mse_values.append(F.mse_loss(outputs, hr).item())
            psnr_values.append(calculate_psnr(outputs, hr))
            ssim_values.append(calculate_ssim(outputs, hr))
            
        avg_loss = total_loss / len(self.val_loader)
        avg_psnr = np.mean(psnr_values)
        avg_ssim = np.mean(ssim_values)
        avg_mse = np.mean(mse_values)
        epoch_time = time.time() - start_time
        
        self.logger.log_metrics({
            "val/loss": avg_loss,
            "val/mse": avg_mse,
            "val/psnr": avg_psnr,
            "val/ssim": avg_ssim,
            "epoch_time": epoch_time
        })

        self.logger.log_metrics({
            "val_output_dist": wandb.Histogram(outputs.cpu().numpy())
        })

        if avg_loss < self.best_val_loss:
            self.best_val_loss = avg_loss
            self.best_epoch = epoch
            torch.save({
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': avg_loss,
            }, 'equiformer_best.pth')
            print("Saved best model!")
            self.logger.log_model('sr_mae_best.pth', {
                'epoch': epoch,
                'val_loss': avg_loss,
                'val_psnr': avg_psnr
            })

        if self.scheduler is not None:
            self.scheduler.step(avg_loss)
            
        return avg_loss

## Runner

In [122]:
def train_model(config):
    
    print(device)
    mae = MAE_ViT(
        image_size=64,
        patch_size=4,
        emb_dim=192,
        encoder_layer=12,
        encoder_head=3,
        decoder_layer=4,
        decoder_head=3,
        mask_ratio=0
    ).to(device)
    mae.load_state_dict(torch.load('best_mae_model2.pth')['model_state_dict'])
    model = SuperResolutionMAE(mae).to(device)
    
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['lr'],
        weight_decay=config['weight_decay']
    )
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min',
        factor=0.5,
        patience=3,
        verbose=True
    )
    
    criterion = CombinedLoss().to(device=device)
    
    # preprocessor = LensDataPreprocessor(crop_size=75)
    # transforms = preprocessor.get_transforms()

    train_dataset = LensDataset(
        lr_dir=config['train_lr_dir'],
        hr_dir=config['train_hr_dir'],
        transform=config['transform']
    )
    
    val_dataset = LensDataset(
        lr_dir=config['val_lr_dir'],
        hr_dir=config['val_hr_dir'],
        transform=config['transform']
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )
    
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        criterion=criterion,
        device=device,
        scheduler=scheduler,
        config=config,
        use_amp=config['use_amp']
    )

    try:
        for epoch in range(1, config['epochs'] + 1):
            train_loss = trainer.train_epoch(epoch)
            val_loss = trainer.validate(epoch)

            # Early stopping
            if (epoch - trainer.best_epoch) > config['patience']:
                print(f"Early stopping at epoch {epoch}")
                break



    finally:
        trainer.logger.finish()


In [123]:
train_model(config)

mps


  self.scaler = GradScaler(device, enabled=use_amp)


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
  with autocast(enabled=self.scaler.is_enabled()):


RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.