# McKean–Vlasov 3D Diffusion — Colab Runner (GPU + Drive + Private Git)

**What this does**
1. Mounts Google Drive (datasets in, artifacts out)
2. Installs JAX (CUDA), Flax, Optax, Torch (CPU-only for `.pt`)
3. Clones your **private** GitHub repo/branch without printing your token
4. Runs `main.py` on GPU, saving checkpoints/samples to Drive
5. Visualizes latest generated samples (`.npy`) as 3D MPL landscapes

**Before you run**
- Put your dataset `.pt` in Drive, e.g. `/MyDrive/datasets/unified_topological_data_v6_semifast.pt`

In [1]:
# Colab / Drive / GPU setup + deps
import os, sys, subprocess, json, time, textwrap, glob
from pathlib import Path

try:
    import google.colab  # type: ignore
    IN_COLAB = True
except Exception:
    IN_COLAB = False
print("IN_COLAB:", IN_COLAB)

if IN_COLAB:
    from google.colab import drive  # type: ignore
    drive.mount('/content/drive', force_remount=True)
    # Show GPU
    try:
        print(subprocess.check_output(["nvidia-smi"], text=True))
    except Exception as e:
        print("No NVIDIA GPU visible:", e)

import jax
print("JAX devices:", jax.devices())
print("Backend:", jax.lib.xla_bridge.get_backend().platform)

IN_COLAB: True
Mounted at /content/drive
Mon Aug 25 07:44:28 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   44C    P0             58W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
       

  print("Backend:", jax.lib.xla_bridge.get_backend().platform)


In [2]:
# Point this to the folder in Drive that contains your code:
# The folder should have: dataloader.py, models.py, losses_steps.py, sampling.py, main.py, etc.
# Example: /content/drive/MyDrive/mckean-vlasov
REPO_DIR = Path("/content/drive/MyDrive/ph-mckeanvlasov-diff/src/mckean-vlasov").resolve()

assert REPO_DIR.exists(), f"Repo dir not found: {REPO_DIR}"
os.chdir(REPO_DIR)
print("CWD:", Path.cwd())

# Quick tree for sanity
def tree(path, max_levels=2, prefix=""):
    path = Path(path)
    print(prefix + path.name + "/")
    if max_levels <= 0:
        return
    for p in sorted(path.iterdir()):
        if p.is_dir():
            tree(p, max_levels-1, prefix + "  ")
        else:
            print(prefix + "  " + p.name)

tree(REPO_DIR, max_levels=2)

# Verify required files exist
required = ["dataloader.py","models.py","losses_steps.py","sampling.py","main.py"]
missing = [f for f in required if not (REPO_DIR/f).exists()]
assert not missing, f"Missing files: {missing}"

# Locate dataset .pt (edit if needed)
CANDIDATES = glob.glob("/content/drive/MyDrive/ph-mckeanvlasov-diff/**/unified_topological_data*.pt", recursive=True)
print("Found .pt candidates:", CANDIDATES[:3])
if not CANDIDATES:
    print(">>> If you don't see your dataset, set DATA_PT manually in next cell.")

CWD: /content/drive/MyDrive/ph-mckeanvlasov-diff/src/mckean-vlasov
mckean-vlasov/
  .DS_Store
  __pycache__/
    dataloader.cpython-312.pyc
    losses_steps.cpython-312.pyc
    losses_steps.py
    models.cpython-312.pyc
    sampling.cpython-312.pyc
  dataloader.py
  losses_steps.py
  main.py
  models.py
  run.ipynb
  runs/
    mv_sde/
  sampling.py
Found .pt candidates: ['/content/drive/MyDrive/ph-mckeanvlasov-diff/datasets/unified_topological_data_v6_semifast.pt']


## Env setup

In [3]:
%pip -q install tqdm gudhi==3.11.0 multipers==2.3.2 geomstats==2.8.0 POT==0.9.5

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/552.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m552.2/552.2 kB[0m [31m36.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m118.5/118.5 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.2/4.2 MB[0m [31m95.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.5/8.5 MB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.5/10.5 MB[0m [31m30.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m901.7/901.7 kB[0m [31m63.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [4]:
# Optional: KeOps forced to CPU (avoid NVCC builds on Colab)
%pip -q install pykeops==2.2.3
import os
os.environ["KEOPS_VERBOSE"]="0"
os.environ["PYKEOPS_FORCE_BUILD"]="1"
os.environ["USE_CUDA"]="0"
os.environ.setdefault("XLA_FLAGS", "--xla_gpu_autotune_level=1 --xla_gpu_deterministic_ops=true")
print("PyKeOps set to CPU mode.")

# Sanity check
import os
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"]="platform"  # grow-as-needed GPU memory
import jax, jax.numpy as jnp, numpy as np, flax, optax, ml_dtypes
print("JAX:", jax.__version__, "| Devices:", jax.devices())
x = jnp.ones((2048,2048))
print("Matmul:", (x@x).block_until_ready().shape)
import torch
print("Torch:", torch.__version__, "CUDA avail:", torch.cuda.is_available())
print("NumPy:", np.__version__, "Flax:", flax.__version__, "Optax:", optax.__version__)

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.5/92.5 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m100.3/100.3 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for pykeops (setup.py) ... [?25l[?25hdone
  Building wheel for keopscore (setup.py) ... [?25l[?25hdone
PyKeOps set to CPU mode.
JAX: 0.5.3 | Devices: [CudaDevice(id=0)]
Matmul: (2048, 2048)
Torch: 2.8.0+cu126 CUDA avail: True
NumPy: 2.0.2 Flax: 0.10.6 Optax: 0.2.5


In [5]:
!export XLA_PYTHON_CLIENT_ALLOCATOR=platform
!export XLA_PYTHON_CLIENT_PREALLOCATE=false
!export XLA_FLAGS="--xla_gpu_autotune_level=2 --xla_gpu_enable_triton=false"

## Repo / Paths
Fill in your repo/user/branch and paths. Artifacts will go to Drive under `OUTDIR`.

In [6]:
import shlex, datetime

# === Set paths/args ===
# If the auto search in previous cell didn’t find your dataset, set it explicitly here:
DATA_PT = CANDIDATES[0] if len(CANDIDATES) else "/content/drive/MyDrive/ph-mckeanvlasov-diff/datasets/unified_topological_data_v6_semifast.pt"

# Choose an output folder on Drive
stamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
OUTDIR = Path("/content/drive/MyDrive/ph-mckeanvlasov-diff/src/mckean-vlasov/runs/mv_sde") / stamp
OUTDIR.mkdir(parents=True, exist_ok=True)

# Training args — tune as you like
args = {
  "--data_pt": DATA_PT,
  "--outdir": str(OUTDIR),
  "--steps": "5000",
  "--batch": "8",
  "--lr": "1.5e-4",
  "--v_pred" : "",
  "--use_guidance": "",
  "--lr_guidance": "5e-5",
  "--guidance_loss_weight": "0.1",
  "--guidance_scale": "0.25",
  "--cfg_scale": "2.5",
  "--sample_steps": "200",
  "--mf_mode": "rbf",
  "--mf_lambda": "0.01",
  "--sample_count": "512",
  "--cfg_drop": "0.1"
}

# Pretty print
print("Training OUTDIR:", OUTDIR)
print("DATA_PT:", DATA_PT)

# Build argv
argv = [sys.executable, "-u", "main.py"]
for k, v in args.items():
    if v == "":                       # boolean flags
        argv.append(k)
    else:
        argv.extend([k, v])

print('Launching:\n', ' '.join(shlex.quote(a) for a in argv))

# Stream logs live
proc = subprocess.Popen(argv, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
try:
    for line in proc.stdout:
        print(line, end="")
finally:
    ret = proc.wait()
    print("\nTraining exit code:", ret)
    if ret != 0:
        raise SystemExit(ret)
print("Done.")

Training OUTDIR: /content/drive/MyDrive/ph-mckeanvlasov-diff/src/mckean-vlasov/runs/mv_sde/20250825_074507
DATA_PT: /content/drive/MyDrive/ph-mckeanvlasov-diff/datasets/unified_topological_data_v6_semifast.pt
Launching:
 /usr/bin/python3 -u main.py --data_pt /content/drive/MyDrive/ph-mckeanvlasov-diff/datasets/unified_topological_data_v6_semifast.pt --outdir /content/drive/MyDrive/ph-mckeanvlasov-diff/src/mckean-vlasov/runs/mv_sde/20250825_074507 --steps 5000 --batch 8 --lr 1.5e-4 --v_pred --use_guidance --lr_guidance 5e-5 --guidance_loss_weight 0.1 --guidance_scale 0.25 --cfg_scale 2.5 --sample_steps 200 --mf_mode rbf --mf_lambda 0.01 --sample_count 512 --cfg_drop 0.1
N=1000  vol=(N,H,W,K,C)=(1000, 128, 128, 3, 3)  KS=3  degrees=3  res=128
[train] Starting 5000 steps... (v_pred: True, guidance: True)
step 00001/5000 | loss=0.7833 (diff=0.7325, guide=0.5087)
step 00100/5000 | loss=0.5311 (diff=0.4815, guide=0.4957)
step 00200/5000 | loss=0.6918 (diff=0.6469, guide=0.4487)
step 00300/50