In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from pathlib import Path
import numpy as np
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from tqdm import tqdm
from torchvision import transforms
import argparse
import traceback

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--resume', action='store_true', help='Resume training from latest checkpoint')
    parser.add_argument('--checkpoint', type=str, default=None, help='Specific checkpoint path to resume from')
    return parser.parse_args()

def load_checkpoint(checkpoint_path, unet, proj, optimizer):
    """Load model and optimizer states from checkpoint"""
    checkpoint = torch.load(checkpoint_path)
    unet.load_state_dict(checkpoint['unet'])
    proj.load_state_dict(checkpoint['proj'])
    if optimizer and 'optimizer' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer'])
    epoch = checkpoint.get('epoch', 0)
    step = checkpoint.get('step', 0)
    print(f"✅ Loaded checkpoint from epoch {epoch}, step {step}")
    return epoch, step

def save_checkpoint(epoch, step, unet, proj, optimizer, output_dir, prefix='epoch'):
    """Save training state"""
    ckpt_path = output_dir / f"{prefix}_{epoch:03d}_step{step:06d}.pt"
    torch.save({
        'unet': unet.state_dict(),
        'proj': proj.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
        'step': step
    }, ckpt_path)
    # Also save as latest
    torch.save({
        'unet': unet.state_dict(),
        'proj': proj.state_dict(),
        'optimizer': optimizer.state_dict(),
        'epoch': epoch,
        'step': step
    }, output_dir / "latest_checkpoint.pt")
    return ckpt_path

class CelebALatentDataset(Dataset):
    def __init__(self, precomputed_latents, attr_codes, scheduler, device):
        self.latents = precomputed_latents
        self.attr_codes = attr_codes
        self.scheduler = scheduler
        self.device = device

    def __len__(self):
        return len(self.latents)

    def __getitem__(self, idx):
        latent = self.latents[idx]
        noise = torch.randn_like(latent)
        t = torch.randint(0, self.scheduler.config.num_train_timesteps, (1,)).long()
        noisy_latent = self.scheduler.add_noise(latent.unsqueeze(0), noise.unsqueeze(0), t).squeeze(0)
        cond = self.attr_codes[idx].float()
        return noisy_latent, t.squeeze(0), noise, cond

def main():
    args = parse_args()

    try:
        # =====================================================================
        # 1. Initialization
        # =====================================================================
        print("="*80)
        print("STARTING SCRIPT EXECUTION")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"PyTorch: {torch.__version__}")
        print(f"Device: {device}")
        print("="*80)

        # =====================================================================
        # 2. Configuration
        # =====================================================================
        IMG_DIR = Path("/mnt/beegfs/home/gs113310/KGGD-project/img_align_celeba/img_align_celeba")
        ATTR_PATH = Path("/mnt/beegfs/home/gs113310/KGGD-project/compressed_attr_vectors.npy")
        OUTPUT_DIR = Path("/mnt/beegfs/home/gs113310/KGGD-project/sd_checkpoints")
        LATENTS_PATH = OUTPUT_DIR / "precomputed_latents.pt"
        OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

        # =====================================================================
        # 3. Load Data and Models
        # =====================================================================
        print("Loading data and models...")
        attr_codes = torch.from_numpy(np.load(ATTR_PATH)).float()

        vae = AutoencoderKL.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            subfolder="vae"
        ).to(device)

        unet = UNet2DConditionModel.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            subfolder="unet"
        ).to(device)

        scheduler = DDPMScheduler.from_pretrained(
            "runwayml/stable-diffusion-v1-5",
            subfolder="scheduler"
        )

        # =====================================================================
        # 4. Projection and Dataset Setup
        # =====================================================================
        proj = nn.Sequential(
            nn.Linear(128, unet.config.cross_attention_dim),
            nn.Unflatten(1, (1, -1))  # [batch, 1, 768]
        ).to(device)

        # Load or compute latents
        if not LATENTS_PATH.exists():
            print("Precomputing VAE latents...")
            preprocess = transforms.Compose([
                transforms.Resize((512, 512)),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5])
            ])

            image_paths = sorted(IMG_DIR.glob("*.jpg"))
            all_latents = []
            vae.eval()
            with torch.no_grad():
                for i, path in enumerate(tqdm(image_paths)):
                    img = Image.open(path).convert("RGB")
                    img = preprocess(img).unsqueeze(0).to(device)
                    latents = vae.encode(img).latent_dist.sample() * vae.config.scaling_factor
                    all_latents.append(latents.cpu())
                    if i % 1000 == 0:
                        torch.cuda.empty_cache()

            all_latents = torch.cat(all_latents, dim=0)
            torch.save(all_latents, LATENTS_PATH)

        precomputed_latents = torch.load(LATENTS_PATH)
        dataset = CelebALatentDataset(precomputed_latents, attr_codes, scheduler, device)
        loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, pin_memory=True)

        # =====================================================================
        # 5. Training Setup with Resume Capability
        # =====================================================================
        optimizer = optim.AdamW(list(unet.parameters()) + list(proj.parameters()), lr=1e-4)
        mse_loss = nn.MSELoss()
        num_epochs = 5

        # Resume logic
        start_epoch = 1
        start_step = 0
        if args.resume or args.checkpoint:
            checkpoint_path = args.checkpoint if args.checkpoint else OUTPUT_DIR / "latest_checkpoint.pt"
            start_epoch, start_step = load_checkpoint(checkpoint_path, unet, proj, optimizer)
            print(f"Resuming from epoch {start_epoch}, step {start_step}")

        # Skip already processed steps if resuming
        if start_step > 0:
            print(f"Skipping first {start_step} steps...")
            loader = list(loader)[start_step:]

        # =====================================================================
        # 6. Training Loop with Checkpointing
        # =====================================================================
        print(f"🚀 Starting training from epoch {start_epoch}...")
        for epoch in range(start_epoch, num_epochs + 1):
            try:
                unet.train()
                proj.train()
                epoch_loss = 0.0

                pbar = tqdm(loader, desc=f"Epoch {epoch}/{num_epochs}")
                for step, (noisy_latents, timesteps, noise, cond) in enumerate(pbar, start=1):
                    # Skip steps if resuming
                    if epoch == start_epoch and step < start_step:
                        continue

                    # Move data to device
                    noisy_latents = noisy_latents.to(device, non_blocking=True)
                    timesteps = timesteps.to(device, non_blocking=True)
                    noise = noise.to(device, non_blocking=True)
                    cond = proj(cond.to(device, non_blocking=True))

                    # Forward pass
                    pred = unet(noisy_latents, timesteps, encoder_hidden_states=cond).sample
                    loss = mse_loss(pred, noise)

                    # Backward pass
                    optimizer.zero_grad(set_to_none=True)
                    loss.backward()
                    optimizer.step()

                    epoch_loss += loss.item()
                    pbar.set_postfix({"loss": f"{loss.item():.4f}"})

                    # Save checkpoint every 1000 steps
                    if step % 1000 == 0:
                        ckpt_path = save_checkpoint(epoch, step, unet, proj, optimizer, OUTPUT_DIR, "interim")
                        print(f"Saved interim checkpoint to {ckpt_path}")
                        torch.cuda.empty_cache()

                # Save epoch checkpoint
                avg_loss = epoch_loss / len(loader)
                print(f"✅ Epoch {epoch} complete - Avg Loss: {avg_loss:.6f}")
                ckpt_path = save_checkpoint(epoch, 0, unet, proj, optimizer, OUTPUT_DIR)
                print(f"Saved checkpoint to {ckpt_path}")
                raise

    except Exception as e:
        print(f"❌ Training failed: {str(e)}")
        sys.exit(1)
    finally:
        print("SCRIPT COMPLETED")

if __name__ == "__main__":
    main()
