In [None]:
import os
import json
import numpy as np
import pandas as pd
from scipy.stats import wasserstein_distance
import helpers as ph
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.api as sm
import dataframe_image as dfi

sns.set_style('darkgrid')

In [None]:
RESULTS_DIR = f'./data/distributions/'
CONTEXT = 'default'
SAVEFIG = True

In [None]:
combined_df = []
for context in ['default', 'steer-qa', 'steer-bio', 'steer-portray']:
    
    cdf = pd.read_csv(os.path.join(RESULTS_DIR, f'Pew_American_Trends_Panel_disagreement_500_{context}_combined.csv'))
    cdf['survey'] = 'Pew_American_Trends_Panel_disagreement_500'
    cdf['context'] = context
    if context != 'default':
        cdf = cdf[cdf['group'] == cdf['steer_group']]
    combined_df.append(cdf)
combined_df = pd.concat(combined_df)
combined_df['Source'] = combined_df.apply(lambda x: 'AI21 Labs' if 'j1-' in x['model_name'].lower() else 'OpenAI',
                                          axis=1)

## Measure steerability

In [None]:
KEYS = ['model_name', 'context', 'attribute', 'group', 'group_order', 'model_order']

grouped = combined_df.groupby(KEYS, as_index=False).agg({'WD': np.mean}) \
         .sort_values(by=['context', 'model_order', 'group_order'])
grouped['Rep'] = 1 - grouped['WD']

In [None]:
unsteered = grouped[grouped['context'] == 'default'].rename(columns={'WD': 'WD_d',
                                                                     'Rep': 'Rep_d'})
steered = grouped[grouped['context'] != 'default'].sort_values(by='Rep')
steered = steered.groupby([k for k in KEYS if k != 'context'], as_index=False).last()\
                 .rename(columns={'WD': 'WD_s', 'Rep': 'Rep_s'})
result = pd.merge(unsteered, steered, on=[k for k in KEYS if k != 'context']) 

In [None]:
fig, ax = plt.subplots(figsize=(6, 4))
sns.set_style('whitegrid')

for model in ph.MODEL_NAMES.values():
    if model not in set(result['model_name'].values): continue 
    c = result[result['model_name'] == model]
    
    reg = sm.OLS(c['Rep_s'], c['Rep_d'])
    slope = reg.fit().params['Rep_d']

    sns.regplot(data=c, x='Rep_d', y='Rep_s', ax=ax, 
                label=model, 
                line_kws={'linewidth': 3},
                scatter_kws={'s': 14})
    
xx = np.linspace(0.66, 0.9, 10)
plt.legend(loc=4, fontsize=9, ncol=2)
ax.plot(xx, xx, 'k--')
ax.set_xlim([0.68, 0.88])
ax.set_ylim([0.68, 0.88])

plt.xlabel('Default subgroup representativeness', fontsize=12)
plt.ylabel('Steered subgroup representativeness', fontsize=12)
plt.grid(linestyle='-', linewidth=0.5)
if SAVEFIG: plt.savefig('./figures/steerability.png', bbox_inches="tight")
plt.show()

## Steerability by topic

In [None]:
topic_info = np.load('./data/human_resp/topic_mapping.npy', allow_pickle=True).item()

In [None]:
combined_df['topic'] = combined_df.apply(lambda x: topic_info[x['question']]['cg'], axis=1)
combined_df = combined_df.explode(['topic'])

In [None]:
KEYS = ['Source', 'model_name', 'context', 'attribute', 'group', 'group_order', 'model_order', 'topic']

grouped_topic = combined_df.groupby(KEYS, as_index=False).agg({'WD': np.mean}) \
         .sort_values(by=['context', 'model_order', 'group_order'])
grouped_topic['Rep'] = 1 - grouped_topic['WD']

In [None]:
steered_topic = grouped_topic[grouped_topic['context'] != 'default'].sort_values(by='Rep')
steered_topic = steered_topic.groupby([k for k in KEYS if k != 'context'], as_index=False).last()\
                 .rename(columns={'WD': 'WD_s', 'Rep': 'Rep_s', 'model_name': ''})

In [None]:
steered_topic

In [None]:
styles = ph.VIS_STYLES

styles[-1]['props'][-1] = (styles[-1]['props'][-1][0], "105%")
for attribute in np.unique(steered_topic['attribute']):
    print(attribute)
    table = pd.pivot_table(steered_topic[steered_topic['attribute'] == attribute], 
                           columns=['Source', ''], 
                       index='group',
                       values='Rep_s', 
                       sort=False)
    table_vis = table.style.background_gradient('Reds', axis=1).set_table_styles(styles)  \
                           .set_properties(**{"font-size":"0.8rem"}).format(precision=3)
    if SAVEFIG: dfi.export(table_vis, f'./figures/steerability_{attribute}.png')

    display(table_vis)