In [None]:
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision.models import vgg19
import math

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class ComplexConv2d(nn.Module):
    """
    Complex-valued convolutional layer implementation
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(ComplexConv2d, self).__init__()
        self.conv_real = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.conv_imag = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        
    def forward(self, x):
        # x is a tuple of (real, imag)
        real, imag = x
        real_out = self.conv_real(real) - self.conv_imag(imag)
        imag_out = self.conv_real(imag) + self.conv_imag(real)
        return (real_out, imag_out)

class DIV2KDataset(Dataset):
    """
    Dataset for DIV2K with both LR and HR images
    """
    def __init__(self, root_dir, split="train", scale=2, patch_size=96, transform=None):
        """
        Args:
            root_dir: Directory with DIV2K dataset
            split: 'train', 'valid', or 'test'
            scale: Downscaling factor (2, 3, or 4)
            patch_size: Size of cropped patches for training
            transform: Additional transforms to apply
        """
        self.root_dir = root_dir
        self.split = split
        self.scale = scale
        self.patch_size = patch_size
        self.transform = transform
        
        # Define paths
        self.hr_dir = os.path.join(root_dir, f"DIV2K_{split}_HR")
        self.lr_dir = os.path.join(root_dir, f"DIV2K_{split}_LR_bicubic", f"X{scale}")
        
        # Get all HR images
        self.hr_images = sorted(glob.glob(os.path.join(self.hr_dir, "*.png")))
        
        # Make sure we have files
        if len(self.hr_images) == 0:
            raise RuntimeError(f"No images found in {self.hr_dir}")
        
        # Basic transforms
        self.to_tensor = transforms.ToTensor()
        
    def __len__(self):
        return len(self.hr_images)
    
    def __getitem__(self, idx):
        # Get HR image path
        hr_path = self.hr_images[idx]
        
        # Extract image ID from filename (assuming format like "0001.png")
        img_id = os.path.basename(hr_path).split('.')[0]
        
        # Construct LR image path (e.g., "0001x2.png")
        lr_path = os.path.join(self.lr_dir, f"{img_id}x{self.scale}.png")
        
        # Load images
        hr_img = Image.open(hr_path).convert('RGB')
        lr_img = Image.open(lr_path).convert('RGB')
        
        # Random crop for training
        if self.split == "train":
            # Get dimensions
            hr_width, hr_height = hr_img.size
            lr_width, lr_height = lr_img.size
            
            # Randomly select patch
            hr_patch_size = self.patch_size
            lr_patch_size = hr_patch_size // self.scale
            
            # Ensure we can extract a patch of the desired size
            if hr_width < hr_patch_size or hr_height < hr_patch_size:
                # If image is smaller than patch_size, resize it
                hr_img = transforms.Resize((hr_patch_size, hr_patch_size))(hr_img)
                lr_img = transforms.Resize((lr_patch_size, lr_patch_size))(lr_img)
            else:
                # Random crop
                hr_x = np.random.randint(0, hr_width - hr_patch_size + 1)
                hr_y = np.random.randint(0, hr_height - hr_patch_size + 1)
                
                lr_x = hr_x // self.scale
                lr_y = hr_y // self.scale
                
                hr_img = hr_img.crop((hr_x, hr_y, hr_x + hr_patch_size, hr_y + hr_patch_size))
                lr_img = lr_img.crop((lr_x, lr_y, lr_x + lr_patch_size, lr_y + lr_patch_size))
        
        # Apply transforms
        hr_tensor = self.to_tensor(hr_img)
        lr_tensor = self.to_tensor(lr_img)
        
        if self.transform:
            hr_tensor = self.transform(hr_tensor)
            lr_tensor = self.transform(lr_tensor)
        
        return {'lr': lr_tensor, 'hr': hr_tensor}

class ConvBlock(nn.Module):
    """
    Basic convolutional block with batch normalization and ReLU activation
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

class ResidualBlock(nn.Module):
    """
    Residual block with two convolutional layers
    """
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvBlock(channels, channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn2(self.conv2(out))
        out += residual
        return self.relu(out)

class DownBlock(nn.Module):
    """
    Downsampling block for U-Net style architecture
    """
    def __init__(self, in_channels, out_channels):
        super(DownBlock, self).__init__()
        self.conv = ConvBlock(in_channels, out_channels)
        self.pool = nn.MaxPool2d(2)
        
    def forward(self, x):
        features = self.conv(x)
        out = self.pool(features)
        return out, features

class UpBlock(nn.Module):
    """
    Upsampling block for U-Net style architecture
    """
    def __init__(self, in_channels, out_channels):
        super(UpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = ConvBlock(in_channels, out_channels)
        
    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)

class SpatialBranch(nn.Module):
    """
    Spatial branch using a U-Net like architecture
    """
    def __init__(self, in_channels=3, base_channels=64):
        super(SpatialBranch, self).__init__()
        self.init_conv = ConvBlock(in_channels, base_channels)
        
        # Encoder
        self.down1 = DownBlock(base_channels, base_channels*2)
        self.down2 = DownBlock(base_channels*2, base_channels*4)
        self.down3 = DownBlock(base_channels*4, base_channels*8)
        
        # Bridge
        self.bridge = ConvBlock(base_channels*8, base_channels*16)
        
        # Decoder
        self.up1 = UpBlock(base_channels*16, base_channels*8)
        self.up2 = UpBlock(base_channels*8, base_channels*4)
        self.up3 = UpBlock(base_channels*4, base_channels*2)
        
        # Final layers
        self.final_conv = ConvBlock(base_channels*2, base_channels)
        
    def forward(self, x):
        # Initial features
        x1 = self.init_conv(x)
        
        # Encoder
        x2, skip1 = self.down1(x1)
        x3, skip2 = self.down2(x2)
        x4, skip3 = self.down3(x3)
        
        # Bridge
        x5 = self.bridge(x4)
        
        # Decoder with skip connections
        x = self.up1(x5, skip3)
        x = self.up2(x, skip2)
        x = self.up3(x, skip1)
        
        # Final processing
        features = self.final_conv(x)
        
        return features

class FrequencyBranch(nn.Module):
    """
    Frequency branch that processes the image in frequency domain
    """
    def __init__(self, in_channels=3, base_channels=64):
        super(FrequencyBranch, self).__init__()
        
        # Processing for magnitude
        self.mag_conv1 = ConvBlock(in_channels, base_channels)
        self.mag_res1 = ResidualBlock(base_channels)
        self.mag_res2 = ResidualBlock(base_channels)
        
        # Processing for phase
        self.phase_conv1 = ConvBlock(in_channels, base_channels)
        self.phase_res1 = ResidualBlock(base_channels)
        self.phase_res2 = ResidualBlock(base_channels)
        
        # Final processing
        self.final_conv = ConvBlock(base_channels*2, base_channels)
        
    def forward(self, x):
        # Apply FFT to convert to frequency domain
        # Need to process each channel separately
        batch_size, channels, height, width = x.shape
        
        # Create tensors for storing real and imaginary parts
        real_fft = torch.zeros((batch_size, channels, height, width), device=x.device)
        imag_fft = torch.zeros((batch_size, channels, height, width), device=x.device)
        
        # Apply FFT to each channel
        for c in range(channels):
            fft_result = torch.fft.fft2(x[:, c, :, :])
            real_fft[:, c, :, :] = fft_result.real
            imag_fft[:, c, :, :] = fft_result.imag
        
        # Compute magnitude and phase
        magnitude = torch.sqrt(real_fft**2 + imag_fft**2)
        phase = torch.atan2(imag_fft, real_fft)
        
        # Apply log scaling to magnitude for better dynamic range
        magnitude = torch.log(magnitude + 1e-8)
        
        # Process magnitude
        mag_features = self.mag_conv1(magnitude)
        mag_features = self.mag_res1(mag_features)
        mag_features = self.mag_res2(mag_features)
        
        # Process phase
        phase_features = self.phase_conv1(phase)
        phase_features = self.phase_res1(phase_features)
        phase_features = self.phase_res2(phase_features)
        
        # Concatenate features from magnitude and phase
        combined = torch.cat([mag_features, phase_features], dim=1)
        
        # Final processing
        features = self.final_conv(combined)
        
        return features

class FusionModule(nn.Module):
    """
    Fuses features from spatial and frequency domains
    """
    def __init__(self, in_channels=64, out_channels=64):
        super(FusionModule, self).__init__()
        
        # Concatenate and process
        self.fusion_conv = ConvBlock(in_channels*2, in_channels)
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=1),
            nn.Sigmoid()
        )
        
        # Final processing
        self.final_conv = ConvBlock(in_channels, out_channels)
        
    def forward(self, spatial_features, freq_features):
        # Concatenate features
        concat_features = torch.cat([spatial_features, freq_features], dim=1)
        
        # Initial fusion
        fused = self.fusion_conv(concat_features)
        
        # Attention mechanism
        attention_weights = self.attention(fused)
        
        # Apply attention
        fused = fused * attention_weights
        
        # Final processing
        out = self.final_conv(fused)
        
        return out

class ReconstructionModule(nn.Module):
    """
    Reconstructs the high-resolution image from fused features
    """
    def __init__(self, in_channels=64, out_channels=3, scale_factor=2):
        super(ReconstructionModule, self).__init__()
        
        self.scale_factor = scale_factor
        
        # Upsampling layers
        self.upconv1 = nn.Conv2d(in_channels, in_channels*4, kernel_size=3, padding=1)
        self.pixel_shuffle1 = nn.PixelShuffle(2)
        self.upconv_relu1 = nn.ReLU(inplace=True)
        
        # Add another upsampling layer if scale factor is 4
        if scale_factor == 4:
            self.upconv2 = nn.Conv2d(in_channels, in_channels*4, kernel_size=3, padding=1)
            self.pixel_shuffle2 = nn.PixelShuffle(2)
            self.upconv_relu2 = nn.ReLU(inplace=True)
        
        # Final reconstruction
        self.final_conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        
    def forward(self, x):
        # First upsampling
        x = self.upconv1(x)
        x = self.pixel_shuffle1(x)
        x = self.upconv_relu1(x)
        
        # Second upsampling if scale factor is 4
        if self.scale_factor == 4:
            x = self.upconv2(x)
            x = self.pixel_shuffle2(x)
            x = self.upconv_relu2(x)
        elif self.scale_factor == 3:
            # For scale factor 3, we need to use interpolation
            x = F.interpolate(x, scale_factor=1.5, mode='bilinear', align_corners=False)
        
        # Final reconstruction
        x = self.final_conv(x)
        
        return x

class HybridFrequencySpatialNetwork(nn.Module):
    """
    Complete network combining spatial and frequency branches
    """
    def __init__(self, in_channels=3, out_channels=3, base_channels=64, scale_factor=2):
        super(HybridFrequencySpatialNetwork, self).__init__()
        
        self.spatial_branch = SpatialBranch(in_channels, base_channels)
        self.frequency_branch = FrequencyBranch(in_channels, base_channels)
        self.fusion_module = FusionModule(base_channels, base_channels)
        self.reconstruction_module = ReconstructionModule(base_channels, out_channels, scale_factor)
        
    def forward(self, x):
        # Extract features from both branches
        spatial_features = self.spatial_branch(x)
        freq_features = self.frequency_branch(x)
        
        # Fuse features
        fused_features = self.fusion_module(spatial_features, freq_features)
        
        # Reconstruct high-resolution image
        output = self.reconstruction_module(fused_features)
        
        return output

class PerceptualLoss(nn.Module):
    """
    Perceptual loss using VGG19 features
    """
    def __init__(self, feature_layer=35):
        super(PerceptualLoss, self).__init__()
        vgg = vgg19(pretrained=True).features
        self.feature_extractor = nn.Sequential(*list(vgg.children())[:feature_layer]).eval()
        
        # Freeze VGG parameters
        for param in self.feature_extractor.parameters():
            param.requires_grad = False
            
    def forward(self, x, target):
        # Extract features
        x_features = self.feature_extractor(x)
        target_features = self.feature_extractor(target)
        
        # Calculate MSE loss between features
        loss = F.mse_loss(x_features, target_features)
        
        return loss

def calculate_psnr(img1, img2):
    """
    Calculate PSNR between two images
    """
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0
    psnr = 20 * math.log10(max_pixel / math.sqrt(mse))
    return psnr

def calculate_ssim(img1, img2):
    """
    Calculate SSIM between two images
    Basic implementation - for production use a library like scikit-image
    """
    C1 = (0.01 * 1) ** 2
    C2 = (0.03 * 1) ** 2
    
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    
    kernel = np.ones((11, 11)) / 121
    
    mu1 = np.zeros_like(img1)
    mu2 = np.zeros_like(img2)
    sigma1_sq = np.zeros_like(img1)
    sigma2_sq = np.zeros_like(img2)
    sigma12 = np.zeros_like(img1)
    
    # For each color channel
    for i in range(3):
        mu1[:, :, i] = np.convolve(img1[:, :, i].flatten(), kernel, mode='valid').reshape(img1.shape[0]-10, img1.shape[1]-10)
        mu2[:, :, i] = np.convolve(img2[:, :, i].flatten(), kernel, mode='valid').reshape(img2.shape[0]-10, img2.shape[1]-10)
        sigma1_sq[:, :, i] = np.convolve((img1[:, :, i] - mu1[:, :, i])**2, kernel, mode='valid').reshape(img1.shape[0]-10, img1.shape[1]-10)
        sigma2_sq[:, :, i] = np.convolve((img2[:, :, i] - mu2[:, :, i])**2, kernel, mode='valid').reshape(img2.shape[0]-10, img2.shape[1]-10)
        sigma12[:, :, i] = np.convolve((img1[:, :, i] - mu1[:, :, i]) * (img2[:, :, i] - mu2[:, :, i]), kernel, mode='valid').reshape(img1.shape[0]-10, img1.shape[1]-10)
    
    # Formula for SSIM
    numerator = (2 * mu1 * mu2 + C1) * (2 * sigma12 + C2)
    denominator = (mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2)
    ssim_map = numerator / denominator
    
    # Return mean SSIM across channels
    return np.mean(ssim_map)

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=50, device="cuda"):
    """
    Train the model
    """
    # Initialize best validation PSNR
    best_val_psnr = 0.0
    train_losses = []
    val_psnrs = []
    
    # Move model to device
    model = model.to(device)
    
    # Training loop
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        print("-" * 10)
        
        # Training phase
        model.train()
        running_loss = 0.0
        
        for batch in train_loader:
            lr_imgs = batch['lr'].to(device)
            hr_imgs = batch['hr'].to(device)
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(lr_imgs)
            
            # Calculate loss
            loss = criterion(outputs, hr_imgs)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Update running loss
            running_loss += loss.item() * lr_imgs.size(0)
        
        # Calculate epoch loss
        epoch_loss = running_loss / len(train_loader.dataset)
        train_losses.append(epoch_loss)
        
        # Update learning rate
        scheduler.step()
        
        # Validation phase
        model.eval()
        val_psnr = 0.0
        
        with torch.no_grad():
            for batch in val_loader:
                lr_imgs = batch['lr'].to(device)
                hr_imgs = batch['hr'].to(device)
                
                # Forward pass
                outputs = model(lr_imgs)
                
                # Calculate PSNR
                for i in range(outputs.size(0)):
                    # Convert to numpy for PSNR calculation
                    output = outputs[i].cpu().numpy().transpose(1, 2, 0)
                    target = hr_imgs[i].cpu().numpy().transpose(1, 2, 0)
                    
                    # Clip predictions to valid range
                    output = np.clip(output, 0, 1)
                    
                    # Calculate PSNR
                    psnr = calculate_psnr(output, target)
                    val_psnr += psnr
        
        # Calculate average PSNR
        val_psnr /= len(val_loader.dataset)
        val_psnrs.append(val_psnr)
        
        # Print epoch results
        print(f"Train Loss: {epoch_loss:.4f}, Val PSNR: {val_psnr:.4f}")
        
        # Save best model
        if val_psnr > best_val_psnr:
            best_val_psnr = val_psnr
            torch.save(model.state_dict(), 'best_model.pth')
            print("New best model saved!")
    
    # Plot training curves
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses)
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    
    plt.subplot(1, 2, 2)
    plt.plot(val_psnrs)
    plt.title('Validation PSNR')
    plt.xlabel('Epoch')
    plt.ylabel('PSNR (dB)')
    
    plt.tight_layout()
    plt.savefig('training_curves.png')
    plt.close()
    
    return model, train_losses, val_psnrs

def main():
    # Hyperparameters
    batch_size = 16
    lr = 1e-4
    num_epochs = 50
    scale_factor = 2  # 2, 3, or 4
    patch_size = 96  # HR patch size
    
    # Data root directory
    data_root = 'DIV2K'
    
    # Prepare datasets
    train_dataset = DIV2KDataset(data_root, split='train', scale=scale_factor, patch_size=patch_size)
    val_dataset = DIV2KDataset(data_root, split='valid', scale=scale_factor, patch_size=patch_size)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)
    
    # Create model
    model = HybridFrequencySpatialNetwork(in_channels=3, out_channels=3, base_channels=64, scale_factor=scale_factor)
    
    # Define loss function
    l1_loss = nn.L1Loss()
    perceptual_loss = PerceptualLoss().to(device)
    
    # Combined loss function
    def criterion(outputs, targets):
        # L1 loss for pixel-level accuracy
        loss_l1 = l1_loss(outputs, targets)
        
        # Perceptual loss for better visual quality
        loss_percep = perceptual_loss(outputs, targets)
        
        # Combine losses
        return loss_l1 + 0.1 * loss_percep
    
    # Define optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Define scheduler
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=lr*0.1)
    
    # Train model
    model, train_losses, val_psnrs = train_model(
        model, train_loader, val_loader, criterion, optimizer, scheduler, 
        num_epochs=num_epochs, device=device
    )
    
    # Save final model
    torch.save(model.state_dict(), 'final_model.pth')
    
    # Test on validation set
    model.eval()
    val_psnr = 0.0
    val_ssim = 0.0
    
    with torch.no_grad():
        for i, batch in enumerate(val_loader):
            lr_imgs = batch['lr'].to(device)
            hr_imgs = batch['hr'].to(device)
            
            # Forward pass
            outputs = model(lr_imgs)
            
            # Calculate metrics
            output = outputs[0].cpu().numpy().transpose(1, 2, 0)
            target = hr_imgs[0].cpu().numpy().transpose(1, 2, 0)
            
            # Clip predictions to valid range
            output = np.clip(output, 0, 1)
            
            # Calculate PSNR and SSIM
            psnr = calculate_psnr(output, target)
            ssim = calculate_ssim(output, target)
            
            val_psnr += psnr
            val_ssim += ssim
            
            # Save some example images
            if i < 5:
                plt.figure(figsize=(15, 5))
                
                plt.subplot(1, 3, 1)
                plt.imshow(lr_imgs[0].cpu().numpy().transpose(1, 2, 0))
                plt.title('LR Image')
                plt.axis('off')
                
                plt.subplot(1, 3, 2)
                plt.imshow(output)
                plt.title(f'SR Image (PSNR: {psnr:.2f}, SSIM: {ssim:.4f})')
                plt.axis('off')
                
                plt.subplot(1, 3, 3)
                plt.imshow(target)
                plt.title('HR Image')
                plt.axis('off')
                
                plt.tight_layout()
                plt.savefig(f'example_{i}.png')
                plt.close()
    
    # Calculate average metrics
    val_psnr /= len(val_loader)
    val_ssim /= len(val_loader)
    
    print(f"Final Validation PSNR: {val_psnr:.4f}, SSIM: {val_ssim:.4f}")

if __name__ == "__main__":
    main()