In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import json
import time
import shutil
import random
from pathlib import Path

import numpy as np
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR

# Optional deps
import anndata as ad
import scanpy as sc
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

# OT
import ot  # pip install POT

from trajectory_models import TrajectoryInference
from utils import *

In [None]:
adata = "../data/dyngen.h5ad"
time_col = "days_pd"
layer = None
n_pcs = 64

# --- training parameters ---
batch_size = 1024
steps = 50_000
lr = 1e-3
weight_decay = 0.0
grad_clip = 1.0
log_every = 200
ckpt_every = 200
eval_trajectories = 256


sigma = "adaptive4-1.0-0.0"
t_eps = 1e-3

seed = 0
device = "auto"               # "auto" or explicit "cuda:n"
save_dir = "./tempo_run"

resume = False                 # resume if ckpt exists
resume_from = None            # explicit path, else save_dir/ckpt_last.pt

# --- scheduler ---
use_cosine = True
min_lr = 1e-4


In [None]:
set_seed(args.seed)
device = (torch.device("cuda" if torch.cuda.is_available() else "cpu")
              if args.device == "auto" else torch.device(args.device))
outdir = Path(args.save_dir).expanduser().resolve()
outdir.mkdir(parents=True, exist_ok=True)

# ---------- Data ----------
adata, X, t_raw = load_adata(args.adata, args.time_col, layer=args.layer)
Xz, scaler, pca   = preprocess_features(X, n_pcs=args.n_pcs)
t_norm, timepoints, _ = normalize_times(t_raw)
X_by_t = split_by_time(Xz, t_norm, timepoints)
print(f"[data] n_timepoints={len(timepoints)}; dims={Xz.shape[1]}; counts={[len(x) for x in X_by_t]}")

plans   = compute_pairwise_ot_plans(X_by_t)
sampler = ChainSampler(X_by_t, plans, rng=np.random.default_rng(args.seed))

d       = Xz.shape[1]
model, opt = build_model(d, args.lr, weight_decay=args.weight_decay)
model.to(device)

vel_model = TrajectoryInference(
    sigma=(float(args.sigma) if args.sigma.replace('.', '', 1).isdigit() else args.sigma),
    interpolation=args.interp
)

# ---------- Scheduler (optional) ----------
sched = None
if args.use_cosine:
    sched = CosineAnnealingLR(opt, T_max=args.steps, eta_min=args.min_lr)

# ---------- Resume ----------
ckpt_path = Path(args.resume_from) if args.resume_from else (outdir / "ckpt_last.pt")
start_step = 0
loss_hist = []

loss_file = outdir / "training_loss.json"
if loss_file.exists():
    try:
        loss_hist = json.load(open(loss_file))
    except Exception:
        loss_hist = []

if args.resume and ckpt_path.exists():
    ckpt, start_step, loss_hist_from_ckpt = load_checkpoint(ckpt_path, model, opt, device)

    # Restore sampler RNG if present
    if "sampler_bitgen_state" in ckpt:
        try:
            sampler.rng.bit_generator.state = ckpt["sampler_bitgen_state"]
        except Exception:
            sampler.rng = np.random.default_rng(args.seed)
            sampler.rng.bit_generator.state = ckpt["sampler_bitgen_state"]

    # Try to infer step if the old ckpt didn't store it
    if start_step == 0:
        archived = sorted(outdir.glob("ckpt_step*.pt"))
        if archived:
            try:
                # files like ckpt_step0002000.pt -> 2000
                start_step = max(int(p.stem.replace("ckpt_step", "")) for p in archived)
            except Exception:
                pass

    # Continue LR schedule smoothly
    if sched is not None and start_step > 0:
        for _ in range(start_step):
            sched.step()

    print(f"[resume] loaded {ckpt_path.name} at step {start_step}")
else:
    print("[resume] starting fresh")

# ---------- Train ----------
model.train()
wall = time.time()
loss_ema, ema_beta = None, 0.95

for step in range(start_step + 1, args.steps + 1):
    xs = torch.from_numpy(sampler.sample(args.batch_size)).to(device)  # [B, K+1, d]
    B  = xs.shape[0]

    # Anchors and safe t away from anchors
    tp = np.asarray(timepoints, dtype=np.float32)
    min_gap = float(np.min(np.diff(tp))) if len(tp) > 1 else 1.0
    eps = min(args.t_eps, 0.25 * min_gap)

    # Make writable tensor to avoid PyTorch warning
    tp_tensor = torch.tensor(tp, dtype=torch.float32, device=device).expand(B, -1)
    t_batch   = sample_times_away_from_anchors(B, tp, eps=eps, device=device)

    t_out, xt, ut, _, _ = vel_model.sample_location_and_conditional_flow(xs, tp_tensor, t=t_batch)

    # Safety checks
    if (not torch.isfinite(xt).all()) or (not torch.isfinite(ut).all()):
        print(f"[warn@{step}] Non-finite xt/ut — skipping batch.")
        continue

    pred = model(torch.cat([xt.squeeze(1), t_out.unsqueeze(1)], 1))
    loss = ((pred - ut.squeeze(1)) ** 2).sum(1).mean()

    if not torch.isfinite(loss):
        print(f"[warn@{step}] Non-finite loss — skipping batch.")
        continue

    opt.zero_grad(set_to_none=True)
    loss.backward()
    if args.grad_clip:
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
    opt.step()
    if sched is not None:
        sched.step()

    # ----- logging -----
    loss_ema = loss.item() if loss_ema is None else ema_beta * loss_ema + (1 - ema_beta) * loss.item()
    if step % args.log_every == 0:
        lr_now = opt.param_groups[0]["lr"]
        print(f"[{step:>6d}/{args.steps}] loss={loss.item():8.3f}  ema={loss_ema:8.3f}  lr={lr_now:.2e}  ({time.time()-wall:.1f}s)")
        wall = time.time()
    loss_hist.append(float(loss.item()))

    # ----- checkpointing -----
    if args.ckpt_every and (step % args.ckpt_every == 0):
        meta = dict(
            dimension=d,
            timepoints=timepoints.tolist(),
            sigma=args.sigma,
            interpolation=args.interp,
            # preprocessing
            scaler_mean=scaler.mean_.tolist(),
            scaler_scale=scaler.scale_.tolist(),
            pca_components=(pca.components_.tolist() if pca is not None else None),
            pca_mean=(pca.mean_.tolist() if pca is not None else None),
            # model arch (for reloads)
            x_latent_dim=128,
            time_embed_dim=128,
            conditional_model=False,
            normalization="layernorm",
            activation="SELU",
            num_out_layers=3,
            n_pcs=args.n_pcs,
        )
        tmp = outdir / ".ckpt_last.tmp"
        save_checkpoint(tmp, step, model, opt, loss_hist, meta, device, sampler=sampler)
        os.replace(tmp, outdir / "ckpt_last.pt")  # atomic rename
        # rolling archive without torch.load
        shutil.copy2(outdir / "ckpt_last.pt", outdir / f"ckpt_step{step:07d}.pt")

    if step % (args.log_every * 5) == 0:
        json.dump(loss_hist, open(loss_file, "w"))

# Final save
final_meta = dict(
    dimension=d,
    timepoints=timepoints.tolist(),
    sigma=args.sigma,
    interpolation=args.interp,
    scaler_mean=scaler.mean_.tolist(),
    scaler_scale=scaler.scale_.tolist(),
    pca_components=(pca.components_.tolist() if pca is not None else None),
    pca_mean=(pca.mean_.tolist() if pca is not None else None),
    x_latent_dim=128,
    time_embed_dim=128,
    conditional_model=False,
    normalization="layernorm",
    activation="SELU",
    num_out_layers=3,
    n_pcs=args.n_pcs,
)
# final checkpoint with optimizer (useful for later resume)
save_checkpoint(outdir / "ckpt_last.pt", step, model, opt, loss_hist, final_meta, device, sampler=sampler)
json.dump(loss_hist, open(loss_file, "w"))

# Evaluate a few trajectories (shape: X: [B, d])
X0 = X_by_t[0][: min(args.eval_trajectories, len(X_by_t[0]))].astype(np.float32)
model.eval()
with torch.no_grad():
    traj = sample_trajectory(
        model,
        X=torch.from_numpy(X0),
        y=torch.zeros(len(X0)),
        device=device,
        guidance=1.0,
        conditional_model=False,
        steps=1001,
        method="dopri5",
    )
np.save(outdir / "trajectories.npy", traj)
print(f"Saved to: {outdir}")
