This notebook applies the proposed methods to the ecg data. As data are not publicly available, this can not be directly run, but can easily be adapted to run on your data.

In [None]:
import sys
sys.path.append('../')

from utils_generation import *
from utils_classification import *

import pickle

# For reproducibility
np.random.seed(0)

In [None]:
DATA_PATH = '../../data/multimodal/missingness_shd_embeddings/'
RESULTS_PATH = './results_shd'

PLOT_PATH = os.path.join(RESULTS_PATH, 'plots')
PRED_PATH = os.path.join(RESULTS_PATH, 'preds')

# Define path to open data
create_path_and_all_parents(RESULTS_PATH)
create_path_and_all_parents(PLOT_PATH)
create_path_and_all_parents(PRED_PATH)

print(f'Saving results to {RESULTS_PATH}')

data_path_dict = {'x1_feature_path':os.path.join(DATA_PATH, 'cxr.csv'),
                  'x2_feature_path':os.path.join(DATA_PATH, 'ecg.csv'),
                  'demo_feature_path':os.path.join(DATA_PATH, 'demo.csv'),
                  'classification_label_path':os.path.join(DATA_PATH, 'labels.csv')}

In [None]:
# Open data (each file contains a different modality - columns are features, rows are patients)
x1_df = pd.read_csv(data_path_dict['x1_feature_path'], index_col=0)
x2_df = pd.read_csv(data_path_dict['x2_feature_path'], index_col=0)

# We use an additional demographic feature set to improve ipw estimation
demo_df = pd.read_csv(data_path_dict['demo_feature_path'], index_col=0).loc[x2_df.index]


# Standardize all data
x1_df = pd.DataFrame(StandardScaler().fit_transform(x1_df), 
                     columns=x1_df.columns, index=x1_df.index)
x2_df = pd.DataFrame(StandardScaler().fit_transform(x2_df), 
                     columns=x2_df.columns, index=x2_df.index)
demo_df = pd.DataFrame(StandardScaler().fit_transform(demo_df), 
                     columns=demo_df.columns, index=demo_df.index)

x1_x2_df = pd.concat([x1_df, x2_df], axis = 1).dropna()

label_df = pd.read_csv(data_path_dict['classification_label_path'], index_col=0)
miss_label_df = label_df.cxr_observed_label == 0

In [None]:
# Label file contains the split 
train_index = label_df.loc[label_df.loc[:, 'missingness_data_split'] == 'train'].index
val_index = label_df.loc[label_df.loc[:, 'missingness_data_split'] == 'val'].index
test_index = label_df.loc[label_df.loc[:, 'missingness_data_split'] == 'test'].index

In [None]:
grid_search = {'layers': [[32] * 2]}

In [None]:
# Estimate IPW
path = os.path.join(PRED_PATH, 'p_m_hat.pickle')

if os.path.isfile(path):
    p_m_hat = pickle.load(open(path, 'rb'))
else:
    # Estimate the missingness probabilities - ASSUMING MAR
    regressor = pd.concat([demo_df, x2_df], axis = 1)
    p_m_hat = train_logistic_regression_and_get_prediction_probabilities(regressor.loc[train_index], 
                                                        miss_label_df.loc[train_index], 
                                                        regressor.loc[val_index], 
                                                        miss_label_df.loc[val_index], 
                                                        regressor, clip = True)

    pickle.dump(p_m_hat, open(path, 'wb'))

# Compare performances

In [None]:
path_predictions = os.path.join(PRED_PATH, 'all_setting_predictions.pickle')
path_metrics = os.path.join(PRED_PATH, 'all_setting_metrics.pickle')

if os.path.isfile(path_predictions):
    all_setting_predictions = pickle.load(open(path_predictions, 'rb'))
    all_setting_metrics = pickle.load(open(path_metrics, 'rb'))
else:
    # Predictions across info settings
    all_setting_predictions = {}
    all_setting_metrics = {}

for observation in ['corrected', 'observed'] :
    for modality, modality_name in zip([x1_df, x2_df, x1_x2_df], ['x1', 'x2', 'x1_x2']):        
        if (modality_name, observation) in all_setting_predictions: continue
        
        all_setting_predictions[(modality_name, observation)] = {}
        all_setting_metrics[(modality_name, observation)] = {}

        # Compute under missingness 
        observed = miss_label_df == 0
        data = modality.loc[observed]

        # Split data
        train = data.loc[label_df.loc[observed, 'missingness_data_split'] == 'train']
        val = data.loc[label_df.loc[observed, 'missingness_data_split'] == 'val']
        test = data.loc[label_df.loc[observed, 'missingness_data_split'] == 'test']
        eval = test
        

        # Estimate IPW weights
        p_m = miss_label_df.mean() # observed
        p_hat = p_m_hat
        ipw_weights = (1 - p_m) / (1 - p_hat)  

        # Train with IPW
        all_setting_predictions[(modality_name, observation)] = train_mlp_and_get_prediction_probabilities(train, label_df.shd_composite_label.loc[train.index], 
                                                                                                            val, label_df.shd_composite_label.loc[val.index], 
                                                                                                            modality, 
                                                                                                            sample_weight=ipw_weights.loc[train.index] if observation == 'corrected' else None, 
                                                                                                            weight_val=ipw_weights.loc[val.index] if observation == 'corrected' else None, 
                                                                                                            grid_search=grid_search)

        # Evaluate
        all_setting_metrics[(modality_name, observation)] = get_classification_metric_dict(y_true= label_df.shd_composite_label.loc[eval.index], 
                                                                                        y_pred = all_setting_predictions[(modality_name, observation)].loc[eval.index],
                                                                                        ipw_weights= ipw_weights.loc[eval.index] if observation == 'corrected' else None)
    
        pickle.dump(all_setting_predictions, open(path_predictions, 'wb'))
        pickle.dump(all_setting_metrics, open(path_metrics, 'wb'))

In [None]:
# Display results
for model in all_setting_metrics:
    print(model[0].replace('x1', "CXR").replace('x2', "ECG").replace('_', ' + '), model[1], ' - AUC = {:.2f} ({:.2f})'.format(all_setting_metrics[model]['auroc'], all_setting_metrics[model]['auroc_std']))

# Compare decomposition

In [None]:
from information_decomposition import *

### Decomposition

In [None]:
path_decomposition = os.path.join(PRED_PATH, 'all_setting_decompositions.pickle')

if os.path.isfile(path_decomposition):
    all_setting_decompositions = pickle.load(open(path_decomposition, 'rb'))
else:
    # Predictions across info settings
    all_setting_decompositions = {}

    # Compute under missingness 
    observed = miss_label_df == 0

    # Split data
    train = label_df.loc[observed, 'missingness_data_split'] == 'train'
    val = label_df.loc[observed, 'missingness_data_split'] == 'val'
    test = label_df.loc[observed, 'missingness_data_split'] == 'test'
    train, val, test = train[train], val[val], test[test]

    p_m = miss_label_df.mean() # observed
    p_hat = p_m_hat
    ipw_weights = (1 - p_m) / (1 - p_hat)  

    # Estimate with IPW weights
    p_y_given_x1_x2 = all_setting_predictions[('x1_x2', 'corrected')]
    p_y_given_x1 = all_setting_predictions[('x1', 'corrected')]
    p_y_given_x2 = all_setting_predictions[('x2', 'corrected')]

    estimator = QEstimator(x1_df.loc[train.index].values, x2_df.loc[train.index].values, 
                                x1_df.loc[val.index].values, x2_df.loc[val.index].values, 
                                p_y_given_x1.loc[train.index].values, p_y_given_x2.loc[train.index].values, 
                                p_y_given_x1.loc[val.index].values, p_y_given_x2.loc[val.index].values,
                                ipw_weights.loc[train.index].values, ipw_weights.loc[val.index].values,
                                grid_search=grid_search, epochs = 100, device='cuda:0')

    all_setting_decompositions['corrected'] = pid_decomposition_batched(estimator, x1_df.loc[test.index].values, x2_df.loc[test.index].values, 
                                                                        p_y_given_x1.loc[test.index].values, p_y_given_x2.loc[test.index].values,
                                                                        p_y_given_x1_x2.loc[test.index].values, 
                                                                        label_df.shd_composite_label.loc[test.index].values, ipw_weights.loc[test.index].values)

    # Compute with no correction
    p_y_given_x1_x2 = all_setting_predictions[('x1_x2', 'observed')]
    p_y_given_x1 = all_setting_predictions[('x1', 'observed')]
    p_y_given_x2 = all_setting_predictions[('x2', 'observed')]

    estimator = QEstimator(x1_df.loc[train.index].values, x2_df.loc[train.index].values,
                    x1_df.loc[val.index].values, x2_df.loc[val.index].values,
                    p_y_given_x1.loc[train.index].values, p_y_given_x2.loc[train.index].values, 
                    p_y_given_x1.loc[val.index].values, p_y_given_x2.loc[val.index].values,
                    grid_search=grid_search, epochs = 100, device='cuda:0')

    all_setting_decompositions['observed'] = pid_decomposition_batched(estimator, x1_df.loc[test.index].values, x2_df.loc[test.index].values, 
                                                                        p_y_given_x1.loc[test.index].values, p_y_given_x2.loc[test.index].values, 
                                                                        p_y_given_x1_x2.loc[test.index].values, 
                                                                        label_df.shd_composite_label.loc[test.index].values)
    
    pickle.dump(all_setting_decompositions, open(path_decomposition, 'wb'))

In [None]:
# Display PID results
for model in all_setting_decompositions:
    print(model)
    for pid in all_setting_decompositions[model]:
        if '_std' in pid: continue
        print(pid.replace('1', "CXR").replace('2', "ECG"), ' = {:.2f} ({:.2f})'.format(all_setting_decompositions[model][pid], all_setting_decompositions[model][pid + '_std']))