# Towards Scalable ZKML "training" via Prior Fitted Networks


### What if we could prove what training data was used to make predictions, in the same way we prove inference?



### TabPFN: A Transformer That Solves Small Tabular Classification Problems in a Second
https://github.com/automl/TabPFN


#### From the abstract
We present TabPFN, a trained Transformer that can do supervised classification for small tabular datasets in less than a second, needs no hyperparameter tuning and is competitive with state-of-the-art classification methods. TabPFN performs in-context learning (ICL), it learns to make predictions using sequences of labeled examples (x, f(x)) given in the input, without requiring further parameter updates. TabPFN is fully entailed in the weights of our network, which accepts training and test samples as a set-valued input and yields predictions for the entire test set in a single forward pass. TabPFN is a Prior-Data Fitted Network (PFN) and is trained offline once, to approximate Bayesian inference on synthetic datasets drawn from our prior. This prior incorporates ideas from causal reasoning: It entails a large space of structural causal models with a preference for simple structures. On the 18 datasets in the OpenML-CC18 suite that contain up to 1 000 training data points, up to 100 purely numerical features without missing values, and up to 10 classes, we show that our method clearly outperforms boosted trees and performs on par with complex state-of-the-art AutoML systems with up to 230× speedup. This increases to a 5 700× speedup when using a GPU. We also validate these results on an additional 67 small numerical datasets from OpenML. We provide all our code, the trained TabPFN, an interactive browser demo and a Colab notebook at this https URL.


### Setup

In [9]:
%load_ext autoreload

%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
import time
import torch
import numpy as np
import os


from scipy.ndimage import zoom
from giza_actions.action import action, Action
from giza_actions.task import task
from torch.utils.data import DataLoader, TensorDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from scripts.model_builder import get_default_spec, save_model, load_model_only_inference
from scripts.transformer_prediction_interface import transformer_predict, get_params_from_config, TabPFNClassifier

from datasets import load_openml_list, open_cc_dids, open_cc_valid_dids, test_dids_classification

from scripts import tabular_metrics
import random
import torch

from functools import partial
import tabpfn.encoders as encoders

from transformer import TransformerModel

import uuid

os.environ['GIZA_API_HOST'] = 'https://api-dev.gizatech.xyz'
print(os.environ['GIZA_API_HOST'])

https://api-dev.gizatech.xyz


### Load datasets

In [11]:
@task(name=f'get_dataset')
def get_dataset():
    base_path = '.'
    max_samples = 1000
    bptt = 10000
    test_datasets= load_openml_list([11], multiclass=True, shuffled=True, filter_for_nan=False, max_samples = max_samples, num_feats=100, return_capped=True)[0]
    ds = test_datasets[0]
    print(f'Evaluation dataset name: {ds[0]} shape {ds[1].shape}')
    xs, ys = ds[1].clone(), ds[2].clone()
    print(xs.shape)
    print(ys.shape)
    eval_position = xs.shape[0] // 2
    print(eval_position)
    train_xs, train_ys = xs[0:eval_position], ys[0:eval_position]
    test_xs, test_ys = xs[eval_position:], ys[eval_position:]
    return train_xs, train_ys, test_xs, test_ys


 `@task(name='my_unique_name', ...)`


In [12]:
def load_model_only_inference_for_onnx(path, filename, device):
    """
    Loads a saved model from the specified position. This function only restores inference capabilities and
    cannot be used for further training.
    """

    model_state, optimizer_state, config_sample = torch.load(os.path.join(path, filename), map_location='cpu')

    if (('nan_prob_no_reason' in config_sample and config_sample['nan_prob_no_reason'] > 0.0) or
        ('nan_prob_a_reason' in config_sample and config_sample['nan_prob_a_reason'] > 0.0) or
        ('nan_prob_unknown_reason' in config_sample and config_sample['nan_prob_unknown_reason'] > 0.0)):
        encoder = encoders.NanHandlingEncoder
    else:
        encoder = partial(encoders.Linear, replace_nan_by_zero=True)

    n_out = config_sample['max_num_classes']

    device = device if torch.cuda.is_available() else 'cpu:0'
    encoder = encoder(config_sample['num_features'], config_sample['emsize'])

    nhid = config_sample['emsize'] * config_sample['nhid_factor']
    y_encoder_generator = encoders.get_Canonical(config_sample['max_num_classes']) \
        if config_sample.get('canonical_y_encoder', False) else encoders.Linear

    assert config_sample['max_num_classes'] > 2
    loss = torch.nn.CrossEntropyLoss(reduction='none', weight=torch.ones(int(config_sample['max_num_classes'])))
    with torch.no_grad():
        model = TransformerModel(encoder, n_out, config_sample['emsize'], config_sample['nhead'], nhid,
                                config_sample['nlayers'], y_encoder=y_encoder_generator(1, config_sample['emsize']),
                                dropout=config_sample['dropout'],
                                full_attention=True,
                                num_global_att_tokens=None,
                                )

        # print(f"Using a Transformer with {sum(p.numel() for p in model.parameters()) / 1000 / 1000:.{2}f} M parameters")

        model.criterion = loss
        module_prefix = 'module.'
        model_state = {k.replace(module_prefix, ''): v for k, v in model_state.items()}
        model.load_state_dict(model_state)
        model.to(device)
        model.eval()

        return model # no loss measured


In [13]:
@task(name=f'generate_onnx_model')
def generate_onnx_model():
    path = '../tabpfn/models_diff/'
    filename = 'prior_diff_real_checkpoint_n_0_epoch_42.cpkt'
    device = 'cpu'
    model = load_model_only_inference_for_onnx(path, filename, device)
    
    torch.manual_seed(420)

    x =  torch.randn(625, 3, 100)
    y = torch.randn(312,3)

    d = (x,y)
    model.eval()

    input_names = ["src","onnx::Unsqueeze_1"]
    torch.onnx.export(model, (d, ), "tabpfn.onnx", input_names=input_names, export_params=True, opset_version=13, do_constant_folding=True)



 `@task(name='my_unique_name', ...)`


### TabPFN Sklearn interface

- The fit function does not perform any computations, but only saves the training data. Computations are only done at inference time, when calling predict.
 
- Note that the presaved models were trained for up to 100 features, 10 classes and 1000 samples. While the model does not have a hard bound on the number of samples, the features and classes are restricted and larger sizes lead to an error.

In [14]:
@task(name=f'init inference_only_model')
def init_inference_only_model():
    return TabPFNClassifier(device='cpu', only_inference=True)

@task(name=f'fit')
def fit(model, train_xs, train_ys):
    '''Initializes the TabPFN class with the training data, following their sklearn interface. 
       Note, there is NOT any model interaction happening here. :)'''
    print("Setting up TabPFN with training data context")
    return model.fit(train_xs, train_ys)
 
@task(name=f'predict')
def predict(model, test_xs, with_onnx=True):
    '''The TabPFN workflow is enhanced to use the GizaModel() if with_onnx=True.
       Please see transformer_prediction_interface.py for implementation details.'''
    print("prediction!")
    return model.predict_proba(test_xs, with_onnx=True)

    
@task(name=f'eval')
def auc_eval(test_ys, prediction):
    roc = tabular_metrics.auc_metric(test_ys, prediction)
    print('AUC', float(roc))


 `@task(name='my_unique_name', ...)`

 `@task(name='my_unique_name', ...)`

 `@task(name='my_unique_name', ...)`

 `@task(name='my_unique_name', ...)`


In [15]:
@action(name="Action: Test TabPFN", log_prints=True)
def run_model():
    
    generate_onnx_model()
    
    train_xs, train_ys, test_xs, test_ys = get_dataset()
    
    # loads in TabPFNClassifier with sklearn interface
    model = init_inference_only_model()
    
    # initializing only, no actual model interaction here!
    model = fit(model, train_xs, train_ys)
    
    # makes prediction on test_xs using approximate bayesian inference with train data 'in-context'
    prediction = predict(model, test_xs, with_onnx=True)
    
    auc_eval(test_ys, prediction)


 `@flow(name='my_unique_name', ...)`


In [16]:
run_model()

  torch.tensor([], device=x_src.device)
  global_src = torch.tensor([], device=x_src.device) if self.global_att_embeddings is None else \
  assert embed_dim == embed_dim_to_check, \
  assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
  assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
  if attn_mask.shape != correct_2d_size:
  q_scaled = q / math.sqrt(E)
  return fn(g, to_cast_func(g, input, False), to_cast_func(g, other, False))


  openml_list = openml.datasets.list_datasets(dids)


See https://numpy.org/devdocs/release/1.25.0-notes.html and the docs for more information.  (Deprecated NumPy 1.25)
  return np.find_common_type(types, [])


  dataset = openml.datasets.get_dataset(did)
  X, y, categorical_indicator, attribute_names = dataset.get_data(














[Completed(message=None, type=COMPLETED, result=UnpersistedResult(type='unpersisted', artifact_type='result', artifact_description='Unpersisted result of type `NoneType`')),
 Completed(message=None, type=COMPLETED, result=UnpersistedResult(type='unpersisted', artifact_type='result', artifact_description='Unpersisted result of type `tuple`')),
 Completed(message=None, type=COMPLETED, result=UnpersistedResult(type='unpersisted', artifact_type='result', artifact_description='Unpersisted result of type `TabPFNClassifier`')),
 Completed(message=None, type=COMPLETED, result=UnpersistedResult(type='unpersisted', artifact_type='result', artifact_description='Unpersisted result of type `TabPFNClassifier`')),
 Completed(message=None, type=COMPLETED, result=UnpersistedResult(type='unpersisted', artifact_type='result', artifact_description='Unpersisted result of type `ndarray`')),
 Completed(message=None, type=COMPLETED, result=UnpersistedResult(type='unpersisted', artifact_type='result', artifact

### Unsupported Transpiler Operators

- {'IsInf', 'Mod', 'Less', 'Cast', 'Shape', 'Greater', 'Erf', 'Pow', 'Transpose', 'Where', 'Sqrt', 'ReduceMean', 'IsNaN', 'Gather', 'Slice'}
- Many repeats across transformer encoders, might make more sense to add these to transpiler instead of manually implementing?

#### Not yet implemented in Orion...

- {'Cast', 'Shape'}
- Maybe the orion tensor trait shape attribute could simply be used?

# Next steps!!

- Implement these Operators in the transpiler :)
- Benchmark ZK overhead for TabPFN 
- Test out TabPFN on Giza Datasets