In [None]:
!pip install lmdb
!pip install torchviz
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import os
import cv2
import glob
from torchvision import models
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torch.nn.utils import spectral_norm
import kornia
from kornia.losses import ssim as kornia_ssim
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim_sk
from torch.utils.data.sampler import RandomSampler


# Enable benchmark mode in cudnn (optional optimization)
torch.backends.cudnn.benchmark = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import os
import glob
import cv2
import random
import time
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Sampler
import lmdb

############################################
# 1. File-Based Dataset for SIDD Training/Validation Patches
############################################

class SIDDFileDataset(Dataset):
    """
    A dataset class for SIDD training/validation patches stored as individual PNG files.
    Assumes that noisy patches and ground truth patches are stored in separate directories
    with matching filenames (e.g., '00000001.png' in both directories).
    """
    def __init__(self, noisy_dir, gt_dir, transform=None):
        self.noisy_dir = noisy_dir
        self.gt_dir = gt_dir
        self.transform = transform
        
        # List and sort all PNG files from the two directories.
        self.noisy_files = sorted(glob.glob(os.path.join(noisy_dir, '*.png')))
        self.gt_files = sorted(glob.glob(os.path.join(gt_dir, '*.png')))
        
        if len(self.noisy_files) != len(self.gt_files):
            raise ValueError("Mismatch between number of noisy and GT images. "
                             f"Found {len(self.noisy_files)} noisy and {len(self.gt_files)} GT images.")
        
    def __len__(self):
        return len(self.noisy_files)
    
    def __getitem__(self, idx):
        # Get file paths for the corresponding noisy and GT images.
        noisy_file = self.noisy_files[idx]
        gt_file = self.gt_files[idx]
        
        # Read images using OpenCV.
        noisy_img = cv2.imread(noisy_file, cv2.IMREAD_COLOR)
        gt_img = cv2.imread(gt_file, cv2.IMREAD_COLOR)
        if noisy_img is None or gt_img is None:
            raise RuntimeError(f"Failed to read images: {noisy_file} or {gt_file}")
        
        # Convert from BGR (OpenCV default) to RGB.
        noisy_img = cv2.cvtColor(noisy_img, cv2.COLOR_BGR2RGB)
        gt_img = cv2.cvtColor(gt_img, cv2.COLOR_BGR2RGB)
        
        # Convert from HWC to CHW format and normalize to [0, 1].
        noisy_img = torch.from_numpy(np.transpose(noisy_img, (2, 0, 1))).float() / 255.0
        gt_img = torch.from_numpy(np.transpose(gt_img, (2, 0, 1))).float() / 255.0
        
        if self.transform is not None:
            noisy_img, gt_img = self.transform(noisy_img, gt_img)
            
        return noisy_img, gt_img

############################################
# 2. Custom Sampler: Randomly Subsample a Fixed Number of Patches Each Epoch
############################################

class SubsetRandomSampler(Sampler):
    """
    A sampler that randomly selects a fixed number of indices (num_samples)
    from the dataset for each epoch.
    """
    def __init__(self, data_source, num_samples):
        self.data_source = data_source
        self.num_samples = num_samples

    def __iter__(self):
        # Create a list of indices and use current time to seed the random module
        indices = list(range(len(self.data_source)))
        random.seed(time.time())
        random.shuffle(indices)
        return iter(indices[:self.num_samples])

    def __len__(self):
        return self.num_samples

############################################
# 3. Worker Initialization Function
############################################

def worker_init_fn(worker_id):
    """
    Worker initialization function to ensure different random seeds across workers.
    """
    seed = int(time.time()) + worker_id
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

############################################
# 4. Example Usage for Training and Validation
############################################

# Parameters
batch_size = 8
num_samples_per_epoch = 320

if __name__ == '__main__':
    # -------------------------------
    # Training
    # -------------------------------
    # Use the modified training data directories.
    train_noisy_dir = '/kaggle/input/sagluuuu/train_input'
    train_gt_dir = '/kaggle/input/sagluuuu/train_gt'
    
    # Create an instance of the training dataset.
    train_dataset = SIDDFileDataset(train_noisy_dir, train_gt_dir)
    
    # Create the custom sampler.
    sampler = SubsetRandomSampler(train_dataset, num_samples_per_epoch)
    
    # Create a DataLoader with the custom sampler and worker initialization.
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,          # Adjust batch size as needed.
        sampler=sampler,
        num_workers=4,                  # Adjust based on available CPU cores.
        pin_memory=True,
        persistent_workers=True,        # Optional: reduces worker restart overhead.
        worker_init_fn=worker_init_fn
    )
    
    # Iterate over one training epoch.
    print("TRAINING EPOCH:")
    for batch_idx, (input_imgs, gt_imgs) in enumerate(train_loader):
        print(f"Train Batch {batch_idx}: Input shape {input_imgs.shape}, GT shape {gt_imgs.shape}")
        # For demonstration purposes, break after one batch.
        break

    # -------------------------------
    # Validation
    # -------------------------------
    # The LMDB-based validation dataset is assumed to be unchanged.
    class SIDDValLMDB(Dataset):
        """
        A dataset class for SIDD validation data stored in LMDBs.
        Each key corresponds to a pre-cropped noisy/clean patch (encoded as PNG bytes).
        """
        def __init__(self, input_lmdb_path, gt_lmdb_path, transform=None):
            self.input_env = lmdb.open(input_lmdb_path, readonly=True, lock=False, readahead=False, meminit=False)
            self.gt_env = lmdb.open(gt_lmdb_path, readonly=True, lock=False, readahead=False, meminit=False)
            self.transform = transform
            with self.input_env.begin() as txn:
                self.keys = [key for key, _ in txn.cursor() if key != b'length']
            self.keys = sorted(self.keys)
            self.length = len(self.keys)

        def __len__(self):
            return self.length

        def __getitem__(self, idx):
            key = self.keys[idx]
            with self.input_env.begin() as txn:
                input_bytes = txn.get(key)
            with self.gt_env.begin() as txn:
                gt_bytes = txn.get(key)
            if input_bytes is None or gt_bytes is None:
                raise KeyError(f"Key {key} not found in one of the LMDBs.")
            # Decode PNG bytes to image arrays (BGR)
            input_img = cv2.imdecode(np.frombuffer(input_bytes, np.uint8), cv2.IMREAD_COLOR)
            gt_img = cv2.imdecode(np.frombuffer(gt_bytes, np.uint8), cv2.IMREAD_COLOR)
            if input_img is None or gt_img is None:
                raise RuntimeError(f"Failed to decode images for key {key}.")
            # Convert from BGR to RGB
            input_img = cv2.cvtColor(input_img, cv2.COLOR_BGR2RGB)
            gt_img = cv2.cvtColor(gt_img, cv2.COLOR_BGR2RGB)
            # Convert to torch tensors and scale to [0,1]
            input_img = torch.from_numpy(np.transpose(input_img, (2, 0, 1))).float() / 255.0
            gt_img = torch.from_numpy(np.transpose(gt_img, (2, 0, 1))).float() / 255.0
            if self.transform is not None:
                input_img, gt_img = self.transform(input_img, gt_img)
            return input_img, gt_img

    # Paths for validation LMDBs
    val_input_lmdb = '/kaggle/input/smartphone-image-denoising-dataset/SIDD-val-lmdb/SIDD/val/input_crops.lmdb'
    val_gt_lmdb = '/kaggle/input/smartphone-image-denoising-dataset/SIDD-val-lmdb/SIDD/val/gt_crops.lmdb'
    val_dataset = SIDDValLMDB(val_input_lmdb, val_gt_lmdb)
    
    # Create a DataLoader for validation.
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        persistent_workers=True,
        worker_init_fn=worker_init_fn
    )
    
    # Iterate over one validation batch.
    print("\nVALIDATION BATCH:")
    for input_imgs, gt_imgs in val_loader:
        print("Validation batch shapes:", input_imgs.shape, gt_imgs.shape)
        break

# Generator


In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


# --------------------------------------------------------------------------
# 1) Utility Functions: Splitting and Merging Patches
# --------------------------------------------------------------------------
def split_into_patches(img, patch_size, overlap=0):
    """
    Splits a batch of images into overlapping or non-overlapping patches.

    Args:
        img (torch.Tensor): shape (B, C, H, W).
        patch_size (int): the spatial size of each patch.
        overlap (int): overlap in pixels for adjacent patches. If > 0,
                       patches will overlap, requiring careful merging.

    Returns:
        (torch.Tensor):
            Patches of shape (B, N, C, patch_size, patch_size),
            where N = #patches per image (horiz_patches * vert_patches).
    """
    B, C, H, W = img.shape
    stride = patch_size - overlap

    patches_out = []
    vertical_count = math.ceil((H - overlap) / stride)
    horizontal_count = math.ceil((W - overlap) / stride)

    for b_i in range(B):
        these_patches = []
        for v in range(vertical_count):
            for h in range(horizontal_count):
                top = v * stride
                left = h * stride
                bottom = min(top + patch_size, H)
                right = min(left + patch_size, W)

                patch = img[b_i:b_i+1, :, top:bottom, left:right]

                # If near boundaries, patch might be smaller. We pad so every patch is the same size.
                pad_bottom = patch_size - (bottom - top)
                pad_right = patch_size - (right - left)
                if pad_bottom > 0 or pad_right > 0:
                    patch = F.pad(patch, (0, pad_right, 0, pad_bottom), mode='replicate')

                these_patches.append(patch)
        # Stack along patch dimension => (N, C, patch_size, patch_size)
        these_patches = torch.cat(these_patches, dim=0)
        # Keep a batch dimension => (1, N, C, patch_size, patch_size)
        these_patches = these_patches.unsqueeze(0)
        patches_out.append(these_patches)

    return torch.cat(patches_out, dim=0)


def merge_from_patches(patches, out_shape, patch_size, overlap=0):
    """
    Merges patches (potentially overlapping) back into a single image.

    Args:
        patches (torch.Tensor): shape (B, N, C, patch_size, patch_size).
        out_shape (tuple): the original image shape (B, C, H, W).
        patch_size (int): patch dimension used in splitting.
        overlap (int): overlap in pixels.

    Returns:
        (torch.Tensor):
            A tensor of shape (B, C, H, W) with overlapped regions blended.
    """
    B, N, C, _, _ = patches.shape
    _, _, H, W = out_shape
    stride = patch_size - overlap

    vertical_count = math.ceil((H - overlap) / stride)
    horizontal_count = math.ceil((W - overlap) / stride)

    # We'll reconstruct each image in the batch independently
    merged_imgs = []
    idx_start = 0

    for b_i in range(B):
        # Large accumulators for pixel values and blending weights
        accumulator = torch.zeros((C, H, W), device=patches.device)
        weight_map = torch.zeros((C, H, W), device=patches.device)

        # patches for image b_i: shape => (N, C, patch_size, patch_size)
        # Assume N == vertical_count * horizontal_count
        n_patches = vertical_count * horizontal_count
        b_patches = patches[b_i, :n_patches, :, :, :]

        patch_idx = 0
        for v in range(vertical_count):
            for h in range(horizontal_count):
                top = v * stride
                left = h * stride
                bottom = top + patch_size
                right = left + patch_size

                patch = b_patches[patch_idx]
                patch_idx += 1

                # If we padded earlier, we might need to clip here
                patch_height = min(patch_size, H - top)
                patch_width = min(patch_size, W - left)
                patch_cropped = patch[:, :patch_height, :patch_width]

                accumulator[:, top:top+patch_height, left:left+patch_width] += patch_cropped
                weight_map[:, top:top+patch_height, left:left+patch_width] += 1.0

        # average in overlapped zones
        weight_map[weight_map == 0] = 1e-9
        final_img = accumulator / weight_map
        merged_imgs.append(final_img.unsqueeze(0))

    # Combine all batch items
    return torch.cat(merged_imgs, dim=0)


# --------------------------------------------------------------------------
# 2) CNN or Small Denoiser Blocks
# --------------------------------------------------------------------------
class ResidualBlock(nn.Module):
    """
    A small residual block with two 3x3 conv layers. Useful for refining patches.
    """
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        out = self.activation(self.conv1(x))
        out = self.conv2(out)
        return self.activation(out + identity)


class MiniPatchDenoiser(nn.Module):
    """
    At the smallest fractal level, do a final 'pixel-level' or small-patch cleanup.
    For brevity, we define it as 2 residual blocks.

    If you want something more advanced, you can chain more blocks or
    do per-pixel transformations, etc.
    """
    def __init__(self, in_channels, hidden_channels=32):
        super().__init__()
        self.initial = nn.Conv2d(in_channels, hidden_channels, 3, padding=1)
        self.res1 = ResidualBlock(hidden_channels)
        self.res2 = ResidualBlock(hidden_channels)
        self.final = nn.Conv2d(hidden_channels, in_channels, 3, padding=1)

    def forward(self, x):
        # x: (B, C, H, W)
        z = F.relu(self.initial(x))
        z = self.res1(z)
        z = self.res2(z)
        out = self.final(z)
        # Optionally do a residual add or direct output
        out = x + out  # residual style final
        return out


# --------------------------------------------------------------------------
# 3) Recursive "FractalBlock" - a building block that calls an optional child
# --------------------------------------------------------------------------
class FractalDenoiseBlock(nn.Module):
    """
    One fractal level that:
      - Splits the image into patches
      - If child is not None, each patch is recursively denoised at a smaller scale
      - Then merges patches
      - Applies a local CNN to refine the merged result
    """
    def __init__(
        self,
        patch_size,
        overlap,
        in_channels,
        hidden_channels,
        child_block=None
    ):
        super().__init__()
        self.patch_size = patch_size
        self.overlap = overlap
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels

        # A local refining CNN or residual stack for this scale
        self.local_cnn = nn.Sequential(
            nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden_channels, in_channels, kernel_size=3, padding=1),
        )

        # If child is None => final fractal level (we do a mini patch denoiser).
        # Otherwise we recursively call child on patches.
        if child_block is not None:
            self.child_block = child_block
        else:
            self.child_block = MiniPatchDenoiser(in_channels, hidden_channels)

    def forward(self, x):
        """
        Denoising logic at this fractal level:
          1. If child_block is a FractalDenoiseBlock, we:
             - split x into patches
             - run each patch through child_block
             - merge them
          2. Then apply local_cnn (or final mini denoiser) to refine.
        """
        # if child_block is still a block, we do recursion on patches
        if isinstance(self.child_block, FractalDenoiseBlock):
            # Step 1: split image => patches
            patches = split_into_patches(
                x, patch_size=self.patch_size, overlap=self.overlap
            )
            B, N, C, Hp, Wp = patches.shape

            # Flatten for processing
            patches = patches.view(B*N, C, Hp, Wp)
            # Recursively denoise each patch at smaller fractal level
            refined_patches = self.child_block(patches)
            # Reshape back
            refined_patches = refined_patches.view(B, N, C, Hp, Wp)

            # Merge
            x_merged = merge_from_patches(
                refined_patches, x.shape, self.patch_size, self.overlap
            )
            # Then local refine
            return self.local_cnn(x_merged)
        else:
            # If child_block is our final "MiniPatchDenoiser", we just run it on x
            return self.child_block(x)


# --------------------------------------------------------------------------
# 4) Top-level "FractalDenoiseGenerator" Class
# --------------------------------------------------------------------------
class FractalDenoiseGenerator(nn.Module):
    """
    Recursively builds fractal-based denoising modules for multi-scale patch refinement.
    Similar to "FractalGen" in structure:
      - We define multiple fractal levels
      - Each level calls the next recursively, until final level is a mini patch denoiser
    """
    def __init__(
        self,
        patch_sizes,
        overlap,
        in_channels=3,
        hidden_channels=32
    ):
        """
        Args:
            patch_sizes (list): e.g. [16, 4, 1], from coarse to fine patches
            overlap (int): overlap in patches
            in_channels (int): e.g. 3 for RGB
            hidden_channels (int): internal channels for local refinements
        """
        super().__init__()
        # We'll build from the end (smallest patch) up to largest
        child = None
        # reversed(patch_sizes) => from small to large
        # final patch size => final fractal block is actually a mini denoiser instead
        for idx, psize in enumerate(reversed(patch_sizes)):
            if child is None:
                # final block
                block = FractalDenoiseBlock(
                    patch_size=psize,
                    overlap=overlap,
                    in_channels=in_channels,
                    hidden_channels=hidden_channels,
                    child_block=None
                )
            else:
                block = FractalDenoiseBlock(
                    patch_size=psize,
                    overlap=overlap,
                    in_channels=in_channels,
                    hidden_channels=hidden_channels,
                    child_block=child
                )
            child = block

        # 'child' is now the chain from smallest to largest patch
        # The largest patch block is the root we will call
        self.root_block = child

    def forward(self, noisy_img):
        """
        Orchestrates fractal recursion from large patch scale down to final pixel-level cleanup.

        Args:
            noisy_img: shape (B, C, H, W)

        Returns:
            denoised_img: shape (B, C, H, W)
        """
        return self.root_block(noisy_img)

# Discriminator

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, List, Optional, Dict, Union

# Configuration
CONFIG = {
    'IMAGE_SIZE': 256,            # Input image size
    'NUM_CHANNELS_IN': 3,         # Number of input channels (RGB)
    'NUM_CHANNELS_OUT': 3,        # Number of output channels (same as input for denoising)
    'INIT_FEATURES': 64,          # Initial number of features
    'DEPTH_LEVELS': 5,            # Number of downsampling/upsampling steps
    'EXPANSION_LEVEL': 3,         # Fractal block recursion depth
    'ACTIVATION': 'leaky_relu',   # Activation function - leaky_relu is better for GANs
    'USE_BATCHNORM': True,        # Use batch normalization
}


class FractalConvBlock(nn.Module):
    """
    Fractal Convolutional Block implementing equations (2) and (3) from the FractalSpiNet paper.
    
    This block recursively expands convolutional operations into fractal patterns,
    allowing for more complex feature extraction with fewer parameters.
    
    Args:
        in_channels (int): Number of input channels
        out_channels (int): Number of output channels
        expansion_level (int): Fractal recursion depth
        use_bn (bool): Whether to use batch normalization
        activation (str): Activation function to use ('relu' or 'leaky_relu')
        kernel_size (int): Kernel size for convolutions
    """
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        expansion_level: int = 1,
        use_bn: bool = True,
        activation: str = 'relu',
        kernel_size: int = 3
    ):
        super(FractalConvBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.expansion_level = expansion_level
        self.use_bn = use_bn
        self.activation = activation
        self.kernel_size = kernel_size
        
        # Define activation function
        if activation == 'relu':
            self.act = nn.ReLU(inplace=True)
        elif activation == 'leaky_relu':
            self.act = nn.LeakyReLU(0.2, inplace=True)
        else:
            raise ValueError(f"Unsupported activation: {activation}")
        
        # Base case: eq. (2) - simple convolution with optional BN and activation
        if expansion_level == 1:
            self.conv = nn.Conv2d(
                in_channels, 
                out_channels, 
                kernel_size=kernel_size, 
                padding=kernel_size//2,
                bias=not use_bn
            )
            
            if use_bn:
                self.bn = nn.BatchNorm2d(out_channels)
            
        # Recursive case: eq. (3) - fractal expansion
        else:
            # First branch: f_{c-1}(z)
            self.branch1 = FractalConvBlock(
                in_channels, 
                out_channels, 
                expansion_level - 1,
                use_bn,
                activation,
                kernel_size
            )
            
            # Second branch: f_{c-1}(f_{c-1}(z))
            self.branch2 = FractalConvBlock(
                out_channels, 
                out_channels, 
                expansion_level - 1,
                use_bn,
                activation,
                kernel_size
            )
            
            # 1x1 convolution to merge channels if needed
            self.skip_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass through the fractal block."""
        # Base case: single convolution (eq. 2)
        if self.expansion_level == 1:
            out = self.conv(x)
            if self.use_bn:
                out = self.bn(out)
            out = self.act(out)
            return out
        
        # Recursive case: fractal expansion (eq. 3)
        else:
            # First branch
            branch1_out = self.branch1(x)
            
            # Second branch (takes output of first branch as input)
            branch2_out = self.branch2(branch1_out)
            
            # Residual connection
            skip = self.skip_conv(x)
            
            # Combine with residual
            return branch2_out + skip



# Multiscale PatchGAN Discriminator with Fractal Blocks
class Discriminator(nn.Module):
    """
    Multiscale PatchGAN Discriminator network for conditional GAN training.
    Uses fractal blocks for feature extraction at multiple scales.
    
    Args:
        in_channels (int): Number of input channels per image
        init_features (int): Initial number of feature channels
        depth_levels (int): Number of downsampling steps
        expansion_level (int): Fractal recursion depth
        use_bn (bool): Whether to use batch normalization
        activation (str): Activation function to use
    """
    def __init__(
        self,
        in_channels: int = CONFIG['NUM_CHANNELS_IN'],
        init_features: int = 64,
        depth_levels: int = 3,
        expansion_level: int = 2,  # Fractal recursion depth
        use_bn: bool = True,
        activation: str = 'leaky_relu'
    ):
        super(Discriminator, self).__init__()
        
        # For conditional GAN, we input both noisy and clean/fake images
        # So we double the input channels
        combined_channels = in_channels * 2
        
        # Initial feature extraction
        self.initial_conv = nn.Sequential(
            nn.Conv2d(combined_channels, init_features, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        
        # Multiscale feature extraction with fractal blocks
        self.scales = nn.ModuleList()
        in_feat = init_features
        
        # Create different scale pathways
        for i in range(depth_levels):
            out_feat = in_feat * 2
            
            # Fractal block for feature extraction at this scale
            fractal_block = FractalConvBlock(
                in_feat,
                out_feat,
                expansion_level=expansion_level,
                use_bn=use_bn,
                activation=activation,
                kernel_size=3
            )
            
            # Downsampling after fractal block
            down_block = nn.Sequential(
                fractal_block,
                nn.AvgPool2d(kernel_size=2, stride=2)
            )
            
            # PatchGAN output for this scale
            output_layer = nn.Conv2d(out_feat, 1, kernel_size=4, stride=1, padding=1)
            
            # Add scale path to module list
            self.scales.append(nn.ModuleDict({
                'down': down_block,
                'output': output_layer
            }))
            
            in_feat = out_feat
    
    def forward(self, condition_img: torch.Tensor, target_img: torch.Tensor) -> List[torch.Tensor]:
        """
        Forward pass through the multiscale discriminator.
        
        Args:
            condition_img: The conditional input image (noisy image)
            target_img: The target image (real clean or generated fake)
            
        Returns:
            List of discriminator outputs at different scales
        """
        # Concatenate along channel dimension for conditional input
        x = torch.cat([condition_img, target_img], dim=1)
        
        # Initial features
        features = self.initial_conv(x)
        outputs = []
        
        # Process through each scale
        for scale in self.scales:
            features = scale['down'](features)
            outputs.append(scale['output'](features))
        
        return outputs

# Loss and utils

In [None]:
# ============================
#  Visualization Utility
# ============================
def show_images_from_batches(generator, dataloader, device, num_images=5):
    generator.eval()
    images_shown = 0
    plt.figure(figsize=(15, 5 * num_images), dpi=300)
    for noisy_imgs, clean_imgs in dataloader:
        if images_shown >= num_images:
            break
        noisy_imgs = noisy_imgs.to(device)
        with torch.no_grad():
            denoised_imgs = generator(noisy_imgs).cpu().numpy()
        clean_imgs = clean_imgs.cpu().numpy()
        noisy_imgs = noisy_imgs.cpu().numpy()
        for idx in range(noisy_imgs.shape[0]):
            if images_shown >= num_images:
                break
            real_img = np.clip(clean_imgs[idx].transpose(1, 2, 0), 0, 1)
            denoised_img = np.clip(denoised_imgs[idx].transpose(1, 2, 0), 0, 1)
            noisy_img = np.clip(noisy_imgs[idx].transpose(1, 2, 0), 0, 1)
            plt.subplot(num_images, 3, 3 * images_shown + 1)
            plt.imshow(noisy_img)
            plt.title(f"Noisy Image {images_shown+1}")
            plt.axis("off")
            plt.subplot(num_images, 3, 3 * images_shown + 2)
            plt.imshow(denoised_img)
            plt.title(f"Denoised Image {images_shown+1}")
            plt.axis("off")
            plt.subplot(num_images, 3, 3 * images_shown + 3)
            plt.imshow(real_img)
            plt.title(f"Clean Image {images_shown+1}")
            plt.axis("off")
            images_shown += 1
    plt.tight_layout()
    plt.show()


# =====================
#  Checkpoint Utilities
# =====================
def save_best_checkpoint(generator, critic, optimizer_G, optimizer_D,
                         epoch, loss_history, best_ssim, best_psnr, filepath):
    """
    Saves the state of the models and optimizers when a new best metric is achieved.
    """
    checkpoint = {
        'epoch': epoch,
        'generator_state_dict': generator.state_dict(),
        'critic_state_dict': critic.state_dict(),
        'optimizer_G_state_dict': optimizer_G.state_dict(),
        'optimizer_D_state_dict': optimizer_D.state_dict(),
        'loss_history': loss_history,
        'best_ssim': best_ssim,
        'best_psnr': best_psnr
    }
    torch.save(checkpoint, filepath)
    print(f"Best checkpoint saved at '{filepath}' "
          f"with SSIM: {best_ssim:.4f}, PSNR: {best_psnr:.2f}")


def load_checkpoint(generator, critic, optimizer_G, optimizer_D, filepath, device):
    """
    Loads the state of the models and optimizers from a checkpoint file.
    """
    checkpoint = torch.load(filepath, map_location=device)
    generator.load_state_dict(checkpoint['generator_state_dict'])
    critic.load_state_dict(checkpoint['critic_state_dict'])
    optimizer_G.load_state_dict(checkpoint['optimizer_G_state_dict'])
    optimizer_D.load_state_dict(checkpoint['optimizer_D_state_dict'])
    epoch = checkpoint['epoch']
    loss_history = checkpoint['loss_history']
    print(f"Checkpoint loaded from '{filepath}' (Epoch {epoch})")
    return epoch, loss_history
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List

# ======================
# MS-SSIM Loss
# ======================
class MS_SSIM_Loss(nn.Module):
    """
    Multi-Scale Structural Similarity (MS-SSIM) Loss.
    Implementation focusing only on structural similarity without L1 component.
    """
    def __init__(self, gaussian_sigmas=[0.5, 1.0, 2.0, 4.0, 8.0],
                 data_range=1.0,
                 K=(0.01, 0.03),
                 compensation=1.0,
                 cuda_dev=0):
        super(MS_SSIM_Loss, self).__init__()
        
        # Set the device attribute
        self.device = torch.device(f'cuda:{cuda_dev}' if torch.cuda.is_available() else 'cpu')
        
        self.DR = data_range
        self.C1 = (K[0] * data_range) ** 2
        self.C2 = (K[1] * data_range) ** 2
        self.pad = int(2 * gaussian_sigmas[-1])
        self.compensation = compensation
        filter_size = int(4 * gaussian_sigmas[-1] + 1)
        
        # Initialize Gaussian filters on the correct device
        g_masks = torch.zeros((3 * len(gaussian_sigmas), 1, filter_size, filter_size), device=self.device)
        for idx, sigma in enumerate(gaussian_sigmas):
            # r0, g0, b0, r1, g1, b1, ...
            g_masks[3 * idx + 0, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma)
            g_masks[3 * idx + 1, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma)
            g_masks[3 * idx + 2, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma)
        self.g_masks = g_masks  # Already on the correct device

    def _fspecial_gauss_1d(self, size, sigma):
        """Create 1-D Gaussian kernel"""
        coords = torch.arange(size, dtype=torch.float, device=self.device)
        coords -= size // 2
        g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
        g /= g.sum()
        return g.reshape(-1)

    def _fspecial_gauss_2d(self, size, sigma):
        """Create 2-D Gaussian kernel"""
        gaussian_vec = self._fspecial_gauss_1d(size, sigma)
        return torch.outer(gaussian_vec, gaussian_vec)

    def forward(self, x, y):
        """
        Compute MS-SSIM loss between images x and y.
        
        Args:
            x: Generated images tensor of shape (B, C, H, W)
            y: Target/ground truth images tensor of shape (B, C, H, W)
            
        Returns:
            Scalar tensor containing the loss value
        """
        b, c, h, w = x.shape
        mux = F.conv2d(x, self.g_masks, groups=3, padding=self.pad)
        muy = F.conv2d(y, self.g_masks, groups=3, padding=self.pad)

        mux2 = mux * mux
        muy2 = muy * muy
        muxy = mux * muy

        sigmax2 = F.conv2d(x * x, self.g_masks, groups=3, padding=self.pad) - mux2
        sigmay2 = F.conv2d(y * y, self.g_masks, groups=3, padding=self.pad) - muy2
        sigmaxy = F.conv2d(x * y, self.g_masks, groups=3, padding=self.pad) - muxy

        # l(j), cs(j) in MS-SSIM
        l = (2 * muxy + self.C1) / (mux2 + muy2 + self.C1)  # [B, 15, H, W]
        cs = (2 * sigmaxy + self.C2) / (sigmax2 + sigmay2 + self.C2)

        # Final MS-SSIM calculation
        lM = l[:, -1, :, :] * l[:, -2, :, :] * l[:, -3, :, :]
        PIcs = cs.prod(dim=1)

        # 1 - MS-SSIM as the loss (since MS-SSIM is a similarity measure)
        loss_ms_ssim = 1 - lM * PIcs  # [B, H, W]
        return self.compensation * loss_ms_ssim.mean()


class VGGPerceptualLoss(nn.Module):
    """
    Perceptual loss using pretrained VGG19 features.
    Compares high-level feature representations of images.
    """
    def __init__(self, device, layer_cutoff=21):
        super().__init__()
        # Load a pretrained VGG19 and slice it
        vgg19 = models.vgg19(pretrained=True).features
        self.features = nn.Sequential(*[vgg19[i] for i in range(layer_cutoff)])
        
        # Freeze the parameters to avoid training them
        for param in self.features.parameters():
            param.requires_grad = False
        
        # Move the features to the specified device
        self.features.to(device)
        self.criterion = nn.MSELoss()

    def forward(self, gen_img, target_img):
        """
        Compute perceptual loss between generated and target images.
        
        Args:
            gen_img: Generated image tensor
            target_img: Target image tensor
            
        Returns:
            Scalar tensor containing the perceptual loss
        """
        # Ensure the images have the same dimensions
        if gen_img.size() != target_img.size():
            gen_img = F.interpolate(gen_img, size=target_img.shape[2:], 
                                  mode='bilinear', align_corners=False)
        
        # Extract features
        gen_feat = self.features(gen_img)
        tgt_feat = self.features(target_img)
        
        # Compute MSE loss between features
        return self.criterion(gen_feat, tgt_feat)


# ======================
# Edge Loss
# ======================
def edge_loss(pred, target, eps=1e-3):
    """
    Edge loss using Laplacian filter.
    Detects and preserves edges in the image.
    
    Args:
        pred: Generated images tensor
        target: Target images tensor
        eps: Small epsilon to avoid numerical instability
        
    Returns:
        Scalar tensor containing the edge loss
    """
    laplacian_kernel = torch.tensor([[0.,  1., 0.],
                                     [1., -4., 1.],
                                     [0.,  1., 0.]], 
                                    dtype=torch.float32, device=pred.device).view(1, 1, 3, 3)
    b, c, h, w = pred.shape
    # Expand kernel for group conv on each channel
    kernel_expanded = laplacian_kernel.expand(c, -1, -1, -1)

    # Apply Laplacian filter to both prediction and target
    pred_lap = F.conv2d(pred, kernel_expanded, padding=1, groups=c)
    tgt_lap = F.conv2d(target, kernel_expanded, padding=1, groups=c)

    # Calculate Charbonnier loss between the Laplacian responses
    diff = pred_lap - tgt_lap
    return torch.sqrt(diff * diff + eps * eps).mean()


# ======================
# Charbonnier Loss
# ======================
def charbonnier_loss(pred, target, eps=0.001):
    """
    Charbonnier loss - a differentiable variant of L1 loss.
    More robust to outliers than MSE/L2 loss.
    
    Args:
        pred: Generated images tensor
        target: Target images tensor
        eps: Small epsilon to avoid numerical instability
        
    Returns:
        Scalar tensor containing the Charbonnier loss
    """
    diff = pred - target
    return torch.sqrt(diff * diff + eps * eps).mean()


# ======================
# Content Loss
# ======================
def compute_content_loss(fake_imgs, clean_imgs, 
                         vgg_perceptual_loss,
                         lambda_pixel=10,
                         lambda_edge=0.5,
                         lambda_vgg=1,
                         eps_char=0.001,
                         eps_edge=0.001):
    """
    Combined content loss using pixel (Charbonnier), edge, and VGG perceptual losses.
    
    Args:
        fake_imgs: Generated images tensor
        clean_imgs: Target images tensor
        vgg_perceptual_loss: VGG-based perceptual loss module
        lambda_pixel: Weight for pixel loss
        lambda_edge: Weight for edge loss
        lambda_vgg: Weight for VGG perceptual loss
        eps_char: Epsilon for Charbonnier loss
        eps_edge: Epsilon for edge loss
        
    Returns:
        Scalar tensor containing the weighted combination of losses
    """
    # Pixel (Charbonnier) loss
    pixel_l = charbonnier_loss(fake_imgs, clean_imgs, eps=eps_char)

    # Edge loss
    e_l = edge_loss(fake_imgs, clean_imgs, eps=eps_edge)

    # Perceptual loss
    p_l = vgg_perceptual_loss(fake_imgs, clean_imgs)

    # Weighted sum
    content_l = (lambda_pixel * pixel_l 
                 + lambda_edge * e_l 
                 + lambda_vgg * p_l)
    return content_l


# ======================
# Total Variation Loss
# ======================
def total_variation_loss(img):
    """
    Total Variation loss for image smoothness.
    Encourages spatial smoothness in generated images.
    
    Args:
        img: Generated images tensor
        
    Returns:
        Scalar tensor containing the TV loss
    """
    tv_loss = torch.mean(torch.abs(img[:, :, 1:, :] - img[:, :, :-1, :])) + \
              torch.mean(torch.abs(img[:, :, :, 1:] - img[:, :, :, :-1]))
    return tv_loss


# ======================
# Dual Denoising Loss
# ======================
def dual_denoising_loss(clean_img, denoised_img, generator):
    """
    Dual denoising loss - feeds denoised output back into generator.
    Ensures the generator is robust to its own outputs.
    
    Args:
        clean_img: Target clean images tensor
        denoised_img: First-pass denoised images tensor
        generator: The generator model
        
    Returns:
        Scalar tensor containing the dual denoising loss
    """
    # Run the denoised image through the generator again
    re_denoised = generator(denoised_img.detach())
    return F.smooth_l1_loss(re_denoised, clean_img)


# ======================
# Multiscale Adversarial Loss
# ======================
def multiscale_adversarial_loss(discriminator_outputs, is_real=True):
    """
    Calculates adversarial loss across multiple scales of the discriminator.
    Each scale contributes individually to the loss, preserving detailed feedback.
    
    Args:
        discriminator_outputs: List of tensors from different scales of the discriminator
        is_real: True for real samples, False for fake samples
    
    Returns:
        Tuple of (total loss, list of individual scale losses)
    """
    scale_losses = []
    weights = []
    
    for i, output in enumerate(discriminator_outputs):
        # Higher weight for smaller scales to emphasize texture details
        scale_weight = 1.0 / (i + 1)
        weights.append(scale_weight)
        
        if is_real:
            # For real samples, maximize output (WGAN approach)
            scale_losses.append(-output.mean() * scale_weight)
        else:
            # For fake samples, minimize output (WGAN approach)
            scale_losses.append(output.mean() * scale_weight)
    
    # Normalize by sum of weights
    total_weight = sum(weights)
    total_loss = sum(scale_losses) / total_weight
    
    return total_loss, scale_losses


# ======================
# Multiscale Gradient Penalty
# ======================
def gradient_penalty_multiscale(critic, noisy_imgs, real_clean, fake_denoised, device, lambda_gp=10):
    """
    WGAN-GP gradient penalty that works with multi-scale discriminator outputs.
    
    Args:
        critic: Multi-scale discriminator
        noisy_imgs: Noisy input images
        real_clean: Ground truth clean images
        fake_denoised: Generator output images
        device: Computation device
        lambda_gp: Gradient penalty weight
    
    Returns:
        Gradient penalty loss
    """
    batch_size = real_clean.size(0)
    alpha = torch.rand(batch_size, 1, 1, 1, device=device)
    
    # Interpolate between real and fake samples
    interpolated = (alpha * real_clean + (1 - alpha) * fake_denoised).requires_grad_(True)
    
    # Get discriminator output for interpolated samples
    d_interpolated = critic(noisy_imgs, interpolated)
    
    # We need to calculate gradient penalty for each scale
    total_gp = 0
    weights = []
    
    for i, output in enumerate(d_interpolated):
        # Higher weight for smaller scales
        scale_weight = 1.0 / (i + 1)
        weights.append(scale_weight)
        
        fake = torch.ones_like(output, requires_grad=False, device=device)
        
        # Get gradients
        gradients = torch.autograd.grad(
            outputs=output,
            inputs=interpolated,
            grad_outputs=fake,
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        
        # Flatten gradients
        gradients = gradients.view(batch_size, -1)
        
        # Calculate gradient penalty for this scale
        gradient_norm = gradients.norm(2, dim=1)
        scale_gp = ((gradient_norm - 1) ** 2).mean() * scale_weight
        
        total_gp += scale_gp
    
    # Normalize by sum of weights
    total_weight = sum(weights)
    return (total_gp / total_weight) * lambda_gp


# Traning Loop

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim_sk

# Assuming all models, loss functions, etc. are defined above

# ===============
#  Hyperparameters
# ===============
epochs = 100
n_critic = 7
lambda_gp = 10
smoothing_lambda = 0.09
lambda_ssim = 0.5
lambda_dd = 0.5       # Dual denoising weight
perceptual_lambda = 0.1
lambda_char = 1.0

best_ssim = 0.0
best_psnr = 0.0
loss_history = []

# Loss weights
lambda_adv = 0.01           # Adversarial loss for generator
lambda_tv = 0.01            # Total variation
lambda_dual = 0.1           # Dual denoising
lambda_content = 1.0        # Overall content loss weight

# Inside content loss
lambda_pixel = 10.0         # Charbonnier/pixel loss
lambda_edge = 0.5           # Edge loss (if used)
lambda_vgg = 1.0            # VGG perceptual

# Charbonnier/edge epsilon
charbonnier_eps = 0.001     # from text references

# ===============
#  Model & Optim
# ===============
# Initialize models
generator = FractalDenoiseGenerator(
    patch_sizes=[16, 4],  # or whatever patch scales you want
    overlap=2,            # or 0 if you prefer no overlap
    in_channels=3,        # defaults to 3
    hidden_channels=32    # defaults to 32
).to(device)
critic = Discriminator().to(device)

# If you have multiple GPUs:
# if torch.cuda.device_count() > 1:
#     print("Using", torch.cuda.device_count(), "GPUs!")
#     generator = nn.DataParallel(generator)
#     critic = nn.DataParallel(critic)

optimizer_G = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.9, 0.999))
optimizer_D = optim.Adam(critic.parameters(), lr=2e-4, betas=(0.9, 0.999))

# (Optional) Resume from a previous checkpoint
# resume_training = True
# checkpoint_path = "/kaggle/input/12-epoch-lfg/best_generator_ssim.pth"
# if resume_training:
#     start_epoch, loss_history = load_checkpoint(
#         generator, critic, optimizer_G, optimizer_D,
#         filepath=checkpoint_path, 
#         device=device
#     )
#     print(f"Resuming training from epoch {start_epoch+1}")
# else:
#     start_epoch = 0

# Create MSSSIM loss and VGG perceptual loss
msssim_loss = MS_SSIM_Loss(cuda_dev=0).to(device)
vgg_perceptual_loss = VGGPerceptualLoss(device=device, layer_cutoff=21)


def calculate_metrics(generator, dataloader, device, win_size=3):
    generator.eval()
    psnr_values = []
    ssim_values = []

    with torch.no_grad():
        for i, (noisy_imgs, clean_imgs) in enumerate(dataloader):
            noisy_imgs = noisy_imgs.to(device)
            clean_imgs = clean_imgs.cpu().numpy()

            denoised_imgs = generator(noisy_imgs).cpu().numpy()

            # Calculate PSNR and SSIM for each image
            for denoised, clean in zip(denoised_imgs, clean_imgs):
                # Convert to HWC
                denoised = np.transpose(denoised, (1, 2, 0))
                clean = np.transpose(clean, (1, 2, 0))

                psnr_value = psnr(clean, denoised, data_range=1)
                ssim_value = ssim_sk(clean, denoised, data_range=1, win_size=win_size, multichannel=True)

                psnr_values.append(psnr_value)
                ssim_values.append(ssim_value)

    avg_psnr = np.mean(psnr_values)
    avg_ssim = np.mean(ssim_values)
    return avg_psnr, avg_ssim


# ==================
#     Training
# ==================
start_epoch = 0  # or from checkpoint if resuming

# Track metrics for plotting
critic_losses = []
gen_losses = []
content_losses = []
msssim_losses = []
tv_losses = []
dual_losses = []
psnr_metrics = []
ssim_metrics = []

########################
#   TRAINING LOOP
########################
for epoch in range(start_epoch + 1, epochs + 1):
    generator.train()
    critic.train()
    
    # Track batch losses for epoch average
    epoch_critic_loss = 0
    epoch_gen_adv_loss = 0
    epoch_content_loss = 0
    epoch_msssim_loss = 0
    epoch_tv_loss = 0
    epoch_dual_loss = 0
    batch_count = 0

    for i, (noisy_imgs, clean_imgs) in enumerate(train_loader):
        noisy_imgs = noisy_imgs.to(device)
        clean_imgs = clean_imgs.to(device)
        batch_count += 1

        ################################################
        # (1) TRAIN CRITIC (WGAN-GP)
        ################################################
        for _ in range(n_critic):
            fake_imgs = generator(noisy_imgs)
            optimizer_D.zero_grad()

            # Get real and fake scores (now returns lists)
            real_outputs = critic(noisy_imgs, clean_imgs)
            fake_outputs = critic(noisy_imgs, fake_imgs.detach())
            
            # Calculate losses for each scale
            real_loss, real_scale_losses = multiscale_adversarial_loss(real_outputs, is_real=True)
            fake_loss, fake_scale_losses = multiscale_adversarial_loss(fake_outputs, is_real=False)
            
            # Calculate multi-scale gradient penalty
            gp = gradient_penalty_multiscale(
                critic, noisy_imgs, clean_imgs, fake_imgs.detach(), device, lambda_gp=lambda_gp
            )
            
            # Combined critic loss
            critic_loss = real_loss + fake_loss + gp

            critic_loss.backward()
            optimizer_D.step()

        ################################################
        # (2) TRAIN GENERATOR
        ################################################
        optimizer_G.zero_grad()

        # Re-generate fakes (after critic update)
        fake_imgs = generator(noisy_imgs)
        
        # Get discriminator outputs for fake images
        fake_outputs = critic(noisy_imgs, fake_imgs)
        
        # Calculate adversarial loss for generator
        gen_adv_loss, gen_scale_losses = multiscale_adversarial_loss(fake_outputs, is_real=True)
        gen_adv_loss_scaled = lambda_adv * gen_adv_loss

        # Content loss (Charbonnier + Edge + VGG)
        cont_l = compute_content_loss(
            fake_imgs, clean_imgs,
            vgg_perceptual_loss=vgg_perceptual_loss,
            lambda_pixel=lambda_pixel,
            lambda_edge=lambda_edge,
            lambda_vgg=lambda_vgg,
            eps_char=charbonnier_eps,
            eps_edge=charbonnier_eps
        )
        content_loss_scaled = lambda_content * cont_l

        # MS-SSIM loss (replaced MS-SSIM-L1)
        msssim_val = msssim_loss(fake_imgs, clean_imgs)
        msssim_scaled = lambda_ssim * msssim_val

        # Total Variation loss
        tv_l = total_variation_loss(fake_imgs)
        tv_loss_scaled = lambda_tv * tv_l

        # Dual Denoising loss
        dd_l = dual_denoising_loss(clean_imgs, fake_imgs.detach(), generator)
        dd_loss_scaled = lambda_dual * dd_l

        # Combine all generator losses
        total_gen_loss = (gen_adv_loss_scaled
                         + content_loss_scaled
                         + msssim_scaled
                         + tv_loss_scaled
                         + dd_loss_scaled)

        total_gen_loss.backward()
        optimizer_G.step()
        
        # Track losses for this epoch
        epoch_critic_loss += critic_loss.item()
        epoch_gen_adv_loss += gen_adv_loss.item()
        epoch_content_loss += cont_l.item()
        epoch_msssim_loss += msssim_val.item()
        epoch_tv_loss += tv_l.item()
        epoch_dual_loss += dd_l.item()
        
        # Print batch progress
        # if (i + 1) % 10 == 0:
        #     print(f"[Epoch {epoch}/{epochs}] [Batch {i+1}/{len(train_loader)}] "
        #           f"D Loss: {critic_loss.item():.4f}, G Loss: {total_gen_loss.item():.4f}")

    # Calculate epoch averages
    epoch_critic_loss /= batch_count
    epoch_gen_adv_loss /= batch_count
    epoch_content_loss /= batch_count
    epoch_msssim_loss /= batch_count
    epoch_tv_loss /= batch_count
    epoch_dual_loss /= batch_count
    
    # End-of-epoch logging
    print(f"[Epoch {epoch}/{epochs}] "
          f"Critic Loss: {epoch_critic_loss:.4f} "
          f"Gen Adv Loss: {epoch_gen_adv_loss:.4f} "
          f"Content: {epoch_content_loss:.4f} "
          f"MS-SSIM: {epoch_msssim_loss:.4f} "
          f"TV: {epoch_tv_loss:.4f} "
          f"Dual: {epoch_dual_loss:.4f}")

    # Validation and checkpoint saving...
    avg_psnr, avg_ssim = calculate_metrics(generator, val_loader, device, win_size=3)
    print(f"Validation -- PSNR: {avg_psnr:.2f}  SSIM: {avg_ssim:.4f}")
    
    # Store metrics for plotting
    critic_losses.append(epoch_critic_loss)
    gen_losses.append(epoch_gen_adv_loss)
    content_losses.append(epoch_content_loss)
    msssim_losses.append(epoch_msssim_loss)
    tv_losses.append(epoch_tv_loss)
    dual_losses.append(epoch_dual_loss)
    psnr_metrics.append(avg_psnr)
    ssim_metrics.append(avg_ssim)

    # Store in loss history
    loss_history.append((epoch,
                         epoch_critic_loss,
                         epoch_gen_adv_loss,
                         epoch_content_loss,
                         epoch_msssim_loss,
                         epoch_tv_loss,
                         epoch_dual_loss,
                         avg_psnr,
                         avg_ssim))
    
    # Save checkpoints
    if avg_ssim > best_ssim:
        best_ssim = avg_ssim
        save_best_checkpoint(generator, critic, optimizer_G, optimizer_D,
                             epoch, loss_history, best_ssim, best_psnr,
                             'best_generator_ssim.pth')
    if avg_psnr > best_psnr:
        best_psnr = avg_psnr
        save_best_checkpoint(generator, critic, optimizer_G, optimizer_D,
                             epoch, loss_history, best_ssim, best_psnr,
                             'best_generator_psnr.pth')
    
    # Visualize sample results periodically
    if epoch % 20 == 0:
        show_images_from_batches(generator, val_loader, device, num_images=15)

print("Training complete!")

# ==================
#   Plot Losses
# ==================
epochs_range = range(1, epochs + 1)

# Plot loss curves
plt.figure(figsize=(15, 10))

# Plot 1: Generator and Discriminator losses
plt.subplot(3, 2, 1)
plt.plot(epochs_range, critic_losses, 'b-', label='Critic Loss')
plt.plot(epochs_range, gen_losses, 'r-', label='Generator Adv Loss')
plt.title('Adversarial Losses')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Plot 2: Content Loss
plt.subplot(3, 2, 2)
plt.plot(epochs_range, content_losses, 'g-', label='Content Loss')
plt.title('Content Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Plot 3: MS-SSIM Loss
plt.subplot(3, 2, 3)
plt.plot(epochs_range, msssim_losses, 'c-', label='MS-SSIM Loss')
plt.title('MS-SSIM Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Plot 4: TV and Dual Losses
plt.subplot(3, 2, 4)
plt.plot(epochs_range, tv_losses, 'm-', label='TV Loss')
plt.plot(epochs_range, dual_losses, 'y-', label='Dual Denoising Loss')
plt.title('Regularization Losses')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Plot 5: PSNR
plt.subplot(3, 2, 5)
plt.plot(epochs_range, psnr_metrics, 'b-', label='PSNR')
plt.title('PSNR Metric')
plt.xlabel('Epochs')
plt.ylabel('PSNR (dB)')
plt.legend()
plt.grid(True)

# Plot 6: SSIM
plt.subplot(3, 2, 6)
plt.plot(epochs_range, ssim_metrics, 'r-', label='SSIM')
plt.title('SSIM Metric')
plt.xlabel('Epochs')
plt.ylabel('SSIM')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig('training_metrics.png', dpi=300)
plt.show()

# ==================
# Plot Individual Metrics
# ==================

# PSNR over epochs
plt.figure(figsize=(12, 6))
plt.plot(epochs_range, psnr_metrics, 'b-', linewidth=2)
plt.title('PSNR Over Training Epochs', fontsize=16)
plt.xlabel('Epochs', fontsize=14)
plt.ylabel('PSNR (dB)', fontsize=14)
plt.grid(True)
plt.savefig('psnr_progress.png', dpi=300)
plt.show()

# SSIM over epochs
plt.figure(figsize=(12, 6))
plt.plot(epochs_range, ssim_metrics, 'r-', linewidth=2)
plt.title('SSIM Over Training Epochs', fontsize=16)
plt.xlabel('Epochs', fontsize=14)
plt.ylabel('SSIM', fontsize=14)
plt.grid(True)
plt.savefig('ssim_progress.png', dpi=300)
plt.show()

# Combined MS-SSIM and Content loss
plt.figure(figsize=(12, 6))
plt.plot(epochs_range, msssim_losses, 'c-', label='MS-SSIM Loss', linewidth=2)
plt.plot(epochs_range, content_losses, 'g-', label='Content Loss', linewidth=2)
plt.title('Perceptual Quality Losses', fontsize=16)
plt.xlabel('Epochs', fontsize=14)
plt.ylabel('Loss', fontsize=14)
plt.legend(fontsize=12)
plt.grid(True)
plt.savefig('perceptual_losses.png', dpi=300)
plt.show()