This notebook shows how to use TabPFN for tabular prediction with a scikit learn wrapper.

classifier = TabPFNClassifier(device='cpu')
classifier.fit(train_xs, train_ys)
prediction_ = classifier.predict(test_xs)

The fit function does not perform any computations, but only saves the training data. Computations are only done at inference time, when calling predict.
Note that the presaved models were trained for up to 100 features, 10 classes and 1000 samples. While the model does not have a hard bound on the number of samples, the features and classes are restricted and larger sizes lead to an error.

### Setup

In [5]:

%autoreload all

UsageError: Line magic function `%autoreload` not found.


In [6]:
import time
import torch
import numpy as np
import os

from scripts.model_builder import get_default_spec, save_model, load_model_only_inference
from scripts.transformer_prediction_interface import transformer_predict, get_params_from_config, TabPFNClassifier
# from scripts.differentiable_pfn_evaluation import eval_model, eval_model_range

from datasets import load_openml_list, open_cc_dids, open_cc_valid_dids, test_dids_classification

from scripts import tabular_metrics
import random
import torch

from functools import partial
import tabpfn.encoders as encoders

from transformer import TransformerModel

from tabpfn.utils import get_uniform_single_eval_pos_sampler
import torch
import math



In [4]:
TransformerModel??

In [None]:
def load_model_only_inference(path, filename, device):
    """
    Loads a saved model from the specified position. This function only restores inference capabilities and
    cannot be used for further training.
    """

    model_state, optimizer_state, config_sample = torch.load(os.path.join(path, filename), map_location='cpu')

    if (('nan_prob_no_reason' in config_sample and config_sample['nan_prob_no_reason'] > 0.0) or
        ('nan_prob_a_reason' in config_sample and config_sample['nan_prob_a_reason'] > 0.0) or
        ('nan_prob_unknown_reason' in config_sample and config_sample['nan_prob_unknown_reason'] > 0.0)):
        encoder = encoders.NanHandlingEncoder
    else:
        encoder = partial(encoders.Linear, replace_nan_by_zero=True)

    n_out = config_sample['max_num_classes']

    device = device if torch.cuda.is_available() else 'cpu:0'
    encoder = encoder(config_sample['num_features'], config_sample['emsize'])

    nhid = config_sample['emsize'] * config_sample['nhid_factor']
    y_encoder_generator = encoders.get_Canonical(config_sample['max_num_classes']) \
        if config_sample.get('canonical_y_encoder', False) else encoders.Linear

    assert config_sample['max_num_classes'] > 2
    loss = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.ones(int(config_sample['max_num_classes'])))
    with torch.no_grad():
        model = TransformerModel(encoder, n_out, config_sample['emsize'], config_sample['nhead'], nhid,
                                config_sample['nlayers'], y_encoder=y_encoder_generator(1, config_sample['emsize']),
                                dropout=config_sample['dropout'],
                                # efficient_eval_masking=config_sample['efficient_eval_masking']
                                full_attention=True,
                                num_global_att_tokens=None,
                                )

        # print(f"Using a Transformer with {sum(p.numel() for p in model.parameters()) / 1000 / 1000:.{2}f} M parameters")

        model.criterion = loss
        module_prefix = 'module.'
        model_state = {k.replace(module_prefix, ''): v for k, v in model_state.items()}
        model.load_state_dict(model_state)
        model.to(device)
        model.eval()

        return model # no loss measured

In [None]:
path = '../tabpfn/models_diff/'
filename = 'prior_diff_real_checkpoint_n_0_epoch_42.cpkt'
model = load_model_only_inference(path,filename,'cpu')

In [None]:
x =  torch.randn(625, 3, 100)
y = torch.randn(312,3)
dummy = {'src': (x,y)}
model.forward(**dummy)

In [None]:
def convert_to_onnx(model, onnx_file_path):
    x =  torch.randn(625, 3, 100)
    y = torch.randn(312,3)
    dummy_input = {'src': (x,y),'single_eval_pos':312}
    torch.onnx.export(model, dummy_input, onnx_file_path,
                      export_params=True, opset_version=13, do_constant_folding=True)

    print(f"Model has been converted to ONNX and saved as {onnx_file_path}")


In [None]:
convert_to_onnx(model, 'full_model.onnx')

In [None]:
model.y_encoder.weight.shape

In [None]:
torch_input = torch.randn(512,1)
torch.onnx.export(model.y_encoder,torch_input,'y_encoder.onnx')

In [None]:
test = next(model.transformer_encoder.children())

In [None]:
t_enc_0 = model.transformer_encoder._modules['layers'][0]


In [None]:
t_enc_0

In [None]:
t_enc_0.self_attn.out_proj.weight

In [None]:
model.y_encoder.weight

In [None]:

import torch.nn as nn

new_module = nn.Linear(1,512)

with torch.no_grad():
    new_module.weight.copy_(model.y_encoder.weight)

torch_input = torch.randn(512,1)
torch.onnx.export(new_module,torch_input,'y_encoder.onnx')

In [None]:
torch_input = torch.randn(512,512)
torch.onnx.export(model.transformer_encoder.layers, torch_input, 't_enc_all.onnx')

In [None]:
# for module_name in model._modules:
#     module = getattr(model,module_name)
#     torch_input = torch.randn(module.out_features,module.in_features)
#     torch.onnx.export(model.module, torch_input,f'{module}.onnx')

In [None]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F

# class MyModel(nn.Module):

#     def __init__(self):
#         super(MyModel, self).__init__()
#         self.conv1 = nn.Conv2d(1, 6, 5)
#         self.conv2 = nn.Conv2d(6, 16, 5)
#         self.fc1 = nn.Linear(16 * 5 * 5, 120)
#         self.fc2 = nn.Linear(120, 512)
#         self.fc3 = nn.Linear(512, 10)

#     def forward(self, x):
#         x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
#         x = F.max_pool2d(F.relu(self.conv2(x)), 2)
#         x = torch.flatten(x, 1)
#         x = F.relu(self.fc1(x))
#         x = F.relu(self.fc2(x))
#         x = self.fc3(x)
#         return x

# torch_model = MyModel()
# torch_input = torch.randn(1, 1, 32, 32)
# torch.onnx.export(torch_model, torch_input, 'model2.onnx')

In [None]:
model = model

### Load datasets

In [None]:
base_path = '.'
max_samples = 10000
bptt = 10000

cc_test_datasets_multiclass, cc_test_datasets_multiclass_df = load_openml_list(open_cc_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = max_samples, num_feats=100, return_capped=True)
cc_valid_datasets_multiclass, cc_valid_datasets_multiclass_df = load_openml_list(open_cc_valid_dids, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = max_samples, num_feats=100, return_capped=True)

# Loading longer OpenML Datasets for generalization experiments (optional)
# test_datasets_multiclass, test_datasets_multiclass_df = load_openml_list(test_dids_classification, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 10000, num_feats=100, return_capped=True)

random.seed(0)
random.shuffle(cc_valid_datasets_multiclass)

In [None]:
def get_datasets(selector, task_type, suite='cc'):
    if task_type == 'binary':
        ds = valid_datasets_binary if selector == 'valid' else test_datasets_binary
    else:
        if suite == 'openml':
            ds = valid_datasets_multiclass if selector == 'valid' else test_datasets_multiclass
        elif suite == 'cc':
            ds = cc_valid_datasets_multiclass if selector == 'valid' else cc_test_datasets_multiclass
        else:
            raise Exception("Unknown suite")
    return ds

In [None]:
model_string, longer, task_type = '', 1, 'multiclass'
eval_positions = [1000]
bptt = 2000
    
test_datasets, valid_datasets = get_datasets('test', task_type, suite='cc'), get_datasets('valid', task_type, suite='cc')

### Run on a single dataset

In [None]:
[(i, test_datasets[i][0]) for i in range(len(test_datasets))]

In [None]:
evaluation_dataset_index = 0 # Index of the dataset to predict
ds = test_datasets[evaluation_dataset_index]
print(f'Evaluation dataset name: {ds[0]} shape {ds[1].shape}')

In [None]:
xs, ys = ds[1].clone(), ds[2].clone()
eval_position = xs.shape[0] // 2
train_xs, train_ys = xs[0:eval_position], ys[0:eval_position]
test_xs, test_ys = xs[eval_position:], ys[eval_position:]

In [None]:
xs.shape

In [None]:
# maintain this interface - override the model loading to instead use Giza/Onnx 
# pass Giza models to transformer_predict instead of self.model[2]
classifier = TabPFNClassifier(device='cpu', only_inference=True)


# make fit() a task
classifier.fit(train_xs, train_ys)

# modify transformer_predict to use Giza Model model predict instead of checkpoint(predict,
# line 360 in predict - use GizaModel and predict interface instead of model()
# replace model() with a worflow function that chains together all the GizaModels
# make dataset only one batch

# make predict_proba a task
prediction_ = classifier.predict_proba(test_xs)

In [None]:
# TODO:

# try porting over original tabpfn as a giza action! - make sure everything but onnx works
# save out totally pre-processed batch train and test
# port over TabPFN dependencies into the zk_tabpfn repo. (reproduce)

# test model execution using onnx format (test with just linear layer with dummy input, then with preprocessed real data)
# load in all the model layers and chain their outputs
# use final prediction logic from Tabpfn

# figure out how eval_pos is used by the model forward pass! 
# should the input_feed pass exal_xs, eval_ys, eval_pos ??

# try the weight copy hack to get full transpilation

In [None]:
classifier.model[2].transformer_encoder.layers._modules['0']._modules

In [None]:
classifier.model[2].transformer_encoder

In [None]:
classifier.style

In [None]:
roc, ce = tabular_metrics.auc_metric(test_ys, prediction_), tabular_metrics.cross_entropy(test_ys, prediction_)
'AUC', float(roc), 'Cross Entropy', float(ce)

### Run on all datasets
This section runs a differentiable hyperparameter tuning run and saves the results to a results file, which can be inserted in TabularEval.ipynb to compare to other baselines.

In [None]:
eval_positions=[1000]
bptt=2000

N_models = 3
models_per_block = 1

eval_addition = 'user_run'
device = 'cpu'

eval_model_range(i_range=[0], e=-1
                          , valid_datasets=[]#cc_valid_datasets_multiclass
                          , test_datasets=cc_test_datasets_multiclass
                          , train_datasets=[]
                          , eval_positions_test=eval_positions
                          , bptt_test=bptt
                          , add_name=model_string
                          , base_path=base_path
                          , selection_metric='auc'
                          , best_grad_steps=0
                          , eval_addition=eval_addition
                          , N_ensemble_configurations_list = [32]
                          , device=device)#range(0, 10)

### Run generalization experiments

In [None]:
# Loading longer OpenML Datasets for generalization experiments (optional)
test_datasets_multiclass, test_datasets_multiclass_df = load_openml_list(test_dids_classification, multiclass=True, shuffled=True, filter_for_nan=False, max_samples = 10000, num_feats=100, return_capped=True)


In [None]:
test_datasets_longer_generalization = [ds for ds in test_datasets_multiclass if ds[1].shape[0] >= 10000]

In [None]:
def test_gen(classifier_key, split):
    ces = []
    for k in tqdm(range(0, len(test_datasets_longer_generalization))):
        x, y = test_datasets_longer_generalization[k][1], test_datasets_longer_generalization[k][2].numpy()
        x = normalize_data(x).numpy()
        x[np.isnan(x)] = 0.0
        print(x.shape[0])
        
        if x.shape[0] < 10000:
            continue
        if len(np.unique(y)) > 2:
            continue

        for bptt_ in [500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000, 5500, 6000, 6500, 7000, 7500, 8000, 8500, 9000, 9500, 10000]:
            bptt_ = bptt_ // 2
            #model = classifier_dict[classifier_key]
            x_, y_ = x.copy(), y.copy()
            x_train, x_test, y_train, y_test = train_test_split(x_, y_, test_size=0.5, random_state=split)
            x_train, y_train = x_train[0:bptt_], y_train[0:bptt_]
            model.fit(x_train, y_train) # ranking[0:j]
            pred = model.predict_proba(x_test) # ranking[0:j]
            ce = tabular_metrics.auc_metric(y_test, pred)
            ces += [{'bptt': bptt_, 'k': k, 'm': float(ce), 'method': classifier_key, 'split': split}]
            print(x_train.shape, ce)
    with open(f'generalization_{classifier_key}_{split}.obj',"wb") as fh:
        pickle.dump(ces,fh)

In [None]:
test_gen('tabpfn', 0)

In [None]:
ces = []
for classifier_key in classifier_dict:
    for split in range(0,5):
        try:
            with open(f'generalization_{classifier_key}_{split}.obj',"rb") as fh:
                ces += pickle.load(fh)
        except:
            pass
df = pd.DataFrame(ces)

In [None]:
df = df.groupby(['bptt', 'split', 'method']).mean().reset_index()
fig, ax = plt.subplots(1,1, figsize=(8, 6)) # , sharey=True

colors = iter(sns.color_palette("tab10"))
for classifier_key in ['tabpfn']:#df.method.unique():
    c = next(colors)
    sns.lineplot(x='bptt', y='m', data=df[df.method==classifier_key], label=relabeler[classifier_key], color=c, ax = ax)
    #ax.text(x = df[df.method==classifier_key].iloc[50].bptt, # x-coordinate position of data label
    # y = df[df.method==classifier_key].iloc[50].m, # y-coordinate position of data label, adjusted to be 150 below the data point
    # s = classifier_key, # data label, formatted to ignore decimals
    # color = c, size=12) # set colour of line
    
ax.get_legend().remove()
ax.set(xlabel='Number of training samples')
ax.set(ylabel='ROC AUC')
plt.axvline(x=1024, linestyle='dashed', color='red')
plt.ylim((0.73,0.79))
plt.xlim((250,5000))