In [None]:

!pip install segmentation-models-pytorch
!pip install pytorch-msssim
!pip install ptflops
!pip install lpips
!pip install --upgrade pip
!pip install torch
!pip install cv2
!pip install scikit-image
!pip install albumentations
!pip install einops
!pip install wandb
!pip install torchmetrics
!pip install pyiqa
!pip install pytorch_fid
!pip install piqa

#!apt-get update && apt-get install libgl1 -y
#!apt-get update && apt-get install -y python3-opencv
#!pip install opencv-python
from torch.nn.utils import spectral_norm
from pytorch_msssim import ms_ssim, ssim, ssim as f_ssim
from ptflops import get_model_complexity_info
import segmentation_models_pytorch as smp
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import math
import torch.nn.init as init
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CIFAR10, CIFAR100
from torchvision.datasets import ImageFolder
from sklearn.model_selection import train_test_split
from torchvision.utils import save_image, make_grid
from torchvision.models import resnet50, densenet121, DenseNet121_Weights, ResNet50_Weights
import matplotlib.pyplot as plt
from PIL import Image
import os
import sys
import gc
import numpy as np
import torchvision
from tqdm import tqdm
from skimage.metrics import mean_squared_error as f_mse
from skimage.metrics import peak_signal_noise_ratio as f_psnr
from skimage.metrics import structural_similarity as f_ssim
from skimage.metrics import normalized_root_mse as f_nrmse
from skimage.metrics import normalized_mutual_information as f_nmi
from torch.cuda.amp import autocast, GradScaler
from PIL import Image
import cv2
from torchvision.models import vgg19, VGG19_Weights, VGG16_Weights
import albumentations as A
from albumentations.pytorch import ToTensorV2
from einops import rearrange, repeat


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange, reduce
from torch.nn.parameter import Parameter

class AdaptiveFeatureNorm(nn.Module):
    """Novel Adaptive Feature Normalization module with learnable statistics"""
    def __init__(self, num_features, eps=1e-5):
        super(AdaptiveFeatureNorm, self).__init__()
        self.eps   = eps
        self.gamma = Parameter(torch.ones(1, num_features, 1, 1))
        self.beta  = Parameter(torch.zeros(1, num_features, 1, 1))
        
        # Adaptive statistics network
        self.stats_net = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(num_features, num_features//4, 1),
            nn.ReLU(True),
            nn.Conv2d(num_features//4, num_features*2, 1)
        )

    def forward(self, x):
        b, c, h, w = x.size()
        
        # Calculate adaptive statistics
        stats = self.stats_net(x)
        adaptive_gamma, adaptive_beta = torch.chunk(stats, 2, dim=1)
        
        # Instance normalization
        var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True)
        x_norm = (x - mean) / (var + self.eps).sqrt()
        
        # Apply adaptive scaling and shifting
        return (1 + adaptive_gamma) * self.gamma * x_norm + adaptive_beta * self.beta

class MultiScaleFrequencyAttention(nn.Module):
    """Novel multi-scale frequency attention module"""
    def __init__(self, dim, num_heads=8):
        super(MultiScaleFrequencyAttention, self).__init__()
        self.num_heads = num_heads
        self.scale = dim ** -0.5
        
        self.qkv  = nn.Conv2d(dim, dim*3, 1)
        self.proj = nn.Conv2d(dim, dim, 1)
        
        # Frequency decomposition branches
        self.freq_decomp = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(dim, dim//4, 3, padding=1, groups=dim//4),
                nn.GELU(),
                nn.Conv2d(dim//4, dim, 1)
            ) for _ in range(3)  # Low, mid, high frequencies
        ])
        
        # Frequency attention weights
        self.freq_weights = nn.Parameter(torch.ones(3))
        self.softmax      = nn.Softmax(dim=0)

    def forward(self, x):
        B, C, H, W = x.shape
        
        # Multi-head attention
        qkv = self.qkv(x).reshape(B, 3, self.num_heads, C // self.num_heads, H, W)
        q, k, v = qkv.unbind(1)
        
        # Attention computation
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        
        # Frequency decomposition
        freq_components = [decomp(x) for decomp in self.freq_decomp]
        freq_weights = self.softmax(self.freq_weights)
        
        # Combine frequency components
        freq_out = sum([w * f for w, f in zip(freq_weights, freq_components)])
        
        # Combine with spatial attention
        x = (attn @ v).transpose(1, 2).reshape(B, C, H, W)
        x = self.proj(x)
        
        return x + freq_out

class TemporalConsistencyModule(nn.Module):
    """Novel temporal consistency module with adaptive feature alignment"""
    def __init__(self, dim):
        super(TemporalConsistencyModule, self).__init__()
        
        # Feature alignment network
        self.alignment_net = nn.Sequential(
            nn.Conv2d(dim*2, dim//2, 3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(dim//2, 2, 3, padding=1)
        )
        
        # Temporal attention
        self.temporal_attn = nn.Sequential(
            nn.Conv2d(dim*2, dim//2, 1),
            nn.ReLU(True),
            nn.Conv2d(dim//2, dim, 1),
            nn.Sigmoid()
        )
        
        # Feature fusion
        self.fusion = nn.Conv2d(dim*2, dim, 1)

    def forward(self, current, previous):
        # Calculate optical flow
        flow = self.alignment_net(torch.cat([current, previous], dim=1))
        
        # Warp previous features
        grid = self.get_grid(flow)
        warped_prev = F.grid_sample(previous, grid, align_corners=True)
        
        # Temporal attention
        attn = self.temporal_attn(torch.cat([current, warped_prev], dim=1))
        
        # Feature fusion
        fused = self.fusion(torch.cat([current * attn, warped_prev * (1-attn)], dim=1))
        return fused.float()

    def get_grid(self, flow):
        B, _, H, W = flow.size()
        xx = torch.arange(0, W).view(1,-1).repeat(H,1)
        yy = torch.arange(0, H).view(-1,1).repeat(1,W)
        xx = xx.view(1,1,H,W).repeat(B,1,1,1)
        yy = yy.view(1,1,H,W).repeat(B,1,1,1)
        grid = torch.cat((xx,yy),1).float().to(flow.device)
        vgrid = grid + flow
        
        # Scale grid to [-1,1]
        vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone()/max(W-1,1)-1.0
        vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone()/max(H-1,1)-1.0
        return vgrid.permute(0,2,3,1)

class AdaptiveResidualBlock(nn.Module):
    """Novel adaptive residual block with dynamic routing"""
    def __init__(self, dim):
        super(AdaptiveResidualBlock, self).__init__()
        
        self.branch1 = nn.Sequential(
            nn.Conv2d(dim, dim//4, 3, padding=1),
            AdaptiveFeatureNorm(dim//4),
            nn.GELU(),
            nn.Conv2d(dim//4, dim, 3, padding=1),
            AdaptiveFeatureNorm(dim)
        )
        
        self.branch2 = MultiScaleFrequencyAttention(dim)
        
        # Dynamic routing network
        self.router = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(dim, 2, 1),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        route_weights = self.router(x)
        out1 = self.branch1(x)
        out2 = self.branch2(x)
        
        return x + route_weights[:,0:1,:,:] * out1 + route_weights[:,1:2,:,:] * out2

def adjust(x1, x2):
    x1 = F.interpolate(x1, size=x2.shape[2:], mode='nearest')
    return x1

class AFTNet(nn.Module):
    """Advanced Adaptive Frequency-Temporal Network for Image Deblurring"""
    def __init__(self, in_channels=3, dim=32, num_blocks=4):
        super(AFTNet, self).__init__()
        
        # Initial feature extraction
        self.init_conv = nn.Sequential(
            nn.Conv2d(in_channels, dim, 3, padding=1),
            AdaptiveFeatureNorm(dim)
        )
        
        # Encoder
        self.encoder = nn.ModuleList([
            nn.Sequential(
                AdaptiveResidualBlock(dim * (2**i)),
                nn.Conv2d(dim * (2**i), dim * (2**(i+1)), 2, stride=2),
                AdaptiveFeatureNorm(dim * (2**(i+1)))
            ) for i in range(3)
        ])

        # bringing out prev feature to same level as the middle
        self.conv_middle = nn.Conv2d(dim, dim*8, 1)
        
        # Middle blocks with temporal consistency
        self.middle = nn.ModuleList([
            nn.Sequential(
                AdaptiveResidualBlock(dim * 8),
                TemporalConsistencyModule(dim * 8)
            ) for _ in range(num_blocks)
        ])
        
        # Decoder
        self.decoder = nn.ModuleList([
            nn.Sequential(
                nn.ConvTranspose2d(dim * (2**(i+1)), dim * (2**i), 2, stride=2),
                AdaptiveFeatureNorm(dim * (2**i)),
                AdaptiveResidualBlock(dim * (2**i))
            ) for i in range(2, -1, -1)
        ])
        
        # Feature pyramid fusion
        self.pyramid_fusion = nn.ModuleList([
            nn.Conv2d(dim * (2**i) * 2, dim * (2**i), 1)
            for i in range(3)
        ])

        # Multi-scale output
        self.output_layers = nn.ModuleList([
            nn.Conv2d(dim * (2**i), in_channels, 3, padding=1)
            for i in range(4)
        ])

    def forward(self, x, prev_frame=None):
        r = x
        if prev_frame is None:
            prev_frame = x
            
        # Initial features
        x = self.init_conv(x)
        with torch.no_grad():
            prev_features = self.init_conv(prev_frame)
        
        # Encoder
        encoder_features = [x]
        for enc in self.encoder:
            x = enc(x)
            encoder_features.append(x)

        prev_features = F.interpolate(prev_features, size=x.shape[2:], mode='bicubic', align_corners=False)
        prev_features = self.conv_middle(prev_features)

        # Middle blocks with temporal consistency
        for block in self.middle:
            x = block[0](x)                 # Residual block
            x = block[1](x, prev_features)  # Temporal consistency
            prev_features = x

        # Multi-scale outputs
        outputs = [(self.output_layers[-1](x)+F.interpolate(r, size=x.shape[2:], mode='bicubic', align_corners=False)).clamp(0, 1)] 
        
        # Decoder with feature pyramid fusion
        for i, dec in enumerate(self.decoder):
            # Upsample current features
            x = dec[0](x)  # Upsample
            
            # Fusion with encoder features
            s = encoder_features[::-1][1:][i]
            x = adjust(x, s)
            x = self.pyramid_fusion[::-1][i](torch.cat([x, s], dim=1))
            
            # Apply remaining decoder operations
            x = dec[1:](x).float()
            
            # Generate output at current scale
            outputs.append((self.output_layers[-(i+2)](x)+F.interpolate(r, size=x.shape[2:], mode='bicubic', align_corners=False)).clamp(0, 1))
        
        return outputs[::-1]  # Return multi-scale outputs from fine to coarse


# Discriminator

In [None]:
# models/discriminator.py
class Discriminator(nn.Module):
    def __init__(self, input_nc=3, ndf=64, n_layers=3):
        super().__init__()

        model = [
            nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2**i, 8)
            model += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm2d(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2**n_layers, 8)
        model += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=4, stride=1, padding=1),
            nn.BatchNorm2d(ndf * nf_mult),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(ndf * nf_mult, 1, kernel_size=4, stride=1, padding=1)
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# Losses

In [None]:
from torch.autograd import Variable
from torch.fft import fft2, ifft2

class CharbonnierLoss(nn.Module):
    def __init__(self, epsilon=1e-3):
        super().__init__()
        self.epsilon = epsilon

    def forward(self, pred, target):
        diff = pred - target
        loss = torch.mean(Variable(torch.sqrt(diff * diff + self.epsilon * self.epsilon).type(torch.FloatTensor), requires_grad=True))
        return loss

class MSEGDL(nn.Module):
    def __init__(self, lambda_mse=1, lambda_gdl=1):
        super(MSEGDL, self).__init__()
        self.lambda_mse = lambda_mse
        self.lambda_gdl = lambda_gdl

    def forward(self, inputs, targets):

        squared_error = (inputs - targets).pow(2)
        gradient_diff_i = (inputs.diff(axis=-1)-targets.diff(axis=-1)).pow(2)
        gradient_diff_j =  (inputs.diff(axis=-2)-targets.diff(axis=-2)).pow(2)
        loss = (self.lambda_mse*squared_error.sum() + self.lambda_gdl*gradient_diff_i.sum() + self.lambda_gdl*gradient_diff_j.sum())/inputs.numel()

        return loss

class SSIMLoss(nn.Module):
    def __init__(self, data_range=1.0, size_average=True):
        super(SSIMLoss, self).__init__()
        self.data_range = data_range
        self.size_average = size_average

    def forward(self, img1, img2):
        return 1 - Variable(ssim(img1, img2, data_range=self.data_range, size_average=self.size_average).type(torch.FloatTensor), requires_grad=True)

class MSSSIMLoss(nn.Module):
    def __init__(self, data_range=1.0, size_average=True):
        super(MSSSIMLoss, self).__init__()
        self.data_range = data_range
        self.size_average = size_average

    def forward(self, img1, img2):
        return 1 - Variable(ms_ssim(img1, img2, data_range=self.data_range, size_average=self.size_average).type(torch.FloatTensor), requires_grad=True)

class VGGLoss(nn.Module):
    def __init__(self, layer=36):
        super().__init__()

        self.vgg = vgg19(weights=VGG19_Weights.DEFAULT).features[:layer].eval()
        self.loss = nn.MSELoss()

        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, output, target):
        self.vgg.eval()
        vgg_input_features = self.vgg(output)
        vgg_target_features = self.vgg(target)
        loss = self.loss(vgg_input_features, vgg_target_features)
        del vgg_input_features, vgg_target_features
        gc.collect()
        torch.cuda.empty_cache()
        return loss

class DeblurLoss(nn.Module):
    """Advanced loss function combining multiple objectives"""
    def __init__(self):
        super(DeblurLoss, self).__init__()
        self.l1_loss   = nn.L1Loss()
        self.mse_loss  = nn.MSELoss()
        self.gdl_loss  = MSEGDL()
        self.ssim_loss = SSIMLoss()
        self.vgg_loss  = VGGLoss()
            
    def get_frequency_loss(self, pred, target):
        # FFT-based frequency loss
        pred_freq = torch.fft.fft2(pred)
        target_freq = torch.fft.fft2(target)
        return F.mse_loss(pred_freq.abs(), target_freq.abs())

    def forward(self, pred_list, target):
        total_loss = 0
        weights = [1.0, 0.75, 0.45, 0.35]  # Weights for different scales
        
        for pred, weight in zip(pred_list, weights):
            # Resize target to match prediction if needed
            if pred.shape != target.shape:
                target_resized = F.interpolate(target, size=pred.shape[2:]).to(target)
            else:
                target_resized = target

            pred = pred.to(target)
            # Pixel loss
            pixel_loss = self.l1_loss(pred, target_resized)
            
            # Frequency loss
            freq_loss = self.get_frequency_loss(pred, target_resized)
            
            # Perceptual loss
            perc_loss = self.vgg_loss(pred, target_resized)

            # SSIM loss
            ssim_loss = self.ssim_loss(pred, target_resized)

            # GDL loss
            gdl_loss = self.gdl_loss(pred, target_resized)
            
            # Combine losses with weights
            total_loss += weight * (
                1.0 * pixel_loss + 
                0.6 * ssim_loss +
                0.3 * gdl_loss +
                0.1 * freq_loss + 
                0.8 * perc_loss
            )
            
        return total_loss

# Utilities

In [None]:
import random

# Training and Validation Functions
def calculate_metrics(pred, target):
    """Calculate PSNR and SSIM metrics"""
    mse = F.mse_loss(pred, target)
    psnr = 10 * torch.log10(1 / mse)
    ssim_value = ssim(pred, target, data_range=1.0, size_average=True)
    return psnr.item(), ssim_value.item()

def get_model_size(model):
    """
    Calculates the size of a PyTorch model in megabytes (MB).

    Args:
        model (torch.nn.Module): The PyTorch model to calculate the size for.

    Returns:
        float: The size of the model in megabytes (MB).
    """
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()

    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    total_size = (param_size + buffer_size) / 1024**2
    return total_size

def plot_dataset(train_loader):
    fig, axes = plt.subplots(2, 5, figsize=(14, 7))

    for i, (low_res, high_res) in enumerate(train_loader):
        if i >= 5:
            plt.show()
            break

        axes[0, i].imshow(low_res[0].permute(1, 2, 0))
        axes[0, i].set_title("Low Resolution")
        axes[0, i].axis('off')

        axes[1, i].imshow(high_res[0].permute(1, 2, 0))
        axes[1, i].set_title("High Resolution")
        axes[1, i].axis('off')

def get_pil_image(image_tensor):
    transform = transforms.Compose([
        transforms.Lambda(lambda t: t.permute(1, 2, 0)),
        transforms.Lambda(lambda t: t*255.),
        transforms.Lambda(lambda t: t.cpu().numpy().astype(np.uint8)),
        transforms.ToPILImage()
    ])
    return transform(image_tensor)

def save_image_tensor(tensor_image, image_name):
  # Convert the tensor image to a PIL image
  pil_image = get_pil_image(tensor_image.squeeze(0))
  # Save the PIL image
  pil_image.save(image_name)

def save_pil_image(image, image_name):
    image.save(image_name)

def save_samples(encoder, real_images, sharp_images, index, sample_dir='generated', show=True, device='cuda'):
  with torch.no_grad():
    #Sample random style code
    fake_images = encoder(real_images)[0]
    fake_name   = "generated-images-{0:0=4d}.png".format(index)
    save_image(fake_images, os.path.join(sample_dir, fake_name), nrow=8)
    if show:
        fig, ax = plt.subplots(figsize=(20, 20))
        ax.set_xticks([])
        ax.set_yticks([])
        ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0))
        plt.show()
        fig, ax = plt.subplots(figsize=(20, 20))
        ax.set_xticks([])
        ax.set_yticks([])
        ax.imshow(make_grid(sharp_images.cpu().detach(), nrow=8).permute(1, 2, 0))
        plt.show()

def show_images(images):
    fig, ax = plt.subplots(figsize=(20, 20))
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(make_grid(images.cpu().detach(), nrow=8).permute(1, 2, 0))
    plt.show()

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def get_model_size(model):
    """
    Calculates the size of a PyTorch model in megabytes (MB).

    Args:
        model (torch.nn.Module): The PyTorch model to calculate the size for.

    Returns:
        float: The size of the model in megabytes (MB).
    """
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()

    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()

    total_size = (param_size + buffer_size) / 1024**2
    return total_size

# Datasets

In [None]:
import torch
from torch.utils.data import Dataset
from pathlib import Path
from PIL import Image, ImageFile
import torchvision.transforms as transforms
import random
import numpy as np
import os

# Allow loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

class DeblurDataset(Dataset):
    def __init__(self, root_dir, split='train', patch_size=256):
        """
        Dataset for deblurring training/validation with support for multiple image extensions
        
        Args:
            root_dir: Root directory containing 'sharp' and 'blur' subdirectories
            split: 'train' or 'val'
            patch_size: Size of training patches (only used during training)
        """
        self.root_dir = Path(root_dir)
        self.split = split
        self.patch_size = patch_size
        
        # Get sharp images with multiple extensions
        self.sharp_dir = self.root_dir / 'sharp'
        self.blur_dir = self.root_dir / 'blur'
        
        # Common image extensions
        self.extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif', '.webp']
        
        # Get all sharp images
        self.sharp_files = []
        for ext in self.extensions:
            self.sharp_files.extend(list(self.sharp_dir.glob(f'*{ext}')))
            self.sharp_files.extend(list(self.sharp_dir.glob(f'*{ext.upper()}')))
        
        # Sort the files to ensure deterministic behavior
        self.sharp_files = sorted(self.sharp_files)
        
        print(f"Found {len(self.sharp_files)} images in {self.sharp_dir}")
        
        # Basic transforms
        self.transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        
        # Augmentation transforms for training
        self.augment = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(90),
        ]) if split == 'train' else None
        
    def __len__(self):
        return len(self.sharp_files)
    
    def get_random_crop_params(self, img):
        """Get random crop parameters"""
        w, h = img.size
        th, tw = self.patch_size, self.patch_size
        if w == tw and h == th:
            return 0, 0, h, w
        if w < tw or h < th:
            # Handle images smaller than patch size by resizing
            scale = max(tw / w, th / h) * 1.1  # Scale up with a small margin
            new_w, new_h = int(w * scale), int(h * scale)
            img = img.resize((new_w, new_h), Image.BICUBIC)
            w, h = new_w, new_h
        
        i = random.randint(0, h - th)
        j = random.randint(0, w - tw)
        return i, j, th, tw, img
    
    def __getitem__(self, idx):
        try:
            # Load sharp image
            sharp_path = self.sharp_files[idx]
            
            # Get corresponding blur image with same name
            file_name = sharp_path.name
            blur_path = self.blur_dir / file_name
            
            # If blur file doesn't exist with exact name, try matching without extension
            if not blur_path.exists():
                stem = sharp_path.stem
                for ext in self.extensions:
                    candidate = self.blur_dir / f"{stem}{ext}"
                    if candidate.exists():
                        blur_path = candidate
                        break
                    
                    # Also try with uppercase extension
                    candidate = self.blur_dir / f"{stem}{ext.upper()}"
                    if candidate.exists():
                        blur_path = candidate
                        break
            
            # If still no match, use a fallback
            if not blur_path.exists():
                print(f"Warning: No matching blur image for {file_name}")
                # Return a random sample as fallback
                return self.__getitem__(random.randint(0, len(self) - 1))
            
            # Open images with PIL
            try:
                sharp_img = Image.open(sharp_path).convert('RGB')
                blur_img = Image.open(blur_path).convert('RGB')
            except Exception as e:
                print(f"Error loading images: {e}")
                # Return a random sample as fallback
                return self.__getitem__(random.randint(0, len(self) - 1))
            
            # Ensure both images have the same size
            if sharp_img.size != blur_img.size:
                blur_img = blur_img.resize(sharp_img.size, Image.BICUBIC)
            
            # Random crop for training
            if self.split == 'train':
                # Handle random cropping with potential resizing
                i, j, h, w, sharp_img_resized = self.get_random_crop_params(sharp_img)
                if sharp_img_resized is not sharp_img:  # If image was resized
                    sharp_img = sharp_img_resized
                    blur_img = blur_img.resize(sharp_img.size, Image.BICUBIC)
                
                # Crop both images to the same region
                sharp_img = sharp_img.crop((j, i, j + w, i + h))
                blur_img = blur_img.crop((j, i, j + w, i + h))
                
                # Apply augmentation
                if random.random() > 0.5 and self.augment:
                    state = torch.get_rng_state()
                    sharp_img = self.augment(sharp_img)
                    torch.set_rng_state(state)
                    blur_img = self.augment(blur_img)
            
            # Convert to tensors
            sharp_tensor = self.transform(sharp_img)
            blur_tensor = self.transform(blur_img)
            
            return blur_tensor, sharp_tensor
            
        except Exception as e:
            print(f"Error processing image {idx}: {e}")
            # Return a random sample as fallback
            return self.__getitem__(random.randint(0, len(self) - 1))


def create_dataloaders(root_dir_train, root_dir_val, batch_size=8, patch_size=256, num_workers=4):
    """Create training and validation dataloaders"""
    train_dataset = DeblurDataset(root_dir_train, split='train', patch_size=patch_size)
    val_dataset   = DeblurDataset(root_dir_val, split='train', patch_size=patch_size)
    
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader

# Metrics

In [None]:
# utils/metrics.py
import torch
import lpips
from pytorch_fid import fid_score
import pyiqa
import wandb
from pytorch_msssim import ssim, ms_ssim
from skimage.metrics import peak_signal_noise_ratio as f_psnr

class MetricsCalculator:
    def __init__(self, device):
        self.lpips_fn = lpips.LPIPS(net='vgg').to(device)
        self.ssim = ssim
        self.niqe = pyiqa.create_metric('niqe').to(device)

    def calculate_metrics(self, pred, target):
        with torch.no_grad():
            psnr        = f_psnr(pred.detach().cpu().numpy(), target.detach().cpu().numpy(), data_range=1.0)
            ssim        = self.ssim(pred, target).cpu()
            lpips_value = self.lpips_fn(pred, target).mean().cpu()
            #niqe_value  = self.niqe(pred.clip(0.0, 1.0)).mean().cpu()

            return {
                'psnr': psnr.item(),
                'ssim': ssim.item(),
                'lpips': lpips_value.item(),
                'niqe': 0 #niqe_value.item()
            }

# Train

In [None]:
# train.py
import torch
import torch.nn as nn
import wandb
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import time
from pathlib import Path

class Trainer:
    def __init__(self, config):
        self.config        = config
        self.log_dir       = config.log_dir
        self.generated_dir = config.generated_dir
        self.device        = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Initialize models
        self.netG = AFTNet().to(self.device)

        # Initialize optimizers
        self.optimG = torch.optim.AdamW(self.netG.parameters(), lr=config.lr, weight_decay=0.01, betas=(0.5, 0.999))

        # Scheduler
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimG, T_0=10, T_mult=2)

        self.com_criterion = DeblurLoss().to(self.device)

        self.path = f"{self.log_dir}/checkpoint_global.pt"

        print("Model's Total Num Model Parameters: {}".format(sum([param.nelement() for param in self.netG.parameters()])))
        model_size = get_model_size(self.netG)
        print(f"The model size is {model_size:.2f} MB")

        # Initialize metrics calculator
        self.metrics = MetricsCalculator(self.device)

        # Initialize wandb
        #wandb.login(key=config.key)
        #wandb.init(project=config.project_name, name=config.name, config=config.__dict__)

    def train(self, train_loader, val_loader, resume_from=None):
        # Resume if checkpoint provided
        if self.config.resume and resume_from is not None:
            start_epoch = self.load_checkpoint(resume_from)
        else:
            start_epoch = self.load_checkpoint(self.path)

        print(f"Resuming from step {self.path} with : (epoch {start_epoch})")

        try:
            for epoch in range(start_epoch, self.config.epochs):
                self.netG.train()
                pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{self.config.epochs}')
                for i, (blurred, target) in enumerate(pbar):
                    blurred = blurred.float().to(self.device)
                    target  = target.float().to(self.device)
        
                    # Train Generator
                    self.optimG.zero_grad()
                    outputs = self.netG(blurred)
                    
                    loss = self.com_criterion(outputs, target)
                    loss.backward()
                    
                    # Gradient clipping
                    torch.nn.utils.clip_grad_norm_(self.netG.parameters(), max_norm=1.0)
                    
                    self.optimG.step()
                    
                    metrics = self.metrics.calculate_metrics(outputs[0].detach(), target)
                    # Update progress bar
                    pbar.set_postfix({k: f'{v:.3f}' for k, v in metrics.items()})
        
                    if i % self.config.save_frequency == 0:
                        # Save epoch checkpoint
                        checkpoint_path = f"{self.log_dir}/checkpoint_batch.pt"
                        self.save_checkpoint(checkpoint_path, epoch)

                self.scheduler.step()

                # Save epoch checkpoint
                checkpoint_path = f"{self.log_dir}/checkpoint_global.pt"
                self.save_checkpoint(checkpoint_path, epoch)
                print(f"Saved epoch checkpoint to {checkpoint_path}")
                
                if epoch % self.config.val_frequency == 0:
                    # Validation
                    self.netG.eval()
                    val_loss = 0
                    with torch.no_grad():
                        for blurred, target in val_loader:
                            blurred, target = blurred.float().to(self.device), target.float().to(self.device)
                            outputs = self.netG(blurred)
                            val_loss += self.com_criterion(outputs, target).item()
                    
                    print(f'Epoch: {epoch}, Validation Loss: {val_loss/len(val_loader):.4f}')

                if(epoch+1)%200==0:
                  real_images, sharp_images = next(iter(val_loader))
                  real_images = real_images.to(self.device)
                  save_samples(self.netG, real_images, sharp_images, epoch+1, sample_dir=self.generated_dir)
                  del real_images
                  torch.cuda.empty_cache()

        except KeyboardInterrupt:
            print("Training interrupted by user")
            # Save interrupted checkpoint
            checkpoint_path = f"{self.log_dir}/checkpoint_interrupted.pt"
            self.save_checkpoint(checkpoint_path, epoch)
            print(f"Saved interrupt checkpoint to {checkpoint_path}")
        finally:
            # Save final checkpoint
            checkpoint_path = f"{self.log_dir}/checkpoint_final.pt"
            self.save_checkpoint(checkpoint_path, self.config.epochs)
            print(f"Saved final checkpoint to {checkpoint_path}")
        
    def log_metrics(self, loss_dict, metrics, epoch, iteration):
        # Log losses
        wandb.log({
            'train/total_loss': loss_dict['total_g'],
            'epoch': epoch,
            'iteration': iteration
        })

        # Log metrics
        wandb.log({
            'train/psnr': metrics['psnr'],
            'train/ssim': metrics['ssim'],
            'train/lpips': metrics['lpips'],
            'train/niqe': metrics['niqe']
        })

    def log_images(self, blurred, sharp, fake):
        # Create image grid
        img_grid = make_grid(torch.cat([
            blurred, sharp, fake
        ], dim=0), nrow=sharp.size(0), normalize=True, value_range=(-1, 1))

        wandb.log({
            'images': wandb.Image(img_grid, caption='Blurred | Sharp | Deblurred')
        })

    def log_validation_metrics(self, metrics, epoch):
        wandb.log({
            'val/psnr': metrics['psnr'],
            'val/ssim': metrics['ssim'],
            'val/lpips': metrics['lpips'],
            'val/niqe': metrics['niqe'],
            'epoch': epoch
        })

    def save_checkpoint(self, path: str, epoch: int):
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.netG.state_dict(),
            'optim_state_dict': self.optimG.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
        }, path)

    def load_checkpoint(self, path: str) -> int:
        checkpoint = {}
        checkpoint['epoch'] = 0
        if self.config.resume and os.path.exists(path):
            checkpoint = torch.load(path, map_location=self.device, weights_only=False)
            self.netG.load_state_dict(checkpoint['model_state_dict'])
            self.optimG.load_state_dict(checkpoint['optim_state_dict'])
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        return checkpoint['epoch']

# Training

In [None]:
set_seed(42)

checkpoint_dir = './checkpoints'
results_dir = './results'
samples_dir = './samples'
generated_dir = './generated'
os.makedirs(samples_dir, exist_ok=True)
os.makedirs(results_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(generated_dir, exist_ok=True)


config = type('Config', (), {
    'project_name': 'aft-net',
    'name': 'aft-net',
    'lr': 2e-4,
    'epochs': 2500,
    'batch_size': 16,
    'latent_dim': 8,
    'lambda_l1': 10.0,
    'lambda_kl': 0.01,
    'lambda_tv': 0.1,
    'lambda_adv': 0.01,
    'log_frequency': 2500,
    'val_frequency': 500,
    'save_frequency': 500,
    'key': '', # wandb key
    'log_dir': checkpoint_dir,
    'generated_dir': generated_dir,
    'resume': True
})()

train_dir = ''
val_dir   = ''

train_loader, val_loader = create_dataloaders(train_dir, val_dir, batch_size=config.batch_size, patch_size=256, num_workers=4)

plot_dataset(train_loader)

trainer = Trainer(config)

trainer.train(train_loader, val_loader)
generator = trainer.netG

# Inferences

In [None]:
def load_image(image_path):
    # Load the image
    image = Image.open(image_path)
    image = transforms.ToTensor()(image)
    image = image.unsqueeze(0)
    return image


def infer(image_path, sharp_file, file_name, generator, out_name = '/kaggle/working/gen_results', device='cpu', timesteps=1000):

    generator.to(device)
    generator.eval()

    # Load the image
    image = load_image(image_path).to(device)
    print("Processing image: ", image_path)
    print(image.shape)
    sr_imgs = generator(image)[0]
    print(sr_imgs.shape)

    # Save the image
    # Save the image
    save_image_tensor(image, out_name+f'normal_{file_name}')
    save_image_tensor(sr_imgs, out_name+f'upsample_16_{file_name}')

    hr_image = load_image(sharp_file).to(device)
    print(hr_image.shape)

    psnr, ssim = 0, 0 #calculate_metrics(hr_image, sr_imgs)
    print(f"PSNR: {psnr}, SSIM: {ssim}")
    # Display the original, compressed, and decompressed images
    plt.subplot(1, 3, 1)
    plt.imshow(get_pil_image(image.detach().cpu().squeeze(0)))
    plt.title('Low Resolution Image')
    plt.axis('off')
    plt.subplot(1, 3, 2)
    plt.imshow(get_pil_image(hr_image.detach().cpu().squeeze(0)))
    plt.title('Original Image')
    plt.axis('off')
    plt.subplot(1, 3, 3)
    plt.imshow(get_pil_image(sr_imgs.detach().cpu().squeeze(0)))
    plt.title('UpScale Image')
    plt.axis('off')
    plt.tight_layout()
    plt.show()
    #plt.save(out_name+f'generated_{file_name}')
    plt.close()
    del hr_image, image, sr_imgs
    torch.cuda.empty_cache()



In [None]:
with torch.no_grad():
    data_path = '/kaggle/input/a-curated-list-of-image-deblurring-datasets/DBlur/Wider-Face/test/blur'
    sharp_dir = '/kaggle/input/a-curated-list-of-image-deblurring-datasets/DBlur/Wider-Face/test/sharp'
    results_dir = 'samples/'
    os.makedirs(results_dir, exist_ok=True)
    list_of_files      = sorted(os.listdir(os.path.join(data_path)))
    list_of_shap_files = sorted(os.listdir(os.path.join(sharp_dir)))
    i = 0
    for j, file in enumerate(list_of_files):
        if file == '.ipynb_checkpoints':
            continue
        file_path  = data_path + '/' + file
        sharp_path = sharp_dir + '/' + list_of_shap_files[i]
        infer(file_path, sharp_path, file, generator, out_name=results_dir, device='cpu')
        i = i+1
        if i==25:
            break