In [1]:
import os
import sys
import time
sys.path.append("/home/ziniuw/zero-shot-cost-estimation")
from models.zero_shot_models.specific_models.model import zero_shot_models
from cross_db_benchmark.benchmark_tools.database import DatabaseSystem
from models.training.train import train_default, train_readout_hyperparams

Using backend: pytorch


In [2]:
from copy import copy

import numpy as np
import optuna
import torch
import torch.optim as opt
from tqdm import tqdm

from cross_db_benchmark.benchmark_tools.utils import load_json
from models.dataset.dataset_creation import create_dataloader
from models.training.checkpoint import save_checkpoint, load_checkpoint, save_csv
from models.training.metrics import MAPE, RMSE, QError
from models.training.utils import batch_to, flatten_dict, find_early_stopping_metric
from models.zero_shot_models.specific_models.model import zero_shot_models

def training_model_loader(workload_runs,
                test_workload_runs,
                statistics_file,
                target_dir,
                filename_model,
                optimizer_class_name='Adam',
                optimizer_kwargs=None,
                final_mlp_kwargs=None,
                node_type_kwargs=None,
                model_kwargs=None,
                tree_layer_name='GATConv',
                tree_layer_kwargs=None,
                hidden_dim=32,
                batch_size=32,
                output_dim=1,
                epochs=0,
                device='cpu',
                plan_featurization_name=None,
                max_epoch_tuples=100000,
                param_dict=None,
                num_workers=1,
                early_stopping_patience=20,
                trial=None,
                database=None,
                limit_queries=None,
                limit_queries_affected_wl=None,
                skip_train=False,
                seed=0):
    if model_kwargs is None:
        model_kwargs = dict()

    # seed for reproducibility
    torch.manual_seed(seed)
    np.random.seed(seed)

    target_test_csv_paths = []
    if test_workload_runs is not None:
        for p in test_workload_runs:
            test_workload = os.path.basename(p).replace('.json', '')
            target_test_csv_paths.append(os.path.join(target_dir, f'test_{filename_model}_{test_workload}.csv'))

    # create a dataset
    loss_class_name = final_mlp_kwargs['loss_class_name']
    label_norm, feature_statistics, train_loader, val_loader, test_loaders = \
        create_dataloader(workload_runs, test_workload_runs, statistics_file, plan_featurization_name, database,
                          val_ratio=0.15, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                          pin_memory=False, limit_queries=limit_queries,
                          limit_queries_affected_wl=limit_queries_affected_wl, loss_class_name=loss_class_name)

    if loss_class_name == 'QLoss':
        metrics = [RMSE(), MAPE(), QError(percentile=50, early_stopping_metric=True), QError(percentile=95),
                   QError(percentile=100)]
    elif loss_class_name == 'MSELoss':
        metrics = [RMSE(early_stopping_metric=True), MAPE(), QError(percentile=50), QError(percentile=95),
                   QError(percentile=100)]

    # create zero shot model dependent on database
    model = zero_shot_models[database](device=device, hidden_dim=hidden_dim, final_mlp_kwargs=final_mlp_kwargs,
                                       node_type_kwargs=node_type_kwargs, output_dim=output_dim,
                                       feature_statistics=feature_statistics, tree_layer_name=tree_layer_name,
                                       tree_layer_kwargs=tree_layer_kwargs,
                                       plan_featurization_name=plan_featurization_name,
                                       label_norm=label_norm,
                                       **model_kwargs)
    # move to gpu
    model = model.to(model.device)
    optimizer = opt.__dict__[optimizer_class_name](model.parameters(), **optimizer_kwargs)
    csv_stats, epochs_wo_improvement, epoch, model, optimizer, metrics, finished = \
        load_checkpoint(model, target_dir, filename_model, optimizer=optimizer, metrics=metrics, filetype='.pt')
    return test_loaders, model

def load_model(workload_runs,
            test_workload_runs,
               statistics_file,
               target_dir,
               filename_model,
               hyperparameter_path,
              device='cpu',
              max_epoch_tuples=100000,
              num_workers=1,
              loss_class_name='QLoss',
              database=None,
              seed=0,
              limit_queries=None,
              limit_queries_affected_wl=None,
              max_no_epochs=None,
              skip_train=False
              ):
    """
    Reads out hyperparameters and trains model
    """
    print(f"Reading hyperparameters from {hyperparameter_path}")
    hyperparams = load_json(hyperparameter_path, namespace=False)

    p_dropout = hyperparams.pop('p_dropout')
    # general fc out
    fc_out_kwargs = dict(p_dropout=p_dropout,
                         activation_class_name='LeakyReLU',
                         activation_class_kwargs={},
                         norm_class_name='Identity',
                         norm_class_kwargs={},
                         residual=hyperparams.pop('residual'),
                         dropout=hyperparams.pop('dropout'),
                         activation=True,
                         inplace=True)
    final_mlp_kwargs = dict(width_factor=hyperparams.pop('final_width_factor'),
                            n_layers=hyperparams.pop('final_layers'),
                            loss_class_name=loss_class_name,
                            loss_class_kwargs=dict())
    tree_layer_kwargs = dict(width_factor=hyperparams.pop('tree_layer_width_factor'),
                             n_layers=hyperparams.pop('message_passing_layers'))
    node_type_kwargs = dict(width_factor=hyperparams.pop('node_type_width_factor'),
                            n_layers=hyperparams.pop('node_layers'),
                            one_hot_embeddings=True,
                            max_emb_dim=hyperparams.pop('max_emb_dim'),
                            drop_whole_embeddings=False)
    final_mlp_kwargs.update(**fc_out_kwargs)
    tree_layer_kwargs.update(**fc_out_kwargs)
    node_type_kwargs.update(**fc_out_kwargs)

    train_kwargs = dict(optimizer_class_name='AdamW',
                        optimizer_kwargs=dict(
                            lr=hyperparams.pop('lr'),
                        ),
                        final_mlp_kwargs=final_mlp_kwargs,
                        node_type_kwargs=node_type_kwargs,
                        tree_layer_kwargs=tree_layer_kwargs,
                        tree_layer_name=hyperparams.pop('tree_layer_name'),
                        plan_featurization_name=hyperparams.pop('plan_featurization_name'),
                        hidden_dim=hyperparams.pop('hidden_dim'),
                        output_dim=1,
                        epochs=200 if max_no_epochs is None else max_no_epochs,
                        early_stopping_patience=20,
                        max_epoch_tuples=max_epoch_tuples,
                        batch_size=hyperparams.pop('batch_size'),
                        device=device,
                        num_workers=num_workers,
                        seed=seed,
                        limit_queries=limit_queries,
                        limit_queries_affected_wl=limit_queries_affected_wl,
                        skip_train=skip_train
                        )

    assert len(hyperparams) == 0, f"Not all hyperparams were used (not used: {hyperparams.keys()}). Hence generation " \
                                  f"and reading does not seem to fit"

    param_dict = flatten_dict(train_kwargs)
    
    test_loaders, model = training_model_loader(workload_runs, test_workload_runs, statistics_file, target_dir, filename_model, 
                          param_dict=param_dict, database=database, **train_kwargs)
    
    return test_loaders, model

def validate_model(val_loader, model, epoch=0, epoch_stats=None, metrics=None, max_epoch_tuples=None,
                   custom_batch_to=batch_to, verbose=False, log_all_queries=False):
    model.eval()

    with torch.autograd.no_grad():
        val_loss = torch.Tensor([0])
        labels = []
        preds = []
        probs = []
        sample_idxs = []

        # evaluate test set using model
        test_start_t = time.perf_counter()
        val_num_tuples = 0
        for batch_idx, batch in enumerate(tqdm(val_loader)):
            if max_epoch_tuples is not None and batch_idx * val_loader.batch_size > max_epoch_tuples:
                break

            val_num_tuples += val_loader.batch_size

            input_model, label, sample_idxs_batch = custom_batch_to(batch, model.device, model.label_norm)
            sample_idxs += sample_idxs_batch
            output = model(input_model)

            # sum up mean batch losses
            val_loss += model.loss_fxn(output, label).cpu()

            # inverse transform the predictions and labels
            curr_pred = output.cpu().numpy()
            curr_label = label.cpu().numpy()
            if model.label_norm is not None:
                curr_pred = model.label_norm.inverse_transform(curr_pred)
                curr_label = model.label_norm.inverse_transform(curr_label.reshape(-1, 1))
                curr_label = curr_label.reshape(-1)
           
            preds.append(curr_pred.reshape(-1))
            labels.append(curr_label.reshape(-1))

        if epoch_stats is not None:
            epoch_stats.update(val_time=time.perf_counter() - test_start_t)
            epoch_stats.update(val_num_tuples=val_num_tuples)
            val_loss = (val_loss.cpu() / len(val_loader)).item()
            print(f'val_loss epoch {epoch}: {val_loss}')
            epoch_stats.update(val_loss=val_loss)

        labels = np.concatenate(labels, axis=0)
        preds = np.concatenate(preds, axis=0)
        return labels, preds

In [3]:
database = DatabaseSystem.POSTGRES
workload_runs = ["../zero-shot-data/runs/parsed_plans/imdb_full/complex_workload_400k_s4_c8220.json",
                 "../zero-shot-data/runs/parsed_plans/imdb_full/complex_workload_400k_s5_c8220.json", 
                 "../zero-shot-data/runs/parsed_plans/imdb_full/complex_workload_400k_s6_c8220.json"]
test_workload_runs = ["../zero-shot-data/runs/parsed_plans/imdb_full/job_full_c8220.json"]
statistics_file = "../zero-shot-data/runs/parsed_plans/statistics_workload_combined.json"
target_dir = "../zero-shot-data/evaluation/job_full_tune/"
filename_model = "imdb_full_0"
hyperparameter_path = "/home/ziniuw/zero-shot-cost-estimation/setup/tuned_hyperparameters/tune_best_config.json"
limit_queries = 1000
ft_filename_model = filename_model + f"_ft_{limit_queries}"

test_loaders, model = load_model(workload_runs, test_workload_runs, statistics_file, target_dir, filename_model, 
                  hyperparameter_path, database=database)

Reading hyperparameters from /home/ziniuw/zero-shot-cost-estimation/setup/tuned_hyperparameters/tune_best_config.json
No of Plans: 50722
No of Plans: 77
Successfully loaded checkpoint from epoch 85 (85 csv rows) in 0.064 secs


In [4]:
true, pred = validate_model(test_loaders[0], model)
qerror = np.maximum(true/pred, pred/true)
print("performance without finetuning:")
for i in [50, 90, 95, 99, 100]:
    print(np.percentile(qerror, i))

100%|█████████████████████████████████████████████| 1/1 [00:01<00:00,  1.06s/it]

performance without finetuning:
1.2624925374984741
2.1067908763885503
2.556484699249268
6.628351116180374
13.535557746887207





In [7]:
model = train_readout_hyperparams(workload_runs, test_workload_runs, statistics_file, target_dir, ft_filename_model,
                          hyperparameter_path, num_workers=16, database=database, max_no_epochs=50, 
                          limit_queries=limit_queries, limit_queries_affected_wl=3
)

Reading hyperparameters from /home/ziniuw/zero-shot-cost-estimation/setup/tuned_hyperparameters/tune_est_best_config.json
Capping workload ../zero-shot-data/runs/parsed_plans/imdb_full/complex_workload_400k_s4_c8220.json after 333 queries
Stopping now
Capping workload ../zero-shot-data/runs/parsed_plans/imdb_full/complex_workload_400k_s5_c8220.json after 333 queries
Stopping now
Capping workload ../zero-shot-data/runs/parsed_plans/imdb_full/complex_workload_400k_s6_c8220.json after 333 queries
Stopping now
No of Plans: 1005
No of Plans: 77
PostgresZeroShotModel(
  (loss_fxn): QLoss()
  (fcout): Sequential(
    (0): FcLayer(
      (layers): Sequential(
        (0): Linear(in_features=128, out_features=192, bias=True)
        (1): LeakyReLU(negative_slope=0.01, inplace=True)
      )
    )
    (1): FcLayer(
      (layers): Sequential(
        (0): Linear(in_features=192, out_features=192, bias=True)
        (1): LeakyReLU(negative_slope=0.01, inplace=True)
      )
    )
    (2): FcLayer(


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.26it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.03it/s]


val_loss epoch 5: 0.8347772359848022
val_mse: 10.6745 [best: 6.5581]
val_mape: 2.1939 [best: 2.8307]
val_median_q_error_50: 2.1982 [best: 2.2543]
New best model for val_median_q_error_50
val_median_q_error_95: 10.6802 [best: 11.3514]
val_median_q_error_100: 18.4472 [best: 21.8811]
epochs_wo_improvement: 0
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.065 secs
Epoch 6


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.25it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.02it/s]


val_loss epoch 6: 0.7558313608169556
val_mse: 17.7687 [best: 6.5581]
val_mape: 1.8009 [best: 2.1939]
val_median_q_error_50: 2.1602 [best: 2.1982]
New best model for val_median_q_error_50
val_median_q_error_95: 10.5310 [best: 10.6802]
val_median_q_error_100: 16.3334 [best: 18.4472]
epochs_wo_improvement: 0
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.061 secs
Epoch 7


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.26it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.05it/s]


val_loss epoch 7: 0.6428970694541931
val_mse: 11.5070 [best: 6.5581]
val_mape: 1.0995 [best: 1.8009]
val_median_q_error_50: 1.8935 [best: 2.1602]
New best model for val_median_q_error_50
val_median_q_error_95: 6.2648 [best: 10.5310]
val_median_q_error_100: 21.5065 [best: 16.3334]
epochs_wo_improvement: 0
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.064 secs
Epoch 8


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.19it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.03it/s]


val_loss epoch 8: 0.6220356225967407
val_mse: 9.6421 [best: 6.5581]
val_mape: 0.6758 [best: 1.0995]
val_median_q_error_50: 1.7587 [best: 1.8935]
New best model for val_median_q_error_50
val_median_q_error_95: 8.5414 [best: 6.2648]
val_median_q_error_100: 35.4329 [best: 16.3334]
epochs_wo_improvement: 0
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.065 secs
Epoch 9


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.26it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.01it/s]


val_loss epoch 9: 0.6336520910263062
val_mse: 5.8531 [best: 6.5581]
val_mape: 0.6021 [best: 0.6758]
val_median_q_error_50: 1.7556 [best: 1.7587]
New best model for val_median_q_error_50
val_median_q_error_95: 6.7735 [best: 6.2648]
val_median_q_error_100: 69.7286 [best: 16.3334]
epochs_wo_improvement: 0
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.064 secs
Epoch 10


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.27it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.03it/s]


val_loss epoch 10: 0.5453565716743469
val_mse: 7.7701 [best: 5.8531]
val_mape: 0.8256 [best: 0.6021]
val_median_q_error_50: 1.6857 [best: 1.7556]
New best model for val_median_q_error_50
val_median_q_error_95: 4.3835 [best: 6.2648]
val_median_q_error_100: 29.6279 [best: 16.3334]
epochs_wo_improvement: 0
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.068 secs
Epoch 11


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.26it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.06it/s]


val_loss epoch 11: 0.5616745352745056
val_mse: 10.4737 [best: 5.8531]
val_mape: 0.9413 [best: 0.6021]
val_median_q_error_50: 1.7240 [best: 1.6857]
val_median_q_error_95: 4.5513 [best: 4.3835]
val_median_q_error_100: 16.0253 [best: 16.3334]
epochs_wo_improvement: 1
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.072 secs
Epoch 12


100%|█████████████████████████████████████████████| 4/4 [00:02<00:00,  1.98it/s]
100%|█████████████████████████████████████████████| 1/1 [00:01<00:00,  1.37s/it]


val_loss epoch 12: 0.5073257088661194
val_mse: 5.2252 [best: 5.8531]
val_mape: 0.6051 [best: 0.6021]
val_median_q_error_50: 1.5974 [best: 1.6857]
New best model for val_median_q_error_50
val_median_q_error_95: 4.5733 [best: 4.3835]
val_median_q_error_100: 17.0066 [best: 16.0253]
epochs_wo_improvement: 0
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.063 secs
Epoch 13


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.28it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.05it/s]


val_loss epoch 13: 0.49012452363967896
val_mse: 7.1892 [best: 5.2252]
val_mape: 0.6715 [best: 0.6021]
val_median_q_error_50: 1.6327 [best: 1.5974]
val_median_q_error_95: 3.4415 [best: 4.3835]
val_median_q_error_100: 13.4480 [best: 16.0253]
epochs_wo_improvement: 1
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.065 secs
Epoch 14


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.26it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.04it/s]


val_loss epoch 14: 0.44587960839271545
val_mse: 4.9170 [best: 5.2252]
val_mape: 0.5140 [best: 0.6021]
val_median_q_error_50: 1.5369 [best: 1.5974]
New best model for val_median_q_error_50
val_median_q_error_95: 3.5398 [best: 3.4415]
val_median_q_error_100: 11.7735 [best: 13.4480]
epochs_wo_improvement: 0
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.063 secs
Epoch 15


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.25it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.02it/s]


val_loss epoch 15: 0.5058817863464355
val_mse: 10.2620 [best: 4.9170]
val_mape: 0.7582 [best: 0.5140]
val_median_q_error_50: 1.6346 [best: 1.5369]
val_median_q_error_95: 4.0432 [best: 3.4415]
val_median_q_error_100: 9.2517 [best: 11.7735]
epochs_wo_improvement: 1
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.070 secs
Epoch 16


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.26it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.04it/s]


val_loss epoch 16: 0.45984938740730286
val_mse: 5.3334 [best: 4.9170]
val_mape: 0.7345 [best: 0.5140]
val_median_q_error_50: 1.5909 [best: 1.5369]
val_median_q_error_95: 3.5439 [best: 3.4415]
val_median_q_error_100: 8.3236 [best: 9.2517]
epochs_wo_improvement: 2
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.063 secs
Epoch 17


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.22it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.02it/s]


val_loss epoch 17: 0.3877796530723572
val_mse: 5.0488 [best: 4.9170]
val_mape: 0.4768 [best: 0.5140]
val_median_q_error_50: 1.3922 [best: 1.5369]
New best model for val_median_q_error_50
val_median_q_error_95: 3.0215 [best: 3.4415]
val_median_q_error_100: 9.9830 [best: 8.3236]
epochs_wo_improvement: 0
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.063 secs
Epoch 18


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.25it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.02it/s]


val_loss epoch 18: 0.3544260561466217
val_mse: 4.5736 [best: 4.9170]
val_mape: 0.4090 [best: 0.4768]
val_median_q_error_50: 1.3272 [best: 1.3922]
New best model for val_median_q_error_50
val_median_q_error_95: 2.9220 [best: 3.0215]
val_median_q_error_100: 9.7387 [best: 8.3236]
epochs_wo_improvement: 0
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.066 secs
Epoch 19


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.22it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.02it/s]


val_loss epoch 19: 0.4009760618209839
val_mse: 4.9017 [best: 4.5736]
val_mape: 0.4291 [best: 0.4090]
val_median_q_error_50: 1.4755 [best: 1.3272]
val_median_q_error_95: 3.2820 [best: 2.9220]
val_median_q_error_100: 12.8993 [best: 8.3236]
epochs_wo_improvement: 1
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.083 secs
Epoch 20


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.27it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.03it/s]


val_loss epoch 20: 0.36074620485305786
val_mse: 4.6128 [best: 4.5736]
val_mape: 0.5024 [best: 0.4090]
val_median_q_error_50: 1.3800 [best: 1.3272]
val_median_q_error_95: 2.9532 [best: 2.9220]
val_median_q_error_100: 5.9002 [best: 8.3236]
epochs_wo_improvement: 2
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.068 secs
Epoch 21


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.24it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.05it/s]


val_loss epoch 21: 0.35066160559654236
val_mse: 4.7711 [best: 4.5736]
val_mape: 0.4082 [best: 0.4090]
val_median_q_error_50: 1.3169 [best: 1.3272]
New best model for val_median_q_error_50
val_median_q_error_95: 3.0719 [best: 2.9220]
val_median_q_error_100: 8.8836 [best: 5.9002]
epochs_wo_improvement: 0
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.067 secs
Epoch 22


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.22it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.05it/s]


val_loss epoch 22: 0.41150790452957153
val_mse: 4.9606 [best: 4.5736]
val_mape: 0.6288 [best: 0.4082]
val_median_q_error_50: 1.4954 [best: 1.3169]
val_median_q_error_95: 3.3082 [best: 2.9220]
val_median_q_error_100: 5.9199 [best: 5.9002]
epochs_wo_improvement: 1
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.064 secs
Epoch 23


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.25it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.06it/s]


val_loss epoch 23: 0.4131030738353729
val_mse: 8.0252 [best: 4.5736]
val_mape: 0.5717 [best: 0.4082]
val_median_q_error_50: 1.4538 [best: 1.3169]
val_median_q_error_95: 3.4458 [best: 2.9220]
val_median_q_error_100: 8.9445 [best: 5.9002]
epochs_wo_improvement: 2
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.078 secs
Epoch 24


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.25it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.05it/s]


val_loss epoch 24: 0.3446336090564728
val_mse: 4.6077 [best: 4.5736]
val_mape: 0.4620 [best: 0.4082]
val_median_q_error_50: 1.3922 [best: 1.3169]
val_median_q_error_95: 2.8047 [best: 2.9220]
val_median_q_error_100: 6.1151 [best: 5.9002]
epochs_wo_improvement: 3
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.071 secs
Epoch 25


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.26it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.03it/s]


val_loss epoch 25: 0.3754085898399353
val_mse: 4.6078 [best: 4.5736]
val_mape: 0.3554 [best: 0.4082]
val_median_q_error_50: 1.3877 [best: 1.3169]
val_median_q_error_95: 3.2417 [best: 2.8047]
val_median_q_error_100: 10.1069 [best: 5.9002]
epochs_wo_improvement: 4
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.069 secs
Epoch 26


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.30it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.02it/s]


val_loss epoch 26: 0.316417932510376
val_mse: 4.4351 [best: 4.5736]
val_mape: 0.3483 [best: 0.3554]
val_median_q_error_50: 1.2903 [best: 1.3169]
New best model for val_median_q_error_50
val_median_q_error_95: 2.9149 [best: 2.8047]
val_median_q_error_100: 6.8986 [best: 5.9002]
epochs_wo_improvement: 0
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.064 secs
Epoch 27


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.24it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.04it/s]


val_loss epoch 27: 0.3146883547306061
val_mse: 4.4822 [best: 4.4351]
val_mape: 0.3663 [best: 0.3483]
val_median_q_error_50: 1.3104 [best: 1.2903]
val_median_q_error_95: 2.6798 [best: 2.8047]
val_median_q_error_100: 6.8508 [best: 5.9002]
epochs_wo_improvement: 1
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.067 secs
Epoch 28


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.25it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.03it/s]


val_loss epoch 28: 0.32015544176101685
val_mse: 4.2958 [best: 4.4351]
val_mape: 0.3342 [best: 0.3483]
val_median_q_error_50: 1.3254 [best: 1.2903]
val_median_q_error_95: 2.8673 [best: 2.6798]
val_median_q_error_100: 6.5323 [best: 5.9002]
epochs_wo_improvement: 2
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.069 secs
Epoch 29


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.22it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.04it/s]


val_loss epoch 29: 0.36817654967308044
val_mse: 6.3904 [best: 4.2958]
val_mape: 0.5172 [best: 0.3342]
val_median_q_error_50: 1.4274 [best: 1.2903]
val_median_q_error_95: 2.8965 [best: 2.6798]
val_median_q_error_100: 5.0485 [best: 5.9002]
epochs_wo_improvement: 3
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.070 secs
Epoch 30


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.19it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.03it/s]


val_loss epoch 30: 0.33414843678474426
val_mse: 4.3376 [best: 4.2958]
val_mape: 0.4539 [best: 0.3342]
val_median_q_error_50: 1.3749 [best: 1.2903]
val_median_q_error_95: 2.6099 [best: 2.6798]
val_median_q_error_100: 5.1960 [best: 5.0485]
epochs_wo_improvement: 4
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.076 secs
Epoch 31


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.23it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.03it/s]


val_loss epoch 31: 0.35122525691986084
val_mse: 4.2543 [best: 4.2958]
val_mape: 0.3510 [best: 0.3342]
val_median_q_error_50: 1.3675 [best: 1.2903]
val_median_q_error_95: 2.9455 [best: 2.6099]
val_median_q_error_100: 7.3008 [best: 5.0485]
epochs_wo_improvement: 5
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.078 secs
Epoch 32


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.21it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.06it/s]


val_loss epoch 32: 0.30999085307121277
val_mse: 4.1869 [best: 4.2543]
val_mape: 0.3447 [best: 0.3342]
val_median_q_error_50: 1.3209 [best: 1.2903]
val_median_q_error_95: 2.4963 [best: 2.6099]
val_median_q_error_100: 5.9395 [best: 5.0485]
epochs_wo_improvement: 6
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.063 secs
Epoch 33


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.25it/s]
100%|█████████████████████████████████████████████| 1/1 [00:01<00:00,  1.18s/it]


val_loss epoch 33: 0.3075467050075531
val_mse: 4.1720 [best: 4.1869]
val_mape: 0.3223 [best: 0.3342]
val_median_q_error_50: 1.3109 [best: 1.2903]
val_median_q_error_95: 2.7175 [best: 2.4963]
val_median_q_error_100: 5.7297 [best: 5.0485]
epochs_wo_improvement: 7
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.068 secs
Epoch 34


100%|█████████████████████████████████████████████| 4/4 [00:02<00:00,  1.99it/s]
100%|█████████████████████████████████████████████| 1/1 [00:01<00:00,  1.40s/it]


val_loss epoch 34: 0.3193889856338501
val_mse: 4.3117 [best: 4.1720]
val_mape: 0.3197 [best: 0.3223]
val_median_q_error_50: 1.3286 [best: 1.2903]
val_median_q_error_95: 2.8656 [best: 2.4963]
val_median_q_error_100: 6.1055 [best: 5.0485]
epochs_wo_improvement: 8
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.069 secs
Epoch 35


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.26it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.05it/s]


val_loss epoch 35: 0.3351970314979553
val_mse: 4.1913 [best: 4.1720]
val_mape: 0.3329 [best: 0.3197]
val_median_q_error_50: 1.3500 [best: 1.2903]
val_median_q_error_95: 2.6628 [best: 2.4963]
val_median_q_error_100: 6.9244 [best: 5.0485]
epochs_wo_improvement: 9
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.080 secs
Epoch 36


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.27it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.02it/s]


val_loss epoch 36: 0.3702506124973297
val_mse: 4.6224 [best: 4.1720]
val_mape: 0.3376 [best: 0.3197]
val_median_q_error_50: 1.4125 [best: 1.2903]
val_median_q_error_95: 3.2152 [best: 2.4963]
val_median_q_error_100: 6.4431 [best: 5.0485]
epochs_wo_improvement: 10
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.065 secs
Epoch 37


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.27it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.04it/s]


val_loss epoch 37: 0.3694334924221039
val_mse: 4.3225 [best: 4.1720]
val_mape: 0.3499 [best: 0.3197]
val_median_q_error_50: 1.4397 [best: 1.2903]
val_median_q_error_95: 3.0194 [best: 2.4963]
val_median_q_error_100: 6.8558 [best: 5.0485]
epochs_wo_improvement: 11
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.077 secs
Epoch 38


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.28it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.03it/s]


val_loss epoch 38: 0.3215378522872925
val_mse: 4.1602 [best: 4.1720]
val_mape: 0.3228 [best: 0.3197]
val_median_q_error_50: 1.3125 [best: 1.2903]
val_median_q_error_95: 2.8427 [best: 2.4963]
val_median_q_error_100: 5.7025 [best: 5.0485]
epochs_wo_improvement: 12
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.078 secs
Epoch 39


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.26it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.02it/s]


val_loss epoch 39: 0.2955303192138672
val_mse: 4.0164 [best: 4.1602]
val_mape: 0.3307 [best: 0.3197]
val_median_q_error_50: 1.2812 [best: 1.2903]
New best model for val_median_q_error_50
val_median_q_error_95: 2.6354 [best: 2.4963]
val_median_q_error_100: 4.9800 [best: 5.0485]
epochs_wo_improvement: 0
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.070 secs
Epoch 40


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.26it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.03it/s]


val_loss epoch 40: 0.2988407015800476
val_mse: 4.0602 [best: 4.0164]
val_mape: 0.3348 [best: 0.3197]
val_median_q_error_50: 1.3089 [best: 1.2812]
val_median_q_error_95: 2.5918 [best: 2.4963]
val_median_q_error_100: 5.1725 [best: 4.9800]
epochs_wo_improvement: 1
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.076 secs
Epoch 41


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.22it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.05it/s]


val_loss epoch 41: 0.298584520816803
val_mse: 4.0906 [best: 4.0164]
val_mape: 0.3365 [best: 0.3197]
val_median_q_error_50: 1.2916 [best: 1.2812]
val_median_q_error_95: 2.6018 [best: 2.4963]
val_median_q_error_100: 5.3529 [best: 4.9800]
epochs_wo_improvement: 2
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.074 secs
Epoch 42


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.31it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.03it/s]


val_loss epoch 42: 0.30834224820137024
val_mse: 4.0839 [best: 4.0164]
val_mape: 0.3481 [best: 0.3197]
val_median_q_error_50: 1.3403 [best: 1.2812]
val_median_q_error_95: 2.5797 [best: 2.4963]
val_median_q_error_100: 4.8368 [best: 4.9800]
epochs_wo_improvement: 3
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.076 secs
Epoch 43


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.28it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.03it/s]


val_loss epoch 43: 0.2964106500148773
val_mse: 3.9646 [best: 4.0164]
val_mape: 0.3666 [best: 0.3197]
val_median_q_error_50: 1.2909 [best: 1.2812]
val_median_q_error_95: 2.4119 [best: 2.4963]
val_median_q_error_100: 4.5699 [best: 4.8368]
epochs_wo_improvement: 4
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.067 secs
Epoch 44


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.29it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.03it/s]


val_loss epoch 44: 0.2902454137802124
val_mse: 4.2376 [best: 3.9646]
val_mape: 0.3396 [best: 0.3197]
val_median_q_error_50: 1.2728 [best: 1.2812]
New best model for val_median_q_error_50
val_median_q_error_95: 2.4989 [best: 2.4119]
val_median_q_error_100: 5.0239 [best: 4.5699]
epochs_wo_improvement: 0
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.072 secs
Epoch 45


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.26it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.05it/s]


val_loss epoch 45: 0.2846093475818634
val_mse: 4.0671 [best: 3.9646]
val_mape: 0.3368 [best: 0.3197]
val_median_q_error_50: 1.2985 [best: 1.2728]
val_median_q_error_95: 2.4430 [best: 2.4119]
val_median_q_error_100: 4.7464 [best: 4.5699]
epochs_wo_improvement: 1
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.091 secs
Epoch 46


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.25it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.05it/s]


val_loss epoch 46: 0.28765416145324707
val_mse: 4.0521 [best: 3.9646]
val_mape: 0.3291 [best: 0.3197]
val_median_q_error_50: 1.2842 [best: 1.2728]
val_median_q_error_95: 2.6288 [best: 2.4119]
val_median_q_error_100: 4.7534 [best: 4.5699]
epochs_wo_improvement: 2
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.074 secs
Epoch 47


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.21it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.03it/s]


val_loss epoch 47: 0.29770615696907043
val_mse: 4.2868 [best: 3.9646]
val_mape: 0.3747 [best: 0.3197]
val_median_q_error_50: 1.3161 [best: 1.2728]
val_median_q_error_95: 2.3248 [best: 2.4119]
val_median_q_error_100: 4.4016 [best: 4.5699]
epochs_wo_improvement: 3
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.067 secs
Epoch 48


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.24it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.08it/s]


val_loss epoch 48: 0.29744502902030945
val_mse: 3.9903 [best: 3.9646]
val_mape: 0.3769 [best: 0.3197]
val_median_q_error_50: 1.3362 [best: 1.2728]
val_median_q_error_95: 2.1857 [best: 2.3248]
val_median_q_error_100: 4.2493 [best: 4.4016]
epochs_wo_improvement: 4
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.063 secs
Epoch 49


100%|█████████████████████████████████████████████| 4/4 [00:01<00:00,  2.26it/s]
100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.05it/s]


val_loss epoch 49: 0.31218430399894714
val_mse: 4.3808 [best: 3.9646]
val_mape: 0.4157 [best: 0.3197]
val_median_q_error_50: 1.3403 [best: 1.2728]
val_median_q_error_95: 2.2187 [best: 2.1857]
val_median_q_error_100: 4.3901 [best: 4.2493]
epochs_wo_improvement: 5
Saved checkpoint to ../zero-shot-data/evaluation/job_full_tune/imdb_full_0_pg_est_ft_1000.pt in 0.074 secs
Early stopping kicked in due to no improvement in 20 epochs
Starting validation for ../zero-shot-data/evaluation/job_full_tune/test_imdb_full_0_pg_est_ft_1000_job_full_c8220.csv
Reloading best model


100%|█████████████████████████████████████████████| 1/1 [00:01<00:00,  1.14s/it]

val_loss epoch 50: 0.5266938209533691
val_mse: 6.2942 [best: 3.9646]
val_mape: 0.6153 [best: 0.3197]
val_median_q_error_50: 1.4630 [best: 1.2728]
val_median_q_error_95: 7.1657 [best: 2.1857]
val_median_q_error_100: 11.6235 [best: 4.2493]





In [8]:
true, pred = validate_model(test_loaders[0], model)
qerror = np.maximum(true/pred, pred/true)
print(f"performance with {limit_queries} finetuning:")
for i in [50, 90, 95, 99, 100]:
    print(np.percentile(qerror, i))

100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  1.15it/s]

performance with 1000 finetuning:
1.4629887342453003
3.233961677551271
7.165688800811769
10.138860626220692
11.623549461364746



