In [23]:
from typing import List, Optional, Tuple, Union

import copy
import json
import numpy
import random
import torch
import pandas
import seaborn as sns
import scipy.stats as sts
from scipy.stats import lognorm
from tqdm import tqdm as tqdm

# BED imports
from boed.networks.fullyconnected import FullyConnected
from boed.networks.summstats import NeuralSummStats, CAT_NSS
from boed.simulators.bandits import simulate_bandit_batch, sim_bandit_prior
from boed.utils.utils_human_participant_study import *

# matplotlib imports
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.colors import to_rgb
from matplotlib.patches import Patch
from matplotlib import rc
%matplotlib inline
plt.style.use('default')
plt.rcParams['figure.figsize'] = (16.0, 8.0)
plt.rcParams.update({'font.size': 16})
rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})
rc('text', usetex=True)
plt.rcParams['text.latex.preamble'] = r'\usepackage{amsmath}'
cycle = plt.rcParams['axes.prop_cycle'].by_key()['color']

In [24]:
device = torch.device('cpu')

# Load in Human Participant Data

## Load in raw data

In [25]:
with open('../data/two_stage_data.json', 'r') as f:
    data = json.loads(f.read())
    
optimal_data = data['optimal']
base_data = data['naive']  # naive == baseline

## Specify desired designs

In [26]:
# Baseline designs
d_md_base = torch.load('../data/designs/md_designs_baseline.pt')
d_pe_base = torch.load('../data/designs/pe_designs_baseline.pt')

In [27]:
# MD optimal designs
d_md_opt = [[0, 0, 0.6], [1, 1, 0]]

# PE optimal designs
d_pe_wslts_opt = [[0, 0, 1], [0, 1, 1], [1, 0, 1]]
d_pe_aeg_opt = [[1, 0, 0], [0, 0, 1], [1, 0, 1]]
d_pe_gls_opt = [[0, 1, 0], [0, 0, 1], [0, 0, 1]]

## Transform raw data

In [28]:
# get MD data for optimal and baseline users
optimal_md_only = [get_md_data_only(user, md_blocks=2) for user in optimal_data]
base_md_only = [get_md_data_only(user, md_blocks=2) for user in base_data]

# transform data for optimal and baseline users to be consistent
optimal_md_transf = [transform_data(user, d_md_opt) for user in optimal_md_only]
base_md_transf = [transform_data(user, d_md_base[user['conditionS1']], num_arms=3) for user in base_md_only]

In [29]:
# get PE data for optimal and baseline users
optimal_pe_only = [get_pe_data_only(user, md_blocks=2) for user in optimal_data]
base_pe_only = [get_pe_data_only(user, md_blocks=2) for user in base_data]

# transform data for optimal users to be consistent
optimal_pe_wslts_transf = list()
optimal_pe_aeg_transf = list()
optimal_pe_gls_transf = list()
for user in optimal_pe_only:
    if user['conditionS2'] == 'wslts':
        transf = transform_data(user, d_pe_wslts_opt, num_arms=3)
        optimal_pe_wslts_transf.append(transf)
    elif user['conditionS2'] == 'aeg':
        transf = transform_data(user, d_pe_aeg_opt, num_arms=3)
        optimal_pe_aeg_transf.append(transf)
    elif user['conditionS2'] == 'gls':
        transf = transform_data(user, d_pe_gls_opt, num_arms=3)
        optimal_pe_gls_transf.append(transf)
        
# transform data for baseline users to be consistent
base_pe_transf = [transform_data(user, d_pe_base[user['conditionS1']], num_arms=3) for user in base_pe_only]

# Load in trained neural networks

## MD Models

In [30]:
modelparams_md_loading = {
    'layers': 2,
    'hidden': [32, 32],
    'num_measurements': 2,
    'summ_L': 2,
    'summ_H': [64, 32],
    'summ_out': 6
}

In [31]:
# save the models for optimal designs
model_md_opt_list = list()
model_summ_md_opt_list = list()
for i in range(50):
    
    fcp = "../data/models/md_model_trained_optimal_new_repeat{}.pt".format(i)
    summp = "../data/models/md_model_summ_trained_optimal_new_repeat{}.pt".format(i)
    
    model_tr, model_summ_tr = get_trained_models(modelparams_md_loading, fcp, summp, dim1=1)
    
    model_md_opt_list.append(model_tr)
    model_summ_md_opt_list.append(model_summ_tr)

In [32]:
# save the models for baseline designs
model_md_base_dict = dict((str(i), list()) for i in range(10))
model_summ_md_base_dict = dict((str(i), list()) for i in range(10))
for job_id in range(10):
    for i in range(10):

        fcp = "../data/models/md_model_trained_baseline{}_ensemble_new_repeat{}.pt".format(i, job_id)
        summp = "../data/models/md_model_summ_trained_baseline{}_ensemble_new_repeat{}.pt".format(i, job_id)

        model_tr, model_summ_tr = get_trained_models(modelparams_md_loading, fcp, summp, dim1=1)
                
        model_md_base_dict[str(i)].append(model_tr)
        model_summ_md_base_dict[str(i)].append(model_summ_tr)

## PE-WSLTS Models

In [33]:
modelparams_pe_wslts_loading = {
    'layers': 2,
    'hidden': [64, 32],
    'num_measurements': 3,
    'summ_L': 2,
    'summ_H': [64, 32],
    'summ_out': 8
}

In [34]:
# save the models for optimal designs
model_pe_wslts_opt_list = list()
model_summ_pe_wslts_opt_list = list()
for i in range(50):
    
    fcp = "../data/models/pe_wslts_model_trained_optimal_ensemble_new_repeat{}.pt".format(i)
    summp = "../data/models/pe_wslts_model_summ_trained_optimal_ensemble_new_repeat{}.pt".format(i)
    
    model_tr, model_summ_tr = get_trained_models(modelparams_pe_wslts_loading, fcp, summp, dim1=3)
    
    model_pe_wslts_opt_list.append(model_tr)
    model_summ_pe_wslts_opt_list.append(model_summ_tr)

In [35]:
# save the models for baseline designs
model_pe_wslts_base_dict = dict((str(i), list()) for i in range(10))
model_summ_pe_wslts_base_dict = dict((str(i), list()) for i in range(10))
for job_id in range(10):
    for i in range(10):

        fcp = "../data/models/pe_wslts_model_trained_baseline{}_ensemble_new_repeat{}.pt".format(i, job_id)
        summp = "../data/models/pe_wslts_model_summ_trained_baseline{}_ensemble_new_repeat{}.pt".format(i, job_id)

        model_tr, model_summ_tr = get_trained_models(modelparams_pe_wslts_loading, fcp, summp, dim1=3)
                
        model_pe_wslts_base_dict[str(i)].append(model_tr)
        model_summ_pe_wslts_base_dict[str(i)].append(model_summ_tr)

## PE-AEG Models

In [36]:
modelparams_pe_aeg_loading = {
    'layers': 2,
    'hidden': [64, 32],
    'num_measurements': 3,
    'summ_L': 2,
    'summ_H': [64, 32],
    'summ_out': 6
}

In [37]:
# save the models for optimal designs
model_pe_aeg_opt_list = list()
model_summ_pe_aeg_opt_list = list()
for i in range(50):
    
    fcp = "../data/models/pe_aeg_model_trained_optimal_ensemble_new_repeat{}.pt".format(i)
    summp = "../data/models/pe_aeg_model_summ_trained_optimal_ensemble_new_repeat{}.pt".format(i)
    
    model_tr, model_summ_tr = get_trained_models(modelparams_pe_aeg_loading, fcp, summp, dim1=2)
    
    model_pe_aeg_opt_list.append(model_tr)
    model_summ_pe_aeg_opt_list.append(model_summ_tr)

In [38]:
# save the models for baseline designs
model_pe_aeg_base_dict = dict((str(i), list()) for i in range(10))
model_summ_pe_aeg_base_dict = dict((str(i), list()) for i in range(10))
for job_id in range(10):
    for i in range(10):

        fcp = "../data/models/pe_aeg_model_trained_baseline{}_ensemble_new_repeat{}.pt".format(i, job_id)
        summp = "../data/models/pe_aeg_model_summ_trained_baseline{}_ensemble_new_repeat{}.pt".format(i, job_id)

        model_tr, model_summ_tr = get_trained_models(modelparams_pe_aeg_loading, fcp, summp, dim1=2)
                
        model_pe_aeg_base_dict[str(i)].append(model_tr)
        model_summ_pe_aeg_base_dict[str(i)].append(model_summ_tr)

## PE-GLS Models

In [39]:
modelparams_pe_gls_loading = {
    'layers': 2,
    'hidden': [64, 32],
    'num_measurements': 3,
    'summ_L': 2,
    'summ_H': [64, 32],
    'summ_out': 8
}

In [40]:
# save the models
model_pe_gls_opt_list = list()
model_summ_pe_gls_opt_list = list()
for i in range(50):
    
    fcp = "../data/models/pe_gls_model_trained_optimal_ensemble_new_repeat{}.pt".format(i)
    summp = "../data/models/pe_gls_model_summ_trained_optimal_ensemble_new_repeat{}.pt".format(i)
    
    model_tr, model_summ_tr = get_trained_models(modelparams_pe_gls_loading, fcp, summp, dim1=5)
    
    model_pe_gls_opt_list.append(model_tr)
    model_summ_pe_gls_opt_list.append(model_summ_tr)

In [41]:
# save the models
model_pe_gls_base_dict = dict((str(i), list()) for i in range(10))
model_summ_pe_gls_base_dict = dict((str(i), list()) for i in range(10))
for job_id in range(10):
    for i in range(10):

        fcp = "../data/models/pe_gls_model_trained_baseline{}_ensemble_new_repeat{}.pt".format(i, job_id)
        summp = "../data/models/pe_gls_model_summ_trained_baseline{}_ensemble_new_repeat{}.pt".format(i, job_id)

        model_tr, model_summ_tr = get_trained_models(modelparams_pe_gls_loading, fcp, summp, dim1=5)
                
        model_pe_gls_base_dict[str(i)].append(model_tr)
        model_summ_pe_gls_base_dict[str(i)].append(model_summ_tr)

# Compute Posterior Distributions

## MD Posteriors

In [42]:
# OPTIMAL
posts_md_opt_ensemble = list()
for user in tqdm(optimal_md_transf):
        
    # real-world observation
    y_obs = combine_choices_rewards(user).unsqueeze(0)
    
    posts_single_network = list()
    for i in tqdm(range(len(model_md_opt_list)), disable=True):
        
        model_tr = model_md_opt_list[i]
        model_summ_tr = model_summ_md_opt_list[i]
        
        Sy_obs = model_summ_tr(y_obs)
        
        X = torch.tensor(numpy.arange(0, 3).reshape(-1, 1), dtype=torch.float, device=device)
        X.to(X)
        Y = torch.cat(len(X)*[Sy_obs])
        Y.to(device);
        
        T = model_tr(X, Y).data.numpy().reshape(-1)
        prior_weight = 1 / 3.
        post_weights = numpy.exp(T - 1) * prior_weight
        post_norm = post_weights / numpy.sum(post_weights)
        posts_single_network.append(post_norm)
        
    posts_md_opt_ensemble.append(posts_single_network)
posts_md_opt_ensemble = numpy.array(posts_md_opt_ensemble)

100%|███████████████████████████████████████████████████| 166/166 [00:04<00:00, 38.89it/s]


In [43]:
# BASELINE
posts_md_base_ensemble = list()
for user in tqdm(base_md_transf):
    
    condition = user['conditionS1']
    
    # real-world observation
    y_obs = combine_choices_rewards(user).unsqueeze(0)
    
    posts_single_network = list()
    for i in tqdm(range(len(model_md_base_dict[condition])), disable=True):
        
        model_tr = model_md_base_dict[condition][i]
        model_summ_tr = model_summ_md_base_dict[condition][i]
        
        Sy_obs = model_summ_tr(y_obs)
        
        X = torch.tensor(numpy.arange(0, 3).reshape(-1, 1), dtype=torch.float, device=device)
        X.to(X)
        Y = torch.cat(len(X)*[Sy_obs])
        Y.to(device);
        
        T = model_tr(X, Y).data.numpy().reshape(-1)
        prior_weight = 1 / 3.
        post_weights = numpy.exp(T - 1) * prior_weight
        post_norm = post_weights / numpy.sum(post_weights)
        posts_single_network.append(post_norm)
        
    posts_md_base_ensemble.append(posts_single_network)
posts_md_base_ensemble = numpy.array(posts_md_base_ensemble)

100%|██████████████████████████████████████████████████| 160/160 [00:00<00:00, 177.28it/s]


In [44]:
baseline_md_postargmax = numpy.argmax(numpy.mean(posts_md_base_ensemble, axis=1), axis=1)
print('Optimal model indicators for each participant:')
print(baseline_md_postargmax)

Optimal model indicators for each participant:
[2 2 2 0 0 0 0 0 2 0 1 2 2 2 2 1 2 0 1 0 1 2 0 2 0 2 2 0 0 2 0 2 0 2 1 2 1
 1 2 2 2 0 0 2 0 0 2 0 2 2 2 2 2 0 0 0 0 2 2 2 2 2 2 0 1 0 0 2 2 1 0 0 2 1
 2 0 0 2 2 2 0 0 2 0 2 2 1 0 0 2 2 1 0 0 2 2 2 2 0 2 2 1 0 2 2 2 0 2 2 2 1
 0 2 2 2 0 0 2 2 2 0 1 0 0 0 1 2 2 2 2 2 2 0 2 0 0 0 1 0 2 2 0 0 2 1 1 0 2
 1 0 2 2 0 2 2 2 1 2 2 1]


In [45]:
count_dict_opt = {'wslts': 0, 'aeg': 0, 'gls': 0}
for user in optimal_md_transf:
    count_dict_opt[user['conditionS2']] += 1
    
count_dict_base = {'wslts': 0, 'aeg': 0, 'gls': 0}
for user_ind in baseline_md_postargmax:
    if user_ind == 0:
        count_dict_base['wslts'] += 1
    elif user_ind == 1:
        count_dict_base['aeg'] += 1
    elif user_ind == 2:
        count_dict_base['gls'] += 1

In [46]:
print('MD phase allocation of participants in the optimal design group:')
print(count_dict_opt)
print('MD phase allocation of participants in the baseline design group:')
print(count_dict_base)

MD phase allocation of participants in the optimal design group:
{'wslts': 62, 'aeg': 75, 'gls': 29}
MD phase allocation of participants in the baseline design group:
{'wslts': 57, 'aeg': 22, 'gls': 81}


In [47]:
for key, value in count_dict_opt.items():
    count_dict_opt[key] = round(value / len(optimal_md_transf), 3)
for key, value in count_dict_base.items():
    count_dict_base[key] = round(value / len(baseline_md_postargmax), 3)
print('MD phase fractional allocation of participants in the optimal design group:')
print(count_dict_opt)
print('MD phase fractional allocation of participants in the baseline design group:')
print(count_dict_base)

MD phase fractional allocation of participants in the optimal design group:
{'wslts': 0.373, 'aeg': 0.452, 'gls': 0.175}
MD phase fractional allocation of participants in the baseline design group:
{'wslts': 0.356, 'aeg': 0.138, 'gls': 0.506}


## PE-WSLTS Posteriors

In [48]:
# Get regular prior samples
SIMMODEL = 0
DATASIZE = 5_000 # do 50_000 for best results
prior_0 = sim_bandit_prior(DATASIZE, prior='uninformed', simmodel=SIMMODEL)

# number of re-samples
K = 10_000  # do 100_000 for best results

# bins for histograms
BINS=50
bins_all_0 = [
    numpy.linspace(0, 1, BINS),
    numpy.linspace(0, 1, BINS),
    numpy.linspace(0.01, 5, BINS)
]

In [49]:
# obtain posterior histograms and correlations for optimal designs
posts_pe_wslts_ensemble, corrs_pe_wslts_ensemble = get_pe_posterior_histograms_optimal(
    users=optimal_pe_wslts_transf,
    model_list=model_pe_wslts_opt_list,
    model_summ_list=model_summ_pe_wslts_opt_list,
    prior_samples=prior_0,
    hist_bins_list=bins_all_0,
    simmodel=SIMMODEL,
    num_resample=K
)

100%|█████████████████████████████████████████████████████| 62/62 [03:11<00:00,  3.09s/it]


In [50]:
# obtain posterior histograms and correlations for baseline designs
posts_pe_wslts_ensemble_base, corrs_pe_wslts_ensemble_base = get_pe_posterior_histograms_baseline(
    users=base_pe_transf,
    model_dict=model_pe_wslts_base_dict,
    model_summ_dict=model_summ_pe_wslts_base_dict,
    prior_samples=prior_0,
    hist_bins_list=bins_all_0,
    simmodel=SIMMODEL,
    num_resample=K,
    baseline_allocation=baseline_md_postargmax
)

100%|███████████████████████████████████████████████████| 160/160 [00:34<00:00,  4.70it/s]


## PE-AEG Posteriors

In [51]:
# Get regular prior samples
SIMMODEL = 1
DATASIZE = 5_000 # do 50_000 for best results
prior_1 = sim_bandit_prior(DATASIZE, prior='uninformed', simmodel=SIMMODEL)

# number of re-samples
K = 10_000  # do 100_000 for best results

# bins for histograms
BINS=50
bins_all_1 = [
    numpy.linspace(0, 1, BINS),
    numpy.linspace(0, 1, BINS),
]

In [52]:
# obtain posterior histograms and correlations for optimal designs
posts_pe_aeg_ensemble, corrs_pe_aeg_ensemble = get_pe_posterior_histograms_optimal(
    users=optimal_pe_aeg_transf,
    model_list=model_pe_aeg_opt_list,
    model_summ_list=model_summ_pe_aeg_opt_list,
    prior_samples=prior_1,
    hist_bins_list=bins_all_1,
    simmodel=SIMMODEL,
    num_resample=K
)

100%|█████████████████████████████████████████████████████| 75/75 [04:25<00:00,  3.54s/it]


In [53]:
# obtain posterior histograms and correlations for baseline designs
posts_pe_aeg_ensemble_base, corrs_pe_aeg_ensemble_base = get_pe_posterior_histograms_baseline(
    users=base_pe_transf,
    model_dict=model_pe_aeg_base_dict,
    model_summ_dict=model_summ_pe_aeg_base_dict,
    prior_samples=prior_1,
    hist_bins_list=bins_all_1,
    simmodel=SIMMODEL,
    num_resample=K,
    baseline_allocation=baseline_md_postargmax
)

100%|███████████████████████████████████████████████████| 160/160 [00:18<00:00,  8.83it/s]


## PE-GLS Posteriors

In [54]:
# Get regular prior samples
SIMMODEL = 2
DATASIZE = 5_000 # do 50_000 for best results
prior_2 = sim_bandit_prior(DATASIZE, prior='uninformed', simmodel=SIMMODEL)

# number of re-samples
K = 10_000  # do 100_000 for best results

# bins for histograms
BINS=50
bins_all_2 = [
    numpy.linspace(0, 1, BINS),
    numpy.linspace(0, 1, BINS),
    numpy.linspace(0, 1, BINS),
    numpy.linspace(0, 1, BINS),
    numpy.linspace(0, 1, BINS)
]

In [55]:
# obtain posterior histograms and correlations for optimal designs
posts_pe_gls_ensemble, corrs_pe_gls_ensemble = get_pe_posterior_histograms_optimal(
    users=optimal_pe_gls_transf,
    model_list=model_pe_gls_opt_list,
    model_summ_list=model_summ_pe_gls_opt_list,
    prior_samples=prior_2,
    hist_bins_list=bins_all_2,
    simmodel=SIMMODEL,
    num_resample=K
)

100%|█████████████████████████████████████████████████████| 29/29 [01:45<00:00,  3.64s/it]


In [56]:
# obtain posterior histograms and correlations for baseline designs
posts_pe_gls_ensemble_base, corrs_pe_gls_ensemble_base = get_pe_posterior_histograms_baseline(
    users=base_pe_transf,
    model_dict=model_pe_gls_base_dict,
    model_summ_dict=model_summ_pe_gls_base_dict,
    prior_samples=prior_2,
    hist_bins_list=bins_all_2,
    simmodel=SIMMODEL,
    num_resample=K,
    baseline_allocation=baseline_md_postargmax
)

100%|███████████████████████████████████████████████████| 160/160 [01:00<00:00,  2.66it/s]


# Compute Posterior Entropies

## MD

In [57]:
md_entropies_opt = numpy.array([sts.entropy(p, axis=1) for p in posts_md_opt_ensemble])
md_entropies_base = numpy.array([sts.entropy(p, axis=1) for p in posts_md_base_ensemble])

## WSLTS

In [58]:
# simmodel
SIMMODEL=0

# number of re-samples
K = 1_000  # Select at least 10_000 for best results

# bins for histograms
N_GRID=10  # Select at least 30 for best results
grid_all_0 = [
    numpy.linspace(0, 1, N_GRID),
    numpy.linspace(0, 1, N_GRID),
    numpy.linspace(0.01, 5, N_GRID)
]

In [59]:
# obtain posterior histograms and correlations for optimal designs
entropy_pe_wslts_avg, entropy_pe_wslts_ind = get_pe_entropies_optimal(
    users=optimal_pe_wslts_transf,
    model_list=model_pe_wslts_opt_list,
    model_summ_list=model_summ_pe_wslts_opt_list,
    prior_samples=prior_0,
    grid_list=grid_all_0,
    simmodel=SIMMODEL,
    num_resample=K
)

100%|█████████████████████████████████████████████████████| 62/62 [02:54<00:00,  2.81s/it]


In [60]:
# obtain posterior histograms and correlations for optimal designs
entropy_pe_wslts_base_avg, entropy_pe_wslts_base_ind = get_pe_entropies_baseline(
    users=base_pe_transf,
    model_dict=model_pe_wslts_base_dict,
    model_summ_dict=model_summ_pe_wslts_base_dict,
    prior_samples=prior_0,
    grid_list=grid_all_0,
    simmodel=SIMMODEL,
    num_resample=K,
    baseline_allocation=baseline_md_postargmax
)

100%|███████████████████████████████████████████████████| 160/160 [00:34<00:00,  4.64it/s]


## AEG

In [61]:
# simmodel
SIMMODEL=1

# number of re-samples
K = 1_000  # Select at least 10_000 for best results

# bins for histograms
N_GRID=10  # Select at least 100 for best results
grid_all_1 = [
    numpy.linspace(0, 1, N_GRID),
    numpy.linspace(0, 1, N_GRID),
]

In [62]:
# obtain posterior histograms and correlations for optimal designs
entropy_pe_aeg_avg, entropy_pe_aeg_ind = get_pe_entropies_optimal(
    users=optimal_pe_aeg_transf,
    model_list=model_pe_aeg_opt_list,
    model_summ_list=model_summ_pe_aeg_opt_list,
    prior_samples=prior_1,
    grid_list=grid_all_1,
    simmodel=SIMMODEL,
    num_resample=K
)

100%|█████████████████████████████████████████████████████| 75/75 [02:34<00:00,  2.06s/it]


In [63]:
# obtain posterior histograms and correlations for optimal designs
entropy_pe_aeg_base_avg, entropy_pe_aeg_base_ind = get_pe_entropies_baseline(
    users=base_pe_transf,
    model_dict=model_pe_aeg_base_dict,
    model_summ_dict=model_summ_pe_aeg_base_dict,
    prior_samples=prior_1,
    grid_list=grid_all_1,
    simmodel=SIMMODEL,
    num_resample=K,
    baseline_allocation=baseline_md_postargmax
)

100%|███████████████████████████████████████████████████| 160/160 [00:09<00:00, 16.85it/s]


## GLS

In [64]:
# simmodel
SIMMODEL=2

# number of re-samples
K = 100  # Select at least 10_000 for best results

# bins for histograms
N_GRID=10  # Select at least 10 for best results
grid_all_2 = [
    numpy.linspace(0, 1, N_GRID),
    numpy.linspace(0, 1, N_GRID),
    numpy.linspace(0, 1, N_GRID),
    numpy.linspace(0, 1, N_GRID),
    numpy.linspace(0, 1, N_GRID),
]

In [65]:
# obtain posterior histograms and correlations for optimal designs
entropy_pe_gls_avg, entropy_pe_gls_ind = get_pe_entropies_optimal(
    users=optimal_pe_gls_transf,
    model_list=model_pe_gls_opt_list,
    model_summ_list=model_summ_pe_gls_opt_list,
    prior_samples=prior_2,
    grid_list=grid_all_2,
    simmodel=SIMMODEL,
    num_resample=K
)

100%|█████████████████████████████████████████████████████| 29/29 [08:56<00:00, 18.49s/it]


In [66]:
# obtain posterior histograms and correlations for optimal designs
entropy_pe_gls_base_avg, entropy_pe_gls_base_ind = get_pe_entropies_baseline(
    users=base_pe_transf,
    model_dict=model_pe_gls_base_dict,
    model_summ_dict=model_summ_pe_gls_base_dict,
    prior_samples=prior_2,
    grid_list=grid_all_2,
    simmodel=SIMMODEL,
    num_resample=K,
    baseline_allocation=baseline_md_postargmax
)

100%|███████████████████████████████████████████████████| 160/160 [05:18<00:00,  1.99s/it]


# Save Data

## Posterior Data

In [67]:
# MD Data
md_savedata = {
    'posts_opt_ensemble': posts_md_opt_ensemble,
    'user_opt_data': optimal_md_transf,
    'posts_base_ensemble': posts_md_base_ensemble,
    'user_base_data': base_md_transf,
    'postargmax_base': baseline_md_postargmax,
    'post_shapes': ['num users', 'ensemble repeats', 'num indicators']}

# PE WSLTS Data
pe_wslts_savedata = {
    'posts_opt_ensemble': posts_pe_wslts_ensemble,
    'corrs_opt_ensemble': corrs_pe_wslts_ensemble,
    'user_opt_data': optimal_pe_wslts_transf,
    'posts_base_ensemble': posts_pe_wslts_ensemble_base,
    'corrs_base_ensemble': corrs_pe_wslts_ensemble_base,
    'user_base_data': base_pe_transf,
    'prior_samples': prior_0,
    'bins': bins_all_0,
    'samples': K,
    'post_shapes': ['num users', 'ensemble repeats', 'num parameters', 'num_bins']}

# PE AEG Data
pe_aeg_savedata = {
    'posts_opt_ensemble': posts_pe_aeg_ensemble,
    'corrs_opt_ensemble': corrs_pe_aeg_ensemble,
    'user_opt_data': optimal_pe_aeg_transf,
    'posts_base_ensemble': posts_pe_aeg_ensemble_base,
    'corrs_base_ensemble': corrs_pe_aeg_ensemble_base,
    'user_base_data': base_pe_transf,
    'prior_samples': prior_1,
    'bins': bins_all_1,
    'samples': K,
    'post_shapes': ['num users', 'ensemble repeats', 'num parameters', 'num_bins']}

# PE GLS Data
pe_gls_savedata = {
    'posts_opt_ensemble': posts_pe_gls_ensemble,
    'corrs_opt_ensemble': corrs_pe_gls_ensemble,
    'user_opt_data': optimal_pe_gls_transf,
    'posts_base_ensemble': posts_pe_gls_ensemble_base,
    'corrs_base_ensemble': corrs_pe_gls_ensemble_base,
    'user_base_data': base_pe_transf,
    'prior_samples': prior_2,
    'bins': bins_all_2,
    'samples': K,
    'post_shapes': ['num users', 'ensemble repeats', 'num parameters', 'num_bins']}

In [68]:
torch.save(md_savedata, '../data/md_posts_savedata.pt')
torch.save(pe_wslts_savedata, '../data/pe_wslts_posts_savedata.pt')
torch.save(pe_aeg_savedata, '../data/pe_aeg_posts_savedata.pt')
torch.save(pe_gls_savedata, '../data/pe_gls_posts_savedata.pt')

## Entropy Data

In [69]:
# MD Entropy Data
md_entropies_data = {
    'opt_ind': md_entropies_opt,
    'base_ind': md_entropies_base,
}
torch.save(md_entropies_data, '../data/md_entropies.pt')

# PE WSLTS Entropy Data
pe_wslts_entropies_data = {
    'opt_avg': entropy_pe_wslts_avg,
    'base_avg': entropy_pe_wslts_base_avg,
    'opt_ind': entropy_pe_wslts_ind,
    'base_ind': entropy_pe_wslts_base_ind,
    'extra': {'K': 1_000, 'N_GRID': 10}
}
torch.save(pe_wslts_entropies_data, '../data/pe_wslts_entropies.pt')

# PE AEG Entropy Data
pe_aeg_entropies_data = {
    'opt_avg': entropy_pe_aeg_avg,
    'base_avg': entropy_pe_aeg_base_avg,
    'opt_ind': entropy_pe_aeg_ind,
    'base_ind': entropy_pe_aeg_base_ind,
    'extra': {'K': 1_000, 'N_GRID': 10}
}
torch.save(pe_wslts_entropies_data, '../data/pe_aeg_entropies.pt')

# PE GLS Entropy Data
pe_gls_entropies_data = {
    'opt_avg': entropy_pe_gls_avg,
    'base_avg': entropy_pe_gls_base_avg,
    'opt_ind': entropy_pe_gls_ind,
    'base_ind': entropy_pe_gls_base_ind,
    'extra': {'K': 100, 'N_GRID': 10}
}
torch.save(pe_wslts_entropies_data, '../data/pe_gls_entropies.pt')