## Setup

In [22]:
%load_ext autoreload

%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
import random
import time
import warnings
from datetime import datetime

import torch

import numpy as np

import matplotlib.pyplot as plt
#from scripts.differentiable_pfn_evaluation import eval_model_range
# No string
from scripts.model_builder import get_model, get_default_spec, save_model, load_model
from scripts.transformer_prediction_interface import transformer_predict, get_params_from_config, load_model_workflow, TabPFNClassifier

from scripts.model_configs import *

from datasets import load_openml_list, open_cc_dids, open_cc_valid_dids
from priors.utils import plot_prior, plot_features
from priors.utils import uniform_int_sampler_f

from scripts.tabular_metrics import calculate_score_per_method, calculate_score
# No string
#from scripts.tabular_evaluation import evaluate

from priors.differentiable_prior import DifferentiableHyperparameterList, draw_random_style, merge_style_with_info
from scripts import tabular_metrics
from notebook_utils import *

In [24]:
large_datasets = True
max_samples = 10000 if large_datasets else 5000
bptt = 10000 if large_datasets else 3000
suite='cc'

In [25]:
device = 'cpu'
base_path = '.'
max_features = 100

In [26]:
def print_models(model_string):
    print(model_string)

    for i in range(80):
        for e in range(50):
            exists = Path(os.path.join(base_path, f'models_diff/prior_diff_real_checkpoint{model_string}_n_{i}_epoch_{e}.cpkt')).is_file()
            if exists:
                print(os.path.join(base_path, f'models_diff/prior_diff_real_checkpoint{model_string}_n_{i}_epoch_{e}.cpkt'))
        print()

In [27]:
def train_function(config_sample, i, add_name='', maximum_runtime = 15):
    start_time = time.time()
    N_epochs_to_save = 50
    def save_callback(model, epoch):
        if not hasattr(model, 'last_saved_epoch'):
            model.last_saved_epoch = 0
        if ((time.time() - start_time) / (maximum_runtime * 60 / N_epochs_to_save)) > model.last_saved_epoch:
            print('Saving model..')
            config_sample['epoch_in_training'] = epoch
            save_model(model, base_path, f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{model.last_saved_epoch}.cpkt',
                           config_sample)
            model.last_saved_epoch = model.last_saved_epoch + 1 # TODO: Rename to checkpoint
    
    model = get_model(config_sample
                      , device
                      , should_train=True
                      , verbose=1
                      , epoch_callback = save_callback)
    
    return

## Define prior settings

In [28]:
def reload_config(config_type='causal', task_type='multiclass', longer=0):
    config = get_prior_config(config_type=config_type)
    
    config['prior_type'], config['differentiable'], config['flexible'] = 'mlp', True, True
    
    model_string = ''
    
    config['epochs'] = 12000
    config['recompute_attn'] = False

    config['max_num_classes'] = 10
    config['num_classes'] = uniform_int_sampler_f(2, config['max_num_classes'])
    config['balanced'] = False
    model_string = model_string + '_multiclass'
    
    model_string = model_string + '_'+datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
    
    return config, model_string

## Visualize Prior samples

In [29]:
config, model_string = reload_config(longer=1)
print(config)
config['bptt_extra_samples'] = None

# diff
config['output_multiclass_ordered_p'] = 0.
del config['differentiable_hyperparameters']['output_multiclass_ordered_p']

config['multiclass_type'] = 'rank'
del config['differentiable_hyperparameters']['multiclass_type']

config['sampling'] = 'normal' # vielleicht schlecht?
del config['differentiable_hyperparameters']['sampling']

config['pre_sample_causes'] = True
# end diff

config['multiclass_loss_type'] = 'nono' # 'compatible'
config['normalize_to_ranking'] = False # False

config['categorical_feature_p'] = .2 # diff: .0

# turn this back on in a random search!?
config['nan_prob_no_reason'] = .0
config['nan_prob_unknown_reason'] = .0 # diff: .0
config['set_value_to_nan'] = .1 # diff: 1.

config['normalize_with_sqrt'] = False

config['new_mlp_per_example'] = True
config['prior_mlp_scale_weights_sqrt'] = True
config['batch_size_per_gp_sample'] = None

config['normalize_ignore_label_too'] = False

config['differentiable_hps_as_style'] = False
config['max_eval_pos'] = 1000

config['random_feature_rotation'] = True
config['rotate_normalized_labels'] = True

config["mix_activations"] = False # False heisst eig True

config['emsize'] = 512 # 512 in the paper
config['emsize_f'] = 100
config['nhead'] = config['emsize'] // 8 
config['bptt'] = 1024+128
config['canonical_y_encoder'] = False
config['nlayers'] = 2

    
config['aggregate_k_gradients'] = 8
config['batch_size'] = 2 * config['aggregate_k_gradients']
config['num_steps'] = 1024//config['aggregate_k_gradients']
config['epochs'] = 400
config['total_available_time_in_s'] = None #60*60*22 # 22 hours for some safety...

config['train_mixed_precision'] = True
config['efficient_eval_masking'] = True

config_sample = evaluate_hypers(config)

{'lr': lr, Type: UniformFloat, Range: [0.0001, 0.00015], Default: 0.0001224745, on log-scale, 'dropout': dropout, Type: Categorical, Choices: {0.0}, Default: 0.0, 'emsize': emsize, Type: Categorical, Choices: {256}, Default: 256, 'emsize_f': 256, 'batch_size': batch_size, Type: Categorical, Choices: {64, 128}, Default: 64, 'nlayers': nlayers, Type: Categorical, Choices: {12}, Default: 12, 'num_features': 100, 'nhead': nhead, Type: Categorical, Choices: {4}, Default: 4, 'nhid_factor': 2, 'bptt': 50, 'eval_positions': None, 'seq_len_used': 50, 'sampling': 'normal', 'epochs': 12000, 'num_steps': 100, 'verbose': False, 'mix_activations': False, 'pre_sample_causes': True, 'multiclass_type': 'rank', 'nan_prob_unknown_reason_reason_prior': nan_prob_unknown_reason_reason_prior, Type: Categorical, Choices: {0.5}, Default: 0.5, 'categorical_feature_p': categorical_feature_p, Type: Categorical, Choices: {0.0, 0.1, 0.2}, Default: 0.0, 'nan_prob_no_reason': nan_prob_no_reason, Type: Categorical, Ch

In [30]:
%%script echo skipping

config_sample['batch_size'] = 4
model = get_model(config_sample, device, should_train=False, verbose=2) # , state_dict=model[2].state_dict()
(hp_embedding, data, _), targets, single_eval_pos = next(iter(model[3]))

from utils import normalize_data
fig = plt.figure(figsize=(8, 8))
N = 100
plot_features(data[0:N, 0, 0:4], targets[0:N, 0], fig=fig)

d = np.concatenate([data[:, 0, :].T, np.expand_dims(targets[:, 0], -1).T])
d[np.isnan(d)] = 0
c = np.corrcoef(d)
plt.matshow(np.abs(c), vmin=0, vmax=1)
plt.show()

skipping


## Training

In [20]:
model = get_model(config_sample, device, should_train=True, verbose=1)

Using style prior: True
Using cpu:0 device
Using a Transformer with 5.03 M parameters


In [32]:
model = train_function(config_sample, i=0, add_name = '_Testing_1_')

Using style prior: True
Using cpu:0 device
Using a Transformer with 5.03 M parameters


In [39]:
model = TabPFNClassifier(device=device, base_path='/Users/antanas/GitRepo/TabPFN/tabpfn',
                         N_ensemble_configurations=32, model_string = '_Second_test_', epoch = 1)

Loading model that can be used for inference only
Using a Transformer with 5.03 M parameters


In [34]:
torch.load('/Users/antanas/GitRepo/TabPFN/tabpfn/models_diff/prior_diff_real_checkpoint_Second_test__n_0_epoch_1.cpkt',
           map_location=torch.device('cpu'))

(OrderedDict([('transformer_encoder.layers.0.self_attn.in_proj_weight',
               tensor([[ 0.0454,  0.0307, -0.0531,  ..., -0.0323, -0.0213,  0.0488],
                       [ 0.0401,  0.0006,  0.0315,  ...,  0.0326,  0.0319, -0.0293],
                       [ 0.0345, -0.0355,  0.0251,  ...,  0.0536,  0.0120, -0.0469],
                       ...,
                       [ 0.0071, -0.0001,  0.0325,  ..., -0.0441, -0.0508,  0.0288],
                       [-0.0304,  0.0268, -0.0551,  ..., -0.0501,  0.0418,  0.0033],
                       [ 0.0049,  0.0376, -0.0540,  ..., -0.0370, -0.0142, -0.0328]])),
              ('transformer_encoder.layers.0.self_attn.in_proj_bias',
               tensor([ 0.0006,  0.0011,  0.0014,  ..., -0.0017,  0.0017,  0.0017])),
              ('transformer_encoder.layers.0.self_attn.out_proj.weight',
               tensor([[ 0.0016,  0.0017,  0.0018,  ...,  0.0017, -0.0017, -0.0017],
                       [-0.0004, -0.0008, -0.0007,  ..., -0.0007,  0.0007