In [2]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
from pathlib import Path
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import math

In [3]:
# ==================== Restormer Components ====================

class LayerNorm(nn.Module):
    """Layer Normalization that supports two data formats: channels_last and channels_first."""
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape, )

    def forward(self, x):
        if self.data_format == "channels_last":
            return nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x


class GDFN(nn.Module):
    """Gated-Dconv Feed-Forward Network"""
    def __init__(self, dim, ffn_expansion_factor=2.66, bias=False):
        super().__init__()
        hidden_features = int(dim * ffn_expansion_factor)

        self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
        self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3,
                               stride=1, padding=1, groups=hidden_features * 2, bias=bias)
        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = nn.functional.gelu(x1) * x2
        x = self.project_out(x)
        return x


class MDTA(nn.Module):
    """Multi-Dconv Head Transposed Attention"""
    def __init__(self, dim, num_heads=8, bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1,
                                    padding=1, groups=dim * 3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        b, c, h, w = x.shape

        qkv = self.qkv_dwconv(self.qkv(x))
        q, k, v = qkv.chunk(3, dim=1)

        q = q.reshape(b, self.num_heads, -1, h * w)
        k = k.reshape(b, self.num_heads, -1, h * w)
        v = v.reshape(b, self.num_heads, -1, h * w)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = attn.softmax(dim=-1)

        out = (attn @ v)
        out = out.reshape(b, -1, h, w)

        out = self.project_out(out)
        return out


class TransformerBlock(nn.Module):
    """Restormer Transformer Block"""
    def __init__(self, dim, num_heads=8, ffn_expansion_factor=2.66, bias=False):
        super().__init__()

        self.norm1 = LayerNorm(dim, data_format='channels_first')
        self.attn = MDTA(dim, num_heads, bias)
        self.norm2 = LayerNorm(dim, data_format='channels_first')
        self.ffn = GDFN(dim, ffn_expansion_factor, bias)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x


class OverlapPatchEmbed(nn.Module):
    """Overlapping Patch Embedding"""
    def __init__(self, in_c=3, embed_dim=48, bias=False):
        super().__init__()
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, x):
        x = self.proj(x)
        return x


class Downsample(nn.Module):
    """Downsampling layer"""
    def __init__(self, n_feat):
        super().__init__()
        self.body = nn.Sequential(
            nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False),
            nn.PixelUnshuffle(2)
        )

    def forward(self, x):
        return self.body(x)


class Upsample(nn.Module):
    """Upsampling layer"""
    def __init__(self, n_feat):
        super().__init__()
        self.body = nn.Sequential(
            nn.Conv2d(n_feat, n_feat * 2, kernel_size=3, stride=1, padding=1, bias=False),
            nn.PixelShuffle(2)
        )

    def forward(self, x):
        return self.body(x)


class Restormer(nn.Module):
    """
    Restormer: Efficient Transformer for High-Resolution Image Restoration (CVPR 2022)
    Paper: https://arxiv.org/abs/2111.09881
    """
    def __init__(self,
                 inp_channels=3,
                 out_channels=3,
                 dim=48,
                 num_blocks=[4, 6, 6, 8],
                 num_refinement_blocks=4,
                 heads=[1, 2, 4, 8],
                 ffn_expansion_factor=2.66,
                 bias=False):
        super().__init__()

        self.patch_embed = OverlapPatchEmbed(inp_channels, dim)

        self.encoder_level1 = nn.Sequential(
            *[TransformerBlock(dim=dim, num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias) for i in range(num_blocks[0])])

        self.down1_2 = Downsample(dim)
        self.encoder_level2 = nn.Sequential(
            *[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias) for i in range(num_blocks[1])])

        self.down2_3 = Downsample(int(dim*2**1))
        self.encoder_level3 = nn.Sequential(
            *[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias) for i in range(num_blocks[2])])

        self.down3_4 = Downsample(int(dim*2**2))
        self.latent = nn.Sequential(
            *[TransformerBlock(dim=int(dim*2**3), num_heads=heads[3], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias) for i in range(num_blocks[3])])

        self.up4_3 = Upsample(int(dim*2**3))
        self.reduce_chan_level3 = nn.Conv2d(int(dim*2**3), int(dim*2**2), kernel_size=1, bias=bias)
        self.decoder_level3 = nn.Sequential(
            *[TransformerBlock(dim=int(dim*2**2), num_heads=heads[2], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias) for i in range(num_blocks[2])])

        self.up3_2 = Upsample(int(dim*2**2))
        self.reduce_chan_level2 = nn.Conv2d(int(dim*2**2), int(dim*2**1), kernel_size=1, bias=bias)
        self.decoder_level2 = nn.Sequential(
            *[TransformerBlock(dim=int(dim*2**1), num_heads=heads[1], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias) for i in range(num_blocks[1])])

        self.up2_1 = Upsample(int(dim*2**1))
        self.decoder_level1 = nn.Sequential(
            *[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias) for i in range(num_blocks[0])])

        self.refinement = nn.Sequential(
            *[TransformerBlock(dim=int(dim*2**1), num_heads=heads[0], ffn_expansion_factor=ffn_expansion_factor,
                             bias=bias) for i in range(num_refinement_blocks)])

        self.output = nn.Conv2d(int(dim*2**1), out_channels, kernel_size=3, stride=1, padding=1, bias=bias)

    def forward(self, inp_img):
        inp_enc_level1 = self.patch_embed(inp_img)
        out_enc_level1 = self.encoder_level1(inp_enc_level1)

        inp_enc_level2 = self.down1_2(out_enc_level1)
        out_enc_level2 = self.encoder_level2(inp_enc_level2)

        inp_enc_level3 = self.down2_3(out_enc_level2)
        out_enc_level3 = self.encoder_level3(inp_enc_level3)

        inp_enc_level4 = self.down3_4(out_enc_level3)
        latent = self.latent(inp_enc_level4)

        inp_dec_level3 = self.up4_3(latent)
        inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
        inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
        out_dec_level3 = self.decoder_level3(inp_dec_level3)

        inp_dec_level2 = self.up3_2(out_dec_level3)
        inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
        inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
        out_dec_level2 = self.decoder_level2(inp_dec_level2)

        inp_dec_level1 = self.up2_1(out_dec_level2)
        inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
        out_dec_level1 = self.decoder_level1(inp_dec_level1)

        out_dec_level1 = self.refinement(out_dec_level1)

        out_dec_level1 = self.output(out_dec_level1) + inp_img

        return out_dec_level1


In [4]:
# ==================== Dataset ====================
class GoPRODataset(Dataset):
    def __init__(self, blur_dir, sharp_dir, transform=None):
        self.blur_dir = Path(blur_dir)
        self.sharp_dir = Path(sharp_dir)
        self.transform = transform

        blur_images = list(self.blur_dir.glob('*.png')) + list(self.blur_dir.glob('*.jpg'))
        sharp_images = list(self.sharp_dir.glob('*.png')) + list(self.sharp_dir.glob('*.jpg'))

        print(f"Found {len(blur_images)} blur images")
        print(f"Found {len(sharp_images)} sharp images")

        blur_dict = {img.name: img for img in blur_images}
        sharp_dict = {img.name: img for img in sharp_images}

        common_names = sorted(set(blur_dict.keys()) & set(sharp_dict.keys()))

        if len(common_names) == 0:
            raise ValueError("No matching image pairs found!")

        self.blur_images = [blur_dict[name] for name in common_names]
        self.sharp_images = [sharp_dict[name] for name in common_names]

        print(f"Using {len(self.blur_images)} matched image pairs")

    def __len__(self):
        return len(self.blur_images)

    def __getitem__(self, idx):
        blur_img = Image.open(self.blur_images[idx]).convert('RGB')
        sharp_img = Image.open(self.sharp_images[idx]).convert('RGB')

        if self.transform:
            blur_img = self.transform(blur_img)
            sharp_img = self.transform(sharp_img)

        return blur_img, sharp_img



In [5]:

# ==================== Loss Functions ====================
class CharbonnierLoss(nn.Module):
    """Charbonnier Loss (L1 smooth loss)"""
    def __init__(self, eps=1e-3):
        super().__init__()
        self.eps = eps

    def forward(self, pred, target):
        diff = pred - target
        loss = torch.mean(torch.sqrt(diff * diff + self.eps))
        return loss


class EdgeLoss(nn.Module):
    """Edge-aware loss"""
    def __init__(self):
        super().__init__()
        k = torch.Tensor([[.05, .25, .4, .25, .05]])
        self.kernel = torch.matmul(k.t(), k).unsqueeze(0).repeat(3, 1, 1, 1)
        if torch.cuda.is_available():
            self.kernel = self.kernel.cuda()
        self.loss = CharbonnierLoss()

    def conv_gauss(self, img):
        n_channels, _, kw, kh = self.kernel.shape
        img = nn.functional.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
        return nn.functional.conv2d(img, self.kernel, groups=n_channels)

    def laplacian_kernel(self, current):
        filtered = self.conv_gauss(current)
        down = filtered[:, :, ::2, ::2]
        new_filter = torch.zeros_like(filtered)
        new_filter[:, :, ::2, ::2] = down * 4
        filtered = self.conv_gauss(new_filter)
        diff = current - filtered
        return diff

    def forward(self, pred, target):
        loss = self.loss(self.laplacian_kernel(pred), self.laplacian_kernel(target))
        return loss

In [6]:
# ==================== Training ====================
class RestormerTrainer:
    def __init__(self, train_loader, val_loader, device='cuda'):
        self.device = device
        self.train_loader = train_loader
        self.val_loader = val_loader

        # Model
        self.model = Restormer(
            inp_channels=3,
            out_channels=3,
            dim=48,
            num_blocks=[4, 6, 6, 8],
            num_refinement_blocks=4,
            heads=[1, 2, 4, 8],
            ffn_expansion_factor=2.66,
            bias=False
        ).to(device)

        # Optimizer
        self.optimizer = optim.AdamW(self.model.parameters(), lr=3e-4, betas=(0.9, 0.999),
                                     weight_decay=1e-4)

        # Scheduler
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=50, eta_min=1e-6)

        # Loss functions
        self.char_loss = CharbonnierLoss()
        self.edge_loss = EdgeLoss()

        # Tracking
        self.history = {'train_loss': [], 'val_loss': [], 'val_psnr': []}

    def calculate_psnr(self, pred, target):
        """Calculate PSNR"""
        mse = torch.mean((pred - target) ** 2)
        if mse == 0:
            return 100
        psnr = 20 * torch.log10(1.0 / torch.sqrt(mse))
        return psnr.item()

    def train_epoch(self, epoch):
        self.model.train()
        epoch_loss = 0

        pbar = tqdm(self.train_loader, desc=f'Epoch {epoch+1}')
        for blur_imgs, sharp_imgs in pbar:
            blur_imgs = blur_imgs.to(self.device)
            sharp_imgs = sharp_imgs.to(self.device)

            self.optimizer.zero_grad()

            restored = self.model(blur_imgs)

            # Combined loss
            loss_char = self.char_loss(restored, sharp_imgs)
            loss_edge = self.edge_loss(restored, sharp_imgs)
            loss = loss_char + (0.05 * loss_edge)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 0.01)
            self.optimizer.step()

            epoch_loss += loss.item()
            pbar.set_postfix({'Loss': f'{loss.item():.4f}'})

        return epoch_loss / len(self.train_loader)

    def validate(self):
        self.model.eval()
        val_loss = 0
        psnrs = []

        with torch.no_grad():
            for blur_imgs, sharp_imgs in self.val_loader:
                blur_imgs = blur_imgs.to(self.device)
                sharp_imgs = sharp_imgs.to(self.device)

                restored = self.model(blur_imgs)

                loss = self.char_loss(restored, sharp_imgs)
                val_loss += loss.item()

                psnr = self.calculate_psnr(restored, sharp_imgs)
                psnrs.append(psnr)

        return val_loss / len(self.val_loader), np.mean(psnrs)

    def train(self, num_epochs=100, save_dir='checkpoints'):
        os.makedirs(save_dir, exist_ok=True)
        best_psnr = 0

        for epoch in range(num_epochs):
            train_loss = self.train_epoch(epoch)
            val_loss, val_psnr = self.validate()

            self.scheduler.step()

            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            self.history['val_psnr'].append(val_psnr)

            print(f'Epoch {epoch+1}/{num_epochs}')
            print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val PSNR: {val_psnr:.2f} dB')
            print(f'LR: {self.optimizer.param_groups[0]["lr"]:.6f}')
            print('-' * 50)

            # Save best model
            if val_psnr > best_psnr:
                best_psnr = val_psnr
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'best_psnr': best_psnr,
                    'scheduler_state_dict': self.scheduler.state_dict()
                }, f'{save_dir}/best_model.pth')
                print(f'âœ… Saved best model with PSNR: {best_psnr:.2f} dB')

            # Save checkpoint every 10 epochs
            if (epoch + 1) % 10 == 0:
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'psnr': val_psnr
                }, f'{save_dir}/checkpoint_epoch_{epoch+1}.pth')

        self.plot_history(save_dir)

    def plot_history(self, save_dir):
        fig, axes = plt.subplots(1, 3, figsize=(15, 4))

        axes[0].plot(self.history['train_loss'], label='Train')
        axes[0].plot(self.history['val_loss'], label='Val')
        axes[0].set_title('Loss')
        axes[0].set_xlabel('Epoch')
        axes[0].set_ylabel('Loss')
        axes[0].legend()

        axes[1].plot(self.history['val_psnr'])
        axes[1].set_title('Validation PSNR')
        axes[1].set_xlabel('Epoch')
        axes[1].set_ylabel('PSNR (dB)')

        axes[2].plot([self.optimizer.param_groups[0]['lr'] for _ in range(len(self.history['train_loss']))])
        axes[2].set_title('Learning Rate')
        axes[2].set_xlabel('Epoch')
        axes[2].set_ylabel('LR')

        plt.tight_layout()
        plt.savefig(f'{save_dir}/training_history.png')
        plt.close()


In [1]:
# ==================== Main ====================
def main():
    # Mount Google Drive (for Colab)
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        print("Google Drive mounted successfully!")
    except:
        print("Not running in Colab or Drive already mounted")

    # Configuration
    DATASET_PATH = "/content/drive/MyDrive/CV/gopro_deblur"
    BLUR_DIR = f'{DATASET_PATH}/blur/images'
    SHARP_DIR = f'{DATASET_PATH}/sharp/images'
    BATCH_SIZE = 4  # Restormer is memory intensive
    NUM_EPOCHS = 5
    IMG_SIZE = 256

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    if torch.cuda.is_available():
        print(f'GPU: {torch.cuda.get_device_name(0)}')
        print(f'GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')

    # Transforms
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor()
    ])

    # Dataset
    dataset = GoPRODataset(BLUR_DIR, SHARP_DIR, transform=transform)

    # Split dataset
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size]
    )

    print(f'Train size: {len(train_dataset)}, Val size: {len(val_dataset)}')

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                            shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE,
                          shuffle=False, num_workers=4, pin_memory=True)

    # Save checkpoints to Drive
    CHECKPOINT_DIR = '/content/drive/MyDrive/CV/restormer_checkpoints'

    # Train
    trainer = RestormerTrainer(train_loader, val_loader, device=device)

    # Print model size
    total_params = sum(p.numel() for p in trainer.model.parameters())
    print(f'Total parameters: {total_params / 1e6:.2f}M')

    trainer.train(num_epochs=NUM_EPOCHS, save_dir=CHECKPOINT_DIR)

    print('Training completed!')
    print(f'Checkpoints saved to: {CHECKPOINT_DIR}')


if __name__ == '__main__':
    main()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive mounted successfully!


NameError: name 'torch' is not defined