# Notebook 5: Hybrid Image Restoration

This notebook implements a hybrid approach combining PDE-based restoration with CNN refinement. The pipeline: PDE output → CNN refinement → final restored image.


In [None]:
import sys
import os
from pathlib import Path
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F

sys.path.append('..')
from utils import load_image, save_image, ensure_dir, get_image_files

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Define U-Net architecture (same as Notebook_4)
class DoubleConv(nn.Module):
    """Double convolution block."""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    """U-Net architecture for image inpainting."""
    def __init__(self, in_channels=4, out_channels=3):
        super(UNet, self).__init__()
        
        # Encoder (downsampling)
        self.enc1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = DoubleConv(512, 1024)
        
        # Decoder (upsampling)
        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = DoubleConv(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = DoubleConv(128, 64)
        
        # Output layer
        self.final = nn.Conv2d(64, out_channels, 1)
    
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))
        
        # Bottleneck
        b = self.bottleneck(self.pool4(e4))
        
        # Decoder with skip connections
        d4 = self.up4(b)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)
        
        d3 = self.up3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        
        # Output
        out = self.final(d1)
        return torch.sigmoid(out)  # Output in [0, 1] range

print("U-Net model defined!")


## 1. Load PDE Results and Trained CNN Model


In [None]:
# Paths
PDE_RESULTS_DIR = Path('../methods/PDE/results')
MASKS_DIR = Path('../data/masks')
MODEL_PATH = Path('../models/unet_inpainting.pth')
RESULTS_DIR = Path('../methods/Hybrid/results')

ensure_dir(RESULTS_DIR)

# Load trained U-Net model
print("Loading trained U-Net model...")
model = UNet(in_channels=4, out_channels=3).to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()
print("Model loaded successfully!")

# Get PDE result files
pde_files = get_image_files(PDE_RESULTS_DIR)
print(f"Found {len(pde_files)} PDE results to refine")


## 2. Hybrid Pipeline: PDE + CNN Refinement


In [None]:
def refine_with_cnn(model, pde_result, mask, device):
    """
    Refine PDE result using CNN.
    
    Args:
        model: Trained U-Net model
        pde_result: PDE restoration output (RGB)
        mask: Binary mask (255 for damaged regions)
        device: Device for computation
        
    Returns:
        Refined image
    """
    h, w = pde_result.shape[:2]
    
    # Resize for model input
    pde_resized = cv2.resize(pde_result, (256, 256))
    mask_resized = cv2.resize(mask, (256, 256))
    
    # Normalize
    pde_norm = pde_resized.astype(np.float32) / 255.0
    mask_norm = mask_resized.astype(np.float32) / 255.0
    
    # Convert to tensor
    pde_tensor = torch.from_numpy(pde_norm).permute(2, 0, 1).unsqueeze(0)
    mask_tensor = torch.from_numpy(mask_norm).unsqueeze(0).unsqueeze(0)
    input_tensor = torch.cat([pde_tensor, mask_tensor], dim=1).to(device)
    
    # Inference
    with torch.no_grad():
        output = model(input_tensor)
        output = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
        output = (output * 255).astype(np.uint8)
    
    # Resize back to original size
    output = cv2.resize(output, (w, h))
    
    # Blend: use CNN output in masked regions, keep PDE result elsewhere
    mask_bool = (mask == 255)[:, :, np.newaxis]
    refined = np.where(mask_bool, output, pde_result)
    
    return refined

print("Refinement function defined!")


## 3. Test on Sample Image


In [None]:
# Test on sample
if pde_files:
    sample_pde = load_image(pde_files[0])
    sample_mask = cv2.imread(str(MASKS_DIR / f"{pde_files[0].stem}.png"), cv2.IMREAD_GRAYSCALE)
    
    print(f"Testing on: {pde_files[0].name}")
    
    # Refine with CNN
    refined = refine_with_cnn(model, sample_pde, sample_mask, device)
    
    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    axes[0].imshow(sample_pde)
    axes[0].set_title('PDE Result')
    axes[0].axis('off')
    
    axes[1].imshow(refined)
    axes[1].set_title('Hybrid (PDE + CNN)')
    axes[1].axis('off')
    
    diff = np.abs(refined.astype(float) - sample_pde.astype(float))
    axes[2].imshow(diff.astype(np.uint8))
    axes[2].set_title('Difference (Hybrid - PDE)')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print("Sample test complete!")


## 4. Process All Images


In [None]:
# Process all PDE results
print(f"Processing {len(pde_files)} images with hybrid pipeline...")

for pde_path in tqdm(pde_files, desc="Hybrid Refinement"):
    # Load PDE result and mask
    pde_result = load_image(pde_path)
    mask = cv2.imread(str(MASKS_DIR / f"{pde_path.stem}.png"), cv2.IMREAD_GRAYSCALE)
    
    # Refine with CNN
    hybrid_result = refine_with_cnn(model, pde_result, mask, device)
    
    # Save result
    save_image(hybrid_result, RESULTS_DIR / f"{pde_path.stem}.png")

print(f"\n✓ All {len(pde_files)} images processed and saved to {RESULTS_DIR}")


## 5. Summary

Hybrid restoration complete! All restored images saved to `methods/Hybrid/results/`.

**Pipeline:**
1. PDE-based inpainting (from Notebook_2)
2. CNN refinement using trained U-Net (from Notebook_4)
3. Blending: CNN output in masked regions, PDE result elsewhere

The hybrid approach combines the smoothness of PDE methods with the detail-preserving capabilities of deep learning.
