This notebook allows to reproduce the semi-synethic results on the humour dataset.

Please download the dataset from []().

In [None]:
import sys
sys.path.append('../')

from utils_generation import *
from utils_classification import *

from tqdm import tqdm
import pickle

import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import seaborn as sns
custom_params = {"axes.spines.right": False, "axes.spines.top": False, "axes.spines.left": False,
                 "axes.spines.bottom": False, "figure.dpi": 300, 'savefig.dpi': 300}
sns.set_theme(style = "whitegrid", rc = custom_params, font_scale = 1.75)

# For reproducibility
np.random.seed(0)

In [None]:
DATA_PATH = '../../data/multimodal/UR-FUNNY_preprocessed/'
RESULTS_PATH = './results_humour'

PLOT_PATH = os.path.join(RESULTS_PATH, 'plots')
PRED_PATH = os.path.join(RESULTS_PATH, 'preds')


# Define path to open data
create_path_and_all_parents(RESULTS_PATH)
create_path_and_all_parents(PLOT_PATH)
create_path_and_all_parents(PRED_PATH)

print(f'Saving results to {RESULTS_PATH}')

data_path_dict = {'x1_feature_path':os.path.join(DATA_PATH, 'audio.csv'),
                  'x2_feature_path':os.path.join(DATA_PATH, 'vision.csv'),
                  'x3_feature_path':os.path.join(DATA_PATH, 'text.csv'),
                  'classification_label_path':os.path.join(DATA_PATH, 'labels.csv')}

In [None]:
# CSV open
x1_df = pd.read_csv(data_path_dict['x3_feature_path'], index_col=0) # text
x2_df = pd.concat([pd.read_csv(data_path_dict['x1_feature_path'], index_col=0), pd.read_csv(data_path_dict['x2_feature_path'], index_col=0)], axis = 1).dropna() # video

# Standardize
x1_df = pd.DataFrame(StandardScaler().fit_transform(x1_df), 
                     columns=x1_df.columns, index=x1_df.index)
x2_df = pd.DataFrame(StandardScaler().fit_transform(x2_df), 
                     columns=x2_df.columns, index=x2_df.index)

x1_x2_df = pd.concat([x1_df, x2_df], axis = 1)
label_df = pd.read_csv(data_path_dict['classification_label_path'], index_col=0)

In [None]:
# Generate synthetic missingness - text is not observed for everyone (missingness)
p_m_1_array = np.arange(0.3, 0.8, 0.1).round(3)


missingness_prob = generate_missingness(x1_df, x2_df, label_df.classification_label, p_m_1_array)
missingness_label = missingness_prob.apply(lambda x: np.random.binomial(1, x, len(x)))

In [None]:
train_index = label_df.loc[label_df.loc[:, 'data_split'] == 'train'].index
val_index = label_df.loc[label_df.loc[:, 'data_split'] == 'valid'].index
test_index = label_df.loc[label_df.loc[:, 'data_split'] == 'test'].index

In [None]:
grid_search = {'layers': [[32] * 2]}

# Compare performances

In [None]:
path_predictions = os.path.join(PRED_PATH, 'all_setting_predictions.pickle')
path_metrics = os.path.join(PRED_PATH, 'all_setting_metrics.pickle')

if os.path.isfile(path_predictions):
    all_setting_predictions = pickle.load(open(path_predictions, 'rb'))
    all_setting_metrics = pickle.load(open(path_metrics, 'rb'))
else:
    # Predictions across info settings
    all_setting_predictions = {}
    all_setting_metrics = {}

for observation in ['all', 'observed', 'corrected'] :
    for modality, modality_name in zip([x1_df, x2_df, x1_x2_df], ['x1', 'x2', 'x1_x2']):    
        if (modality_name, observation) in all_setting_predictions: continue    

        all_setting_predictions[(modality_name, observation)] = {}
        all_setting_metrics[(modality_name, observation)] = {}

        if observation == 'all':
            # Train with normalization
            all_setting_predictions[(modality_name, observation)] = train_mlp_and_get_prediction_probabilities(modality.loc[train_index], label_df.classification_label.loc[train_index], 
                                                                                                                modality.loc[val_index], label_df.classification_label.loc[val_index], 
                                                                                                                modality, grid_search=grid_search)

            # Evaluate    
            all_setting_metrics[(modality_name, observation)] = get_classification_metric_dict(y_true= label_df.classification_label.loc[test_index], 
                                                                                                        y_pred = all_setting_predictions[(modality_name, observation)].loc[test_index])
        
        else:
            for p_m_1 in p_m_1_array:
                # Compute under missingness 
                observed = missingness_label.loc[:, p_m_1] == 0
                data = modality.loc[observed]

                # Split data
                train = data.loc[label_df.loc[observed, 'data_split'] == 'train']
                val = data.loc[label_df.loc[observed, 'data_split'] == 'valid']
                test = data.loc[label_df.loc[observed, 'data_split'] == 'test']
                eval = test
                

                # Estimate IPW weights
                p_m = missingness_label.loc[:, p_m_1].mean() # observed
                p_hat = missingness_prob.loc[:, p_m_1] # estimated
                ipw_weights = (1-p_m) / (1-p_hat)  

                # Train with IPW
                all_setting_predictions[(modality_name, observation)][p_m_1] = train_mlp_and_get_prediction_probabilities(train, label_df.classification_label.loc[train.index], 
                                                                                                                        val, label_df.classification_label.loc[val.index], 
                                                                                                                        modality, 
                                                                                                                        sample_weight=ipw_weights.loc[train.index] if observation == 'corrected' else None, 
                                                                                                                        weight_val=ipw_weights.loc[val.index] if observation == 'corrected' else None, 
                                                                                                                        grid_search=grid_search)

                # Evaluate
                all_setting_metrics[(modality_name, observation)][p_m_1] = get_classification_metric_dict(y_true= label_df.classification_label.loc[eval.index], 
                                                                                                            y_pred = all_setting_predictions[(modality_name, observation)][p_m_1].loc[eval.index],
                                                                                                            ipw_weights= ipw_weights.loc[eval.index] if observation == 'corrected' else None)
                    
        pickle.dump(all_setting_predictions, open(path_predictions, 'wb'))
        pickle.dump(all_setting_metrics, open(path_metrics, 'wb'))

In [None]:
width = 0.02
unadjusted_marker = 'x'
adjusted_marker = 'o'
all_marker = 'D'
alpha_value = 0.5

x_color =  ['#648fff', '#dc267f', '#fe6100']
x_offset = [-width, 0, width]

clean = {
    'causal_shared': 'Shared', 
    'causal_unique_1': 'Unique 1', 
    'causal_unique_2': 'Unique 2', 
    'causal_complementary': 'Complementary'
}

In [None]:
ipw_parent_path = os.path.join(PLOT_PATH, 'ipw_plots')
create_path_and_all_parents(ipw_parent_path)

for current_metric in ['auroc'] :
    # Naming
    if current_metric == 'bce_loss' :
        metric_name_for_title = 'Binary Cross Entropy'
    else :
        metric_name_for_title = current_metric.upper()

    
    super_title = f'{metric_name_for_title}'

    fig, ax = plt.subplots(2,1, figsize = (12, 7))
    ax = ax.flatten()

    for i, p_m_1 in enumerate(p_m_1_array) :      
        p_m_1_float = float(p_m_1)
        for j, (modality, name) in enumerate(zip(['x1', 'x2', 'x1_x2'], ['$X_1$', '$X_2$', 'Multimodal'])):
            unadjusted = all_setting_metrics[(modality, 'observed')][p_m_1][current_metric]
            adjusted = all_setting_metrics[(modality, 'corrected')][p_m_1][current_metric]
            all = all_setting_metrics[(modality, 'all')][current_metric]

            unadjusted_err = all_setting_metrics[(modality, 'observed')][p_m_1][current_metric + '_std']
            adjusted_err = all_setting_metrics[(modality, 'corrected')][p_m_1][current_metric + '_std']
            all_err = all_setting_metrics[(modality, 'all')][current_metric + '_std']

            ax[0].scatter(p_m_1_float + x_offset[j], unadjusted, color=x_color[j], marker = unadjusted_marker, linewidths = 2, label = name + ' Unadjusted' if (i == 0) else None, alpha = alpha_value, s = 100)
            ax[0].errorbar(p_m_1_float + x_offset[j], unadjusted, yerr=unadjusted_err, color=x_color[j], label = None)
            ax[0].scatter(p_m_1_float + x_offset[j], adjusted, color=x_color[j], marker = adjusted_marker, label = name + ' Adjusted' if (i == 0) else None, alpha = alpha_value, s = 100)
            ax[0].errorbar(p_m_1_float + x_offset[j], adjusted, yerr=adjusted_err, color=x_color[j], label = None)
            ax[0].scatter(p_m_1_float + x_offset[j], all, color=x_color[j], marker = all_marker, label = name + ' Oracle' if (i == 0) else None, alpha = alpha_value, s = 100)
            ax[0].errorbar(p_m_1_float + x_offset[j], all, yerr= all_err, color=x_color[j], label = None)

            ax[1].scatter(p_m_1_float + x_offset[j], np.abs(unadjusted - all) - np.abs(adjusted - all), color=x_color[j], label = None, alpha = alpha_value, s = 100)

        if i < len(p_m_1_array) - 1:
            ax[0].axvline(p_m_1_float + 0.05, color = 'grey', alpha = 0.25, linestyle = 'dotted')
            ax[1].axvline(p_m_1_float + 0.05, color = 'grey', alpha = 0.25, linestyle = 'dotted')

    ax[1].axhline(0, color = 'grey', alpha = 0.25, linestyle = 'dotted')
    ax[0].set_xticklabels([])
    ax[1].set_xlabel('Missingness Rate - $P(m_2=1)$', fontweight = 'bold')
    ax[0].legend(bbox_to_anchor=(1, 1))

    ax[0].set_ylabel(metric_name_for_title)
    ax[1].set_ylabel('Absolute error difference')

    ax[0].grid(axis='x', visible=False)
    ax[1].grid(axis='x', visible=False)

    if current_metric == 'auroc' :
        ax[0].set_ylim([.4, .8])
 

    fig.suptitle(super_title, fontweight = 'bold')
    fig.tight_layout()

    current_parent_path = os.path.join(ipw_parent_path, current_metric)
    create_path_and_all_parents(current_parent_path)
    plt.subplots_adjust(hspace=0.1)
    plt.savefig(os.path.join(current_parent_path, f'{super_title}.png'), bbox_inches = 'tight', dpi = 400)

    plt.show()
    plt.close()

# Compare decomposition

In [None]:
from information_decomposition import *

### Decomposition

In [None]:
path_decomposition = os.path.join(PRED_PATH, 'all_setting_decompositions.pickle')

if os.path.isfile(path_decomposition):
    all_setting_decompositions = pickle.load(open(path_decomposition, 'rb'))
else:
    # Predictions across info settings
    all_setting_decompositions = {}

if 'all' not in all_setting_decompositions:
    p_y_given_x1_x2 = all_setting_predictions[('x1_x2', 'all')]
    p_y_given_x1 = all_setting_predictions[('x1', 'all')]
    p_y_given_x2 = all_setting_predictions[('x2', 'all')]

    estimator = QEstimator(x1_df.loc[train_index].values, x2_df.loc[train_index].values,
                                x1_df.loc[val_index].values, x2_df.loc[val_index].values,
                                p_y_given_x1.loc[train_index].values, p_y_given_x2.loc[train_index].values, 
                                p_y_given_x1.loc[val_index].values, p_y_given_x2.loc[val_index].values,
                                grid_search=grid_search, epochs = 100)

    all_setting_decompositions = {
        'all': pid_decomposition_batched(estimator, x1_df.loc[test_index].values, x2_df.loc[test_index].values, 
                                        p_y_given_x1.loc[test_index].values, p_y_given_x2.loc[test_index].values, 
                                        p_y_given_x1_x2.loc[test_index].values, 
                                        label_df.classification_label.loc[test_index].values),
    }
    pickle.dump(all_setting_decompositions, open(path_decomposition, 'wb'))
    print('all', all_setting_decompositions['all'])

for p_m_1 in tqdm(p_m_1_array):
    if ('observed', p_m_1) in all_setting_decompositions: continue
    
    # Compute under missingness 
    observed = missingness_label.loc[:, p_m_1] == 0

    # Split data
    train = label_df.loc[observed, 'data_split'] == 'train'
    val = label_df.loc[observed, 'data_split'] == 'valid'
    test = label_df.loc[observed, 'data_split'] == 'test'
    train, val, test = train[train], val[val], test[test]

    p_m = missingness_label.loc[:, p_m_1].mean() # observed
    p_hat = missingness_prob.loc[:, p_m_1]
    ipw_weights = (1 - p_m) / (1 - p_hat)  

    # Estimate with IPW weights
    p_y_given_x1_x2 = all_setting_predictions[('x1_x2', 'corrected')][p_m_1]
    p_y_given_x1 = all_setting_predictions[('x1', 'corrected')][p_m_1]
    p_y_given_x2 = all_setting_predictions[('x2', 'corrected')][p_m_1]

    estimator = QEstimator(x1_df.loc[train.index].values, x2_df.loc[train.index].values, 
                                x1_df.loc[val.index].values, x2_df.loc[val.index].values, 
                                p_y_given_x1.loc[train.index].values, p_y_given_x2.loc[train.index].values, 
                                p_y_given_x1.loc[val.index].values, p_y_given_x2.loc[val.index].values,
                                ipw_weights.loc[train.index].values, ipw_weights.loc[val.index].values,
                                grid_search=grid_search, epochs = 100)

    all_setting_decompositions[('corrected', p_m_1)] = pid_decomposition_batched(estimator, x1_df.loc[test.index].values, x2_df.loc[test.index].values, 
                                                                                p_y_given_x1.loc[test.index].values, p_y_given_x2.loc[test.index].values,
                                                                                p_y_given_x1_x2.loc[test.index].values, 
                                                                                label_df.classification_label.loc[test.index].values, 
                                                                                ipw_weights.loc[test.index].values)

    # Compute with no correction
    p_y_given_x1_x2 = all_setting_predictions[('x1_x2', 'observed')][p_m_1]
    p_y_given_x1 = all_setting_predictions[('x1', 'observed')][p_m_1]
    p_y_given_x2 = all_setting_predictions[('x2', 'observed')][p_m_1]

    estimator = QEstimator(x1_df.loc[train.index].values, x2_df.loc[train.index].values,
                    x1_df.loc[val.index].values, x2_df.loc[val.index].values,
                    p_y_given_x1.loc[train.index].values, p_y_given_x2.loc[train.index].values, 
                    p_y_given_x1.loc[val.index].values, p_y_given_x2.loc[val.index].values,
                    grid_search=grid_search, epochs = 100, device = 'cuda:1')

    all_setting_decompositions[('observed', p_m_1)] = pid_decomposition_batched(estimator, x1_df.loc[test.index].values, x2_df.loc[test.index].values, 
                                                                                p_y_given_x1.loc[test.index].values, p_y_given_x2.loc[test.index].values, 
                                                                                p_y_given_x1_x2.loc[test.index].values, 
                                                                                label_df.classification_label.loc[test.index].values)
    pickle.dump(all_setting_decompositions, open(path_decomposition, 'wb'))

In [None]:
pid_parent_path = os.path.join(PLOT_PATH, 'pid_plots10')
create_path_and_all_parents(pid_parent_path)

width = 0.015
x_color =  ['#648fff', '#dc267f', '#fe6100', '#50C878']
x_offset = [-2*width, -width, width, 2*width]


fig, ax = plt.subplots(figsize = (10, 4))
        
ax.scatter([], [], color='grey', marker = unadjusted_marker, linewidths = 2, label = 'Observed', s = 100)
ax.scatter([], [], color='grey', marker = adjusted_marker, linewidths = 2, label = r'ICYM$^2$I', s = 100)
ax.scatter([], [], color='grey', marker = all_marker, linewidths = 2, label = 'Oracle', s = 100)
ax.scatter([], [], alpha = 0, label = ' ')

for i, p_m_1 in enumerate(p_m_1_array) :      
    p_m_1_float = float(p_m_1)

    for j, decomposition in enumerate(['unique1', 'unique2', 'complementary', 'shared']):
        unadjusted = max(all_setting_decompositions[('observed', p_m_1)][decomposition], 0.)
        adjusted = max(all_setting_decompositions[('corrected', p_m_1)][decomposition], 0.)
        all_res = max(all_setting_decompositions['all'][decomposition], 0)

        unadjusted_err = all_setting_decompositions[('observed', p_m_1)][decomposition + '_std']
        adjusted_err = all_setting_decompositions[('corrected', p_m_1)][decomposition + '_std']
        all_err = all_setting_decompositions['all'][decomposition + '_std']

        ax.scatter(p_m_1_float + x_offset[j], unadjusted, color=x_color[j], marker = unadjusted_marker, linewidths = 2, alpha = alpha_value, s = 200)           
        ax.errorbar(p_m_1_float + x_offset[j], unadjusted, yerr=unadjusted_err, color=x_color[j])
        ax.scatter(p_m_1_float + x_offset[j], adjusted, color=x_color[j], marker = adjusted_marker,  alpha = alpha_value, s = 200)
        ax.errorbar(p_m_1_float + x_offset[j], adjusted, yerr=adjusted_err, color=x_color[j])
        ax.scatter(p_m_1_float + x_offset[j], all_res, color=x_color[j], marker = all_marker, alpha = alpha_value, s = 200)
        ax.errorbar(p_m_1_float + x_offset[j], all_res, yerr= all_err, color=x_color[j])

        if i == 0:
            ax.scatter([], [], color=x_color[j], marker = 's', label = decomposition.capitalize(), s = 200)

    if i < len(p_m_1_array) - 1:
        ax.axvline(p_m_1_float + 0.05, color = 'grey', alpha = 0.25, linestyle = 'dotted')

plt.legend(bbox_to_anchor=(1, 1))
plt.grid(axis='x', visible=False)

plt.ylabel('PID Values')
plt.xlabel('Missigness Rate')
ax.xaxis.set_major_formatter(mtick.PercentFormatter(xmax=1.0, decimals=1))

plt.savefig(os.path.join(pid_parent_path, f'{super_title}.png'), bbox_inches = 'tight', dpi = 400)

plt.show()
plt.close()