# Class-Conditional Latent Diffusion on TPU v5e-8

Training PlantVillage dataset with DDIM on Kaggle TPU.

**Important**: Enable TPU in Notebook Settings!

In [None]:
# Cell 1: Environment Setup (MUST RUN FIRST!)
import os
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
os.environ['JAX_PLATFORMS'] = 'tpu'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

print("✓ Environment variables set")

In [None]:
# Cell 2: Install Dependencies
!pip install -q flax==0.8.0 optax==0.1.9
!pip install -q diffusers==0.25.1 transformers
!pip install -q tensorflow-hub scipy einops
!pip install -q wandb PyYAML tqdm

print("✓ Dependencies installed")

In [None]:
# Cell 3: Clone/Setup Code
# Option 1: Clone from GitHub
!git clone https://github.com/YOUR_USERNAME/ddim.git /kaggle/working/ddim

# Option 2: Copy from Kaggle dataset (if you uploaded as dataset)
# !cp -r /kaggle/input/your-code-dataset/ddim /kaggle/working/

%cd /kaggle/working/ddim/jax
!ls -la
print("✓ Code ready")

In [None]:
# Cell 4: Verify TPU
import jax
import jax.numpy as jnp

print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")
print(f"Device count: {jax.device_count()}")
print(f"Local devices: {jax.local_devices()}")

# Quick TPU test
x = jnp.ones((1000, 1000))
y = jnp.dot(x, x)
print(f"✓ TPU test passed: {y.shape}")

In [None]:
# Cell 5: Create Directories
import os

dirs = [
    "/kaggle/working/logs",
    "/kaggle/working/checkpoints",
    "/kaggle/working/samples",
    "/kaggle/working/fid_stats"
]

for d in dirs:
    os.makedirs(d, exist_ok=True)
    print(f"✓ Created {d}")

In [None]:
# Cell 6: Configure Wandb (Optional)
import wandb

# Option 1: Login with API key
wandb.login(key="YOUR_WANDB_API_KEY")

# Option 2: Use anonymous mode
# wandb.init(anonymous="allow")

print("✓ Wandb configured")

In [None]:
# Cell 7: Verify VAE
from utils.vae import create_vae

print("Loading VAE...")
vae = create_vae("stabilityai/sd-vae-ft-mse")
print("✓ VAE loaded successfully")

# Quick test
test_images = jnp.ones((1, 256, 256, 3)) * 0.5
test_latents = vae.encode(test_images)
print(f"Test encode: {test_images.shape} -> {test_latents.shape}")
print("✓ VAE test passed")

In [None]:
# Cell 8: Compute FID Stats (Run once)
# This computes FID statistics for the validation set
# Only needs to be run once, then cached

!python compute_fid_stats.py \
    --config plantvillage_latent.yml \
    --split val \
    --num_samples 500 \
    --output /kaggle/working/fid_stats/plantvillage_val_fid_stats.npz

print("✓ FID stats computed")

In [None]:
# Cell 9: Start Training
# This will train for 200k steps (takes ~20-30 hours on TPU v5e-8)

!python train_tpu.py \
    --config plantvillage_latent.yml \
    --doc plantvillage_tpu_v1

# Training will:
# - Save checkpoints every 25k steps to /kaggle/working/logs/plantvillage_tpu_v1/checkpoints/
# - Generate samples every 10k steps to /kaggle/working/logs/plantvillage_tpu_v1/samples/
# - Compute FID every 25k steps
# - Log to wandb (if configured)

In [None]:
# Cell 10: Monitor Training (Optional)
from IPython.display import Image, display
import glob

# Show latest samples
sample_files = sorted(glob.glob("/kaggle/working/logs/*/samples/*.png"))
if sample_files:
    latest = sample_files[-1]
    print(f"Latest sample: {latest}")
    display(Image(latest))
else:
    print("No samples yet - training just started")

In [None]:
# Cell 11: View Training Logs
!tail -n 50 /kaggle/working/logs/plantvillage_tpu_v1/stdout.txt 2>/dev/null || echo "No logs yet"

In [None]:
# Cell 12: List Checkpoints
import glob

ckpts = sorted(glob.glob("/kaggle/working/logs/*/checkpoints/*.pkl"))
print(f"Found {len(ckpts)} checkpoints:")
for ckpt in ckpts[-5:]:
    print(f"  {ckpt}")

## Tips

### Resume Training
If training is interrupted, just run Cell 9 again. It will auto-resume from the last checkpoint.

### Adjust Training
Edit `configs/plantvillage_latent.yml` to change:
- Batch size
- Learning rate  
- Number of steps
- Sample/checkpoint frequency

### Save Results
Download these directories before the notebook ends:
- `/kaggle/working/logs/plantvillage_tpu_v1/checkpoints/` - Model weights
- `/kaggle/working/logs/plantvillage_tpu_v1/samples/` - Generated samples

### Debugging
If you get errors:
1. Restart notebook and run cells 1-4 again
2. Check TPU is enabled in Settings
3. Verify dataset path in config matches Kaggle input path
