# PU/OSLS TabPFN Pretraining with Curriculum (TabICL Prior)

Curriculum schedule used here is controlled by config variables:
- `update_every_steps` (how often hyperparameters are updated)
- `max_updates` (number of update stages before becoming stable)
- `min_features` stays fixed while `max_features` expands over stages


In [None]:
from pathlib import Path
import importlib
import sys

import numpy as np
import torch

def _find_repo_root(start: Path) -> Path:
    for candidate in [start.resolve(), *start.resolve().parents]:
        if (candidate / "src" / "pu_osls_tabpfn").exists():
            return candidate
    raise RuntimeError(
        "Could not find project root containing src/pu_osls_tabpfn. "
        "Open this notebook from the repository root or install with `pip install -e .`."
    )

ROOT = _find_repo_root(Path.cwd())
sys.path.insert(0, str(ROOT / "src"))

from pu_osls_tabpfn.eval_pu_osls import EvalConfig, evaluate_pu_osls, print_results
from pu_osls_tabpfn.model import CustomNanoTabPFNModel
from pu_osls_tabpfn.prior_data import PriorGeneratorConfig, TabICLPriorConfig, TestLabelShiftConfig, generate_batch
import pu_osls_tabpfn.train as train_module
train_module = importlib.reload(train_module)
CurriculumConfig = train_module.CurriculumConfig
get_device = train_module.get_device
train = train_module.train


## 1) Final target configuration

In [None]:
cfg = PriorGeneratorConfig(
    max_classes=5,
    min_features=4,
    max_features=10,
    min_rows=800,
    max_rows=1000,
    min_train_fraction=0.5,
    max_train_fraction=0.6,
    remove_poisson_lambda=1.2,
    min_train_rows_after_removal=30,
    seed=99,
    prior_backend="tabicl",
    tabicl=TabICLPriorConfig(
        prior_type="mlp_scm",
        n_jobs=1,
        batch_size_per_gp=4,
        batch_size_per_subgp=2,
    ),
    test_label_shift=TestLabelShiftConfig(enabled=False, strategy="none", strength=0.0),
)
update_every_steps = 100
max_updates = 20
curriculum_cfg = CurriculumConfig(
    enabled=True,
    update_every_steps=update_every_steps,
    max_updates=max_updates,
    start_max_classes=3,
    start_max_features=cfg.min_features,
    start_min_rows=700,
    start_max_rows=800,
    start_remove_poisson_lambda=0.3,
    tabicl_sampled_hp_start={
        "num_layers": {"max_mean": 2.0},
        "hidden_dim": {"max_mean": 24.0, "min_mean": 4.0},
        "num_causes": {"max_mean": 4.0},
    },
)
device = get_device()
unseen_label = cfg.max_classes
num_outputs = cfg.max_classes + 1
print(f"Device: {device}")
print(curriculum_cfg)
print(f"Curriculum updates every {curriculum_cfg.update_every_steps} steps")
print(f"Curriculum stabilizes after {curriculum_cfg.update_every_steps * curriculum_cfg.max_updates} steps")


## 2) Quick batch sanity check

In [None]:
batch = generate_batch(cfg, batch_size=4, device=device, rng=np.random.default_rng(cfg.seed))
print("x shape:", tuple(batch["x"].shape))
print("y shape:", tuple(batch["y"].shape))
print("split indices:", tuple(batch["train_test_split_index"].tolist()))
print("num_classes:", tuple(batch["num_classes"].tolist()))
print("num_features:", tuple(batch["num_features"].tolist()))
print("removed_class_count:", tuple(batch["removed_class_count"].tolist()))

## 3) Train with curriculum

Choose `num_steps` for your run; the schedule above determines when updates happen and when curriculum stabilizes.

## 3a) GPU memory profiling for a given config

Run this before full training to estimate peak CUDA memory for a chosen curriculum step and batch size.


In [None]:
def profile_gpu_step(model, cfg, curriculum_cfg, step, batch_size, unseen_label, device):
    assert device.type == "cuda", "GPU profiling requires CUDA."

    from pu_osls_tabpfn.prior_data import generate_batch
    from pu_osls_tabpfn.train import _build_step_cfg

    model = model.to(device).train()
    step_cfg = _build_step_cfg(cfg, step=step, curriculum_cfg=curriculum_cfg)
    batch = generate_batch(
        step_cfg,
        batch_size=batch_size,
        device=device,
        rng=np.random.default_rng(cfg.seed + step),
    )

    x = batch["x"]
    y_full = batch["y"]
    split = batch["train_test_split_index"]

    max_split = int(split.max().item())
    y_train_padded = torch.full(
        (x.shape[0], max_split),
        fill_value=unseen_label,
        device=x.device,
        dtype=y_full.dtype,
    )
    for b in range(x.shape[0]):
        split_b = int(split[b].item())
        y_train_padded[b, :split_b] = y_full[b, :split_b]

    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    try:
        logits = model((x, y_train_padded), split)
        loss = logits.float().mean()
        loss.backward()

        peak_alloc = torch.cuda.max_memory_allocated() / 1024**3
        peak_reserved = torch.cuda.max_memory_reserved() / 1024**3
        free, total = torch.cuda.mem_get_info()
        print(f"step={step}, batch_size={batch_size}")
        print(f"curriculum rows={step_cfg.min_rows}-{step_cfg.max_rows}, features={step_cfg.min_features}-{step_cfg.max_features}")
        print(f"peak allocated: {peak_alloc:.2f} GiB")
        print(f"peak reserved : {peak_reserved:.2f} GiB")
        print(f"free/total    : {free/1024**3:.2f}/{total/1024**3:.2f} GiB")
    except RuntimeError as e:
        print("OOM/runtime error:", str(e).split("\n")[0])


# Example probes:
# profile_gpu_step(model, cfg, curriculum_cfg, step=0, batch_size=64, unseen_label=unseen_label, device=device)
# profile_gpu_step(model, cfg, curriculum_cfg, step=2000, batch_size=16, unseen_label=unseen_label, device=device)


In [None]:
model = CustomNanoTabPFNModel(
    embedding_size=32,
    num_attention_heads=4,
    mlp_hidden_size=64,
    num_layers=3,
    num_outputs=num_outputs,
    unseen_label=unseen_label,
)

# Optional re-check before long training:
# profile_gpu_step(model, cfg, curriculum_cfg, step=20000, batch_size=64, unseen_label=unseen_label, device=device)


In [None]:
eval_cfg = EvalConfig(n_tasks=100, batch_size=8, seed=999, outlier_score="msp")
model, losses = train(
    model,
    cfg,
    batch_size=64,
    lr=2e-4,
    device=device,
    num_steps=30000,
    unseen_label=unseen_label,
    eval_cfg=eval_cfg,
    eval_interval=0,
    curriculum_cfg=curriculum_cfg,
)
print(f"Final loss: {losses[-1]:.4f}")


## 4) Final evaluation + checkpoint

In [None]:
results = evaluate_pu_osls(
    model,
    cfg_prior=cfg,
    unseen_label=unseen_label,
    device=device,
    eval_cfg=eval_cfg,
)
print_results(results)

artifacts_dir = ROOT / "artifacts"
artifacts_dir.mkdir(parents=True, exist_ok=True)
checkpoint_path = artifacts_dir / "pretrained_pu_osls_tabpfn.pt"
torch.save(
    {
        "model_state_dict": model.state_dict(),
        "cfg": cfg,
        "curriculum_cfg": curriculum_cfg,
        "eval_cfg": eval_cfg,
        "unseen_label": unseen_label,
    },
    checkpoint_path,
)
print(f"Saved checkpoint to: {checkpoint_path}")