In [None]:
# FULL 128x128 DDPM TRAINING PIPELINE
# Features: Mixed Precision, Gradient Accumulation, EMA, and Robust Data Loading.

print("ðŸ”¥ Script started â€” loading imports...")

# Handle truncated/corrupt images in the dataset to prevent crashes
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

import os
import random
import math
import copy
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

# ---------------- CONFIGURATION ----------------
DATASET_PATH = "/kaggle/input/cartoon-classification/cartoon_classification/TRAIN"
IMG_SIZE = 128
BASE_CH = 32              # UNet base channels
BATCH_SIZE = 2            # Low batch size per GPU step to save VRAM
ACCUM_STEPS = 4           # effective_batch = BATCH_SIZE * ACCUM_STEPS
EPOCHS = 12
LR = 2e-4
NUM_WORKERS = 4
PIN_MEMORY = True
TIMESTEPS = 1000          # Diffusion steps
SUBSET_SIZE = None        # Set int to limit dataset size for testing
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CHECKPOINT_DIR = "checkpoints"
CHECKPOINT_FREQ = 1
SAMPLES_TO_SAVE = 4
SEED = 42

os.makedirs(CHECKPOINT_DIR, exist_ok=True)
random.seed(SEED)
torch.manual_seed(SEED)

print(f"Config: Device={DEVICE} | Size={IMG_SIZE} | Batch={BATCH_SIZE}x{ACCUM_STEPS}")

# ---------------- DATASET ----------------
class CartoonColorizationDataset(Dataset):
    def __init__(self, root, img_size=IMG_SIZE, subset=SUBSET_SIZE):
        self.paths = []
        for r, d, f in os.walk(root):
            for x in f:
                if x.lower().endswith(("jpg","jpeg","png")):
                    self.paths.append(os.path.join(r, x))
        
        self.paths.sort()
        if subset:
            self.paths = self.paths[:subset]
            
        self.tf = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor()
        ])
        self.gray = transforms.Grayscale()

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        
        # Robust loading for corrupt images
        try:
            img = Image.open(p).convert('RGB')
        except Exception:
            # Return gray placeholder if image is unreadable
            img = Image.new('RGB', (IMG_SIZE, IMG_SIZE), color=(128,128,128))

        try:
            y = self.tf(img)               # Ground truth (Color)
            x = self.tf(self.gray(img))    # Condition (Grayscale)
        except Exception:
            y = torch.zeros(3, IMG_SIZE, IMG_SIZE)
            x = torch.zeros(1, IMG_SIZE, IMG_SIZE)

        return x, y

# ---------------- NOISE SCHEDULE ----------------
def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
    return torch.linspace(beta_start, beta_end, timesteps)

class NoiseSchedule:
    def __init__(self, timesteps=TIMESTEPS, device='cpu'):
        self.timesteps = timesteps
        betas = linear_beta_schedule(timesteps).to(device)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = torch.cat([torch.tensor([1.0], device=device), alphas_cumprod[:-1]], dim=0)

        self.betas = betas
        self.alphas = alphas
        self.alphas_cumprod = alphas_cumprod
        self.alphas_cumprod_prev = alphas_cumprod_prev

        self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        
        a = self.sqrt_alphas_cumprod[t].view(-1,1,1,1)
        b = self.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
        
        return a * x_start + b * noise, noise

# ---------------- MODEL (UNET) ----------------
class SinusoidalPositionEmb(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        half = self.dim // 2
        emb = math.log(10000) / (half - 1)
        emb = torch.exp(torch.arange(half, device=t.device) * -emb)
        emb = t.float().unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        if self.dim % 2 == 1:
            emb = torch.cat([emb, torch.zeros(t.size(0),1,device=t.device)], dim=1)
        return emb

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch) if out_ch>=8 else nn.BatchNorm2d(out_ch),
            nn.SiLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch) if out_ch>=8 else nn.BatchNorm2d(out_ch),
            nn.SiLU()
        )
        self.res_conv = nn.Conv2d(in_ch, out_ch, 1) if in_ch!=out_ch else nn.Identity()

    def forward(self, x, t_emb=None):
        h = self.net(x)
        if t_emb is not None:
            te = t_emb.unsqueeze(-1).unsqueeze(-1)
            # Project emb if channels don't match
            if te.shape[1] != h.shape[1]:
                proj = nn.Linear(te.shape[1], h.shape[1]).to(te.device)
                te = proj(t_emb).unsqueeze(-1).unsqueeze(-1)
            h = h + te
        return h + self.res_conv(x)

class MediumUNet(nn.Module):
    def __init__(self, in_ch=4, base_ch=BASE_CH, time_emb_dim=128):
        super().__init__()
        # Time Embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmb(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim*4),
            nn.SiLU(),
            nn.Linear(time_emb_dim*4, time_emb_dim)
        )

        # Encoder
        self.enc1 = ConvBlock(in_ch, base_ch)
        self.down1 = nn.Conv2d(base_ch, base_ch*2, 4, stride=2, padding=1)
        self.enc2 = ConvBlock(base_ch*2, base_ch*2)
        self.down2 = nn.Conv2d(base_ch*2, base_ch*4, 4, stride=2, padding=1)
        self.enc3 = ConvBlock(base_ch*4, base_ch*4)
        self.down3 = nn.Conv2d(base_ch*4, base_ch*8, 4, stride=2, padding=1)

        # Bottleneck
        self.bot1 = ConvBlock(base_ch*8, base_ch*8)
        self.bot2 = ConvBlock(base_ch*8, base_ch*8)

        # Decoder
        self.up3 = nn.ConvTranspose2d(base_ch*8, base_ch*4, 4, stride=2, padding=1)
        self.dec3 = ConvBlock(base_ch*8, base_ch*4)
        self.up2 = nn.ConvTranspose2d(base_ch*4, base_ch*2, 4, stride=2, padding=1)
        self.dec2 = ConvBlock(base_ch*4, base_ch*2)
        self.up1 = nn.ConvTranspose2d(base_ch*2, base_ch, 4, stride=2, padding=1)
        self.dec1 = ConvBlock(base_ch*2, base_ch)

        self.out = nn.Conv2d(base_ch, 3, 3, padding=1)

    def forward(self, x_gray, y_noisy, t):
        t_emb = self.time_mlp(t)
        # Condition: Concatenate noisy image + grayscale input
        x = torch.cat([y_noisy, x_gray], dim=1)

        # Down
        e1 = self.enc1(x, t_emb)
        d1 = self.enc2(self.down1(e1), t_emb)
        e2 = d1
        d2 = self.enc3(self.down2(e2), t_emb)
        e3 = d2
        d3 = self.bot1(self.down3(e3), t_emb)
        d3 = self.bot2(d3, t_emb)

        # Up
        u3 = self.up3(d3)
        c3 = torch.cat([u3, e3], dim=1)
        u3 = self.dec3(c3, t_emb)

        u2 = self.up2(u3)
        c2 = torch.cat([u2, e2], dim=1)
        u2 = self.dec2(c2, t_emb)

        u1 = self.up1(u2)
        c1 = torch.cat([u1, e1], dim=1)
        u1 = self.dec1(c1, t_emb)

        return self.out(u1)

# ---------------- EMA & SAMPLER ----------------
class EMA:
    def __init__(self, model, decay=0.9999):
        self.ema = copy.deepcopy(model).eval()
        self.decay = decay
        self.num_updates = 0

    def update(self, model):
        self.num_updates += 1
        decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
        msd = model.state_dict()
        for k, v in self.ema.state_dict().items():
            if v.dtype.is_floating_point:
                v *= decay
                v += (1 - decay) * msd[k].detach()
            else:
                v.copy_(msd[k])

@torch.no_grad()
def ddim_sample(model, x_gray, schedule, steps=50, eta=0.0, use_ema=False, ema_obj=None):
    """Samples using DDIM to speed up inference (e.g., 50 steps instead of 1000)."""
    device = x_gray.device
    B = x_gray.size(0)
    T = schedule.timesteps
    times = torch.linspace(T-1, 0, steps).long().to(device)

    y = torch.randn(B,3,IMG_SIZE,IMG_SIZE, device=device)

    if use_ema and ema_obj:
        model_backup = copy.deepcopy(model.state_dict())
        model.load_state_dict(ema_obj.ema.state_dict())

    for i, t in enumerate(times):
        t_batch = torch.full((B,), int(t.item()), dtype=torch.long, device=device)
        eps_pred = model(x_gray, y, t_batch)

        alpha_t = schedule.alphas_cumprod[t].to(device)
        alpha_prev = schedule.alphas_cumprod_prev[t].to(device)
        sqrt_alpha_t = torch.sqrt(alpha_t)
        sqrt_alpha_prev = torch.sqrt(alpha_prev)

        x0_pred = (y - torch.sqrt(1 - alpha_t) * eps_pred) / (sqrt_alpha_t + 1e-8)
        sigma_t = eta * torch.sqrt((1 - alpha_prev) / (1 - alpha_t) * (1 - alpha_t/alpha_prev + 1e-8))
        dir_xt = torch.sqrt(torch.clamp(1 - alpha_prev - sigma_t ** 2, min=0.0)) * eps_pred
        y = sqrt_alpha_prev * x0_pred + dir_xt

    if use_ema and ema_obj:
        model.load_state_dict(model_backup)

    return y.clamp(0,1)

# ---------------- TRAINING LOOP ----------------
def train(model, schedule, loader, opt, scaler, ema_obj=None):
    model.train()
    global_step = 0
    
    for epoch in range(EPOCHS):
        pbar = tqdm(loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
        running_loss = 0.0
        opt.zero_grad()
        
        for step, (x_gray, y) in enumerate(pbar):
            try:
                x_gray = x_gray.to(DEVICE, non_blocking=PIN_MEMORY)
                y = y.to(DEVICE, non_blocking=PIN_MEMORY)
                
                # Sample noise
                B = y.size(0)
                t = torch.randint(0, schedule.timesteps, (B,), device=DEVICE)
                y_noisy, noise = schedule.q_sample(y, t)

                # Mixed Precision Forward
                with torch.cuda.amp.autocast(enabled=(DEVICE=='cuda')):
                    pred = model(x_gray, y_noisy, t)
                    loss = ((pred - noise)**2).mean() / ACCUM_STEPS

                # Backward & Step
                scaler.scale(loss).backward()

                if (step + 1) % ACCUM_STEPS == 0:
                    scaler.step(opt)
                    scaler.update()
                    opt.zero_grad()
                    if ema_obj:
                        ema_obj.update(model)

                running_loss += float(loss.item()) * ACCUM_STEPS
                global_step += 1
                pbar.set_postfix({'loss': running_loss / (global_step)})

            except RuntimeError as e:
                # Clear cache on OOM and skip batch
                if 'out of memory' in str(e):
                    print('OOM: Clearing cache...')
                    torch.cuda.empty_cache()
                    opt.zero_grad()
                    continue
                else:
                    raise

        # Save checkpoints
        if (epoch + 1) % CHECKPOINT_FREQ == 0:
            torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, f'model_epoch{epoch+1}.pth'))
            if ema_obj:
                torch.save(ema_obj.ema.state_dict(), os.path.join(CHECKPOINT_DIR, f'ema_epoch{epoch+1}.pth'))

    print('Training complete.')

# ---------------- MAIN ----------------
def main():
    dataset = CartoonColorizationDataset(DATASET_PATH, img_size=IMG_SIZE, subset=SUBSET_SIZE)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
    print(f'Total images: {len(dataset)}')

    schedule = NoiseSchedule(timesteps=TIMESTEPS, device=DEVICE)

    model = MediumUNet(in_ch=4, base_ch=BASE_CH, time_emb_dim=128).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=LR)
    scaler = torch.cuda.amp.GradScaler(enabled=(DEVICE=='cuda'))
    ema_obj = EMA(model, decay=0.9999)

    print('Starting training...')
    train(model, schedule, loader, opt, scaler, ema_obj=ema_obj)

    # Validate/Visuzalize
    print('Generating samples...')
    plt.figure(figsize=(8, 2 * SAMPLES_TO_SAVE))
    
    for i in range(SAMPLES_TO_SAVE):
        x_gray, _ = dataset[i]
        x_gray_b = x_gray.unsqueeze(0).to(DEVICE)
        
        y_sample = ddim_sample(model, x_gray_b, schedule, steps=50, eta=0.0, use_ema=True, ema_obj=ema_obj)[0].cpu()

        plt.subplot(SAMPLES_TO_SAVE, 2, i*2+1)
        plt.imshow(x_gray.permute(1,2,0).squeeze(), cmap='gray')
        plt.axis('off')
        
        plt.subplot(SAMPLES_TO_SAVE, 2, i*2+2)
        plt.imshow(y_sample.permute(1,2,0))
        plt.axis('off')

    plt.tight_layout()
    plt.show()

    # Final Save
    torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, 'model_final.pth'))
    torch.save(ema_obj.ema.state_dict(), os.path.join(CHECKPOINT_DIR, 'ema_final.pth'))
    print('Saved final models.')

if __name__ == '__main__':
    main()