In [None]:
import os
import glob
import math
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torchvision.models import vgg19, VGG19_Weights
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import wandb
import cv2
from skimage.metrics import structural_similarity as ssim_metric
import time
import warnings
warnings.filterwarnings("ignore")

In [None]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
!nvidia-smi

In [None]:
# Set random seed for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Configuration
class Config:
    # Dataset
    data_root = 'DIV2K'  # Adjust to your path
    scale = 2  # Upscaling factor (2, 3, or 4)
    patch_size = 128  # HR patch size for training
    aug_probability = 0.5  # Probability of applying augmentations
    
    # Model
    base_channels = 64
    use_channel_attention = True
    use_spatial_attention = True
    num_residual_blocks = 16
    
    # Training
    batch_size = 16
    num_workers = 1
    lr = 2e-4
    min_lr = 1e-6
    weight_decay = 1e-4
    num_epochs = 100
    warmup_epochs = 5
    
    # Loss
    l1_weight = 1.0
    perceptual_weight = 0.1
    freq_loss_weight = 0.05
    
    # Logging
    save_dir = 'results'
    checkpoint_interval = 5
    log_interval = 100
    use_wandb = False  # Set to True if you want to use Weights & Biases
    
    # Mixed precision
    use_amp = True

# Initialize config
cfg = Config()

# Create save directory
os.makedirs(cfg.save_dir, exist_ok=True)

# Initialize wandb if enabled
if cfg.use_wandb:
    wandb.init(project="hybrid-freq-spatial-image-restoration", config=vars(cfg))

In [None]:
class DIV2KDataset(Dataset):
    """
    Dataset for DIV2K with both LR and HR images
    """
    def __init__(self, root_dir, split="train", scale=2, patch_size=128, augment=True):
        """
        Args:
            root_dir: Directory with DIV2K dataset
            split: 'train' or 'valid'
            scale: Downscaling factor (2, 3, or 4)
            patch_size: Size of cropped patches for training
            augment: Whether to apply data augmentation
        """
        self.root_dir = root_dir
        self.split = split
        self.scale = scale
        self.patch_size = patch_size
        self.augment = augment and split == "train"
        
        # Define paths
        self.hr_dir = os.path.join(root_dir, f"DIV2K_{split}_HR")
        self.lr_dir = os.path.join(root_dir, f"DIV2K_{split}_LR_bicubic", f"X{scale}")
        
        # Get all HR images
        self.hr_images = sorted(glob.glob(os.path.join(self.hr_dir, "*.png")))
        
        # Make sure we have files
        if len(self.hr_images) == 0:
            raise RuntimeError(f"No images found in {self.hr_dir}")
        
        # Basic transforms
        self.to_tensor = transforms.ToTensor()
        
        print(f"Loaded {len(self.hr_images)} images for {split}")
        
    def __len__(self):
        return len(self.hr_images)
    
    def __getitem__(self, idx):
        # Get HR image path
        hr_path = self.hr_images[idx]
        
        # Extract image ID from filename (assuming format like "0001.png")
        img_id = os.path.basename(hr_path).split('.')[0]
        
        # Construct LR image path (e.g., "0001x2.png")
        lr_path = os.path.join(self.lr_dir, f"{img_id}x{self.scale}.png")
        
        # Load images
        hr_img = Image.open(hr_path).convert('RGB')
        lr_img = Image.open(lr_path).convert('RGB')
        
        # Random crop for training
        if self.split == "train":
            # Get dimensions
            hr_width, hr_height = hr_img.size
            
            # Randomly select patch
            hr_patch_size = self.patch_size
            lr_patch_size = hr_patch_size // self.scale
            
            # Ensure we can extract a patch of the desired size
            if hr_width < hr_patch_size or hr_height < hr_patch_size:
                # If image is smaller than patch_size, resize it
                hr_img = hr_img.resize((hr_patch_size, hr_patch_size), Image.BICUBIC)
                lr_img = lr_img.resize((lr_patch_size, lr_patch_size), Image.BICUBIC)
            else:
                # Random crop
                hr_x = random.randint(0, hr_width - hr_patch_size)
                hr_y = random.randint(0, hr_height - hr_patch_size)
                
                lr_x = hr_x // self.scale
                lr_y = hr_y // self.scale
                
                hr_img = hr_img.crop((hr_x, hr_y, hr_x + hr_patch_size, hr_y + hr_patch_size))
                lr_img = lr_img.crop((lr_x, lr_y, lr_x + lr_patch_size, lr_y + lr_patch_size))
        
        # Apply augmentations for training
        if self.augment and random.random() < cfg.aug_probability:
            # Random horizontal flip
            if random.random() < 0.5:
                hr_img = hr_img.transpose(Image.FLIP_LEFT_RIGHT)
                lr_img = lr_img.transpose(Image.FLIP_LEFT_RIGHT)
                
            # Random vertical flip
            if random.random() < 0.5:
                hr_img = hr_img.transpose(Image.FLIP_TOP_BOTTOM)
                lr_img = lr_img.transpose(Image.FLIP_TOP_BOTTOM)
                
            # Random rotation (90, 180, 270 degrees)
            rot_factor = random.choice([0, 1, 2, 3])
            if rot_factor > 0:
                hr_img = hr_img.rotate(90 * rot_factor)
                lr_img = lr_img.rotate(90 * rot_factor)
        
        # Apply transforms
        hr_tensor = self.to_tensor(hr_img)
        lr_tensor = self.to_tensor(lr_img)
        
        return {'lr': lr_tensor, 'hr': hr_tensor, 'idx': idx, 'hr_path': hr_path}

In [None]:
# Channel Attention Module
class ChannelAttention(nn.Module):
    def __init__(self, channel, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.conv = nn.Sequential(
            nn.Conv2d(channel, channel // reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // reduction, channel, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        avg_out = self.conv(self.avg_pool(x))
        max_out = self.conv(self.max_pool(x))
        return self.sigmoid(avg_out + max_out) * x

# Spatial Attention Module
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        y = torch.cat([avg_out, max_out], dim=1)
        y = self.conv(y)
        return self.sigmoid(y) * x

# Residual Channel Attention Block
class RCAB(nn.Module):
    def __init__(self, channels, reduction=16, use_ca=True, use_sa=False):
        super(RCAB, self).__init__()
        self.use_ca = use_ca
        self.use_sa = use_sa
        
        self.body = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(channels, channels, 3, 1, 1)
        )
        
        if use_ca:
            self.ca = ChannelAttention(channels, reduction)
        if use_sa:
            self.sa = SpatialAttention()
            
    def forward(self, x):
        res = self.body(x)
        if self.use_ca:
            res = self.ca(res)
        if self.use_sa:
            res = self.sa(res)
        return res + x

# Residual Group
class ResidualGroup(nn.Module):
    def __init__(self, channels, n_blocks=8, use_ca=True, use_sa=False):
        super(ResidualGroup, self).__init__()
        
        body = []
        for _ in range(n_blocks):
            body.append(RCAB(channels, use_ca=use_ca, use_sa=use_sa))
            
        self.body = nn.Sequential(*body)
        self.conv = nn.Conv2d(channels, channels, 3, 1, 1)
        
    def forward(self, x):
        res = self.body(x)
        res = self.conv(res)
        return res + x

In [None]:
class ConvBlock(nn.Module):
    """
    Basic convolutional block with PReLU activation and instance normalization
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, use_norm=True):
        super(ConvBlock, self).__init__()
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)]
        if use_norm:
            layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.PReLU(num_parameters=out_channels))
        self.block = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.block(x)

class SpatialBranch(nn.Module):
    """
    Spatial branch using residual groups for feature extraction
    """
    def __init__(self, in_channels=3, base_channels=64, num_blocks=8):
        super(SpatialBranch, self).__init__()
        
        # Initial feature extraction
        self.conv_first = nn.Conv2d(in_channels, base_channels, 3, 1, 1)
        
        # Residual Groups with channel attention
        self.residual_groups = nn.ModuleList([
            ResidualGroup(base_channels, n_blocks=num_blocks, 
                         use_ca=cfg.use_channel_attention, 
                         use_sa=cfg.use_spatial_attention)
            for _ in range(2)
        ])
        
        # Final feature processing
        self.conv_last = nn.Conv2d(base_channels, base_channels, 3, 1, 1)
        
    def forward(self, x):
        # Initial feature extraction
        x = self.conv_first(x)
        residual = x
        
        # Pass through residual groups
        for rg in self.residual_groups:
            x = rg(x)
        
        # Final processing
        x = self.conv_last(x)
        
        # Global residual connection
        return x + residual

class ComplexConv2d(nn.Module):
    """Improved complex-valued convolutional layer"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ComplexConv2d, self).__init__()
        self.conv_real = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.conv_imag = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        
    def forward(self, x):
        # x is a tuple of (real, imag)
        real, imag = x
        real_out = self.conv_real(real) - self.conv_imag(imag)
        imag_out = self.conv_real(imag) + self.conv_imag(real)
        return (real_out, imag_out)

class ComplexBatchNorm2d(nn.Module):
    """Complex batch normalization"""
    def __init__(self, num_features):
        super(ComplexBatchNorm2d, self).__init__()
        self.bn_real = nn.BatchNorm2d(num_features)
        self.bn_imag = nn.BatchNorm2d(num_features)
        
    def forward(self, x):
        real, imag = x
        return (self.bn_real(real), self.bn_imag(imag))

class ComplexReLU(nn.Module):
    """Complex ReLU activation"""
    def __init__(self):
        super(ComplexReLU, self).__init__()
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        real, imag = x
        return (self.relu(real), self.relu(imag))

class ComplexConvBlock(nn.Module):
    """Complex convolutional block with normalization and activation"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ComplexConvBlock, self).__init__()
        self.conv = ComplexConv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = ComplexBatchNorm2d(out_channels)
        self.relu = ComplexReLU()
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return self.relu(x)

In [None]:
class FrequencyBranch(nn.Module):
    """
    Enhanced frequency branch that processes the image in frequency domain
    """
    def __init__(self, in_channels=3, base_channels=64):
        super(FrequencyBranch, self).__init__()
        
        # Processing for real and imaginary components
        self.complex_blocks = nn.ModuleList([
            ComplexConvBlock(base_channels//2, base_channels//2)
            for _ in range(4)
        ])
        
        # Magnitude processing
        self.magnitude_conv = nn.Sequential(
            ConvBlock(in_channels, base_channels//2),
            ResidualGroup(base_channels//2, n_blocks=4, 
                         use_ca=cfg.use_channel_attention, 
                         use_sa=False)
        )
        
        # Phase processing
        self.phase_conv = nn.Sequential(
            ConvBlock(in_channels, base_channels//2),
            ResidualGroup(base_channels//2, n_blocks=4, 
                         use_ca=cfg.use_channel_attention, 
                         use_sa=False)
        )
        
        # Convert from frequency domain features to spatial domain
        self.freq_to_spatial = nn.Sequential(
            nn.Conv2d(base_channels, base_channels, 3, 1, 1),
            nn.InstanceNorm2d(base_channels),
            nn.PReLU(num_parameters=base_channels)
        )
        
    def forward(self, x):
        batch_size, channels, height, width = x.shape
        
        # Apply FFT to convert to frequency domain
        x_fft = torch.fft.fft2(x)
        
        # Split into real and imaginary parts
        real = x_fft.real
        imag = x_fft.imag
        
        # Compute magnitude and phase
        magnitude = torch.sqrt(real**2 + imag**2 + 1e-10)
        phase = torch.atan2(imag, real + 1e-10)
        
        # Apply log scaling to magnitude for better dynamic range
        log_magnitude = torch.log(magnitude + 1.0)
        
        # Process magnitude
        mag_features = self.magnitude_conv(log_magnitude)
        
        # Process phase
        phase_features = self.phase_conv(phase)
        
        # Process complex components directly
        complex_features = (real, imag)
        for block in self.complex_blocks:
            complex_features = block(complex_features)
        
        real_features, imag_features = complex_features
        
        # Combine processed features
        combined_features = torch.cat([mag_features, phase_features], dim=1)
        
        # Convert to spatial domain features
        spatial_features = self.freq_to_spatial(combined_features)
        
        return spatial_features, (real_features, imag_features)

class AttentionFusion(nn.Module):
    """
    Advanced fusion module using channel and spatial attention
    """
    def __init__(self, channels=64):
        super(AttentionFusion, self).__init__()
        
        # Initial fusion
        self.conv1 = nn.Conv2d(channels*2, channels, 1, 1, 0)
        
        # Channel attention
        self.ca = ChannelAttention(channels)
        
        # Spatial attention
        self.sa = SpatialAttention()
        
        # Final processing
        self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1)
        self.norm = nn.InstanceNorm2d(channels)
        self.prelu = nn.PReLU(num_parameters=channels)
        
    def forward(self, spatial_features, freq_features):
        # Concatenate features
        concat_features = torch.cat([spatial_features, freq_features], dim=1)
        
        # Initial fusion
        fused = self.conv1(concat_features)
        
        # Apply attention
        fused = self.ca(fused)
        fused = self.sa(fused)
        
        # Final processing
        fused = self.conv2(fused)
        fused = self.norm(fused)
        fused = self.prelu(fused)
        
        return fused

class UpscaleBlock(nn.Module):
    """
    Upscaling block using pixel-shuffle
    """
    def __init__(self, in_channels, scale_factor):
        super(UpscaleBlock, self).__init__()
        
        self.conv = nn.Conv2d(in_channels, in_channels * (scale_factor ** 2), 3, 1, 1)
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)
        self.prelu = nn.PReLU(num_parameters=in_channels)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        return self.prelu(x)

class HybridFrequencySpatialNetwork(nn.Module):
    """
    Enhanced Hybrid Network for frequency-spatial image restoration
    """
    def __init__(self, in_channels=3, out_channels=3, base_channels=64, scale_factor=2):
        super(HybridFrequencySpatialNetwork, self).__init__()
        
        # Feature extraction for both branches
        self.feature_extract = nn.Conv2d(in_channels, base_channels, 3, 1, 1)
        
        # Branches
        self.spatial_branch = SpatialBranch(base_channels, base_channels, 
                                           num_blocks=cfg.num_residual_blocks//2)
        self.frequency_branch = FrequencyBranch(base_channels, base_channels)
        
        # Fusion
        self.fusion = AttentionFusion(base_channels)
        
        # Deep feature extraction after fusion
        self.deep_features = ResidualGroup(base_channels, n_blocks=cfg.num_residual_blocks, 
                                          use_ca=cfg.use_channel_attention, 
                                          use_sa=cfg.use_spatial_attention)
        
        # Reconstruction
        self.reconstruction = nn.Sequential(
            nn.Conv2d(base_channels, base_channels, 3, 1, 1),
            UpscaleBlock(base_channels, scale_factor),
            nn.Conv2d(base_channels, out_channels, 3, 1, 1)
        )
        
    def forward(self, x):
        # Extract initial features
        initial_features = self.feature_extract(x)
        
        # Process through branches
        spatial_features = self.spatial_branch(initial_features)
        freq_features, complex_features = self.frequency_branch(initial_features)
        
        # Fuse features
        fused_features = self.fusion(spatial_features, freq_features)
        
        # Deep feature processing
        deep_features = self.deep_features(fused_features)
        
        # Reconstruction
        output = self.reconstruction(deep_features)
        
        # Return output and intermediate features for additional supervision
        return output, (spatial_features, freq_features, complex_features)

In [None]:
class FrequencyLoss(nn.Module):
    """
    Loss function in frequency domain
    """
    def __init__(self):
        super(FrequencyLoss, self).__init__()
        
    def forward(self, output, target):
        # Convert to frequency domain
        output_fft = torch.fft.fft2(output)
        target_fft = torch.fft.fft2(target)
        
        # Compute magnitude difference
        output_magnitude = torch.sqrt(output_fft.real**2 + output_fft.imag**2 + 1e-10)
        target_magnitude = torch.sqrt(target_fft.real**2 + target_fft.imag**2 + 1e-10)
        
        # Apply log scaling
        output_log_magnitude = torch.log(output_magnitude + 1.0)
        target_log_magnitude = torch.log(target_magnitude + 1.0)
        
        # Compute L1 loss on log magnitude
        magnitude_loss = F.l1_loss(output_log_magnitude, target_log_magnitude)
        
        # Compute phase difference
        output_phase = torch.atan2(output_fft.imag, output_fft.real + 1e-10)
        target_phase = torch.atan2(target_fft.imag, target_fft.real + 1e-10)
        
        # Phase loss (accounting for phase wrapping)
        phase_diff = torch.abs(output_phase - target_phase)
        phase_diff = torch.min(phase_diff, 2*np.pi - phase_diff)
        phase_loss = phase_diff.mean()
        
        # Combined loss
        return magnitude_loss + 0.5 * phase_loss

class PerceptualLoss(nn.Module):
    """
    Perceptual loss using VGG19 features
    """
    def __init__(self, feature_layers=[2, 7, 12, 21]):
        super(PerceptualLoss, self).__init__()
        
        # Load pre-trained VGG19
        vgg = vgg19(weights=VGG19_Weights.DEFAULT).features.eval()
        
        self.feature_extractors = nn.ModuleList()
        self.feature_layers = feature_layers
        
        # Create feature extractors for each layer
        for i in range(max(feature_layers) + 1):
            self.feature_extractors.append(vgg[i])
            
            # Don't need gradients for VGG
            if i in feature_layers:
                for param in vgg[i].parameters():
                    param.requires_grad = False
        
    def forward(self, x, target):
        loss = 0.0
        
        # Register hooks to get features
        x_features = []
        target_features = []
        
        # Extract features
        for i, layer in enumerate(self.feature_extractors):
            x = layer(x)
            target = layer(target)
            
            if i in self.feature_layers:
                loss += F.l1_loss(x, target)
        
        return loss / len(self.feature_layers)

In [None]:
def calculate_psnr(img1, img2):
    """Calculate PSNR between two images"""
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0
    psnr = 20 * math.log10(max_pixel / math.sqrt(mse))
    return psnr

def calculate_ssim(img1, img2):
    """Calculate SSIM between two images using skimage implementation"""
    # Convert to grayscale if needed
    if img1.shape[-1] == 3:
        # Convert RGB to grayscale using standard weights
        y1 = 0.299 * img1[:,:,0] + 0.587 * img1[:,:,1] + 0.114 * img1[:,:,2]
        y2 = 0.299 * img2[:,:,0] + 0.587 * img2[:,:,1] + 0.114 * img2[:,:,2]
        return ssim_metric(y1, y2, data_range=1.0)
    else:
        return ssim_metric(img1, img2, data_range=1.0)

In [None]:
def visualize_results(model, val_loader, device, epoch):
    """Visualize model results on validation set"""
    model.eval()
    with torch.no_grad():
        # Get a batch of validation images
        batch = next(iter(val_loader))
        lr_imgs = batch['lr'].to(device)
        hr_imgs = batch['hr'].to(device)
        
        # Generate predictions
        outputs, _ = model(lr_imgs)
        
        # Convert to numpy
        lr_np = lr_imgs[0].cpu().numpy().transpose(1, 2, 0)
        sr_np = outputs[0].cpu().numpy().transpose(1, 2, 0)
        hr_np = hr_imgs[0].cpu().numpy().transpose(1, 2, 0)
        
        # Clip predictions
        sr_np = np.clip(sr_np, 0, 1)
        
        # Calculate metrics
        psnr = calculate_psnr(sr_np, hr_np)
        ssim = calculate_ssim(sr_np, hr_np)
        
        # Plot results
        plt.figure(figsize=(15, 5))
        
        plt.subplot(1, 3, 1)
        plt.imshow(lr_np)
        plt.title('Low Resolution')
        plt.axis('off')
        
        plt.subplot(1, 3, 2)
        plt.imshow(sr_np)
        plt.title(f'Super Resolution\nPSNR: {psnr:.2f}, SSIM: {ssim:.4f}')
        plt.axis('off')
        
        plt.subplot(1, 3, 3)
        plt.imshow(hr_np)
        plt.title('High Resolution')
        plt.axis('off')
        
        plt.tight_layout()
        plt.savefig(os.path.join(cfg.save_dir, f'results_epoch_{epoch}.png'))
        plt.close()
        
        # Log to wandb if enabled
        if cfg.use_wandb:
            wandb.log({
                'examples': wandb.Image(plt.gcf()),
                'epoch': epoch
            })

def plot_training_curves(train_losses, val_psnrs):
    """Plot training curves"""
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(train_losses)
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    plt.plot(val_psnrs)
    plt.title('Validation PSNR')
    plt.xlabel('Epoch')
    plt.ylabel('PSNR (dB)')
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(cfg.save_dir, 'training_curves.png'))
    plt.close()

In [None]:
def save_checkpoint(model, optimizer, scheduler, epoch, val_psnr, save_path):
    """Save model checkpoint"""
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'val_psnr': val_psnr
    }, save_path)

def warmup_learning_rate(base_lr, iter, warmup_iter):
    """Warm up learning rate linearly"""
    return base_lr * iter / warmup_iter

def train_model():
    """Train the hybrid frequency-spatial image restoration model"""
    # Create data loaders
    train_dataset = DIV2KDataset(
        cfg.data_root, 
        split='train', 
        scale=cfg.scale, 
        patch_size=cfg.patch_size,
        augment=True
    )
    
    val_dataset = DIV2KDataset(
        cfg.data_root, 
        split='valid', 
        scale=cfg.scale, 
        patch_size=cfg.patch_size,
        augment=False
    )
    
    train_loader = DataLoader(
        train_dataset, 
        batch_size=cfg.batch_size, 
        shuffle=True, 
        num_workers=cfg.num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=1, 
        shuffle=False, 
        num_workers=cfg.num_workers
    )
    
    # Create model
    model = HybridFrequencySpatialNetwork(
        in_channels=3, 
        out_channels=3, 
        base_channels=cfg.base_channels, 
        scale_factor=cfg.scale
    )
    
    # Move model to device
    model = model.to(device)
    
    # Define loss functions
    l1_loss = nn.L1Loss().to(device)
    perceptual_loss = PerceptualLoss().to(device)
    frequency_loss = FrequencyLoss().to(device)
    
    # Define optimizer with weight decay
    optimizer = optim.Adam(
        model.parameters(), 
        lr=cfg.lr, 
        weight_decay=cfg.weight_decay
    )
    
    # Define scheduler
    scheduler = CosineAnnealingWarmRestarts(
        optimizer, 
        T_0=cfg.num_epochs // 10, 
        T_mult=2, 
        eta_min=cfg.min_lr
    )
    
    # Gradient scaler for mixed precision
    scaler = GradScaler(enabled=cfg.use_amp)
    
    # Initialize statistics
    best_val_psnr = 0.0
    train_losses = []
    val_psnrs = []
    start_epoch = 0
    
    # Check for existing checkpoint
    checkpoint_path = os.path.join(cfg.save_dir, 'best_model.pth')
    if os.path.exists(checkpoint_path):
        print("Loading checkpoint...")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        best_val_psnr = checkpoint['val_psnr']
        print(f"Resuming from epoch {start_epoch} with PSNR {best_val_psnr:.4f}")
    
    # Training loop
    for epoch in range(start_epoch, cfg.num_epochs):
        print(f"Epoch {epoch+1}/{cfg.num_epochs}")
        model.train()
        epoch_loss = 0.0
        
        # Initialize progress bar
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.num_epochs}")
        
        # Warm-up learning rate
        if epoch < cfg.warmup_epochs:
            # Warm-up learning rate
            for param_group in optimizer.param_groups:
                param_group['lr'] = warmup_learning_rate(cfg.lr, epoch, cfg.warmup_epochs)
        
        # Track time
        start_time = time.time()
        
        # Training loop
        for batch_idx, batch in enumerate(pbar):
            # Get data
            lr_imgs = batch['lr'].to(device)
            hr_imgs = batch['hr'].to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass with mixed precision
            with autocast(enabled=cfg.use_amp):
                outputs, (spatial_features, freq_features, complex_features) = model(lr_imgs)
                
                # Calculate losses
                loss_l1 = l1_loss(outputs, hr_imgs)
                loss_perceptual = perceptual_loss(outputs, hr_imgs)
                loss_frequency = frequency_loss(outputs, hr_imgs)
                
                # Combined loss
                loss = (cfg.l1_weight * loss_l1 + 
                       cfg.perceptual_weight * loss_perceptual + 
                       cfg.freq_loss_weight * loss_frequency)
            
            # Backward pass with gradient scaling
            scaler.scale(loss).backward()
            
            # Gradient clipping
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            # Update weights
            scaler.step(optimizer)
            scaler.update()
            
            # Update progress bar
            epoch_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
            
            # Log to wandb if enabled
            if cfg.use_wandb and batch_idx % cfg.log_interval == 0:
                wandb.log({
                    'train/loss': loss.item(),
                    'train/l1_loss': loss_l1.item(),
                    'train/perceptual_loss': loss_perceptual.item(),
                    'train/frequency_loss': loss_frequency.item(),
                    'train/learning_rate': optimizer.param_groups[0]['lr']
                })
        
        # Calculate average epoch loss
        epoch_loss /= len(train_loader)
        train_losses.append(epoch_loss)
        
        # Validation phase
        model.eval()
        val_psnr = 0.0
        val_ssim = 0.0
        
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validation"):
                lr_imgs = batch['lr'].to(device)
                hr_imgs = batch['hr'].to(device)
                
                # Forward pass
                outputs, _ = model(lr_imgs)
                
                # Convert to numpy for metric calculation
                output = outputs[0].cpu().numpy().transpose(1, 2, 0)
                target = hr_imgs[0].cpu().numpy().transpose(1, 2, 0)
                
                # Clip predictions to valid range
                output = np.clip(output, 0, 1)
                
                # Calculate metrics
                psnr = calculate_psnr(output, target)
                ssim = calculate_ssim(output, target)
                
                val_psnr += psnr
                val_ssim += ssim
        
        # Calculate average validation metrics
        val_psnr /= len(val_loader)
        val_ssim /= len(val_loader)
        val_psnrs.append(val_psnr)
        
        # Update learning rate
        scheduler.step()
        
        # Print epoch results
        print(f"Epoch {epoch+1}/{cfg.num_epochs} - "
              f"Train Loss: {epoch_loss:.4f} - "
              f"Val PSNR: {val_psnr:.4f} - "
              f"Val SSIM: {val_ssim:.4f}")
        
        # Log to wandb if enabled
        if cfg.use_wandb:
            wandb.log({
                'epoch': epoch + 1,
                'train/epoch_loss': epoch_loss,
                'val/psnr': val_psnr,
                'val/ssim': val_ssim,
                'train/learning_rate': optimizer.param_groups[0]['lr']
            })
        
        # Save checkpoint if validation PSNR improved
        if val_psnr > best_val_psnr:
            best_val_psnr = val_psnr
            save_checkpoint(
                model, optimizer, scheduler, epoch, val_psnr,
                os.path.join(cfg.save_dir, 'best_model.pth')
            )
            print(f"New best model saved with PSNR: {val_psnr:.4f}")
        
        # Save periodic checkpoint
        if (epoch + 1) % cfg.checkpoint_interval == 0:
            save_checkpoint(
                model, optimizer, scheduler, epoch, val_psnr,
                os.path.join(cfg.save_dir, f'checkpoint_epoch_{epoch+1}.pth')
            )
        
        # Visualize some results
        if (epoch + 1) % 10 == 0:
            visualize_results(model, val_loader, device, epoch + 1)
        
        # Early stopping check
        if len(val_psnrs) > 10:
            if val_psnr < max(val_psnrs[-10:]):
                print("Early stopping triggered")
                break
    
    # Save final model
    save_checkpoint(
        model, optimizer, scheduler, cfg.num_epochs, val_psnr,
        os.path.join(cfg.save_dir, 'final_model.pth')
    )
    
    # Plot training curves
    plot_training_curves(train_losses, val_psnrs)
    
    # Cleanup
    if cfg.use_wandb:
        wandb.finish()
    
    return model, train_losses, val_psnrs

In [None]:
def main():
    """Main function to run training"""
    # Create save directory
    os.makedirs(cfg.save_dir, exist_ok=True)
    
    # Initialize wandb if enabled
    if cfg.use_wandb:
        wandb.init(
            project="hybrid-freq-spatial-image-restoration",
            config=vars(cfg),
            name=f"scale{cfg.scale}_channels{cfg.base_channels}"
        )
    
    # Train model
    model, train_losses, val_psnrs = train_model()
    
    print("Training completed!")
    print(f"Best validation PSNR: {max(val_psnrs):.4f}")

if __name__ == "__main__":
    main()