# Evaluate Pre-trained nanoTabPFN
This notebook loads a pre-trained model and evaluates it on OpenML datasets, comparing against baseline methods.

## Configuration

In [None]:
# Specify the path to your trained model
MODEL_PATH = "../trained_model_20241201_120000.pth"  # Update this path!

# Model configuration (should match what was used during training)
MODEL_CONFIG = {
    "embedding_size": 96,
    "num_attention_heads": 4,
    "mlp_hidden_size": 192,
    "num_layers": 3,
    "num_outputs": 2
}

# Dataset configuration
MAX_FEATURES = 10
NUM_INSTANCES = 200
TARGET_CLASSES = 2

## Imports and Setup

In [None]:
import sys
sys.path.append('..')

from matplotlib import pyplot as plt
import pandas as pd
from sklearn.metrics import roc_auc_score
import numpy as np
from sklearn.model_selection import StratifiedKFold
import seaborn as sns
import openml
from openml.tasks import TaskType
from sklearn.compose import ColumnTransformer
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import LabelEncoder, OrdinalEncoder, FunctionTransformer
import torch

from model import NanoTabPFNModel, NanoTabPFNClassifier
from train import get_default_device

device = get_default_device()
print(f"Using device: {device}")

## Helper Functions

In [None]:
"""
=================== DATA LOADING AND PREPROCESSING ===================
"""

def get_feature_preprocessor(X: np.ndarray | pd.DataFrame) -> ColumnTransformer:
    """
    fits a preprocessor that imputes NaNs, encodes categorical features and removes constant features
    """
    X = pd.DataFrame(X)
    num_mask = []
    cat_mask = []
    for col in X:
        unique_non_nan_entries = X[col].dropna().unique()
        if len(unique_non_nan_entries) <= 1:
            num_mask.append(False)
            cat_mask.append(False)
            continue
        non_nan_entries = X[col].notna().sum()
        numeric_entries = pd.to_numeric(X[col], errors='coerce').notna().sum() # in case numeric columns are stored as strings
        num_mask.append(non_nan_entries == numeric_entries)
        cat_mask.append(non_nan_entries != numeric_entries)

    num_mask = np.array(num_mask)
    cat_mask = np.array(cat_mask)

    num_transformer = Pipeline([
        ("to_pandas", FunctionTransformer(lambda x: pd.DataFrame(x) if not isinstance(x, pd.DataFrame) else x)),
        ("to_numeric", FunctionTransformer(lambda x: x.apply(pd.to_numeric, errors='coerce').to_numpy())),
    ])
    cat_transformer = Pipeline([
        ('encoder', OrdinalEncoder(handle_unknown='use_encoded_value', unknown_value=np.nan)),
    ])

    preprocessor = ColumnTransformer(
        transformers=[
            ('num', num_transformer, num_mask),
            ('cat', cat_transformer, cat_mask)
        ]
    )
    return preprocessor

def get_openml_datasets(
        max_features_eval: int = 10, 
        new_instances_eval: int = 200, 
        target_classes_filter: int = 2,
        **kwargs,
        ) -> dict[str, tuple[np.ndarray, np.ndarray]]:
    """
    Load OpenML tabarena datasets with at most `max_features` features and subsampled (stratified) to `new_instances` instances.
    """
    task_ids = [
        363612, 363613, 363614, 363615, 363616, 363618, 363619, 363620,
        363621, 363623, 363624, 363625, 363626, 363627, 363628, 363629,
        363630, 363631, 363632, 363671, 363672, 363673, 363674, 363675,
        363676, 363677, 363678, 363679, 363681, 363682, 363683, 363684,
        363685, 363686, 363689, 363691, 363693, 363694, 363696, 363697,
        363698, 363699, 363700, 363702, 363704, 363705, 363706, 363707,
        363708, 363711, 363712
    ] # TabArena v0.1
    datasets = {}
    for task_id in task_ids:
        task = openml.tasks.get_task(task_id, download_splits=False)
        if task.task_type_id != TaskType.SUPERVISED_CLASSIFICATION:
            continue  # skip task, only classification
        dataset = task.get_dataset(download_data=False)

        if dataset.qualities["NumberOfFeatures"] > max_features_eval or (dataset.qualities["NumberOfClasses"] > target_classes_filter) or dataset.qualities["PercentageOfInstancesWithMissingValues"] > 0 or dataset.qualities["MinorityClassPercentage"] < 2.5:
            continue
        X, y, categorical_indicator, attribute_names = dataset.get_data(
            target=task.target_name, dataset_format="dataframe"
        )
        if new_instances_eval < len(y):
            _, X_sub, _, y_sub = train_test_split(
                X, y,
                test_size=new_instances_eval,
                stratify=y,
                random_state=0,
            )
        else:
            X_sub = X
            y_sub = y
        
        X = X_sub.to_numpy(copy=True)
        y = y_sub.to_numpy(copy=True)
        label_encoder = LabelEncoder()
        y = label_encoder.fit_transform(y)

        preprocessor = get_feature_preprocessor(X)
        X = preprocessor.fit_transform(X)
        datasets[dataset.name] = (X, y)
    return datasets

In [None]:
"""
=================== EVALUATION ===================
"""

_skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)

def eval_model(model, datasets):
    """Evaluates a model on multiple datasets and returns metrics"""
    metrics = {}
    for dataset_name, (X,y)  in datasets.items():
        targets = []
        probabilities = []
        
        for train_idx, test_idx in _skf.split(X, y):
            X_train, X_test = X[train_idx], X[test_idx]
            y_train, y_test  = y[train_idx], y[test_idx]
            targets.append(y_test)
            model.fit(X_train, y_train)
            y_proba = model.predict_proba(X_test)
            if y_proba.shape[1] == 2:  # binary classification with neural network
                y_proba = y_proba[:, 1]
            probabilities.append(y_proba)
    
        targets = np.concatenate(targets, axis=0)
        probabilities = np.concatenate(probabilities, axis=0)

        metrics[f"{dataset_name}/ROC AUC"] = roc_auc_score(targets, probabilities, multi_class="ovr")
    
    metric_names = list({key.split("/")[-1] for key in metrics.keys()})
    for metric_name in metric_names:
        avg_metric = np.mean([metrics[key] for key in metrics.keys() if key.endswith(metric_name)])
        metrics[f"{metric_name}"] = float(avg_metric)
    
    return metrics

## Load Pre-trained Model

In [None]:
# Initialize model with the same configuration used during training
model = NanoTabPFNModel(**MODEL_CONFIG)

# Load the trained weights
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device)
model.eval()

# Create classifier wrapper
nano_classifier = NanoTabPFNClassifier(model, device)

print(f"Successfully loaded model from {MODEL_PATH}")

## Load OpenML Datasets

In [None]:
print("Loading OpenML datasets...")
DATASETS = get_openml_datasets(
    max_features_eval=MAX_FEATURES,
    new_instances_eval=NUM_INSTANCES,
    target_classes_filter=TARGET_CLASSES
)
print(f"Loaded {len(DATASETS)} datasets: {list(DATASETS.keys())}")

## Evaluate Models

In [None]:
print("Evaluating nanoTabPFN...")
nano_results = eval_model(nano_classifier, DATASETS)
print("\nnanoTabPFN Results:")
print(nano_results)

In [None]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from tabpfn import TabPFNClassifier
from tabpfn.config import ModelInterfaceConfig, PreprocessorConfig

print("Evaluating baseline models...")

no_preprocessing_inference_config = ModelInterfaceConfig(
    FINGERPRINT_FEATURE=False,
    PREPROCESS_TRANSFORMS=[PreprocessorConfig(name='none')]
)

baseline_models = {
    "TabPFN v2": TabPFNClassifier(random_state=0),
    "TabPFN v2 (no preprocessing)": TabPFNClassifier(
        inference_config=no_preprocessing_inference_config, 
        n_estimators=1, 
        random_state=0
    ),
    "Random Forest": RandomForestClassifier(random_state=0),
    "K-Nearest Neighbors": KNeighborsClassifier(),
    "Decision Tree": DecisionTreeClassifier(random_state=0),
}

baseline_results = {}
for name, model in baseline_models.items():
    print(f"  Evaluating {name}...")
    baseline_results[name] = eval_model(model, DATASETS)

print("\nBaseline evaluation complete!")

## Results Summary

In [None]:
# Create results DataFrame
all_results = {"nanoTabPFN": nano_results, **baseline_results}
results_df = pd.DataFrame(all_results).T

# Sort by ROC AUC descending
results_df = results_df.sort_values("ROC AUC", ascending=False)

print("\n" + "="*80)
print("OVERALL RESULTS (Average ROC AUC across all datasets)")
print("="*80)
print(results_df[["ROC AUC"]].to_string())
print("="*80)

# Display full results
display(results_df)

## Visualization

In [None]:
# Plot overall comparison
sns.set_style("whitegrid")
sns.set_context("notebook")

fig, ax = plt.subplots(figsize=(10, 6))

# Create bar plot
colors = ['#1f77b4' if idx == 'nanoTabPFN' else '#d62728' for idx in results_df.index]
results_df["ROC AUC"].plot(kind='barh', ax=ax, color=colors)

ax.set_xlabel('ROC AUC', fontsize=12)
ax.set_ylabel('Model', fontsize=12)
ax.set_title('Model Comparison on OpenML TabArena Datasets', fontsize=14, fontweight='bold')
ax.axvline(x=results_df["ROC AUC"].loc["nanoTabPFN"], color='blue', linestyle='--', alpha=0.3)
ax.set_xlim(results_df["ROC AUC"].min() - 0.02, results_df["ROC AUC"].max() + 0.02)

plt.tight_layout()
plt.show()

In [None]:
# Plot per-dataset comparison
dataset_metrics = [col for col in results_df.columns if "/" in col]
dataset_names = [col.split("/")[0] for col in dataset_metrics]

# Create subplot for each dataset
n_datasets = len(dataset_names)
n_cols = 3
n_rows = (n_datasets + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 4 * n_rows))
axes = axes.flatten() if n_datasets > 1 else [axes]

for idx, (dataset_name, metric_col) in enumerate(zip(dataset_names, dataset_metrics)):
    ax = axes[idx]
    dataset_results = results_df[metric_col].sort_values(ascending=True)
    colors = ['#1f77b4' if model == 'nanoTabPFN' else '#d62728' for model in dataset_results.index]
    dataset_results.plot(kind='barh', ax=ax, color=colors)
    ax.set_title(dataset_name, fontsize=10, fontweight='bold')
    ax.set_xlabel('ROC AUC', fontsize=9)
    ax.set_ylabel('')
    ax.tick_params(axis='y', labelsize=8)

# Hide unused subplots
for idx in range(n_datasets, len(axes)):
    axes[idx].set_visible(False)

plt.tight_layout()
plt.show()

## Save Results

In [None]:
# Save results to CSV
import os
from datetime import datetime

os.makedirs("../results", exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_filename = f"../results/pretrained_evaluation_{timestamp}.csv"
results_df.to_csv(results_filename)
print(f"Results saved to: {results_filename}")