# McKean–Vlasov 3D Diffusion — Colab Runner (GPU + Drive + Private Git)

**What this does**
1. Mounts Google Drive (datasets in, artifacts out)
2. Installs JAX (CUDA), Flax, Optax, Torch (CPU-only for `.pt`)
3. Clones your **private** GitHub repo/branch without printing your token
4. Runs `main.py` on GPU, saving checkpoints/samples to Drive
5. Visualizes latest generated samples (`.npy`) as 3D MPL landscapes

**Before you run**
- Put your dataset `.pt` in Drive, e.g. `/MyDrive/datasets/unified_topological_data_v6_semifast.pt`

In [None]:
# Colab / Drive / GPU setup + deps
import os, sys, subprocess, json, time, textwrap, glob
from pathlib import Path

try:
    import google.colab  # type: ignore
    IN_COLAB = True
except Exception:
    IN_COLAB = False
print("IN_COLAB:", IN_COLAB)

if IN_COLAB:
    from google.colab import drive  # type: ignore
    drive.mount('/content/drive', force_remount=True)
    # Show GPU
    try:
        print(subprocess.check_output(["nvidia-smi"], text=True))
    except Exception as e:
        print("No NVIDIA GPU visible:", e)

import jax
print("JAX devices:", jax.devices())
print("Backend:", jax.lib.xla_bridge.get_backend().platform)

In [None]:
# Point this to the folder in Drive that contains your code:
# The folder should have: dataloader.py, models.py, losses_steps.py, sampling.py, main.py, etc.
# Example: /content/drive/MyDrive/mckean-vlasov
REPO_DIR = Path("/content/drive/MyDrive/ph-mckeanvlasov-diff/src/mckean-vlasov").resolve()

assert REPO_DIR.exists(), f"Repo dir not found: {REPO_DIR}"
os.chdir(REPO_DIR)
print("CWD:", Path.cwd())

# Quick tree for sanity
def tree(path, max_levels=2, prefix=""):
    path = Path(path)
    print(prefix + path.name + "/")
    if max_levels <= 0:
        return
    for p in sorted(path.iterdir()):
        if p.is_dir():
            tree(p, max_levels-1, prefix + "  ")
        else:
            print(prefix + "  " + p.name)

tree(REPO_DIR, max_levels=2)

# Verify required files exist
required = ["dataloader.py","models.py","losses_steps.py","sampling.py","main.py"]
missing = [f for f in required if not (REPO_DIR/f).exists()]
assert not missing, f"Missing files: {missing}"

# Locate dataset .pt (edit if needed)
CANDIDATES = glob.glob("/content/drive/MyDrive/ph-mckeanvlasov-diff/**/unified_topological_data*.pt", recursive=True)
print("Found .pt candidates:", CANDIDATES[:3])
if not CANDIDATES:
    print(">>> If you don't see your dataset, set DATA_PT manually in next cell.")

## Env setup

In [None]:
%pip -q install gudhi==3.11.0 multipers==2.3.2 geomstats==2.8.0 POT==0.9.5

In [None]:
# Optional: KeOps forced to CPU (avoid NVCC builds on Colab)
%pip -q install pykeops==2.2.3
import os
os.environ["KEOPS_VERBOSE"]="0"
os.environ["PYKEOPS_FORCE_BUILD"]="1"
os.environ["USE_CUDA"]="0"
print("PyKeOps set to CPU mode.")

# Sanity check
import os
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"  # grow-as-needed GPU memory
import jax, jax.numpy as jnp, numpy as np, flax, optax, ml_dtypes
print("JAX:", jax.__version__, "| Devices:", jax.devices())
x = jnp.ones((2048,2048))
print("Matmul:", (x@x).block_until_ready().shape)
import torch
print("Torch:", torch.__version__, "CUDA avail:", torch.cuda.is_available())
print("NumPy:", np.__version__, "Flax:", flax.__version__, "Optax:", optax.__version__)

In [None]:
!export XLA_PYTHON_CLIENT_ALLOCATOR=platform
!export XLA_PYTHON_CLIENT_PREALLOCATE=false
!export XLA_FLAGS="--xla_gpu_autotune_level=2 --xla_gpu_enable_triton=false"

## Repo / Paths
Fill in your repo/user/branch and paths. Artifacts will go to Drive under `OUTDIR`.

In [None]:
import shlex, datetime

# === Set paths/args ===
# If the auto search in previous cell didn’t find your dataset, set it explicitly here:
DATA_PT = CANDIDATES[0] if len(CANDIDATES) else "/content/drive/MyDrive/ph-mckeanvlasov-diff/datasets/unified_topological_data_v6_semifast.pt"

# Choose an output folder on Drive
stamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
OUTDIR = Path("/content/drive/MyDrive/ph-mckeanvlasov-diff/src/mckean-vlasov/runs/mv_sde") / stamp
OUTDIR.mkdir(parents=True, exist_ok=True)

# Training args — tune as you like
args = {
    "--data_pt": DATA_PT,
    "--batch": "4",                # bump to 24/32 on A100 if it fits
    "--steps": "20000",
    "--seed": "0",
    "--lr": "2e-4",
    "--lr_energy": "3e-4",
    "--lr_enc": "1e-3",
    "--T": "1000",
    "--schedule": "cosine",
    "--v_pred": "",                 # flag
    "--ema_decay": "0.999",
    "--use_energy": "",             # flag
    "--energy_scale": "0.1",
    "--energy_tau": "0.07",
    "--energy_gp": "1e-4",
    "--mf_mode": "rbf",
    "--mf_lambda": "0.05",
    "--mf_bandwidth": "0.5",
    "--outdir": str(OUTDIR),
    "--ckpt_every": "2000",
    "--sample_every": "2000",
    "--sample_steps": "250",
    "--sample_label": "-1",
    # If you want classifier-free guidance enabled in your main.py, add:
    "--cfg_drop": "0.1",
    "--cfg_scale": "3.0",
    "--cfg_sched": "cosine",
    "--cfg_strength": "5.0",
}

# Pretty print
print("Training OUTDIR:", OUTDIR)
print("DATA_PT:", DATA_PT)

# Build argv
argv = [sys.executable, "-u", "main.py"]
for k, v in args.items():
    if v == "":                       # boolean flags
        argv.append(k)
    else:
        argv.extend([k, v])

print('Launching:\n', ' '.join(shlex.quote(a) for a in argv))

# Stream logs live
proc = subprocess.Popen(argv, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
try:
    for line in proc.stdout:
        print(line, end="")
finally:
    ret = proc.wait()
    print("\nTraining exit code:", ret)
    if ret != 0:
        raise SystemExit(ret)
print("Done.")