In [11]:
import torch  # Core PyTorch library for tensor operations and neural networks
import torch.nn as nn  # Neural network module from PyTorch
import torch.optim as optim  # Optimization algorithms (like Adam, SGD)
from torch.utils.data import Dataset, DataLoader, random_split  # Data utilities
from torchvision import transforms  # Image transformations
from PIL import Image  # Python Imaging Library for image handling
import os  # Operating system interface for file operations
import matplotlib.pyplot as plt  # Plotting library for visualizing results
import numpy as np

# =============================================================================
# Custom Dataset Class for Image Super-Resolution
# =============================================================================

class ImageSuperResDataset(Dataset):
    def __init__(self, input_dir, output_dir):
        """
        Initialize the dataset with directories containing low-res and high-res images.
        
        Args:
            input_dir: Directory containing low-resolution input images (225x300)
            output_dir: Directory corresponding high-resolution target images (450x600)
        """
        self.input_dir = input_dir
        self.output_dir = output_dir
        # Get sorted list of filenames to ensure consistent ordering
        self.filenames = sorted(os.listdir(input_dir))
        
        # Define transformations for input and output images
        self.input_transform = transforms.Compose([
            transforms.Resize((225, 300)),  # Resize all inputs to standard size (225x300)
            transforms.ToTensor()  # Convert PIL image to PyTorch tensor
        ])
        self.output_transform = transforms.Compose([
            transforms.Resize((450, 600)),  # Resize all outputs to double the input dimensions (450x600)
            transforms.ToTensor()  # Convert PIL image to PyTorch tensor
        ])
    
    def __len__(self):
        """Return the total number of images in the dataset"""
        return len(self.filenames)
    
    def __getitem__(self, idx):
        """
        Retrieve a pair of low-res and high-res images by index.
        
        Args:
            idx: Index of the image to retrieve
            
        Returns:
            A tuple containing (input_tensor, output_tensor) where input is 3x225x300
            and output is 3x450x600 (channels x height x width)
        """
        img_name = self.filenames[idx]
        
        # Load low-resolution input image
        input_path = os.path.join(self.input_dir, img_name)
        input_img = Image.open(input_path).convert('RGB')  # Ensure RGB format
        input_tensor = self.input_transform(input_img)
        
        # Load high-resolution target image
        output_path = os.path.join(self.output_dir, img_name)
        output_img = Image.open(output_path).convert('RGB')  # Ensure RGB format
        output_tensor = self.output_transform(output_img)
        
        # Return images as 2D tensors (C x H x W) without flattening
        return input_tensor, output_tensor
    

# =============================================================================
# VISUALIZATION FUNCTION
# =============================================================================
def visualize_super_resolution(model, test_loader, device, num_examples=5):
    """
    Visualize super-resolution results with side-by-side comparisons.
    
    For each example, we show:
    1. Input: Low-resolution image (225x300)
    2. Model Output: Super-resolved image (450x600)
    3. Ground Truth: Actual high-resolution image (450x600)
    
    Args:
        model: The trained super-resolution model
        test_loader: DataLoader containing test images
        num_examples: Number of examples to visualize (default: 3)
    """
    model.eval()
    data_iter = iter(test_loader)
    inputs, targets = next(data_iter)
    
    # Limit to requested number of examples and move to device
    inputs = inputs[:num_examples].to(device)
    targets = targets[:num_examples].to(device)
    
    print(f"\nGenerating super-resolution for {num_examples} test images...")
    
    # Generate predictions (no gradient computation needed for inference)
    with torch.no_grad():
        outputs = model(inputs)
    
    # Move tensors to CPU and convert to numpy arrays for plotting
    inputs_np = inputs.cpu().numpy()
    outputs_np = outputs.cpu().numpy()
    targets_np = targets.cpu().numpy()
    
    # Create a grid of subplots: one row per example, three columns per row
    fig, axes = plt.subplots(num_examples, 3, figsize=(15, 5 * num_examples))
    
    # Handle case of single example (axes would be 1D)
    if num_examples == 1:
        axes = axes.reshape(1, -1)
    
    # Process each example
    for i in range(num_examples):
        # Convert from PyTorch format (C, H, W) to matplotlib format (H, W, C)
        input_img = np.transpose(inputs_np[i], (1, 2, 0))
        output_img = np.transpose(outputs_np[i], (1, 2, 0))
        target_img = np.transpose(targets_np[i], (1, 2, 0))
        
        # Ensure pixel values are in valid range [0, 1]
        input_img = np.clip(input_img, 0, 1)
        output_img = np.clip(output_img, 0, 1)
        target_img = np.clip(target_img, 0, 1)
        
        # Display the three images side by side
        axes[i, 0].imshow(input_img)
        axes[i, 0].set_title(f'Input (Low-Res)\n{inputs_np[i].shape[1]}x{inputs_np[i].shape[2]}')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(output_img)
        axes[i, 1].set_title(f'Model Output\n{outputs_np[i].shape[1]}x{outputs_np[i].shape[2]}')
        axes[i, 1].axis('off')
        
        axes[i, 2].imshow(target_img)
        axes[i, 2].set_title(f'Ground Truth\n{targets_np[i].shape[1]}x{targets_np[i].shape[2]}')
        axes[i, 2].axis('off')
        
        # Calculate quality metrics
        mse = np.mean((output_img - target_img) ** 2)
        psnr = 10 * np.log10(1.0 / (mse + 1e-10))  # Add small epsilon to avoid division by zero
        
        print(f"  Example {i+1} - MSE: {mse:.6f}, PSNR: {psnr:.2f} dB")
    
    plt.tight_layout()
    #plt.savefig('superres_results.png', dpi=150, bbox_inches='tight')
    #print(f"\nVisualization saved as 'superres_results.png'")
    plt.show()

# =============================================================================
# QUANTITATIVE EVALUATION FUNCTION
# =============================================================================
def evaluate_metrics(model, test_loader, device):
    """
    Calculate average quality metrics across the entire test set.
    
    Metrics:
    - MSE (Mean Squared Error): Lower is better, measures pixel-wise difference
    - PSNR (Peak Signal-to-Noise Ratio): Higher is better, logarithmic quality measure
    
    Args:
        model: The trained super-resolution model
        test_loader: DataLoader containing all test images
        
    Returns:
        avg_mse: Average mean squared error
        avg_psnr: Average PSNR in decibels
    """
    model.eval()
    
    # Accumulate metrics across all batches
    total_mse = 0
    total_psnr = 0
    count = 0
    
    print("\nEvaluating model on test set...")
    
    # Process all test images
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Generate predictions
            outputs = model(inputs)
            
            # Calculate MSE and PSNR for each image in the batch
            mse = torch.mean((outputs - targets) ** 2, dim=[1, 2, 3])
            psnr = 10 * torch.log10(1.0 / (mse + 1e-10))
            
            # Accumulate statistics
            total_mse += mse.sum().item()
            total_psnr += psnr.sum().item()
            count += inputs.size(0)
            
            # Print progress every 10 batches
            if (batch_idx + 1) % 10 == 0:
                print(f"  Processed {count} images...")
    
    # Calculate averages
    avg_mse = total_mse / count
    avg_psnr = total_psnr / count
    
    # Print results
    print(f"\n{'='*70}")
    print(f"TEST SET EVALUATION RESULTS")
    print(f"{'='*70}")
    print(f"Total images evaluated: {count}")
    print(f"Average MSE:  {avg_mse:.6f}")
    print(f"Average PSNR: {avg_psnr:.2f} dB")
    
    # Simple quality assessment
    if avg_psnr >= 30:
        quality = "Good"
    else:
        quality = "Better hyperparameter search needed"
        
    print(f"Quality Assessment: {quality}")
    print(f"{'='*70}\n")
    
    return avg_mse, avg_psnr
print("check")


check


In [13]:
"""
=============================================================================
SUPER-RESOLUTION MODEL EVALUATION
=============================================================================
This notebook demonstrates how to:
1. Load a trained super-resolution neural network
2. Test it on unseen data
3. Visualize the results with before/after comparisons
4. Calculate quantitative metrics (MSE and PSNR)

Super-resolution is the task of taking a low-resolution image and generating
a high-resolution version. Our model takes 225x300 images and outputs 450x600
images (2x upscaling).

A 2GB VRAM GPU is enough for the neural network
=============================================================================
"""


# Import necessary libraries for deep learning, image processing, and visualization
import torch  # Core PyTorch library for tensor operations and neural networks
import torch.nn as nn  # Neural network module from PyTorch
import torch.optim as optim  # Optimization algorithms (like Adam, SGD)
from torch.utils.data import Dataset, DataLoader, random_split  # Data utilities
from torchvision import transforms  # Image transformations
from PIL import Image  # Python Imaging Library for image handling
import os  # Operating system interface for file operations
import myutilities


# Set device and enable BF16 if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # Use GPU if available, otherwise CPU
use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()  # Check if BF16 (bfloat16) is supported on the GPU
print(f"Using device: {device}")  # Print which device is being used
print(f"BF16 enabled: {use_bf16}")  # Print whether BF16 is enabled

check
Using device: cpu
BF16 enabled: False


In [15]:
# Load and prepare the dataset
print("Loading dataset...")
dataset = myutilities.ImageSuperResDataset('kaggle/train_x', 'kaggle/train_y')

# Split dataset into training, validation, and test sets (70%/15%/15%)
total_size = len(dataset)
train_size = int(0.70 * total_size)  # 70% for training
val_size = int(0.15 * total_size)    # 15% for validation
test_size = total_size - train_size - val_size  # Remaining 15% for testing

# Use random_split to create subsets and set a seed for reproducibility
# We do not need the train and validation data now
_, _, test_dataset = random_split(
    dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)  # Fixed seed for consistent splits
)
print(f"Test samples: {len(test_dataset)}")

# Create data loaders for efficient batch processing
batch_size = 10  # Number of images to process in each batch (adjust based on GPU memory)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Loading dataset...
Test samples: 383


In [41]:
test_dataset.dataset[0][0].shape

torch.Size([3, 225, 300])

In [42]:
# =============================================================================
# SETUP: Device Configuration
# =============================================================================
# Check if GPU is available - neural networks run much faster on GPUs
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# =============================================================================
# MODEL ARCHITECTURE DEFINITIONS
# =============================================================================

# =============================================================================
# CUSTOM PIXEL SHUFFLE IMPLEMENTATION
# =============================================================================
class CustomPixelShuffle(nn.Module):
    """
    Hand-made implementation of PixelShuffle (also known as sub-pixel convolution).
    PyTorch has an implementation of it called nn.PixelShuffle
    
    This operation rearranges elements in a tensor from depth to spatial dimensions.
    It's a clever way to upscale images using learned convolutions instead of 
    simple interpolation.
    
    How it works:
    - Input:  (batch, channels * r^2, height, width)
    - Output: (batch, channels, height * r, width * r)
    
    where r is the upscale factor.
    
    Example with r=2:
    - Input:  (B, 12, 100, 100) where 12 = 3 * 2^2
    - Output: (B, 3, 200, 200)
    
    The magic is in the reshaping and permutation that rearranges the channel
    data into spatial locations.
    """
    def __init__(self, upscale_factor):
        super(CustomPixelShuffle, self).__init__()
        self.upscale_factor = upscale_factor
        
    def forward(self, x):
        """
        Rearrange tensor to increase spatial resolution.
        
        Args:
            x: Input tensor of shape (batch, channels * r², height, width)
            
        Returns:
            Output tensor of shape (batch, channels, height * r, width * r)
        """
        r = self.upscale_factor
        
        # Get input dimensions
        b, c, h, w = x.shape
        
        # Step 1: Reshape to separate the upscale factor from channels
        # (b, c, h, w) -> (b, c/r², r, r, h, w)
        # This groups the channels that will become spatial pixels
        x = x.reshape(b, c // (r**2), r, r, h, w)
        
        # Step 2: Permute dimensions to interleave spatial information
        # (b, c/r², r, r, h, w) -> (b, c/r², h, r, w, r)
        # This moves the r×r blocks next to their spatial positions
        x = x.permute(0, 1, 4, 2, 5, 3)
        
        # Step 3: Collapse the r dimensions into spatial dimensions
        # (b, c/r², h, r, w, r) -> (b, c/r², h*r, w*r)
        # Now we have increased spatial resolution!
        x = x.reshape(b, c // (r**2), h * r, w * r)
        
        return x

Using device: cpu


In [44]:
# =============================================================================
# SUPER RESOLUTION CNN
# =============================================================================
   
class SuperResCNN(nn.Module):
    """
    A Convolutional Neural Network for 2x image super-resolution.
    
    Architecture:
    - Feature extraction: 3 convolutional layers to learn image patterns
    - Upsampling: Sub-pixel convolution (PixelShuffle) to increase resolution
    - Refinement: Final conv layer to polish the output
    
    Input:  3-channel RGB image of size 225x300
    Output: 3-channel RGB image of size 450x600 (2x larger)
    """
    def __init__(self):
        super(SuperResCNN, self).__init__()
        # Feature extraction layers - learn to recognize patterns like edges and textures
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)   # First conv layer
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5, padding=2)   # Second conv layer
        self.conv3 = nn.Conv2d(64, 32, kernel_size=5, padding=2)   # Third conv layer
        
        # Upsampling branch - increases image resolution by 2x
        self.upsample_conv = nn.Conv2d(32, 12, kernel_size=3, padding=1)
        # We could use PyTorch's implementation of PixelShuffle
        # self.pixel_shuffle = nn.PixelShuffle(upscale_factor=2)      # Rearranges channels to spatial dimensions
        # But we will be using our own implementation of PixelShuffle
        self.pixel_shuffle = CustomPixelShuffle(upscale_factor=2)
        
        # Final refinement layer - smooths out artifacts
        self.refine = nn.Conv2d(3, 3, kernel_size=3, padding=1)

    def forward(self, x):
        """
        Forward pass through the network
        Args:
            x: Input tensor of shape (batch_size, 3, 225, 300)
        Returns:
            Output tensor of shape (batch_size, 3, 450, 600)
        """
        # Feature extraction with ReLU activation
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        
        # Upsampling to increase resolution
        x = self.upsample_conv(x)
        x = self.pixel_shuffle(x)
        
        # Final refinement
        x = self.refine(x)
        return x

In [None]:
# =============================================================================
# LOAD THE TRAINED MODEL
# =============================================================================
print("\n" + "="*70)
print("LOADING TRAINED MODEL")
print("="*70)

# Create model instance and load saved weights
model = SuperResCNN().to(device)

try:
    # Attempt to load the best saved model weights in training onto the appropriate device (GPU or CPU)
    model.load_state_dict(torch.load('best_superres_model.pth', map_location=device))
    
    # Set the model to evaluation mode (disables dropout/batchnorm training behavior)
    model.eval()
    
    # Print confirmation message with parameter count
    print(f"Model loaded successfully with {sum(p.numel() for p in model.parameters()):,} parameters")
    
except FileNotFoundError:
    print("Warning: 'best_superres_model.pth' not found. Using the current model state.")
except Exception as e:
    print(f"Error loading model: {e}. Using the current model state.")