# Training

Experiments were conducted using:
- `mb_vae_dti/training/run.py`: command line interface for training
- and the scripts in `scripts/training/` for running the experiments

This notebook shows some plots and analysis of the unsupervised pretraining, general DTI training and benchmark fine-tuning results.

In [1]:
from resolve import *

Setting working directory to: /home/robsyc/Desktop/thesis/MB-VAE-DTI


In [None]:
from mb_vae_dti.validating.analysis import *

# should update these helper functions
# - read folder full of jsons
# - get all first and 2nd level keys and turn to df

# Downstream
# - sort by best_val_loss
# - get 5 best val loss
# - average test results, n_params, train_time over top 5
# - additional filters:
#    - for baseline: which encoder did better?
#    - for multi-modal: which aggregator did better?

df = load_gridsearch_results("notebooks/results/grid_mo_rand/")

# get subsets with desired distinction (encoder_type and aggregator_type)
with_resnet = df[df["config.model.encoder_type"] == "resnet"]
with_transformer = df[df["config.model.encoder_type"] == "transformer"]

# get 5 best performers
with_resnet_best = with_resnet.sort_values(by="best_val_loss", ascending=True).head(5)
with_transformer_best = with_transformer.sort_values(by="best_val_loss", ascending=True).head(5)

# get stats
print(f"Params: {with_resnet_best['trainable_params'].mean() / 1000000:.1f}M")
print(get_test_averages(with_resnet_best))
print("-" * 100)
print(f"Params: {with_transformer_best['trainable_params'].mean() / 1000000:.1f}M")
print(get_test_averages(with_transformer_best))

2025-07-30 01:53:43,772 - mb_vae_dti.validating.analysis - INFO - Loading 47 result files from notebooks/results/grid_mo_rand


Params: 44.3M
test/loss_Y               0.351866
test/loss_Y_pKd           0.203438
test/loss_Y_pKi           0.676976
test/loss_Y_KIBA          0.175217
test/loss                 0.859665
test/binary_accuracy      0.869318
test/binary_auprc         0.832982
test/binary_auroc         0.917402
test/binary_f1            0.758347
test/real_pKd_ci          0.862616
test/real_pKd_mse         0.428890
test/real_pKd_pearson     0.858399
test/real_pKd_r2          0.721310
test/real_pKd_rmse        0.654620
test/real_pKi_ci          0.823461
test/real_pKi_mse         0.696080
test/real_pKi_pearson     0.831988
test/real_pKi_r2          0.670206
test/real_pKi_rmse        0.833835
test/real_KIBA_ci         0.829142
test/real_KIBA_mse        0.294405
test/real_KIBA_pearson    0.806901
test/real_KIBA_r2         0.559779
test/real_KIBA_rmse       0.538769
dtype: float64
----------------------------------------------------------------------------------------------------
Params: 45.1M
test/loss_Y     