In [1]:
filename = "requirements.txt"
packs = """absl-py==2.3.0
certifi
charset-normalizer
clean-fid
colorama
contourpy
cycler
filelock
fonttools
fsspec
grpcio
h5py
idna
imageio
Jinja2
kiwisolver
lazy_loader
lightning-utilities
Markdown
MarkupSafe
matplotlib
mpmath
networkx
numpy
packaging
patchify
pillow
protobuf
pyparsing
python-dateutil
requests
scikit-image
scipy
six
sympy
tensorboard
tensorboard-data-server
tifffile
torch
torch-fidelity
torchaudio
torchmetrics
torchvision
tqdm
typing_extensions
urllib3
Werkzeug
piq
"""

with open(filename,'w')as f:
  f.write(packs)

In [2]:
!pip install -r requirements.txt




In [3]:
import json
import os
import random
import tempfile
import wandb
import yaml
import csv
import zipfile
import uuid
import requests
import numpy as np
from pathlib import Path
import sys
import torch
from torch.utils.data import DataLoader, random_split
from torch.utils.data import Dataset
from torch.optim import Adam
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF   # For tensor ↔ image conversion
from torchvision.models import vgg19, VGG19_Weights
from torchvision.utils import save_image
from torchmetrics.image.fid import FrechetInceptionDistance
from tqdm import tqdm
from PIL import Image, ImageDraw, ImageFont, ImageFile       # For drawing text and making collages
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
from piq import ssim  # pip install piq
from concurrent.futures import ThreadPoolExecutor

In [4]:
#################### Models ########################

#==================== SRCNN Model =========================
class SRCNN(nn.Module):
    """ Vanilla model of SRCNN with 3 conv layers. """

    def __init__(self):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=2)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=5, padding=2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        return x
#==========================================================

#================== VDSR ========================
class VDSR(nn.Module):
    def __init__(self, num_channels=1):
        super(VDSR, self).__init__()
        layers = [nn.Conv2d(num_channels, 64, kernel_size=3, padding=1), nn.ReLU(inplace=True)]
        for _ in range(18):
            layers.append(nn.Conv2d(64, 64, kernel_size=3, padding=1))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(64, num_channels, kernel_size=3, padding=1))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x) + x  # Residual learning
#===========================================================

#================ VDSR_Attention Model =====================
class VDSR_SA(nn.Module):
    def __init__(self, num_channels=3, num_features=64, num_resblocks=18):
        super().__init__()
        self.input_conv = nn.Conv2d(num_channels, num_features, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.resblocks = nn.Sequential(
            *[ResidualBlockSA(num_features) for _ in range(num_resblocks)]
        )
        #self.output_conv = nn.Conv2d(num_features, num_channels, kernel_size=3, padding=1)
        self.output_conv = nn.Sequential(
            nn.Conv2d(num_features, num_features // 2, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(num_features // 2, num_channels, 3, padding=1)
        )


    def forward(self, x):
        out = self.relu(self.input_conv(x))
        out = self.resblocks(out)
        out = self.output_conv(out)
        return out + x  # residual learning

class ResidualBlockSA(nn.Module):
    def __init__(self, num_features=64):
        super().__init__()
        self.conv1 = nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(num_features, num_features, kernel_size=3, padding=1)
        self.sa = SpatialAttention(kernel_size=7)

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.sa(out)
        return x + out


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super().__init__()
        assert kernel_size in (3, 7), "Kernel size must be 3 or 7"
        padding = (kernel_size - 1) // 2

        # Compress channels using max-pool and avg-pool and concatenate
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
        #self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # Apply max-pool and avg-pool along channel axis (dim=1)
        max_pool = torch.max(x, dim=1, keepdim=True)[0]
        avg_pool = torch.mean(x, dim=1, keepdim=True)

        pool = torch.cat([max_pool, avg_pool], dim=1)  # shape (B, 2, H, W)
        attention = self.sigmoid(self.conv(pool))      # shape (B, 1, H, W)
        return x * attention
#===========================================================

#===========================================================

In [5]:
#=================== Data ===================
from PIL import ImageFile

#================== DIV2K Data Loader ==================
class DIV2KDataset(Dataset):
    def __init__(self, hr_folder, scale=2, hr_size=(256, 256)):
        self.hr_folder = hr_folder
        self.hr_files = sorted([f for f in os.listdir(hr_folder) if f.endswith('.png')])
        self.scale = scale
        self.hr_size = hr_size
        self.lr_size = (hr_size[0] // scale, hr_size[1] // scale)

        self.hr_transform = transforms.Compose([
            transforms.Resize(hr_size, interpolation=Image.BICUBIC),
            transforms.ToTensor()
        ])

        self.lr_transform = transforms.Compose([
            transforms.Resize(self.lr_size, interpolation=Image.BICUBIC),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.hr_files)

    def __getitem__(self, idx):
        hr_path = os.path.join(self.hr_folder, self.hr_files[idx])
        hr_image = Image.open(hr_path).convert('RGB')

        hr_tensor = self.hr_transform(hr_image)
        lr_tensor = self.lr_transform(hr_image)

        return lr_tensor, hr_tensor

class TiledDIV2KDataset(Dataset):
    def __init__(self, hr_folder="Data/DIV2K", scale=4, crop_size=512, num_threads=8):
        self.hr_folder = hr_folder
        self.hr_files = sorted([f for f in os.listdir(hr_folder) if f.endswith('.png')])
        self.scale = scale
        self.crop_size = crop_size
        self.lr_size = crop_size // scale
        self.transform = transforms.ToTensor()
        self.samples = []

        print(f"Preprocessing {len(self.hr_files)} HR images for tiling using {num_threads} threads...")

        def process_image(file_idx_file):
            file_idx, file_name = file_idx_file
            img_path = os.path.join(self.hr_folder, file_name)
            try:
                with Image.open(img_path) as img:
                    w, h = img.size
                tiles_w = w // self.crop_size
                tiles_h = h // self.crop_size
                return [(file_idx, j * self.crop_size, i * self.crop_size)
                        for i in range(tiles_h) for j in range(tiles_w)]
            except Exception as e:
                print(f"Error processing {file_name}: {e}")
                return []

        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            results = list(tqdm(
                executor.map(process_image, enumerate(self.hr_files)),
                total=len(self.hr_files),
                desc="Tiling DIV2K"
            ))

        # Flatten list of lists
        self.samples = [sample for sublist in results for sample in sublist]

        print(f"Total crop pairs generated: {len(self.samples)}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        file_idx, x, y = self.samples[idx]
        file_name = self.hr_files[file_idx]
        img_path = os.path.join(self.hr_folder, file_name)

        with Image.open(img_path) as img:
            hr_image = img.convert("RGB")
            hr_crop = hr_image.crop((x, y, x + self.crop_size, y + self.crop_size))
            lr_crop = hr_crop.resize((self.lr_size, self.lr_size), Image.BICUBIC)

        return self.transform(lr_crop), self.transform(hr_crop)
#===========================================================

#================== DIV2K Data Utils ==================
def download_div2k(destination="data"):
    """
    Downloads the DIV2K dataset zip file, extracts it to 'data/DIV2K',
    flattens all subfolders (copies images to a single folder), and removes duplicates.
    """
    os.makedirs(destination, exist_ok=True)
    url = "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip"
    zip_path = os.path.join(destination, "DIV2K_train_HR.zip")
    extract_temp = os.path.join(destination, "_temp_extract")
    final_folder = os.path.join(destination, "DIV2K")

    # === Download ZIP ===
    if not os.path.exists(zip_path):
        print("Downloading DIV2K...")
        with requests.get(url, stream=True) as r:
            r.raise_for_status()
            with open(zip_path, 'wb') as f:
                for chunk in tqdm(r.iter_content(chunk_size=8192), desc="Downloading"):
                    f.write(chunk)
    else:
        print("DIV2K zip already exists.")

    # === Extract ZIP to temp directory ===
    if not os.path.exists(final_folder):
        print("Extracting DIV2K...")
        os.makedirs(extract_temp, exist_ok=True)
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_temp)
        print("Extraction complete.")

        # Flatten and copy all images to DIV2K directory, remove duplicates
        os.makedirs(final_folder, exist_ok=True)
        seen = set()
        for root, _, files in os.walk(extract_temp):
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    src_path = os.path.join(root, file)
                    dst_path = os.path.join(final_folder, file)
                    if file not in seen:
                        os.rename(src_path, dst_path)
                        seen.add(file)
        # Clean up temporary extraction
        for root, dirs, files in os.walk(extract_temp, topdown=False):
            for name in files:
                os.remove(os.path.join(root, name))
            for name in dirs:
                os.rmdir(os.path.join(root, name))
        os.rmdir(extract_temp)
        print("Files flattened and duplicates removed.")

def preprocess_div2k_center_crop(source_folder="Data/DIV2K", target_folder="Data/DIV2K_CROPPED", patch_size=(512, 512)):
    """
    Extracts a centered patch of size `patch_size` (default 512x512) from each image in 'source_folder'
    and saves the patch to 'target_folder'. Skips corrupted or incompatible images.
    """
    if not os.path.exists(target_folder):
        os.makedirs(target_folder, exist_ok=True)

    print(f"Cropping center patches of size {patch_size}...")

    for img_name in os.listdir(source_folder):
        img_path = os.path.join(source_folder, img_name)
        target_path = os.path.join(target_folder, img_name)

        try:
            with Image.open(img_path) as img:
                img.load()  # Ensure full image is read

                width, height = img.size
                crop_width, crop_height = patch_size

                if width < crop_width or height < crop_height:
                    print(f"Skipping {img_name}: image too small ({width}x{height})")
                    continue

                left = (width - crop_width) // 2
                upper = (height - crop_height) // 2
                right = left + crop_width
                lower = upper + crop_height

                cropped_img = img.crop((left, upper, right, lower))
                cropped_img.save(target_path)

        except Exception as e:
            print(f"Skipping {img_name}: {e}")

    print("Center cropping complete. Cropped patches saved to:", target_folder)


def preprocess_div2k(source_folder="Data/DIV2K", target_folder="Data/DIV2K_NORMALIZED", standard_size=(2048, 1408)):
    """
    Resizes all images in 'source_folder' to a fixed standard resolution (default 2048x1408)
    and saves them into 'target_folder'. Uses img.load() to catch any corrupted files.
    """
    if not os.path.exists(target_folder):
        os.makedirs(target_folder, exist_ok=True)

    print(f"Preprocessing images to standard resolution: {standard_size}...")

    for img_name in os.listdir(source_folder):
        img_path = os.path.join(source_folder, img_name)
        target_path = os.path.join(target_folder, img_name)

        try:
            with Image.open(img_path) as img:
                img.load()  # 🔹 Ensure full image is read to trigger exception if corrupted
                img = img.resize(standard_size, Image.BICUBIC)
                img.save(target_path)
        except Exception as e:
            print(f"Skipping {img_name}: {e}")

    print("Preprocessing complete. Normalized images saved to:", target_folder)

    def save_batch_as_images(batch_tensor, root_dir):
        """
        Saves a batch of tensors (images) to the specified directory as PNG files.
        Each tensor in the batch is saved with a unique UUID filename.
        """
        os.makedirs(root_dir, exist_ok=True)
        for i, img in enumerate(batch_tensor):
            img = img.clamp(0, 1).cpu()
            vutils.save_image(img, os.path.join(root_dir, f"{uuid.uuid4().hex}.png"))

def resize_lr_images(folder_path, target_size=(512, 512)):
    """
    Iterates over a folder and resizes all image files containing '_lr' in their filename to 512x512.

    Args:
        folder_path (str): Path to the folder with images.
        target_size (tuple): Target size to resize images to. Default is (512, 512).
    """
    valid_exts = (".png", ".jpg", ".jpeg", ".bmp", ".tiff")

    for filename in os.listdir(folder_path):
        if "_lr" in filename and filename.lower().endswith(valid_exts):
            file_path = os.path.join(folder_path, filename)

            try:
                with Image.open(file_path) as img:
                    img_resized = img.resize(target_size, Image.BICUBIC)
                    img_resized.save(file_path)
            except Exception as e:
                print(f"Failed to process {filename}: {e}")
#===========================================================


In [6]:
#====================== Metric utils =========================

def compute_psnr(sr: torch.Tensor, hr: torch.Tensor, data_range=1.0) -> float:
    """
    Compute average PSNR between super-resolved and high-res images.
    Args:
        sr: Super-resolved image tensor [B, C, H, W]
        hr: High-res ground truth tensor [B, C, H, W]
    Returns:
        PSNR value averaged over the batch
    """
    sr = sr.clamp(0, 1) # Ensure values are in [0, 1] range
    hr = hr.clamp(0, 1) # Ensure values are in [0, 1] range
    sr_np = sr.detach().cpu().numpy()
    hr_np = hr.detach().cpu().numpy()
    psnr = 0.0
    for i in range(sr_np.shape[0]):
        psnr += peak_signal_noise_ratio(hr_np[i].transpose(1, 2, 0),
                                        sr_np[i].transpose(1, 2, 0),
                                        data_range=data_range)
    return psnr / sr_np.shape[0]

def compute_ssim_batch(sr: torch.Tensor, hr: torch.Tensor, data_range=1.0) -> float:
    """
    Compute average SSIM over a batch.
    Args:
        sr: Super-resolved tensor [B, C, H, W]
        hr: High-res ground truth tensor [B, C, H, W]
    Returns:
        Average SSIM value over batch
    """
    sr = sr.clamp(0, 1) # Ensure values are in [0, 1] range
    hr = hr.clamp(0, 1) # Ensure values are in [0, 1] range
    sr_np = sr.detach().cpu().numpy()
    hr_np = hr.detach().cpu().numpy()
    ssim = 0.0
    for i in range(sr_np.shape[0]):
        ssim += structural_similarity(hr_np[i].transpose(1, 2, 0),
                                    sr_np[i].transpose(1, 2, 0),
                                    channel_axis=2,
                                    data_range=data_range)
    return ssim / sr_np.shape[0]

#===========================================================

In [7]:
#====================== Plot utils =========================
def annotate_image(tensor_img, text):
    img = TF.to_pil_image(tensor_img.squeeze(0).cpu())
    draw = ImageDraw.Draw(img)
    try:
        font = ImageFont.truetype("arial.ttf", 16)
    except:
        font = ImageFont.load_default()
    draw.text((5, 5), text, fill="white", font=font)
    return TF.to_tensor(img)

def create_collage(images, save_path):
    widths, heights = zip(*(i.size for i in images))
    total_width = sum(widths)
    max_height = max(heights)
    collage = Image.new('RGB', (total_width, max_height))
    x_offset = 0
    for img in images:
        collage.paste(img, (x_offset, 0))
        x_offset += img.size[0]
    collage.save(save_path)


# === Training Curves Plot ===
def plot_training_curves(history, model_name, save_path=None):

    epochs = range(1, len(history['train_loss']) + 1)
    plt.figure(figsize=(12, 6))
    plt.plot(epochs, history['train_loss'], label='Train Loss')
    if 'val_loss' in history:
        plt.plot(epochs, history['val_loss'], label='Val Loss')
    plt.plot(epochs, history['val_psnr'], label='Val PSNR')
    plt.plot(epochs, history['val_ssim'], label='Val SSIM')
    if any(history.get('val_fid', [])):
        plt.plot(epochs, [fid if fid is not None else np.nan for fid in history['val_fid']], label='Val FID')
    plt.xlabel('Epoch')
    plt.ylabel('Metric')
    plt.legend()
    plt.title(f"Training Curves - {model_name}")
    if save_path:
        plt.savefig(save_path)
    plt.show()

#===========================================================

In [8]:
#====================== Results Logger =========================
def log_result(
    model_name,
    loss_type,
    metrics,
    save_dir,
    final_train_loss=None,
    final_val_loss=None,
    csv_path="results.csv"
):
    """Append model results to a central CSV log with timestamps and losses."""
    fieldnames = [
        "timestamp_unix",
        "timestamp_readable",
        "model",
        "loss",
        "train_loss",
        "val_loss",
        "psnr",
        "ssim",
        "fid",
        "save_dir"
    ]
    file_exists = os.path.exists(csv_path)

    timestamp_unix = int(time.time())
    timestamp_readable = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    row = {
        "timestamp_unix": timestamp_unix,
        "timestamp_readable": timestamp_readable,
        "model": model_name,
        "loss": loss_type,
        "train_loss": final_train_loss,
        "val_loss": final_val_loss,
        "psnr": metrics.get("test_psnr"),
        "ssim": metrics.get("test_ssim"),
        "fid": metrics.get("test_fid"),
        "save_dir": save_dir
    }

    with open(csv_path, "a", newline="") as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        if not file_exists:
            writer.writeheader()
        writer.writerow(row)
#===============================================================

In [9]:
#====================== Loss Functions =========================
class PerceptualLoss(nn.Module):
    def __init__(self, device="cpu"):
        super().__init__()

        # Load pretrained VGG19 and use layers up to relu2_2 (layer 8)
        vgg = vgg19(weights=VGG19_Weights.DEFAULT).features[:9].to(device).eval()

        # Freeze VGG parameters
        for param in vgg.parameters():
            param.requires_grad = False

        self.feature_extractor = vgg
        self.criterion = nn.MSELoss()

        # Normalize to match VGG19 input expectations
        self.normalize = transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )

    def forward(self, sr, hr):
        """
        sr, hr: [B, 3, H, W] images in [0, 1] range
        """
        # Normalize each image in batch
        sr_norm = torch.stack([self.normalize(img) for img in sr])
        hr_norm = torch.stack([self.normalize(img) for img in hr])

        sr_feat = self.feature_extractor(sr_norm)
        hr_feat = self.feature_extractor(hr_norm)

        return self.criterion(sr_feat, hr_feat)

class NewCombinedLoss(nn.Module):
    def __init__(self, alpha=0.2, beta=0.4, data_range=1.0):
        """
        alpha: weight for MSE
        beta: weight for L1
        (1 - alpha - beta): weight for SSIM
        """
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.data_range = data_range
        self.mse = nn.MSELoss()
        self.l1 = nn.L1Loss()

    def forward(self, prediction, target):
        mse_loss = self.mse(prediction, target)
        l1_loss = self.l1(prediction, target)
        prediction = prediction.clamp(0,1)
        target = target.clamp(0,1)
        ssim_val = ssim(prediction, target, data_range=self.data_range)
        ssim_loss = 1 - ssim_val  # Higher SSIM = better, so loss = 1 - SSIM
        return self.alpha * mse_loss + self.beta * l1_loss + (1 - self.alpha - self.beta) * ssim_loss

class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.8, resize=True, device='cpu'):
        super().__init__()
        self.perceptual = PerceptualLoss(device=device)
        self.alpha = alpha
        self.mse = nn.MSELoss()

    def forward(self, sr, hr):
        l2 = self.mse(sr, hr)
        perceptual = self.perceptual(sr, hr)
        return self.alpha * l2 + (1 - self.alpha) * perceptual

class CharbonnierLoss(torch.nn.Module):
    def __init__(self, eps=1e-3):
        super(CharbonnierLoss, self).__init__()
        self.eps = eps

    def forward(self, pred, target):
        return torch.mean(torch.sqrt((pred - target) ** 2 + self.eps ** 2))
#===========================================================

In [10]:
#====================== Training =========================

def train_and_validate(
    model,
    train_loader,
    val_loader,
    optimizer,
    loss_fn,
    save_dir,
    checkpoint_dir="checkpoints",
    model_name="Model",
    num_epochs=20,
    val_fid_interval=5,
    device=None,
    verbose=True,
    early_stopping_patience=10,
    use_wandb=False
):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(checkpoint_dir, exist_ok=True)

    best_val_psnr = 0
    early_stop_counter = 0
    history = {'train_loss': [], 'val_loss': [], 'val_psnr': [], 'val_ssim': [], 'val_fid': []}

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for lr_img, hr_img in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            lr_img, hr_img = lr_img.to(device), hr_img.to(device)
            lr_up = F.interpolate(lr_img, size=hr_img.shape[-2:], mode='bicubic', align_corners=False)
            output = model(lr_up).clamp(0, 1)
            loss = loss_fn(output, hr_img)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)
        scheduler.step(avg_train_loss)
        history['train_loss'].append(avg_train_loss)

        # === Validation ===
        model.eval()
        val_loss_total = 0.0
        psnr_list, ssim_list = [], []
        with torch.no_grad():
            for lr_img, hr_img in val_loader:
                lr_img, hr_img = lr_img.to(device), hr_img.to(device)
                lr_up = F.interpolate(lr_img, size=hr_img.shape[-2:], mode='bicubic', align_corners=False)
                output = model(lr_up).clamp(0, 1)

                val_loss_total += loss_fn(output, hr_img).item()
                psnr_list.append(compute_psnr(output, hr_img))
                ssim_list.append(compute_ssim_batch(output, hr_img))

        avg_val_loss = val_loss_total / len(val_loader)
        val_psnr = np.mean(psnr_list)
        val_ssim = np.mean(ssim_list)

        history['val_loss'].append(avg_val_loss)
        history['val_psnr'].append(val_psnr)
        history['val_ssim'].append(val_ssim)

        # === Optional FID ===
        if (epoch + 1) % val_fid_interval == 0:
            fid_metric = FrechetInceptionDistance(feature=2048, normalize=True).to(device)
            with torch.no_grad():
                for lr_img, hr_img in val_loader:
                    lr_img, hr_img = lr_img.to(device), hr_img.to(device)
                    lr_up = F.interpolate(lr_img, size=hr_img.shape[-2:], mode='bicubic', align_corners=False)
                    output = model(lr_up).clamp(0, 1)
                    sr_resized = F.interpolate(output, size=(299, 299), mode='bilinear', align_corners=False)
                    hr_resized = F.interpolate(hr_img, size=(299, 299), mode='bilinear', align_corners=False)
                    fid_metric.update(sr_resized, real=False)
                    fid_metric.update(hr_resized, real=True)
                fid_score = fid_metric.compute().item()
                history['val_fid'].append(fid_score)
        else:
            fid_score = None
            history['val_fid'].append(None)

        # === Logging ===
        if verbose:
            print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}, Val Loss={avg_val_loss:.4f}, "
                  f"PSNR={val_psnr:.2f}, SSIM={val_ssim:.4f}, FID={fid_score}")

        if use_wandb:
            import wandb
            wandb.log({
                "epoch": epoch + 1,
                "train_loss": avg_train_loss,
                "val_loss": avg_val_loss,
                "val_psnr": val_psnr,
                "val_ssim": val_ssim,
                "val_fid": fid_score
            })

        # === Checkpointing ===
        if val_psnr > best_val_psnr:
            best_val_psnr = val_psnr
            early_stop_counter = 0
            torch.save(model.state_dict(), os.path.join(checkpoint_dir, model_name + '_best_model.pth'))
        else:
            early_stop_counter += 1

        if early_stop_counter >= early_stopping_patience:
            if verbose:
                print(f"\nEarly stopping at epoch {epoch+1}. No improvement in PSNR for {early_stopping_patience} epochs.")
            break

    # === Append training history to metrics.json ===
    metrics_path = os.path.join(save_dir, "metrics.json")

    with open(metrics_path, "w") as f:
        #json.dump(existing_data, f, indent=2)
        json.dump(history, f, indent=2)

    return model, history


def test_upsample(
    model,
    test_loader,
    save_dir,
    checkpoint_dir="checkpoints",
    model_name="Model",
    forced_indices=None,
    device=None,
    verbose=True,
    use_wandb=False
):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    best_model_path = os.path.join(checkpoint_dir, model_name + '_best_model.pth')
    if os.path.exists(best_model_path):
        model.load_state_dict(torch.load(best_model_path))
        if verbose:
            print(f"Loaded best model from checkpoint: {best_model_path}")

    model.eval()
    psnr_list, ssim_list = [], []
    collage_dir = Path(save_dir) / "collages"
    example_dir = Path(save_dir) / "test_examples"
    os.makedirs(collage_dir, exist_ok=True)
    os.makedirs(example_dir, exist_ok=True)

    example_data = {}
    dataset = test_loader.dataset

    if forced_indices is None:
        forced_indices = sorted(random.sample(range(len(dataset)), 10))

    with torch.no_grad():
        for idx in forced_indices:
            lr_img, hr_img = dataset[idx]
            lr_img, hr_img = lr_img.unsqueeze(0).to(device), hr_img.unsqueeze(0).to(device)
            lr_up = F.interpolate(lr_img, size=hr_img.shape[-2:], mode='bicubic', align_corners=False)
            output = model(lr_up).clamp(0, 1)

            psnr = compute_psnr(output, hr_img)
            ssim = compute_ssim_batch(output, hr_img)
            psnr_list.append(psnr)
            ssim_list.append(ssim)

            collage = [TF.to_pil_image(t.squeeze(0).cpu()) for t in [lr_img, output, hr_img]]
            collage_path = collage_dir / f"{idx:05d}_PSNR{psnr:.2f}_SSIM{ssim:.4f}.png"
            resize_lr_images(example_dir, target_size=(512, 512))
            create_collage(collage, collage_path)

            paths = {
                "lr": example_dir / f"{idx}_lr.png",
                "sr": example_dir / f"{idx}_sr.png",
                "hr": example_dir / f"{idx}_hr.png"
            }

            save_image(lr_img.clamp(0, 1), paths["lr"])
            save_image(output.clamp(0, 1), paths["sr"])
            save_image(hr_img.clamp(0, 1), paths["hr"])

            example_data[idx] = {
                "lr": str(paths["lr"]),
                "sr": str(paths["sr"]),
                "hr": str(paths["hr"]),
                "psnr": float(psnr),
                "ssim": float(ssim)
            }

    fid_metric = FrechetInceptionDistance(feature=2048, normalize=True).to(device)
    with torch.no_grad():
        for entry in example_data.values():
            sr = TF.to_tensor(Image.open(entry["sr"]).convert("RGB")).unsqueeze(0).to(device)
            hr = TF.to_tensor(Image.open(entry["hr"]).convert("RGB")).unsqueeze(0).to(device)
            sr = F.interpolate(sr, size=(299, 299), mode='bilinear', align_corners=False)
            hr = F.interpolate(hr, size=(299, 299), mode='bilinear', align_corners=False)
            fid_metric.update(sr, real=False)
            fid_metric.update(hr, real=True)

    final_metrics = {
        "test_psnr": float(np.mean(psnr_list)),
        "test_ssim": float(np.mean(ssim_list)),
        "test_fid": float(fid_metric.compute().item())
    }

    if verbose:
        print("\n=== Final Test Metrics ===")
        for k, v in final_metrics.items():
            print(f"{k.upper()}: {v:.4f}")

    metrics_path = os.path.join(save_dir, 'metrics.json')

    # Load existing history from metrics.json if it exists
    if os.path.exists(metrics_path):
        try:
            with open(metrics_path, 'r') as f:
                metrics_data = json.load(f)
            if not isinstance(metrics_data, dict):
                print("Warning: metrics.json is not a valid dictionary. Resetting.")
                metrics_data = {}
        except Exception as e:
            print(f"Error reading metrics.json: {e}. Resetting.")
            metrics_data = {}
    else:
        metrics_data = {}

    # Update the loaded dictionary with test results
    metrics_data.update(final_metrics)

    # Save updated dictionary (overwriting the file)
    with open(metrics_path, 'w') as f:
        json.dump(metrics_data, f, indent=2)

    # Save test examples
    with open(Path(save_dir) / "test_examples.json", 'w') as f:
        json.dump(example_data, f, indent=2)

    if use_wandb:
        wandb.log(final_metrics)
        for idx in list(example_data.keys())[:3]:  # log 3 example images
            wandb.log({
                f"Example_{idx}": [
                    wandb.Image(str(example_data[idx]["lr"]), caption="LR"),
                    wandb.Image(str(example_data[idx]["sr"]), caption="SR"),
                    wandb.Image(str(example_data[idx]["hr"]), caption="HR"),
                ]
            })

    return final_metrics, example_data
#===========================================================

In [11]:
# ---------------------- Channel Attention Layer ----------------------
class CALayer(nn.Module):
    def __init__(self, channel, reduction=8):
        super(CALayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(channel, channel // reduction, 1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // reduction, channel, 1, bias=True),
            #nn.Sigmoid()
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.fc(y)
        return x * y

# ---------------------- Pyramid Deep SRCNN with Channel Attention ----------------------
class PyramidDeepSRCNN_CA(nn.Module):
    def __init__(self, num_channels=3):
        super(PyramidDeepSRCNN_CA, self).__init__()

        # Entry layer
        self.entry = nn.Sequential(
            nn.Conv2d(num_channels, 16, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        # Progressive channel expansion
        self.conv1 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        # Deep middle blocks (64 channels)
        self.middle_blocks = nn.Sequential(
            *[
                nn.Sequential(
                    nn.Conv2d(64, 64, kernel_size=3, padding=1),
                    nn.ReLU(inplace=True)
                ) for _ in range(6)
            ],
            CALayer(64)
        )

        # Progressive channel reduction
        self.deconv1 = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.deconv2 = nn.Sequential(
            nn.Conv2d(32, 16, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

        # Output layer
        self.exit = nn.Conv2d(16, num_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.entry(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.middle_blocks(x)
        x = self.deconv1(x)
        x = self.deconv2(x)
        x = self.exit(x)
        return x


In [18]:
def preprocess_pipeline(config):
    print("Downloading and preparing DIV2K dataset...")
    download_div2k("Data")

    # Set all random seeds for reproducibility
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    generator = torch.Generator().manual_seed(seed)

    # Apply augmentation if requested
    if config["aug"]:
        add_augmentation("Data/DIV2K")

    # Load dataset
    if config["tiled"]:
        print("Using TiledDIV2KDataset")
        dataset = TiledDIV2KDataset("Data/DIV2K", scale=config["scale"])
    else:
        print("Using DIV2KDataset")
        dataset = DIV2KDataset("Data/DIV2K", scale=config["scale"])

    # Ensure deterministic splitting
    total_size = len(dataset)
    train_size = int(0.8 * total_size)
    val_size = int(0.1 * total_size)
    test_size = total_size - train_size - val_size

    train_set, val_set, test_set = random_split(
        dataset, [train_size, val_size, test_size], generator=generator
    )

    # Set shuffle=False to maintain consistent index-to-image mapping
    train_loader = DataLoader(train_set, batch_size=config["batch_size"], shuffle=False)
    val_loader = DataLoader(val_set, batch_size=1, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=1, shuffle=False)

    forced_indices = config["idx"]

    print("Dataset ready and reproducible.")
    return train_loader, val_loader, test_loader, forced_indices

def test_pipeline(config, test_loader, forced_indices, device, history=None):
    model_name = config["model"]

    # === Model Instantiation ===
    if model_name == "SRCNN":
        model = SRCNN()
        test_fn = test_model_with_upsample
    elif model_name == "SvOcSRCNN":
        model = SvOcSRCNN()
        test_fn = test_model_with_upsample
    elif model_name == "VDSR":
        model = VDSR(num_channels=3)
        test_fn = test_model_with_upsample
    elif model_name == "VDSR_SA":
        model = VDSR_SA(num_features=64, num_resblocks=24)
        test_fn = test_model_with_upsample
    elif model_name == "dsrcnn_ca":
        model = PyramidDeepSRCNN_CA(num_channels=3)
        test_fn = test_model_with_upsample
    elif model_name == "RCAN":
        model = RCAN(num_channels=3, scale=config["scale"])
        test_fn = test_model_no_upsample
    elif model_name == "RCAN_SWIN":
        model = RCAN_Swin(num_channels=3)
        test_fn = test_model_no_upsample
    else:
        raise ValueError(f"Unsupported model in test_pipeline: {model_name}")

    model = model.to(device)

    # === Run the test function ===
    metrics, example_data = test_fn(
        model=model,
        test_loader=test_loader,
        save_dir=config["save_dir"],
        checkpoint_dir="checkpoints",
        model_name=model_name,
        forced_indices=forced_indices,
        device=device,
        use_wandb=config.get("use_wandb", False),
        verbose=True
    )

    if config.get("use_wandb", False):
        wandb.log(metrics)

    log_result(model_name, config["loss"], metrics, config["save_dir"])
    generate_summary_collage_from_checkpoints()

    print(f"Testing complete for model: {model_name}")

    final_train_loss = history['train_loss'][-1] if history and 'train_loss' in history else None
    final_val_loss = history['val_loss'][-1] if history and 'val_loss' in history else None

    log_result(
    model_name=config["model"],
    loss_type=config["loss"],
    metrics=metrics,
    save_dir=config["save_dir"],
    final_train_loss=final_train_loss,
    final_val_loss=final_val_loss
    )

    return metrics


def get_loss_fn(name,device=None):
    if name == "mse":
        return nn.MSELoss()
    elif name == "charbonnier":
        return CharbonnierLoss()
    elif name == "combined":
        return CombinedLoss(alpha=0.8, device=device)
    elif name == "NewCombinedLoss":
        return NewCombinedLoss(alpha=0.2, beta=0.6)
    else:
        raise ValueError("Unsupported loss function: " + name)

def train_pipeline(config, train_loader, val_loader, device):

    model_name = config["model"]
    loss_fn = get_loss_fn(config["loss"], device)
    lr = float(config.get("lr", 1e-4))

    # === Model Selection ===
    if model_name == "SRCNN":
        model = SRCNN().to(device)
    elif model_name == "SvOcSRCNN":
        model = SvOcSRCNN().to(device)
    elif model_name == "VDSR_SA":
        model = VDSR_SA(num_features=64, num_resblocks=15).to(device)
    elif model_name == "VDSR":
        model = VDSR(num_channels=3).to(device)
    elif model_name == "dsrcnn_ca":
        model = PyramidDeepSRCNN_CA(num_channels=3).to(device)
    elif model_name == "RCAN_SWIN":
        model = RCAN_Swin(num_channels=3).to(device)
    elif model_name == "RCAN":
        model = RCAN(num_channels=3, scale=config["scale"]).to(device)
    else:
        raise ValueError(f"Model {model_name} not supported in train_pipeline.")

    optimizer = Adam(model.parameters(), lr=lr)

    print(f"Training model: {model_name} with loss: {config['loss']}")

    if model_name in ["RCAN", "RCAN_SWIN"]:
        trained_model, history, _ = train_no_upsample(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            test_loader=None,  # not used in training phase
            optimizer=optimizer,
            loss_fn=loss_fn,
            model_name=model_name,
            save_dir=config["save_dir"],
            checkpoint_dir="checkpoints",
            num_epochs=config["epochs"],
            device=device,
            forced_indices=None,
            verbose=True,
            early_stopping_patience=10
        )
    else:
        trained_model, history = train_and_validate(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            optimizer=optimizer,
            loss_fn=loss_fn,
            save_dir=config["save_dir"],
            checkpoint_dir="checkpoints",
            model_name=model_name,
            num_epochs=config["epochs"],
            device=device,
            use_wandb=config.get("use_wandb", False),
            early_stopping_patience=10
        )

    print(f"Training completed for model: {model_name}")
    return trained_model, history

In [19]:
def create_multi_model_collage(root_dir: str = "checkpoints", font_path: str = None, font_size: int = 22):
    """
    Generates a comparison collage of model outputs from multiple models.
    Each row represents an example index, and each column represents a model's SR image with LR and HR as the first two columns.
    """
    root = Path(root_dir)
    model_dirs = [d for d in root.iterdir() if d.is_dir() and (d / "test_examples.json").exists() and (d / "metrics.json").exists()]
    model_dirs = sorted(model_dirs, key=lambda d: d.name)

    if not model_dirs:
        raise ValueError("No valid model directories found.")

    # Load test examples and ensure consistent example indices
    model_data = {}
    example_indices = None
    for model_dir in model_dirs:
        name = model_dir.name
        with open(model_dir / "test_examples.json") as f:
            examples = json.load(f)
        with open(model_dir / "metrics.json") as f:
            metrics = json.load(f)

        indices = sorted(map(int, examples.keys()))
        if example_indices is None:
            example_indices = indices
        elif example_indices != indices:
            raise ValueError(f"Example indices do not match across models. Check model {name}.")

        model_data[name] = {"examples": examples, "metrics": metrics}

    num_examples = len(example_indices)
    num_models = len(model_data)
    columns = ["LR", *model_data.keys(), "HR"]
    cell_width, cell_height = 400, 400 + font_size + 10
    font = ImageFont.truetype(font_path or str(ImageFont.load_default().path), font_size) if font_path else ImageFont.load_default()

    for example_idx in example_indices:
        collage = Image.new("RGB", (len(columns) * cell_width, cell_height), (255, 255, 255))
        draw = ImageDraw.Draw(collage)

        for col_idx, label in enumerate(columns):
            x = col_idx * cell_width
            if label == "LR":
                img_path = model_dirs[0] / "test_examples" / f"{example_idx}_lr.png"
                caption = "Low Resolution"
            elif label == "HR":
                img_path = model_dirs[0] / "test_examples" / f"{example_idx}_hr.png"
                caption = "High Resolution"
            else:
                img_path = Path(model_data[label]["examples"][str(example_idx)]["sr"])
                psnr = model_data[label]["examples"][str(example_idx)]["psnr"]
                ssim = model_data[label]["examples"][str(example_idx)]["ssim"]
                fid = model_data[label]["metrics"]["test_fid"]
                caption = f"{label}\nPSNR: {psnr:.2f}, SSIM: {ssim:.4f}, FID: {fid:.2f}"

            if not img_path.exists():
                continue

            img = Image.open(img_path).convert("RGB").resize((cell_width, cell_width), Image.BICUBIC)
            collage.paste(img, (x, 0))
            draw.rectangle([(x, cell_width), (x + cell_width, cell_height)], fill=(255, 255, 255))
            draw.text((x + 5, cell_width + 2), caption, fill=(0, 0, 0), font=font)

        out_path = root /"summary_collages"/ f"{example_idx:03d}_comparison_collage.png"
        collage.save(out_path)
        print(f"Saved collage for example {example_idx} to {out_path}")

def generate_summary_collage_from_checkpoints(checkpoints_root="checkpoints", output_dir="checkpoints/summary_collages"):
    model_dirs = [d for d in os.listdir(checkpoints_root) if os.path.isdir(os.path.join(checkpoints_root, d)) and d != "summary_collages"]

    all_model_outputs = {}
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    for model_name in model_dirs:
        json_path = os.path.join(checkpoints_root, model_name, "test_examples.json")
        if not os.path.exists(json_path):
            print(f"Skipping {model_name}, no test_examples.json found.")
            continue

        with open(json_path, "r") as f:
            data = json.load(f)

        all_model_outputs[model_name] = {}
        for idx_str, entry in data.items():
            idx = int(idx_str)
            all_model_outputs[model_name][idx] = {
                "lr": TF.to_tensor(Image.open(entry["lr"]).convert("RGB")).unsqueeze(0),
                "sr": TF.to_tensor(Image.open(entry["sr"]).convert("RGB")).unsqueeze(0),
                "hr": TF.to_tensor(Image.open(entry["hr"]).convert("RGB")).unsqueeze(0),
                "psnr": entry["psnr"],
                "ssim": entry["ssim"]
            }
    # Create the collage
    create_multi_model_collage(output_path = output_dir)

    print("All collages created successfully.")

In [20]:
#====================== Main =========================

sys.path.append(os.getcwd())
def load_config(path="config/config.yaml"):
    with open(path, "r") as f:
        return yaml.safe_load(f)

def main():
    # === Load Config ===
    config = load_config()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # === Optional wandb init ===
    if config.get("use_wandb", False):
        wandb.init(project="super-resolution", config=config, name=config["model"])

    # === Step 1: Preprocessing ===
    train_loader, val_loader, test_loader, forced_indices = preprocess_pipeline(config)

    # === Step 2: Training ===
    if config["train"] == True:
        trained_model, history = train_pipeline(
            config=config,
            train_loader=train_loader,
            val_loader=val_loader,
            device=device
        )

    # === Step 3: Testing ===
    if config["test"] == True:
        metrics = test_pipeline(
            config=config,
            test_loader=test_loader,
            forced_indices=forced_indices,
            device=device,
            history=history
        )

    print("All stages complete.")

main()
#=====================================================

Using device: cuda
Downloading and preparing DIV2K dataset...
DIV2K zip already exists.
Using TiledDIV2KDataset
Preprocessing 800 HR images for tiling using 8 threads...


Tiling DIV2K: 100%|██████████| 800/800 [00:00<00:00, 7687.81it/s]


Total crop pairs generated: 5052
Dataset ready and reproducible.
Training model: VDSR_SA with loss: NewCombinedLoss


Epoch 1/10:   0%|          | 0/1011 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 256.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 194.12 MiB is free. Process 17290 has 14.55 GiB memory in use. Of the allocated memory 14.41 GiB is allocated by PyTorch, and 21.31 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)