In [20]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import random
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import shutil

sys.path.append('../')
from meta_fusion.benchmarks import *
from meta_fusion.methods import *
from meta_fusion.models import *
from meta_fusion.utils import *
from meta_fusion.third_party import *
from meta_fusion.synthetic_data import Prepare_synthetic_data
from meta_fusion.config import *

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
# Fixed data parameters
repetition=1
seed=1
scale=0

# Data model parameters
n = 2000
[d1, d2] = dim_modalities = [2000, 400]
dim_latent = [50, 30, 20]
noise_ratios = [0.4, 0.3]
trans_type = ["linear", "linear", "quadratic"]
mod_prop = [0, 0, 1, 0]
interactive_prop = 0

# mod1_outs = [40, 60, 80, 100]
# mod2_outs = [40, 60, 80, 100]

mod1_outs = [0, 60, 80]
mod2_outs = [0, 60, 80]
combined_hiddens = [300,200,100]
mod1_hiddens = mod2_hiddens = [128]

# data parameters
data_name = 'regression'
exp_name = data_name + "_" + "quadratic_late"
output_dim = 1  # specify the output dimension for regression


extractor_type = 'encoder'
separate=False
is_static_mod1=False
is_static_mod2=False
freeze_extractor_mod1=False
freeze_extractor_mod2=False

# Load default model configurations 
config = load_config('../experiments_synthetic/config.json')
extractor_config = load_config('../experiments_synthetic/config_extractor.json')

# Model files directory
ckpt_dir = f"./checkpoints/{exp_name}/scale{scale}_seed{seed}/"
config['ckpt_dir'] = extractor_config['ckpt_dir'] = ckpt_dir

# Update other training parameters
config['divergence_weight_scale'] =  scale
config['output_dim'] = extractor_config['output_dim'] = output_dim
config["init_lr"] = 0.001
config["epochs"] = 1
extractor_config["init_lr"] = 0.001
extractor_config["epochs"] = 1


config["ensemble_methods"]=[
        "simple_average",
        "weighted_average",
        "greedy_ensemble"
        ]
config["rho_list"]=[0]

#####################
#    Load Dataset   #
#####################
data_preparer = Prepare_synthetic_data(data_name = data_name, test_size = 0.2, val_size = 0.2)

In [37]:
def run_single_experiment(config, extractor_config, n,
                          random_state, 
                          run_oracle=False, run_coop=True, 
                          run_coop_linear=False, run_all_at_once=False):

    config['random_state'] = random_state
    extractor_config['random_state'] = random_state
    res_list = []
    best_rho = {}
    cohort_pairs = {}
    ens_idxs={}

    #----------------#
    # Split dataset  #
    #----------------#
    train_loader, val_loader, test_loader, oracle_train_loader, oracle_val_loader, oracle_test_loader =\
    data_preparer.get_data_loaders(n, trans_type=trans_type, mod_prop=mod_prop, 
                                   interactive_prop = interactive_prop,
                                   dim_modalities=dim_modalities, dim_latent=dim_latent,
                                   noise_ratios=noise_ratios, random_state=random_state)
    # Get data info
    data_info = data_preparer.get_data_info()
    d1 = data_info[0]
    d2 = data_info[1]
    n = data_info[2]
    n_train = data_info[3]
    n_val = data_info[4]
    n_test = data_info[5]

    print(f"Finished splitting {data_name} dataset. Data information are summarized below:\n"
            f"Modality 1 dimension: {d1}\n"
            f"Modality 2 dimension: {d2}\n"
            f"Data size: {n}\n"
            f"Train size: {n_train}\n"
            f"Val size: {n_val}\n"
            f"Test size: {n_test}")
    sys.stdout.flush() 

    #------------------#
    # Benchmark models #
    #------------------#
    bm_extractor = Extractors([0, d1], [0,d2], d1, d2, train_loader, val_loader)
    _ = bm_extractor.get_dummy_extractors()
    bm_cohort = Cohorts(extractors=bm_extractor, combined_hidden_layers=combined_hiddens, output_dim=output_dim)

    if run_oracle:
        oracle_d1 = dim_latent[0]
        oracle_d2 = dim_latent[1]+dim_latent[2]
        oracle_extractor = Extractors([0, oracle_d1], [0, oracle_d2], oracle_d1, oracle_d2, oracle_train_loader, oracle_val_loader)
        _ = oracle_extractor.get_dummy_extractors()
        oracle_cohort = Cohorts(extractors=oracle_extractor, combined_hidden_layers=combined_hiddens, output_dim=output_dim)

    #----------------------------#
    # Proposed model: MetaJoint  #
    #----------------------------#
    meta_extractor = Extractors(mod1_outs, mod2_outs, d1, d2, train_loader, val_loader)
    if extractor_type == 'encoder':
        _ = meta_extractor.get_encoder_extractors(mod1_hiddens, mod2_hiddens, separate=separate, config=extractor_config)
    elif extractor_type == 'PCA':
        _ = meta_extractor.get_PCA_extractors()
    meta_cohort = Cohorts(extractors=meta_extractor, combined_hidden_layers=combined_hiddens, output_dim=output_dim,
                          is_static_mod1=is_static_mod1, is_static_mod2=is_static_mod2,
                          freeze_extractor_mod1=freeze_extractor_mod1, freeze_extractor_mod2=freeze_extractor_mod2)


    #------------------------------#
    #  Train and test benchmarks   #
    #------------------------------#
    bm_models = bm_cohort.get_cohort_models()
    _, bm_dims = bm_cohort.get_cohort_info()
    bm = Benchmarks(config, bm_models, bm_dims, [train_loader, val_loader])
    bm.train()
    res = bm.test(test_loader)
    res_list.append(res)
    print(f"Finished running basic benchmarks!")
    
    if run_oracle:
        oracle_config = config
        oracle_config["init_lr"] = 0.001
        oracle_models = oracle_cohort.get_cohort_models()
        _, oracle_dims = oracle_cohort.get_cohort_info()
        oracle = Benchmarks(config, oracle_models, oracle_dims, [oracle_train_loader, oracle_val_loader])
        oracle.train()
        res = oracle.test(oracle_test_loader)
        res = {f"oracle_{key}": value for key, value in res.items()}
        res_list.append(res)
        print(f"Finished running oracle benchmarks!")
        
    if run_coop:
        bm_models = bm_cohort.get_cohort_models()
        _, bm_dims = bm_cohort.get_cohort_info()    
        coop = Coop(config, bm_models, bm_dims, [train_loader, val_loader])
        coop.train()
        res = coop.test(test_loader)
        res_list.append(res)
        best_rho['coop'] = coop.best_rho
        print(f"Finished running coop!")

    if run_coop_linear:
        cooplinear = CoopLinear([train_loader, val_loader], rho_list=config['rho_list'])
        cooplinear.train()
        res = cooplinear.test(test_loader)
        res_list.append(res)
        print(f"Finished running coop linear!")

    
    #------------------------------#
    #  Train and test MetaJoint   #
    #------------------------------#
    cohort_models = meta_cohort.get_cohort_models()
    _, dim_pairs = meta_cohort.get_cohort_info()
    metafuse = Trainer(config, cohort_models, [train_loader, val_loader])
    metafuse.train() 
    res = metafuse.test(test_loader)
    res_list.append(res)
    
    metafuse.train_ablation() 
    res = metafuse.test_ablation(test_loader)
    res_list.append(res)
    
    best_rho['meta_learner'] = metafuse.best_rho
    cohort_pairs['cohort'] = dim_pairs
    cohort_pairs['indep_cohort'] = dim_pairs
    if "greedy_ensemble" in config["ensemble_methods"]:
        ens_idxs['greedy_ensemble'] = metafuse.ens_idxs        
    
    print(f"Finished running meta fusion!")
    
    
    if run_all_at_once:
        cohort_models = meta_cohort.get_cohort_models()
        config['epochs']=30
        allin1 = Trainer_all_at_once(config, cohort_models, [train_loader, val_loader])
        allin1.train()
        res = allin1.test(test_loader)
        res_list.append(res)
        best_rho['all_at_once'] = allin1.best_rho
        print(f"Finished running meta fusion all in one model!")

    
    results = []
    for res in res_list:
        for method, val in res.items():
            results.append({'Method': method, 'Test_metric': val, 
                            'best_rho':best_rho.get(method), 'cohort_pairs':cohort_pairs.get(method),
                            'ensemble_idxs': ens_idxs.get(method)})
    

    results = pd.DataFrame(results)

    results['random_state']=random_state
    results["d1"] = d1
    results["d2"] = d2
    results['n'] = n
    results['n_train'] = n_train
    results['n_val'] = n_val
    results['n_test'] = n_test 
    results['scale'] = scale

    return results

In [38]:
#####################
#  Run Experiments  #
#####################
results = []

for i in range(1, repetition+1):
    print(f'Running with repetition {i}...')
    random_state = repetition * (seed-1) + i
    set_random_seed(random_state)
    
    # Run experiment
    tmp = run_single_experiment(config, extractor_config, n, run_oracle=False, run_coop=False, random_state=random_state)
    
    results.append(tmp)

results = pd.concat(results, ignore_index=True)

Running with repetition 1...
Finished splitting regression dataset. Data information are summarized below:
Modality 1 dimension: 2000
Modality 2 dimension: 400
Data size: 2000
Train size: 1280
Val size: 320
Test size: 400


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:01<00:00, 928.56it/s]


Start training benchmark models...
Training with disagreement penalty = 0

Epoch: 1/1 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:02<00:00, 541.23it/s]


model_1: train loss: 180.765, train task loss: 180.765 - val loss: 129.376, val task loss: 129.376 [*] Best so far
model_2: train loss: 178.988, train task loss: 178.988 - val loss: 114.404, val task loss: 114.404 [*] Best so far
model_3: train loss: 200.411, train task loss: 200.411 - val loss: 104.654, val task loss: 104.654 [*] Best so far
Finished training benchmark models!
Method: (mod1), Test_MSE: 134.7310791015625
Method: (mod2), Test_MSE: 147.7878875732422
Method: (early_fusion), Test_MSE: 147.08448791503906
Method: (late_fusion), Test_MSE: 129.3821258544922
Finished running basic benchmarks!
Start training student cohort...
Training with disagreement penalty = 0

Epoch: 1/1 - LR: 0.001000


100%|█████████████████████████████████████████████████████████████████████████████| 1280/1280 [00:08<00:00, 152.12it/s]


model_1: train loss: 154.208, train task loss: 154.208 - val loss: 110.078, val task loss: 110.078 [*] Best so far
model_2: train loss: 151.313, train task loss: 151.313 - val loss: 110.494, val task loss: 110.494 [*] Best so far
model_3: train loss: 154.951, train task loss: 154.951 - val loss: 110.940, val task loss: 110.940 [*] Best so far
model_4: train loss: 150.588, train task loss: 150.588 - val loss: 112.007, val task loss: 112.007 [*] Best so far
model_5: train loss: 185.704, train task loss: 185.704 - val loss: 109.613, val task loss: 109.613 [*] Best so far
Finished training student cohort!
Selecting the optimal disgreement penalty via cross-validation...
Best rho: 0 with average task loss: 110.6265
Done!
Selecting greedy ensemble on the best cohort...
Pruned 1 worst models, keeping 4 models
Initial best models: [4, 0, 1] with losses: [tensor(109.6128, grad_fn=<MseLossBackward0>), tensor(110.0782, grad_fn=<MseLossBackward0>), tensor(110.4945, grad_fn=<MseLossBackward0>)]
Don

In [39]:
results

Unnamed: 0,Method,Test_metric,best_rho,cohort_pairs,ensemble_idxs,random_state,d1,d2,n,n_train,n_val,n_test,scale
0,mod1,134.731079,,,,1,2000,400,2000,1280,320,400,0
1,mod2,147.787888,,,,1,2000,400,2000,1280,320,400,0
2,early_fusion,147.084488,,,,1,2000,400,2000,1280,320,400,0
3,late_fusion,129.382126,,,,1,2000,400,2000,1280,320,400,0
4,simple_average,132.346161,,,,1,2000,400,2000,1280,320,400,0
5,weighted_average,132.287064,,,,1,2000,400,2000,1280,320,400,0
6,greedy_ensemble,128.919693,,,"[4, 0, 1]",1,2000,400,2000,1280,320,400,0
7,best_single,137.186935,,,,1,2000,400,2000,1280,320,400,0
8,cohort,"[140.25405883789062, 139.87057495117188, 140.0...",,"[(100, 40), (100, 0), (60, 40), (60, 0), (0, 40)]",,1,2000,400,2000,1280,320,400,0
9,indep_simple_average,132.346161,,,,1,2000,400,2000,1280,320,400,0


In [1]:
from collections import Counter

In [2]:
ensemble = Counter()

In [3]:
ensemble.update([1,2,5])

In [5]:
list(ensemble)

[1, 2, 5]