## Setup

In [32]:
%load_ext autoreload

%autoreload 2

  from imp import reload


In [81]:
import random
import time
import warnings
from datetime import datetime

from scripts.model_builder import get_model, get_default_spec, save_model, load_model
from scripts.transformer_prediction_interface import transformer_predict, get_params_from_config, load_model_workflow, TabPFNClassifier
from scripts.model_configs import *
from datasets import load_openml_list, open_cc_dids, open_cc_valid_dids
from priors.utils import plot_prior, plot_features
from priors.utils import uniform_int_sampler_f

from scripts.tabular_metrics import calculate_score_per_method, calculate_score
from priors.differentiable_prior import DifferentiableHyperparameterList, draw_random_style, merge_style_with_info
from scripts import tabular_metrics
from notebook_utils import *

In [82]:
# Parameters for training:
large_datasets = True
max_samples = 10000 if large_datasets else 5000
bptt = 10000 if large_datasets else 3000
suite='cc'

In [83]:
device = 'cpu'
base_path = '.'
max_features = 100

In [84]:
def print_models(model_string):
    print(model_string)

    for i in range(80):
        for e in range(50):
            exists = Path(os.path.join(base_path, f'models_diff/prior_diff_real_checkpoint{model_string}_n_{i}_epoch_{e}.cpkt')).is_file()
            if exists:
                print(os.path.join(base_path, f'models_diff/prior_diff_real_checkpoint{model_string}_n_{i}_epoch_{e}.cpkt'))
        print()

In [85]:
def train_function(config_sample, i, add_name='', maximum_runtime = 15):
    start_time = time.time()
    N_epochs_to_save = 50
    def save_callback(model, epoch):
        if not hasattr(model, 'last_saved_epoch'):
            model.last_saved_epoch = 0
        if ((time.time() - start_time) / (maximum_runtime * 60 / N_epochs_to_save)) > model.last_saved_epoch:
            print('Saving model..')
            config_sample['epoch_in_training'] = epoch
            save_model(model, base_path, f'models_diff/prior_diff_real_checkpoint{add_name}_n_{i}_epoch_{model.last_saved_epoch}.cpkt',
                           config_sample)
            model.last_saved_epoch = model.last_saved_epoch + 1 # TODO: Rename to checkpoint
    
    model = get_model(config_sample
                      , device
                      , should_train=True
                      , verbose=1
                      , epoch_callback = save_callback)
    
    return

## Define prior settings

In [86]:
def reload_config(config_type='causal', task_type='multiclass', longer=0):
    config = get_prior_config(config_type=config_type)
    
    config['prior_type'], config['differentiable'], config['flexible'] = 'mlp', True, True
    
    model_string = ''
    
    config['epochs'] = 12000
    config['recompute_attn'] = False

    config['max_num_classes'] = 10
    config['num_classes'] = uniform_int_sampler_f(2, config['max_num_classes'])
    config['balanced'] = False
    model_string = model_string + '_multiclass'
    
    model_string = model_string + '_'+datetime.now().strftime("%m_%d_%Y_%H_%M_%S")
    
    return config, model_string

## Parameters for training

In [95]:
config, model_string = reload_config(longer=1)
config['bptt_extra_samples'] = None

config['output_multiclass_ordered_p'] = 0.
del config['differentiable_hyperparameters']['output_multiclass_ordered_p']

config['multiclass_type'] = 'rank'
del config['differentiable_hyperparameters']['multiclass_type']

config['sampling'] = 'normal' # <- samling of weights
del config['differentiable_hyperparameters']['sampling']

config['pre_sample_causes'] = True

config['multiclass_loss_type'] = 'nono' 
config['normalize_to_ranking'] = False

config['categorical_feature_p'] = .2 # <- categorical feature probability

### Inclusion of NaNs in the features ###
config['nan_prob_no_reason'] = .0
config['nan_prob_unknown_reason'] = .0  
config['set_value_to_nan'] = .1 

config['normalize_with_sqrt'] = False

## Prior generation parameters ##
config['new_mlp_per_example'] = True
config['prior_mlp_scale_weights_sqrt'] = True
config['batch_size_per_gp_sample'] = None

config['normalize_ignore_label_too'] = False

config['differentiable_hps_as_style'] = False


config['max_eval_pos'] = 1000 # <- max evaluation position in batch.
config['random_feature_rotation'] = True
config['rotate_normalized_labels'] = True

config["mix_activations"] = False 

config['emsize'] = 512 # <- emmbeding size in the Intersample attention
config['emsize_f'] = 100 # <- embedding size in the Interfeature attention
config['nhead'] = config['emsize'] // 128 # <- number of heads in the Intersample attention (always 4 in interfeature)
config['bptt'] = 1024+128
config['canonical_y_encoder'] = False
config['nlayers'] = 1

    
config['aggregate_k_gradients'] = 8
config['batch_size'] = 1 * config['aggregate_k_gradients']
config['num_steps'] = 1024//config['aggregate_k_gradients']
config['epochs'] = 400
config['total_available_time_in_s'] = None #60*60*22 # 22 hours for some safety...

config['train_mixed_precision'] = True
config['efficient_eval_masking'] = True

config_sample = evaluate_hypers(config)

## Training

In [52]:
# Initiates training without saving.
model = get_model(config_sample, device, device = 'cpu',should_train=True, verbose=1)

Using style prior: True
Using cpu:0 device
Using a Transformer with 3.63 M parameters


In [None]:
# Initiates training - all model checkpoints are saved in tabpfn/models_diff folder
model = train_function(config_sample, device = 'cpu', i=0, add_name = '')