# Training

Experiments were conducted using:
- `mb_vae_dti/training/run.py`: command line interface for training
- and the scripts in `scripts/training/` for running the experiments

This notebook shows some plots and analysis of the unsupervised pretraining, general DTI training and benchmark fine-tuning results.

In [57]:
from resolve import *

In [None]:
from mb_vae_dti.validating.analysis import *

# should update these helper functions
# - read folder full of jsons
# - get all first and 2nd level keys and turn to df

# Downstream
# - sort by best_val_loss
# - get 5 best val loss
# - average test results, n_params, train_time over top 5
# - additional filters:
#    - for baseline: which encoder did better?
#    - for multi-modal: which aggregator did better?

df = load_gridsearch_results("notebooks/results/grid_mo_rand/")

if "metadata.data.provenance_cols" in df.columns:
    df["metadata.data.provenance_cols"] = df["metadata.data.provenance_cols"].apply(lambda x: x[0])
if "metadata.data.drug_features" in df.columns:
    df = df.drop(columns=["metadata.data.drug_features", "metadata.data.target_features"])

df.sort_values(by="best_val_loss", ascending=True)

2025-07-29 17:03:46,098 - mb_vae_dti.validating.analysis - INFO - Successfully loaded 47 results with 10 columns


Unnamed: 0,experiment_name,best_val_loss,best_epoch,val_metrics,test_metrics,trainable_params,config,timing.total_training_time,timing.avg_time_per_epoch,timing.total_epochs
34,multi_output_train_rand_b04c0004,0.774465,22,"{'val/loss_Y': 0.34521543979644775, 'val/loss_...","{'test/loss_Y': 0.33906838297843933, 'test/los...",44318469,"{'training.learning_rate': 0.0005, 'training.s...",18862.226883,538.920768,35
6,multi_output_train_rand_b03c0003,0.795921,21,"{'val/loss_Y': 0.3455863296985626, 'val/loss_Y...","{'test/loss_Y': 0.3403141498565674, 'test/loss...",44318469,"{'training.learning_rate': 0.001, 'training.sc...",18193.235152,535.095152,34
43,multi_output_train_rand_b02c0004,0.811379,37,"{'val/loss_Y': 0.3678390383720398, 'val/loss_Y...","{'test/loss_Y': 0.36141425371170044, 'test/los...",44318469,"{'training.learning_rate': 0.0005, 'training.s...",27077.049963,541.540999,50
23,multi_output_train_rand_b00c0004,0.821792,29,"{'val/loss_Y': 0.3467855155467987, 'val/loss_Y...","{'test/loss_Y': 0.3420794904232025, 'test/loss...",45114885,"{'training.learning_rate': 0.0005, 'training.s...",25892.316397,616.483724,42
19,multi_output_train_rand_b05c0003,0.871481,25,"{'val/loss_Y': 0.3511863648891449, 'val/loss_Y...","{'test/loss_Y': 0.3425217270851135, 'test/loss...",44318469,"{'training.learning_rate': 0.001, 'training.sc...",27079.617825,712.621522,38
39,multi_output_train_rand_b04c0001,0.900583,31,"{'val/loss_Y': 0.37625232338905334, 'val/loss_...","{'test/loss_Y': 0.3760126233100891, 'test/loss...",44318469,"{'training.learning_rate': 0.0001, 'training.s...",23727.569358,539.26294,44
2,multi_output_train_rand_b04c0000,0.905496,29,"{'val/loss_Y': 0.361272931098938, 'val/loss_Y_...","{'test/loss_Y': 0.35842740535736084, 'test/los...",44318469,"{'training.learning_rate': 0.0005, 'training.s...",11551.701686,275.040516,42
0,multi_output_train_rand_b06c0001,0.925667,27,"{'val/loss_Y': 0.3481124937534332, 'val/loss_Y...","{'test/loss_Y': 0.34490832686424255, 'test/los...",44318469,"{'training.learning_rate': 0.001, 'training.sc...",11391.050251,284.776256,40
8,multi_output_train_rand_b00c0001,0.942529,37,"{'val/loss_Y': 0.39523595571517944, 'val/loss_...","{'test/loss_Y': 0.38714006543159485, 'test/los...",45114885,"{'training.learning_rate': 0.0005, 'training.s...",31010.740196,620.214804,50
26,multi_output_train_rand_b05c0001,0.968147,34,"{'val/loss_Y': 0.34488821029663086, 'val/loss_...","{'test/loss_Y': 0.3431718945503235, 'test/loss...",44318469,"{'training.learning_rate': 0.0005, 'training.s...",17580.354755,374.050101,47


In [159]:
varying_attrs = get_varying_attributes(df)
varying_attrs.remove('metadata.logging.experiment_name')
varying_attrs.extend(["best_val_loss", 'best_epoch', 'trainable_params', "timing.total_training_time"])
varying_attrs

2025-07-27 19:35:42,354 - mb_vae_dti.validating.analysis - INFO - Found 8 varying attributes: ['metadata.training.learning_rate', 'metadata.data.batch_size', 'metadata.logging.experiment_name', 'metadata.model.encoder_kwargs.hidden_dim', 'metadata.model.encoder_kwargs.n_layers', 'metadata.model.embedding_dim', 'metadata.model.encoder_type', 'metadata.model.aggregator_type']


['metadata.training.learning_rate',
 'metadata.data.batch_size',
 'metadata.model.encoder_kwargs.hidden_dim',
 'metadata.model.encoder_kwargs.n_layers',
 'metadata.model.embedding_dim',
 'metadata.model.encoder_type',
 'metadata.model.aggregator_type',
 'best_val_loss',
 'best_epoch',
 'trainable_params',
 'timing.total_training_time']

In [160]:
subdf = get_top_performers(df, n_top=5)[varying_attrs]
# for each columns, rename the column by splitting the column name by "." and taking the last element
subdf.columns = subdf.columns.str.split(".").str[-1]
print(f"Total training time: {(subdf['total_training_time'].sum() / 3600):.2f} hours")
print(f"Average trainable params: {subdf['trainable_params'].mean() / 1000000:.1f}M")
subdf

Total training time: 7.57 hours
Average trainable params: 10.1M


Unnamed: 0,learning_rate,batch_size,hidden_dim,n_layers,embedding_dim,encoder_type,aggregator_type,best_val_loss,best_epoch,trainable_params,total_training_time
290,0.0005,16,256,3,1024,resnet,concat,0.154958,85,10555136,7962.247344
422,0.001,32,512,1,1024,resnet,concat,0.155305,83,16484864,2606.612936
300,0.001,16,256,3,1024,resnet,concat,0.156679,93,10555136,8547.191034
379,0.002,16,256,1,1024,resnet,concat,0.157262,95,5294336,4868.378374
229,0.0005,32,256,2,768,resnet,concat,0.1576,90,7596544,3258.366287


In [176]:
a = 0.65317
b = 0.62066
c = 0.64136
d = 0.6565
e = 0.64104


total = a + b + c + d + e
print(total / 5)

0.642546


In [186]:
ensemble = """
Name
5 visualized
test/Y_KIBA_ci
test/Y_KIBA_mse
test/Y_KIBA_pearson
test/Y_KIBA_r2
test/Y_KIBA_rmse
•
multi_modal_finetune_KIBA_rand_ensemble
0.85238
0.18802
0.84787
0.71881
0.43361
•
multi_modal_finetune_KIBA_rand_ensemble
0.84646
0.19966
0.83758
0.70141
0.44683
•
multi_modal_finetune_KIBA_rand_ensemble
0.84461
0.20252
0.83521
0.69713
0.45002
•
multi_modal_finetune_KIBA_rand_ensemble
0.84727
0.19507
0.84162
0.70826
0.44167
•
multi_modal_finetune_KIBA_rand_ensemble
0.84913
0.19269
0.84389
0.71183
0.43896
"""
from statistics import mean
lines = [line.strip() for line in ensemble.strip().split('\n') if line.strip()]
first_bullet = next(i for i, line in enumerate(lines) if line == '•')
start_idx = 2
column_names = lines[start_idx:first_bullet]
data = {col: [] for col in column_names}

i = first_bullet + 1
while i < len(lines):
    line = lines[i]
    if line == '•' or line.startswith('multi_modal_'):
        i += 1
        continue
    for j, col_name in enumerate(column_names):
        if i + j < len(lines):
            try:
                value = float(lines[i + j])
                data[col_name].append(value)
            except ValueError:
                break
    
    i += len(column_names)

results = {}
for col_name, values in data.items():
    results[col_name] = mean(values)

results

{'test/Y_KIBA_ci': 0.84797,
 'test/Y_KIBA_mse': 0.195592,
 'test/Y_KIBA_pearson': 0.841234,
 'test/Y_KIBA_r2': 0.707488,
 'test/Y_KIBA_rmse': 0.442218}

In [187]:
data

{'test/Y_KIBA_ci': [0.85238, 0.84646, 0.84461, 0.84727, 0.84913],
 'test/Y_KIBA_mse': [0.18802, 0.19966, 0.20252, 0.19507, 0.19269],
 'test/Y_KIBA_pearson': [0.84787, 0.83758, 0.83521, 0.84162, 0.84389],
 'test/Y_KIBA_r2': [0.71881, 0.70141, 0.69713, 0.70826, 0.71183],
 'test/Y_KIBA_rmse': [0.43361, 0.44683, 0.45002, 0.44167, 0.43896]}