In [14]:
import os
import glob
import math
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torchvision.models import vgg19, VGG19_Weights
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import wandb
import cv2
from skimage.metrics import structural_similarity as ssim_metric
import time
import warnings
warnings.filterwarnings("ignore")

In [15]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [16]:
!nvidia-smi

Thu Apr 24 23:26:58 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 572.60                 Driver Version: 572.60         CUDA Version: 12.8     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4050 ...  WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   41C    P8              2W /  140W |    1033MiB /   6141MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [17]:
# Set random seed for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Configuration
class Config:
    # Dataset
    data_root = 'DIV2K'  # Adjust to your path
    scale = 2  # Upscaling factor (2, 3, or 4)
    patch_size = 128  # HR patch size for training
    aug_probability = 0.5  # Probability of applying augmentations

    # Model
    base_channels = 64
    use_channel_attention = True
    use_spatial_attention = True
    num_residual_blocks = 16

    # Training
    batch_size = 16
    num_workers = 0 # Adjust based on your system
    lr = 2e-4
    min_lr = 1e-6
    weight_decay = 1e-4
    num_epochs = 100
    warmup_epochs = 5

    # Loss
    l1_weight = 1.0
    perceptual_weight = 0.1
    freq_loss_weight = 0.05

    # Logging
    save_dir = 'results'
    checkpoint_interval = 5
    log_interval = 100
    use_wandb = False  # Set to True if you want to use Weights & Biases

    # Mixed precision
    use_amp = True

# Initialize config
cfg = Config()

# Create save directory
os.makedirs(cfg.save_dir, exist_ok=True)

# Initialize wandb if enabled
if cfg.use_wandb:
    # Make sure to login to wandb first if running locally
    try:
        wandb.init(project="hybrid-freq-spatial-image-restoration", config=vars(cfg))
    except Exception as e:
        print(f"Could not initialize wandb: {e}. Set cfg.use_wandb=False to disable.")
        cfg.use_wandb = False

In [18]:
class DIV2KDataset(Dataset):
    """
    Dataset for DIV2K with both LR and HR images
    """
    def __init__(self, root_dir, split="train", scale=2, patch_size=128, augment=True):
        """
        Args:
            root_dir: Directory with DIV2K dataset
            split: 'train' or 'valid'
            scale: Downscaling factor (2, 3, or 4)
            patch_size: Size of cropped patches for training
            augment: Whether to apply data augmentation
        """
        self.root_dir = root_dir
        self.split = split
        self.scale = scale
        self.patch_size = patch_size
        self.augment = augment and split == "train"

        # Define paths
        self.hr_dir = os.path.join(root_dir, f"DIV2K_{split}_HR")
        self.lr_dir = os.path.join(root_dir, f"DIV2K_{split}_LR_bicubic", f"X{scale}")

        # Get all HR images
        self.hr_images = sorted(glob.glob(os.path.join(self.hr_dir, "*.png")))

        # Make sure we have files
        if len(self.hr_images) == 0:
            raise RuntimeError(f"No images found in {self.hr_dir}. Please check the path.")

        # Basic transforms
        self.to_tensor = transforms.ToTensor() # Converts PIL image [0, 255] to Tensor [0, 1]

        print(f"Loaded {len(self.hr_images)} images for {split}")

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        # Get HR image path
        hr_path = self.hr_images[idx]

        # Extract image ID from filename (assuming format like "0001.png")
        img_id = os.path.basename(hr_path).split('.')[0]

        # Construct LR image path (e.g., "0001x2.png")
        lr_path = os.path.join(self.lr_dir, f"{img_id}x{self.scale}.png")

        # Load images
        try:
            hr_img = Image.open(hr_path).convert('RGB')
            lr_img = Image.open(lr_path).convert('RGB')
        except FileNotFoundError as e:
            print(f"Error loading image: {e}")
            # Return dummy data or handle appropriately
            # For simplicity, we'll raise the error here
            raise e

        # Random crop for training
        if self.split == "train":
            # Get dimensions
            hr_width, hr_height = hr_img.size

            # Randomly select patch
            hr_patch_size = self.patch_size
            lr_patch_size = hr_patch_size // self.scale

            # Ensure we can extract a patch of the desired size
            if hr_width < hr_patch_size or hr_height < hr_patch_size:
                # If image is smaller than patch_size, resize it (might affect quality)
                # Consider skipping smaller images or using different padding/cropping strategy
                hr_img = hr_img.resize((hr_patch_size, hr_patch_size), Image.BICUBIC)
                lr_img = lr_img.resize((lr_patch_size, lr_patch_size), Image.BICUBIC)
                hr_x, hr_y = 0, 0
                lr_x, lr_y = 0, 0
            else:
                # Random crop
                hr_x = random.randint(0, hr_width - hr_patch_size)
                hr_y = random.randint(0, hr_height - hr_patch_size)

                lr_x = hr_x // self.scale
                lr_y = hr_y // self.scale

                hr_img = hr_img.crop((hr_x, hr_y, hr_x + hr_patch_size, hr_y + hr_patch_size))
                lr_img = lr_img.crop((lr_x, lr_y, lr_x + lr_patch_size, lr_y + lr_patch_size))

        # Apply augmentations for training
        if self.augment and random.random() < cfg.aug_probability:
            # Random horizontal flip
            if random.random() < 0.5:
                hr_img = hr_img.transpose(Image.FLIP_LEFT_RIGHT)
                lr_img = lr_img.transpose(Image.FLIP_LEFT_RIGHT)

            # Random vertical flip
            if random.random() < 0.5:
                hr_img = hr_img.transpose(Image.FLIP_TOP_BOTTOM)
                lr_img = lr_img.transpose(Image.FLIP_TOP_BOTTOM)

            # Random rotation (90, 180, 270 degrees)
            rot_factor = random.choice([0, 1, 2, 3])
            if rot_factor > 0:
                hr_img = hr_img.rotate(90 * rot_factor)
                lr_img = lr_img.rotate(90 * rot_factor)

        # Apply transforms
        hr_tensor = self.to_tensor(hr_img)
        lr_tensor = self.to_tensor(lr_img)

        return {'lr': lr_tensor, 'hr': hr_tensor, 'idx': idx, 'hr_path': hr_path}

In [19]:
# Channel Attention Module
class ChannelAttention(nn.Module):
    def __init__(self, channel, reduction=16):
        super(ChannelAttention, self).__init__()
        # Global average pooling
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # Global max pooling
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        # Shared MLP
        self.conv = nn.Sequential(
            nn.Conv2d(channel, channel // reduction, 1, bias=False), # 1x1 conv for reduction
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // reduction, channel, 1, bias=False)  # 1x1 conv for expansion
        )
        self.sigmoid = nn.Sigmoid() # Sigmoid activation to get attention weights

    def forward(self, x):
        avg_out = self.conv(self.avg_pool(x)) # Pass avg pooled features through MLP
        max_out = self.conv(self.max_pool(x)) # Pass max pooled features through MLP
        attention_weights = self.sigmoid(avg_out + max_out) # Combine and activate
        return attention_weights * x # Apply attention weights to input feature map

# Spatial Attention Module
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        # Convolution layer to process concatenated pooled features
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) # 2 input channels (avg + max)
        self.sigmoid = nn.Sigmoid() # Sigmoid activation to get attention map

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True) # Average pooling across channels
        max_out, _ = torch.max(x, dim=1, keepdim=True) # Max pooling across channels
        y = torch.cat([avg_out, max_out], dim=1) # Concatenate pooled features
        attention_map = self.sigmoid(self.conv(y)) # Generate spatial attention map
        return attention_map * x # Apply attention map to input feature map

# Residual Channel Attention Block (RCAB)
class RCAB(nn.Module):
    def __init__(self, channels, reduction=16, use_ca=True, use_sa=False):
        super(RCAB, self).__init__()
        self.use_ca = use_ca
        self.use_sa = use_sa

        # Main convolutional path
        self.body = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 1, 1),
            nn.LeakyReLU(0.2, inplace=True), # Using LeakyReLU
            nn.Conv2d(channels, channels, 3, 1, 1)
        )

        # Optional Channel Attention
        if use_ca:
            self.ca = ChannelAttention(channels, reduction)
        # Optional Spatial Attention
        if use_sa:
            self.sa = SpatialAttention()

    def forward(self, x):
        res = self.body(x) # Pass through main conv layers
        if self.use_ca:
            res = self.ca(res) # Apply channel attention
        if self.use_sa:
            res = self.sa(res) # Apply spatial attention
        return res + x # Residual connection

# Residual Group (contains multiple RCABs)
class ResidualGroup(nn.Module):
    def __init__(self, channels, n_blocks=8, use_ca=True, use_sa=False):
        super(ResidualGroup, self).__init__()

        # Stack multiple RCABs
        body = [RCAB(channels, use_ca=use_ca, use_sa=use_sa) for _ in range(n_blocks)]
        self.body = nn.Sequential(*body)

        # Final convolution within the group
        self.conv = nn.Conv2d(channels, channels, 3, 1, 1)

    def forward(self, x):
        res = self.body(x) # Pass through all RCABs
        res = self.conv(res) # Final convolution
        return res + x # Residual connection for the whole group

In [20]:
class ConvBlock(nn.Module):
    """
    Basic convolutional block with PReLU activation and optional instance normalization
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, use_norm=True):
        super(ConvBlock, self).__init__()
        layers = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)]
        if use_norm:
            # Using InstanceNorm instead of BatchNorm for potentially better style transfer/SR tasks
            layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.PReLU(num_parameters=out_channels)) # Using PReLU activation
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class SpatialBranch(nn.Module):
    """
    Spatial branch using residual groups for feature extraction from pixel data
    """
    def __init__(self, in_channels=64, base_channels=64, num_blocks=8):
        super(SpatialBranch, self).__init__()

        # Initial feature extraction (adjusts input channels if needed)
        self.conv_first = nn.Conv2d(in_channels, base_channels, 3, 1, 1)

        # Residual Groups with channel and spatial attention
        # Dividing the total residual blocks between groups
        self.residual_groups = nn.ModuleList([
            ResidualGroup(base_channels, n_blocks=num_blocks,
                         use_ca=cfg.use_channel_attention,
                         use_sa=cfg.use_spatial_attention)
            for _ in range(2) # Using 2 residual groups
        ])

        # Final feature processing convolution within the branch
        self.conv_last = nn.Conv2d(base_channels, base_channels, 3, 1, 1)

    def forward(self, x):
        # Initial feature extraction
        x = F.leaky_relu(self.conv_first(x), 0.2, inplace=True) # Apply activation after first conv
        residual = x # Store for global residual connection

        # Pass through residual groups
        for rg in self.residual_groups:
            x = rg(x)

        # Final processing convolution
        x = self.conv_last(x)

        # Global residual connection across the spatial branch
        return x + residual

class ComplexConv2d(nn.Module):
    """Complex-valued convolutional layer"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ComplexConv2d, self).__init__()
        # Convolution for the real part of the input
        self.conv_real = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        # Convolution for the imaginary part of the input
        self.conv_imag = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)

    def forward(self, x):
        # Input x is a tuple of (real_tensor, imag_tensor)
        real, imag = x
        # Complex convolution formula: (a+bi)*(c+di) = (ac-bd) + (ad+bc)i
        # Output real part: conv_real(real) - conv_imag(imag)
        real_out = self.conv_real(real) - self.conv_imag(imag)
        # Output imaginary part: conv_real(imag) + conv_imag(real)
        imag_out = self.conv_real(imag) + self.conv_imag(real)
        return (real_out, imag_out)

class ComplexBatchNorm2d(nn.Module):
    """Complex batch normalization"""
    def __init__(self, num_features):
        super(ComplexBatchNorm2d, self).__init__()
        # Separate BatchNorm for real and imaginary parts
        self.bn_real = nn.BatchNorm2d(num_features)
        self.bn_imag = nn.BatchNorm2d(num_features)

    def forward(self, x):
        real, imag = x
        # Apply BatchNorm independently to real and imaginary tensors
        return (self.bn_real(real), self.bn_imag(imag))

class ComplexReLU(nn.Module):
    """Complex ReLU activation (applies ReLU independently to real and imag parts)"""
    def __init__(self):
        super(ComplexReLU, self).__init__()
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        real, imag = x
        # Apply ReLU to real and imaginary parts separately
        return (self.relu(real), self.relu(imag))

class ComplexConvBlock(nn.Module):
    """Complex convolutional block with normalization and activation"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ComplexConvBlock, self).__init__()
        self.conv = ComplexConv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = ComplexBatchNorm2d(out_channels) # Use complex batch norm
        self.relu = ComplexReLU() # Use complex ReLU

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return self.relu(x)

In [27]:
class FrequencyBranch(nn.Module):
    """
    Enhanced frequency branch that processes the image in frequency domain
    Handles potential precision issues with FFT during validation.
    """
    def __init__(self, in_channels=64, base_channels=64): # Input to branch has base_channels
        super(FrequencyBranch, self).__init__()

        # Processing for real and imaginary components directly from FFT
        self.complex_blocks = nn.ModuleList([
            ComplexConvBlock(base_channels, base_channels) # Expects base_channels input
            for _ in range(4) # Number of complex blocks
        ])

        # Magnitude processing path
        self.magnitude_conv = nn.Sequential(
            ConvBlock(in_channels, base_channels // 2), # Reduce channels to 32
            ResidualGroup(base_channels // 2, n_blocks=4, # Process with 32 channels
                         use_ca=cfg.use_channel_attention,
                         use_sa=False)
        )

        # Phase processing path
        self.phase_conv = nn.Sequential(
            ConvBlock(in_channels, base_channels // 2), # Reduce channels to 32
            ResidualGroup(base_channels // 2, n_blocks=4, # Process with 32 channels
                         use_ca=cfg.use_channel_attention,
                         use_sa=False)
        )

        # Convert combined magnitude/phase features back towards spatial domain representation
        self.freq_to_spatial = nn.Sequential(
            nn.Conv2d(base_channels, base_channels, 3, 1, 1),
            nn.InstanceNorm2d(base_channels),
            nn.PReLU(num_parameters=base_channels)
        )

    def forward(self, x):
        # Input x has shape [B, C=base_channels, H, W]

        # --- FIX: Cast to float32 before FFT ---
        # Ensure FFT runs in full precision to avoid cuFFT half-precision limitations
        # on non-power-of-two dimensions, especially during validation.
        x_float32 = x.float()
        # -----------------------------------------

        # Apply 2D FFT to convert to frequency domain
        x_fft = torch.fft.fft2(x_float32, norm='ortho') # Use the float32 version for FFT

        # Split into real and imaginary parts
        real = x_fft.real # Shape [B, C, H, W]
        imag = x_fft.imag # Shape [B, C, H, W]

        # Compute magnitude and phase
        magnitude = torch.sqrt(real**2 + imag**2 + 1e-10)
        phase = torch.atan2(imag, real + 1e-10)

        # Apply log scaling to magnitude
        log_magnitude = torch.log(magnitude + 1.0)

        # Process magnitude path
        # Note: log_magnitude might need casting back if subsequent layers expect AMP type
        # However, ConvBlock likely handles mixed types due to autocast.
        mag_features = self.magnitude_conv(log_magnitude) # Output shape [B, C/2, H, W]

        # Process phase path
        phase_features = self.phase_conv(phase) # Output shape [B, C/2, H, W]

        # Process complex components directly using complex convolutions
        # The ComplexConvBlock inputs (real, imag) derived from x_fft will be float32.
        # The ComplexConv2d layers inside will perform calculations potentially
        # using float16 due to autocast, but the input types are handled.
        complex_input = (real, imag) # Tuple for complex layers (float32)
        complex_out_features = complex_input
        for block in self.complex_blocks:
            complex_out_features = block(complex_out_features)
        # complex_out_features is a tuple (real_feat, imag_feat)

        # Combine processed magnitude and phase features
        combined_mag_phase_features = torch.cat([mag_features, phase_features], dim=1) # Shape [B, C, H, W]

        # Convert combined magnitude/phase features to spatial domain features
        spatial_domain_freq_features = self.freq_to_spatial(combined_mag_phase_features) # Shape [B, C, H, W]

        # Return the spatial-domain representation from mag/phase and the processed complex features
        # The types of returned features will depend on the last operation within autocast context.
        return spatial_domain_freq_features, complex_out_features


class AttentionFusion(nn.Module):
    """
    Advanced fusion module using channel and spatial attention to combine features
    """
    def __init__(self, channels=64):
        super(AttentionFusion, self).__init__()

        # Initial fusion using 1x1 convolution to reduce channels
        self.conv1 = nn.Conv2d(channels * 2, channels, kernel_size=1, stride=1, padding=0) # Input is concatenation (2*channels)

        # Channel attention on fused features
        self.ca = ChannelAttention(channels)

        # Spatial attention on fused features
        self.sa = SpatialAttention()

        # Final processing block after attention
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
        self.norm = nn.InstanceNorm2d(channels)
        self.prelu = nn.PReLU(num_parameters=channels)

    def forward(self, spatial_features, freq_features):
        # Concatenate features from both branches along the channel dimension
        concat_features = torch.cat([spatial_features, freq_features], dim=1) # Shape [B, 2*C, H, W]

        # Initial fusion and channel reduction
        fused = F.leaky_relu(self.conv1(concat_features), 0.2, inplace=True) # Shape [B, C, H, W]

        # Apply attention mechanisms sequentially
        fused_ca = self.ca(fused)
        fused_sa = self.sa(fused_ca) # Apply spatial attention after channel attention

        # Final processing
        fused_out = self.conv2(fused_sa)
        fused_out = self.norm(fused_out)
        fused_out = self.prelu(fused_out)

        # Residual connection can be added here if desired, e.g., return fused_out + fused
        return fused_out

class UpscaleBlock(nn.Module):
    """
    Upscaling block using PixelShuffle
    """
    def __init__(self, in_channels, scale_factor):
        super(UpscaleBlock, self).__init__()
        # Convolution increases channels to in_channels * (scale_factor^2)
        self.conv = nn.Conv2d(in_channels, in_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1)
        # PixelShuffle rearranges elements from [B, C * r^2, H, W] to [B, C, H * r, W * r]
        self.pixel_shuffle = nn.PixelShuffle(scale_factor)
        self.prelu = nn.PReLU(num_parameters=in_channels) # Activation after pixel shuffle

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x

class HybridFrequencySpatialNetwork(nn.Module):
    """
    Enhanced Hybrid Network for frequency-spatial image restoration
    """
    def __init__(self, in_channels=3, out_channels=3, base_channels=64, scale_factor=2):
        super(HybridFrequencySpatialNetwork, self).__init__()

        # Initial shallow feature extraction from input LR image
        self.feature_extract = nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=1, padding=1)

        # Spatial processing branch
        self.spatial_branch = SpatialBranch(in_channels=base_channels, base_channels=base_channels,
                                           num_blocks=cfg.num_residual_blocks // 2) # Half blocks here

        # Frequency processing branch
        self.frequency_branch = FrequencyBranch(in_channels=base_channels, base_channels=base_channels)

        # Feature fusion module
        self.fusion = AttentionFusion(base_channels)

        # Deep feature extraction after fusion using Residual Groups
        self.deep_features = ResidualGroup(base_channels, n_blocks=cfg.num_residual_blocks, # All blocks here
                                          use_ca=cfg.use_channel_attention,
                                          use_sa=cfg.use_spatial_attention)

        # Reconstruction module (Upscaling + final convolution)
        self.reconstruction = nn.Sequential(
            # Optional: Conv before upscaling
            # nn.Conv2d(base_channels, base_channels, 3, 1, 1),
            UpscaleBlock(base_channels, scale_factor),
            # Final convolution to get the output image channels
            nn.Conv2d(base_channels, out_channels, kernel_size=3, stride=1, padding=1)
            # Optional: Add Tanh or Sigmoid activation if output needs to be in a specific range,
            # but usually not needed if target is [0, 1] and using L1/MSE loss.
            # nn.Tanh() or nn.Sigmoid()
        )

    def forward(self, x):
        # x: Input LR image tensor [B, 3, H_lr, W_lr]

        # Extract initial features
        initial_features = F.leaky_relu(self.feature_extract(x), 0.2, inplace=True) # [B, C, H_lr, W_lr]

        # Process through spatial branch
        spatial_features = self.spatial_branch(initial_features) # [B, C, H_lr, W_lr]

        # Process through frequency branch
        # freq_features are the spatial-domain representation from mag/phase path
        # complex_features are the processed (real, imag) tuple from complex path
        freq_features, complex_features = self.frequency_branch(initial_features) # [B, C, H_lr, W_lr], ([B,C,H,W], [B,C,H,W])

        # Fuse features from both branches
        fused_features = self.fusion(spatial_features, freq_features) # [B, C, H_lr, W_lr]

        # Process fused features through deep feature extraction block
        deep_features = self.deep_features(fused_features) # [B, C, H_lr, W_lr]

        # Reconstruct the high-resolution image
        output = self.reconstruction(deep_features) # [B, 3, H_hr, W_hr]

        # Return final output and intermediate features (optional, for potential auxiliary losses)
        return output, (spatial_features, freq_features, complex_features)

In [28]:
class FrequencyLoss(nn.Module):
    """
    Loss function comparing magnitude and phase in the frequency domain
    """
    def __init__(self, use_log_magnitude=True, phase_weight=0.5):
        super(FrequencyLoss, self).__init__()
        self.use_log_magnitude = use_log_magnitude
        self.phase_weight = phase_weight
        self.l1_loss = nn.L1Loss()

    def forward(self, output, target):
        # Ensure input tensors are floating point
        output = output.float()
        target = target.float()

        # Convert to frequency domain using FFT
        output_fft = torch.fft.fft2(output, norm='ortho')
        target_fft = torch.fft.fft2(target, norm='ortho')

        # Compute magnitude difference
        output_magnitude = torch.abs(output_fft) # Magnitude = sqrt(real^2 + imag^2)
        target_magnitude = torch.abs(target_fft)

        if self.use_log_magnitude:
            # Use L1 loss on log magnitude for better handling of dynamic range
            output_log_magnitude = torch.log(output_magnitude + 1e-8) # Add epsilon before log
            target_log_magnitude = torch.log(target_magnitude + 1e-8)
            magnitude_loss = self.l1_loss(output_log_magnitude, target_log_magnitude)
        else:
            # Use L1 loss directly on magnitude
            magnitude_loss = self.l1_loss(output_magnitude, target_magnitude)

        # Compute phase difference
        output_phase = torch.angle(output_fft) # Phase = atan2(imag, real)
        target_phase = torch.angle(target_fft)

        # Phase loss (L1 distance, accounting for phase wrapping [-pi, pi])
        phase_diff = output_phase - target_phase
        # Map difference to [-pi, pi]
        phase_diff = torch.remainder(phase_diff + np.pi, 2 * np.pi) - np.pi
        phase_loss = torch.abs(phase_diff).mean() # Mean absolute phase error

        # Combined loss
        return magnitude_loss + self.phase_weight * phase_loss


class PerceptualLoss(nn.Module):
    """
    Perceptual loss using VGG19 features.
    Compares feature maps from specific layers of a pre-trained VGG network.
    """
    def __init__(self, feature_layers=[2, 7, 12, 21, 30], use_input_norm=True):
        super(PerceptualLoss, self).__init__()

        # Load pre-trained VGG19 model
        vgg = vgg19(weights=VGG19_Weights.DEFAULT).features.eval() # Set to evaluation mode

        # Normalize input based on ImageNet mean/std if needed
        self.use_input_norm = use_input_norm
        if self.use_input_norm:
            mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
            self.register_buffer('mean', mean)
            self.register_buffer('std', std)

        self.feature_extractors = nn.ModuleList()
        self.feature_layers = feature_layers # Layers to extract features from

        # Create feature extractors up to the maximum layer needed
        for i in range(max(feature_layers) + 1):
            layer = vgg[i]
            # Freeze VGG parameters
            for param in layer.parameters():
                param.requires_grad = False
            self.feature_extractors.append(layer)

        # Use L1 loss for comparing feature maps
        self.loss_fn = nn.L1Loss()

    def forward(self, x, target):
        # Ensure inputs are on the same device as the model
        x = x.to(next(self.parameters()).device)
        target = target.to(next(self.parameters()).device)

        # Normalize inputs if required
        if self.use_input_norm:
            x = (x - self.mean) / self.std
            target = (target - self.mean) / self.std

        perceptual_loss = 0.0
        current_x = x
        current_target = target

        # Extract features layer by layer and compute loss
        for i, layer in enumerate(self.feature_extractors):
            current_x = layer(current_x)
            current_target = layer(current_target)

            if i in self.feature_layers:
                perceptual_loss += self.loss_fn(current_x, current_target)

        return perceptual_loss / len(self.feature_layers) # Average loss over selected layers

In [33]:
# --- Utility Functions ---

def calculate_psnr(img1, img2, data_range=1.0):
    """Calculate Peak Signal-to-Noise Ratio (PSNR) between two images."""
    # Ensure images are numpy arrays and float type
    img1 = np.asarray(img1, dtype=np.float32)
    img2 = np.asarray(img2, dtype=np.float32)

    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf') # PSNR is infinite if images are identical
    psnr = 20 * math.log10(data_range / math.sqrt(mse))
    return psnr

def calculate_ssim(img1, img2, data_range=1.0, win_size=7):
    """
    Calculate Structural Similarity Index (SSIM) between two images.
    Handles cases where images are too small for the specified window size.
    """
    # Ensure images are numpy arrays and float type
    img1 = np.asarray(img1, dtype=np.float32)
    img2 = np.asarray(img2, dtype=np.float32)

    # SSIM expects channel dimension last or grayscale
    if img1.ndim == 3 and img1.shape[0] in [1, 3]: # If channel is first (e.g., [C, H, W])
        img1 = img1.transpose(1, 2, 0) # Convert to [H, W, C]
    if img2.ndim == 3 and img2.shape[0] in [1, 3]:
        img2 = img2.transpose(1, 2, 0)

    # --- FIX: Check image dimensions before calculating SSIM ---
    # Get spatial dimensions (height, width)
    if img1.ndim == 3: # HWC
        h1, w1, _ = img1.shape
        h2, w2, _ = img2.shape
    elif img1.ndim == 2: # HW (grayscale)
        h1, w1 = img1.shape
        h2, w2 = img2.shape
    else:
        warnings.warn(f"Unexpected image dimension: {img1.ndim}. Skipping SSIM calculation.")
        return 0.0 # Cannot calculate SSIM

    # Check if window size is valid for both images
    if h1 < win_size or w1 < win_size or h2 < win_size or w2 < win_size:
        warnings.warn(
            f"Image size ({h1}x{w1}, {h2}x{w2}) is too small for "
            f"SSIM calculation with win_size={win_size}. Returning SSIM=0.0."
        )
        return 0.0 # Return 0.0 if images are too small

    # Ensure win_size is odd
    if win_size % 2 == 0:
        warnings.warn(f"win_size must be odd, but got {win_size}. Using win_size={win_size+1}")
        win_size += 1
    # ----------------------------------------------------------

    # Calculate SSIM using skimage
    # For multichannel images (RGB), set channel_axis=-1 (last axis)
    is_multichannel = img1.ndim == 3 and img1.shape[-1] == 3
    channel_axis = -1 if is_multichannel else None

    # Use data_range argument correctly
    ssim_val = ssim_metric(
        img1, img2,
        win_size=win_size,
        data_range=data_range,
        channel_axis=channel_axis # Use channel_axis instead of multichannel
    )
    return ssim_val

In [34]:
def visualize_results(model, val_loader, device, epoch, save_dir):
    """Visualize model results on a validation sample and save the plot."""
    model.eval() # Set model to evaluation mode
    with torch.no_grad(): # Disable gradient calculation
        try:
            # Get a batch of validation images
            batch = next(iter(val_loader))
            lr_imgs = batch['lr'].to(device) # Shape [B, C, H_lr, W_lr]
            hr_imgs = batch['hr'].to(device) # Shape [B, C, H_hr, W_hr]

            # Generate predictions using the model
            outputs, _ = model(lr_imgs) # Shape [B, C, H_hr, W_hr]

            # Select the first image from the batch for visualization
            lr_img = lr_imgs[0].cpu().numpy().transpose(1, 2, 0) # Convert to HWC, numpy
            sr_img = outputs[0].cpu().numpy().transpose(1, 2, 0) # Convert to HWC, numpy
            hr_img = hr_imgs[0].cpu().numpy().transpose(1, 2, 0) # Convert to HWC, numpy

            # Clip predictions to the valid range [0, 1]
            sr_img = np.clip(sr_img, 0, 1)
            hr_img = np.clip(hr_img, 0, 1) # Also clip HR just in case
            lr_img = np.clip(lr_img, 0, 1) # Also clip LR

            # Calculate metrics for the visualized sample
            psnr = calculate_psnr(sr_img, hr_img, data_range=1.0)
            ssim = calculate_ssim(sr_img, hr_img, data_range=1.0)

            # Plot results: LR, SR (Super-Resolved), HR (High-Resolution)
            fig, axes = plt.subplots(1, 3, figsize=(18, 6)) # Create figure and axes

            # Plot Low Resolution
            axes[0].imshow(lr_img, vmin=0, vmax=1)
            axes[0].set_title('Low Resolution Input')
            axes[0].axis('off')

            # Plot Super Resolution
            axes[1].imshow(sr_img, vmin=0, vmax=1)
            axes[1].set_title(f'Super Resolution Output\nPSNR: {psnr:.2f} dB, SSIM: {ssim:.4f}')
            axes[1].axis('off')

            # Plot High Resolution Ground Truth
            axes[2].imshow(hr_img, vmin=0, vmax=1)
            axes[2].set_title('High Resolution Ground Truth')
            axes[2].axis('off')

            plt.tight_layout() # Adjust layout
            save_path = os.path.join(save_dir, f'results_epoch_{epoch:03d}.png')
            plt.savefig(save_path) # Save the figure
            print(f"Saved visualization to {save_path}")
            plt.close(fig) # Close the figure to free memory

            # Log to wandb if enabled
            if cfg.use_wandb:
                wandb.log({
                    f"examples/epoch_{epoch}": wandb.Image(save_path),
                    'epoch': epoch # Ensure epoch is logged for x-axis
                })

        except Exception as e:
            print(f"Error during visualization: {e}")
            # If val_loader is exhausted, reset it or handle appropriately
            if isinstance(e, StopIteration):
                print("Validation loader exhausted during visualization.")
            # Optionally, re-raise the error or continue training
            # raise e


def plot_training_curves(train_losses, val_psnrs, save_dir):
    """Plot and save training loss and validation PSNR curves."""
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(12, 5))

    # Plot Training Loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Training Loss')
    plt.title('Training Loss per Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)
    plt.legend()

    # Plot Validation PSNR
    plt.subplot(1, 2, 2)
    plt.plot(epochs, val_psnrs, label='Validation PSNR', color='orange')
    plt.title('Validation PSNR per Epoch')
    plt.xlabel('Epoch')
    plt.ylabel('PSNR (dB)')
    plt.grid(True)
    plt.legend()

    plt.tight_layout()
    save_path = os.path.join(save_dir, 'training_curves.png')
    plt.savefig(save_path)
    print(f"Saved training curves to {save_path}")
    plt.close() # Close the plot

    # Log curves to wandb if enabled
    if cfg.use_wandb:
        for i in range(len(train_losses)):
             wandb.log({'chart/train_loss': train_losses[i], 'chart/val_psnr': val_psnrs[i], 'epoch': i + 1})

In [37]:
def save_checkpoint(model, optimizer, scheduler, epoch, val_psnr, save_path):
    """Save model checkpoint."""
    state = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'val_psnr': val_psnr,
        'config': vars(cfg) # Save config for reference
    }
    torch.save(state, save_path)
    print(f"Checkpoint saved to {save_path} at epoch {epoch}")

def load_checkpoint(model, optimizer, scheduler, checkpoint_path, device):
    """Load model checkpoint."""
    if not os.path.exists(checkpoint_path):
        print(f"Checkpoint file not found: {checkpoint_path}")
        return 0, 0.0 # Return start epoch 0, best PSNR 0

    print(f"Loading checkpoint from {checkpoint_path}...")
    checkpoint = torch.load(checkpoint_path, map_location=device)

    # Load model state
    # Handle potential DataParallel prefix if model was saved with it
    state_dict = checkpoint['model_state_dict']
    if list(state_dict.keys())[0].startswith('module.'):
        state_dict = {k[len('module.'):]: v for k, v in state_dict.items()}
    model.load_state_dict(state_dict)

    # Load optimizer and scheduler state
    if optimizer and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if scheduler and 'scheduler_state_dict' in checkpoint:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    start_epoch = checkpoint.get('epoch', -1) + 1 # Resume from next epoch
    best_val_psnr = checkpoint.get('val_psnr', 0.0)

    print(f"Resuming training from epoch {start_epoch} with best validation PSNR: {best_val_psnr:.4f}")
    return start_epoch, best_val_psnr


def warmup_learning_rate(base_lr, current_epoch, warmup_epochs):
    """Linear warm-up for learning rate."""
    if current_epoch >= warmup_epochs:
        return base_lr # Return base LR after warmup
    # Linearly increase LR from a small value (e.g., base_lr / 10) to base_lr
    start_lr = base_lr / 10
    return start_lr + (base_lr - start_lr) * (current_epoch / warmup_epochs)


def train_model():
    """Train the hybrid frequency-spatial image restoration model."""
    # --- Setup ---
    # Create data loaders
    try:
        train_dataset = DIV2KDataset(
            cfg.data_root,
            split='train',
            scale=cfg.scale,
            patch_size=cfg.patch_size,
            augment=True
        )
        val_dataset = DIV2KDataset(
            cfg.data_root,
            split='valid',
            scale=cfg.scale,
            patch_size=cfg.patch_size, # Use full images or large patches for validation? Usually full images.
            augment=False
        )
    except RuntimeError as e:
        print(f"Error creating dataset: {e}")
        print("Please ensure the DIV2K dataset is correctly placed and the 'data_root' path is correct.")
        return None, [], [] # Exit gracefully

    train_loader = DataLoader(
        train_dataset,
        batch_size=cfg.batch_size,
        shuffle=True,
        num_workers=cfg.num_workers,
        pin_memory=True, # Improves data transfer speed to GPU
        drop_last=True # Drop last incomplete batch
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=1, # Validate one image at a time for metric calculation
        shuffle=False,
        num_workers=cfg.num_workers
    )

    # Create model
    model = HybridFrequencySpatialNetwork(
        in_channels=3,
        out_channels=3,
        base_channels=cfg.base_channels,
        scale_factor=cfg.scale
    )

    # Move model to device (GPU or CPU)
    model = model.to(device)

    # Optional: Use DataParallel for multi-GPU training
    # if torch.cuda.device_count() > 1:
    #    print(f"Using {torch.cuda.device_count()} GPUs!")
    #    model = nn.DataParallel(model)

    # Define loss functions
    l1_loss = nn.L1Loss().to(device)
    perceptual_loss = PerceptualLoss().to(device)
    frequency_loss = FrequencyLoss().to(device)

    # Define optimizer
    optimizer = optim.Adam(
        model.parameters(),
        lr=cfg.lr,
        weight_decay=cfg.weight_decay,
        betas=(0.9, 0.999) # Default Adam betas
    )

    # Define learning rate scheduler
    # Cosine Annealing with Warm Restarts
    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=cfg.num_epochs // 10 if cfg.num_epochs >= 10 else cfg.num_epochs, # Number of epochs for the first restart
        T_mult=2, # Factor to increase T_i after a restart
        eta_min=cfg.min_lr # Minimum learning rate
    )

    # Gradient scaler for Automatic Mixed Precision (AMP)
    scaler = GradScaler(enabled=cfg.use_amp)

    # Initialize statistics
    best_val_psnr = 0.0
    train_losses = []
    val_psnrs = []
    start_epoch = 0

    # --- Checkpoint Loading ---
    best_checkpoint_path = os.path.join(cfg.save_dir, 'best_model.pth')
    latest_checkpoint_path = os.path.join(cfg.save_dir, 'latest_model.pth') # Optional: save latest checkpoint too

    # Prefer loading the best model if it exists
    load_path = best_checkpoint_path if os.path.exists(best_checkpoint_path) else latest_checkpoint_path
    if os.path.exists(load_path):
         start_epoch, best_val_psnr = load_checkpoint(model, optimizer, scheduler, load_path, device)
         # Reload stats if needed (e.g., load train_losses, val_psnrs from checkpoint)

    # --- Training Loop ---
    print(f"Starting training from epoch {start_epoch}...")
    for epoch in range(start_epoch, cfg.num_epochs):
        print(f"\n--- Epoch {epoch+1}/{cfg.num_epochs} ---")
        model.train() # Set model to training mode
        epoch_loss = 0.0
        epoch_l1_loss = 0.0
        epoch_perc_loss = 0.0
        epoch_freq_loss = 0.0

        # Initialize progress bar for the training epoch
        pbar = tqdm(train_loader, desc=f"Train E{epoch+1}", leave=False)

        # Apply learning rate warm-up if applicable
        current_lr = optimizer.param_groups[0]['lr'] # Get current LR before potential warmup adjustment
        if epoch < cfg.warmup_epochs:
            new_lr = warmup_learning_rate(cfg.lr, epoch, cfg.warmup_epochs)
            for param_group in optimizer.param_groups:
                param_group['lr'] = new_lr
            current_lr = new_lr # Update current_lr for logging
            print(f"Warm-up Epoch {epoch+1}: LR set to {current_lr:.6f}")
        else:
             print(f"Epoch {epoch+1}: LR = {current_lr:.6f}")


        # --- Training Batch Loop ---
        for batch_idx, batch in enumerate(pbar):
            # Get data and move to device
            lr_imgs = batch['lr'].to(device, non_blocking=True) # Use non_blocking for potential speedup
            hr_imgs = batch['hr'].to(device, non_blocking=True)

            # Zero gradients before the backward pass
            optimizer.zero_grad()

            # Forward pass with Automatic Mixed Precision (AMP) context
            with autocast(enabled=cfg.use_amp):
                # Get model outputs
                outputs, (spatial_features, freq_features, complex_features) = model(lr_imgs)
                # Ensure outputs are float32 for loss calculation if needed
                outputs = outputs.float()

                # Calculate individual loss components
                loss_l1 = l1_loss(outputs, hr_imgs)
                loss_perceptual = perceptual_loss(outputs, hr_imgs)
                loss_frequency = frequency_loss(outputs, hr_imgs)

                # Combine losses with specified weights
                total_loss = (cfg.l1_weight * loss_l1 +
                              cfg.perceptual_weight * loss_perceptual +
                              cfg.freq_loss_weight * loss_frequency)

            # Backward pass: Calculate gradients using the scaled loss
            scaler.scale(total_loss).backward()

            # Optional: Gradient clipping (unscale first)
            scaler.unscale_(optimizer) # Unscales gradients inplace
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Clip gradient norm

            # Optimizer step: Update model weights using scaled gradients
            scaler.step(optimizer)

            # Update the scale factor for the next iteration
            scaler.update()

            # --- Logging and Progress ---
            batch_loss = total_loss.item()
            epoch_loss += batch_loss
            epoch_l1_loss += loss_l1.item()
            epoch_perc_loss += loss_perceptual.item()
            epoch_freq_loss += loss_frequency.item()

            # Update progress bar description
            pbar.set_postfix({
                'Loss': f'{batch_loss:.4f}',
                'L1': f'{loss_l1.item():.4f}',
                'Perc': f'{loss_perceptual.item():.4f}',
                'Freq': f'{loss_frequency.item():.4f}',
                'LR': f'{current_lr:.1e}'
            })

            # Log batch metrics to wandb if enabled
            if cfg.use_wandb and batch_idx % cfg.log_interval == 0:
                step = epoch * len(train_loader) + batch_idx
                wandb.log({
                    'train/batch_loss': batch_loss,
                    'train/batch_l1_loss': loss_l1.item(),
                    'train/batch_perceptual_loss': loss_perceptual.item(),
                    'train/batch_frequency_loss': loss_frequency.item(),
                    'train/learning_rate': current_lr,
                    'step': step,
                    'epoch': epoch + (batch_idx / len(train_loader)) # Log fractional epoch
                })

        # --- End of Epoch ---
        # Calculate average losses for the epoch
        avg_epoch_loss = epoch_loss / len(train_loader)
        avg_l1_loss = epoch_l1_loss / len(train_loader)
        avg_perc_loss = epoch_perc_loss / len(train_loader)
        avg_freq_loss = epoch_freq_loss / len(train_loader)
        train_losses.append(avg_epoch_loss)

        print(f"Epoch {epoch+1} Average Train Loss: {avg_epoch_loss:.4f} "
              f"(L1: {avg_l1_loss:.4f}, Perc: {avg_perc_loss:.4f}, Freq: {avg_freq_loss:.4f})")


        # --- Validation Phase ---
        model.eval() # Set model to evaluation mode
        total_val_psnr = 0.0
        total_val_ssim = 0.0

        val_pbar = tqdm(val_loader, desc=f"Validate E{epoch+1}", leave=False)
        with torch.no_grad(): # Disable gradient calculations during validation
            for batch in val_pbar:
                lr_imgs = batch['lr'].to(device, non_blocking=True)
                hr_imgs = batch['hr'].to(device, non_blocking=True)

                # Forward pass (no need for AMP context here, but doesn't hurt)
                with autocast(enabled=cfg.use_amp):
                    outputs, _ = model(lr_imgs)
                    outputs = outputs.float() # Ensure float32 for metrics

                # Convert tensors to numpy arrays for metric calculation
                # Process each image in the batch (batch size is 1 here)
                output_np = outputs[0].cpu().numpy().transpose(1, 2, 0) # HWC format
                target_np = hr_imgs[0].cpu().numpy().transpose(1, 2, 0) # HWC format

                # Clip predictions to the valid range [0, 1] before calculating metrics
                output_np = np.clip(output_np, 0, 1)
                target_np = np.clip(target_np, 0, 1)

                # Calculate metrics
                psnr = calculate_psnr(output_np, target_np, data_range=1.0)
                ssim = calculate_ssim(output_np, target_np, data_range=1.0)

                total_val_psnr += psnr
                total_val_ssim += ssim
                val_pbar.set_postfix({'PSNR': f'{psnr:.2f}', 'SSIM': f'{ssim:.4f}'})


        # Calculate average validation metrics
        avg_val_psnr = total_val_psnr / len(val_loader)
        avg_val_ssim = total_val_ssim / len(val_loader)
        val_psnrs.append(avg_val_psnr)

        print(f"Epoch {epoch+1} Validation Results - Avg PSNR: {avg_val_psnr:.4f} dB, Avg SSIM: {avg_val_ssim:.4f}")

        # Update learning rate scheduler (step after validation)
        # Note: Step scheduler based on epoch, not validation score unless it's ReduceLROnPlateau
        if epoch >= cfg.warmup_epochs: # Don't step scheduler during warmup
             scheduler.step()


        # --- Logging Epoch Results ---
        if cfg.use_wandb:
            wandb.log({
                'epoch': epoch + 1, # Log integer epoch
                'train/epoch_loss': avg_epoch_loss,
                'train/epoch_l1_loss': avg_l1_loss,
                'train/epoch_perceptual_loss': avg_perc_loss,
                'train/epoch_frequency_loss': avg_freq_loss,
                'val/avg_psnr': avg_val_psnr,
                'val/avg_ssim': avg_val_ssim,
                'train/learning_rate': current_lr # Log LR used for the epoch
            })

        # --- Checkpointing ---
        # Save the latest model state
        save_checkpoint(
            model, optimizer, scheduler, epoch, avg_val_psnr,
            latest_checkpoint_path
        )

        # Save checkpoint if validation PSNR improved
        if avg_val_psnr > best_val_psnr:
            best_val_psnr = avg_val_psnr
            save_checkpoint(
                model, optimizer, scheduler, epoch, best_val_psnr,
                best_checkpoint_path
            )
            print(f"*** New best model saved with PSNR: {best_val_psnr:.4f} ***")

        # Save periodic checkpoint (optional)
        if (epoch + 1) % cfg.checkpoint_interval == 0:
            periodic_path = os.path.join(cfg.save_dir, f'checkpoint_epoch_{epoch+1:03d}.pth')
            save_checkpoint(
                model, optimizer, scheduler, epoch, avg_val_psnr,
                periodic_path
            )

        # --- Visualization ---
        # Visualize some results periodically (e.g., every 10 epochs)
        if (epoch + 1) % 10 == 0 or epoch == cfg.num_epochs - 1:
            visualize_results(model, val_loader, device, epoch + 1, cfg.save_dir)


        # --- Early Stopping (Optional) ---
        # Add early stopping logic if desired, e.g., stop if val PSNR doesn't improve for N epochs.
        patience = 15
        if len(val_psnrs) > patience and avg_val_psnr < max(val_psnrs[-patience-1:-1]):
             print(f"Validation PSNR did not improve for {patience} epochs. Early stopping.")
             break


    # --- End of Training ---
    print("\nTraining finished!")
    print(f"Best validation PSNR achieved: {best_val_psnr:.4f}")

    # Plot final training curves
    plot_training_curves(train_losses, val_psnrs, cfg.save_dir)

    # Cleanup wandb run
    if cfg.use_wandb:
        wandb.finish()

    return model, train_losses, val_psnrs

In [38]:
def main():
    """Main function to setup and run training"""
    print("Starting Hybrid Frequency-Spatial Image Restoration Training...")
    print(f"Configuration: {vars(cfg)}")

    # Ensure save directory exists
    os.makedirs(cfg.save_dir, exist_ok=True)

    # Initialize wandb if enabled (moved inside train_model for better scope)

    # Train the model
    model, train_losses, val_psnrs = train_model()

    if model is not None:
        print("\nTraining completed successfully!")
        if val_psnrs:
            print(f"Final Best Validation PSNR: {max(val_psnrs):.4f}")
        else:
            print("No validation results recorded.")
    else:
        print("\nTraining failed or was interrupted.")

if __name__ == "__main__":
    # This block runs when the script is executed directly
    main()

Starting Hybrid Frequency-Spatial Image Restoration Training...
Configuration: {}
Loaded 800 images for train
Loaded 100 images for valid
Loading checkpoint from results\best_model.pth...
Resuming training from epoch 1 with best validation PSNR: 12.2912
Starting training from epoch 1...

--- Epoch 2/100 ---


Train E2:   0%|          | 0/50 [00:00<?, ?it/s]

Warm-up Epoch 2: LR set to 0.000056


                                                                                                                       

Epoch 2 Average Train Loss: 0.7110 (L1: 0.1842, Perc: 4.1059, Freq: 2.3238)


                                                                                       

Epoch 2 Validation Results - Avg PSNR: 13.4346 dB, Avg SSIM: 0.3652
Checkpoint saved to results\latest_model.pth at epoch 1
Checkpoint saved to results\best_model.pth at epoch 1
*** New best model saved with PSNR: 13.4346 ***

--- Epoch 3/100 ---


Train E3:   0%|          | 0/50 [00:00<?, ?it/s]

Warm-up Epoch 3: LR set to 0.000092


                                                                                                                       

Epoch 3 Average Train Loss: 0.5780 (L1: 0.1224, Perc: 3.5627, Freq: 1.9865)


                                                                                       

Epoch 3 Validation Results - Avg PSNR: 14.8454 dB, Avg SSIM: 0.4698
Checkpoint saved to results\latest_model.pth at epoch 2
Checkpoint saved to results\best_model.pth at epoch 2
*** New best model saved with PSNR: 14.8454 ***

--- Epoch 4/100 ---


Train E4:   0%|          | 0/50 [00:00<?, ?it/s]

Warm-up Epoch 4: LR set to 0.000128


                                                                                                                       

Epoch 4 Average Train Loss: 0.5186 (L1: 0.1053, Perc: 3.1984, Freq: 1.8692)


                                                                                       

Epoch 4 Validation Results - Avg PSNR: 12.3651 dB, Avg SSIM: 0.4354
Checkpoint saved to results\latest_model.pth at epoch 3

--- Epoch 5/100 ---


Train E5:   0%|          | 0/50 [00:00<?, ?it/s]

Warm-up Epoch 5: LR set to 0.000164


                                                                                                                       

Epoch 5 Average Train Loss: 0.4998 (L1: 0.0966, Perc: 3.1409, Freq: 1.7832)


                                                                                       

Epoch 5 Validation Results - Avg PSNR: 13.4819 dB, Avg SSIM: 0.4841
Checkpoint saved to results\latest_model.pth at epoch 4
Checkpoint saved to results\checkpoint_epoch_005.pth at epoch 4

--- Epoch 6/100 ---


Train E6:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 6: LR = 0.000164


                                                                                                                       

Epoch 6 Average Train Loss: 0.4687 (L1: 0.0931, Perc: 2.9001, Freq: 1.7108)


                                                                                       

Epoch 6 Validation Results - Avg PSNR: 13.5431 dB, Avg SSIM: 0.4809
Checkpoint saved to results\latest_model.pth at epoch 5

--- Epoch 7/100 ---


Train E7:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 7: LR = 0.000195


                                                                                                                       

Epoch 7 Average Train Loss: 0.4518 (L1: 0.0906, Perc: 2.7754, Freq: 1.6727)


                                                                                       

Epoch 7 Validation Results - Avg PSNR: 12.3039 dB, Avg SSIM: 0.4582
Checkpoint saved to results\latest_model.pth at epoch 6

--- Epoch 8/100 ---


Train E8:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 8: LR = 0.000181


                                                                                                                       

Epoch 8 Average Train Loss: 0.4287 (L1: 0.0845, Perc: 2.6395, Freq: 1.6050)


                                                                                       

Epoch 8 Validation Results - Avg PSNR: 12.7601 dB, Avg SSIM: 0.4483
Checkpoint saved to results\latest_model.pth at epoch 7

--- Epoch 9/100 ---


Train E9:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 9: LR = 0.000159


                                                                                                                       

Epoch 9 Average Train Loss: 0.4159 (L1: 0.0826, Perc: 2.5469, Freq: 1.5729)


                                                                                       

Epoch 9 Validation Results - Avg PSNR: 13.2515 dB, Avg SSIM: 0.4775
Checkpoint saved to results\latest_model.pth at epoch 8

--- Epoch 10/100 ---


Train E10:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 10: LR = 0.000131


                                                                                                                        

Epoch 10 Average Train Loss: 0.4107 (L1: 0.0845, Perc: 2.4876, Freq: 1.5491)


                                                                                        

Epoch 10 Validation Results - Avg PSNR: 12.7946 dB, Avg SSIM: 0.4676
Checkpoint saved to results\latest_model.pth at epoch 9
Checkpoint saved to results\checkpoint_epoch_010.pth at epoch 9
Saved visualization to results\results_epoch_010.png

--- Epoch 11/100 ---


Train E11:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 11: LR = 0.000101


                                                                                                                        

Epoch 11 Average Train Loss: 0.3913 (L1: 0.0766, Perc: 2.3827, Freq: 1.5288)


                                                                                        

Epoch 11 Validation Results - Avg PSNR: 12.2890 dB, Avg SSIM: 0.4274
Checkpoint saved to results\latest_model.pth at epoch 10

--- Epoch 12/100 ---


Train E12:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 12: LR = 0.000070


                                                                                                                        

Epoch 12 Average Train Loss: 0.3761 (L1: 0.0747, Perc: 2.2583, Freq: 1.5106)


                                                                                        

Epoch 12 Validation Results - Avg PSNR: 12.5736 dB, Avg SSIM: 0.4363
Checkpoint saved to results\latest_model.pth at epoch 11

--- Epoch 13/100 ---


Train E13:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 13: LR = 0.000042


                                                                                                                        

Epoch 13 Average Train Loss: 0.3714 (L1: 0.0738, Perc: 2.2363, Freq: 1.4780)


                                                                                        

Epoch 13 Validation Results - Avg PSNR: 12.1984 dB, Avg SSIM: 0.4129
Checkpoint saved to results\latest_model.pth at epoch 12

--- Epoch 14/100 ---


Train E14:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 14: LR = 0.000020


                                                                                                                        

Epoch 14 Average Train Loss: 0.3636 (L1: 0.0741, Perc: 2.1646, Freq: 1.4618)


                                                                                        

Epoch 14 Validation Results - Avg PSNR: 12.0569 dB, Avg SSIM: 0.4101
Checkpoint saved to results\latest_model.pth at epoch 13

--- Epoch 15/100 ---


Train E15:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 15: LR = 0.000006


                                                                                                                        

Epoch 15 Average Train Loss: 0.3572 (L1: 0.0717, Perc: 2.1269, Freq: 1.4554)


                                                                                        

Epoch 15 Validation Results - Avg PSNR: 12.5078 dB, Avg SSIM: 0.4380
Checkpoint saved to results\latest_model.pth at epoch 14
Checkpoint saved to results\checkpoint_epoch_015.pth at epoch 14

--- Epoch 16/100 ---


Train E16:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 16: LR = 0.000200


                                                                                                                        

Epoch 16 Average Train Loss: 0.3899 (L1: 0.0810, Perc: 2.3486, Freq: 1.4813)


                                                                                        

Epoch 16 Validation Results - Avg PSNR: 12.5694 dB, Avg SSIM: 0.4471
Checkpoint saved to results\latest_model.pth at epoch 15

--- Epoch 17/100 ---


Train E17:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 17: LR = 0.000199


                                                                                                                        

Epoch 17 Average Train Loss: 0.3847 (L1: 0.0788, Perc: 2.3353, Freq: 1.4477)


                                                                                        

Epoch 17 Validation Results - Avg PSNR: 12.3911 dB, Avg SSIM: 0.4353
Checkpoint saved to results\latest_model.pth at epoch 16
Validation PSNR did not improve for 15 epochs. Early stopping.

Training finished!
Best validation PSNR achieved: 14.8454
Saved training curves to results\training_curves.png

Training completed successfully!
Final Best Validation PSNR: 14.8454
