In [52]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import random
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import shutil

sys.path.append('../')
from meta_fusion.benchmarks import *
from meta_fusion.methods import *
from meta_fusion.models import *
from meta_fusion.utils import *
from meta_fusion.third_party import *
from meta_fusion.synthetic_data import Prepare_synthetic_data
from meta_fusion.config import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [53]:
# Fixed data parameters
repetition=1
seed=1
scale=0

# Data model parameters
n = 2000
[d1, d2] = dim_modalities = [500, 100]
dim_latent = [20, 20, 0]
noise_ratios = [0.5, 0.1]
trans_type = ["quadratic", "linear", "linear"]
mod_prop = [1, 1, 0, 1]
interactive_prop = 0.1

# mod1_outs = [40, 60, 80, 100]
# mod2_outs = [40, 60, 80, 100]

mod1_outs = [60]
mod2_outs = [40, 60]
combined_hiddens = [300,200,100]
mod1_hiddens = mod2_hiddens = [128]

# data parameters
data_name = 'regression'
exp_name = data_name + "_" + "quadratic_early"
output_dim = 1  # specify the output dimension for regression


extractor_type = 'PCA'
seperate=False
is_static_mod1=True
is_static_mod2=True
freeze_extractor_mod1=False
freeze_extractor_mod2=False

# Load default model configurations 
config = load_config('../experiments_synthetic/config.json')
extractor_config = load_config('../experiments_synthetic/config_extractor.json')

# Model files directory
ckpt_dir = f"./checkpoints/{exp_name}/scale{scale}_seed{seed}/"
config['ckpt_dir'] = extractor_config['ckpt_dir'] = ckpt_dir

# Update other training parameters
config['divergence_weight_scale'] =  scale
config['output_dim'] = extractor_config['output_dim'] = output_dim
config["init_lr"] = 0.001
config["epochs"] = 30
extractor_config["init_lr"] = 0.001

#####################
#    Load Dataset   #
#####################
data_preparer = Prepare_synthetic_data(data_name = data_name, test_size = 0.2, val_size = 0.2)

In [54]:
def run_single_experiment(config, extractor_config, n,
                          random_state, 
                          run_oracle=False, run_coop=True, run_all_at_once=False):

    config['random_state'] = random_state
    extractor_config['random_state'] = random_state
    res_list = []
    best_rho = {}
    cohort_pairs = {}

    #----------------#
    # Split dataset  #
    #----------------#
    train_loader, val_loader, test_loader, oracle_train_loader, oracle_val_loader, oracle_test_loader =\
    data_preparer.get_data_loaders(n, trans_type=trans_type, mod_prop=mod_prop, 
                                   interactive_prop = interactive_prop,
                                   dim_modalities=dim_modalities, dim_latent=dim_latent,
                                   noise_ratios=noise_ratios, random_state=random_state)
    # Get data info
    data_info = data_preparer.get_data_info()
    d1 = data_info[0]
    d2 = data_info[1]
    n = data_info[2]
    n_train = data_info[3]
    n_val = data_info[4]
    n_test = data_info[5]

    print(f"Finished splitting {data_name} dataset. Data information are summarized below:\n"
            f"Modality 1 dimension: {d1}\n"
            f"Modality 2 dimension: {d2}\n"
            f"Data size: {n}\n"
            f"Train size: {n_train}\n"
            f"Val size: {n_val}\n"
            f"Test size: {n_test}")
    sys.stdout.flush() 

    #------------------#
    # Benchmark models #
    #------------------#
    bm_extractor = Extractors([0, d1], [0,d2], d1, d2, train_loader, val_loader)
    _ = bm_extractor.get_dummy_extractors()
    bm_cohort = Cohorts(extractors=bm_extractor, combined_hidden_layers=combined_hiddens, output_dim=output_dim)

    if run_oracle:
        oracle_d1 = dim_latent[0]
        oracle_d2 = dim_latent[1]+dim_latent[2]
        oracle_extractor = Extractors([0, oracle_d1], [0, oracle_d2], oracle_d1, oracle_d2, oracle_train_loader, oracle_val_loader)
        _ = oracle_extractor.get_dummy_extractors()
        oracle_cohort = Cohorts(extractors=oracle_extractor, combined_hidden_layers=combined_hiddens, output_dim=output_dim)

    #----------------------------#
    # Proposed model: MetaJoint  #
    #----------------------------#
    meta_extractor = Extractors(mod1_outs, mod2_outs, d1, d2, train_loader, val_loader)
    if extractor_type == 'encoder':
        _ = meta_extractor.get_encoder_extractors(mod1_hiddens, mod2_hiddens, seperate=seperate, config=extractor_config)
    elif extractor_type == 'PCA':
        _ = meta_extractor.get_PCA_extractors()
    meta_cohort = Cohorts(extractors=meta_extractor, combined_hidden_layers=combined_hiddens, output_dim=output_dim,
                          is_static_mod1=is_static_mod1, is_static_mod2=is_static_mod2,
                          freeze_extractor_mod1=freeze_extractor_mod1, freeze_extractor_mod2=freeze_extractor_mod2)


    #------------------------------#
    #  Train and test benchmarks   #
    #------------------------------#
    # bm_models = bm_cohort.get_cohort_models()
    # _, bm_dims = bm_cohort.get_cohort_info()
    # bm = Benchmarks(config, bm_models, bm_dims, [train_loader, val_loader])
    # bm.train()
    # res = bm.test(test_loader)
    # res_list.append(res)
    # print(f"Finished running basic benchmarks!")
    
    if run_oracle:
        oracle_config = config
        oracle_config["init_lr"] = 0.001
        oracle_models = oracle_cohort.get_cohort_models()
        _, oracle_dims = oracle_cohort.get_cohort_info()
        oracle = Benchmarks(config, oracle_models, oracle_dims, [oracle_train_loader, oracle_val_loader])
        oracle.train()
        res = oracle.test(oracle_test_loader)
        res = {f"oracle_{key}": value for key, value in res.items()}
        res_list.append(res)
        print(f"Finished running oracle benchmarks!")
        
    if run_coop:
        bm_models = bm_cohort.get_cohort_models()
        _, bm_dims = bm_cohort.get_cohort_info()    
        coop = Coop(config, bm_models, bm_dims, [train_loader, val_loader])
        coop.train()
        res = coop.test(test_loader)
        res_list.append(res)
        best_rho['coop'] = coop.best_rho
        print(f"Finished running coop!")

    
    #------------------------------#
    #  Train and test MetaJoint   #
    #------------------------------#
    cohort_models = meta_cohort.get_cohort_models()
    _, dim_pairs = meta_cohort.get_cohort_info()
    metafuse = Trainer(config, cohort_models, [train_loader, val_loader])
    metafuse.train() 
    res = metafuse.test(test_loader)
    res_list.append(res)
    
    metafuse.train_ablation() 
    res = metafuse.test_ablation(test_loader)
    res_list.append(res)
    
    best_rho['meta_learner'] = metafuse.best_rho
    cohort_pairs['cohort'] = dim_pairs
    print(f"Finished running meta fusion!")
    
    
    if run_all_at_once:
        cohort_models = meta_cohort.get_cohort_models()
        config['epochs']=30
        allin1 = Trainer_all_at_once(config, cohort_models, [train_loader, val_loader])
        allin1.train()
        res = allin1.test(test_loader)
        res_list.append(res)
        best_rho['all_at_once'] = allin1.best_rho
        print(f"Finished running meta fusion all in one model!")

    
    results = []
    for res in res_list:
        for method, val in res.items():
            results.append({'Method': method, 'Test_metric': val, 
                            'best_rho':best_rho.get(method), 'cohort_pairs':cohort_pairs.get(method)})
    

    results = pd.DataFrame(results)

    results['random_state']=random_state
    results["d1"] = d1
    results["d2"] = d2
    results['n'] = n
    results['n_train'] = n_train
    results['n_val'] = n_val
    results['n_test'] = n_test 
    results['scale'] = scale

    return results

In [55]:
results = []

#for i in tqdm(range(1, repetition+1), desc="Repetitions", leave=True, position=0):
for i in range(1, repetition+1):
    print(f'Running with repetition {i}...')
    random_state = repetition * (seed-1) + i
    set_random_seed(random_state)
    
    # Run experiment
    tmp = run_single_experiment(config, extractor_config, n, run_oracle=False, run_coop=False, random_state=random_state)
    
    results.append(tmp)

results = pd.concat(results, ignore_index=True)


#####################
#    Save Results   #
#####################
# results.to_csv(outfile, index=False)
# print("\nResults written to {:s}\n".format(outfile))
# sys.stdout.flush()

# After the job is done, remove the model directory to free up space
if os.path.exists(ckpt_dir):
    print(f"Deleting the model checkpoint directory: {ckpt_dir}")
    shutil.rmtree(ckpt_dir)
    print(f"Model checkpoint directory {ckpt_dir} has been deleted.")

Running with repetition 1...
Finished splitting regression dataset. Data information are summarized below:
Modality 1 dimension: 500
Modality 2 dimension: 100
Data size: 2000
Train size: 1280
Val size: 320
Test size: 400
Start training student cohort...
Training with disagreement penalty = 0

Epoch: 1/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 600.94it/s]


model_1: train loss: 408.705, train task loss: 408.705 - val loss: 279.459, val task loss: 279.459 [*] Best so far
model_2: train loss: 440.947, train task loss: 440.947 - val loss: 262.540, val task loss: 262.540 [*] Best so far

Epoch: 2/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 583.37it/s]


model_1: train loss: 218.667, train task loss: 218.667 - val loss: 194.831, val task loss: 194.831 [*] Best so far
model_2: train loss: 232.896, train task loss: 232.896 - val loss: 205.024, val task loss: 205.024 [*] Best so far

Epoch: 3/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 505.18it/s]


model_1: train loss: 166.180, train task loss: 166.180 - val loss: 174.376, val task loss: 174.376 [*] Best so far
model_2: train loss: 176.177, train task loss: 176.177 - val loss: 182.301, val task loss: 182.301 [*] Best so far

Epoch: 4/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 495.44it/s]


model_1: train loss: 144.897, train task loss: 144.897 - val loss: 172.473, val task loss: 172.473 [*] Best so far
model_2: train loss: 152.241, train task loss: 152.241 - val loss: 177.110, val task loss: 177.110 [*] Best so far

Epoch: 5/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 506.78it/s]


model_1: train loss: 132.107, train task loss: 132.107 - val loss: 161.203, val task loss: 161.203 [*] Best so far
model_2: train loss: 138.863, train task loss: 138.863 - val loss: 163.188, val task loss: 163.188 [*] Best so far

Epoch: 6/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 504.57it/s]


model_1: train loss: 118.716, train task loss: 118.716 - val loss: 154.478, val task loss: 154.478 [*] Best so far
model_2: train loss: 125.767, train task loss: 125.767 - val loss: 157.276, val task loss: 157.276 [*] Best so far

Epoch: 7/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 450.93it/s]


model_1: train loss: 104.479, train task loss: 104.479 - val loss: 148.378, val task loss: 148.378 [*] Best so far
model_2: train loss: 113.206, train task loss: 113.206 - val loss: 151.336, val task loss: 151.336 [*] Best so far

Epoch: 8/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 461.36it/s]


model_1: train loss: 90.017, train task loss: 90.017 - val loss: 138.058, val task loss: 138.058 [*] Best so far
model_2: train loss: 99.632, train task loss: 99.632 - val loss: 139.174, val task loss: 139.174 [*] Best so far

Epoch: 9/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 468.87it/s]


model_1: train loss: 68.111, train task loss: 68.111 - val loss: 130.390, val task loss: 130.390 [*] Best so far
model_2: train loss: 79.440, train task loss: 79.440 - val loss: 132.483, val task loss: 132.483 [*] Best so far

Epoch: 10/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 489.10it/s]


model_1: train loss: 45.512, train task loss: 45.512 - val loss: 122.762, val task loss: 122.762 [*] Best so far
model_2: train loss: 56.370, train task loss: 56.370 - val loss: 126.820, val task loss: 126.820 [*] Best so far

Epoch: 11/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 459.83it/s]


model_1: train loss: 31.065, train task loss: 31.065 - val loss: 109.724, val task loss: 109.724 [*] Best so far
model_2: train loss: 36.460, train task loss: 36.460 - val loss: 108.902, val task loss: 108.902 [*] Best so far

Epoch: 12/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 457.04it/s]


model_1: train loss: 20.935, train task loss: 20.935 - val loss: 105.221, val task loss: 105.221 [*] Best so far
model_2: train loss: 24.844, train task loss: 24.844 - val loss: 98.588, val task loss: 98.588 [*] Best so far

Epoch: 13/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 484.59it/s]


model_1: train loss: 13.753, train task loss: 13.753 - val loss: 96.523, val task loss: 96.523 [*] Best so far
model_2: train loss: 15.599, train task loss: 15.599 - val loss: 90.256, val task loss: 90.256 [*] Best so far

Epoch: 14/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 475.45it/s]


model_1: train loss: 8.610, train task loss: 8.610 - val loss: 93.181, val task loss: 93.181 [*] Best so far
model_2: train loss: 10.426, train task loss: 10.426 - val loss: 83.801, val task loss: 83.801 [*] Best so far

Epoch: 15/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 451.12it/s]


model_1: train loss: 6.198, train task loss: 6.198 - val loss: 88.932, val task loss: 88.932 [*] Best so far
model_2: train loss: 7.705, train task loss: 7.705 - val loss: 86.961, val task loss: 86.961

Epoch: 16/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 484.26it/s]


model_1: train loss: 3.710, train task loss: 3.710 - val loss: 87.048, val task loss: 87.048 [*] Best so far
model_2: train loss: 4.994, train task loss: 4.994 - val loss: 80.288, val task loss: 80.288 [*] Best so far

Epoch: 17/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 463.81it/s]


model_1: train loss: 2.323, train task loss: 2.323 - val loss: 87.732, val task loss: 87.732
model_2: train loss: 3.526, train task loss: 3.526 - val loss: 79.985, val task loss: 79.985 [*] Best so far

Epoch: 18/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 457.12it/s]


model_1: train loss: 1.703, train task loss: 1.703 - val loss: 86.398, val task loss: 86.398 [*] Best so far
model_2: train loss: 2.912, train task loss: 2.912 - val loss: 80.521, val task loss: 80.521

Epoch: 19/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 456.20it/s]


model_1: train loss: 1.217, train task loss: 1.217 - val loss: 84.460, val task loss: 84.460 [*] Best so far
model_2: train loss: 1.989, train task loss: 1.989 - val loss: 78.373, val task loss: 78.373 [*] Best so far

Epoch: 20/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 451.81it/s]


model_1: train loss: 0.850, train task loss: 0.850 - val loss: 83.209, val task loss: 83.209 [*] Best so far
model_2: train loss: 1.429, train task loss: 1.429 - val loss: 77.490, val task loss: 77.490 [*] Best so far

Epoch: 21/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 478.28it/s]


model_1: train loss: 0.647, train task loss: 0.647 - val loss: 84.297, val task loss: 84.297
model_2: train loss: 1.224, train task loss: 1.224 - val loss: 76.425, val task loss: 76.425 [*] Best so far

Epoch: 22/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 445.32it/s]


model_1: train loss: 0.496, train task loss: 0.496 - val loss: 86.025, val task loss: 86.025
model_2: train loss: 0.751, train task loss: 0.751 - val loss: 77.237, val task loss: 77.237

Epoch: 23/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 467.85it/s]


model_1: train loss: 0.407, train task loss: 0.407 - val loss: 84.755, val task loss: 84.755
model_2: train loss: 0.434, train task loss: 0.434 - val loss: 77.527, val task loss: 77.527

Epoch: 24/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 461.70it/s]


model_1: train loss: 0.235, train task loss: 0.235 - val loss: 84.899, val task loss: 84.899
model_2: train loss: 0.359, train task loss: 0.359 - val loss: 77.889, val task loss: 77.889

Epoch: 25/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 462.12it/s]


model_1: train loss: 0.178, train task loss: 0.178 - val loss: 84.156, val task loss: 84.156
model_2: train loss: 0.295, train task loss: 0.295 - val loss: 76.542, val task loss: 76.542

Epoch: 26/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 474.28it/s]


model_1: train loss: 0.138, train task loss: 0.138 - val loss: 84.428, val task loss: 84.428
model_2: train loss: 0.213, train task loss: 0.213 - val loss: 76.794, val task loss: 76.794

Epoch: 27/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 470.95it/s]


model_1: train loss: 0.086, train task loss: 0.086 - val loss: 84.491, val task loss: 84.491
model_2: train loss: 0.179, train task loss: 0.179 - val loss: 76.633, val task loss: 76.633

Epoch: 28/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 500.72it/s]


model_1: train loss: 0.058, train task loss: 0.058 - val loss: 84.252, val task loss: 84.252
model_2: train loss: 0.147, train task loss: 0.147 - val loss: 75.803, val task loss: 75.803 [*] Best so far

Epoch: 29/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 737.08it/s]


model_1: train loss: 0.047, train task loss: 0.047 - val loss: 84.372, val task loss: 84.372
model_2: train loss: 0.109, train task loss: 0.109 - val loss: 75.870, val task loss: 75.870

Epoch: 30/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 860.91it/s]


model_1: train loss: 0.037, train task loss: 0.037 - val loss: 84.240, val task loss: 84.240
model_2: train loss: 0.077, train task loss: 0.077 - val loss: 75.595, val task loss: 75.595 [*] Best so far
Training with disagreement penalty = 0.99

Epoch: 1/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 743.64it/s]


model_1: train loss: 411.694, train task loss: 408.857 - val loss: 269.849, val task loss: 261.042 [*] Best so far
model_2: train loss: 442.252, train task loss: 439.415 - val loss: 274.474, val task loss: 265.667 [*] Best so far

Epoch: 2/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 797.03it/s]


model_1: train loss: 215.263, train task loss: 212.236 - val loss: 195.959, val task loss: 193.211 [*] Best so far
model_2: train loss: 239.780, train task loss: 236.700 - val loss: 209.490, val task loss: 206.742 [*] Best so far

Epoch: 3/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 763.24it/s]


model_1: train loss: 165.417, train task loss: 164.146 - val loss: 177.087, val task loss: 174.618 [*] Best so far
model_2: train loss: 178.389, train task loss: 177.029 - val loss: 185.021, val task loss: 182.552 [*] Best so far

Epoch: 4/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 731.80it/s]


model_1: train loss: 143.319, train task loss: 142.439 - val loss: 165.357, val task loss: 163.222 [*] Best so far
model_2: train loss: 153.746, train task loss: 152.827 - val loss: 171.956, val task loss: 169.821 [*] Best so far

Epoch: 5/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 765.42it/s]


model_1: train loss: 128.544, train task loss: 127.669 - val loss: 166.557, val task loss: 164.112
model_2: train loss: 140.180, train task loss: 139.270 - val loss: 170.321, val task loss: 167.876 [*] Best so far

Epoch: 6/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 731.59it/s]


model_1: train loss: 115.343, train task loss: 114.220 - val loss: 156.112, val task loss: 153.014 [*] Best so far
model_2: train loss: 130.096, train task loss: 128.947 - val loss: 164.265, val task loss: 161.166 [*] Best so far

Epoch: 7/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 753.14it/s]


model_1: train loss: 98.953, train task loss: 97.433 - val loss: 156.160, val task loss: 151.992 [*] Best so far
model_2: train loss: 117.619, train task loss: 116.019 - val loss: 165.019, val task loss: 160.851 [*] Best so far

Epoch: 8/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 726.34it/s]


model_1: train loss: 81.510, train task loss: 78.668 - val loss: 143.389, val task loss: 136.876 [*] Best so far
model_2: train loss: 107.063, train task loss: 104.055 - val loss: 153.404, val task loss: 146.891 [*] Best so far

Epoch: 9/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 710.44it/s]


model_1: train loss: 61.159, train task loss: 56.022 - val loss: 137.839, val task loss: 122.266 [*] Best so far
model_2: train loss: 92.221, train task loss: 86.709 - val loss: 158.157, val task loss: 142.584 [*] Best so far

Epoch: 10/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 737.37it/s]


model_1: train loss: 42.973, train task loss: 36.427 - val loss: 132.293, val task loss: 117.657 [*] Best so far
model_2: train loss: 71.712, train task loss: 64.079 - val loss: 138.932, val task loss: 124.296 [*] Best so far

Epoch: 11/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 685.16it/s]


model_1: train loss: 28.234, train task loss: 23.022 - val loss: 123.885, val task loss: 103.698 [*] Best so far
model_2: train loss: 46.338, train task loss: 40.831 - val loss: 135.292, val task loss: 115.104 [*] Best so far

Epoch: 12/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 490.50it/s]


model_1: train loss: 20.270, train task loss: 16.271 - val loss: 124.887, val task loss: 99.996 [*] Best so far
model_2: train loss: 31.575, train task loss: 27.136 - val loss: 130.003, val task loss: 105.113 [*] Best so far

Epoch: 13/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 501.79it/s]


model_1: train loss: 15.949, train task loss: 12.372 - val loss: 112.714, val task loss: 93.619 [*] Best so far
model_2: train loss: 22.602, train task loss: 18.842 - val loss: 111.171, val task loss: 92.076 [*] Best so far

Epoch: 14/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 674.31it/s]


model_1: train loss: 11.389, train task loss: 8.556 - val loss: 113.007, val task loss: 92.476 [*] Best so far
model_2: train loss: 15.449, train task loss: 12.663 - val loss: 112.775, val task loss: 92.243

Epoch: 15/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 713.53it/s]


model_1: train loss: 8.797, train task loss: 6.532 - val loss: 109.643, val task loss: 88.134 [*] Best so far
model_2: train loss: 11.831, train task loss: 9.572 - val loss: 111.872, val task loss: 90.363 [*] Best so far

Epoch: 16/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 748.11it/s]


model_1: train loss: 6.844, train task loss: 4.747 - val loss: 111.700, val task loss: 90.426
model_2: train loss: 9.123, train task loss: 6.973 - val loss: 106.194, val task loss: 84.920 [*] Best so far

Epoch: 17/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 723.64it/s]


model_1: train loss: 5.940, train task loss: 3.789 - val loss: 108.944, val task loss: 86.416 [*] Best so far
model_2: train loss: 6.932, train task loss: 4.911 - val loss: 106.146, val task loss: 83.619 [*] Best so far

Epoch: 18/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 796.44it/s]


model_1: train loss: 5.138, train task loss: 2.949 - val loss: 109.314, val task loss: 87.888
model_2: train loss: 5.996, train task loss: 3.878 - val loss: 102.987, val task loss: 81.561 [*] Best so far

Epoch: 19/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 756.16it/s]


model_1: train loss: 4.509, train task loss: 2.401 - val loss: 109.593, val task loss: 88.323
model_2: train loss: 4.584, train task loss: 2.628 - val loss: 101.409, val task loss: 80.139 [*] Best so far

Epoch: 20/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 789.90it/s]


model_1: train loss: 3.978, train task loss: 2.119 - val loss: 108.656, val task loss: 87.782
model_2: train loss: 3.631, train task loss: 1.944 - val loss: 98.680, val task loss: 77.806 [*] Best so far

Epoch: 21/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 782.90it/s]


model_1: train loss: 3.338, train task loss: 1.729 - val loss: 107.290, val task loss: 85.452 [*] Best so far
model_2: train loss: 2.838, train task loss: 1.411 - val loss: 102.950, val task loss: 81.112

Epoch: 22/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 766.50it/s]


model_1: train loss: 3.627, train task loss: 1.754 - val loss: 111.343, val task loss: 88.533
model_2: train loss: 3.096, train task loss: 1.318 - val loss: 101.393, val task loss: 78.583

Epoch: 23/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 773.59it/s]


model_1: train loss: 3.491, train task loss: 1.559 - val loss: 106.895, val task loss: 83.657 [*] Best so far
model_2: train loss: 2.905, train task loss: 1.190 - val loss: 104.894, val task loss: 81.656

Epoch: 24/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 745.08it/s]


model_1: train loss: 3.282, train task loss: 1.446 - val loss: 112.302, val task loss: 87.867
model_2: train loss: 2.916, train task loss: 1.125 - val loss: 101.278, val task loss: 76.843 [*] Best so far

Epoch: 25/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 797.24it/s]


model_1: train loss: 3.641, train task loss: 1.569 - val loss: 106.854, val task loss: 84.535
model_2: train loss: 2.879, train task loss: 1.067 - val loss: 102.908, val task loss: 80.589

Epoch: 26/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 845.28it/s]


model_1: train loss: 2.851, train task loss: 1.249 - val loss: 105.206, val task loss: 83.019 [*] Best so far
model_2: train loss: 2.367, train task loss: 0.839 - val loss: 103.892, val task loss: 81.705

Epoch: 27/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 752.75it/s]


model_1: train loss: 2.248, train task loss: 0.951 - val loss: 105.482, val task loss: 84.907
model_2: train loss: 1.996, train task loss: 0.719 - val loss: 99.008, val task loss: 78.433

Epoch: 28/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 829.97it/s]


model_1: train loss: 2.305, train task loss: 0.910 - val loss: 104.792, val task loss: 84.351
model_2: train loss: 2.011, train task loss: 0.723 - val loss: 98.682, val task loss: 78.240

Epoch: 29/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 791.34it/s]


model_1: train loss: 2.697, train task loss: 1.058 - val loss: 107.698, val task loss: 81.409 [*] Best so far
model_2: train loss: 2.335, train task loss: 0.815 - val loss: 109.487, val task loss: 83.199

Epoch: 30/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 728.31it/s]


model_1: train loss: 4.102, train task loss: 1.605 - val loss: 111.416, val task loss: 86.422
model_2: train loss: 3.790, train task loss: 1.238 - val loss: 105.449, val task loss: 80.455
Training with disagreement penalty = 3

Epoch: 1/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 694.47it/s]


model_1: train loss: 402.841, train task loss: 397.166 - val loss: 270.244, val task loss: 255.632 [*] Best so far
model_2: train loss: 432.583, train task loss: 426.908 - val loss: 283.282, val task loss: 268.670 [*] Best so far

Epoch: 2/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 853.62it/s]


model_1: train loss: 221.483, train task loss: 215.250 - val loss: 201.958, val task loss: 194.727 [*] Best so far
model_2: train loss: 241.732, train task loss: 235.181 - val loss: 213.329, val task loss: 206.098 [*] Best so far

Epoch: 3/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 691.67it/s]


model_1: train loss: 167.791, train task loss: 165.058 - val loss: 181.603, val task loss: 174.642 [*] Best so far
model_2: train loss: 180.275, train task loss: 177.383 - val loss: 191.705, val task loss: 184.744 [*] Best so far

Epoch: 4/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 737.10it/s]


model_1: train loss: 145.148, train task loss: 142.603 - val loss: 167.718, val task loss: 162.958 [*] Best so far
model_2: train loss: 156.893, train task loss: 154.201 - val loss: 173.678, val task loss: 168.919 [*] Best so far

Epoch: 5/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 772.55it/s]


model_1: train loss: 129.841, train task loss: 127.407 - val loss: 162.318, val task loss: 156.549 [*] Best so far
model_2: train loss: 144.928, train task loss: 142.405 - val loss: 169.228, val task loss: 163.459 [*] Best so far

Epoch: 6/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 672.22it/s]


model_1: train loss: 114.778, train task loss: 111.764 - val loss: 161.088, val task loss: 151.311 [*] Best so far
model_2: train loss: 133.643, train task loss: 130.495 - val loss: 173.129, val task loss: 163.351 [*] Best so far

Epoch: 7/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 771.42it/s]


model_1: train loss: 101.302, train task loss: 96.731 - val loss: 163.467, val task loss: 150.923 [*] Best so far
model_2: train loss: 125.591, train task loss: 120.657 - val loss: 179.677, val task loss: 167.134

Epoch: 8/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 627.83it/s]


model_1: train loss: 86.322, train task loss: 80.052 - val loss: 148.851, val task loss: 134.563 [*] Best so far
model_2: train loss: 114.959, train task loss: 108.015 - val loss: 161.756, val task loss: 147.469 [*] Best so far

Epoch: 9/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 738.09it/s]


model_1: train loss: 72.039, train task loss: 62.968 - val loss: 146.812, val task loss: 128.110 [*] Best so far
model_2: train loss: 104.385, train task loss: 94.444 - val loss: 158.888, val task loss: 140.186 [*] Best so far

Epoch: 10/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 700.90it/s]


model_1: train loss: 58.049, train task loss: 47.453 - val loss: 152.789, val task loss: 124.540 [*] Best so far
model_2: train loss: 87.887, train task loss: 76.293 - val loss: 168.402, val task loss: 140.153 [*] Best so far

Epoch: 11/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 714.71it/s]


model_1: train loss: 42.383, train task loss: 33.620 - val loss: 152.161, val task loss: 114.788 [*] Best so far
model_2: train loss: 61.349, train task loss: 51.487 - val loss: 158.832, val task loss: 121.459 [*] Best so far

Epoch: 12/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 755.83it/s]


model_1: train loss: 32.200, train task loss: 22.619 - val loss: 155.732, val task loss: 107.228 [*] Best so far
model_2: train loss: 41.860, train task loss: 31.722 - val loss: 159.748, val task loss: 111.244 [*] Best so far

Epoch: 13/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 692.26it/s]


model_1: train loss: 35.396, train task loss: 18.990 - val loss: 162.486, val task loss: 102.158 [*] Best so far
model_2: train loss: 39.353, train task loss: 22.332 - val loss: 163.433, val task loss: 103.105 [*] Best so far

Epoch: 14/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 795.42it/s]


model_1: train loss: 33.604, train task loss: 17.034 - val loss: 152.175, val task loss: 98.973 [*] Best so far
model_2: train loss: 34.035, train task loss: 17.312 - val loss: 149.048, val task loss: 95.846 [*] Best so far

Epoch: 15/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 679.03it/s]


model_1: train loss: 30.320, train task loss: 14.024 - val loss: 145.992, val task loss: 94.353 [*] Best so far
model_2: train loss: 29.408, train task loss: 13.627 - val loss: 143.411, val task loss: 91.772 [*] Best so far

Epoch: 16/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 646.69it/s]


model_1: train loss: 24.050, train task loss: 10.442 - val loss: 147.234, val task loss: 92.246 [*] Best so far
model_2: train loss: 23.893, train task loss: 10.657 - val loss: 146.719, val task loss: 91.732 [*] Best so far

Epoch: 17/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 750.04it/s]


model_1: train loss: 20.903, train task loss: 8.353 - val loss: 152.248, val task loss: 96.561
model_2: train loss: 21.082, train task loss: 8.602 - val loss: 142.113, val task loss: 86.426 [*] Best so far

Epoch: 18/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 839.06it/s]


model_1: train loss: 20.360, train task loss: 7.269 - val loss: 143.583, val task loss: 90.115 [*] Best so far
model_2: train loss: 18.875, train task loss: 7.158 - val loss: 141.977, val task loss: 88.509

Epoch: 19/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 668.58it/s]


model_1: train loss: 17.274, train task loss: 6.355 - val loss: 144.724, val task loss: 88.938 [*] Best so far
model_2: train loss: 16.532, train task loss: 5.807 - val loss: 144.356, val task loss: 88.570

Epoch: 20/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 632.09it/s]


model_1: train loss: 13.253, train task loss: 4.908 - val loss: 157.371, val task loss: 88.034 [*] Best so far
model_2: train loss: 13.162, train task loss: 4.851 - val loss: 157.005, val task loss: 87.668

Epoch: 21/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 691.57it/s]


model_1: train loss: 11.937, train task loss: 3.859 - val loss: 138.970, val task loss: 88.585
model_2: train loss: 12.242, train task loss: 4.198 - val loss: 133.352, val task loss: 82.967 [*] Best so far

Epoch: 22/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 750.33it/s]


model_1: train loss: 15.597, train task loss: 4.235 - val loss: 143.166, val task loss: 87.513 [*] Best so far
model_2: train loss: 14.694, train task loss: 4.053 - val loss: 139.855, val task loss: 84.202

Epoch: 23/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 675.50it/s]


model_1: train loss: 18.433, train task loss: 4.704 - val loss: 139.629, val task loss: 88.567
model_2: train loss: 17.287, train task loss: 4.078 - val loss: 132.708, val task loss: 81.646 [*] Best so far

Epoch: 24/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 769.54it/s]


model_1: train loss: 13.897, train task loss: 3.738 - val loss: 133.739, val task loss: 87.342 [*] Best so far
model_2: train loss: 12.864, train task loss: 3.499 - val loss: 129.133, val task loss: 82.737

Epoch: 25/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 745.02it/s]


model_1: train loss: 10.591, train task loss: 2.908 - val loss: 141.288, val task loss: 85.298 [*] Best so far
model_2: train loss: 9.970, train task loss: 2.693 - val loss: 140.001, val task loss: 84.010

Epoch: 26/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 726.93it/s]


model_1: train loss: 10.263, train task loss: 2.570 - val loss: 142.467, val task loss: 84.941 [*] Best so far
model_2: train loss: 9.983, train task loss: 2.406 - val loss: 142.157, val task loss: 84.631

Epoch: 27/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 791.50it/s]


model_1: train loss: 7.685, train task loss: 2.040 - val loss: 133.251, val task loss: 85.163
model_2: train loss: 7.614, train task loss: 1.988 - val loss: 130.769, val task loss: 82.681

Epoch: 28/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 767.38it/s]


model_1: train loss: 6.307, train task loss: 1.589 - val loss: 133.494, val task loss: 86.742
model_2: train loss: 6.244, train task loss: 1.664 - val loss: 126.343, val task loss: 79.592 [*] Best so far

Epoch: 29/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 743.53it/s]


model_1: train loss: 5.771, train task loss: 1.450 - val loss: 133.912, val task loss: 83.564 [*] Best so far
model_2: train loss: 5.340, train task loss: 1.375 - val loss: 133.651, val task loss: 83.304

Epoch: 30/30 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 734.82it/s]


model_1: train loss: 5.594, train task loss: 1.345 - val loss: 131.457, val task loss: 85.952
model_2: train loss: 5.533, train task loss: 1.298 - val loss: 125.179, val task loss: 79.674
Finished training student cohort!
Selecting the optimal disgreement penalty via cross-validation...
Best rho: 0.99 with average task loss: 79.1259
Done!
Training meta learner on the best cohort...


1280it [00:01, 1123.34it/s]                                                                                            


meta_learner: train task loss: 199.611 - val task loss: 134.955 [*] Best so far


1280it [00:01, 1038.46it/s]                                                                                            


meta_learner: train task loss: 47.713 - val task loss: 101.165 [*] Best so far


1280it [00:01, 1099.88it/s]                                                                                            


meta_learner: train task loss: 24.305 - val task loss: 94.734 [*] Best so far


1280it [00:01, 1054.69it/s]                                                                                            


meta_learner: train task loss: 20.933 - val task loss: 95.073


1280it [00:01, 1097.65it/s]                                                                                            


meta_learner: train task loss: 19.983 - val task loss: 93.290 [*] Best so far
Done!
Method: (simple_average), Test_MSE: 60.048126220703125
Method: (weighted_average), Test_MSE: 60.01173400878906
Method: (meta_learner), Test_MSE: 63.95914077758789
Method: (best_single), Test_MSE: 63.67018508911133
Method: (cohort), Test_MSE: [tensor(66.4824), tensor(63.6702)]
Training meta learner with no disagreement penalty...


1280it [00:01, 988.54it/s]                                                                                             


meta_learner: train task loss: 525.323 - val task loss: 425.372 [*] Best so far


1280it [00:01, 1145.76it/s]                                                                                            


meta_learner: train task loss: 458.939 - val task loss: 377.778 [*] Best so far


1280it [00:01, 1073.28it/s]                                                                                            


meta_learner: train task loss: 405.014 - val task loss: 333.191 [*] Best so far


1280it [00:01, 1066.47it/s]                                                                                            


meta_learner: train task loss: 357.224 - val task loss: 297.993 [*] Best so far


1280it [00:01, 1069.67it/s]                                                                                            


meta_learner: train task loss: 316.696 - val task loss: 268.519 [*] Best so far
Done!
Method: (simple_average), Test_MSE: 57.140743255615234
Method: (weighted_average), Test_MSE: 57.05751037597656
Method: (meta_learner), Test_MSE: 305.0787353515625
Method: (best_single), Test_MSE: 60.327919006347656
Method: (cohort), Test_MSE: [tensor(64.2954), tensor(60.3279)]
Finished running meta fusion!
Deleting the model checkpoint directory: ./checkpoints/regression_quadratic_early/scale0_seed1/


NameError: name 'shutil' is not defined

In [None]:
# Convert tensor values in 'Test_MSE' column to float
def convert_to_float(value):
    if isinstance(value, torch.Tensor):
        return value.item()  # Convert tensor to float
    elif isinstance(value, list):
        return [v.item() if isinstance(v, torch.Tensor) else v for v in value]  # Convert list of tensors
    return value

In [None]:
import matplotlib.pyplot as plt
results_pd['Test_metric'] = results_pd['Test_metric'].apply(convert_to_float)

In [None]:
# Filter out methods you don't want to include
methods_to_include = ['oracle_mod1', 'oracle_mod2', 'oracle_early_fusion', 'oracle_late_fusion', 
                      'mod1', 'mod2', 'early_fusion', 'late_fusion',  
                      'meta_learner', 'coop']  # Change this list based on your needs
df_filtered = results_pd[results_pd['Method'].isin(methods_to_include)]

In [None]:
# Mean and standard deviation of Test_MSE for each method
grouped_stats = df_filtered.groupby('Method')['Test_metric'].agg(['mean', 'std'])

print(grouped_stats)

In [None]:
results_pd