## Imports

In [None]:
import wandb
from pathlib import Path
import torch
import os
import shutil
import torch.nn as nn
from sklearn.model_selection import train_test_split
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import numpy as np
from torch.nn import functional as F
from torchvision import transforms
from PIL import Image
import torchvision.models as models
import time
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader

## Constants

In [None]:
DATA_A_PATH = "dataset3a/"
SPLIT_DATA_PATH_A = "data_a/"
PATH_TRAIN_LR = "data_a/train/LR"
PATH_TRAIN_HR = "data_a/train/HR"
PATH_VAL_LR = "data_a/val/LR"
PATH_VAL_HR = "data_a/val/HR"

DATA_B_PATH = "dataset3b/"
SPLIT_DATA_PATH_B = "data_b/"
PATH_TRAIN_LR_B = "data_b/train/LR"
PATH_TRAIN_HR_B = "data_b/train/HR"
PATH_VAL_LR_B = "data_b/val/LR"
PATH_VAL_HR_B = "data_b/val/HR"

SAVED_MODEL_PATH = "task3A/equiformer_best_0.pth"

LR = "LR"
HR = "HR"
TRAIN = "train"
VAL = "val"

WANDB_USERNAME = "samkitshah1262-warner-bros-discovery"
WANDB_PROJECT = "ml4sci-superres"

## Utils

### Logger

In [None]:
class WandBLogger:
    def __init__(self, config, model):
        self.config = config
        self.run = wandb.init(
            project=config['wandb_project'],
            entity=config['wandb_entity'],
            config=config,
            tags=config.get('tags', ['super-resolution', 'lensing']),
            dir=str(Path.cwd())
        )

        wandb.watch(
            model,
            log='all',
            log_freq=config.get('log_interval', 50),
            log_graph=True
        )
        
    def log_metrics(self, metrics, step=None, commit=True):
        wandb.log(metrics, step=step, commit=commit)
        
    def log_images(self, lr, sr, hr, caption="LR/SR/HR Comparison"):
        # Resize LR to match HR/SR dimensions
        lr_upscaled = F.interpolate(lr, scale_factor=2, mode='bilinear')
        
        # Denormalize images
        lr_upscaled = (lr_upscaled * 0.5 + 0.5).clamp(0, 1)
        sr = (sr * 0.5 + 0.5).clamp(0, 1)
        hr = (hr * 0.5 + 0.5).clamp(0, 1)

        grid = torch.cat([lr_upscaled, sr, hr], dim=-1)  # Concatenate along width

        images = wandb.Image(grid, caption=caption)
        wandb.log({"Examples": images})
        
    def log_model(self, model_path, metadata=None):
        artifact = wandb.Artifact(
            name=f"model-{wandb.run.id}",
            type="model",
            description="Equiformer super-resolution model",
            metadata=metadata or {}
        )
        artifact.add_file(model_path)
        wandb.log_artifact(artifact)
        
    def finish(self):
        wandb.finish()


### Splitter

In [None]:
dataset_root = DATA_A_PATH # Change this to your dataset path
lr_dir = os.path.join(dataset_root, LR)
hr_dir = os.path.join(dataset_root, HR)

# Define output directories
output_root = SPLIT_DATA_PATH_A  # Change this to your desired output path
train_lr_dir = os.path.join(output_root, "train", LR)
train_hr_dir = os.path.join(output_root, "train", HR)
val_lr_dir = os.path.join(output_root, "val", LR)
val_hr_dir = os.path.join(output_root, "val", HR)

# Create train/val directories
for dir_path in [train_lr_dir, train_hr_dir, val_lr_dir, val_hr_dir]:
    os.makedirs(dir_path, exist_ok=True)

# Get all sample file names
lr_files = sorted(os.listdir(lr_dir))  # Ensure matching order
hr_files = sorted(os.listdir(hr_dir))

# Ensure pairs match
assert len(lr_files) == len(hr_files), "Mismatch in LR and HR files!"

# Split dataset (90% train, 10% val)
train_lr, val_lr, train_hr, val_hr = train_test_split(
    lr_files, hr_files, test_size=0.1, random_state=42
)

# Function to copy files
def move_files(files, src_dir, dest_dir):
    for file in files:
        shutil.copy(os.path.join(src_dir, file), os.path.join(dest_dir, file))

# Move files to respective directories
move_files(train_lr, lr_dir, train_lr_dir)
move_files(train_hr, hr_dir, train_hr_dir)
move_files(val_lr, lr_dir, val_lr_dir)
move_files(val_hr, hr_dir, val_hr_dir)

print("Dataset split completed successfully!")

dataset_root = DATA_B_PATH # Change this to your dataset path
lr_dir = os.path.join(dataset_root, LR)
hr_dir = os.path.join(dataset_root, HR)

# Define output directories
output_root = SPLIT_DATA_PATH_B  # Change this to your desired output path
train_lr_dir = os.path.join(output_root, "train", LR)
train_hr_dir = os.path.join(output_root, "train", HR)
val_lr_dir = os.path.join(output_root, "val", LR)
val_hr_dir = os.path.join(output_root, "val", HR)

# Create train/val directories
for dir_path in [train_lr_dir, train_hr_dir, val_lr_dir, val_hr_dir]:
    os.makedirs(dir_path, exist_ok=True)

# Get all sample file names
lr_files = sorted(os.listdir(lr_dir))  # Ensure matching order
hr_files = sorted(os.listdir(hr_dir))

# Ensure pairs match
assert len(lr_files) == len(hr_files), "Mismatch in LR and HR files!"

# Split dataset (90% train, 10% val)
train_lr, val_lr, train_hr, val_hr = train_test_split(
    lr_files, hr_files, test_size=0.1, random_state=42
)

# Function to copy files
def move_files(files, src_dir, dest_dir):
    for file in files:
        shutil.copy(os.path.join(src_dir, file), os.path.join(dest_dir, file))

# Move files to respective directories
move_files(train_lr, lr_dir, train_lr_dir)
move_files(train_hr, hr_dir, train_hr_dir)
move_files(val_lr, lr_dir, val_lr_dir)
move_files(val_hr, hr_dir, val_hr_dir)

print("Dataset split completed successfully!")

### Metrics

In [None]:
def calculate_psnr(img1, img2):
    """Calculate PSNR between two image tensors"""
    img1 = img1.detach().cpu().numpy().transpose(0,2,3,1)
    img2 = img2.detach().cpu().numpy().transpose(0,2,3,1)
    return np.mean([psnr(im1, im2, data_range=1.0) 
                   for im1, im2 in zip(img1, img2)])

def calculate_ssim(img1, img2, data_range=1.0, eps=1e-8):
    """Numerically stable SSIM calculation for grayscale images"""
    # Input validation and clamping
    img1 = torch.clamp(img1, -data_range, data_range).detach()
    img2 = torch.clamp(img2, -data_range, data_range).detach()
    
    # Convert to numpy with double precision
    img1_np = img1.cpu().numpy().squeeze(1).astype(np.float64)
    img2_np = img2.cpu().numpy().squeeze(1).astype(np.float64)
    
    ssim_values = []
    
    for i in range(img1_np.shape[0]):
        im1 = img1_np[i]
        im2 = img2_np[i]
        
        try:
            # Dynamic window size selection with safety checks
            min_dim = min(im1.shape)
            win_size = min(7, min_dim - 1 if min_dim % 2 == 0 else min_dim)
            win_size = max(3, win_size)
            
            # Calculate SSIM with stability parameters
            ssim_val = ssim(
                im1, im2,
                data_range=data_range,
                win_size=win_size,
                channel_axis=None,
                gaussian_weights=True,
                sigma=1.5,
                use_sample_covariance=False
            )
            
            # Handle potential NaN/Inf
            if np.isnan(ssim_val) or np.isinf(ssim_val):
                raise ValueError("Invalid SSIM value")
                
        except Exception as e:
            print(f"SSIM calculation failed for image {i}: {str(e)}")
            ssim_val = -1  # Sentinel value for failures
            
        ssim_values.append(ssim_val)
    
    # Filter out failed calculations
    valid_ssim = [v for v in ssim_values if v >= 0]
    
    return np.mean(valid_ssim) if valid_ssim else 0.0

## Configs and settings

In [None]:
device = torch.device("cuda" if torch.backends.mps.is_available() else "cpu")

In [None]:
config = {
    'dim': 64,
    'num_blocks': 8,
    'num_heads': 4,
    'upscale': 2,
    'lr': 1e-4,
    'weight_decay': 1e-4,
    'batch_size': 16,
    'epochs': 100,
    'patience': 10,
    'use_amp': True,
    'train_lr_dir': PATH_TRAIN_LR,
    'train_hr_dir': PATH_TRAIN_HR,
    'val_lr_dir': PATH_VAL_LR,
    'val_hr_dir': PATH_VAL_HR,
    'transform': transforms.Compose([
        transforms.Normalize(mean=[0.5], std=[0.5]) 
    ]),
    'wandb_project': WANDB_PROJECT,
    'wandb_entity': WANDB_USERNAME,
    'tags': ['gsoc2025', 'diffilens'],
    'log_interval': 50,
    'sample_interval': 200,
    'architecture': 'Equiformer'
}


## Data loading

In [None]:
class LensDataset(torch.utils.data.Dataset):
    """Strong lensing dataset handler for .npy files"""

    def __init__(self, lr_dir, hr_dir, transform=None):
        self.lr_files = sorted(os.listdir(lr_dir))  # Ensure order
        self.hr_files = sorted(os.listdir(hr_dir))
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.transform = transform

        assert len(self.lr_files) == len(self.hr_files), "Mismatch in LR and HR files!"

    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):

        lr = np.load(os.path.join(self.lr_dir, self.lr_files[idx]))
        hr = np.load(os.path.join(self.hr_dir, self.hr_files[idx]))


        lr = torch.tensor(lr, dtype=torch.float32)
        hr = torch.tensor(hr, dtype=torch.float32) 

        if self.transform:
            lr = self.transform(lr)
            hr = self.transform(hr)

        return lr, hr
    

In [None]:
class LensDataPreprocessor:
    def __init__(self, crop_size=75, scale_factor=2):
        self.train_transform = transforms.Compose([
            transforms.RandomChoice([
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip()
            ]),
            transforms.RandomRotation(15),
            transforms.RandomCrop(crop_size*scale_factor),
            transforms.Lambda(lambda x: self._degrade(x, scale_factor)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])
        
        self.val_transform = transforms.Compose([
            transforms.CenterCrop(crop_size*scale_factor),
            transforms.Lambda(lambda x: self._degrade(x, scale_factor)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])

    def _degrade(self, hr, scale_factor):
        lr_size = (hr.size[1]//scale_factor, hr.size[0]//scale_factor)
        return hr.resize(lr_size, Image.BICUBIC)

    def get_transforms(self):
        return {
            'train': PairedTransform(self.train_transform),
            'val': PairedTransform(self.val_transform)
        }

class PairedTransform:
    def __init__(self, transform):
        self.transform = transform
        
    def __call__(self, hr):
        lr = self.transform(hr)
        hr = self.transform(hr)
        return lr, hr

## Loss

In [None]:
class VGGPerceptualLoss(nn.Module):
    def __init__(self, layers=[2, 7, 14]): 
        super(VGGPerceptualLoss, self).__init__()
        vgg = models.vgg16(pretrained=True).features 
        self.layers = layers
        self.vgg_layers = nn.ModuleList([vgg[i] for i in layers])
        for param in self.parameters():
            param.requires_grad = False 
        self.vgg = vgg.eval()
        

    def forward(self, sr, hr):
        sr = sr.repeat(1, 3, 1, 1)  # (B, 1, H, W) → (B, 3, H, W)
        hr = hr.repeat(1, 3, 1, 1)
        
        sr_features = self.vgg(sr)
        hr_features = self.vgg(hr)
        return F.l1_loss(sr_features, hr_features)
    

class PhysicsConstrainedLoss(nn.Module):
    def __init__(self, alpha=0.1, beta=0.01, device=device):
        super().__init__()
        self.alpha = alpha  # Weight for mass conservation
        self.beta = beta    # Weight for lensing equation
        self.device = device
        
        # Sobel filters for gradient calculation
        self.sobel_x = torch.tensor([[[[1, 0, -1], 
                                      [2, 0, -2], 
                                      [1, 0, -1]]]], dtype=torch.float32, device=device)
        self.sobel_y = torch.tensor([[[[1, 2, 1], 
                                      [0, 0, 0], 
                                      [-1, -2, -1]]]], dtype=torch.float32, device=device)

    def gradient(self, img):
        """Calculate image gradients using Sobel operators"""
        grad_x = F.conv2d(img, self.sobel_x, padding=1)
        grad_y = F.conv2d(img, self.sobel_y, padding=1)
        return grad_x, grad_y

    def laplacian(self, img):
        """Calculate image Laplacian"""
        kernel = torch.tensor([[[[0, 1, 0], 
                               [1, -4, 1], 
                               [0, 1, 0]]]], dtype=torch.float32, device=self.device)
        return F.conv2d(img, kernel, padding=1)

    def mass_conservation_loss(self, sr, hr):
        """
        Enforce conservation of total flux/mass between 
        LR upscaled and SR reconstruction
        """
        lr_upscaled = F.interpolate(sr, scale_factor=0.5, mode='bicubic')
        return F.mse_loss(lr_upscaled.mean(dim=(2,3)), sr.mean(dim=(2,3)))

    def lensing_equation_loss(self, sr):
        """
        Enforce weak lensing approximation:
        ∇²ψ = 2κ where ψ is lensing potential, κ is convergence
        Approximated using image gradients and Laplacian
        """
        grad_x, grad_y = self.gradient(sr)
        lap = self.laplacian(sr)
        
        # Simulated convergence (κ) from image intensity
        kappa = sr.mean(dim=1, keepdim=True)  # Simplified assumption
        
        # Lensing equation residual
        residual = lap - 2*kappa
        return torch.mean(residual**2)

    def forward(self, sr, hr):
        base_loss = F.l1_loss(sr, hr)
        mass_loss = self.mass_conservation_loss(sr, hr)
        lens_loss = self.lensing_equation_loss(sr)
        
        return base_loss + self.alpha*mass_loss + self.beta*lens_loss

class TVLoss(nn.Module):
    def __init__(self, weight=1.0):
        super(TVLoss, self).__init__()
        self.weight = weight

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        h_variation = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]).sum()
        w_variation = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]).sum()
        return self.weight * (h_variation + w_variation) / (batch_size * channels * height * width)

class HybridLoss(nn.Module):
    def __init__(self,device):
        super().__init__()
        self.physics = PhysicsConstrainedLoss(device=device)
        self.vgg = VGGPerceptualLoss().to(device)
        self.tv = TVLoss()
    def forward(self, sr, hr):
        return (0.7*self.physics(sr, hr) + 
                0.2*self.vgg(sr, hr) + 
                0.1*self.tv(sr))

## Model

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class GroupConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, groups=4):
        super().__init__()
        self.groups = groups
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, 
                            padding=kernel_size//2, groups=groups)
        
    def forward(self, x):
        b, c, h, w = x.shape
        x = x.view(b, self.groups, c//self.groups, h, w)
        x = self.conv(x.reshape(b, c, h, w))
        return x.view(b, -1, h, w)

class EquivariantAttention(nn.Module):
    def __init__(self, dim, num_heads=4, groups=4, qkv_bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.groups = groups
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        
        self.qkv = GroupConv(dim, dim*3, 1, groups=groups)
        self.proj = GroupConv(dim, dim, 1, groups=groups)
        
        self.norm = nn.GroupNorm(groups, dim)
        
    def forward(self, x):
        B, C, H, W = x.shape
        qkv = self.qkv(self.norm(x)).chunk(3, dim=1)
        
        q, k, v = map(lambda t: t.view(B, self.num_heads, C // self.num_heads, H, W), qkv)
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        
        x = (attn @ v).reshape(B, C, H, W)
        x = self.proj(x)
        return x + x

class EquivariantFFN(nn.Module):
    def __init__(self, dim, expansion=4, groups=4):
        super().__init__()
        hidden_dim = dim * expansion
        self.net = nn.Sequential(
            GroupConv(dim, hidden_dim, 1, groups=groups),
            nn.GELU(),
            GroupConv(hidden_dim, dim, 1, groups=groups)
        )
        self.norm = nn.GroupNorm(groups, dim)
        
    def forward(self, x):
        return self.net(self.norm(x)) + x

class EquiformerBlock(nn.Module):
    def __init__(self, dim, num_heads, groups=4, mlp_expansion=4):
        super().__init__()
        self.attn = EquivariantAttention(dim, num_heads, groups)
        self.ffn = EquivariantFFN(dim, mlp_expansion, groups)
        self.norm1 = nn.GroupNorm(groups, dim)
        self.norm2 = nn.GroupNorm(groups, dim)
        
    def forward(self, x):
        x = self.attn(self.norm1(x)) + x
        x = self.ffn(self.norm2(x)) + x
        return x

class Equiformer(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, dim=64, 
                 num_blocks=8, num_heads=4, groups=4):
        super().__init__()
        self.embed = nn.Conv2d(in_channels, dim, 3, padding=1)
        
        self.blocks = nn.Sequential(*[
            EquiformerBlock(dim, num_heads, groups)
            for _ in range(num_blocks)
        ])

        self.upsampler = nn.Sequential(
            nn.Conv2d(dim, dim * 4, 3, padding=1),  # 2² = 4 channels for PixelShuffle
            nn.PixelShuffle(2),  # 2× upscale
            nn.Conv2d(dim, out_channels, 3, padding=1)
        )
        
    def forward(self, x):
        x_low = x
        x = self.embed(x)
        x = self.blocks(x)
        x = self.upsampler(x)
        return x + F.interpolate(x_low, scale_factor=2, mode='bilinear')  
        

## Trainer

In [None]:
class Trainer:
    def __init__(self, model, train_loader, val_loader, 
                 optimizer, criterion, device, config,
                 scheduler=None, use_amp=True):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device
        self.scheduler = scheduler
        self.scaler = GradScaler(device, enabled=use_amp)
        self.best_val_loss = float('inf')
        self.logger = WandBLogger(config, model)
        self.log_interval = config.get('log_interval', 50)
        self.sample_interval = config.get('sample_interval', 200)
        self.best_val_loss = float('inf')
        self.best_epoch = 0 
        
    def train_epoch(self, epoch):
        self.model.train()
        total_loss = 0.0
        psnr_values = []
        ssim_values = []
        mse_values = []
        start_time = time.time()
        
        for batch_idx, (lr, hr) in enumerate(self.train_loader):
            lr = lr.to(self.device, non_blocking=True)
            hr = hr.to(self.device, non_blocking=True)
            
            self.optimizer.zero_grad(set_to_none=True)
            
            with autocast(enabled=self.scaler.is_enabled()):
                outputs = self.model(lr)
                loss = self.criterion(outputs, hr)
                
            self.scaler.scale(loss).backward()
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.scaler.step(self.optimizer)
            self.scaler.update()
            
            total_loss += loss.item()
            batch_mse = F.mse_loss(outputs, hr).item()
            batch_psnr = calculate_psnr(outputs, hr)
            batch_ssim = calculate_ssim(outputs, hr)
            psnr_values.append(batch_psnr)
            ssim_values.append(batch_ssim)
            mse_values.append(batch_mse)

            if batch_idx % self.log_interval == 0:
                self.logger.log_metrics({
                    "train/loss": loss.item(),
                    "train/batch_mse": batch_mse,
                    "train/batch_psnr": batch_psnr,
                    "train/batch_ssim": batch_ssim,
                    "lr": self.optimizer.param_groups[0]['lr']
                }, commit=False)
                
            if batch_idx % self.sample_interval == 0:
                with torch.no_grad():
                    self.logger.log_images(lr[:1], outputs[:1], hr[:1])
                
        avg_loss = total_loss / len(self.train_loader)
        avg_psnr = np.mean(psnr_values)
        avg_ssim = np.mean(ssim_values)
        avg_mse = np.mean(mse_values)
        epoch_time = time.time() - start_time
        
        self.logger.log_metrics({
            "epoch": epoch,
            "train/avg_loss": avg_loss,
            "train/epoch_mse": avg_mse,
            "train/avg_psnr": avg_psnr,
            "train/avg_ssim": avg_ssim,
            "epoch_time": epoch_time
        })
        return avg_loss
    
    @torch.no_grad()
    def validate(self, epoch):
        self.model.eval()
        total_loss = 0.0
        mse_values = []
        psnr_values = []
        ssim_values = []
        start_time = time.time()
        
        for lr, hr in self.val_loader:
            lr = lr.to(self.device, non_blocking=True)
            hr = hr.to(self.device, non_blocking=True)
            
            outputs = self.model(lr)
            loss = self.criterion(outputs, hr)
            
            total_loss += loss.item()
            mse_values.append(F.mse_loss(outputs, hr).item())
            psnr_values.append(calculate_psnr(outputs, hr))
            ssim_values.append(calculate_ssim(outputs, hr))
            
        avg_loss = total_loss / len(self.val_loader)
        avg_psnr = np.mean(psnr_values)
        avg_ssim = np.mean(ssim_values)
        avg_mse = np.mean(mse_values)
        epoch_time = time.time() - start_time
        
        self.logger.log_metrics({
            "val/loss": avg_loss,
            "val/mse": avg_mse,
            "val/psnr": avg_psnr,
            "val/ssim": avg_ssim,
            "epoch_time": epoch_time
        })

        self.logger.log_metrics({
            "val_output_dist": wandb.Histogram(outputs.cpu().numpy())
        })

        if avg_loss < self.best_val_loss:
            self.best_val_loss = avg_loss
            self.best_epoch = epoch
            torch.save({
                'epoch': epoch,
                'model_state_dict': self.model.state_dict(),
                'optimizer_state_dict': self.optimizer.state_dict(),
                'loss': avg_loss,
            }, 'equiformer_best.pth')
            print("Saved best model!")
            self.logger.log_model('equiformer_best.pth', {
                'epoch': epoch,
                'val_loss': avg_loss,
                'val_psnr': avg_psnr
            })

        if self.scheduler is not None:
            self.scheduler.step(avg_loss)
            
        return avg_loss

## Runner

In [None]:
def train_model(config):
    
    print(device)
    model = Equiformer(
        dim=config['dim'],
        num_blocks=config['num_blocks'],
        num_heads=config['num_heads'],
    )
    
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['lr'],
        weight_decay=config['weight_decay']
    )
    
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min',
        factor=0.5,
        patience=3,
        verbose=True
    )
    
    criterion = HybridLoss(device=device)
    
    # preprocessor = LensDataPreprocessor(crop_size=75)
    # transforms = preprocessor.get_transforms()

    train_dataset = LensDataset(
        lr_dir=config['train_lr_dir'],
        hr_dir=config['train_hr_dir'],
        transform=config['transform']
    )
    
    val_dataset = LensDataset(
        lr_dir=config['val_lr_dir'],
        hr_dir=config['val_hr_dir'],
        transform=config['transform']
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )
    
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        criterion=criterion,
        device=device,
        scheduler=scheduler,
        config=config,
        use_amp=config['use_amp']
    )

    try:
        for epoch in range(1, config['epochs'] + 1):
            train_loss = trainer.train_epoch(epoch)
            val_loss = trainer.validate(epoch)

            # Early stopping
            if (epoch - trainer.best_epoch) > config['patience']:
                print(f"Early stopping at epoch {epoch}")
                break



    finally:
        trainer.logger.finish()


In [None]:
train_model(config)