In [None]:
import sys
sys.path.append('../implementation/')
from gotz_adaptive_contextualization import AC
import pandas as pd
import ast
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import warnings
warnings.filterwarnings('ignore')

In [None]:
underlying_data = pd.read_csv('../data/political/final/political.csv')
clean_id = [s.replace("p", "") for s in underlying_data['id']]
clean_id = [int(s.lstrip('0')) - 1 for s in clean_id]
underlying_data['id'] = clean_id
underlying_data = underlying_data.set_index('id')
underlying_data = underlying_data.sort_index()
output_path = '../output/political/political_ac_mouse.pkl'

interaction_data = pd.read_csv('../data/political/final/wall_political_interactions_mouse.csv')
interaction_data['interaction_session'] = interaction_data.apply(lambda row: ast.literal_eval(row.interaction_session), axis=1)
interaction_data['interaction_type'] = interaction_data.apply(lambda row: ast.literal_eval(row.interaction_type), axis=1)
c_attrs = ['age', 'political_experience', 'policy_strength_ban_abortion_after_6_weeks', 
           'policy_strength_legalize_medical_marijuana', 'policy_strength_increase_medicare_funding',
          'policy_strength_ban_alcohol_sales_sundays']
d_attrs = ['party', 'gender', 'occupation']

In [None]:
ac_results = pd.DataFrame()
for participant_index, row in interaction_data.iterrows():
    print(f'Processing user {row.user}')
    results = {'participant_id': row.user}
    ac_model = AC(underlying_data, c_attrs, d_attrs)
    for i in tqdm(range(len(interaction_data.iloc[participant_index].interaction_session))):
        interaction = interaction_data.iloc[participant_index].interaction_session[i]
        ac_model.update(interaction)

    bias = ac_model.get_attribute_bias()
    for col in bias.columns:
        results[f'bias-{col}'] = bias[col].to_numpy()
    
    ac_results = ac_results.append(results, ignore_index=True)
    
ac_results.to_pickle(output_path)

In [None]:
fig, axs = plt.subplots(6, 2, sharey=True, figsize=(20, 15))
fig.tight_layout(pad=4)
fig.suptitle('AC Bias Detection for Political Data', fontsize=20)
fig.text(0.5, 0.03, 'Interactions Observed', ha='center')
fig.text(0.03, 0.5, 'Bias', va='center', rotation='vertical')
plt.rcParams.update({'axes.titlesize': 15, 'axes.labelsize': 15,
                     'xtick.labelsize':12, 'xtick.labelsize':12})
for (m,n), ax in np.ndenumerate(axs):
    ax.set_ylim((0, 1.05))
#     ax.set(xlabel='Interactions Observed', ylabel= 'Bias')
#     ax.set_xlim(left=1)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    ax.spines['left'].set_color('black')
    ax.spines['bottom'].set_color('black')

bias_metric_per_task = {'party': 'bias-party', 'gender': 'bias-gender', 'occupation':'bias-occupation',
                       'age': 'bias-age_disc', 'political_experience': 'bias-political_experience_disc',
                       'policy_strength_ban_abortion_after_6_weeks': 'bias-policy_strength_ban_abortion_after_6_weeks_disc',
                       'policy_strength_legalize_medical_marijuana':'bias-policy_strength_legalize_medical_marijuana_disc',
                        'policy_strength_increase_medicare_funding': 'bias-policy_strength_increase_medicare_funding_disc',
                       'policy_strength_ban_alcohol_sales_sundays': 'bias-policy_strength_ban_alcohol_sales_sundays_disc'}

columns = ['party', 'gender', 'occupation', 'age', 'political_experience', 'policy_strength_ban_abortion_after_6_weeks',
                       'policy_strength_legalize_medical_marijuana', 'policy_strength_increase_medicare_funding',
                       'policy_strength_ban_alcohol_sales_sundays']
index_1 = 0
index_2 = 0
for index, row in ac_results.iterrows():
    bias_over_time= pd.DataFrame()
    for ai, attr in enumerate(columns):
        temp_df = pd.DataFrame()
        temp_df[row['participant_id']] = row[bias_metric_per_task[attr]]
        bias_over_time = pd.concat([bias_over_time, temp_df], axis=1, ignore_index=True)
    bias_over_time.columns = columns
    bias_over_time.plot(ax=axs[index_1, index_2], title=f'Bias Detection for {row["participant_id"]}', legend = 0)
    if index != 11:
        if (index_2 == 1):
            index_2 = 0
            index_1 += 1
        else:
            index_2 += 1
    else:
        handles, labels = axs[index_1, index_2].get_legend_handles_labels()
        n_cols = round(len(columns)/2)
        fig.legend(handles, labels, loc='lower center', ncol=n_cols, bbox_to_anchor=(0.5, -0.07))