This notebook executes the proposed information estimation methods on the synthetic data obtained by running the `generate_simulation_data.py` notebook.

In [1]:
import sys
sys.path.append('../')

from utils_generation import *
from utils_classification import *

# For reproducibility
np.random.seed(0)

import pickle
from sklearn.metrics import root_mean_squared_error

import matplotlib.pyplot as plt
import seaborn as sns
custom_params = {"axes.spines.right": False, "axes.spines.top": False, "axes.spines.left": False,
                 "axes.spines.bottom": False, "figure.dpi": 300, 'savefig.dpi': 300}
sns.set_theme(style = "whitegrid", rc = custom_params, font_scale = 1.75)

In [2]:
# Should match the parameters used in the data generation
N_SAMPLES = 10000
CLASSIFICATION_FLIP_PROB = 0.2
DATA_PARENT_PATH = './data'
RESULTS_PARENT_PATH = './results'

DATA_PATH = os.path.join(DATA_PARENT_PATH, f'nsamples_{N_SAMPLES}_flipprob_{CLASSIFICATION_FLIP_PROB}')
RESULTS_PATH = os.path.join(RESULTS_PARENT_PATH, f'nsamples_{N_SAMPLES}_flipprob_{CLASSIFICATION_FLIP_PROB}')
PLOT_PATH = os.path.join(RESULTS_PATH, 'plots')
PRED_PATH = os.path.join(RESULTS_PATH, 'preds')


# Define path to open data
create_path_and_all_parents(RESULTS_PATH)
create_path_and_all_parents(PLOT_PATH)
create_path_and_all_parents(PRED_PATH)

print(f'Saving results to {RESULTS_PATH}')

data_path_dict = {'x1_feature_path':os.path.join(DATA_PATH, 'continuous_x1_features.csv'),
                  'x2_feature_path':os.path.join(DATA_PATH, 'continuous_x2_features.csv'),
                  'classification_label_path':os.path.join(DATA_PATH, 'classification_labels.csv'),
                  'miss_probs_path':os.path.join(DATA_PATH, 'miss_probs.csv'),
                  'miss_label_path':os.path.join(DATA_PATH, 'miss_labels.csv')}

In [3]:
# CSV open
x1_df = pd.read_csv(data_path_dict['x1_feature_path'])
x2_df = pd.read_csv(data_path_dict['x2_feature_path'])
x1_x2_df = pd.concat([x1_df, x2_df], axis = 1)

label_df = pd.read_csv(data_path_dict['classification_label_path'])
miss_label_df = pd.read_csv(data_path_dict['miss_label_path'], header=[0,1])
miss_probs_df = pd.read_csv(data_path_dict['miss_probs_path'], header=[0,1])

classification_settings =  label_df.columns[:-1]

In [4]:
grid_search = {'layers': [[32] * 2]}

In [5]:
path = os.path.join(PRED_PATH, 'p_m_hat.pickle')

if os.path.isfile(path):
    p_m_hat = pickle.load(open(path, 'rb'))
else:
    # Estimate the missingness probabilities - ASSUMING MAR
    p_m_hat = {}
    for classification_setting in classification_settings:
        p_m_hat[classification_setting] = miss_probs_df[(classification_setting, '0.5')]

In [6]:
train_index = label_df.loc[label_df.loc[:, 'data_split'] == 'train'].index
val_index = label_df.loc[label_df.loc[:, 'data_split'] == 'val'].index
test_index = label_df.loc[label_df.loc[:, 'data_split'] == 'test'].index

# Compare performances

In [7]:
path_predictions = os.path.join(PRED_PATH, 'all_setting_predictions.pickle')
path_metrics = os.path.join(PRED_PATH, 'all_setting_metrics.pickle')

if os.path.isfile(path_predictions):
    all_setting_predictions = pickle.load(open(path_predictions, 'rb'))
    all_setting_metrics = pickle.load(open(path_metrics, 'rb'))
else:
    # Predictions across info settings
    all_setting_predictions = {}
    all_setting_metrics = {}

# Loop across settings
for classification_setting in classification_settings:
    print(f'Classification setting: {classification_setting}')

    if classification_setting not in all_setting_predictions: 
        all_setting_predictions[classification_setting] = {}
        all_setting_metrics[classification_setting] = {}
    
    for modality, modality_name in zip([x1_df, x2_df, x1_x2_df], ['x1', 'x2', 'x1_x2']):   
        if modality_name in all_setting_predictions[classification_setting]: continue
        all_setting_predictions[classification_setting][modality_name] = {}
        all_setting_metrics[classification_setting][modality_name] = {}

        observed = miss_label_df[(classification_setting, '0.5')] == 0
        train_index_class = train_index[observed.loc[train_index]]
        val_index_class = val_index[observed.loc[val_index]]
        test_index_class = test_index[observed.loc[test_index]]

        # Train
        all_setting_predictions[classification_setting][modality_name]['all'] = train_mlp_and_get_prediction_probabilities(modality.loc[train_index], label_df.loc[train_index, classification_setting], 
                                                                                                                    modality.loc[val_index], label_df.loc[val_index, classification_setting], 
                                                                                                                    modality, grid_search=grid_search)

        # Evaluate    
        all_setting_metrics[classification_setting][modality_name]['all'] = {'all': get_classification_metric_dict(y_true= label_df.loc[test_index, classification_setting], 
                                                                                                    y_pred = all_setting_predictions[classification_setting][modality_name]['all'].loc[test_index])}


        # Estimate IPW weights
        p_hat = p_m_hat[classification_setting]
        ipw_weights = 0.5 / (1 - p_hat)  

        # Train with normalization
        all_setting_predictions[classification_setting][modality_name]['observed'] = train_mlp_and_get_prediction_probabilities(modality.loc[train_index_class], label_df.loc[train_index_class, classification_setting], 
                                                                                                                    modality.loc[val_index_class], label_df.loc[val_index_class, classification_setting], 
                                                                                                                    modality, grid_search=grid_search)

        # Evaluate    
        all_setting_metrics[classification_setting][modality_name]['observed'] = {'all': get_classification_metric_dict(y_true= label_df.loc[test_index, classification_setting], 
                                                                                                    y_pred = all_setting_predictions[classification_setting][modality_name]['observed'].loc[test_index]),
                                                                                  'observed': get_classification_metric_dict(y_true= label_df.loc[test_index_class, classification_setting],
                                                                                                    y_pred = all_setting_predictions[classification_setting][modality_name]['observed'].loc[test_index_class]), 
                                                                                  'corrected': get_classification_metric_dict(y_true= label_df.loc[test_index_class, classification_setting],
                                                                                                                              y_pred = all_setting_predictions[classification_setting][modality_name]['observed'].loc[test_index_class],
                                                                                                                              ipw_weights=ipw_weights.loc[test_index_class])}

        # Train with IPW
        all_setting_predictions[classification_setting][modality_name]['corrected'] = train_mlp_and_get_prediction_probabilities(modality.loc[train_index_class], label_df.loc[train_index_class, classification_setting], 
                                                                                                                    modality.loc[val_index_class], label_df.loc[val_index_class, classification_setting], 
                                                                                                                    modality, grid_search=grid_search,
                                                                                                                    sample_weight=ipw_weights.loc[train_index_class], 
                                                                                                                    weight_val=ipw_weights.loc[val_index_class])

        # Evaluate
        all_setting_metrics[classification_setting][modality_name]['corrected'] = {'all': get_classification_metric_dict(y_true= label_df.loc[test_index, classification_setting], 
                                                                                                    y_pred = all_setting_predictions[classification_setting][modality_name]['corrected'].loc[test_index]),
                                                                                  'observed': get_classification_metric_dict(y_true= label_df.loc[test_index_class, classification_setting],
                                                                                                    y_pred = all_setting_predictions[classification_setting][modality_name]['corrected'].loc[test_index_class]), 
                                                                                  'corrected': get_classification_metric_dict(y_true= label_df.loc[test_index_class, classification_setting],
                                                                                                                              y_pred = all_setting_predictions[classification_setting][modality_name]['corrected'].loc[test_index_class],
                                                                                                                              ipw_weights=ipw_weights.loc[test_index_class])}

        # Save predictions and metrics
        pickle.dump(all_setting_predictions, open(path_predictions, 'wb'))
        pickle.dump(all_setting_metrics, open(path_metrics, 'wb'))

# Compare decomposition

In [8]:
from information_decomposition import *

### Decomposition

In [9]:
path_decomposition = os.path.join(PRED_PATH, 'all_setting_decompositions.pickle')

if os.path.isfile(path_decomposition):
    all_setting_decompositions = pickle.load(open(path_decomposition, 'rb'))
else:
    # Predictions across info settings
    all_setting_decompositions = {}

np.random.seed(0)
# Loop across settings
for classification_setting in np.random.choice(classification_settings, replace = False, size = len(classification_settings)):
    print(f'Classification setting: {classification_setting}')
    if classification_setting not in all_setting_decompositions:
        all_setting_decompositions[classification_setting] = {}
    else:
        continue

    p_y_given_x1_x2 = all_setting_predictions[classification_setting]['x1_x2']['all']
    p_y_given_x1 = all_setting_predictions[classification_setting]['x1']['all']
    p_y_given_x2 = all_setting_predictions[classification_setting]['x2']['all']

    estimator = QEstimator(x1_df.loc[train_index].values, x2_df.loc[train_index].values,
                                x1_df.loc[val_index].values, x2_df.loc[val_index].values,
                                p_y_given_x1.loc[train_index].values, p_y_given_x2.loc[train_index].values, 
                                p_y_given_x1.loc[val_index].values, p_y_given_x2.loc[val_index].values,
                                grid_search=grid_search, epochs = 100, device='cuda:0')

    all_setting_decompositions[classification_setting]['all'] = {'all': pid_decomposition_batched(estimator, x1_df.loc[test_index].values, x2_df.loc[test_index].values, 
                                            p_y_given_x1.loc[test_index].values, p_y_given_x2.loc[test_index].values, 
                                            p_y_given_x1_x2.loc[test_index].values, 
                                            label_df[classification_setting].loc[test_index].values)}

    p_hat = p_m_hat[classification_setting]
    ipw_weights = 0.5 / (1 - p_hat)  

    observed = miss_label_df[(classification_setting, '0.5')] == 0
    train_index_class = train_index[observed.loc[train_index]]
    val_index_class = val_index[observed.loc[val_index]]
    test_index_class = test_index[observed.loc[test_index]]

    p_y_given_x1_x2 = all_setting_predictions[classification_setting]['x1_x2']['observed']
    p_y_given_x1 = all_setting_predictions[classification_setting]['x1']['observed']
    p_y_given_x2 = all_setting_predictions[classification_setting]['x2']['observed']

    estimator = QEstimator(x1_df.loc[train_index_class].values, x2_df.loc[train_index_class].values,
                                x1_df.loc[val_index_class].values, x2_df.loc[val_index_class].values,
                                p_y_given_x1.loc[train_index_class].values, p_y_given_x2.loc[train_index_class].values, 
                                p_y_given_x1.loc[val_index_class].values, p_y_given_x2.loc[val_index_class].values,
                                grid_search=grid_search, epochs = 100, device='cuda:0')

    all_setting_decompositions[classification_setting]['observed'] = {'observed': pid_decomposition_batched(estimator, x1_df.loc[test_index_class].values, x2_df.loc[test_index_class].values, 
                                            p_y_given_x1.loc[test_index_class].values, p_y_given_x2.loc[test_index_class].values, 
                                            p_y_given_x1_x2.loc[test_index_class].values, 
                                            label_df[classification_setting].loc[test_index_class].values),
                                        'corrected': pid_decomposition_batched(estimator, x1_df.loc[test_index_class].values, x2_df.loc[test_index_class].values,
                                            p_y_given_x1.loc[test_index_class].values, p_y_given_x2.loc[test_index_class].values,
                                            p_y_given_x1_x2.loc[test_index_class].values,
                                            label_df[classification_setting].loc[test_index_class].values,
                                            ipw_weights.loc[test_index_class].values)}
    
    # Estimate with IPW weights
    p_y_given_x1_x2 = all_setting_predictions[classification_setting]['x1_x2']['corrected']
    p_y_given_x1 = all_setting_predictions[classification_setting]['x1']['corrected']
    p_y_given_x2 = all_setting_predictions[classification_setting]['x2']['corrected']

    estimator = QEstimator(x1_df.loc[train_index_class].values, x2_df.loc[train_index_class].values,
                                x1_df.loc[val_index_class].values, x2_df.loc[val_index_class].values,
                                p_y_given_x1.loc[train_index_class].values, p_y_given_x2.loc[train_index_class].values, 
                                p_y_given_x1.loc[val_index_class].values, p_y_given_x2.loc[val_index_class].values,
                                ipw_weights.loc[train_index_class].values, ipw_weights.loc[val_index_class].values,
                                grid_search=grid_search, epochs = 100, device='cuda:0')

    all_setting_decompositions[classification_setting]['corrected'] = {'observed': pid_decomposition_batched(estimator, x1_df.loc[test_index_class].values, x2_df.loc[test_index_class].values, 
                                            p_y_given_x1.loc[test_index_class].values, p_y_given_x2.loc[test_index_class].values, 
                                            p_y_given_x1_x2.loc[test_index_class].values, 
                                            label_df[classification_setting].loc[test_index_class].values),
                                        'corrected': pid_decomposition_batched(estimator, x1_df.loc[test_index_class].values, x2_df.loc[test_index_class].values,
                                            p_y_given_x1.loc[test_index_class].values, p_y_given_x2.loc[test_index_class].values,
                                            p_y_given_x1_x2.loc[test_index_class].values,
                                            label_df[classification_setting].loc[test_index_class].values,
                                            ipw_weights.loc[test_index_class].values)}

    pickle.dump(all_setting_decompositions, open(path_decomposition, 'wb'))

# Visualize correlation

In [10]:
naming = {
    'Unique 1': 'unique1', 
    'Unique 2': 'unique2', 
    'Shared': 'shared', 
    'Complementary': 'complementary',
    r'$X_1$': 'x2',
    r'$X_2$': 'x1',
    r'$X_1 + X_2$': 'x1_x2'
}

In [49]:
training = 'observed' # 'corrected' or 'observed'
evaluation = 'observed' # 'all', 'observed', 'corrected'

In [50]:
current_parent_path = os.path.join(PLOT_PATH, 'plots')
create_path_and_all_parents(current_parent_path)

plt.figure()
plt.xlabel('Estimated PID')
plt.ylabel('Oracle PID')
for pid, color in zip(['Unique 1', 'Unique 2', 'Shared', 'Complementary'], ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']):
    pids = [all_setting_decompositions[classification_setting][training][evaluation][naming[pid]] for classification_setting in all_setting_decompositions]
    oracle_pids = [all_setting_decompositions[classification_setting]['all']['all'][naming[pid]] for classification_setting in all_setting_decompositions]
    plt.scatter(pids, oracle_pids, alpha = 0.25, color = color, s = 100)

    # Calculate RMSE and fit
    perfs, perfs_oracle = np.array(pids), np.array(oracle_pids)
    x = np.linspace(0, 0.25, 100)
    curves, slopes, rmse = [], [], []
    np.random.seed(0)
    for i in range(100):
        sample = np.random.choice(len(perfs), size = len(perfs), replace = True)
        m, c = np.polyfit(perfs[sample], perfs_oracle[sample], 1)
        curves.append(m * x + c)
        slopes.append(m)
        rmse.append(root_mean_squared_error(perfs_oracle[sample], perfs[sample]))

    mean = np.mean(curves, axis = 0)
    std = np.std(curves, axis = 0)
    plt.fill_between(x, mean - std, mean + std, color = color, alpha = 0.25)
    plt.plot(x, mean, color = color, ls = '--', alpha = 0.75, lw = 2, label = r'$\epsilon =$' + '{:.3f} ({:.3f})'.format(np.mean(rmse), np.std(rmse)))
    plt.xlim(-0.01, 0.25)
    plt.ylim(-0.01, 0.35)

plt.legend(loc='upper left')
plt.savefig(os.path.join(current_parent_path, f'{training}_{evaluation}_pid.png'), bbox_inches = 'tight', dpi = 400)
plt.show()

In [51]:
current_parent_path = os.path.join(PLOT_PATH, 'plots')
create_path_and_all_parents(current_parent_path)

plt.figure()
plt.xlabel('Estimated AUC')
plt.ylabel('Oracle AUC')
for model, color in zip([r'$X_1$', r'$X_2$', r'$X_1 + X_2$'], ['#1f77b4', '#ff7f0e', '#2ca02c']):
    perfs = [all_setting_metrics[classification_setting][naming[model]][training][evaluation]['auroc'] for classification_setting in all_setting_metrics]
    perfs_oracle = [all_setting_metrics[classification_setting][naming[model]]['all']['all']['auroc'] for classification_setting in all_setting_metrics]

    plt.scatter(perfs, perfs_oracle, alpha = 0.25, color = color, s = 100)

    perfs, perfs_oracle = np.array(perfs), np.array(perfs_oracle)
    x = np.linspace(0, 1, 100)
    curves, slopes, rmse = [], [], []
    np.random.seed(0)
    for i in range(100):
        sample = np.random.choice(len(perfs), size = len(perfs), replace = True)
        m, c = np.polyfit(perfs[sample], perfs_oracle[sample], 1)
        curves.append(m * x + c)
        slopes.append(m)
        rmse.append(root_mean_squared_error(perfs_oracle[sample], perfs[sample]))

    mean = np.mean(curves, axis = 0)
    std = np.std(curves, axis = 0)
    plt.fill_between(x, mean - std, mean + std, color = color, alpha = 0.25)
    plt.plot(x, mean, color = color, ls = '--', alpha = 0.75, lw = 2, label = r'$\epsilon =$' + '{:.3f} ({:.3f})'.format(np.mean(rmse), np.std(rmse)))
    plt.xlim(0.5, 0.8)
    plt.ylim(0.5, 0.8)

plt.legend(loc='lower right')
plt.savefig(os.path.join(current_parent_path, f'{training}_{evaluation}_perf.png'), bbox_inches = 'tight', dpi = 400)
plt.show()

# Visualize correlation PID and AUC

In [71]:
evaluation = 'all' # 'all', 'observed', 'corrected'

In [72]:
current_parent_path = os.path.join(PLOT_PATH, 'plots')
create_path_and_all_parents(current_parent_path)

fig, axes = plt.subplots(1, 4, sharey=True, figsize=(14, 5))
axes[0].set_ylabel('Oracle AUC')
for pid, ax in zip(['Unique 1', 'Unique 2', 'Shared', 'Complementary'], axes):
    ax.set_xlabel(pid)
    for model, color in zip([r'$X_1$', r'$X_2$', r'$X_1 + X_2$'], ['#1f77b4', '#ff7f0e', '#2ca02c']):
        pids = [all_setting_decompositions[classification_setting][evaluation][evaluation][naming[pid]] for classification_setting in all_setting_metrics]
        perfs_oracle = [all_setting_metrics[classification_setting][naming[model]]['all']['all']['auroc'] for classification_setting in all_setting_metrics]

        ax.scatter(pids, perfs_oracle, alpha = 0.25, color = color, s = 100)

        pids, perfs_oracle = np.array(pids), np.array(perfs_oracle)
        x = np.linspace(0, pids.max(), 100)
        curves, slopes, corr = [], [], []
        np.random.seed(0)
        for i in range(100):
            sample = np.random.choice(len(pids), size = len(pids), replace = True)
            m, c = np.polyfit(pids[sample], perfs_oracle[sample], 1)
            curves.append(m * x + c)
            slopes.append(m)
            corr.append(np.corrcoef(pids[sample], perfs_oracle[sample])[0, 1])

        mean = np.mean(curves, axis = 0)
        std = np.std(curves, axis = 0)
        ax.fill_between(x, mean - std, mean + std, color = color, alpha = 0.25)
        ax.plot(x, mean, color = color, ls = '--', alpha = 0.75, lw = 2, label = r'$\alpha =$' + '{:.2f}'.format(np.mean(slopes)))

        ax.set_ylim(0.5, 1)
        ax.legend(loc='upper center',  bbox_to_anchor=(0.5, 1.))
if evaluation == 'observed':
    plt.suptitle('Current Practice')
elif evaluation == 'corrected':
    plt.suptitle(r'$ICYM^2I$')
else:
    plt.suptitle('Oracle')
plt.savefig(os.path.join(current_parent_path, f'{evaluation}_connection.png'), bbox_inches = 'tight', dpi = 400)
plt.show()