In [1]:
import numpy as np
import sys
import os
import pickle
import torch
sys.path.append(os.getcwd())
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))
from supervised_learning import test
from ensemble import test as test_ensemble
from get_data import get_dataloader
from rus import *
import pandas as pd
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize
from sklearn.cluster import KMeans

#### dataset clustering

In [3]:
def clustering(X, pca=False, n_clusters=20, n_components=5):
  X = np.nan_to_num(X)
  if len(X.shape) > 2:
    X = X.reshape(X.shape[0],-1)
  if pca:
    # print(np.any(np.isnan(X)), np.all(np.isfinite(X)))
    X = normalize(X)
    X = PCA(n_components=n_components).fit_transform(X)
  kmeans = KMeans(n_clusters=n_clusters).fit(X)
  return kmeans.labels_, X

In [3]:
for i in range(1, 6):
    data_dir = 'synthetic/model_selection/DATA_synthetic{}.pickle'.format(i)
    dataset = pd.read_pickle(data_dir)
    n_components = 2
    data_cluster = dict()
    for split in ['valid2', 'test']:
        data_cluster[split] = dict()
        data = dataset[split]
        kmeans_0, data_0 = clustering(data['0'], pca=True, n_components=n_components, n_clusters=20)
        data_cluster[split]['0'] = kmeans_0.reshape(-1,1)
        kmeans_1, data_1 = clustering(data['1'], pca=True, n_components=n_components, n_clusters=20)
        data_cluster[split]['1'] = kmeans_1.reshape(-1,1)
        data_cluster[split]['label'] = data['label']
    with open('synthetic/model_selection/DATA_synthetic{}_cluster.pickle'.format(i), 'wb') as f:
        pickle.dump(data_cluster, f)

#### get dataset measures

In [4]:
results = dict()
for i in range(1,6):
    with open('synthetic/model_selection/DATA_synthetic{}_cluster.pickle'.format(i), 'rb') as f:
        dataset = pickle.load(f)
    print('synthetic', i)
    data = (dataset['test']['0'], dataset['test']['1'], dataset['test']['label'])
    P, maps = convert_data_to_distribution(*data)
    result = get_measure(P)
    results['synthetic{}'.format(i)] = result
    print()

with open('synthetic/model_selection/datasets.pickle', 'wb') as f:
    pickle.dump(results, f)

synthetic 1
Redundancy 0.042766259538634416
Unique 0.02596364321132929
Unique 2.17157384344753e-15
Synergy 0.10614122901494483

synthetic 2
Redundancy 0.012672691701037832
Unique 2.8259215299914983e-05
Unique 0.0061237290331549
Synergy 0.05926624624675512

synthetic 3
Redundancy 0.027742116209452286
Unique 0.025980735076023997
Unique 1.4684105273299479e-15
Synergy 0.06713126114497113

synthetic 4
Redundancy 0.06557045989622268
Unique 2.1518739480906317e-15
Unique 0.056037809617245915
Synergy 0.053704370671353585

synthetic 5
Redundancy 0.03943597142606828
Unique 0.043663108779810396
Unique 4.6119695336763056e-15
Synergy 0.06589532206160133



#### post-process model predictions

In [8]:
METHODS = ['additive', 'agree', 'align', 'early_fusion', 'elem', 'mfm', 'mi', 'mult', 'outer', 'lower']
MEASURES = ['redundancy', 'unique1', 'unique2', 'synergy']
SETTINGS = ['synthetic{}'.format(i) for i in range(1,6)] + ['maps', 'maps2']

In [6]:
if os.path.isfile('synthetic/model_selection/results.pickle'):
    with open('synthetic/model_selection/results.pickle', 'rb') as f:
        results = pickle.load(f)
else:
    results = dict()
for method in METHODS:
    results[method] = dict()
    for SETTING in SETTINGS:
        print(SETTING)
        data_path = 'synthetic/model_selection/DATA_{}.pickle'.format(SETTING)
        saved_model = 'synthetic/model_selection/{}/{}_{}_best.pt'.format(SETTING, SETTING, method)
        saved_cluster = 'synthetic/model_selection/{}/{}_{}_cluster.pickle'.format(SETTING, SETTING, method)
        _, _, _, testdata = get_dataloader(path=data_path, keys=['0','1','label'], modalities=[1,1], batch_size=128, num_workers=4)
        model = torch.load(saved_model).cuda()
        num_params = sum([p.numel() for p in model.parameters()])
        if method in ['additive', 'agree', 'align']:
            acc = test_ensemble(model, testdata, no_robust=True, criterion=torch.nn.CrossEntropyLoss(), save_preds=saved_cluster)
        else:
            acc = test(model, testdata, no_robust=True, criterion=torch.nn.CrossEntropyLoss(), save_preds=saved_cluster)
        with open(saved_cluster, 'rb') as f:
            preds = pickle.load(f)
        with open('synthetic/model_selection/DATA_{}_cluster.pickle'.format(SETTING), 'rb') as f:
            cluster = pickle.load(f)
        pred_results = (cluster['test']['0'], cluster['test']['1'], preds.reshape(-1,1))
        P, maps = convert_data_to_distribution(*pred_results)
        results[method][SETTING] = get_measure(P)
        results[method][SETTING]['acc'] = acc
        results[method][SETTING]['params'] = num_params
        print()

with open('synthetic/model_selection/results.pickle', 'wb') as f:
    pickle.dump(results, f)

synthetic1
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
loss: tensor(0.8502, device='cuda:0')
acc: 0.6283333333333333
Redundancy 0.2707825379424073
Unique 0.16334202068341028
Unique 3.266602334664933e-07
Synergy 0.09782503764897238

synthetic2
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800




loss: tensor(0.8568, device='cuda:0')
acc: 0.5616666666666666
Redundancy 0.2539995938687824
Unique 0.23269513344692727
Unique 2.354576132383619e-07
Synergy 0.05884751091946233

synthetic3
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800




loss: tensor(1.0994, device='cuda:0')
acc: 0.4261111111111111
Redundancy 0.06405922085532612
Unique 0.23564387158508435
Unique 1.0616968579028281e-08
Synergy 0.14250508125130634

synthetic4
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800




loss: tensor(0.7412, device='cuda:0')
acc: 0.6127777777777778
Redundancy 0.1908922899482186
Unique 0.01327551736675702
Unique 0.02295801843955826
Synergy 0.23860715986087638

synthetic5
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800




loss: tensor(0.7490, device='cuda:0')
acc: 0.6061111111111112
Redundancy 0.2944541194736107
Unique 3.002523193074523e-07
Unique 0.06322643586213239
Synergy 0.04327024215039316

maps
Train data: 672
Valid data: 84
Test data: 88




loss: tensor(32.1191, device='cuda:0')
acc: 0.38636363636363635
Redundancy 0.1404448135200927
Unique 0.3070674037705854
Unique 6.514528330414106e-08
Synergy 0.22755536867441323

synthetic1
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800




loss: tensor(0.6294, device='cuda:0')
acc: 0.6566666666666666
Redundancy 0.15878912442105286
Unique 2.3108788509712802e-14
Unique 0.06375439449212335
Synergy 0.18670017342710815

synthetic2
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
loss: tensor(0.8391, device='cuda:0')
acc: 0.5611111111111111
Redundancy 0.3284486986573478
Unique 0.0965686571055237
Unique 0.0002948888114399769
Synergy 0.09728481616062967

synthetic3
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800




loss: tensor(0.7082, device='cuda:0')
acc: 0.6061111111111112
Redundancy 0.29935997960161165
Unique 0.021428142617372022
Unique 0.0013107095331416508
Synergy 0.15683897633730787

synthetic4
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800




loss: tensor(0.5708, device='cuda:0')
acc: 0.7066666666666667
Redundancy 0.16042916833048892
Unique 1.9288911172398268e-07
Unique 0.30688436833173954
Synergy 0.029394559648220353

synthetic5
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800




loss: tensor(0.9230, device='cuda:0')
acc: 0.4988888888888889
Redundancy 0.17356000715369602
Unique 0.16934852359610825
Unique 1.7879144494800756e-07
Synergy 0.10330239420087672

maps
Train data: 672
Valid data: 84
Test data: 88




loss: tensor(21.7319, device='cuda:0')
acc: 0.3977272727272727
Redundancy 0.12287974707883487
Unique 0.3385740612013276
Unique 6.566806033304332e-08
Synergy 0.21180201951570998

synthetic1
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800




loss: tensor(0.7756, device='cuda:0')
acc: 0.5238888888888888
Redundancy 0.15532672973063116
Unique 1.8309746972723633e-14
Unique 0.0481686017490063
Synergy 0.2054566824898213

synthetic2
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
loss: tensor(0.7483, device='cuda:0')
acc: 0.5611111111111111
Redundancy 0.12777011667835125
Unique 0.2755450275779032
Unique 3.230175750635837e-09
Synergy 0.06437572626653204

synthetic3
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800




loss: tensor(0.9251, device='cuda:0')
acc: 0.4861111111111111
Redundancy 0.14201343884174356
Unique 0.08356761172566413
Unique 1.1855105881503844e-14
Synergy 0.20429857388128242

synthetic4
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
loss: tensor(0.6789, device='cuda:0')
acc: 0.6216666666666667
Redundancy 0.08751141689661879
Unique 7.025224792140894e-16
Unique 0.1819599013173871
Synergy 0.1928610622007646

synthetic5
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
loss: tensor(0.8072, device='cuda:0')
acc: 0.5166666666666667
Redundancy 0.1859622933397237
Unique 0.08313296371646638
Unique 1.2541910737501424e-11
Synergy 0.07519010064887716

maps
Train data: 672
Valid data: 84
Test data: 88
loss: tensor(11.9608, device='cuda:0')
acc: 0.38636363636363635




Redundancy 0.20905235753934148
Unique 0.46367149857812057
Unique 0.005455627361284126
Synergy 0.14743277840214086

synthetic1
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.6777777777777778
Redundancy 0.08618767009564049
Unique 0.1478020164222811
Unique 9.566669550286212e-17
Synergy 0.19242797099226355

synthetic2
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.5694444444444444
Redundancy 0.06787806065761096
Unique 4.235057627976897e-17
Unique 0.16098405633895307
Synergy 0.09259909848076212

synthetic3
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.6366666666666667
Redundancy 0.11609462584523028
Unique 0.05341603242862995
Unique 9.890010965064085e-16
Synergy 0.17822541538307207

synthetic4
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.7038888888888889
Redundancy 0.16942094857145884
Unique 6.149740506924903e-09
Unique 0.19325739854783458
Synergy 0.0952159236667



acc: 0.6588888888888889
Redundancy 0.13494035317652653
Unique 0.18570412495941221
Unique 1.089736656045533e-15
Synergy 0.08558406311556876

maps
Train data: 672
Valid data: 84
Test data: 88
acc: 0.42045454545454547
Redundancy 0.14994052418047799
Unique 0.447735139462754
Unique 0.0038598854631005037
Synergy 0.11095042358349771

synthetic1
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800




acc: 0.5966666666666667
Redundancy 0.03337934033381815
Unique 0.03851051769304159
Unique 4.840201261536444e-15
Synergy 0.07619296193485867

synthetic2
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.5033333333333333
Redundancy 0.015074563180853567
Unique 0.00012426934807638015
Unique 0.0033122852589436177
Synergy 0.05653113154947864

synthetic3
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.5622222222222222
Redundancy 0.033328647677820615
Unique 0.009507784955423818
Unique 0.0007612155052617252
Synergy 0.08083738727055859

synthetic4
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.6305555555555555
Redundancy 0.06675257326672898
Unique 6.0452735493396886e-15
Unique 0.0441738172004701
Synergy 0.07654911296960754

synthetic5
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.5766666666666667
Redundancy 0.04723567218413219
Unique 0.04047391226474971
Unique 9.93467223755



acc: 0.6938888888888889
Redundancy 0.07280928242115736
Unique 0.28393534561598677
Unique 1.2020176513184946e-08
Synergy 0.22175160776922243

synthetic2
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800




acc: 0.5944444444444444
Redundancy 0.13096347933907365
Unique 1.2861476318174074e-15
Unique 0.08858449315646502
Synergy 0.2796268224545085

synthetic3
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.6555555555555556
Redundancy 0.10980141315140374
Unique 0.13220896647385083
Unique 5.208511697530409e-16
Synergy 0.2346480066798905

synthetic4
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.7083333333333334
Redundancy 0.18375387086002104
Unique 1.498659543573323e-07
Unique 0.28068806373675614
Synergy 0.1019438057732398

synthetic5
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800




acc: 0.6738888888888889
Redundancy 0.15808898729715626
Unique 0.24215221187041677
Unique 1.26238299687121e-07
Synergy 0.15050771802014606

maps
Train data: 672
Valid data: 84
Test data: 88




acc: 0.5568181818181818




Redundancy 0.17289438215623415
Unique 0.1745283161713936
Unique 0.005945067822097273
Synergy 0.23531432627157528

synthetic1
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.5733333333333334
Redundancy 0.02538581906757799
Unique 0.008914696139185638
Unique 0.00047549708608233084
Synergy 0.05475519051321456

synthetic2
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.5122222222222222
Redundancy 0.01545316209259949
Unique 0.0005461218186865164
Unique 0.003529586588757831
Synergy 0.0951218946633707

synthetic3
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.5461111111111111
Redundancy 0.058444428849140945
Unique 0.011360755604213782
Unique 7.456743569475827e-06
Synergy 0.11547567560166236

synthetic4
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.6105555555555555
Redundancy 0.08989229096834728
Unique 0.0008958762285644825
Unique 0.011097121951945385
Synergy 0.06725448



Redundancy 0.21080083811757722
Unique 0.18530327050094345
Unique 0.005946063298328629
Synergy 0.37625064688600496

synthetic1
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.6916666666666667




Redundancy 0.1583677418071135
Unique 0.07435460948816144
Unique 9.991994989891182e-08
Synergy 0.27731614112839453

synthetic2
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.5527777777777778




Redundancy 0.22963861878384084
Unique 0.04533441477216857
Unique 6.247972810597319e-08
Synergy 0.12510329850294555

synthetic3
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.6472222222222223
Redundancy 0.22854004588262783
Unique 0.030753586282733315
Unique 0.0027629118227717516
Synergy 0.24915154203645068

synthetic4
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.7011111111111111




Redundancy 0.13836555285116947
Unique 1.0204914957962936e-07
Unique 0.2517058240940211
Synergy 0.07149941745430843

synthetic5
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.6688888888888889




Redundancy 0.11889499605630177
Unique 0.24892403369420704
Unique 1.7867267075186927e-10
Synergy 0.13199549493956658

maps
Train data: 672
Valid data: 84
Test data: 88
acc: 0.4772727272727273
Redundancy 0.12116486602944188
Unique 0.3712700206802071
Unique 1.949454515693101e-08
Synergy 0.025561392210076805

synthetic1
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800




acc: 0.6811111111111111
Redundancy 0.08089932041380615
Unique 0.16949608997378
Unique 2.248332153660169e-16
Synergy 0.182319213857005

synthetic2
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.5338888888888889
Redundancy 0.06388768248240623
Unique 0.018853783253713872
Unique 2.771177848121125e-06
Synergy 0.10365824343417751

synthetic3
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.625
Redundancy 0.09414570169491798
Unique 0.12572443045454873
Unique 1.868797019408191e-16
Synergy 0.15776341908835556

synthetic4
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.7027777777777777
Redundancy 0.15302155312345944
Unique 4.434213186523049e-09
Unique 0.2334888351470007
Synergy 0.0925846018578847

synthetic5
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800




acc: 0.6583333333333333
Redundancy 0.11106344322853678
Unique 0.27327203159650765
Unique 1.3586243769602145e-10
Synergy 0.06888141696302097

maps
Train data: 672
Valid data: 84
Test data: 88




acc: 0.45454545454545453
Redundancy 0.11795716724665539
Unique 0.38781048350030767
Unique 0.013521096062297993
Synergy 0.14609965099677924

synthetic1
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800




acc: 0.69
Redundancy 0.09527345016671326
Unique 0.13641131037709606
Unique 2.382427525680957e-16
Synergy 0.20252465215162743

synthetic2
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.5477777777777778
Redundancy 0.06847499976948997
Unique 0.022903981744335285
Unique 0.007418392437419652
Synergy 0.12175143886396504

synthetic3
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.6366666666666667
Redundancy 0.10357419518084665
Unique 0.12297961562256464
Unique 1.1925369030011689e-15
Synergy 0.18935551235983905

synthetic4
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800
acc: 0.71
Redundancy 0.17547741485622148
Unique 5.611408306976975e-08
Unique 0.227133185554622
Synergy 0.08744457427872937

synthetic5
Train data: 14000
Valid data 1: 2100
Valid data 2: 2100
Test data: 1800




acc: 0.6611111111111111
Redundancy 0.11364093525405028
Unique 0.2545367324290119
Unique 1.0846433085366024e-10
Synergy 0.07272122364189781

maps
Train data: 672
Valid data: 84
Test data: 88




acc: 0.4090909090909091
Redundancy 0.1590486890256423
Unique 0.38409604269174047
Unique 2.344669587668299e-07
Synergy 0.20995459666534433





#### maps

In [5]:
with open ('synthetic/model_selection/maps.pkl', 'rb') as f:
    maps = pickle.load(f)
N = len(maps['bow'])
print(N, maps.keys())
data = dict()
data['train'] = dict()
data['valid'] = dict()
data['test'] = dict()
for i, f in enumerate(['bow', 'typing']):
    data['train'][str(i)] = np.array(maps[f][:N//10*8], dtype=float)
    data['valid'][str(i)] = np.array(maps[f][N//10*8:N//10*9], dtype=float)
    data['test'][str(i)] = np.array(maps[f][N//10*9:], dtype=float)
data['train']['label'] = maps['label'][:N//10*8]
data['valid']['label'] = maps['label'][N//10*8:N//10*9]
data['test']['label'] = maps['label'][N//10*9:]
with open ('synthetic/model_selection/DATA_maps2.pickle', 'wb') as f:
    pickle.dump(data, f)

844 dict_keys(['bow', 'apps', 'typing', 'text_char', 'epoch_diff', 'word_epoch_diff', 'bow_onehot', 'apps_onehot', 'typing_onehot', 'label'])


In [4]:
with open ('synthetic/model_selection/DATA_maps.pickle', 'rb') as f:
    dataset = pickle.load(f)
print(dataset.keys())
print(dataset['train']['0'].shape, dataset['train']['1'].shape)
print(np.unique(dataset['train']['label']))

dict_keys(['train', 'valid', 'test'])
(672, 1000) (672, 137)
[0 1 2]


In [8]:
n_components = 2
data_cluster = dict()
for split in ['test']:
    data_cluster[split] = dict()
    data = dataset[split]
    kmeans_0, data_0 = clustering(data['0'], pca=True, n_components=n_components, n_clusters=20)
    data_cluster[split]['0'] = kmeans_0.reshape(-1,1)
    kmeans_1, data_1 = clustering(data['1'], pca=True, n_components=n_components, n_clusters=20)
    data_cluster[split]['1'] = kmeans_1.reshape(-1,1)
    data_cluster[split]['label'] = data['label']
with open('synthetic/model_selection/DATA_maps2_cluster.pickle', 'wb') as f:
    pickle.dump(data_cluster, f)


In [13]:
if os.path.isfile('synthetic/model_selection/datasets.pickle'):
    with open('synthetic/model_selection/datasets.pickle', 'rb') as f:
        results = pickle.load(f)
else:
    results = dict()
with open('synthetic/model_selection/DATA_maps2_cluster.pickle', 'rb') as f:
    dataset = pickle.load(f)
print('maps')
data = (dataset['test']['0'], dataset['test']['1'], dataset['test']['label'])
P, maps = convert_data_to_distribution(*data)
result = get_measure(P)
results['maps2'] = result     

with open('synthetic/model_selection/datasets.pickle', 'wb') as f:
    pickle.dump(results, f)

maps
Redundancy 0.12120173624349637
Unique 0.004155439963112993
Unique 0.03654196259113951
Synergy 0.40223033982849343




In [15]:
if os.path.isfile('synthetic/model_selection/results.pickle'):
    with open('synthetic/model_selection/results.pickle', 'rb') as f:
        results = pickle.load(f)
else:
    results = dict()
for method in ['mi', 'mult']:
    for SETTING in ['maps2']:
        print(SETTING)
        data_path = 'synthetic/model_selection/DATA_{}.pickle'.format(SETTING)
        saved_model = 'synthetic/model_selection/{}/{}_{}_best.pt'.format(SETTING, SETTING, method)
        saved_cluster = 'synthetic/model_selection/{}/{}_{}_cluster.pickle'.format(SETTING, SETTING, method)
        _, _, _, testdata = get_dataloader(path=data_path, keys=['0','1','label'], modalities=[1,1], batch_size=128, num_workers=4)
        model = torch.load(saved_model).cuda()
        num_params = sum([p.numel() for p in model.parameters()])
        if method in ['additive', 'agree', 'align']:
            acc = test_ensemble(model, testdata, no_robust=True, criterion=torch.nn.CrossEntropyLoss(), save_preds=saved_cluster)
        else:
            acc = test(model, testdata, no_robust=True, criterion=torch.nn.CrossEntropyLoss(), save_preds=saved_cluster)
        with open(saved_cluster, 'rb') as f:
            preds = pickle.load(f)
        with open('synthetic/model_selection/DATA_{}_cluster.pickle'.format(SETTING), 'rb') as f:
            cluster = pickle.load(f)
        pred_results = (cluster['test']['0'], cluster['test']['1'], preds.reshape(-1,1))
        P, maps = convert_data_to_distribution(*pred_results)
        results[method][SETTING] = get_measure(P)
        results[method][SETTING]['acc'] = acc
        results[method][SETTING]['params'] = num_params
        print()

with open('synthetic/model_selection/results.pickle', 'wb') as f:
    pickle.dump(results, f)

maps2
Train data: 672
Valid data: 84
Test data: 88
acc: 0.36363636363636365




Redundancy 0.25493315428829044
Unique 0.24084311647191153
Unique 0.024454060329114166
Synergy 0.4022239205538314

maps2
Train data: 672
Valid data: 84
Test data: 88
acc: 0.5
Redundancy 0.10450894868078536
Unique 0.26887867229181706
Unique 7.651484056453488e-08
Synergy 0.18770946896049534





#### model selection

In [2]:
with open('synthetic/model_selection/results.pickle', 'rb') as f:
    results = pickle.load(f)
print(results.keys())
with open('synthetic/model_selection/datasets.pickle', 'rb') as f:
    datasets = pickle.load(f)
print(datasets.keys())

dict_keys(['additive', 'agree', 'align', 'early_fusion', 'elem', 'mfm', 'mi', 'mult', 'outer', 'tf', 'lower'])
dict_keys(['synthetic1', 'synthetic2', 'synthetic3', 'synthetic4', 'synthetic5', 'maps', 'maps2'])


In [5]:
with open('synthetic/experiments2/datasets.pickle', 'rb') as f:
    old_datasets = pickle.load(f)
with open('synthetic/experiments2/results.pickle', 'rb') as f:
    old_results = pickle.load(f)
normalized_old_datasets = dict()
for setting in old_datasets:
    normalized_old_datasets[setting] = dict()
    datasets_total = sum([old_datasets[setting][measure] for measure in MEASURES])
    for measure in MEASURES:
        normalized_old_datasets[setting][measure] = old_datasets[setting][measure] / datasets_total

In [9]:
methods = ['additive', 'agree', 'align', 'early_fusion', 'elem', 'recon', 'mi', 'mult', 'outer', 'lower']
model_selection = dict()
for setting in old_datasets:
    measures = sorted([(old_results[method][setting]['acc'], method) for method in methods])
    model_selection[setting] = measures[-3:]
# model_selection

In [11]:
normalized_results = dict()
normalized_datasets = dict()
for method in METHODS:
    normalized_results[method] = dict()
    for setting in SETTINGS:
        normalized_results[method][setting] = dict()
        normalized_datasets[setting] = dict()
        results_total = sum([results[method][setting][measure] for measure in MEASURES])
        datasets_total = sum([datasets[setting][measure] for measure in MEASURES])
        for measure in MEASURES:
            normalized_results[method][setting][measure] = results[method][setting][measure] / results_total
            normalized_datasets[setting][measure] = datasets[setting][measure] / datasets_total
accs = [results[method][setting]['acc'] for setting in SETTINGS for method in METHODS]
model_selection_best = dict()
for setting in SETTINGS:
    measures = sorted([(results[method][setting]['acc'], method) for method in METHODS])
    model_selection_best[setting] = results[measures[-1][1]][setting]['acc']
    print(setting, measures)

synthetic1 [(0.5238888888888888, 'align'), (0.5733333333333334, 'mi'), (0.5966666666666667, 'elem'), (0.6283333333333333, 'additive'), (0.6566666666666666, 'agree'), (0.6777777777777778, 'early_fusion'), (0.6811111111111111, 'outer'), (0.69, 'lower'), (0.6916666666666667, 'mult'), (0.6938888888888889, 'mfm')]
synthetic2 [(0.5033333333333333, 'elem'), (0.5122222222222222, 'mi'), (0.5338888888888889, 'outer'), (0.5477777777777778, 'lower'), (0.5527777777777778, 'mult'), (0.5611111111111111, 'agree'), (0.5611111111111111, 'align'), (0.5616666666666666, 'additive'), (0.5694444444444444, 'early_fusion'), (0.5944444444444444, 'mfm')]
synthetic3 [(0.4261111111111111, 'additive'), (0.4861111111111111, 'align'), (0.5461111111111111, 'mi'), (0.5622222222222222, 'elem'), (0.6061111111111112, 'agree'), (0.625, 'outer'), (0.6366666666666667, 'early_fusion'), (0.6366666666666667, 'lower'), (0.6472222222222223, 'mult'), (0.6555555555555556, 'mfm')]
synthetic4 [(0.6105555555555555, 'mi'), (0.612777777

In [15]:
best_model_scores = []
for setting in datasets:
    tmp = []
    for old_setting in old_datasets:
        tmp.append((np.sum(np.absolute([normalized_datasets[setting][measure]-normalized_old_datasets[old_setting][measure] for measure in MEASURES])), old_setting))
    tmp = sorted(tmp)
    for i in range(1):
        print(setting, "most similar dataset:", tmp[i], "selected models", model_selection[tmp[i][1]])
        selected_models = model_selection[tmp[i][1]]
        model_scores = []
        for _, method in selected_models:
            if method == 'recon':
                method = 'mfm'
            model_scores.append((results[method][setting]['acc'], method))
        model_scores = sorted(model_scores)
        best_score = model_scores[-1][0] / model_selection_best[setting] * 100
        best_model_scores.append(best_score)
        print('{:.2f}%'.format(best_score))
print('Overall:{:.2f}%'.format(np.mean(best_model_scores)))

synthetic1 most similar dataset: (0.14633150964283043, 'synthetic5') selected models [(0.6844444444444444, 'early_fusion'), (0.685, 'recon'), (0.6877777777777778, 'agree')]
100.00%
synthetic2 most similar dataset: (0.18355423842472463, 'synthetic3') selected models [(0.5461111111111111, 'recon'), (0.5538888888888889, 'align'), (0.555, 'mult')]
100.00%
synthetic3 most similar dataset: (0.2793375138069869, 'synthetic5') selected models [(0.6844444444444444, 'early_fusion'), (0.685, 'recon'), (0.6877777777777778, 'agree')]
100.00%
synthetic4 most similar dataset: (0.11088404258613735, 'synthetic2') selected models [(0.7138888888888889, 'additive'), (0.7183333333333334, 'align'), (0.72, 'agree')]
99.53%
synthetic5 most similar dataset: (0.14912201008589812, 'mix2') selected models [(0.7116666666666667, 'recon'), (0.7205555555555555, 'agree'), (0.7233333333333334, 'align')]
100.00%
maps most similar dataset: (0.09856786863625443, 'synthetic4') selected models [(0.6522222222222223, 'recon'),

In [4]:
for setting in datasets:
    print(setting, datasets[setting])

synthetic1 {'redundancy': 0.042766259538634416, 'unique1': 0.02596364321132929, 'unique2': 2.17157384344753e-15, 'synergy': 0.10614122901494483}
synthetic2 {'redundancy': 0.012672691701037832, 'unique1': 2.8259215299914983e-05, 'unique2': 0.0061237290331549, 'synergy': 0.05926624624675512}
synthetic3 {'redundancy': 0.027742116209452286, 'unique1': 0.025980735076023997, 'unique2': 1.4684105273299479e-15, 'synergy': 0.06713126114497113}
synthetic4 {'redundancy': 0.06557045989622268, 'unique1': 2.1518739480906317e-15, 'unique2': 0.056037809617245915, 'synergy': 0.053704370671353585}
synthetic5 {'redundancy': 0.03943597142606828, 'unique1': 0.043663108779810396, 'unique2': 4.6119695336763056e-15, 'synergy': 0.06589532206160133}
maps {'redundancy': 0.08691429284341395, 'unique1': 1.1195372689354384e-07, 'unique2': 0.08681460221097399, 'synergy': 0.2624709049842621}
maps2 {'redundancy': 0.12120173624349637, 'unique1': 0.004155439963112993, 'unique2': 0.03654196259113951, 'synergy': 0.4022303

In [11]:
if os.path.isfile('synthetic/model_selection/datasets.pickle'):
    with open('synthetic/model_selection/datasets.pickle', 'rb') as f:
        results = pickle.load(f)
else:
    results = dict()
with open('synthetic/model_selection/DATA_maps2_cluster.pickle', 'rb') as f:
    dataset = pickle.load(f)
print('maps2')
data = (dataset['test']['0'], dataset['test']['1'], dataset['test']['label'])
P, maps = convert_data_to_distribution(*data)
result = get_measure(P)
results['maps2'] = result     

with open('synthetic/model_selection/datasets.pickle', 'wb') as f:
    pickle.dump(results, f)

maps2
Redundancy 0.12120173624349637
Unique 0.004155439963112993
Unique 0.03654196259113951
Synergy 0.40223033982849343


Format results in $\LaTeX$

In [16]:
for setting in ['maps', 'maps2']:
    # for method in ['early_fusion', 'additive', 'agree', 'align']:
    # for method in ['elem', 'outer', 'mi', 'mult']:
    for method in ['lower', 'mfm']:
        for measure in MEASURES:
            print('&', '${:.2f}$'.format(results[method][setting][measure]), end=' ')
        print('&', '${:.2f}$'.format(results[method][setting]['acc']), end=' ')
    print('\\\\')

& $0.53$ & $0.04$ & $0.04$ & $0.09$ & $0.41$ & $0.28$ & $0.06$ & $0.07$ & $0.10$ & $0.56$ \\
& $0.12$ & $0.24$ & $0.00$ & $0.22$ & $0.51$ & $0.23$ & $0.17$ & $0.00$ & $0.18$ & $0.52$ \\
