# PU/OSLS TabPFN Pretraining with Curriculum (TabICL Prior)

Curriculum schedule used here is controlled by config variables:
- `update_every_steps` (how often hyperparameters are updated)
- `max_updates` (number of update stages before becoming stable)
- `min_features` stays fixed while `max_features` expands over stages


In [10]:
from pathlib import Path
import importlib
import sys

import numpy as np
import torch

def _find_repo_root(start: Path) -> Path:
    for candidate in [start.resolve(), *start.resolve().parents]:
        if (candidate / "src" / "pu_osls_tabpfn").exists():
            return candidate
    raise RuntimeError(
        "Could not find project root containing src/pu_osls_tabpfn. "
        "Open this notebook from the repository root or install with `pip install -e .`."
    )

ROOT = _find_repo_root(Path.cwd())
sys.path.insert(0, str(ROOT / "src"))

from pu_osls_tabpfn.eval_pu_osls import EvalConfig, evaluate_pu_osls, print_results
from pu_osls_tabpfn.model import CustomNanoTabPFNModel
from pu_osls_tabpfn.prior_data import PriorGeneratorConfig, TabICLPriorConfig, TestLabelShiftConfig, generate_batch
import pu_osls_tabpfn.train as train_module
train_module = importlib.reload(train_module)
CurriculumConfig = train_module.CurriculumConfig
get_device = train_module.get_device
train = train_module.train


## 1) Final target configuration

In [11]:
cfg = PriorGeneratorConfig(
    max_classes=10,
    min_features=3,
    max_features=8,
    min_rows=500,
    max_rows=1000,
    min_train_fraction=0.4,
    max_train_fraction=0.8,
    remove_poisson_lambda=1.0,
    seed=0,
    prior_backend="tabicl",
    tabicl=TabICLPriorConfig(
        prior_type="mlp_scm",
        n_jobs=1,
        batch_size_per_gp=4,
        batch_size_per_subgp=2,
    ),
    test_label_shift=TestLabelShiftConfig(enabled=False, strategy="none", strength=0.0),
)
update_every_steps = 500
max_updates = 20
curriculum_cfg = CurriculumConfig(
    enabled=True,
    update_every_steps=update_every_steps,
    max_updates=max_updates,
    start_max_classes=2,
    start_max_features=cfg.min_features,
    start_min_rows=120,
    start_max_rows=300,
    start_remove_poisson_lambda=0.3,
    tabicl_sampled_hp_start={
        "num_layers": {"max_mean": 2.0},
        "hidden_dim": {"max_mean": 24.0, "min_mean": 4.0},
        "num_causes": {"max_mean": 4.0},
    },
)
device = get_device()
unseen_label = cfg.max_classes
num_outputs = cfg.max_classes + 1
print(f"Device: {device}")
print(curriculum_cfg)
print(f"Curriculum updates every {curriculum_cfg.update_every_steps} steps")
print(f"Curriculum stabilizes after {curriculum_cfg.update_every_steps * curriculum_cfg.max_updates} steps")


Device: mps
CurriculumConfig(enabled=True, update_every_steps=500, max_updates=20, start_max_classes=2, start_max_features=3, start_min_rows=120, start_max_rows=300, start_remove_poisson_lambda=0.3, tabicl_sampled_hp_start={'num_layers': {'max_mean': 2.0}, 'hidden_dim': {'max_mean': 24.0, 'min_mean': 4.0}, 'num_causes': {'max_mean': 4.0}})
Curriculum updates every 500 steps
Curriculum stabilizes after 10000 steps


## 2) Quick batch sanity check

In [12]:
batch = generate_batch(cfg, batch_size=4, device=device, rng=np.random.default_rng(cfg.seed))
print("x shape:", tuple(batch["x"].shape))
print("y shape:", tuple(batch["y"].shape))
print("split indices:", tuple(batch["train_test_split_index"].tolist()))
print("num_classes:", tuple(batch["num_classes"].tolist()))
print("num_features:", tuple(batch["num_features"].tolist()))
print("removed_class_count:", tuple(batch["removed_class_count"].tolist()))

x shape: (4, 647, 8)
y shape: (4, 647)
split indices: (347, 347, 347, 245)
num_classes: (2, 2, 8, 8)


KeyError: 'num_features'

## 3) Train with curriculum

Choose `num_steps` for your run; the schedule above determines when updates happen and when curriculum stabilizes.

In [None]:
model = CustomNanoTabPFNModel(
    embedding_size=32,
    num_attention_heads=4,
    mlp_hidden_size=64,
    num_layers=2,
    num_outputs=num_outputs,
    unseen_label=unseen_label,
)
eval_cfg = EvalConfig(n_tasks=100, batch_size=8, seed=999, outlier_score="msp")
model, losses = train(
    model,
    cfg,
    batch_size=8,
    lr=5e-4,
    device=device,
    num_steps=1200,
    unseen_label=unseen_label,
    eval_cfg=eval_cfg,
    eval_interval=0,
    curriculum_cfg=curriculum_cfg,
)
print(f"Final loss: {losses[-1]:.4f}")


## 4) Final evaluation + checkpoint

In [None]:
results = evaluate_pu_osls(
    model,
    cfg_prior=cfg,
    unseen_label=unseen_label,
    device=device,
    eval_cfg=eval_cfg,
)
print_results(results)

artifacts_dir = ROOT / "artifacts"
artifacts_dir.mkdir(parents=True, exist_ok=True)
checkpoint_path = artifacts_dir / "pretrained_pu_osls_tabpfn.pt"
torch.save(
    {
        "model_state_dict": model.state_dict(),
        "cfg": cfg,
        "curriculum_cfg": curriculum_cfg,
        "eval_cfg": eval_cfg,
        "unseen_label": unseen_label,
    },
    checkpoint_path,
)
print(f"Saved checkpoint to: {checkpoint_path}")