In [21]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import random
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import shutil
import pdb

sys.path.append('../')
from meta_fusion.benchmarks_general import *
from meta_fusion.methods_general import *
from meta_fusion.models_general import *
from meta_fusion.utils import *
from meta_fusion.third_party import *
from meta_fusion.synthetic_data_general import PrepareSyntheticData
from meta_fusion.synthetic_data_general import PrepareSyntheticData

from meta_fusion.config import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [89]:
def run_single_experiment(config, extractor_config, 
                          n, dim_modalities, dim_latent, noise_ratios, trans_type, mod_prop, interactive_modalities, interactive_prop,
                          mod_outs, combined_hiddens, mod_hiddens, is_mod_static, freeze_mod_extractors,
                          extractor_type, seperate, 
                          random_state, 
                          run_oracle=False, run_coop=True):

    config['random_state'] = random_state
    extractor_config['random_state'] = random_state
    res_list = []
    best_rho = {}
    cohort_pairs = {}

    #----------------#
    # Split dataset  #
    #----------------#
    train_loader, val_loader, test_loader, oracle_train_loader, oracle_val_loader, oracle_test_loader =\
    data_preparer.get_data_loaders(n, trans_type=trans_type, mod_prop=mod_prop, 
                                   interactive_prop = interactive_prop, interactive_modalities = interactive_modalities,
                                   dim_modalities=dim_modalities, dim_latent=dim_latent,
                                   noise_ratios=noise_ratios, random_state=random_state)
    # Get data info
    data_info = data_preparer.get_data_info()
    dim_modalities = data_info[0]
    n = data_info[1]
    n_train = data_info[2]
    n_val = data_info[3]
    n_test = data_info[4]

    print(f"Finished splitting {data_name} dataset. Data information are summarized below:\n"
            f"Modality dimensions: {dim_modalities}\n"
            f"Data size: {n}\n"
            f"Train size: {n_train}\n"
            f"Val size: {n_val}\n"
            f"Test size: {n_test}")
    sys.stdout.flush() 

    #------------------#
    # Benchmark models #
    #------------------#
    bm_outs = [[d,0] for d in dim_modalities]
    bm_extractor = Extractors(bm_outs, dim_modalities, train_loader, val_loader)
    _ = bm_extractor.get_dummy_extractors()
    bm_cohort = Cohorts(extractors=bm_extractor, combined_hidden_layers=combined_hiddens, output_dim=output_dim)

    if run_oracle:
        dim_oracles = dim_latent[:-1].copy()
        dim_oracles[-1]+=dim_latent[-1]
        oracle_outs = [[d,0] for d in dim_oracles]
        oracle_extractor = Extractors(oracle_outs, dim_oracles, oracle_train_loader, oracle_val_loader)
        _ = oracle_extractor.get_dummy_extractors()
        oracle_cohort = Cohorts(extractors=oracle_extractor, combined_hidden_layers=combined_hiddens, output_dim=output_dim)

    #----------------------------#
    # Proposed model: MetaFuse   #
    #----------------------------#
    meta_extractor = Extractors(mod_outs, dim_modalities, train_loader, val_loader)
    if extractor_type == 'encoder':
        _ = meta_extractor.get_encoder_extractors(mod_hiddens, seperate=seperate, config=extractor_config)
    elif extractor_type == 'PCA':
        _ = meta_extractor.get_PCA_extractors()
    meta_cohort = Cohorts(extractors=meta_extractor, combined_hidden_layers=combined_hiddens, output_dim=output_dim,
                          is_mod_static=is_mod_static, freeze_mod_extractors=freeze_mod_extractors)


    #------------------------------#
    #  Train and test benchmarks   #
    #------------------------------#
    bm_models = bm_cohort.get_cohort_models()
    _, bm_dims = bm_cohort.get_cohort_info()
    bm = Benchmarks(config, bm_models, bm_dims, [train_loader, val_loader])
    bm.train()
    res = bm.test(test_loader)
    res_list.append(res)
    print(f"Finished running basic benchmarks!")
    
    if run_oracle:
        oracle_config = config
        oracle_config["init_lr"] = 0.001
        oracle_models = oracle_cohort.get_cohort_models()
        _, oracle_dims = oracle_cohort.get_cohort_info()
        oracle = Benchmarks(config, oracle_models, oracle_dims, [oracle_train_loader, oracle_val_loader])
        oracle.train()
        res = oracle.test(oracle_test_loader)
        res = {f"oracle_{key}": value for key, value in res.items()}
        res_list.append(res)
        print(f"Finished running oracle benchmarks!")
        
    if run_coop:
        if len(dim_modalities) == 2:
            bm_models = bm_cohort.get_cohort_models()
            _, bm_dims = bm_cohort.get_cohort_info()    
            coop = Coop(config, bm_models, bm_dims, [train_loader, val_loader])
            coop.train()
            res = coop.test(test_loader)
            res_list.append(res)
            best_rho['coop'] = coop.best_rho
            print(f"Finished running coop!")
        else:
            print(f"Skip coop benchmark since it is not clear how it handles multiple modalities with interactions.")

    
    #------------------------------#
    #  Train and test MetaFuse  #
    #------------------------------#
    cohort_models = meta_cohort.get_cohort_models()
    _, dim_pairs = meta_cohort.get_cohort_info()
    metafuse = Trainer(config, cohort_models, [train_loader, val_loader])
    metafuse.train() 
    res = metafuse.test(test_loader)
    res_list.append(res)
    
    metafuse.train_ablation() 
    res = metafuse.test_ablation(test_loader)
    res_list.append(res)
    
    best_rho['meta_learner'] = metafuse.best_rho
    cohort_pairs['cohort'] = dim_pairs
    print(f"Finished running meta fusion!")
    
    results = []
    for res in res_list:
        for method, val in res.items():
            results.append({'Method': method, 'Test_metric': val, 
                            'best_rho':best_rho.get(method), 'cohort_pairs':cohort_pairs.get(method)})
    

    results = pd.DataFrame(results)

    results['random_state']=random_state
    results["dim_modalities"] = [dim_modalities] * len(results)
    results['n'] = n
    results['n_train'] = n_train
    results['n_val'] = n_val
    results['n_test'] = n_test 
    results['scale'] = scale

    return results

In [90]:
# Fixed data parameters
repetition=1
seed=1
scale=0

# Data model parameters
n = 2000
dim_modalities = [300, 200, 100, 50]
dim_latent = [20, 20, 10, 10, 0]
noise_ratios = [0.5, 0.1, 0.2, 0.1]
trans_type = ["quadratic", "linear", "linear", "linear","linear"]
mod_prop = [1, 1, 1, 1, 0, 0]
interactive_modalities = [1, 1, 1, 1]
interactive_prop = 0.1

mod_outs = [[50,30],[80,50,30],[30],[20]]
combined_hiddens = [300,200,100]
mod_hiddens = [[128]]*len(dim_modalities)

# n = 2000
# dim_modalities = [2000, 100]
# dim_latent = [20, 20, 0]
# noise_ratios = [0.5, 0.1]
# trans_type = ["quadratic", "linear", "linear"]
# mod_prop = [1, 1, 0, 1]
# interactive_prop = 0.1


# mod_outs = [[60],[40, 60]]
# combined_hiddens = [300,200,100]
# mod_hiddens = [[128]]*len(dim_modalities)

# data parameters
data_name = 'regression'
exp_name = data_name + "_" + "quadratic_early"
output_dim = 1  # specify the output dimension for regression


extractor_type = 'PCA'
seperate=False
is_mod_static=[True]*len(dim_modalities)
freeze_mod_extractors=[False]*len(dim_modalities)

# Load default model configurations 
config = load_config('../experiments_synthetic/config.json')
extractor_config = load_config('../experiments_synthetic/config_extractor.json')

# Model files directory
ckpt_dir = f"./checkpoints/{exp_name}/scale{scale}_seed{seed}/"
config['ckpt_dir'] = extractor_config['ckpt_dir'] = ckpt_dir

# Update other training parameters
config['divergence_weight_type'] = 'clustering'
config['optimal_k'] = None
config['output_dim'] = extractor_config['output_dim'] = output_dim
config["init_lr"] = 0.001
config["epochs"] = 10
extractor_config["init_lr"] = 0.001
config["ensemble_methods"] = [
        "simple_average",
        "weighted_average",
        "greedy_ensemble"
        ]
#####################
#    Load Dataset   #
#####################
data_preparer = PrepareSyntheticData(data_name = data_name, test_size = 0.2, val_size = 0.2)

In [91]:
results = []

#for i in tqdm(range(1, repetition+1), desc="Repetitions", leave=True, position=0):
for i in range(1, repetition+1):
    print(f'Running with repetition {i}...')
    random_state = repetition * (seed-1) + i
    set_random_seed(random_state)
    
    # Run experiment
    tmp = run_single_experiment(config, extractor_config, 
                              n, dim_modalities, dim_latent, noise_ratios, trans_type, mod_prop, interactive_modalities, interactive_prop,
                              mod_outs, combined_hiddens, mod_hiddens, is_mod_static, freeze_mod_extractors,
                              extractor_type, seperate, run_oracle=True, run_coop=False, random_state=random_state)
        
    results.append(tmp)

results = pd.concat(results, ignore_index=True)


#####################
#    Save Results   #
#####################
# results.to_csv(outfile, index=False)
# print("\nResults written to {:s}\n".format(outfile))
# sys.stdout.flush()

# After the job is done, remove the model directory to free up space
if os.path.exists(ckpt_dir):
    print(f"Deleting the model checkpoint directory: {ckpt_dir}")
    shutil.rmtree(ckpt_dir)
    print(f"Model checkpoint directory {ckpt_dir} has been deleted.")

Running with repetition 1...
Finished splitting regression dataset. Data information are summarized below:
Modality dimensions: [300, 200, 100, 50]
Data size: 2000
Train size: 1280
Val size: 320
Test size: 400
Start training benchmark models...
Training with disagreement penalty = 0

Epoch: 1/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:03<00:00, 349.76it/s]


model_1: train loss: 497.006, train task loss: 497.006 - val loss: 261.421, val task loss: 261.421 [*] Best so far
Created directory: ./checkpoints/regression_quadratic_early/scale0_seed1/0
model_2: train loss: 473.858, train task loss: 473.858 - val loss: 243.686, val task loss: 243.686 [*] Best so far
model_3: train loss: 489.301, train task loss: 489.301 - val loss: 289.694, val task loss: 289.694 [*] Best so far
model_4: train loss: 490.328, train task loss: 490.328 - val loss: 295.394, val task loss: 295.394 [*] Best so far
model_5: train loss: 444.354, train task loss: 444.354 - val loss: 158.570, val task loss: 158.570 [*] Best so far

Epoch: 2/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:03<00:00, 328.21it/s]


model_1: train loss: 256.575, train task loss: 256.575 - val loss: 230.286, val task loss: 230.286 [*] Best so far
model_2: train loss: 265.788, train task loss: 265.788 - val loss: 240.228, val task loss: 240.228 [*] Best so far
model_3: train loss: 319.710, train task loss: 319.710 - val loss: 284.617, val task loss: 284.617 [*] Best so far
model_4: train loss: 318.679, train task loss: 318.679 - val loss: 274.073, val task loss: 274.073 [*] Best so far
model_5: train loss: 151.980, train task loss: 151.980 - val loss: 144.855, val task loss: 144.855 [*] Best so far

Epoch: 3/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:04<00:00, 285.40it/s]


model_1: train loss: 214.143, train task loss: 214.143 - val loss: 230.348, val task loss: 230.348
model_2: train loss: 241.594, train task loss: 241.594 - val loss: 238.193, val task loss: 238.193 [*] Best so far
model_3: train loss: 303.981, train task loss: 303.981 - val loss: 283.630, val task loss: 283.630 [*] Best so far
model_4: train loss: 294.670, train task loss: 294.670 - val loss: 276.710, val task loss: 276.710
model_5: train loss: 107.858, train task loss: 107.858 - val loss: 126.189, val task loss: 126.189 [*] Best so far

Epoch: 4/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:04<00:00, 281.54it/s]


model_1: train loss: 187.352, train task loss: 187.352 - val loss: 228.216, val task loss: 228.216 [*] Best so far
model_2: train loss: 239.590, train task loss: 239.590 - val loss: 235.654, val task loss: 235.654 [*] Best so far
model_3: train loss: 300.102, train task loss: 300.102 - val loss: 282.280, val task loss: 282.280 [*] Best so far
model_4: train loss: 291.379, train task loss: 291.379 - val loss: 275.232, val task loss: 275.232
model_5: train loss: 90.678, train task loss: 90.678 - val loss: 124.290, val task loss: 124.290 [*] Best so far

Epoch: 5/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:05<00:00, 251.08it/s]


model_1: train loss: 170.688, train task loss: 170.688 - val loss: 226.165, val task loss: 226.165 [*] Best so far
model_2: train loss: 231.776, train task loss: 231.776 - val loss: 237.317, val task loss: 237.317
model_3: train loss: 292.552, train task loss: 292.552 - val loss: 291.793, val task loss: 291.793
model_4: train loss: 287.338, train task loss: 287.338 - val loss: 272.521, val task loss: 272.521 [*] Best so far
model_5: train loss: 78.338, train task loss: 78.338 - val loss: 122.870, val task loss: 122.870 [*] Best so far

Epoch: 6/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:05<00:00, 231.57it/s]


model_1: train loss: 157.868, train task loss: 157.868 - val loss: 226.983, val task loss: 226.983
model_2: train loss: 226.834, train task loss: 226.834 - val loss: 236.775, val task loss: 236.775
model_3: train loss: 288.481, train task loss: 288.481 - val loss: 288.004, val task loss: 288.004
model_4: train loss: 284.698, train task loss: 284.698 - val loss: 276.320, val task loss: 276.320
model_5: train loss: 69.608, train task loss: 69.608 - val loss: 115.683, val task loss: 115.683 [*] Best so far

Epoch: 7/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:04<00:00, 270.21it/s]


model_1: train loss: 142.611, train task loss: 142.611 - val loss: 229.923, val task loss: 229.923
model_2: train loss: 224.410, train task loss: 224.410 - val loss: 233.584, val task loss: 233.584 [*] Best so far
model_3: train loss: 285.973, train task loss: 285.973 - val loss: 284.690, val task loss: 284.690
model_4: train loss: 284.337, train task loss: 284.337 - val loss: 277.151, val task loss: 277.151
model_5: train loss: 59.709, train task loss: 59.709 - val loss: 112.987, val task loss: 112.987 [*] Best so far

Epoch: 8/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:05<00:00, 235.38it/s]


model_1: train loss: 129.928, train task loss: 129.928 - val loss: 254.280, val task loss: 254.280
model_2: train loss: 221.844, train task loss: 221.844 - val loss: 235.590, val task loss: 235.590
model_3: train loss: 284.559, train task loss: 284.559 - val loss: 308.989, val task loss: 308.989
model_4: train loss: 283.690, train task loss: 283.690 - val loss: 288.314, val task loss: 288.314
model_5: train loss: 48.655, train task loss: 48.655 - val loss: 104.864, val task loss: 104.864 [*] Best so far

Epoch: 9/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:04<00:00, 265.83it/s]


model_1: train loss: 115.281, train task loss: 115.281 - val loss: 239.976, val task loss: 239.976
model_2: train loss: 214.743, train task loss: 214.743 - val loss: 232.251, val task loss: 232.251 [*] Best so far
model_3: train loss: 280.226, train task loss: 280.226 - val loss: 290.313, val task loss: 290.313
model_4: train loss: 281.346, train task loss: 281.346 - val loss: 274.969, val task loss: 274.969
model_5: train loss: 33.164, train task loss: 33.164 - val loss: 100.566, val task loss: 100.566 [*] Best so far

Epoch: 10/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:04<00:00, 258.26it/s]


model_1: train loss: 101.788, train task loss: 101.788 - val loss: 243.450, val task loss: 243.450
model_2: train loss: 213.147, train task loss: 213.147 - val loss: 231.299, val task loss: 231.299 [*] Best so far
model_3: train loss: 280.165, train task loss: 280.165 - val loss: 282.494, val task loss: 282.494
model_4: train loss: 278.605, train task loss: 278.605 - val loss: 276.133, val task loss: 276.133
model_5: train loss: 21.151, train task loss: 21.151 - val loss: 97.003, val task loss: 97.003 [*] Best so far
Finished training benchmark models!
Method: (modality_1), Test_MSE: 227.9994659423828
Method: (modality_2), Test_MSE: 214.0945281982422
Method: (modality_3), Test_MSE: 281.72479248046875
Method: (modality_4), Test_MSE: 271.2962646484375
Method: (early_fusion), Test_MSE: 110.87876892089844
Method: (late_fusion), Test_MSE: 202.2051239013672
Finished running basic benchmarks!
Start training benchmark models...
Training with disagreement penalty = 0

Epoch: 1/10 - LR: 0.001000

100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:05<00:00, 246.31it/s]


model_1: train loss: 481.571, train task loss: 481.571 - val loss: 279.113, val task loss: 279.113 [*] Best so far
model_2: train loss: 469.070, train task loss: 469.070 - val loss: 283.910, val task loss: 283.910 [*] Best so far
model_3: train loss: 471.323, train task loss: 471.323 - val loss: 304.201, val task loss: 304.201 [*] Best so far
model_4: train loss: 459.965, train task loss: 459.965 - val loss: 305.754, val task loss: 305.754 [*] Best so far
model_5: train loss: 485.677, train task loss: 485.677 - val loss: 274.022, val task loss: 274.022 [*] Best so far

Epoch: 2/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:04<00:00, 257.81it/s]


model_1: train loss: 291.015, train task loss: 291.015 - val loss: 244.069, val task loss: 244.069 [*] Best so far
model_2: train loss: 299.899, train task loss: 299.899 - val loss: 252.432, val task loss: 252.432 [*] Best so far
model_3: train loss: 332.833, train task loss: 332.833 - val loss: 280.489, val task loss: 280.489 [*] Best so far
model_4: train loss: 326.167, train task loss: 326.167 - val loss: 278.164, val task loss: 278.164 [*] Best so far
model_5: train loss: 249.763, train task loss: 249.763 - val loss: 191.772, val task loss: 191.772 [*] Best so far

Epoch: 3/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:05<00:00, 255.98it/s]


model_1: train loss: 258.669, train task loss: 258.669 - val loss: 234.151, val task loss: 234.151 [*] Best so far
model_2: train loss: 261.833, train task loss: 261.833 - val loss: 237.659, val task loss: 237.659 [*] Best so far
model_3: train loss: 308.273, train task loss: 308.273 - val loss: 275.636, val task loss: 275.636 [*] Best so far
model_4: train loss: 298.246, train task loss: 298.246 - val loss: 273.469, val task loss: 273.469 [*] Best so far
model_5: train loss: 172.348, train task loss: 172.348 - val loss: 151.254, val task loss: 151.254 [*] Best so far

Epoch: 4/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:05<00:00, 249.15it/s]


model_1: train loss: 240.577, train task loss: 240.577 - val loss: 223.684, val task loss: 223.684 [*] Best so far
model_2: train loss: 246.701, train task loss: 246.701 - val loss: 233.881, val task loss: 233.881 [*] Best so far
model_3: train loss: 299.953, train task loss: 299.953 - val loss: 277.640, val task loss: 277.640
model_4: train loss: 288.951, train task loss: 288.951 - val loss: 272.903, val task loss: 272.903 [*] Best so far
model_5: train loss: 132.886, train task loss: 132.886 - val loss: 129.831, val task loss: 129.831 [*] Best so far

Epoch: 5/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:03<00:00, 381.30it/s]


model_1: train loss: 230.907, train task loss: 230.907 - val loss: 220.059, val task loss: 220.059 [*] Best so far
model_2: train loss: 240.126, train task loss: 240.126 - val loss: 232.288, val task loss: 232.288 [*] Best so far
model_3: train loss: 295.348, train task loss: 295.348 - val loss: 275.659, val task loss: 275.659
model_4: train loss: 286.338, train task loss: 286.338 - val loss: 274.487, val task loss: 274.487
model_5: train loss: 114.939, train task loss: 114.939 - val loss: 121.590, val task loss: 121.590 [*] Best so far

Epoch: 6/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:05<00:00, 228.42it/s]


model_1: train loss: 223.625, train task loss: 223.625 - val loss: 219.529, val task loss: 219.529 [*] Best so far
model_2: train loss: 235.599, train task loss: 235.599 - val loss: 231.741, val task loss: 231.741 [*] Best so far
model_3: train loss: 293.304, train task loss: 293.304 - val loss: 280.368, val task loss: 280.368
model_4: train loss: 283.597, train task loss: 283.597 - val loss: 281.353, val task loss: 281.353
model_5: train loss: 105.253, train task loss: 105.253 - val loss: 117.021, val task loss: 117.021 [*] Best so far

Epoch: 7/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:04<00:00, 263.13it/s]


model_1: train loss: 213.813, train task loss: 213.813 - val loss: 220.995, val task loss: 220.995
model_2: train loss: 232.300, train task loss: 232.300 - val loss: 230.876, val task loss: 230.876 [*] Best so far
model_3: train loss: 290.449, train task loss: 290.449 - val loss: 285.997, val task loss: 285.997
model_4: train loss: 282.495, train task loss: 282.495 - val loss: 275.728, val task loss: 275.728
model_5: train loss: 97.690, train task loss: 97.690 - val loss: 115.730, val task loss: 115.730 [*] Best so far

Epoch: 8/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:05<00:00, 247.72it/s]


model_1: train loss: 204.463, train task loss: 204.463 - val loss: 212.206, val task loss: 212.206 [*] Best so far
model_2: train loss: 228.429, train task loss: 228.429 - val loss: 229.743, val task loss: 229.743 [*] Best so far
model_3: train loss: 287.418, train task loss: 287.418 - val loss: 272.247, val task loss: 272.247 [*] Best so far
model_4: train loss: 280.410, train task loss: 280.410 - val loss: 270.807, val task loss: 270.807 [*] Best so far
model_5: train loss: 89.249, train task loss: 89.249 - val loss: 108.080, val task loss: 108.080 [*] Best so far

Epoch: 9/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:04<00:00, 258.11it/s]


model_1: train loss: 192.730, train task loss: 192.730 - val loss: 206.587, val task loss: 206.587 [*] Best so far
model_2: train loss: 224.431, train task loss: 224.431 - val loss: 230.194, val task loss: 230.194
model_3: train loss: 285.280, train task loss: 285.280 - val loss: 274.890, val task loss: 274.890
model_4: train loss: 278.976, train task loss: 278.976 - val loss: 269.680, val task loss: 269.680 [*] Best so far
model_5: train loss: 77.069, train task loss: 77.069 - val loss: 103.579, val task loss: 103.579 [*] Best so far

Epoch: 10/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:04<00:00, 259.86it/s]


model_1: train loss: 180.785, train task loss: 180.785 - val loss: 208.671, val task loss: 208.671
model_2: train loss: 221.589, train task loss: 221.589 - val loss: 227.193, val task loss: 227.193 [*] Best so far
model_3: train loss: 283.575, train task loss: 283.575 - val loss: 274.963, val task loss: 274.963
model_4: train loss: 276.779, train task loss: 276.779 - val loss: 276.649, val task loss: 276.649
model_5: train loss: 63.327, train task loss: 63.327 - val loss: 95.248, val task loss: 95.248 [*] Best so far
Finished training benchmark models!
Method: (modality_1), Test_MSE: 198.72540283203125
Method: (modality_2), Test_MSE: 210.79225158691406
Method: (modality_3), Test_MSE: 280.873779296875
Method: (modality_4), Test_MSE: 265.84417724609375
Method: (early_fusion), Test_MSE: 93.92153930664062
Method: (late_fusion), Test_MSE: 186.94775390625
Finished running oracle benchmarks!
Start training student cohort...
Training with disagreement penalty = 0

Epoch: 1/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:06<00:00, 211.66it/s]


model_1: train loss: 452.943, train task loss: 452.943 - val loss: 235.212, val task loss: 235.212 [*] Best so far
model_2: train loss: 442.088, train task loss: 442.088 - val loss: 265.077, val task loss: 265.077 [*] Best so far
model_3: train loss: 440.913, train task loss: 440.913 - val loss: 256.667, val task loss: 256.667 [*] Best so far
model_4: train loss: 434.685, train task loss: 434.685 - val loss: 268.245, val task loss: 268.245 [*] Best so far
model_5: train loss: 438.722, train task loss: 438.722 - val loss: 270.752, val task loss: 270.752 [*] Best so far
model_6: train loss: 427.837, train task loss: 427.837 - val loss: 274.006, val task loss: 274.006 [*] Best so far

Epoch: 2/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:05<00:00, 215.92it/s]


model_1: train loss: 213.306, train task loss: 213.306 - val loss: 168.125, val task loss: 168.125 [*] Best so far
model_2: train loss: 224.006, train task loss: 224.006 - val loss: 170.598, val task loss: 170.598 [*] Best so far
model_3: train loss: 226.246, train task loss: 226.246 - val loss: 172.746, val task loss: 172.746 [*] Best so far
model_4: train loss: 221.306, train task loss: 221.306 - val loss: 170.102, val task loss: 170.102 [*] Best so far
model_5: train loss: 228.447, train task loss: 228.447 - val loss: 174.455, val task loss: 174.455 [*] Best so far
model_6: train loss: 226.949, train task loss: 226.949 - val loss: 175.500, val task loss: 175.500 [*] Best so far

Epoch: 3/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:06<00:00, 204.46it/s]


model_1: train loss: 142.380, train task loss: 142.380 - val loss: 141.039, val task loss: 141.039 [*] Best so far
model_2: train loss: 145.014, train task loss: 145.014 - val loss: 139.456, val task loss: 139.456 [*] Best so far
model_3: train loss: 150.007, train task loss: 150.007 - val loss: 141.353, val task loss: 141.353 [*] Best so far
model_4: train loss: 146.735, train task loss: 146.735 - val loss: 142.662, val task loss: 142.662 [*] Best so far
model_5: train loss: 151.701, train task loss: 151.701 - val loss: 143.788, val task loss: 143.788 [*] Best so far
model_6: train loss: 153.769, train task loss: 153.769 - val loss: 144.950, val task loss: 144.950 [*] Best so far

Epoch: 4/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:06<00:00, 203.29it/s]


model_1: train loss: 117.104, train task loss: 117.104 - val loss: 134.809, val task loss: 134.809 [*] Best so far
model_2: train loss: 115.811, train task loss: 115.811 - val loss: 130.477, val task loss: 130.477 [*] Best so far
model_3: train loss: 119.135, train task loss: 119.135 - val loss: 131.053, val task loss: 131.053 [*] Best so far
model_4: train loss: 119.643, train task loss: 119.643 - val loss: 133.624, val task loss: 133.624 [*] Best so far
model_5: train loss: 121.536, train task loss: 121.536 - val loss: 133.673, val task loss: 133.673 [*] Best so far
model_6: train loss: 124.235, train task loss: 124.235 - val loss: 133.315, val task loss: 133.315 [*] Best so far

Epoch: 5/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:08<00:00, 145.95it/s]


model_1: train loss: 106.280, train task loss: 106.280 - val loss: 129.665, val task loss: 129.665 [*] Best so far
model_2: train loss: 102.760, train task loss: 102.760 - val loss: 125.984, val task loss: 125.984 [*] Best so far
model_3: train loss: 105.323, train task loss: 105.323 - val loss: 125.076, val task loss: 125.076 [*] Best so far
model_4: train loss: 107.696, train task loss: 107.696 - val loss: 130.571, val task loss: 130.571 [*] Best so far
model_5: train loss: 108.016, train task loss: 108.016 - val loss: 129.137, val task loss: 129.137 [*] Best so far
model_6: train loss: 110.650, train task loss: 110.650 - val loss: 128.248, val task loss: 128.248 [*] Best so far

Epoch: 6/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:06<00:00, 200.13it/s]


model_1: train loss: 98.895, train task loss: 98.895 - val loss: 126.303, val task loss: 126.303 [*] Best so far
model_2: train loss: 93.312, train task loss: 93.312 - val loss: 122.480, val task loss: 122.480 [*] Best so far
model_3: train loss: 95.984, train task loss: 95.984 - val loss: 121.648, val task loss: 121.648 [*] Best so far
model_4: train loss: 99.748, train task loss: 99.748 - val loss: 126.609, val task loss: 126.609 [*] Best so far
model_5: train loss: 99.288, train task loss: 99.288 - val loss: 124.842, val task loss: 124.842 [*] Best so far
model_6: train loss: 101.878, train task loss: 101.878 - val loss: 124.069, val task loss: 124.069 [*] Best so far

Epoch: 7/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:06<00:00, 212.61it/s]


model_1: train loss: 91.475, train task loss: 91.475 - val loss: 124.253, val task loss: 124.253 [*] Best so far
model_2: train loss: 82.510, train task loss: 82.510 - val loss: 124.516, val task loss: 124.516
model_3: train loss: 83.523, train task loss: 83.523 - val loss: 121.336, val task loss: 121.336 [*] Best so far
model_4: train loss: 91.111, train task loss: 91.111 - val loss: 124.780, val task loss: 124.780 [*] Best so far
model_5: train loss: 88.643, train task loss: 88.643 - val loss: 127.083, val task loss: 127.083
model_6: train loss: 91.987, train task loss: 91.987 - val loss: 124.928, val task loss: 124.928

Epoch: 8/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:06<00:00, 210.56it/s]


model_1: train loss: 83.780, train task loss: 83.780 - val loss: 121.507, val task loss: 121.507 [*] Best so far
model_2: train loss: 71.461, train task loss: 71.461 - val loss: 112.110, val task loss: 112.110 [*] Best so far
model_3: train loss: 71.556, train task loss: 71.556 - val loss: 103.304, val task loss: 103.304 [*] Best so far
model_4: train loss: 81.897, train task loss: 81.897 - val loss: 120.838, val task loss: 120.838 [*] Best so far
model_5: train loss: 78.830, train task loss: 78.830 - val loss: 113.453, val task loss: 113.453 [*] Best so far
model_6: train loss: 82.822, train task loss: 82.822 - val loss: 112.472, val task loss: 112.472 [*] Best so far

Epoch: 9/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:06<00:00, 198.03it/s]


model_1: train loss: 73.045, train task loss: 73.045 - val loss: 123.560, val task loss: 123.560
model_2: train loss: 56.333, train task loss: 56.333 - val loss: 121.892, val task loss: 121.892
model_3: train loss: 55.649, train task loss: 55.649 - val loss: 107.709, val task loss: 107.709
model_4: train loss: 68.693, train task loss: 68.693 - val loss: 121.949, val task loss: 121.949
model_5: train loss: 64.138, train task loss: 64.138 - val loss: 121.106, val task loss: 121.106
model_6: train loss: 68.715, train task loss: 68.715 - val loss: 121.611, val task loss: 121.611

Epoch: 10/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:06<00:00, 211.48it/s]


model_1: train loss: 60.067, train task loss: 60.067 - val loss: 113.316, val task loss: 113.316 [*] Best so far
model_2: train loss: 39.727, train task loss: 39.727 - val loss: 108.743, val task loss: 108.743 [*] Best so far
model_3: train loss: 38.655, train task loss: 38.655 - val loss: 98.901, val task loss: 98.901 [*] Best so far
model_4: train loss: 54.072, train task loss: 54.072 - val loss: 112.480, val task loss: 112.480 [*] Best so far
model_5: train loss: 47.555, train task loss: 47.555 - val loss: 106.079, val task loss: 106.079 [*] Best so far
model_6: train loss: 52.103, train task loss: 52.103 - val loss: 102.987, val task loss: 102.987 [*] Best so far
Training with disagreement penalty = 0.99
Computing divergence weights by clustering method...
Initialization complete
Iteration 0, inertia 24.489870383462403.
Iteration 1, inertia 12.244935191731201.
Converged at iteration 1: strict convergence.
Computed divergence weights by clustering method, weights are [0.  0.  0.5 0.

100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:10<00:00, 126.56it/s]


model_1: train loss: 59.421, train task loss: 44.750 - val loss: 130.481, val task loss: 109.740 [*] Best so far
Created directory: ./checkpoints/regression_quadratic_early/scale0_seed1/0.99
model_2: train loss: 36.791, train task loss: 25.815 - val loss: 125.651, val task loss: 106.396 [*] Best so far
model_3: train loss: 30.874, train task loss: 24.950 - val loss: 111.622, val task loss: 102.484 [*] Best so far
model_4: train loss: 55.106, train task loss: 39.691 - val loss: 144.619, val task loss: 113.567 [*] Best so far
model_5: train loss: 42.178, train task loss: 32.348 - val loss: 126.392, val task loss: 108.406 [*] Best so far
model_6: train loss: 41.310, train task loss: 35.386 - val loss: 121.669, val task loss: 112.532 [*] Best so far

Epoch: 2/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:10<00:00, 124.53it/s]


model_1: train loss: 50.294, train task loss: 36.942 - val loss: 138.950, val task loss: 107.293 [*] Best so far
model_2: train loss: 32.695, train task loss: 24.261 - val loss: 120.566, val task loss: 100.465 [*] Best so far
model_3: train loss: 23.505, train task loss: 19.249 - val loss: 107.321, val task loss: 96.542 [*] Best so far
model_4: train loss: 57.576, train task loss: 38.980 - val loss: 178.133, val task loss: 124.106
model_5: train loss: 34.494, train task loss: 26.180 - val loss: 118.994, val task loss: 99.029 [*] Best so far
model_6: train loss: 30.852, train task loss: 26.595 - val loss: 108.002, val task loss: 97.222 [*] Best so far

Epoch: 3/10 - LR: 0.001000


100%|██████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:13<00:00, 92.84it/s]


model_1: train loss: 47.799, train task loss: 30.556 - val loss: 125.174, val task loss: 102.691 [*] Best so far
model_2: train loss: 30.767, train task loss: 21.797 - val loss: 130.302, val task loss: 108.810
model_3: train loss: 17.689, train task loss: 14.387 - val loss: 110.857, val task loss: 99.232
model_4: train loss: 58.021, train task loss: 34.170 - val loss: 129.172, val task loss: 103.749 [*] Best so far
model_5: train loss: 29.504, train task loss: 22.190 - val loss: 129.596, val task loss: 108.110
model_6: train loss: 22.715, train task loss: 19.414 - val loss: 112.385, val task loss: 100.760

Epoch: 4/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:11<00:00, 108.68it/s]


model_1: train loss: 36.500, train task loss: 22.504 - val loss: 129.329, val task loss: 101.483 [*] Best so far
model_2: train loss: 40.444, train task loss: 24.569 - val loss: 129.736, val task loss: 100.771
model_3: train loss: 12.550, train task loss: 10.163 - val loss: 112.935, val task loss: 101.009
model_4: train loss: 42.561, train task loss: 25.334 - val loss: 129.606, val task loss: 98.479 [*] Best so far
model_5: train loss: 36.707, train task loss: 24.370 - val loss: 129.218, val task loss: 104.061
model_6: train loss: 15.721, train task loss: 13.335 - val loss: 108.330, val task loss: 96.404 [*] Best so far

Epoch: 5/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:11<00:00, 110.78it/s]


model_1: train loss: 29.607, train task loss: 16.209 - val loss: 129.278, val task loss: 99.055 [*] Best so far
model_2: train loss: 36.155, train task loss: 18.820 - val loss: 126.827, val task loss: 98.350 [*] Best so far
model_3: train loss: 9.108, train task loss: 7.311 - val loss: 107.129, val task loss: 94.688 [*] Best so far
model_4: train loss: 31.730, train task loss: 17.043 - val loss: 126.873, val task loss: 97.165 [*] Best so far
model_5: train loss: 36.051, train task loss: 20.326 - val loss: 124.819, val task loss: 94.919 [*] Best so far
model_6: train loss: 10.936, train task loss: 9.140 - val loss: 108.121, val task loss: 95.680 [*] Best so far

Epoch: 6/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:11<00:00, 112.38it/s]


model_1: train loss: 19.444, train task loss: 10.911 - val loss: 127.666, val task loss: 101.171
model_2: train loss: 22.136, train task loss: 11.927 - val loss: 127.572, val task loss: 100.859
model_3: train loss: 6.206, train task loss: 4.665 - val loss: 107.330, val task loss: 94.598 [*] Best so far
model_4: train loss: 20.079, train task loss: 11.349 - val loss: 129.851, val task loss: 100.695
model_5: train loss: 22.309, train task loss: 12.285 - val loss: 119.448, val task loss: 95.736
model_6: train loss: 7.553, train task loss: 6.012 - val loss: 103.777, val task loss: 91.046 [*] Best so far

Epoch: 7/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:10<00:00, 124.84it/s]


model_1: train loss: 17.141, train task loss: 9.097 - val loss: 132.036, val task loss: 103.303
model_2: train loss: 17.168, train task loss: 8.814 - val loss: 125.269, val task loss: 99.866
model_3: train loss: 5.100, train task loss: 3.679 - val loss: 105.474, val task loss: 92.614 [*] Best so far
model_4: train loss: 17.346, train task loss: 8.976 - val loss: 134.536, val task loss: 104.303
model_5: train loss: 17.648, train task loss: 9.662 - val loss: 122.225, val task loss: 96.006
model_6: train loss: 6.174, train task loss: 4.753 - val loss: 101.880, val task loss: 89.021 [*] Best so far

Epoch: 8/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:10<00:00, 124.14it/s]


model_1: train loss: 12.203, train task loss: 6.504 - val loss: 125.487, val task loss: 99.274
model_2: train loss: 11.204, train task loss: 5.965 - val loss: 122.379, val task loss: 97.836 [*] Best so far
model_3: train loss: 4.088, train task loss: 2.681 - val loss: 106.968, val task loss: 94.264
model_4: train loss: 12.835, train task loss: 7.051 - val loss: 125.019, val task loss: 97.161 [*] Best so far
model_5: train loss: 13.340, train task loss: 7.391 - val loss: 120.633, val task loss: 95.050
model_6: train loss: 4.486, train task loss: 3.079 - val loss: 101.223, val task loss: 88.520 [*] Best so far

Epoch: 9/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:10<00:00, 117.40it/s]


model_1: train loss: 11.108, train task loss: 5.547 - val loss: 123.974, val task loss: 97.836 [*] Best so far
model_2: train loss: 9.717, train task loss: 4.879 - val loss: 119.584, val task loss: 95.589 [*] Best so far
model_3: train loss: 3.441, train task loss: 2.143 - val loss: 106.628, val task loss: 93.415
model_4: train loss: 10.860, train task loss: 5.599 - val loss: 123.795, val task loss: 94.941 [*] Best so far
model_5: train loss: 10.917, train task loss: 5.864 - val loss: 118.815, val task loss: 93.269 [*] Best so far
model_6: train loss: 3.716, train task loss: 2.418 - val loss: 102.532, val task loss: 89.320

Epoch: 10/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:12<00:00, 103.63it/s]


model_1: train loss: 8.624, train task loss: 4.107 - val loss: 127.778, val task loss: 99.248
model_2: train loss: 8.024, train task loss: 3.810 - val loss: 118.056, val task loss: 94.082 [*] Best so far
model_3: train loss: 3.236, train task loss: 1.685 - val loss: 104.754, val task loss: 91.587 [*] Best so far
model_4: train loss: 7.900, train task loss: 3.904 - val loss: 123.877, val task loss: 95.720
model_5: train loss: 9.187, train task loss: 4.529 - val loss: 115.969, val task loss: 91.000 [*] Best so far
model_6: train loss: 3.488, train task loss: 1.937 - val loss: 100.965, val task loss: 87.797 [*] Best so far
Training with disagreement penalty = 3

Epoch: 1/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:11<00:00, 113.57it/s]


model_1: train loss: 465.772, train task loss: 454.360 - val loss: 264.138, val task loss: 241.149 [*] Best so far
Created directory: ./checkpoints/regression_quadratic_early/scale0_seed1/3
model_2: train loss: 445.815, train task loss: 442.062 - val loss: 281.356, val task loss: 266.070 [*] Best so far
model_3: train loss: 444.629, train task loss: 442.515 - val loss: 266.857, val task loss: 260.559 [*] Best so far
model_4: train loss: 439.843, train task loss: 436.720 - val loss: 269.151, val task loss: 256.496 [*] Best so far
model_5: train loss: 445.434, train task loss: 442.162 - val loss: 281.123, val task loss: 265.633 [*] Best so far
model_6: train loss: 433.521, train task loss: 431.408 - val loss: 255.300, val task loss: 249.002 [*] Best so far

Epoch: 2/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:07<00:00, 178.69it/s]


model_1: train loss: 267.201, train task loss: 243.100 - val loss: 238.196, val task loss: 198.514 [*] Best so far
model_2: train loss: 235.830, train task loss: 227.711 - val loss: 181.337, val task loss: 176.529 [*] Best so far
model_3: train loss: 234.532, train task loss: 230.874 - val loss: 177.329, val task loss: 174.778 [*] Best so far
model_4: train loss: 228.287, train task loss: 221.130 - val loss: 175.901, val task loss: 171.109 [*] Best so far
model_5: train loss: 236.892, train task loss: 229.160 - val loss: 179.625, val task loss: 175.283 [*] Best so far
model_6: train loss: 227.156, train task loss: 223.498 - val loss: 176.779, val task loss: 174.227 [*] Best so far

Epoch: 3/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:06<00:00, 191.88it/s]


model_1: train loss: 177.027, train task loss: 159.756 - val loss: 162.601, val task loss: 152.677 [*] Best so far
model_2: train loss: 151.671, train task loss: 147.374 - val loss: 153.454, val task loss: 148.182 [*] Best so far
model_3: train loss: 152.938, train task loss: 150.840 - val loss: 149.989, val task loss: 147.032 [*] Best so far
model_4: train loss: 150.524, train task loss: 146.088 - val loss: 153.370, val task loss: 148.479 [*] Best so far
model_5: train loss: 154.492, train task loss: 150.312 - val loss: 153.294, val task loss: 148.259 [*] Best so far
model_6: train loss: 151.491, train task loss: 149.394 - val loss: 149.765, val task loss: 146.808 [*] Best so far

Epoch: 4/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:06<00:00, 199.00it/s]


model_1: train loss: 136.683, train task loss: 129.302 - val loss: 151.955, val task loss: 141.629 [*] Best so far
model_2: train loss: 124.874, train task loss: 121.197 - val loss: 137.627, val task loss: 132.620 [*] Best so far
model_3: train loss: 125.840, train task loss: 123.571 - val loss: 133.872, val task loss: 131.104 [*] Best so far
model_4: train loss: 125.862, train task loss: 122.152 - val loss: 139.693, val task loss: 134.939 [*] Best so far
model_5: train loss: 127.682, train task loss: 123.846 - val loss: 139.095, val task loss: 134.073 [*] Best so far
model_6: train loss: 125.079, train task loss: 122.810 - val loss: 135.175, val task loss: 132.407 [*] Best so far

Epoch: 5/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:06<00:00, 202.74it/s]


model_1: train loss: 132.710, train task loss: 119.960 - val loss: 140.260, val task loss: 133.430 [*] Best so far
model_2: train loss: 112.404, train task loss: 108.971 - val loss: 133.355, val task loss: 127.855 [*] Best so far
model_3: train loss: 112.985, train task loss: 110.848 - val loss: 129.252, val task loss: 126.379 [*] Best so far
model_4: train loss: 114.349, train task loss: 110.813 - val loss: 134.807, val task loss: 129.996 [*] Best so far
model_5: train loss: 115.297, train task loss: 111.613 - val loss: 133.875, val task loss: 128.768 [*] Best so far
model_6: train loss: 112.808, train task loss: 110.671 - val loss: 130.583, val task loss: 127.710 [*] Best so far

Epoch: 6/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:06<00:00, 194.42it/s]


model_1: train loss: 134.812, train task loss: 115.435 - val loss: 174.701, val task loss: 137.595
model_2: train loss: 102.103, train task loss: 98.882 - val loss: 130.999, val task loss: 124.166 [*] Best so far
model_3: train loss: 102.818, train task loss: 100.981 - val loss: 125.388, val task loss: 121.791 [*] Best so far
model_4: train loss: 104.969, train task loss: 101.694 - val loss: 133.002, val task loss: 126.625 [*] Best so far
model_5: train loss: 105.538, train task loss: 102.069 - val loss: 131.718, val task loss: 125.144 [*] Best so far
model_6: train loss: 103.006, train task loss: 101.169 - val loss: 126.933, val task loss: 123.335 [*] Best so far

Epoch: 7/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:06<00:00, 192.13it/s]


model_1: train loss: 139.021, train task loss: 110.740 - val loss: 168.705, val task loss: 139.311
model_2: train loss: 92.166, train task loss: 88.368 - val loss: 130.072, val task loss: 121.233 [*] Best so far
model_3: train loss: 91.890, train task loss: 89.626 - val loss: 121.886, val task loss: 117.025 [*] Best so far
model_4: train loss: 95.849, train task loss: 91.684 - val loss: 133.421, val task loss: 125.328 [*] Best so far
model_5: train loss: 95.374, train task loss: 91.490 - val loss: 129.458, val task loss: 121.311 [*] Best so far
model_6: train loss: 92.407, train task loss: 90.143 - val loss: 126.581, val task loss: 121.720 [*] Best so far

Epoch: 8/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:06<00:00, 193.90it/s]


model_1: train loss: 132.874, train task loss: 105.726 - val loss: 153.310, val task loss: 125.709 [*] Best so far
model_2: train loss: 84.029, train task loss: 79.436 - val loss: 132.099, val task loss: 118.232 [*] Best so far
model_3: train loss: 82.154, train task loss: 79.211 - val loss: 118.165, val task loss: 110.650 [*] Best so far
model_4: train loss: 89.671, train task loss: 84.552 - val loss: 136.031, val task loss: 122.953 [*] Best so far
model_5: train loss: 86.518, train task loss: 81.730 - val loss: 128.313, val task loss: 114.865 [*] Best so far
model_6: train loss: 84.534, train task loss: 81.591 - val loss: 124.692, val task loss: 117.177 [*] Best so far

Epoch: 9/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:06<00:00, 194.25it/s]


model_1: train loss: 119.059, train task loss: 94.548 - val loss: 159.267, val task loss: 127.442
model_2: train loss: 72.083, train task loss: 66.545 - val loss: 136.190, val task loss: 116.090 [*] Best so far
model_3: train loss: 68.746, train task loss: 65.043 - val loss: 119.871, val task loss: 109.289 [*] Best so far
model_4: train loss: 81.183, train task loss: 73.319 - val loss: 143.508, val task loss: 122.054 [*] Best so far
model_5: train loss: 72.993, train task loss: 67.230 - val loss: 132.727, val task loss: 113.688 [*] Best so far
model_6: train loss: 70.778, train task loss: 67.075 - val loss: 126.686, val task loss: 116.103 [*] Best so far

Epoch: 10/10 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:06<00:00, 188.75it/s]


model_1: train loss: 160.656, train task loss: 95.260 - val loss: 179.879, val task loss: 120.896 [*] Best so far
model_2: train loss: 63.132, train task loss: 54.365 - val loss: 135.285, val task loss: 105.327 [*] Best so far
model_3: train loss: 58.038, train task loss: 53.263 - val loss: 113.428, val task loss: 97.978 [*] Best so far
model_4: train loss: 73.644, train task loss: 60.864 - val loss: 140.970, val task loss: 109.802 [*] Best so far
model_5: train loss: 63.264, train task loss: 53.917 - val loss: 130.176, val task loss: 101.707 [*] Best so far
model_6: train loss: 56.146, train task loss: 51.371 - val loss: 116.536, val task loss: 101.086 [*] Best so far
Finished training student cohort!
Selecting the optimal disgreement penalty via cross-validation...
Best rho: 0.99 with average task loss: 89.6919
Done!
Selecting greedy ensemble on the best cohort...
Pruned 2 worst models, keeping 4 models
Initial best models: [5, 4, 2] with losses: [tensor(87.7970, grad_fn=<MseLossBack

In [92]:
# Convert tensor values in 'Test_MSE' column to float
def convert_to_float(value):
    if isinstance(value, torch.Tensor):
        return value.item()  # Convert tensor to float
    elif isinstance(value, list):
        return [v.item() if isinstance(v, torch.Tensor) else v for v in value]  # Convert list of tensors
    return value

In [93]:
import matplotlib.pyplot as plt
results['Test_metric'] = results['Test_metric'].apply(convert_to_float)

In [95]:
# Filter out methods you don't want to include
methods_to_include = ['early_fusion', 'late_fusion', "oracle_early_fusion", "oracle_late_fusion", 
                      'simple_average','weighted_average','meta_learner', 'greedy_ensemble',
                     'best_single', 'indep_best_single']  # Change this list based on your needs
df_filtered = results[results['Method'].isin(methods_to_include)]

In [96]:
# Mean and standard deviation of Test_MSE for each method
grouped_stats = df_filtered.groupby('Method')['Test_metric'].agg(['mean', 'std'])

print(grouped_stats)

                           mean  std
Method                              
best_single           91.989861  NaN
early_fusion         110.878769  NaN
greedy_ensemble       85.793556  NaN
indep_best_single    105.578903  NaN
late_fusion          202.205124  NaN
oracle_early_fusion   93.921539  NaN
oracle_late_fusion   186.947754  NaN
simple_average         86.19706  NaN
weighted_average      86.105324  NaN


In [94]:
results

Unnamed: 0,Method,Test_metric,best_rho,cohort_pairs,random_state,dim_modalities,n,n_train,n_val,n_test,scale
0,modality_1,227.999466,,,1,"[300, 200, 100, 50]",2000,1280,320,400,0
1,modality_2,214.094528,,,1,"[300, 200, 100, 50]",2000,1280,320,400,0
2,modality_3,281.724792,,,1,"[300, 200, 100, 50]",2000,1280,320,400,0
3,modality_4,271.296265,,,1,"[300, 200, 100, 50]",2000,1280,320,400,0
4,early_fusion,110.878769,,,1,"[300, 200, 100, 50]",2000,1280,320,400,0
5,late_fusion,202.205124,,,1,"[300, 200, 100, 50]",2000,1280,320,400,0
6,oracle_modality_1,198.725403,,,1,"[300, 200, 100, 50]",2000,1280,320,400,0
7,oracle_modality_2,210.792252,,,1,"[300, 200, 100, 50]",2000,1280,320,400,0
8,oracle_modality_3,280.873779,,,1,"[300, 200, 100, 50]",2000,1280,320,400,0
9,oracle_modality_4,265.844177,,,1,"[300, 200, 100, 50]",2000,1280,320,400,0
