This notebook shows how to use TabPFN for tabular prediction with a scikit learn wrapper.

classifier = TabPFNClassifier(device='cpu')
classifier.fit(train_xs, train_ys)
prediction_ = classifier.predict(test_xs)

The fit function does not perform any computations, but only saves the training data. Computations are only done at inference time, when calling predict.
Note that the presaved models were trained for up to 100 features, 10 classes and 1000 samples. While the model does not have a hard bound on the number of samples, the features and classes are restricted and larger sizes lead to an error.

### Setup

In [None]:
%load_ext autoreload

%autoreload 2

In [None]:
import time
import torch
import numpy as np
import os

from scripts.model_builder import get_model, get_default_spec, save_model, load_model
from scripts.transformer_prediction_interface import transformer_predict, get_params_from_config, TabPFNClassifier
from scripts.differentiable_pfn_evaluation import eval_model, eval_model_range

from datasets import load_openml_list, open_cc_dids, open_cc_valid_dids, test_dids_classification

from scripts import tabular_metrics
import random

In [None]:
base_path = '.'

### Load datasets

In [None]:
max_samples = 10000
bptt = 10000

cc_test_datasets_multiclass, cc_test_datasets_multiclass_df = load_openml_list(open_cc_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = max_samples, num_feats=100, return_capped=True)
cc_valid_datasets_multiclass, cc_valid_datasets_multiclass_df = load_openml_list(open_cc_valid_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = max_samples, num_feats=100, return_capped=True)

# Loading longer OpenML Datasets for generalization experiments (optional)
# test_datasets_multiclass, test_datasets_multiclass_df = load_openml_list(test_dids_classification, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 10000, num_feats=100, return_capped=True)

random.seed(0)
random.shuffle(cc_valid_datasets_multiclass)

In [None]:
def get_datasets(selector, task_type, suite='cc'):
    if task_type == 'binary':
        ds = valid_datasets_binary if selector == 'valid' else test_datasets_binary
    else:
        if suite == 'openml':
            ds = valid_datasets_multiclass if selector == 'valid' else test_datasets_multiclass
        elif suite == 'cc':
            ds = cc_valid_datasets_multiclass if selector == 'valid' else cc_test_datasets_multiclass
        else:
            raise Exception("Unknown suite")
    return ds

In [None]:
model_string, longer, task_type = '', 1, 'multiclass'
eval_positions = [1000]
bptt = 2000
    
test_datasets, valid_datasets = get_datasets('test', task_type, suite='cc'), get_datasets('valid', task_type, suite='cc')

### Run on a single dataset

In [None]:
[(i, test_datasets[i][0]) for i in range(len(test_datasets))]

In [None]:
evaluation_dataset_index = 0 # Index of the dataset to predict
ds = test_datasets[evaluation_dataset_index]
print(f'Evaluation dataset name: {ds[0]} shape {ds[1].shape}')

In [None]:
xs, ys = ds[1].clone(), ds[2].clone()
eval_position = xs.shape[0] // 2
train_xs, train_ys = xs[0:eval_position], ys[0:eval_position]
test_xs, test_ys = xs[eval_position:], ys[eval_position:]

In [None]:
classifier = TabPFNClassifier(device='cpu')
classifier.fit(train_xs, train_ys)
prediction_ = classifier.predict_proba(test_xs)

In [None]:
roc, ce = tabular_metrics.auc_metric(test_ys, prediction_), tabular_metrics.cross_entropy(test_ys, prediction_)
'AUC', float(roc), 'Cross Entropy', float(ce)

### Run on all datasets
This section runs a differentiable hyperparameter tuning run and saves the results to a results file, which can be inserted in TabularEval.ipynb to compare to other baselines.

In [None]:
eval_positions=[1000]
bptt=2000

N_models = 3
models_per_block = 1

eval_addition = 'user_run'
device = 'cpu'

eval_model_range(i_range=[0], e=-1
                          , valid_datasets=[]#cc_valid_datasets_multiclass
                          , test_datasets=cc_test_datasets_multiclass
                          , train_datasets=[]
                          , eval_positions_test=eval_positions
                          , bptt_test=bptt
                          , add_name=model_string
                          , base_path=base_path
                          , selection_metric='auc'
                          , best_grad_steps=0
                          , eval_addition=eval_addition
                          , N_ensemble_configurations_list = [32]
                          , device=device)#range(0, 10)

### Run generalization experiments

In [None]:
# Loading longer OpenML Datasets for generalization experiments (optional)
test_datasets_multiclass, test_datasets_multiclass_df = load_openml_list(test_dids_classification, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 10000, num_feats=100, return_capped=True)


In [None]:
test_datasets_longer_generalization = [ds for ds in test_datasets_multiclass if ds[1].shape[0] >= 10000]

In [None]:
def test_gen(classifier_key, split):
    if classifier_key == 'tabpfn':
        model = TabPFNClassifier(device='cuda', base_path='/work/dlclarge1/hollmann-PFN_Tabular/',
                                model_string=model_string, N_ensemble_configurations=4
                          , no_preprocess_mode=False, i=i, feature_shift_decoder=False)
    else:
        model = classifier_dict[classifier_key]
    
    ces = []
    for k in tqdm(range(0, len(test_datasets_longer_generalization))):
        x, y = test_datasets_longer_generalization[k][1], test_datasets_longer_generalization[k][2].numpy()
        x = normalize_data(x).numpy()
        x[np.isnan(x)] = 0.0
        print(x.shape[0])
        
        if x.shape[0] < 10000:
            continue
        if len(np.unique(y)) > 2:
            continue

        for bptt_ in [500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000, 5500, 6000, 6500, 7000, 7500, 8000, 8500, 9000, 9500, 10000]:
            bptt_ = bptt_ // 2
            #model = classifier_dict[classifier_key]
            x_, y_ = x.copy(), y.copy()
            x_train, x_test, y_train, y_test = train_test_split(x_, y_, test_size=0.5, random_state=split)
            x_train, y_train = x_train[0:bptt_], y_train[0:bptt_]
            model.fit(x_train, y_train) # ranking[0:j]
            pred = model.predict_proba(x_test) # ranking[0:j]
            ce = tabular_metrics.auc_metric(y_test, pred)
            ces += [{'bptt': bptt_, 'k': k, 'm': float(ce), 'method': classifier_key, 'split': split}]
            print(x_train.shape, ce)
    with open(f'generalization_{classifier_key}_{split}.obj',"wb") as fh:
        pickle.dump(ces,fh)

In [None]:
test_gen('tabpfn', 0)

In [None]:
ces = []
for classifier_key in classifier_dict:
    for split in range(0,5):
        try:
            with open(f'generalization_{classifier_key}_{split}.obj',"rb") as fh:
                ces += pickle.load(fh)
        except:
            pass
df = pd.DataFrame(ces)

In [None]:
df = df.groupby(['bptt', 'split', 'method']).mean().reset_index()
fig, ax = plt.subplots(1,1, figsize=(8, 6)) # , sharey=True

colors = iter(sns.color_palette("tab10"))
for classifier_key in ['tabpfn']:#df.method.unique():
    c = next(colors)
    sns.lineplot(x='bptt', y='m', data=df[df.method==classifier_key], label=relabeler[classifier_key], color=c, ax = ax)
    #ax.text(x = df[df.method==classifier_key].iloc[50].bptt, # x-coordinate position of data label
    # y = df[df.method==classifier_key].iloc[50].m, # y-coordinate position of data label, adjusted to be 150 below the data point
    # s = classifier_key, # data label, formatted to ignore decimals
    # color = c, size=12) # set colour of line
    
ax.get_legend().remove()
ax.set(xlabel='Number of training samples')
ax.set(ylabel='ROC AUC')
plt.axvline(x=1024, linestyle='dashed', color='red')
plt.ylim((0.73,0.79))
plt.xlim((250,5000))