In [None]:
import os
import json
import numpy as np
import pandas as pd
from scipy.stats import wasserstein_distance
import helpers as ph
import seaborn as sns
import dataframe_image as dfi

COLOR_STR = "#0A3EA4,#4874F9,#84A0F5,#F1F4FB,#FFFFFF"
palette = sns.color_palette(f"blend:{COLOR_STR}", 12, as_cmap=True)

styles = ph.VIS_STYLES

In [None]:
RESULTS_DIR = f'./data/distributions/'
CONTEXT = 'default'
SAVEFIG = False

## Load human and LM opinion distributions

In [None]:
combined_df, human_df = [], []
for wave in ph.PEW_SURVEY_LIST:
    SURVEY_NAME = f'American_Trends_Panel_W{wave}'

    cdf = pd.read_csv(os.path.join(RESULTS_DIR, f'{SURVEY_NAME}_{CONTEXT}_combined.csv'))
    cdf['survey'] = f'ATP {wave}'
    combined_df.append(cdf)
    
    hdf = pd.read_csv(os.path.join(RESULTS_DIR, f'{SURVEY_NAME}_{CONTEXT}_baseline.csv'))
    hdf['survey'] = f'ATP {wave}'
    human_df.append(hdf)
combined_df, human_df = pd.concat(combined_df), pd.concat(human_df)
combined_df['Source'] = combined_df.apply(lambda x: 'AI21 Labs' if 'j1-' in x['model_name'].lower() else 'OpenAI',
                                          axis=1)

In [None]:
print('# Questions:', len(set(combined_df['question'])))

## Compute average representativeness across dataset

In [None]:
KEYS = ['Source', 'model_name', 'attribute', 'group', 'group_order', 'model_order']

grouped = combined_df.groupby(KEYS, as_index=False).agg({'WD': np.mean}) \
         .sort_values(by=['model_order', 'group_order'])
grouped['Rep'] = 1 - grouped['WD']

### Overall representativeness

In [None]:
human_baseline = human_df.groupby(['group_x'], as_index=False).agg({'WD': np.mean})
human_baseline['Rep'] = 1 - human_baseline['WD']
human_baseline = human_baseline.agg({'Rep': (np.mean, min)}).reset_index()
human_baseline['model_name'] = human_baseline.apply(lambda x: 'Avg' if x['index'] == 'mean' \
                                                    else 'Worst', axis=1)
human_baseline['model_order'] = -1
human_baseline['Source'] = "Humans"


g = pd.concat([human_baseline, grouped[grouped['attribute'] == 'Overall']]).rename(columns={'model_name': '',
                                                                                            'Rep': 'R'})

table = pd.pivot_table(g, 
                       columns=['Source', ''], 
                       values='R', 
                       sort=False)
table_vis = table.style.background_gradient(palette, axis=1).set_table_styles(styles)  \
                        .set_properties(**{"font-size":"0.75rem"}).format(precision=3)

if SAVEFIG: table_vis.hide_index().export_png('./figures/representativeness.png')
display(table_vis)

### Subgroup representativeness

In [None]:
styles[-1]['props'][-1] = (styles[-1]['props'][-1][0], "105%")

In [None]:
for attribute in ph.DEMOGRAPHIC_ATTRIBUTES[1:]:
    
    print(f'-----{attribute}----')
    
    g = grouped[grouped['attribute'] == attribute].rename(columns={'model_name': 'Model', 'group': attribute,
                                                                  'Source': ''})

    table = pd.pivot_table(g, 
                           index=[attribute], 
                           columns=['', 'Model'], 
                           values="Rep", 
                           sort=False)
    table_vis = table.style.background_gradient(palette, axis=(attribute=='Overall')).set_table_styles(styles)  \
                            .set_properties(**{"font-size":"1.3rem"}).format(precision=3)
    if SAVEFIG: table_vis.export_png(f'./figures/representativeness_{attribute}.png')

    display(table_vis)