In [None]:
# FULL 128x128 DDPM TRAINING PIPELINE
# Features: Mixed Precision, Gradient Accumulation, EMA, and Robust Data Loading.

print("ðŸ”¥ Script started â€” loading imports...")

# ---------------- CRITICAL FIX FOR DATASET ----------------
# Standard PIL library often crashes on truncated (partially downloaded) images.
# This setting forces PIL to load them anyway, filling missing data with grey/black
# preventing the entire training loop from crashing due to one bad file.
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

import os
import random
import math
import copy
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

# ---------------- CONFIGURATION ----------------
DATASET_PATH = "/kaggle/input/cartoon-classification/cartoon_classification/TRAIN"
IMG_SIZE = 128            # Higher res = better detail but exponentially more VRAM usage
BASE_CH = 32              # Base network width. 32->64->128->256 channels.
BATCH_SIZE = 2            # Kept tiny (2) to prevent "CUDA Out Of Memory" on limited GPUs.
ACCUM_STEPS = 4           # GRADIENT ACCUMULATION: We only update weights every 4 steps.
                          # Effective Batch Size = 2 * 4 = 8 images.
EPOCHS = 12
LR = 2e-4                 # Standard learning rate for AdamW optimizer
NUM_WORKERS = 4           # Uses 4 CPU cores to preload images (prevents GPU waiting for data)
PIN_MEMORY = True         # Pins RAM to speed up transfer of data from CPU -> GPU
TIMESTEPS = 1000          # Standard diffusion: noise is added over 1000 steps
SUBSET_SIZE = None        # Set to an integer (e.g., 100) to debug on a small dataset quickly
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CHECKPOINT_DIR = "checkpoints"
CHECKPOINT_FREQ = 1       # Save weights every epoch
SAMPLES_TO_SAVE = 4       # How many test images to generate at the end
SEED = 42

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
random.seed(SEED)
torch.manual_seed(SEED)

print(f"Config: Device={DEVICE} | Size={IMG_SIZE} | Batch={BATCH_SIZE}x{ACCUM_STEPS}")

# ---------------- DATASET ----------------
class CartoonColorizationDataset(Dataset):
    def __init__(self, root, img_size=IMG_SIZE, subset=SUBSET_SIZE):
        self.paths = []
        # Walk through directories to find all image files
        for r, d, f in os.walk(root):
            for x in f:
                if x.lower().endswith(("jpg","jpeg","png")):
                    self.paths.append(os.path.join(r, x))
        
        self.paths.sort() # Sorting ensures reproducibility (same order every run)
        if subset:
            self.paths = self.paths[:subset]
            
        # Transform Pipeline: Resizes image and converts to PyTorch Tensor (0-1 range)
        self.tf = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor()
        ])
        self.gray = transforms.Grayscale() # Used to create the input condition

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        
        # ERROR HANDLING:
        # If an image is corrupt, instead of crashing the training, 
        # we generate a blank grey image. This keeps the training loop alive.
        try:
            img = Image.open(p).convert('RGB')
        except Exception:
            img = Image.new('RGB', (IMG_SIZE, IMG_SIZE), color=(128,128,128))

        try:
            y = self.tf(img)               # This is the GROUND TRUTH (Color) we want to predict
            x = self.tf(self.gray(img))    # This is the CONDITION (Black & White) we give the model
        except Exception:
            # Fallback if transformation fails
            y = torch.zeros(3, IMG_SIZE, IMG_SIZE)
            x = torch.zeros(1, IMG_SIZE, IMG_SIZE)

        return x, y

# ---------------- NOISE SCHEDULE ----------------
# Diffusion works by slowly adding noise. We need to pre-calculate exactly how much
# noise to add at every timestep (from 0 to 1000).

def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
    # Creates a linear ramp of noise variance
    return torch.linspace(beta_start, beta_end, timesteps)

class NoiseSchedule:
    def __init__(self, timesteps=TIMESTEPS, device='cpu'):
        self.timesteps = timesteps
        betas = linear_beta_schedule(timesteps).to(device)
        alphas = 1.0 - betas
        
        # Alphas Cumulative Product: Represents how much "Original Signal" is left.
        # At t=0, this is 1.0 (pure signal). At t=1000, it approaches 0 (pure noise).
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = torch.cat([torch.tensor([1.0], device=device), alphas_cumprod[:-1]], dim=0)

        self.betas = betas
        self.alphas = alphas
        self.alphas_cumprod = alphas_cumprod
        self.alphas_cumprod_prev = alphas_cumprod_prev

        # Pre-calculate square roots for the forward diffusion formula:
        # x_t = sqrt(alpha_cumprod) * x_0  +  sqrt(1 - alpha_cumprod) * noise
        self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

    def q_sample(self, x_start, t, noise=None):
        """
        Forward Process: Adds noise to an image 'x_start' to timestep 't'.
        """
        if noise is None:
            noise = torch.randn_like(x_start) # Generate Gaussian noise
        
        # Gather the specific alpha values for the requested timesteps (t)
        # .view(-1, 1, 1, 1) reshapes the scalars to broadcast over image dimensions [B, C, H, W]
        a = self.sqrt_alphas_cumprod[t].view(-1,1,1,1)
        b = self.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
        
        return a * x_start + b * noise, noise

# ---------------- MODEL (UNET) ----------------
# The UNet predicts the NOISE that was added to the image.

class SinusoidalPositionEmb(nn.Module):
    """
    Encodes the timestep 't' into a vector. 
    This allows the UNet to know if it's looking at a slightly noisy image (t=10)
    or a completely destroyed static image (t=900).
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        # Uses sine and cosine frequencies (similar to Transformer positional encodings)
        half = self.dim // 2
        emb = math.log(10000) / (half - 1)
        emb = torch.exp(torch.arange(half, device=t.device) * -emb)
        emb = t.float().unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        if self.dim % 2 == 1:
            emb = torch.cat([emb, torch.zeros(t.size(0),1,device=t.device)], dim=1)
        return emb

class ConvBlock(nn.Module):
    """
    Standard building block: Conv2d -> GroupNorm -> SiLU (Swish) -> Conv2d -> GroupNorm -> SiLU
    """
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            # We use GroupNorm instead of BatchNorm because our Batch Size is very small (2).
            # BatchNorm becomes unstable with small batches; GroupNorm is robust.
            nn.GroupNorm(8, out_ch) if out_ch>=8 else nn.BatchNorm2d(out_ch),
            nn.SiLU(), # SiLU (Swish) activation works better for diffusion than ReLU
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch) if out_ch>=8 else nn.BatchNorm2d(out_ch),
            nn.SiLU()
        )
        # Residual connection: If input/output channels differ, project input to match output.
        self.res_conv = nn.Conv2d(in_ch, out_ch, 1) if in_ch!=out_ch else nn.Identity()

    def forward(self, x, t_emb=None):
        h = self.net(x)
        if t_emb is not None:
            # Inject time embedding info into the features
            te = t_emb.unsqueeze(-1).unsqueeze(-1)
            # Project time embedding to match feature channel count
            if te.shape[1] != h.shape[1]:
                proj = nn.Linear(te.shape[1], h.shape[1]).to(te.device)
                te = proj(t_emb).unsqueeze(-1).unsqueeze(-1)
            h = h + te # Add time info to the image features
        return h + self.res_conv(x) # Add residual connection

class MediumUNet(nn.Module):
    def __init__(self, in_ch=4, base_ch=BASE_CH, time_emb_dim=128):
        super().__init__()
        # Time Embedding MLP (Multi-Layer Perceptron)
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmb(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim*4),
            nn.SiLU(),
            nn.Linear(time_emb_dim*4, time_emb_dim)
        )

        # Encoder (Downsampling): Reduces image size, increases channels (learning features)
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.down1 = nn.Conv2d(base_ch, base_ch*2, 4, stride=2, padding=1)
        self.enc2 = ConvBlock(base_ch*2, base_ch*2)
        self.down2 = nn.Conv2d(base_ch*2, base_ch*4, 4, stride=2, padding=1)
        self.enc3 = ConvBlock(base_ch*4, base_ch*4)
        self.down3 = nn.Conv2d(base_ch*4, base_ch*8, 4, stride=2, padding=1)

        # Bottleneck: The deepest part of the network
        self.bot1 = ConvBlock(base_ch*8, base_ch*8)
        self.bot2 = ConvBlock(base_ch*8, base_ch*8)

        # Decoder (Upsampling): Increases image size, reduces channels (reconstructing details)
        # We use Transpose Convolutions to learn how to upscale.
        self.up3 = nn.ConvTranspose2d(base_ch*8, base_ch*4, 4, stride=2, padding=1)
        self.dec3 = ConvBlock(base_ch*8, base_ch*4)
        self.up2 = nn.ConvTranspose2d(base_ch*4, base_ch*2, 4, stride=2, padding=1)
        self.dec2 = ConvBlock(base_ch*4, base_ch*2)
        self.up1 = nn.ConvTranspose2d(base_ch*2, base_ch, 4, stride=2, padding=1)
        self.dec1 = ConvBlock(base_ch*2, base_ch)

        # Final projection to 3 RGB channels (Predicting the noise)
        self.out = nn.Conv2d(base_ch, 3, 3, padding=1)

    def forward(self, x_gray, y_noisy, t):
        # 1. Process Time Embedding
        t_emb = self.time_mlp(t)
        
        # 2. CONDITIONING STEP (CRITICAL):
        # We concatenate the Noisy Image (3 channels) with the Grayscale Hint (1 channel)
        # Result = 4 channels. This tells the model: "Given this grayscale shape, denoise this RGB static."
        x = torch.cat([y_noisy, x_gray], dim=1)

        # 3. Downsample (Store skip connections e1, e2, e3)
        e1 = self.enc1(x, t_emb)
        d1 = self.enc2(self.down1(e1), t_emb)
        e2 = d1
        d2 = self.enc3(self.down2(e2), t_emb)
        e3 = d2
        d3 = self.bot1(self.down3(e3), t_emb)
        d3 = self.bot2(d3, t_emb)

        # 4. Upsample (Concatenate with skip connections to preserve spatial details)
        u3 = self.up3(d3)
        c3 = torch.cat([u3, e3], dim=1) # Skip connection
        u3 = self.dec3(c3, t_emb)

        u2 = self.up2(u3)
        c2 = torch.cat([u2, e2], dim=1)
        u2 = self.dec2(c2, t_emb)

        u1 = self.up1(u2)
        c1 = torch.cat([u1, e1], dim=1)
        u1 = self.dec1(c1, t_emb)

        return self.out(u1)

# ---------------- EMA & SAMPLER ----------------
class EMA:
    """
    Exponential Moving Average.
    We maintain a 'shadow' copy of the model weights that updates slowly.
    This creates smoother results and prevents the model from jittering at the very end of training.
    """
    def __init__(self, model, decay=0.9999):
        self.ema = copy.deepcopy(model).eval()
        self.decay = decay
        self.num_updates = 0

    def update(self, model):
        self.num_updates += 1
        # Dynamic decay adjustment (optional, but helps early training)
        decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
        msd = model.state_dict()
        for k, v in self.ema.state_dict().items():
            if v.dtype.is_floating_point:
                # v_new = decay * v_old + (1-decay) * current_model_weight
                v *= decay
                v += (1 - decay) * msd[k].detach()
            else:
                v.copy_(msd[k])

@torch.no_grad()
def ddim_sample(model, x_gray, schedule, steps=50, eta=0.0, use_ema=False, ema_obj=None):
    """
    DDIM (Denoising Diffusion Implicit Models) Sampler.
    Unlike standard DDPM (which needs 1000 steps), DDIM can generate high-quality images 
    in just 50 steps by skipping timesteps deterministically.
    """
    device = x_gray.device
    B = x_gray.size(0)
    T = schedule.timesteps
    # Create a list of timesteps to visit (e.g., 999, 979, 959... 0)
    times = torch.linspace(T-1, 0, steps).long().to(device)

    # Start with pure random noise
    y = torch.randn(B,3,IMG_SIZE,IMG_SIZE, device=device)

    # Swap in EMA weights for inference if requested (gives better quality)
    if use_ema and ema_obj:
        model_backup = copy.deepcopy(model.state_dict())
        model.load_state_dict(ema_obj.ema.state_dict())

    # The Reverse Diffusion Loop
    for i, t in enumerate(times):
        t_batch = torch.full((B,), int(t.item()), dtype=torch.long, device=device)
        
        # Model predicts the noise in the current image 'y'
        eps_pred = model(x_gray, y, t_batch)

        alpha_t = schedule.alphas_cumprod[t].to(device)
        alpha_prev = schedule.alphas_cumprod_prev[t].to(device)
        sqrt_alpha_t = torch.sqrt(alpha_t)
        
        # 1. Predict what the clean image (x0) looks like
        x0_pred = (y - torch.sqrt(1 - alpha_t) * eps_pred) / (sqrt_alpha_t + 1e-8)
        
        # 2. Point towards the direction of the next timestep
        dir_xt = torch.sqrt(1 - alpha_prev) * eps_pred
        
        # 3. Combine to get the image at the previous timestep
        y = torch.sqrt(alpha_prev) * x0_pred + dir_xt

    # Restore original weights if EMA was used
    if use_ema and ema_obj:
        model.load_state_dict(model_backup)

    return y.clamp(0,1) # Ensure pixel values are valid image colors

# ---------------- TRAINING LOOP ----------------
def train(model, schedule, loader, opt, scaler, ema_obj=None):
    model.train()
    global_step = 0
    
    for epoch in range(EPOCHS):
        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
        running_loss = 0.0
        opt.zero_grad()
        
        for step, (x_gray, y) in enumerate(pbar):
            try:
                # non_blocking=True allows GPU to compute while CPU loads next batch
                x_gray = x_gray.to(DEVICE, non_blocking=PIN_MEMORY)
                y = y.to(DEVICE, non_blocking=PIN_MEMORY)
                
                B = y.size(0)
                
                # Pick random timesteps for each image in batch (0 to 1000)
                t = torch.randint(0, schedule.timesteps, (B,), device=DEVICE)
                
                # Add noise to the clean images (Forward Diffusion)
                y_noisy, noise = schedule.q_sample(y, t)

                # MIXED PRECISION (Autocast):
                # Runs math in float16 (half precision) where possible to save VRAM and speed up training.
                with torch.cuda.amp.autocast(enabled=(DEVICE=='cuda')):
                    # Model tries to predict the noise we just added
                    pred = model(x_gray, y_noisy, t)
                    # Loss = Mean Squared Error between Predicted Noise and Actual Noise
                    loss = ((pred - noise)**2).mean() / ACCUM_STEPS

                # SCALER: Helps standard gradients flow through float16 math without vanishing
                scaler.scale(loss).backward()

                # GRADIENT ACCUMULATION:
                # We only step the optimizer every ACCUM_STEPS (4).
                # This simulates a larger batch size than our GPU memory can actually hold.
                if (step + 1) % ACCUM_STEPS == 0:
                    scaler.step(opt)
                    scaler.update()
                    opt.zero_grad()
                    if ema_obj:
                        ema_obj.update(model)

                running_loss += float(loss.item()) * ACCUM_STEPS
                global_step += 1
                pbar.set_postfix({'loss': running_loss / (global_step)})

            except RuntimeError as e:
                # OOM (Out Of Memory) Protection:
                # If the GPU fills up, we clear the cache and skip the batch instead of crashing.
                if 'out of memory' in str(e):
                    print('OOM: Clearing cache...')
                    torch.cuda.empty_cache()
                    opt.zero_grad()
                    continue
                else:
                    raise

        # Save checkpoints
        if (epoch + 1) % CHECKPOINT_FREQ == 0:
            torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, f'model_epoch{epoch+1}.pth'))
            if ema_obj:
                torch.save(ema_obj.ema.state_dict(), os.path.join(CHECKPOINT_DIR, f'ema_epoch{epoch+1}.pth'))

    print('Training complete.')

# ---------------- MAIN ----------------
def main():
    dataset = CartoonColorizationDataset(DATASET_PATH, img_size=IMG_SIZE, subset=SUBSET_SIZE)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
    print(f'Total images: {len(dataset)}')

    schedule = NoiseSchedule(timesteps=TIMESTEPS, device=DEVICE)

    model = MediumUNet(in_ch=4, base_ch=BASE_CH, time_emb_dim=128).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=LR)
    
    # GradScaler is required when using Mixed Precision training to prevent gradient underflow
    scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=='cuda'))
    ema_obj = EMA(model, decay=0.9999)

    print('Starting training...')
    train(model, schedule, loader, opt, scaler, ema_obj=ema_obj)

    # Visualization Block
    print('Generating samples...')
    plt.figure(figsize=(8, 2 * SAMPLES_TO_SAVE))
    
    for i in range(SAMPLES_TO_SAVE):
        x_gray, _ = dataset[i]
        # Reshape to [1, 1, H, W] for the model
        x_gray_b = x_gray.unsqueeze(0).to(DEVICE)
        
        # Inference using the trained model
        y_sample = ddim_sample(model, x_gray_b, schedule, steps=50, eta=0.0, use_ema=True, ema_obj=ema_obj)[0].cpu()

        # Display Input (Gray)
        plt.subplot(SAMPLES_TO_SAVE, 2, i*2+1)
        # Permute changes order from [Channel, Height, Width] to [Height, Width, Channel] for Matplotlib
        plt.imshow(x_gray.permute(1,2,0).squeeze(), cmap='gray')
        plt.axis('off')
        
        # Display Output (Color)
        plt.subplot(SAMPLES_TO_SAVE, 2, i*2+2)
        plt.imshow(y_sample.permute(1,2,0))
        plt.axis('off')

    plt.tight_layout()
    plt.show()

    # Final Save
    torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, 'model_final.pth'))
    torch.save(ema_obj.ema.state_dict(), os.path.join(CHECKPOINT_DIR, 'ema_final.pth'))
    print('Saved final models.')

if __name__ == '__main__':
    main()