In [None]:
# CELL 1: Install Dependencies (Minimal - Pure PyTorch)
# =============================================================================
# Fix numpy version conflict first
!pip install -q numpy==1.24.3 --force-reinstall
!pip install -q pillow tqdm matplotlib

print("‚úÖ Dependencies installed!")
print("‚ö†Ô∏è  RESTART THE KERNEL NOW: Runtime ‚Üí Restart session")
print("   Then SKIP this cell and run Cell 2")

In [None]:
# CELL 2: Imports
# =============================================================================
import os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import gc
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.models as models

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"‚úÖ Using device: {DEVICE}")
if DEVICE == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# CELL 3: Configuration
# =============================================================================
# Paths
DATASET_PATH = '/kaggle/input/burn-scar-faces-with-original'
CLEAN_FOLDER = os.path.join(DATASET_PATH, 'original')
BURNS_FOLDER = os.path.join(DATASET_PATH, 'with_burns')
OUTPUT_DIR = '/kaggle/working/outputs'
CHECKPOINT_DIR = '/kaggle/working/checkpoints'

os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Training settings
IMAGE_SIZE = 256          # Good balance of quality and speed
BATCH_SIZE = 16           # Increased for faster training with full dataset
MAX_IMAGES = None         # None = use ALL images (7000)
EPOCHS = 30               # Reduced epochs since we have more data
LR_G = 2e-4               # Generator learning rate
LR_D = 2e-4               # Discriminator learning rate
LAMBDA_L1 = 100           # Weight for L1 loss
LAMBDA_PERCEPTUAL = 10    # Weight for perceptual loss
SAVE_EVERY = 5            # Save checkpoint every N epochs (more frequent)

print(f"‚úÖ Config loaded")
print(f"   Image size: {IMAGE_SIZE}x{IMAGE_SIZE}")
print(f"   Batch size: {BATCH_SIZE}")
print(f"   Training on {MAX_IMAGES} images for {EPOCHS} epochs")

In [None]:
# =============================================================================
# CELL 4: Dataset Class
# =============================================================================
class BurnFaceDataset(Dataset):
    """
    Paired dataset: burned face ‚Üí clean face
    """
    def __init__(self, clean_folder, burns_folder, max_images=None, image_size=256, augment=True):
        self.image_size = image_size
        self.augment = augment
        self.pairs = []
        
        # Get sorted file lists
        clean_files = sorted([f for f in os.listdir(clean_folder) 
                             if f.lower().endswith(('.jpg', '.png', '.jpeg'))])
        burn_files = sorted([f for f in os.listdir(burns_folder) 
                            if f.lower().endswith(('.jpg', '.png', '.jpeg'))])
        
        # Pair by index
        num_pairs = min(len(clean_files), len(burn_files))
        if max_images:
            num_pairs = min(num_pairs, max_images)
        
        for i in range(num_pairs):
            self.pairs.append((
                os.path.join(burns_folder, burn_files[i]),   # Input: burned
                os.path.join(clean_folder, clean_files[i])   # Target: clean
            ))
        
        # Base transforms
        self.to_tensor = transforms.ToTensor()
        self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        
        print(f"‚úÖ Dataset loaded: {len(self.pairs)} image pairs")
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        burn_path, clean_path = self.pairs[idx]
        
        # Load images
        burn_img = Image.open(burn_path).convert('RGB')
        clean_img = Image.open(clean_path).convert('RGB')
        
        # Resize
        burn_img = burn_img.resize((self.image_size, self.image_size), Image.BILINEAR)
        clean_img = clean_img.resize((self.image_size, self.image_size), Image.BILINEAR)
        
        # Data augmentation (same transform for both)
        if self.augment and random.random() > 0.5:
            burn_img = burn_img.transpose(Image.FLIP_LEFT_RIGHT)
            clean_img = clean_img.transpose(Image.FLIP_LEFT_RIGHT)
        
        # To tensor and normalize to [-1, 1]
        burn_tensor = self.normalize(self.to_tensor(burn_img))
        clean_tensor = self.normalize(self.to_tensor(clean_img))
        
        return {
            'burned': burn_tensor,
            'clean': clean_tensor,
            'burn_path': burn_path
        }

# Test dataset
dataset = BurnFaceDataset(CLEAN_FOLDER, BURNS_FOLDER, max_images=MAX_IMAGES, image_size=IMAGE_SIZE)
print(f"   Sample pair loaded successfully!")

In [None]:
# CELL 5: U-Net Generator Architecture
# =============================================================================
class ConvBlock(nn.Module):
    """Encoder block: Conv -> BatchNorm -> LeakyReLU"""
    def __init__(self, in_ch, out_ch, use_bn=True, use_dropout=False):
        super().__init__()
        layers = [nn.Conv2d(in_ch, out_ch, 4, stride=2, padding=1, bias=False)]
        if use_bn:
            layers.append(nn.BatchNorm2d(out_ch))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        if use_dropout:
            layers.append(nn.Dropout(0.5))
        self.block = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.block(x)

class DeconvBlock(nn.Module):
    """Decoder block: ConvTranspose -> BatchNorm -> ReLU -> (Dropout)"""
    def __init__(self, in_ch, out_ch, use_dropout=False):
        super().__init__()
        layers = [
            nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        ]
        if use_dropout:
            layers.append(nn.Dropout(0.5))
        self.block = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.block(x)

class UNetGenerator(nn.Module):
    """
    U-Net Generator for image-to-image translation
    Input: burned face (3 channels)
    Output: predicted clean face (3 channels)
    """
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super().__init__()
        
        # Encoder (downsampling)
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )  # 256 -> 128
        self.enc2 = ConvBlock(features, features*2)      # 128 -> 64
        self.enc3 = ConvBlock(features*2, features*4)    # 64 -> 32
        self.enc4 = ConvBlock(features*4, features*8)    # 32 -> 16
        self.enc5 = ConvBlock(features*8, features*8)    # 16 -> 8
        self.enc6 = ConvBlock(features*8, features*8)    # 8 -> 4
        self.enc7 = ConvBlock(features*8, features*8)    # 4 -> 2
        
        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features*8, features*8, 4, stride=2, padding=1),  # 2 -> 1
            nn.ReLU(inplace=True)
        )
        
        # Decoder (upsampling with skip connections)
        self.dec1 = DeconvBlock(features*8, features*8, use_dropout=True)    # 1 -> 2
        self.dec2 = DeconvBlock(features*8*2, features*8, use_dropout=True)  # 2 -> 4
        self.dec3 = DeconvBlock(features*8*2, features*8, use_dropout=True)  # 4 -> 8
        self.dec4 = DeconvBlock(features*8*2, features*8)   # 8 -> 16
        self.dec5 = DeconvBlock(features*8*2, features*4)   # 16 -> 32
        self.dec6 = DeconvBlock(features*4*2, features*2)   # 32 -> 64
        self.dec7 = DeconvBlock(features*2*2, features)     # 64 -> 128
        
        # Final layer
        self.final = nn.Sequential(
            nn.ConvTranspose2d(features*2, out_channels, 4, stride=2, padding=1),  # 128 -> 256
            nn.Tanh()
        )
    
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        e6 = self.enc6(e5)
        e7 = self.enc7(e6)
        
        # Bottleneck
        b = self.bottleneck(e7)
        
        # Decoder with skip connections
        d1 = self.dec1(b)
        d1 = torch.cat([d1, e7], dim=1)
        
        d2 = self.dec2(d1)
        d2 = torch.cat([d2, e6], dim=1)
        
        d3 = self.dec3(d2)
        d3 = torch.cat([d3, e5], dim=1)
        
        d4 = self.dec4(d3)
        d4 = torch.cat([d4, e4], dim=1)
        
        d5 = self.dec5(d4)
        d5 = torch.cat([d5, e3], dim=1)
        
        d6 = self.dec6(d5)
        d6 = torch.cat([d6, e2], dim=1)
        
        d7 = self.dec7(d6)
        d7 = torch.cat([d7, e1], dim=1)
        
        return self.final(d7)

print("‚úÖ U-Net Generator defined")


In [None]:
# =============================================================================
# CELL 6: PatchGAN Discriminator
# =============================================================================
class PatchDiscriminator(nn.Module):
    """
    PatchGAN Discriminator - classifies 70x70 patches as real/fake
    Input: concatenated burned + clean/generated images (6 channels)
    """
    def __init__(self, in_channels=6, features=64):
        super().__init__()
        
        self.model = nn.Sequential(
            # Layer 1: No BatchNorm
            nn.Conv2d(in_channels, features, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Layer 2
            nn.Conv2d(features, features*2, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(features*2),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Layer 3
            nn.Conv2d(features*2, features*4, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(features*4),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Layer 4
            nn.Conv2d(features*4, features*8, 4, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(features*8),
            nn.LeakyReLU(0.2, inplace=True),
            
            # Output layer
            nn.Conv2d(features*8, 1, 4, stride=1, padding=1)
        )
    
    def forward(self, x, y):
        # Concatenate input and target/generated
        combined = torch.cat([x, y], dim=1)
        return self.model(combined)

print("‚úÖ PatchGAN Discriminator defined")

In [None]:
# =============================================================================
# CELL 7: Perceptual Loss (VGG-based)
# =============================================================================
class VGGPerceptualLoss(nn.Module):
    """Perceptual loss using VGG16 features"""
    def __init__(self):
        super().__init__()
        vgg = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1)
        self.features = nn.Sequential(*list(vgg.features[:16])).eval()
        for param in self.features.parameters():
            param.requires_grad = False
        
        self.register_buffer('mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
    
    def forward(self, pred, target):
        # Normalize from [-1,1] to VGG range
        pred = (pred + 1) / 2
        target = (target + 1) / 2
        
        pred = (pred - self.mean) / self.std
        target = (target - self.mean) / self.std
        
        return F.l1_loss(self.features(pred), self.features(target))

print("‚úÖ Perceptual Loss defined")

In [None]:
# =============================================================================
# CELL 8: Initialize Models and Training
# =============================================================================
# Initialize models
generator = UNetGenerator().to(DEVICE)
discriminator = PatchDiscriminator().to(DEVICE)
vgg_loss = VGGPerceptualLoss().to(DEVICE)

# Count parameters
g_params = sum(p.numel() for p in generator.parameters())
d_params = sum(p.numel() for p in discriminator.parameters())
print(f"‚úÖ Generator parameters: {g_params:,}")
print(f"‚úÖ Discriminator parameters: {d_params:,}")

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=LR_G, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=LR_D, betas=(0.5, 0.999))

# Loss functions
criterion_GAN = nn.BCEWithLogitsLoss()
criterion_L1 = nn.L1Loss()

# DataLoader
dataloader = DataLoader(
    dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=2,
    pin_memory=True
)

print(f"‚úÖ Training setup complete!")
print(f"   Steps per epoch: {len(dataloader)}")

In [None]:
# =============================================================================
# CELL 9: Training Loop
# =============================================================================
def train():
    print("\n" + "="*60)
    print("üöÄ STARTING TRAINING")
    print("="*60 + "\n")
    
    history = {'g_loss': [], 'd_loss': []}
    
    for epoch in range(EPOCHS):
        generator.train()
        discriminator.train()
        
        epoch_g_loss = 0
        epoch_d_loss = 0
        
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
        
        for batch in pbar:
            burned = batch['burned'].to(DEVICE)
            clean = batch['clean'].to(DEVICE)
            batch_size = burned.size(0)
            
            # Labels
            real_label = torch.ones(batch_size, 1, 30, 30).to(DEVICE)
            fake_label = torch.zeros(batch_size, 1, 30, 30).to(DEVICE)
            
            # ==================
            # Train Discriminator
            # ==================
            optimizer_D.zero_grad()
            
            # Real loss
            pred_real = discriminator(burned, clean)
            loss_D_real = criterion_GAN(pred_real, real_label)
            
            # Fake loss
            fake_clean = generator(burned)
            pred_fake = discriminator(burned, fake_clean.detach())
            loss_D_fake = criterion_GAN(pred_fake, fake_label)
            
            # Combined D loss
            loss_D = (loss_D_real + loss_D_fake) * 0.5
            loss_D.backward()
            optimizer_D.step()
            
            # ==================
            # Train Generator
            # ==================
            optimizer_G.zero_grad()
            
            # GAN loss
            pred_fake = discriminator(burned, fake_clean)
            loss_G_GAN = criterion_GAN(pred_fake, real_label)
            
            # L1 loss
            loss_G_L1 = criterion_L1(fake_clean, clean)
            
            # Perceptual loss
            loss_G_perceptual = vgg_loss(fake_clean, clean)
            
            # Combined G loss
            loss_G = loss_G_GAN + LAMBDA_L1 * loss_G_L1 + LAMBDA_PERCEPTUAL * loss_G_perceptual
            loss_G.backward()
            optimizer_G.step()
            
            # Logging
            epoch_g_loss += loss_G.item()
            epoch_d_loss += loss_D.item()
            
            pbar.set_postfix({
                'G': f'{loss_G.item():.3f}',
                'D': f'{loss_D.item():.3f}',
                'L1': f'{loss_G_L1.item():.3f}'
            })
        
        # Epoch summary
        avg_g_loss = epoch_g_loss / len(dataloader)
        avg_d_loss = epoch_d_loss / len(dataloader)
        history['g_loss'].append(avg_g_loss)
        history['d_loss'].append(avg_d_loss)
        
        print(f"üìä Epoch {epoch+1}: G_loss={avg_g_loss:.4f}, D_loss={avg_d_loss:.4f}")
        
        # Save checkpoint and samples
        if (epoch + 1) % SAVE_EVERY == 0 or epoch == EPOCHS - 1:
            # Save model
            torch.save({
                'epoch': epoch,
                'generator': generator.state_dict(),
                'discriminator': discriminator.state_dict(),
                'optimizer_G': optimizer_G.state_dict(),
                'optimizer_D': optimizer_D.state_dict(),
            }, os.path.join(CHECKPOINT_DIR, f'checkpoint_epoch_{epoch+1}.pt'))
            print(f"üíæ Saved checkpoint at epoch {epoch+1}")
            
            # Save sample images
            save_samples(epoch+1)
        
        # Clear cache
        torch.cuda.empty_cache()
        gc.collect()
    
    # Save final model
    torch.save(generator.state_dict(), os.path.join(CHECKPOINT_DIR, 'generator_final.pt'))
    print(f"\nüéâ Training complete! Final model saved.")
    
    return history

def save_samples(epoch, num_samples=4):
    """Save sample predictions"""
    generator.eval()
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, num_samples*4))
    
    with torch.no_grad():
        for i in range(num_samples):
            idx = random.randint(0, len(dataset)-1)
            sample = dataset[idx]
            
            burned = sample['burned'].unsqueeze(0).to(DEVICE)
            clean = sample['clean'].unsqueeze(0).to(DEVICE)
            
            predicted = generator(burned)
            
            # Convert to numpy for display
            burned_np = (burned[0].cpu().numpy().transpose(1, 2, 0) + 1) / 2
            predicted_np = (predicted[0].cpu().numpy().transpose(1, 2, 0) + 1) / 2
            clean_np = (clean[0].cpu().numpy().transpose(1, 2, 0) + 1) / 2
            
            # Clip values
            burned_np = np.clip(burned_np, 0, 1)
            predicted_np = np.clip(predicted_np, 0, 1)
            clean_np = np.clip(clean_np, 0, 1)
            
            axes[i, 0].imshow(burned_np)
            axes[i, 0].set_title('Input (Burned)')
            axes[i, 0].axis('off')
            
            axes[i, 1].imshow(predicted_np)
            axes[i, 1].set_title('Predicted (Post-Surgery)')
            axes[i, 1].axis('off')
            
            axes[i, 2].imshow(clean_np)
            axes[i, 2].set_title('Ground Truth (Clean)')
            axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, f'samples_epoch_{epoch}.png'), dpi=150)
    plt.close()
    print(f"üì∏ Saved samples for epoch {epoch}")
    
    generator.train()

print("‚úÖ Training functions defined")
print("\n‚è≥ Run the next cell to start training...")

In [None]:
# =============================================================================
# CELL 10: START TRAINING
# =============================================================================
# Run training
history = train()

# Plot training history
plt.figure(figsize=(10, 5))
plt.plot(history['g_loss'], label='Generator Loss')
plt.plot(history['d_loss'], label='Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training History')
plt.legend()
plt.savefig(os.path.join(OUTPUT_DIR, 'training_history.png'))
plt.show()