In [1]:
"""
WaveFusion-Net: Dual-Branch Image Deblurring with Wavelet-Spatial Fusion
HIDE Dataset Version (under 12h)
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import numpy as np
from tqdm import tqdm
from torch import amp
import matplotlib.pyplot as plt
import random

# ======================== WAVELET TRANSFORMS ========================
class DWT(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('ll', torch.tensor([[0.5, 0.5], [0.5, 0.5]]).view(1, 1, 2, 2))
        self.register_buffer('lh', torch.tensor([[0.5, 0.5], [-0.5, -0.5]]).view(1, 1, 2, 2))
        self.register_buffer('hl', torch.tensor([[0.5, -0.5], [0.5, -0.5]]).view(1, 1, 2, 2))
        self.register_buffer('hh', torch.tensor([[0.5, -0.5], [-0.5, 0.5]]).view(1, 1, 2, 2))

    def forward(self, x):
        B, C, H, W = x.shape
        if H % 2 != 0:
            x = F.pad(x, (0, 0, 0, 1))
        if W % 2 != 0:
            x = F.pad(x, (0, 1, 0, 0))
        x = x.contiguous().view(B * C, 1, x.shape[2], x.shape[3])
        ll = F.conv2d(x, self.ll, stride=2)
        lh = F.conv2d(x, self.lh, stride=2)
        hl = F.conv2d(x, self.hl, stride=2)
        hh = F.conv2d(x, self.hh, stride=2)
        return (
            ll.view(B, C, ll.shape[2], ll.shape[3]),
            lh.view(B, C, lh.shape[2], lh.shape[3]),
            hl.view(B, C, hl.shape[2], hl.shape[3]),
            hh.view(B, C, hh.shape[2], hh.shape[3]),
        )

class IDWT(nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer('ll', torch.tensor([[0.5, 0.5], [0.5, 0.5]]).view(1, 1, 2, 2))
        self.register_buffer('lh', torch.tensor([[0.5, 0.5], [-0.5, -0.5]]).view(1, 1, 2, 2))
        self.register_buffer('hl', torch.tensor([[0.5, -0.5], [0.5, -0.5]]).view(1, 1, 2, 2))
        self.register_buffer('hh', torch.tensor([[0.5, -0.5], [-0.5, 0.5]]).view(1, 1, 2, 2))

    def forward(self, ll, lh, hl, hh):
        B, C, H, W = ll.shape
        ll = ll.view(B * C, 1, H, W)
        lh = lh.view(B * C, 1, H, W)
        hl = hl.view(B * C, 1, H, W)
        hh = hh.view(B * C, 1, H, W)
        ll = F.conv_transpose2d(ll, self.ll, stride=2)
        lh = F.conv_transpose2d(lh, self.lh, stride=2)
        hl = F.conv_transpose2d(hl, self.hl, stride=2)
        hh = F.conv_transpose2d(hh, self.hh, stride=2)
        out = ll + lh + hl + hh
        return out.view(B, C, out.shape[2], out.shape[3])

# ======================== ATTENTION MODULES ========================
class SimpleGate(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        return x1 * x2

class StripAttention(nn.Module):
    def __init__(self, channels, strip_size=7):
        super().__init__()
        self.h_conv = nn.Conv2d(channels, channels, (1, strip_size), padding=(0, strip_size // 2), groups=channels)
        self.v_conv = nn.Conv2d(channels, channels, (strip_size, 1), padding=(strip_size // 2, 0), groups=channels)
        self.proj = nn.Conv2d(channels * 2, channels, 1)

    def forward(self, x):
        B, C, H, W = x.shape
        h_out = self.h_conv(x)[:, :, :H, :W]
        v_out = self.v_conv(x)[:, :, :H, :W]
        h_attn = torch.sigmoid(h_out)
        v_attn = torch.sigmoid(v_out)
        return self.proj(torch.cat([x * h_attn, x * v_attn], dim=1))

class SCA(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        return x * self.fc(self.gap(x))

# ======================== CORE BLOCKS ========================
class NAFBlock(nn.Module):
    def __init__(self, channels, dw_expand=2):
        super().__init__()
        dw_channels = channels * dw_expand
        self.conv1 = nn.Conv2d(channels, dw_channels, 1)
        self.conv2 = nn.Conv2d(dw_channels, dw_channels, 3, padding=1, groups=dw_channels)
        self.conv3 = nn.Conv2d(dw_channels // 2, channels, 1)
        self.sca = SCA(dw_channels // 2)
        self.sg = SimpleGate()
        self.norm = nn.LayerNorm(channels)

    def forward(self, x):
        residual = x
        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        x = x.permute(0, 3, 1, 2)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.sg(x)
        x = self.sca(x)
        x = self.conv3(x)
        return x + residual

class WaveletBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.ll_conv = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.GELU(),
            nn.Conv2d(channels, channels, 3, padding=1)
        )
        self.hf_conv = nn.Sequential(
            nn.Conv2d(channels * 3, channels * 3, 3, padding=1, groups=3),
            nn.GELU(),
            nn.Conv2d(channels * 3, channels * 3, 3, padding=1, groups=3)
        )
        self.hf_attn = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels * 3, channels * 3, 1),
            nn.Sigmoid()
        )

    def forward(self, ll, lh, hl, hh):
        ll_out = ll + self.ll_conv(ll)
        hf = torch.cat([lh, hl, hh], dim=1)
        hf_feat = self.hf_conv(hf)
        hf_attn = self.hf_attn(hf_feat)
        hf_out = hf + hf_feat * hf_attn
        lh_out, hl_out, hh_out = hf_out.chunk(3, dim=1)
        return ll_out, lh_out, hl_out, hh_out

class CrossBranchFusion(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.spatial_proj = nn.Conv2d(channels, channels, 1)
        self.wavelet_proj = nn.Conv2d(channels, channels, 1)
        self.gate = nn.Sequential(
            nn.Conv2d(channels * 2, channels, 1),
            nn.Sigmoid()
        )
        self.out = nn.Conv2d(channels, channels, 1)

    def forward(self, spatial_feat, wavelet_feat):
        if spatial_feat.shape[2:] != wavelet_feat.shape[2:]:
            wavelet_feat = F.interpolate(wavelet_feat, size=spatial_feat.shape[2:], mode='bilinear', align_corners=False)
        s = self.spatial_proj(spatial_feat)
        w = self.wavelet_proj(wavelet_feat)
        gate = self.gate(torch.cat([s, w], dim=1))
        fused = gate * s + (1 - gate) * w
        return self.out(fused)

# ======================== MAIN NETWORK ========================
class WaveFusionNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, base_channels=48, num_blocks=[4, 6, 6, 4]):
        super().__init__()
        self.dwt = DWT()
        self.idwt = IDWT()
        self.intro = nn.Conv2d(in_channels, base_channels, 3, padding=1)
        self.enc1 = nn.Sequential(*[NAFBlock(base_channels) for _ in range(num_blocks[0])])
        self.down1 = nn.Conv2d(base_channels, base_channels * 2, 2, stride=2)
        self.enc2 = nn.Sequential(*[NAFBlock(base_channels * 2) for _ in range(num_blocks[1])])
        self.down2 = nn.Conv2d(base_channels * 2, base_channels * 4, 2, stride=2)
        self.enc3 = nn.Sequential(*[NAFBlock(base_channels * 4) for _ in range(num_blocks[2])])
        self.down3 = nn.Conv2d(base_channels * 4, base_channels * 8, 2, stride=2)
        self.wav_intro = nn.Conv2d(in_channels, base_channels, 3, padding=1)
        self.wav_block1 = WaveletBlock(base_channels)
        self.wav_proj1 = nn.Conv2d(base_channels, base_channels * 2, 1)
        self.wav_block2 = WaveletBlock(base_channels * 2)
        self.wav_proj2 = nn.Conv2d(base_channels * 2, base_channels * 4, 1)
        self.wav_block3 = WaveletBlock(base_channels * 4)
        self.fusion1 = CrossBranchFusion(base_channels * 2)
        self.fusion2 = CrossBranchFusion(base_channels * 4)
        self.bottleneck = nn.Sequential(
            NAFBlock(base_channels * 8),
            StripAttention(base_channels * 8),
            NAFBlock(base_channels * 8),
            NAFBlock(base_channels * 8),
        )
        self.up3 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, 2, stride=2)
        self.dec3 = nn.Sequential(*[NAFBlock(base_channels * 4) for _ in range(num_blocks[2])])
        self.up2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, 2, stride=2)
        self.dec2 = nn.Sequential(*[NAFBlock(base_channels * 2) for _ in range(num_blocks[1])])
        self.up1 = nn.ConvTranspose2d(base_channels * 2, base_channels, 2, stride=2)
        self.dec1 = nn.Sequential(*[NAFBlock(base_channels) for _ in range(num_blocks[0])])
        self.refine = nn.Sequential(
            nn.Conv2d(base_channels, base_channels, 3, padding=1),
            nn.GELU(),
            nn.Conv2d(base_channels, base_channels, 3, padding=1),
        )
        self.outro = nn.Conv2d(base_channels, out_channels, 3, padding=1)

    def forward(self, x):
        B, C, H, W = x.shape
        pad_h = (8 - H % 8) % 8
        pad_w = (8 - W % 8) % 8
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
        f0 = self.intro(x)
        f1 = self.enc1(f0)
        f1_down = self.down1(f1)
        f2 = self.enc2(f1_down)
        f2_down = self.down2(f2)
        f3 = self.enc3(f2_down)
        f3_down = self.down3(f3)
        w0 = self.wav_intro(x)
        ll1, lh1, hl1, hh1 = self.dwt(w0)
        ll1, lh1, hl1, hh1 = self.wav_block1(ll1, lh1, hl1, hh1)
        w1 = self.wav_proj1(ll1)
        ll2, lh2, hl2, hh2 = self.dwt(w1)
        ll2, lh2, hl2, hh2 = self.wav_block2(ll2, lh2, hl2, hh2)
        w2 = self.wav_proj2(ll2)
        ll3, lh3, hl3, hh3 = self.dwt(w2)
        ll3, lh3, hl3, hh3 = self.wav_block3(ll3, lh3, hl3, hh3)
        f2_fused = self.fusion1(f2, w1)
        f3_fused = self.fusion2(f3, w2)
        bottleneck_out = self.bottleneck(f3_down)
        d3 = self.up3(bottleneck_out) + f3_fused
        d3 = self.dec3(d3)
        d2 = self.up2(d3) + f2_fused
        d2 = self.dec2(d2)
        d1 = self.up1(d2) + f1
        d1 = self.dec1(d1)
        out = self.refine(d1)
        out = out + f0
        out = self.outro(out)
        out = out + x
        if pad_h > 0 or pad_w > 0:
            out = out[:, :, :H, :W]
        return out

# ======================== LOSS FUNCTIONS ========================
class VGGPerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg19(weights=models.VGG19_Weights.IMAGENET1K_V1).features
        self.slice1 = nn.Sequential(*list(vgg.children())[:4])
        self.slice2 = nn.Sequential(*list(vgg.children())[4:9])
        self.slice3 = nn.Sequential(*list(vgg.children())[9:18])
        for p in self.parameters():
            p.requires_grad = False
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, pred, target):
        pred = (pred - self.mean) / self.std
        target = (target - self.mean) / self.std
        pf1 = self.slice1(pred)
        pf2 = self.slice2(pf1)
        pf3 = self.slice3(pf2)
        with torch.no_grad():
            tf1 = self.slice1(target)
            tf2 = self.slice2(tf1)
            tf3 = self.slice3(tf2)
        return F.l1_loss(pf1, tf1) + F.l1_loss(pf2, tf2) + F.l1_loss(pf3, tf3)

class FFTLoss(nn.Module):
    def forward(self, pred, target):
        pred_fft = torch.fft.rfft2(pred)
        target_fft = torch.fft.rfft2(target)
        return F.l1_loss(pred_fft.real, target_fft.real) + F.l1_loss(pred_fft.imag, target_fft.imag)

class GradientLoss(nn.Module):
    def __init__(self):
        super().__init__()
        sobel_x = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32)
        sobel_y = torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32)
        self.register_buffer('sobel_x', sobel_x.view(1, 1, 3, 3).repeat(3, 1, 1, 1))
        self.register_buffer('sobel_y', sobel_y.view(1, 1, 3, 3).repeat(3, 1, 1, 1))

    def forward(self, pred, target):
        pgx = F.conv2d(pred, self.sobel_x, padding=1, groups=3)
        pgy = F.conv2d(pred, self.sobel_y, padding=1, groups=3)
        tgx = F.conv2d(target, self.sobel_x, padding=1, groups=3)
        tgy = F.conv2d(target, self.sobel_y, padding=1, groups=3)
        return F.l1_loss(pgx, tgx) + F.l1_loss(pgy, tgy)

class WaveletHFLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.dwt = DWT()

    def forward(self, pred, target):
        _, plh, phl, phh = self.dwt(pred)
        _, tlh, thl, thh = self.dwt(target)
        return F.l1_loss(plh, tlh) + F.l1_loss(phl, thl) + F.l1_loss(phh, thh)

class CombinedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.L1Loss()
        self.vgg = VGGPerceptualLoss()
        self.fft = FFTLoss()
        self.gradient = GradientLoss()
        self.wavelet_hf = WaveletHFLoss()
        self.w_l1 = 1.0
        self.w_vgg = 0.1
        self.w_fft = 0.05
        self.w_gradient = 0.1
        self.w_wavelet = 0.02

    def forward(self, pred, target):
        l1_loss = self.l1(pred, target)
        vgg_loss = self.vgg(pred, target)
        fft_loss = self.fft(pred, target)
        gradient_loss = self.gradient(pred, target)
        wavelet_loss = self.wavelet_hf(pred, target)
        total = (self.w_l1 * l1_loss + self.w_vgg * vgg_loss + self.w_fft * fft_loss +
                 self.w_gradient * gradient_loss + self.w_wavelet * wavelet_loss)
        return total, {
            'l1': l1_loss.item(),
            'vgg': vgg_loss.item(),
            'fft': fft_loss.item(),
            'gradient': gradient_loss.item(),
            'wavelet': wavelet_loss.item(),
        }

# ======================== DATASET ========================
class HIDEPairs(Dataset):
    """HIDE dataset loader. Supports blur/GT folders or train/test txt lists."""
    def __init__(self, root_dir, split='train', patch_size=256):
        self.patch_size = patch_size
        self.split = split
        self.blur_images = []
        self.sharp_images = []

        def resolve(path_str):
            # Honor absolute paths from txt; otherwise treat as relative to root_dir
            if os.path.isabs(path_str) or path_str.startswith('/'):
                return path_str
            return os.path.join(root_dir, path_str)

        list_path = os.path.join(root_dir, f"{split}.txt")
        used_txt = False
        if os.path.exists(list_path):
            used_txt = True
            with open(list_path, 'r') as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) >= 2:
                        blur_path = resolve(parts[0])
                        gt_path = resolve(parts[1])
                        if os.path.exists(blur_path) and os.path.exists(gt_path):
                            self.blur_images.append(blur_path)
                            self.sharp_images.append(gt_path)
                    elif len(parts) == 1 and parts[0]:
                        blur_path = resolve(parts[0])
                        fname = os.path.basename(blur_path)
                        candidates = [
                            os.path.join(root_dir, 'GT', fname),
                            os.path.join(root_dir, 'gt', fname),
                            os.path.join(root_dir, split, 'GT', fname),
                            os.path.join(root_dir, split, 'gt', fname),
                        ]
                        gt_path = next((p for p in candidates if os.path.exists(p)), None)
                        if os.path.exists(blur_path) and gt_path:
                            self.blur_images.append(blur_path)
                            self.sharp_images.append(gt_path)
        if not used_txt or (used_txt and len(self.blur_images) == 0):
            blur_dir_candidates = [
                os.path.join(root_dir, split, 'blur'),
                os.path.join(root_dir, split, 'input'),
                os.path.join(root_dir, split),
                os.path.join(root_dir, 'blur'),
                os.path.join(root_dir, 'input'),
                os.path.join(root_dir, split, 'test'),
                os.path.join(root_dir, 'test'),
            ]
            gt_dir_candidates = [
                os.path.join(root_dir, split, 'gt'),
                os.path.join(root_dir, split, 'GT'),
                os.path.join(root_dir, 'gt'),
                os.path.join(root_dir, 'GT'),
            ]
            gt_dir = next((p for p in gt_dir_candidates if os.path.exists(p)), None)

            # Scan recursively for blur files
            extensions = ('.png', '.jpg', '.jpeg')
            for blur_root in blur_dir_candidates:
                if not os.path.exists(blur_root):
                    continue
                for dirpath, _, filenames in os.walk(blur_root):
                    # skip GT folders
                    if 'GT' in os.path.normpath(dirpath).split(os.sep):
                        continue
                    for name in filenames:
                        if not name.lower().endswith(extensions):
                            continue
                        blur_path = os.path.join(dirpath, name)
                        if gt_dir:
                            gt_path = os.path.join(gt_dir, name)
                        else:
                            gt_path = None
                            for candidate in [
                                os.path.join(root_dir, 'GT', name),
                                os.path.join(root_dir, 'gt', name),
                            ]:
                                if os.path.exists(candidate):
                                    gt_path = candidate
                                    break
                        if gt_path and os.path.exists(gt_path):
                            self.blur_images.append(blur_path)
                            self.sharp_images.append(gt_path)
        print(f"Found {len(self.blur_images)} {split} pairs in HIDE")
        if len(self.blur_images) == 0:
            raise ValueError(f"No {split} pairs found in HIDE. Check train/test.txt or folder structure.")
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.blur_images)

    def __getitem__(self, idx):
        blur = Image.open(self.blur_images[idx]).convert('RGB')
        sharp = Image.open(self.sharp_images[idx]).convert('RGB')
        blur = self.to_tensor(blur)
        sharp = self.to_tensor(sharp)
        if self.split == 'train' and self.patch_size:
            _, h, w = blur.shape
            if h >= self.patch_size and w >= self.patch_size:
                top = np.random.randint(0, h - self.patch_size + 1)
                left = np.random.randint(0, w - self.patch_size + 1)
                blur = blur[:, top:top + self.patch_size, left:left + self.patch_size]
                sharp = sharp[:, top:top + self.patch_size, left:left + self.patch_size]
            if np.random.random() > 0.5:
                blur = torch.flip(blur, [2])
                sharp = torch.flip(sharp, [2])
            if np.random.random() > 0.5:
                blur = torch.flip(blur, [1])
                sharp = torch.flip(sharp, [1])
        return blur, sharp

# ======================== METRICS ========================
def calculate_psnr(pred, target):
    mse = F.mse_loss(pred, target)
    if mse == 0:
        return float('inf')
    return 10 * torch.log10(1.0 / mse)

# ======================== VISUALIZATION ========================
def visualize_results(model, test_loader, device, save_dir, num_samples=8):
    model.eval()
    results_dir = os.path.join(save_dir, 'results_hide')
    os.makedirs(results_dir, exist_ok=True)
    total = len(test_loader.dataset)
    indices = random.sample(range(total), min(num_samples, total))
    stats = []
    with torch.no_grad():
        for idx, sample_idx in enumerate(indices):
            blur, sharp = test_loader.dataset[sample_idx]
            blur = blur.unsqueeze(0).to(device)
            sharp = sharp.unsqueeze(0).to(device)
            with amp.autocast('cuda'):
                pred = model(blur)
            pred = torch.clamp(pred, 0, 1)
            psnr = calculate_psnr(pred, sharp).item()
            stats.append(psnr)
            comparison = torch.cat([blur.cpu(), pred.cpu(), sharp.cpu()], dim=0)
            grid = torch.cat([comparison[i] for i in range(3)], dim=2)
            img = transforms.ToPILImage()(grid)
            img.save(os.path.join(results_dir, f'sample_{idx}_psnr{psnr:.2f}.png'))
    print(f"Saved {len(indices)} comparisons to {results_dir}; mean PSNR={np.mean(stats):.2f}dB")

# ======================== TRAINING ========================
def train():
    config = {
        'data_root': '/kaggle/input/hideblur/HIDE_dataset',  # adjust to your path
        'batch_size': 4,
        'patch_size': 256,
        'epochs': 30,
        'lr': 2e-4,
        'min_lr': 1e-7,
        'num_workers': 4,
        'base_channels': 48,
        'num_blocks': [4, 6, 6, 4],
        'save_dir': '/kaggle/working',
    }

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    # Build model
    model = WaveFusionNet(base_channels=config['base_channels'], num_blocks=config['num_blocks'])

    # Print detailed architecture table
    print("\n" + "="*80)
    print("MODEL ARCHITECTURE: WaveFusion-Net (HIDE)")
    print("="*80)

    total_params = 0
    trainable_params = 0
    print(f"{'Module':<40} {'Parameters':<15} {'Shape':<25}")
    print("-"*80)
    for name, param in model.named_parameters():
        params = param.numel()
        total_params += params
        if param.requires_grad:
            trainable_params += params
        short_name = name.replace('module.', '')
        print(f"{short_name:<40} {params:>12,} {str(list(param.shape)):<25}")
    print("-"*80)
    print(f"{'Total Parameters':<40} {total_params:>12,}")
    print(f"{'Trainable Parameters':<40} {trainable_params:>12,}")
    print(f"{'Total (Millions)':<40} {total_params/1e6:>12.2f}M")
    print("="*80)

    # Architecture summary
    print("\nARCHITECTURE SUMMARY:")
    print(f"  - Base Channels: {config['base_channels']}")
    print(f"  - Encoder Blocks: {config['num_blocks']}")
    print(f"  - Dual-Branch: Spatial (NAFBlocks) + Wavelet (DWT)")
    print(f"  - Fusion: Cross-Branch Gated Fusion at 2 levels")
    print(f"  - Bottleneck: Strip Attention")
    print("="*80 + "\n")

    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs")
        model = nn.DataParallel(model)
    model = model.to(device)

    criterion = CombinedLoss().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr'], weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
    scaler = amp.GradScaler('cuda')

    train_dataset = HIDEPairs(config['data_root'], split='train', patch_size=config['patch_size'])
    test_dataset = HIDEPairs(config['data_root'], split='test', patch_size=None)
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config['num_workers'], pin_memory=True, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2)

    # Baseline PSNR check (fewer samples for speed)
    print("\n=== Baseline Check ===")
    with torch.no_grad():
        baseline_psnrs = []
        for i, (blur, sharp) in enumerate(test_loader):
            if i >= 5:
                break
            baseline_psnrs.append(calculate_psnr(blur, sharp).item())
        if len(baseline_psnrs) > 0:
            print(f"Baseline PSNR (blur vs sharp): {np.mean(baseline_psnrs):.2f} dB")
        else:
            print("Baseline PSNR: no test samples found.")

    best_psnr = 0
    print("\n=== Starting Training ===")
    for epoch in range(config['epochs']):
        model.train()
        epoch_loss = 0
        loss_components = {'l1': 0, 'vgg': 0, 'fft': 0, 'gradient': 0, 'wavelet': 0}
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['epochs']}")
        for blur, sharp in pbar:
            blur = blur.to(device)
            sharp = sharp.to(device)
            optimizer.zero_grad()
            with amp.autocast('cuda'):
                pred = model(blur)
                loss, comps = criterion(pred, sharp)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            epoch_loss += loss.item()
            for k, v in comps.items():
                loss_components[k] += v
            pbar.set_postfix({'loss': f"{loss.item():.4f}", 'lr': f"{optimizer.param_groups[0]['lr']:.2e}"})
        scheduler.step()
        n_batches = len(train_loader)
        epoch_loss /= n_batches
        for k in loss_components:
            loss_components[k] /= n_batches

        if (epoch + 1) % 10 == 0 or epoch == 0:
            model.eval()
            val_psnr = 0
            with torch.no_grad():
                for blur, sharp in tqdm(test_loader, desc="Validating"):
                    blur = blur.to(device)
                    sharp = sharp.to(device)
                    with amp.autocast('cuda'):
                        pred = model(blur)
                    pred = torch.clamp(pred, 0, 1)
                    val_psnr += calculate_psnr(pred, sharp).item()
            val_psnr /= len(test_loader)
            print(f"Epoch {epoch+1}: Loss={epoch_loss:.4f}, PSNR={val_psnr:.2f}dB")
            print(f"  Components - L1:{loss_components['l1']:.4f}, VGG:{loss_components['vgg']:.4f}, FFT:{loss_components['fft']:.4f}, Grad:{loss_components['gradient']:.4f}, Wav:{loss_components['wavelet']:.4f}")
            if val_psnr > best_psnr:
                best_psnr = val_psnr
                save_path = os.path.join(config['save_dir'], 'best_model_hide.pth')
                state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
                torch.save({'epoch': epoch + 1, 'model_state_dict': state_dict, 'optimizer_state_dict': optimizer.state_dict(), 'psnr': val_psnr}, save_path)
                print(f"  *** New best model saved! PSNR: {val_psnr:.2f}dB ***")

        if (epoch + 1) % 20 == 0:
            save_path = os.path.join(config['save_dir'], f'checkpoint_hide_epoch{epoch+1}.pth')
            state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
            torch.save({'epoch': epoch + 1, 'model_state_dict': state_dict, 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict()}, save_path)

    print(f"\n=== Training Complete ===\nBest PSNR: {best_psnr:.2f}dB")
    if best_psnr > 0:
        best_model_path = os.path.join(config['save_dir'], 'best_model_hide.pth')
        if os.path.exists(best_model_path):
            checkpoint = torch.load(best_model_path)
            if hasattr(model, 'module'):
                model.module.load_state_dict(checkpoint['model_state_dict'])
            else:
                model.load_state_dict(checkpoint['model_state_dict'])
            visualize_results(model, test_loader, device, config['save_dir'], num_samples=8)

if __name__ == '__main__':
    train()


Using device: cuda

MODEL ARCHITECTURE: WaveFusion-Net (HIDE)
Module                                   Parameters      Shape                    
--------------------------------------------------------------------------------
intro.weight                                    1,296 [48, 3, 3, 3]            
intro.bias                                         48 [48]                     
enc1.0.conv1.weight                             4,608 [96, 48, 1, 1]           
enc1.0.conv1.bias                                  96 [96]                     
enc1.0.conv2.weight                               864 [96, 1, 3, 3]            
enc1.0.conv2.bias                                  96 [96]                     
enc1.0.conv3.weight                             2,304 [48, 48, 1, 1]           
enc1.0.conv3.bias                                  48 [48]                     
enc1.0.sca.fc.weight                            2,304 [48, 48, 1, 1]           
enc1.0.sca.fc.bias                                 48 

100%|██████████| 548M/548M [00:02<00:00, 239MB/s]  


Found 8422 train pairs in HIDE
Found 4050 test pairs in HIDE

=== Baseline Check ===
Baseline PSNR (blur vs sharp): 23.43 dB

=== Starting Training ===


Epoch 1/30: 100%|██████████| 2105/2105 [14:17<00:00,  2.45it/s, loss=1.2794, lr=2.00e-04]
Validating: 100%|██████████| 4050/4050 [36:17<00:00,  1.86it/s]


Epoch 1: Loss=0.7106, PSNR=24.73dB
  Components - L1:0.0420, VGG:1.6224, FFT:9.6900, Grad:0.2109, Wav:0.0401
  *** New best model saved! PSNR: 24.73dB ***


Epoch 2/30: 100%|██████████| 2105/2105 [13:57<00:00,  2.51it/s, loss=0.6357, lr=1.99e-04]
Epoch 3/30: 100%|██████████| 2105/2105 [13:56<00:00,  2.52it/s, loss=0.4432, lr=1.98e-04]
Epoch 4/30: 100%|██████████| 2105/2105 [13:58<00:00,  2.51it/s, loss=0.5784, lr=1.95e-04]
Epoch 5/30: 100%|██████████| 2105/2105 [13:59<00:00,  2.51it/s, loss=0.7588, lr=1.91e-04]
Epoch 6/30: 100%|██████████| 2105/2105 [14:05<00:00,  2.49it/s, loss=1.1327, lr=1.87e-04]
Epoch 7/30: 100%|██████████| 2105/2105 [14:06<00:00,  2.49it/s, loss=0.4192, lr=1.81e-04]
Epoch 8/30: 100%|██████████| 2105/2105 [14:04<00:00,  2.49it/s, loss=0.5984, lr=1.74e-04]
Epoch 9/30: 100%|██████████| 2105/2105 [14:01<00:00,  2.50it/s, loss=0.6778, lr=1.67e-04]
Epoch 10/30: 100%|██████████| 2105/2105 [14:03<00:00,  2.50it/s, loss=0.5509, lr=1.59e-04]
Validating: 100%|██████████| 4050/4050 [36:19<00:00,  1.86it/s]


Epoch 10: Loss=0.6026, PSNR=27.50dB
  Components - L1:0.0293, VGG:1.3996, FFT:8.3014, Grad:0.1756, Wav:0.0355
  *** New best model saved! PSNR: 27.50dB ***


Epoch 11/30: 100%|██████████| 2105/2105 [14:03<00:00,  2.50it/s, loss=0.7635, lr=1.50e-04]
Epoch 12/30: 100%|██████████| 2105/2105 [13:59<00:00,  2.51it/s, loss=0.6433, lr=1.41e-04]
Epoch 13/30: 100%|██████████| 2105/2105 [14:01<00:00,  2.50it/s, loss=0.5943, lr=1.31e-04]
Epoch 14/30: 100%|██████████| 2105/2105 [13:59<00:00,  2.51it/s, loss=0.5081, lr=1.21e-04]
Epoch 15/30: 100%|██████████| 2105/2105 [14:01<00:00,  2.50it/s, loss=0.4639, lr=1.10e-04]
Epoch 16/30: 100%|██████████| 2105/2105 [14:00<00:00,  2.50it/s, loss=0.5775, lr=1.00e-04]
Epoch 17/30: 100%|██████████| 2105/2105 [14:02<00:00,  2.50it/s, loss=0.5522, lr=8.96e-05]
Epoch 18/30: 100%|██████████| 2105/2105 [13:58<00:00,  2.51it/s, loss=0.3833, lr=7.93e-05]
Epoch 19/30: 100%|██████████| 2105/2105 [14:00<00:00,  2.50it/s, loss=0.7211, lr=6.92e-05]
Epoch 20/30: 100%|██████████| 2105/2105 [13:57<00:00,  2.51it/s, loss=0.4699, lr=5.94e-05]
Validating: 100%|██████████| 4050/4050 [36:18<00:00,  1.86it/s]


Epoch 20: Loss=0.5612, PSNR=28.33dB
  Components - L1:0.0260, VGG:1.3225, FFT:7.7253, Grad:0.1605, Wav:0.0334
  *** New best model saved! PSNR: 28.33dB ***


Epoch 21/30: 100%|██████████| 2105/2105 [13:56<00:00,  2.52it/s, loss=0.4877, lr=5.01e-05]
Epoch 22/30: 100%|██████████| 2105/2105 [13:55<00:00,  2.52it/s, loss=0.5226, lr=4.13e-05]
Epoch 23/30: 100%|██████████| 2105/2105 [13:56<00:00,  2.52it/s, loss=0.4784, lr=3.32e-05]
Epoch 24/30: 100%|██████████| 2105/2105 [13:56<00:00,  2.52it/s, loss=0.8744, lr=2.58e-05]
Epoch 25/30: 100%|██████████| 2105/2105 [13:58<00:00,  2.51it/s, loss=0.3664, lr=1.92e-05]
Epoch 26/30: 100%|██████████| 2105/2105 [13:59<00:00,  2.51it/s, loss=0.4916, lr=1.35e-05]
Epoch 27/30: 100%|██████████| 2105/2105 [13:59<00:00,  2.51it/s, loss=0.9436, lr=8.74e-06]
Epoch 28/30: 100%|██████████| 2105/2105 [13:59<00:00,  2.51it/s, loss=0.4650, lr=4.99e-06]
Epoch 29/30: 100%|██████████| 2105/2105 [14:01<00:00,  2.50it/s, loss=0.5247, lr=2.28e-06]
Epoch 30/30: 100%|██████████| 2105/2105 [14:03<00:00,  2.50it/s, loss=0.7571, lr=6.48e-07]
Validating: 100%|██████████| 4050/4050 [36:19<00:00,  1.86it/s]


Epoch 30: Loss=0.5545, PSNR=28.61dB
  Components - L1:0.0253, VGG:1.3141, FFT:7.6247, Grad:0.1589, Wav:0.0334
  *** New best model saved! PSNR: 28.61dB ***

=== Training Complete ===
Best PSNR: 28.61dB
Saved 8 comparisons to /kaggle/working/results_hide; mean PSNR=28.35dB
