# Prior Generation Comparison (`head` vs `roots` vs `is_causal=True`)

This notebook benchmarks classification difficulty across three prior data-generation modes:

- `is_causal=False`, `noncausal_feature_source="head"`
- `is_causal=False`, `noncausal_feature_source="roots"` (TabICL-like roots as features)
- `is_causal=True` (intermediate-node feature sampling)

Constraints in this benchmark:

- No PU row dropping and no hidden labels (`pu_keep_probability=1.0`)
- Train and test splits both contain two classes for every sampled dataset
- Other prior settings stay aligned with `default_base_prior_config()`

In [17]:
from pathlib import Path
import sys

repo_root = Path.cwd().resolve()
while repo_root != repo_root.parent and not (repo_root / "simplified_prior").exists():
    repo_root = repo_root.parent
if not (repo_root / "simplified_prior").exists():
    raise RuntimeError("Could not find repo root containing simplified_prior/.")

if str(repo_root.parent) not in sys.path:
    sys.path.insert(0, str(repo_root.parent))

print("Repo root:", repo_root)

Repo root: /Users/qltian/Library/CloudStorage/GoogleDrive-qltian2021@gmail.com/Other computers/My Laptop/Documents/Research/ai/slim_pretrain


In [18]:
from dataclasses import replace
import warnings
from typing import Dict

import numpy as np
import pandas as pd

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score, roc_auc_score

from slim_pretrain.pretrain.train import default_base_prior_config
from slim_pretrain.simplified_prior import SimplifiedPriorConfig, generate_simplified_prior_data

warnings.filterwarnings("ignore", category=RuntimeWarning, module="sklearn")

try:
    from xgboost import XGBClassifier
    HAVE_XGBOOST = True
except Exception:
    HAVE_XGBOOST = False

print("xgboost available:", HAVE_XGBOOST)

xgboost available: True


In [19]:
# Benchmark size: increase for a more stable comparison.
N_DATASETS_PER_MODE = 30
MAX_TRIES_PER_DATASET = 300
BASE_SEED = 123

MODES = ["noncausal_head", "noncausal_roots", "causal_intermediate"]
MODE_LABELS = {
    "noncausal_head": "is_causal=False + head",
    "noncausal_roots": "is_causal=False + roots",
    "causal_intermediate": "is_causal=True",
}

print("Datasets per mode:", N_DATASETS_PER_MODE)

Datasets per mode: 30


In [20]:
def build_mode_config(base_cfg: SimplifiedPriorConfig, mode: str) -> SimplifiedPriorConfig:
    common = dict(
        pu_keep_probability=1.0,  # keep full labels; no PU hidden rows
    )
    if mode == "noncausal_head":
        return replace(base_cfg, is_causal=False, noncausal_feature_source="head", **common)
    if mode == "noncausal_roots":
        # TabICL-like non-causal mode: roots as observed features.
        return replace(
            base_cfg,
            is_causal=False,
            noncausal_feature_source="roots",
            num_causes=base_cfg.num_features,
            **common,
        )
    if mode == "causal_intermediate":
        return replace(base_cfg, is_causal=True, **common)
    raise ValueError(f"Unknown mode: {mode}")


def generate_valid_dataset(cfg: SimplifiedPriorConfig, seed_start: int, max_tries: int = 200) -> Dict[str, np.ndarray]:
    for offset in range(max_tries):
        cfg_try = replace(cfg, seed=seed_start + offset)
        out = generate_simplified_prior_data(cfg_try, num_datasets=1)

        X = out["X"][0].cpu().numpy()
        y = out["y"][0].cpu().numpy()
        train_size = int(out["train_sizes"][0].item())
        is_pu = bool(out["is_pu"][0].item())

        # Requirement 1: keep original full labels (no PU hiding).
        if is_pu:
            continue
        if np.any(y < 0):
            continue

        X_train = X[:train_size]
        y_train = y[:train_size]
        X_test = X[train_size:]
        y_test = y[train_size:]

        # Requirement 1: both train/test contain two classes.
        if len(np.unique(y_train)) != 2:
            continue
        if len(np.unique(y_test)) != 2:
            continue

        return {
            "X_train": X_train,
            "y_train": y_train,
            "X_test": X_test,
            "y_test": y_test,
            "seed": seed_start + offset,
        }

    raise RuntimeError("Failed to sample a valid binary train/test split with full labels.")


def build_models(random_state: int):
    models = {
        "logreg": Pipeline(
            [
                ("scaler", StandardScaler()),
                ("clf", LogisticRegression(max_iter=5000, random_state=random_state)),
            ]
        ),
        "svm_rbf": Pipeline(
            [
                ("scaler", StandardScaler()),
                ("clf", SVC(kernel="rbf", C=1.0, probability=True, random_state=random_state)),
            ]
        ),
        "random_forest": RandomForestClassifier(
            n_estimators=400,
            random_state=random_state,
            n_jobs=-1,
        ),
    }
    if HAVE_XGBOOST:
        models["xgboost"] = XGBClassifier(
            n_estimators=300,
            max_depth=6,
            learning_rate=0.05,
            subsample=0.9,
            colsample_bytree=0.9,
            objective="binary:logistic",
            eval_metric="logloss",
            tree_method="hist",
            random_state=random_state,
            n_jobs=1,
        )
    return models


def evaluate_model(model, X_train, y_train, X_test, y_test):
    model.fit(X_train, y_train)
    pred = model.predict(X_test)

    if hasattr(model, "predict_proba"):
        score = model.predict_proba(X_test)[:, 1]
    elif hasattr(model, "decision_function"):
        score = model.decision_function(X_test)
    else:
        score = pred

    return {
        "accuracy": accuracy_score(y_test, pred),
        "balanced_accuracy": balanced_accuracy_score(y_test, pred),
        "f1": f1_score(y_test, pred),
        "roc_auc": roc_auc_score(y_test, score),
    }

In [21]:
base_cfg = default_base_prior_config()

rows = []
sample_dataset_summaries = []

for mode_idx, mode in enumerate(MODES):
    cfg = build_mode_config(base_cfg, mode)
    for ds_idx in range(N_DATASETS_PER_MODE):
        seed_start = BASE_SEED + mode_idx * 1_000_000 + ds_idx * 1_000
        dataset = generate_valid_dataset(cfg, seed_start=seed_start, max_tries=MAX_TRIES_PER_DATASET)

        X_train = dataset["X_train"]
        y_train = dataset["y_train"]
        X_test = dataset["X_test"]
        y_test = dataset["y_test"]

        if ds_idx == 0:
            sample_dataset_summaries.append(
                {
                    "mode": MODE_LABELS[mode],
                    "seed": dataset["seed"],
                    "n_train": len(y_train),
                    "n_test": len(y_test),
                    "n_features": X_train.shape[1],
                    "train_class0": int((y_train == 0).sum()),
                    "train_class1": int((y_train == 1).sum()),
                    "test_class0": int((y_test == 0).sum()),
                    "test_class1": int((y_test == 1).sum()),
                }
            )

        models = build_models(random_state=seed_start + 77)
        for model_name, model in models.items():
            try:
                metrics = evaluate_model(model, X_train, y_train, X_test, y_test)
                rows.append(
                    {
                        "mode": MODE_LABELS[mode],
                        "dataset_id": ds_idx,
                        "model": model_name,
                        "status": "ok",
                        "error": "",
                        **metrics,
                    }
                )
            except Exception as exc:
                rows.append(
                    {
                        "mode": MODE_LABELS[mode],
                        "dataset_id": ds_idx,
                        "model": model_name,
                        "status": "error",
                        "error": str(exc),
                        "accuracy": np.nan,
                        "balanced_accuracy": np.nan,
                        "f1": np.nan,
                        "roc_auc": np.nan,
                    }
                )

results_df = pd.DataFrame(rows)
sample_summary_df = pd.DataFrame(sample_dataset_summaries)

print("Benchmark rows:", len(results_df))
sample_summary_df

Benchmark rows: 360


Unnamed: 0,mode,seed,n_train,n_test,n_features,train_class0,train_class1,test_class0,test_class1
0,is_causal=False + head,123,179,77,20,83,96,45,32
1,is_causal=False + roots,1000123,179,77,20,91,88,37,40
2,is_causal=True,2000123,179,77,20,89,90,39,38


In [22]:
results_ok = results_df[results_df["status"] == "ok"].copy()
error_counts = (
    results_df[results_df["status"] == "error"]
    .groupby(["mode", "model"], as_index=False)
    .size()
)

model_mode_summary = (
    results_ok.groupby(["mode", "model"], as_index=False)[["accuracy", "balanced_accuracy", "f1", "roc_auc"]]
    .mean()
    .sort_values(["mode", "accuracy"], ascending=[True, False])
)

mode_overall_summary = (
    results_ok.groupby("mode")[["accuracy", "balanced_accuracy", "f1", "roc_auc"]]
    .mean()
    .assign(mean_metric=lambda d: d.mean(axis=1))
    .sort_values("mean_metric", ascending=False)
)

print("Successful model runs:", len(results_ok), "/", len(results_df))
if len(error_counts) > 0:
    print("Model failures by mode/model:")
    display(error_counts)

model_mode_summary

Successful model runs: 360 / 360


Unnamed: 0,mode,model,accuracy,balanced_accuracy,f1,roc_auc
0,is_causal=False + head,logreg,0.855411,0.857489,0.850273,0.940286
2,is_causal=False + head,svm_rbf,0.800866,0.804322,0.796572,0.888548
3,is_causal=False + head,xgboost,0.787013,0.788717,0.782355,0.87041
1,is_causal=False + head,random_forest,0.773593,0.776275,0.766658,0.86099
6,is_causal=False + roots,svm_rbf,0.561472,0.567915,0.556512,0.576042
4,is_causal=False + roots,logreg,0.561039,0.566182,0.555728,0.587825
5,is_causal=False + roots,random_forest,0.548485,0.555309,0.544447,0.569962
7,is_causal=False + roots,xgboost,0.529004,0.53356,0.531471,0.558644
8,is_causal=True,logreg,0.734199,0.737136,0.73459,0.809973
10,is_causal=True,svm_rbf,0.720346,0.723962,0.724648,0.801629


In [23]:
mode_overall_summary

Unnamed: 0_level_0,accuracy,balanced_accuracy,f1,roc_auc,mean_metric
mode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
is_causal=False + head,0.804221,0.8067,0.798964,0.890058,0.824986
is_causal=True,0.718831,0.721514,0.718598,0.792645,0.737897
is_causal=False + roots,0.55,0.555742,0.547039,0.573118,0.556475


In [24]:
easiest_mode = mode_overall_summary.index[0]
print("Easiest generation mode by mean metric:", easiest_mode)
print("(Higher means easier to classify for the selected model family and config.)")

Easiest generation mode by mean metric: is_causal=False + head
(Higher means easier to classify for the selected model family and config.)
