In [2]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import random
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import shutil


sys.path.append('../')
from meta_fusion.benchmarks_general import *
from meta_fusion.methods_general import *
from meta_fusion.models_general import *
from meta_fusion.utils_general import *
from meta_fusion.third_party import *
from meta_fusion.config import load_config
from meta_fusion.real_data_general import PrepareRealData

In [3]:
def run_single_experiment(config, extractor_config, random_state, 
                          mod_outs, combined_hiddens, mod_hiddens,
                          is_mod_static, freeze_mod_extractors):

    config['random_state'] = random_state
    extractor_config['random_state'] = random_state
    res_list = []
    best_rho = {}
    cohort_pairs = {}
    ens_idxs={}
    cluster_idxs = {}

    #----------------#
    # Split dataset  #
    #----------------#
    train_loader, val_loader, test_loader = data_preparer.get_data_loaders(random_state=random_state)
    # Get data info
    data_info = data_preparer.get_data_info()
    dim_modalities = data_info[0]
    n = data_info[1]
    n_train = data_info[2]
    n_val = data_info[3]
    n_test = data_info[4]

    print(f"Finished splitting {data_name} dataset. Data information are summarized below:\n"
            f"Modality dimensions: {dim_modalities}\n"
            f"Data size: {n}\n"
            f"Train size: {n_train}\n"
            f"Val size: {n_val}\n"
            f"Test size: {n_test}")
    sys.stdout.flush() 

    #------------------#
    # Benchmark models #
    #------------------#
    bm_extractor = Extractors([[d,0] for d in dim_modalities], dim_modalities, train_loader, val_loader)
    _ = bm_extractor.get_dummy_extractors()
    bm_cohort = Cohorts(extractors=bm_extractor, combined_hidden_layers=combined_hiddens, output_dim=output_dim)

    #----------------------------#
    # Proposed model: MetaJoint  #
    #----------------------------#
    mod_outs = [out+[dim] for out, dim in zip(mod_outs,dim_modalities)] # include the full modalities
    meta_extractor = Extractors(mod_outs, dim_modalities, train_loader, val_loader)
    if (extractor_type == 'encoder') or (extractor_type == 'separate'):
        _ = meta_extractor.get_encoder_extractors(mod_hiddens, separate=separate, config=extractor_config)
    elif extractor_type == 'PCA':
        _ = meta_extractor.get_PCA_extractors()
    meta_cohort = Cohorts(extractors=meta_extractor, combined_hidden_layers=combined_hiddens, output_dim=output_dim,
                          is_mod_static=is_mod_static, freeze_mod_extractors=freeze_mod_extractors)

    #------------------------------#
    #  Train and test benchmarks   #
    #------------------------------#
    bm_models = bm_cohort.get_cohort_models()
    _, bm_dims = bm_cohort.get_cohort_info()
    bm = Benchmarks(config, bm_models, bm_dims, [train_loader, val_loader])
    bm.train()
    res = bm.test(test_loader)
    res_list.append(res)
    print(f"Finished running basic benchmarks!")

    
    #------------------------------#
    #  Train and test MetaJoint   #
    #------------------------------#
    cohort_models = meta_cohort.get_cohort_models()
    _, dim_pairs = meta_cohort.get_cohort_info()
    metafuse = Trainer(config, cohort_models, [train_loader, val_loader])
    metafuse.train() 
    res = metafuse.test(test_loader)
    res_list.append(res)
    
    metafuse.train_ablation() 
    res = metafuse.test_ablation(test_loader)
    res_list.append(res)
    
    best_rho['cohort'] = metafuse.best_rho
    cohort_pairs['cohort'] = dim_pairs
    cohort_pairs['indep_cohort'] = dim_pairs

    if "greedy_ensemble" in config["ensemble_methods"]:
        ens_idxs['greedy_ensemble'] = metafuse.ens_idxs  

    if config['divergence_weight_type'] == "clustering":
        cluster_idxs['cohort'] = metafuse.cluster_idxs

    print(f"Finished running meta fusion!")
    
    results = []
    for res in res_list:
        for method, val in res.items():
            results.append({'Method': method, 'Test_metric': val, 
                            'best_rho':best_rho.get(method), 'cohort_pairs':cohort_pairs.get(method),
                            'ensemble_idxs': ens_idxs.get(method), 'cluster_idxs': cluster_idxs.get(method)})
    

    results = pd.DataFrame(results)

    results['random_state']=random_state
    results["dim_modalities"] = [dim_modalities] * len(results)
    results['n'] = n
    results['n_train'] = n_train
    results['n_val'] = n_val
    results['n_test'] = n_test 

    return results

In [6]:
seed=1
extractor_type = 'encoder'
fine_grained=True
num_modalities = 4 if fine_grained else 3

mod_outs = [[50],[],[],[50]]
combined_hiddens = [64, 32]
mod_hiddens = [[128]]*num_modalities

data_name = 'NACC'
output_dim = 4

if extractor_type == 'encoder':
    separate=False
    is_mod_static=[False]*num_modalities
    freeze_mod_extractors=[False]*num_modalities
elif extractor_type == "separate":
    separate=True
    is_mod_static=[False]*num_modalities
    freeze_mod_extractors=[False]*num_modalities
elif extractor_type == 'PCA':
    separate=False
    is_mod_static=[True]*num_modalities
    freeze_mod_extractors=[False]*num_modalities

config = load_config('../experiments_real/config.json')
extractor_config = load_config('../experiments_real/config_extractor.json')

# Model files directory
ckpt_dir = f"./checkpoints/{data_name}/{extractor_type}_seed{seed}/"
config['ckpt_dir'] = extractor_config['ckpt_dir'] = ckpt_dir

# Update other training parameters
config['divergence_weight_type'] = 'clustering'
config['optimal_k'] = None
config['output_dim'] = extractor_config['output_dim'] = output_dim
config["init_lr"] = 0.001
#config["epochs"] = 20
config["epochs"] = 2
config["ensemble_methods"] = [
        "simple_average",
        "weighted_average",
        "meta_learner",
        "greedy_ensemble"
        ]
extractor_config["init_lr"] = 0.001
#extractor_config["epoch"] = 20
extractor_config["epoch"] = 2

#####################
#    Load Dataset   #
#####################
data_preparer = PrepareRealData(data_name = data_name, test_size = 0.25, val_size = 0.25, fine_grained=fine_grained)
print(f"Finished loading {data_name} dataset.")
sys.stdout.flush() 

Finished loading NACC dataset.


In [7]:
repetition = 1
results = []

for i in tqdm(range(1, repetition+1), desc="Repetitions", leave=True, position=0):
    print(f'Running with repetition {i}...')
    random_state = repetition * (seed-1) + i
    set_random_seed(random_state)
    
    # Run experiment
    tmp = run_single_experiment(config, extractor_config, random_state, mod_outs, combined_hiddens, mod_hiddens,
                          is_mod_static, freeze_mod_extractors)
    
    results.append(tmp)

results = pd.concat(results, ignore_index=True)

Repetitions:   0%|                                                                               | 0/1 [00:00<?, ?it/s]

Running with repetition 1...




Finished splitting NACC dataset. Data information are summarized below:
Modality dimensions: [104, 39, 28, 192]
Data size: 1426
Train size: 801
Val size: 268
Test size: 357



  0%|                                                                                          | 0/801 [00:00<?, ?it/s][A
  8%|██████▍                                                                         | 64/801 [00:00<00:01, 493.35it/s][A
 16%|████████████▌                                                                  | 128/801 [00:00<00:01, 445.25it/s][A
 24%|██████████████████▉                                                            | 192/801 [00:00<00:01, 400.90it/s][A
 32%|█████████████████████████▏                                                     | 256/801 [00:00<00:01, 343.61it/s][A
 40%|███████████████████████████████▌                                               | 320/801 [00:00<00:01, 338.93it/s][A
 48%|█████████████████████████████████████▊                                         | 384/801 [00:01<00:01, 339.95it/s][A
 56%|████████████████████████████████████████████▏                                  | 448/801 [00:01<00:01, 338.10it/s][A
 64%|██████████

Start training benchmark models...
Training with disagreement penalty = 0

Epoch: 1/2 - LR: 0.001000



  0%|                                                                                          | 0/801 [00:00<?, ?it/s][A
  8%|██████▍                                                                         | 64/801 [00:00<00:02, 277.59it/s][A
 16%|████████████▌                                                                  | 128/801 [00:00<00:01, 414.51it/s][A
 24%|██████████████████▉                                                            | 192/801 [00:00<00:01, 450.20it/s][A
 32%|█████████████████████████▏                                                     | 256/801 [00:00<00:01, 461.58it/s][A
 40%|███████████████████████████████▌                                               | 320/801 [00:00<00:00, 509.75it/s][A
 48%|█████████████████████████████████████▊                                         | 384/801 [00:00<00:00, 470.56it/s][A
 56%|████████████████████████████████████████████▏                                  | 448/801 [00:00<00:00, 474.60it/s][A
 64%|██████████

model_1: train loss: 1.319, train task loss: 1.319 - val loss: 1.176, val task loss: 1.176 [*] Best so far
model_2: train loss: 1.180, train task loss: 1.180 - val loss: 0.972, val task loss: 0.972 [*] Best so far
model_3: train loss: 1.421, train task loss: 1.421 - val loss: 1.329, val task loss: 1.329 [*] Best so far
model_4: train loss: 1.248, train task loss: 1.248 - val loss: 1.122, val task loss: 1.122 [*] Best so far
model_5: train loss: 1.204, train task loss: 1.204 - val loss: 0.952, val task loss: 0.952 [*] Best so far

Epoch: 2/2 - LR: 0.001000



  0%|                                                                                          | 0/801 [00:00<?, ?it/s][A
  8%|██████▍                                                                         | 64/801 [00:00<00:01, 396.67it/s][A
 16%|████████████▌                                                                  | 128/801 [00:00<00:01, 493.95it/s][A
 24%|██████████████████▉                                                            | 192/801 [00:00<00:01, 478.98it/s][A
 32%|█████████████████████████▏                                                     | 256/801 [00:00<00:01, 487.45it/s][A
 40%|███████████████████████████████▌                                               | 320/801 [00:00<00:00, 519.73it/s][A
 48%|█████████████████████████████████████▊                                         | 384/801 [00:00<00:00, 530.91it/s][A
 56%|████████████████████████████████████████████▏                                  | 448/801 [00:00<00:00, 541.28it/s][A
 64%|██████████

model_1: train loss: 1.129, train task loss: 1.129 - val loss: 1.005, val task loss: 1.005 [*] Best so far
model_2: train loss: 0.997, train task loss: 0.997 - val loss: 0.894, val task loss: 0.894 [*] Best so far
model_3: train loss: 1.272, train task loss: 1.272 - val loss: 1.166, val task loss: 1.166 [*] Best so far
model_4: train loss: 1.068, train task loss: 1.068 - val loss: 0.970, val task loss: 0.970 [*] Best so far
model_5: train loss: 0.906, train task loss: 0.906 - val loss: 0.784, val task loss: 0.784 [*] Best so far
Finished training benchmark models!
Method: (modality_1), Test_Accuracy: 0.6218487394957983
Method: (modality_2), Test_Accuracy: 0.6414565826330533
Method: (modality_3), Test_Accuracy: 0.5938375350140056
Method: (modality_4), Test_Accuracy: 0.6694677871148459
Method: (early_fusion), Test_Accuracy: 0.7170868347338936
Method: (late_fusion), Test_Accuracy: 0.6218487394957983
Finished running basic benchmarks!
Start training student cohort...
Training with disagree


  0%|                                                                                          | 0/801 [00:00<?, ?it/s][A
  8%|██████▍                                                                         | 64/801 [00:00<00:02, 360.14it/s][A
 16%|████████████▌                                                                  | 128/801 [00:00<00:01, 442.75it/s][A
 24%|██████████████████▉                                                            | 192/801 [00:00<00:01, 449.01it/s][A
 32%|█████████████████████████▏                                                     | 256/801 [00:00<00:01, 447.28it/s][A
 40%|███████████████████████████████▌                                               | 320/801 [00:00<00:01, 416.15it/s][A
 48%|█████████████████████████████████████▊                                         | 384/801 [00:00<00:01, 404.57it/s][A
 56%|████████████████████████████████████████████▏                                  | 448/801 [00:01<00:00, 403.05it/s][A
 64%|██████████

model_1: train loss: 1.228, train task loss: 1.228 - val loss: 0.957, val task loss: 0.957 [*] Best so far
model_2: train loss: 1.201, train task loss: 1.201 - val loss: 0.956, val task loss: 0.956 [*] Best so far
model_3: train loss: 1.122, train task loss: 1.122 - val loss: 0.900, val task loss: 0.900 [*] Best so far
model_4: train loss: 1.339, train task loss: 1.339 - val loss: 1.118, val task loss: 1.118 [*] Best so far

Epoch: 2/2 - LR: 0.001000



  0%|                                                                                          | 0/801 [00:00<?, ?it/s][A
  8%|██████▍                                                                         | 64/801 [00:00<00:02, 255.44it/s][A
 16%|████████████▌                                                                  | 128/801 [00:00<00:02, 327.43it/s][A
 24%|██████████████████▉                                                            | 192/801 [00:00<00:01, 369.74it/s][A
 32%|█████████████████████████▏                                                     | 256/801 [00:00<00:01, 386.96it/s][A
 40%|███████████████████████████████▌                                               | 320/801 [00:00<00:01, 412.47it/s][A
 48%|█████████████████████████████████████▊                                         | 384/801 [00:01<00:01, 404.35it/s][A
 56%|████████████████████████████████████████████▏                                  | 448/801 [00:01<00:00, 427.60it/s][A
 64%|██████████

model_1: train loss: 0.913, train task loss: 0.913 - val loss: 0.797, val task loss: 0.797 [*] Best so far
model_2: train loss: 0.903, train task loss: 0.903 - val loss: 0.784, val task loss: 0.784 [*] Best so far
model_3: train loss: 0.871, train task loss: 0.871 - val loss: 0.755, val task loss: 0.755 [*] Best so far
model_4: train loss: 0.963, train task loss: 0.963 - val loss: 0.806, val task loss: 0.806 [*] Best so far
Training with disagreement penalty = 0.99
Computing divergence weights by clustering method...
Initialization complete
Iteration 0, inertia 0.000890118685575203.
Iteration 1, inertia 0.0004450593427876015.
Converged at iteration 1: strict convergence.
Computed divergence weights by clustering method, weights are [0.  0.5 0.5 0. ]

Epoch: 1/2 - LR: 0.001000



  0%|                                                                                          | 0/801 [00:00<?, ?it/s][A
  8%|██████▍                                                                         | 64/801 [00:00<00:04, 162.59it/s][A
 16%|████████████▌                                                                  | 128/801 [00:00<00:03, 179.58it/s][A
 24%|██████████████████▉                                                            | 192/801 [00:01<00:03, 186.60it/s][A
 32%|█████████████████████████▏                                                     | 256/801 [00:01<00:02, 212.85it/s][A
 40%|███████████████████████████████▌                                               | 320/801 [00:01<00:02, 198.72it/s][A
 48%|█████████████████████████████████████▊                                         | 384/801 [00:01<00:01, 221.15it/s][A
 56%|████████████████████████████████████████████▏                                  | 448/801 [00:02<00:01, 211.95it/s][A
 64%|██████████

model_1: train loss: 0.793, train task loss: 0.759 - val loss: 0.685, val task loss: 0.664 [*] Best so far
model_2: train loss: 0.744, train task loss: 0.725 - val loss: 0.647, val task loss: 0.635 [*] Best so far
model_3: train loss: 0.743, train task loss: 0.724 - val loss: 0.641, val task loss: 0.628 [*] Best so far
model_4: train loss: 0.776, train task loss: 0.744 - val loss: 0.690, val task loss: 0.664 [*] Best so far

Epoch: 2/2 - LR: 0.001000



  0%|                                                                                          | 0/801 [00:00<?, ?it/s][A
  8%|██████▍                                                                         | 64/801 [00:00<00:04, 180.17it/s][A
 16%|████████████▌                                                                  | 128/801 [00:00<00:03, 201.99it/s][A
 24%|██████████████████▉                                                            | 192/801 [00:00<00:03, 195.05it/s][A
 32%|█████████████████████████▏                                                     | 256/801 [00:01<00:02, 185.89it/s][A
 40%|███████████████████████████████▌                                               | 320/801 [00:01<00:02, 198.53it/s][A
 48%|█████████████████████████████████████▊                                         | 384/801 [00:02<00:02, 190.65it/s][A
 56%|████████████████████████████████████████████▏                                  | 448/801 [00:02<00:01, 189.30it/s][A
 64%|██████████

model_1: train loss: 0.696, train task loss: 0.672 - val loss: 0.633, val task loss: 0.606 [*] Best so far
model_2: train loss: 0.637, train task loss: 0.620 - val loss: 0.599, val task loss: 0.580 [*] Best so far
model_3: train loss: 0.643, train task loss: 0.626 - val loss: 0.588, val task loss: 0.569 [*] Best so far
model_4: train loss: 0.650, train task loss: 0.625 - val loss: 0.615, val task loss: 0.589 [*] Best so far
Training with disagreement penalty = 3

Epoch: 1/2 - LR: 0.001000



  0%|                                                                                          | 0/801 [00:00<?, ?it/s][A
  8%|██████▍                                                                         | 64/801 [00:00<00:03, 229.94it/s][A
 16%|████████████▌                                                                  | 128/801 [00:00<00:02, 262.46it/s][A
 24%|██████████████████▉                                                            | 192/801 [00:00<00:02, 242.77it/s][A
 32%|█████████████████████████▏                                                     | 256/801 [00:01<00:02, 241.54it/s][A
 40%|███████████████████████████████▌                                               | 320/801 [00:01<00:01, 245.88it/s][A
 48%|█████████████████████████████████████▊                                         | 384/801 [00:01<00:01, 247.93it/s][A
 56%|████████████████████████████████████████████▏                                  | 448/801 [00:01<00:01, 253.36it/s][A
 64%|██████████

model_1: train loss: 1.273, train task loss: 1.231 - val loss: 1.020, val task loss: 0.974 [*] Best so far
model_2: train loss: 1.219, train task loss: 1.202 - val loss: 0.972, val task loss: 0.957 [*] Best so far
model_3: train loss: 1.150, train task loss: 1.132 - val loss: 0.915, val task loss: 0.900 [*] Best so far
model_4: train loss: 1.496, train task loss: 1.335 - val loss: 1.319, val task loss: 1.102 [*] Best so far

Epoch: 2/2 - LR: 0.001000



  0%|                                                                                          | 0/801 [00:00<?, ?it/s][A
  8%|██████▍                                                                         | 64/801 [00:00<00:03, 188.61it/s][A
 16%|████████████▌                                                                  | 128/801 [00:00<00:03, 204.84it/s][A
 24%|██████████████████▉                                                            | 192/801 [00:00<00:02, 205.26it/s][A
 32%|█████████████████████████▏                                                     | 256/801 [00:01<00:02, 223.76it/s][A
 40%|███████████████████████████████▌                                               | 320/801 [00:01<00:02, 221.58it/s][A
 48%|█████████████████████████████████████▊                                         | 384/801 [00:01<00:01, 214.47it/s][A
 56%|████████████████████████████████████████████▏                                  | 448/801 [00:02<00:01, 213.22it/s][A
 64%|██████████

model_1: train loss: 0.999, train task loss: 0.947 - val loss: 0.881, val task loss: 0.831 [*] Best so far
model_2: train loss: 0.953, train task loss: 0.924 - val loss: 0.833, val task loss: 0.807 [*] Best so far
model_3: train loss: 0.896, train task loss: 0.868 - val loss: 0.788, val task loss: 0.763 [*] Best so far
model_4: train loss: 1.193, train task loss: 1.046 - val loss: 1.075, val task loss: 0.932 [*] Best so far
Finished training student cohort!
Selecting the optimal disgreement penalty via cross-validation...
Best rho: 0.99 with average task loss: 0.5744
Done!
Training meta learner on the best cohort...



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 486.51it/s]                                                                                               [A
128it [00:00, 447.03it/s][A
192it [00:00, 432.48it/s][A
256it [00:00, 404.39it/s][A
320it [00:00, 403.12it/s][A
384it [00:00, 380.43it/s][A
448it [00:01, 366.38it/s][A
512it [00:01, 385.68it/s][A
576it [00:01, 395.47it/s][A
640it [00:01, 407.84it/s][A
704it [00:01, 402.34it/s][A
801it [00:02, 389.05it/s][A


meta_learner: train task loss: 0.766 - val task loss: 0.555 [*] Best so far



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 326.35it/s]                                                                                               [A
128it [00:00, 393.23it/s][A
192it [00:00, 403.31it/s][A
256it [00:00, 406.85it/s][A
320it [00:00, 425.80it/s][A
384it [00:00, 470.43it/s][A
448it [00:00, 503.17it/s][A
512it [00:01, 521.86it/s][A
576it [00:01, 529.67it/s][A
640it [00:01, 509.65it/s][A
704it [00:01, 486.23it/s][A
801it [00:01, 458.89it/s][A


meta_learner: train task loss: 0.562 - val task loss: 0.543 [*] Best so far



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 349.19it/s]                                                                                               [A
128it [00:00, 390.51it/s][A
192it [00:00, 436.37it/s][A
256it [00:00, 457.14it/s][A
320it [00:00, 444.52it/s][A
384it [00:00, 446.92it/s][A
448it [00:01, 431.79it/s][A
512it [00:01, 428.66it/s][A
576it [00:01, 407.10it/s][A
640it [00:01, 408.19it/s][A
704it [00:01, 414.06it/s][A
801it [00:01, 405.61it/s][A


meta_learner: train task loss: 0.513 - val task loss: 0.581



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 326.36it/s]                                                                                               [A
128it [00:00, 437.59it/s][A
192it [00:00, 465.03it/s][A
256it [00:00, 479.68it/s][A
320it [00:00, 463.68it/s][A
384it [00:00, 459.17it/s][A
448it [00:00, 472.82it/s][A
512it [00:01, 486.11it/s][A
576it [00:01, 475.47it/s][A
640it [00:01, 457.12it/s][A
704it [00:01, 430.00it/s][A
801it [00:01, 426.74it/s][A


meta_learner: train task loss: 0.493 - val task loss: 0.580



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 331.48it/s]                                                                                               [A
128it [00:00, 432.10it/s][A
192it [00:00, 433.75it/s][A
256it [00:00, 429.92it/s][A
320it [00:00, 443.44it/s][A
384it [00:00, 458.94it/s][A
448it [00:01, 443.51it/s][A
512it [00:01, 435.19it/s][A
576it [00:01, 422.35it/s][A
640it [00:01, 441.40it/s][A
704it [00:01, 436.75it/s][A
801it [00:01, 428.35it/s][A


meta_learner: train task loss: 0.487 - val task loss: 0.620



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 357.45it/s]                                                                                               [A
128it [00:00, 405.06it/s][A
192it [00:00, 407.43it/s][A
256it [00:00, 413.47it/s][A
320it [00:00, 409.22it/s][A
384it [00:00, 397.78it/s][A
448it [00:01, 399.91it/s][A
512it [00:01, 416.17it/s][A
576it [00:01, 411.84it/s][A
640it [00:01, 402.91it/s][A
704it [00:01, 400.18it/s][A
801it [00:02, 389.84it/s][A


meta_learner: train task loss: 0.461 - val task loss: 0.621



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 261.65it/s]                                                                                               [A
128it [00:00, 341.89it/s][A
192it [00:00, 374.39it/s][A
256it [00:00, 389.01it/s][A
320it [00:00, 385.99it/s][A
384it [00:01, 384.13it/s][A
448it [00:01, 383.22it/s][A
512it [00:01, 377.90it/s][A
576it [00:01, 380.23it/s][A
640it [00:01, 398.17it/s][A
704it [00:01, 391.45it/s][A
801it [00:02, 369.29it/s][A


meta_learner: train task loss: 0.463 - val task loss: 0.609



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 280.89it/s]                                                                                               [A
128it [00:00, 352.83it/s][A
192it [00:00, 371.63it/s][A
256it [00:00, 385.61it/s][A
320it [00:00, 448.78it/s][A
384it [00:00, 463.64it/s][A
448it [00:01, 472.66it/s][A
512it [00:01, 473.70it/s][A
576it [00:01, 466.79it/s][A
640it [00:01, 448.60it/s][A
704it [00:01, 439.76it/s][A
801it [00:01, 413.83it/s][A


meta_learner: train task loss: 0.482 - val task loss: 0.758



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 291.51it/s]                                                                                               [A
128it [00:00, 373.24it/s][A
192it [00:00, 443.34it/s][A
256it [00:00, 474.42it/s][A
320it [00:00, 461.09it/s][A
384it [00:00, 426.07it/s][A
448it [00:01, 465.87it/s][A
512it [00:01, 465.97it/s][A
576it [00:01, 479.04it/s][A
640it [00:01, 459.82it/s][A
704it [00:01, 426.88it/s][A
801it [00:01, 422.95it/s][A


meta_learner: train task loss: 0.499 - val task loss: 0.639



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 318.50it/s]                                                                                               [A
128it [00:00, 367.06it/s][A
192it [00:00, 386.84it/s][A
256it [00:00, 397.29it/s][A
320it [00:00, 393.33it/s][A
384it [00:00, 416.28it/s][A
448it [00:01, 403.06it/s][A
512it [00:01, 397.63it/s][A
576it [00:01, 402.10it/s][A
640it [00:01, 410.71it/s][A
704it [00:01, 441.93it/s][A
801it [00:01, 407.09it/s][A


meta_learner: train task loss: 0.498 - val task loss: 0.617



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 338.78it/s]                                                                                               [A
128it [00:00, 386.94it/s][A
192it [00:00, 450.68it/s][A
256it [00:00, 461.47it/s][A
320it [00:00, 479.92it/s][A
384it [00:00, 458.81it/s][A
448it [00:01, 447.94it/s][A
512it [00:01, 440.78it/s][A
576it [00:01, 439.20it/s][A
640it [00:01, 432.60it/s][A
704it [00:01, 437.05it/s][A
801it [00:01, 416.32it/s][A


meta_learner: train task loss: 0.496 - val task loss: 0.662



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 321.90it/s]                                                                                               [A
128it [00:00, 384.18it/s][A
192it [00:00, 373.97it/s][A
256it [00:00, 391.66it/s][A
320it [00:00, 399.21it/s][A
384it [00:00, 392.47it/s][A
448it [00:01, 426.14it/s][A
512it [00:01, 445.88it/s][A
576it [00:01, 450.78it/s][A
640it [00:01, 486.31it/s][A
704it [00:01, 512.11it/s][A
801it [00:01, 442.80it/s][A


meta_learner: train task loss: 0.481 - val task loss: 0.593



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 411.29it/s]                                                                                               [A
128it [00:00, 515.40it/s][A
192it [00:00, 506.61it/s][A
256it [00:00, 476.91it/s][A
320it [00:00, 513.40it/s][A
384it [00:00, 542.03it/s][A
448it [00:00, 551.64it/s][A
512it [00:00, 559.48it/s][A
576it [00:01, 572.53it/s][A
640it [00:01, 542.24it/s][A
704it [00:01, 545.11it/s][A
801it [00:01, 511.64it/s][A


meta_learner: train task loss: 0.451 - val task loss: 0.659



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 412.49it/s]                                                                                               [A
128it [00:00, 459.67it/s][A
192it [00:00, 499.21it/s][A
256it [00:00, 516.64it/s][A
320it [00:00, 542.23it/s][A
384it [00:00, 551.28it/s][A
448it [00:00, 551.73it/s][A
512it [00:00, 550.81it/s][A
576it [00:01, 532.69it/s][A
640it [00:01, 515.34it/s][A
704it [00:01, 513.54it/s][A
801it [00:01, 489.24it/s][A


meta_learner: train task loss: 0.490 - val task loss: 0.600



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 358.14it/s]                                                                                               [A
128it [00:00, 402.66it/s][A
192it [00:00, 405.33it/s][A
256it [00:00, 420.74it/s][A
320it [00:00, 411.63it/s][A
384it [00:00, 417.56it/s][A
448it [00:01, 406.71it/s][A
512it [00:01, 420.79it/s][A
576it [00:01, 432.43it/s][A
640it [00:01, 437.46it/s][A
704it [00:01, 418.59it/s][A
801it [00:01, 403.67it/s][A


meta_learner: train task loss: 0.494 - val task loss: 0.666
Done!
Selecting greedy ensemble on the best cohort...
Pruned 1 worst models, keeping 3 models
Initial best models: [2, 1, 3] with losses: [tensor(0.5688, grad_fn=<NllLossBackward0>), tensor(0.5800, grad_fn=<NllLossBackward0>), tensor(0.5893, grad_fn=<NllLossBackward0>)]
Done!
Method: (simple_average), Test_Accuracy: 0.7591036414565826
Method: (weighted_average), Test_Accuracy: 0.7591036414565826
Method: (meta_learner), Test_Accuracy: 0.7703081232492998
Method: (greedy_ensemble), Test_Accuracy: 0.7591036414565826
Method: (best_single), Test_Accuracy: 0.7394957983193278
Method: (cohort), Test_Accuracy: [0.7366946778711485, 0.7366946778711485, 0.7394957983193278, 0.7619047619047619]
Training meta learner with no disagreement penalty...



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 292.34it/s]                                                                                               [A
128it [00:00, 361.39it/s][A
192it [00:00, 395.39it/s][A
256it [00:00, 395.12it/s][A
320it [00:00, 414.84it/s][A
384it [00:00, 421.14it/s][A
448it [00:01, 438.18it/s][A
512it [00:01, 441.48it/s][A
576it [00:01, 468.10it/s][A
640it [00:01, 468.66it/s][A
704it [00:01, 468.55it/s][A
801it [00:01, 422.42it/s][A


meta_learner: train task loss: 0.935 - val task loss: 0.647 [*] Best so far



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 286.77it/s]                                                                                               [A
128it [00:00, 329.01it/s][A
192it [00:00, 365.30it/s][A
256it [00:00, 369.96it/s][A
320it [00:00, 366.59it/s][A
384it [00:01, 379.82it/s][A
448it [00:01, 393.84it/s][A
512it [00:01, 427.18it/s][A
576it [00:01, 433.89it/s][A
640it [00:01, 454.04it/s][A
704it [00:01, 470.48it/s][A
801it [00:01, 408.25it/s][A


meta_learner: train task loss: 0.684 - val task loss: 0.589 [*] Best so far



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 315.99it/s]                                                                                               [A
128it [00:00, 403.03it/s][A
192it [00:00, 457.60it/s][A
256it [00:00, 474.82it/s][A
320it [00:00, 484.15it/s][A
384it [00:00, 495.45it/s][A
448it [00:00, 494.89it/s][A
576it [00:01, 562.88it/s][A
640it [00:01, 560.56it/s][A
704it [00:01, 537.96it/s][A
801it [00:01, 484.11it/s][A


meta_learner: train task loss: 0.630 - val task loss: 0.600



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 347.94it/s]                                                                                               [A
128it [00:00, 397.52it/s][A
192it [00:00, 416.84it/s][A
256it [00:00, 462.29it/s][A
320it [00:00, 469.02it/s][A
384it [00:00, 460.15it/s][A
448it [00:01, 451.11it/s][A
512it [00:01, 440.01it/s][A
576it [00:01, 424.41it/s][A
640it [00:01, 403.14it/s][A
704it [00:01, 400.59it/s][A
801it [00:01, 406.87it/s][A


meta_learner: train task loss: 0.581 - val task loss: 0.557 [*] Best so far



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 323.70it/s]                                                                                               [A
128it [00:00, 377.50it/s][A
192it [00:00, 394.44it/s][A
256it [00:00, 421.15it/s][A
320it [00:00, 442.13it/s][A
384it [00:00, 438.77it/s][A
448it [00:01, 457.19it/s][A
512it [00:01, 495.35it/s][A
576it [00:01, 494.84it/s][A
640it [00:01, 486.06it/s][A
704it [00:01, 481.75it/s][A
801it [00:01, 438.44it/s][A


meta_learner: train task loss: 0.557 - val task loss: 0.575



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 302.23it/s]                                                                                               [A
128it [00:00, 355.91it/s][A
192it [00:00, 396.31it/s][A
256it [00:00, 427.81it/s][A
320it [00:00, 462.03it/s][A
384it [00:00, 487.96it/s][A
448it [00:00, 501.93it/s][A
512it [00:01, 517.08it/s][A
576it [00:01, 499.30it/s][A
640it [00:01, 504.63it/s][A
704it [00:01, 511.51it/s][A
801it [00:01, 462.70it/s][A


meta_learner: train task loss: 0.587 - val task loss: 0.582



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 359.99it/s]                                                                                               [A
128it [00:00, 438.82it/s][A
192it [00:00, 428.87it/s][A
256it [00:00, 418.44it/s][A
320it [00:00, 420.44it/s][A
384it [00:00, 443.07it/s][A
448it [00:01, 471.26it/s][A
512it [00:01, 494.13it/s][A
576it [00:01, 490.46it/s][A
640it [00:01, 451.89it/s][A
704it [00:01, 442.35it/s][A
801it [00:01, 424.44it/s][A


meta_learner: train task loss: 0.549 - val task loss: 0.585



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 370.23it/s]                                                                                               [A
128it [00:00, 444.75it/s][A
192it [00:00, 474.44it/s][A
256it [00:00, 465.23it/s][A
320it [00:00, 495.46it/s][A
384it [00:00, 474.67it/s][A
448it [00:00, 487.39it/s][A
512it [00:01, 474.58it/s][A
576it [00:01, 459.41it/s][A
640it [00:01, 424.51it/s][A
704it [00:01, 444.89it/s][A
801it [00:01, 440.36it/s][A


meta_learner: train task loss: 0.542 - val task loss: 0.620



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 370.18it/s]                                                                                               [A
128it [00:00, 420.94it/s][A
192it [00:00, 433.52it/s][A
256it [00:00, 493.78it/s][A
320it [00:00, 452.84it/s][A
384it [00:00, 432.85it/s][A
448it [00:01, 420.55it/s][A
512it [00:01, 422.20it/s][A
576it [00:01, 438.70it/s][A
640it [00:01, 452.20it/s][A
704it [00:01, 452.86it/s][A
801it [00:01, 431.91it/s][A


meta_learner: train task loss: 0.543 - val task loss: 0.694



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 323.87it/s]                                                                                               [A
128it [00:00, 393.28it/s][A
192it [00:00, 409.80it/s][A
256it [00:00, 393.03it/s][A
320it [00:00, 398.38it/s][A
384it [00:00, 400.69it/s][A
448it [00:01, 419.46it/s][A
512it [00:01, 434.47it/s][A
576it [00:01, 443.27it/s][A
640it [00:01, 466.38it/s][A
704it [00:01, 459.41it/s][A
801it [00:01, 422.19it/s][A


meta_learner: train task loss: 0.534 - val task loss: 0.570



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 409.94it/s]                                                                                               [A
128it [00:00, 436.64it/s][A
192it [00:00, 470.83it/s][A
256it [00:00, 478.92it/s][A
320it [00:00, 473.04it/s][A
384it [00:00, 500.21it/s][A
448it [00:00, 501.79it/s][A
512it [00:01, 517.61it/s][A
576it [00:01, 509.40it/s][A
640it [00:01, 509.09it/s][A
704it [00:01, 534.23it/s][A
801it [00:01, 491.32it/s][A


meta_learner: train task loss: 0.532 - val task loss: 0.586



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 357.84it/s]                                                                                               [A
128it [00:00, 379.24it/s][A
192it [00:00, 448.50it/s][A
256it [00:00, 469.62it/s][A
320it [00:00, 478.37it/s][A
384it [00:00, 505.17it/s][A
448it [00:00, 519.67it/s][A
512it [00:01, 522.98it/s][A
576it [00:01, 502.91it/s][A
640it [00:01, 482.95it/s][A
704it [00:01, 458.02it/s][A
801it [00:01, 431.66it/s][A


meta_learner: train task loss: 0.564 - val task loss: 0.625



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 309.26it/s]                                                                                               [A
128it [00:00, 375.60it/s][A
192it [00:00, 376.18it/s][A
256it [00:00, 434.49it/s][A
320it [00:00, 468.48it/s][A
384it [00:00, 493.07it/s][A
448it [00:01, 483.99it/s][A
512it [00:01, 500.01it/s][A
576it [00:01, 505.53it/s][A
640it [00:01, 496.78it/s][A
704it [00:01, 507.47it/s][A
801it [00:01, 445.17it/s][A


meta_learner: train task loss: 0.567 - val task loss: 0.575



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 296.17it/s]                                                                                               [A
128it [00:00, 347.85it/s][A
192it [00:00, 413.86it/s][A
256it [00:00, 419.96it/s][A
320it [00:00, 419.86it/s][A
384it [00:00, 426.85it/s][A
448it [00:01, 457.54it/s][A
512it [00:01, 475.83it/s][A
576it [00:01, 478.43it/s][A
640it [00:01, 482.29it/s][A
704it [00:01, 489.54it/s][A
801it [00:01, 432.43it/s][A


meta_learner: train task loss: 0.592 - val task loss: 0.571



  0%|                                                                                           | 0/13 [00:00<?, ?it/s][A
64it [00:00, 406.50it/s]                                                                                               [A
128it [00:00, 444.53it/s][A
192it [00:00, 472.44it/s][A
256it [00:00, 509.58it/s][A
320it [00:00, 503.55it/s][A
384it [00:00, 491.53it/s][A
448it [00:00, 497.95it/s][A
512it [00:01, 525.58it/s][A
576it [00:01, 530.92it/s][A
640it [00:01, 516.32it/s][A
704it [00:01, 523.27it/s][A
801it [00:01, 485.41it/s][A
Repetitions: 100%|██████████████████████████████████████████████████████████████████████| 1/1 [02:16<00:00, 136.66s/it]

meta_learner: train task loss: 0.566 - val task loss: 0.672
Done!
Method: (simple_average), Test_Accuracy: 0.6946778711484594
Method: (weighted_average), Test_Accuracy: 0.6946778711484594
Method: (meta_learner), Test_Accuracy: 0.7647058823529411
Method: (greedy_ensemble), Test_Accuracy: 0.6834733893557423
Method: (best_single), Test_Accuracy: 0.7198879551820728
Method: (cohort), Test_Accuracy: [0.7282913165266106, 0.7030812324929971, 0.7198879551820728, 0.6666666666666666]
Finished running meta fusion!





In [10]:
results

Unnamed: 0,Method,Test_metric,best_rho,cohort_pairs,ensemble_idxs,cluster_idxs,random_state,dim_modalities,n,n_train,n_val,n_test
0,modality_1,0.621849,,,,,1,"[104, 39, 28, 192]",1426,801,268,357
1,modality_2,0.641457,,,,,1,"[104, 39, 28, 192]",1426,801,268,357
2,modality_3,0.593838,,,,,1,"[104, 39, 28, 192]",1426,801,268,357
3,modality_4,0.669468,,,,,1,"[104, 39, 28, 192]",1426,801,268,357
4,early_fusion,0.717087,,,,,1,"[104, 39, 28, 192]",1426,801,268,357
5,late_fusion,0.621849,,,,,1,"[104, 39, 28, 192]",1426,801,268,357
6,simple_average,0.759104,,,,,1,"[104, 39, 28, 192]",1426,801,268,357
7,weighted_average,0.759104,,,,,1,"[104, 39, 28, 192]",1426,801,268,357
8,meta_learner,0.770308,,,,,1,"[104, 39, 28, 192]",1426,801,268,357
9,greedy_ensemble,0.759104,,,"[2, 1, 3]",,1,"[104, 39, 28, 192]",1426,801,268,357
