# Notebook 4: Deep Learning (U-Net) Image Restoration

This notebook implements a U-Net architecture for image inpainting using PyTorch. The model learns to restore damaged regions by training on damaged images and masks as input, with ground truth as target.


In [None]:
import sys
import os
from pathlib import Path
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torchvision
from torchvision import transforms

sys.path.append('..')
from utils import load_image, save_image, ensure_dir, get_image_files

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)


## 1. U-Net Architecture Definition


In [None]:
class DoubleConv(nn.Module):
    """Double convolution block."""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):
    """U-Net architecture for image inpainting."""
    def __init__(self, in_channels=4, out_channels=3):
        super(UNet, self).__init__()
        
        # Encoder (downsampling)
        self.enc1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = DoubleConv(512, 1024)
        
        # Decoder (upsampling)
        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = DoubleConv(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = DoubleConv(128, 64)
        
        # Output layer
        self.final = nn.Conv2d(64, out_channels, 1)
    
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))
        
        # Bottleneck
        b = self.bottleneck(self.pool4(e4))
        
        # Decoder with skip connections
        d4 = self.up4(b)
        d4 = torch.cat([d4, e4], dim=1)
        d4 = self.dec4(d4)
        
        d3 = self.up3(d4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        
        # Output
        out = self.final(d1)
        return torch.sigmoid(out)  # Output in [0, 1] range

print("U-Net model defined!")


## 2. Dataset Class


In [None]:
class InpaintingDataset(Dataset):
    """Dataset for image inpainting."""
    def __init__(self, damaged_dir, mask_dir, gt_dir, transform=None, target_size=(256, 256)):
        self.damaged_files = sorted(get_image_files(damaged_dir))
        self.mask_dir = mask_dir
        self.gt_dir = gt_dir
        self.transform = transform
        self.target_size = target_size
    
    def __len__(self):
        return len(self.damaged_files)
    
    def __getitem__(self, idx):
        # Load damaged image
        damaged = load_image(self.damaged_files[idx])
        
        # Load mask
        mask_path = Path(self.mask_dir) / f"{self.damaged_files[idx].stem}.png"
        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        
        # Load ground truth
        gt_path = Path(self.gt_dir) / f"{self.damaged_files[idx].stem}.png"
        gt = load_image(gt_path)
        
        # Resize if needed
        if self.target_size:
            damaged = cv2.resize(damaged, self.target_size)
            mask = cv2.resize(mask, self.target_size)
            gt = cv2.resize(gt, self.target_size)
        
        # Normalize to [0, 1]
        damaged = damaged.astype(np.float32) / 255.0
        mask = mask.astype(np.float32) / 255.0
        gt = gt.astype(np.float32) / 255.0
        
        # Convert to tensors: [C, H, W]
        damaged_tensor = torch.from_numpy(damaged).permute(2, 0, 1)
        mask_tensor = torch.from_numpy(mask).unsqueeze(0)
        gt_tensor = torch.from_numpy(gt).permute(2, 0, 1)
        
        # Concatenate damaged image and mask as input (4 channels)
        input_tensor = torch.cat([damaged_tensor, mask_tensor], dim=0)
        
        return input_tensor, gt_tensor

print("Dataset class defined!")


## 3. Loss Function (L1 + Perceptual Loss)


In [None]:
class VGGPerceptualLoss(nn.Module):
    """Perceptual loss using VGG features."""
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        vgg = torchvision.models.vgg16(pretrained=True).features
        self.feature_layers = nn.ModuleList([
            vgg[:4],   # relu1_2
            vgg[4:9],  # relu2_2
            vgg[9:16], # relu3_3
            vgg[16:23] # relu4_3
        ])
        for param in self.feature_layers.parameters():
            param.requires_grad = False
    
    def forward(self, pred, target):
        loss = 0
        for layer in self.feature_layers:
            pred = layer(pred)
            target = layer(target)
            loss += F.mse_loss(pred, target)
        return loss


def combined_loss(pred, target, perceptual_loss_fn=None, lambda_l1=1.0, lambda_perceptual=0.1):
    """Combined L1 and perceptual loss."""
    l1_loss = F.l1_loss(pred, target)
    
    if perceptual_loss_fn is not None:
        # Normalize for VGG (ImageNet stats)
        mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(pred.device)
        std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(pred.device)
        pred_norm = (pred - mean) / std
        target_norm = (target - mean) / std
        perc_loss = perceptual_loss_fn(pred_norm, target_norm)
    else:
        perc_loss = 0
    
    return lambda_l1 * l1_loss + lambda_perceptual * perc_loss

print("Loss functions defined!")


## 4. Training Setup


In [None]:
# Paths
DAMAGED_DIR = Path('../data/damaged')
MASKS_DIR = Path('../data/masks')
GROUND_TRUTH_DIR = Path('../data/ground_truth')
RESULTS_DIR = Path('../models')
MODEL_SAVE_PATH = RESULTS_DIR / 'unet_inpainting.pth'

ensure_dir(RESULTS_DIR)

# Hyperparameters
BATCH_SIZE = 8
LEARNING_RATE = 0.0001
NUM_EPOCHS = 10  # Adjust based on dataset size
TARGET_SIZE = (256, 256)  # Resize for training

# Create dataset
print("Loading dataset...")
full_dataset = InpaintingDataset(DAMAGED_DIR, MASKS_DIR, GROUND_TRUTH_DIR, target_size=TARGET_SIZE)

# Split into train/val (80/20)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")

# Initialize model
model = UNet(in_channels=4, out_channels=3).to(device)
print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}")

# Loss and optimizer
perceptual_loss_fn = VGGPerceptualLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

print("Training setup complete!")


## 5. Training Loop


In [None]:
# Training history
train_losses = []
val_losses = []

print("Starting training...")
best_val_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    # Training phase
    model.train()
    train_loss = 0
    for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")):
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = combined_loss(outputs, targets, perceptual_loss_fn, lambda_l1=1.0, lambda_perceptual=0.1)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    
    avg_train_loss = train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # Validation phase
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            loss = combined_loss(outputs, targets, perceptual_loss_fn, lambda_l1=1.0, lambda_perceptual=0.1)
            val_loss += loss.item()
    
    avg_val_loss = val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    
    scheduler.step(avg_val_loss)
    
    print(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}")
    
    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
        print(f"  -> Saved best model (val_loss: {best_val_loss:.4f})")

print("\nTraining complete!")
print(f"Best validation loss: {best_val_loss:.4f}")

# Plot training curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Curves')
plt.legend()
plt.grid(True)
plt.show()


## 6. Test on All Images

Load the trained model and process all images at full resolution.


In [None]:
# Load best model
model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=device))
model.eval()

RESULTS_DIR = Path('../methods/Deep/results')
ensure_dir(RESULTS_DIR)

# Get all test images (use full dataset)
test_files = get_image_files(DAMAGED_DIR)

print(f"Processing {len(test_files)} images for testing...")

def process_image(model, damaged_img, mask, device):
    """Process a single image with the model."""
    h, w = damaged_img.shape[:2]
    
    # Resize for model input
    damaged_resized = cv2.resize(damaged_img, (256, 256))
    mask_resized = cv2.resize(mask, (256, 256))
    
    # Normalize
    damaged_norm = damaged_resized.astype(np.float32) / 255.0
    mask_norm = mask_resized.astype(np.float32) / 255.0
    
    # Convert to tensor
    damaged_tensor = torch.from_numpy(damaged_norm).permute(2, 0, 1).unsqueeze(0)
    mask_tensor = torch.from_numpy(mask_norm).unsqueeze(0).unsqueeze(0)
    input_tensor = torch.cat([damaged_tensor, mask_tensor], dim=1).to(device)
    
    # Inference
    with torch.no_grad():
        output = model(input_tensor)
        output = output.squeeze(0).cpu().permute(1, 2, 0).numpy()
        output = (output * 255).astype(np.uint8)
    
    # Resize back to original size
    output = cv2.resize(output, (w, h))
    
    return output

# Process all images
for img_path in tqdm(test_files, desc="Deep Learning Inference"):
    damaged_img = load_image(img_path)
    mask = cv2.imread(str(MASKS_DIR / f"{img_path.stem}.png"), cv2.IMREAD_GRAYSCALE)
    
    restored = process_image(model, damaged_img, mask, device)
    
    save_image(restored, RESULTS_DIR / f"{img_path.stem}.png")

print(f"\nâœ“ All {len(test_files)} images processed and saved to {RESULTS_DIR}")


## 7. Summary

Deep learning restoration complete! All restored images saved to `methods/Deep/results/`.

**Model Architecture:** U-Net with skip connections
**Training:** L1 loss + Perceptual loss (VGG features)
**Input:** 4 channels (RGB damaged image + mask)
**Output:** 3 channels (RGB restored image)

The deep learning method learns complex patterns and textures from the training data, making it effective for various types of damage.
