In [None]:
import os
import json
import numpy as np
import pandas as pd
from math import pi
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap
from scipy.spatial.distance import jensenshannon

import openai
openai.api_key = "Your Openai API key"

### reweighting

In [None]:
df = pd.read_csv("Fair-PP_train.csv") # or directly get our dataset at https://huggingface.co/datasets/tools-o/Fair-PP

target_role = 'gpt-4o-mini_role_6_answer'
comparison_roles = [
    'gpt-4o-mini_role_1_answer',
    'gpt-4o-mini_role_2_answer',
    'gpt-4o-mini_role_3_answer',
    'gpt-4o-mini_role_4_answer',
    'gpt-4o-mini_role_5_answer',
    'gpt-4o-mini_role_7_answer',
]

def calculate_tier(row):
    target_value = row[target_role]
    inconsistent_count = sum(1 for role in comparison_roles if row[role] != target_value)
    return inconsistent_count + 1

df['Tier'] = df.apply(calculate_tier, axis=1)
tier_counts = df['Tier'].value_counts().sort_index()

print(tier_counts)

In [None]:
K = 7
alpha = 1
beta = 1  

N = list(tier_counts.values)

T = []
for tier_degree in range(K):
    tier_number = tier_degree + 1
    N_i = N[tier_degree]
    tier_factor = (tier_number / K) ** alpha
    num_factor = (1 / N_i) ** beta
    T_i = tier_factor * num_factor
    T.append(T_i)

T = np.array(T)
T = T / T.sum()

In [None]:
tier_weights = {i + 1: w for i, w in enumerate(T)} # Create a dictionary for mapping
df['weight'] = df['Tier'].map(tier_weights)
df.to_csv('Fair-PP_train_weighted.csv')

### Analysis

In [None]:
model='gpt-4o-mini'
total = 7 # personas

df = pd.read_csv("Fair-PP.csv") # or directly get our dataset at https://huggingface.co/datasets/tools-o/Fair-PP

for index, row in df.iterrows():
    for role_index in range(total):
        value = questions.at[index, f'{model}_role_{role_index + 1}']
        if 'A.' in value or 'A' == value:
            questions.at[index, f'{model}_role_{role_index + 1}_A'] = 1
        elif 'B.' in value or 'B' == value:
            questions.at[index, f'{model}_role_{role_index + 1}_B'] = 1
        elif 'C.' in value or 'C' == value:
            questions.at[index, f'{model}_role_{role_index + 1}_C'] = 1
        else:
            print(index, value, len(value))

In [None]:
role_trend = {'role_1':[[0,0,0], [0,0,0], [0,0,0], [0,0,0], [0,0,0]], 
              'role_2':[[0,0,0], [0,0,0], [0,0,0], [0,0,0], [0,0,0]], 
              'role_3':[[0,0,0], [0,0,0], [0,0,0], [0,0,0], [0,0,0]], 
              'role_4':[[0,0,0], [0,0,0], [0,0,0], [0,0,0], [0,0,0]], 
              'role_5':[[0,0,0], [0,0,0], [0,0,0], [0,0,0], [0,0,0]], 
              'role_6':[[0,0,0], [0,0,0], [0,0,0], [0,0,0], [0,0,0]], 
              'role_7':[[0,0,0], [0,0,0], [0,0,0], [0,0,0], [0,0,0]]}
for index, row in df.iterrows():
    for role_index, role in enumerate(total):
        if questions.at[index, f'{model}_role_{role_index + 1}_A'] == 1:
            record = 0
        elif questions.at[index, f'{model}_role_{role_index + 1}_B'] == 1:
            record = 1
        elif questions.at[index, f'{model}_role_{role_index + 1}_C'] == 1:
            record = 2

        pers = questions.at[index, 'Perspective']
        fractional_part, integer_part = math.modf(pers)
        pers = round(fractional_part, 1)
        if pers == 0.1:
            role_trend[f'role_{role_index + 1}'][0][record] += 1
        if pers == 0.2:
            role_trend[f'role_{role_index + 1}'][1][record] += 1
        if pers == 0.3:
            role_trend[f'role_{role_index + 1}'][2][record] += 1
        if pers == 0.4:
            role_trend[f'role_{role_index + 1}'][3][record] += 1
        if pers == 0.5:
            role_trend[f'role_{role_index + 1}'][4][record] += 1

In [None]:
# personalized preference anchors, fig 3 in the paper
for role_index in range(1, 8):
    data = np.array(role_trend[f'role_{role_index}'])
    total = data[:, 0] + data[:, 1]
    a_percent = data[:, 0] / total * 100
    b_percent = data[:, 1] / total * 100
    
    data_percent = data / data.sum(axis=1, keepdims=True) * 100
    
    names = ['Progressive Activists', 'Civic Pragmatists', 'Disengaged Battlers', 'Established Liberals', 'Loyal Nationals', 'Traditionalists', 'Backbone Conservatives']
    categories = ['Dimension 1', '2', '3', '4', '5']
    labels = ['Option A', 'Option B']
    N = len(categories)
    angles = [n / N * 2 * pi for n in range(N)]
    angles = angles[::-1]
    angles += angles[:1]
    
    categories = categories[::-1]
    
    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111, polar=True)
    
    ax.set_theta_offset(pi / 2)
    ax.set_theta_direction(-1)
    
    red = (221 / 255, 108 / 255, 90 / 255)
    blue = (69 / 255, 152 / 255, 198 / 255)
    
    values = data_percent[:, 0].tolist()
    values += values[:1]
    ax.plot(angles, values, label=labels[0], linewidth=2, color=blue)
    ax.fill(angles, values, alpha=0.25, color=blue)
    
    values = data_percent[:, 1].tolist()
    values += values[:1]
    ax.plot(angles, values, label=labels[1], linewidth=2, color=red)
    ax.fill(angles, values, alpha=0.25, color=red)
    
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(categories, fontsize=25, rotation_mode='anchor')
    
    for label, angle in zip(ax.get_xticklabels(), angles[:-1]):
        if angle in [0, pi]:
            label.set_horizontalalignment('center')
        elif 0 < angle < pi:
            label.set_horizontalalignment('left')
        else:
            label.set_horizontalalignment('right')

    ax.set_title(f"{names[role_index-1]}", fontsize=30)
    ax.legend(
        loc="upper right",
        bbox_to_anchor=(1.1, 1.0),  # Moves legend outside (adjust x,y as needed)
        fontsize=15
    )
    plt.savefig(f"{names[role_index-1]}.pdf", format='pdf')
    plt.show()

#### similarity between seven persona anchors

In [None]:
df = pd.read_csv("Fair-PP.csv")

model='gpt-4o-mini'
x = 0
y = 0
z = 0
for index, question in enumerate(df['Question']):
    for role_index in range(total):
        value = df.at[index, f'{model}_role_{role_index + 1}']
        if 'A.' in value or 'A' == value:
            df.at[index, f'{model}_role_{role_index + 1}_answer'] = 0
        elif 'B.' in value or 'B' == value:
            df.at[index, f'{model}_role_{role_index + 1}_answer'] = 1
        elif 'C.' in value or 'C' == value:
            df.at[index, f'{model}_role_{role_index + 1}_answer'] = 2
        else:
            print(index, value, len(value))

df = df[['Topic', 'Problem Category', 'Specific Problem', 'Group Type', 'Group', 'Perspective', 'Question', 'gpt-4o-mini_role_1_answer', 'gpt-4o-mini_role_2_answer', 'gpt-4o-mini_role_3_answer', 'gpt-4o-mini_role_4_answer', 'gpt-4o-mini_role_5_answer', 'gpt-4o-mini_role_6_answer', 'gpt-4o-mini_role_7_answer']]

In [None]:
def calculate_js_distances(df):
    role_columns = [f'gpt-4o-mini_role_{i}_answer' for i in range(1, 8)]
    distributions = {}
    for col in role_columns:
        counts = df[col].value_counts(normalize=True).reindex([0, 1, 2], fill_value=0)
        distributions[col] = counts.values
    
    n_roles = len(role_columns)
    js_distances = np.zeros((n_roles, n_roles))
    
    for i in range(n_roles):
        for j in range(n_roles):
            if i <= j:
                js_dist = jensenshannon(distributions[role_columns[i]], 
                                      distributions[role_columns[j]])
                js_distances[i, j] = js_dist
                js_distances[j, i] = js_dist
    
    result_df = pd.DataFrame(js_distances, 
                           index=[f'role_{i+1}' for i in range(7)],
                           columns=[f'role_{i+1}' for i in range(7)])
    
    return result_df

result = calculate_js_distances(df)
print(result)

In [None]:
# fig 5 in the paper
result = calculate_js_distances(df) 
plt.figure(figsize=(10, 8))
sns.set_style("white")

mask = np.tril(np.ones_like(result, dtype=bool), k=-1)

custom_cmap = LinearSegmentedColormap.from_list(
    "custom_gradient", 
    ['#24BFCA', '#7ECED6', '#D4EDEE', '#F4DEDC', '#F7B8B7', '#F39289']
)

heatmap = sns.heatmap(
    result,
    annot=True, 
    fmt='.2f', 
    cmap=custom_cmap,  
    vmin=0.5, vmax=1, 
    square=True,
    linewidths=0.5,
    cbar_kws={'label': '', 'pad': 0.08},
    annot_kws={'size': 24},
    mask=mask
)

heatmap.xaxis.tick_top()
heatmap.yaxis.tick_right()
heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=360)
plt.xticks(fontsize=24)
plt.yticks(fontsize=24)

cbar = heatmap.collections[0].colorbar
cbar.ax.tick_params(labelsize=22)

plt.tight_layout()
plt.savefig('fig5-1.pdf', format='pdf', bbox_inches='tight')
plt.show()

In [None]:
def analyze_role_similarity(df):
    """
    Analyzes the similarity of 'role_6_answer' with other role answers in the DataFrame.

    Args:
        df: The input pandas DataFrame.

    Returns:
        A list of DataFrames, where each DataFrame contains rows from the
        original DataFrame that have a specific number of matching role answers
        with 'role_6_answer'.
    """

    answer_cols = [
        'gpt-4o-mini_role_1_answer',
        'gpt-4o-mini_role_2_answer',
        'gpt-4o-mini_role_3_answer',
        'gpt-4o-mini_role_4_answer',
        'gpt-4o-mini_role_5_answer',
        'gpt-4o-mini_role_7_answer'
    ]

    df_list = []
    for num_matches in range(len(answer_cols) + 1):  # +1 to include 0 matches
        matching_rows = []
        for index, row in df.iterrows():
            match_count = 0
            for col in answer_cols:
                if row['gpt-4o-mini_role_6_answer'] == row[col]:
                    match_count += 1
            if match_count == num_matches:
                matching_rows.append(row)
        df_list.append(pd.DataFrame(matching_rows))
    return df_list

df_list = analyze_role_similarity(df)

for i, matching_df in enumerate(df_list):
    print(f"DataFrame with {i} matches: {len(matching_df)} rows")

In [None]:
k = 5
top_k_groups = df['Group'].value_counts().head(k)
top_k_groups_df = pd.DataFrame({'Value': top_k_groups.index, 'Count': top_k_groups.values, 'Category': 'Social Groups'})

top_k_problems = df['Specific Problem'].value_counts().head(k)
top_k_problems_df = pd.DataFrame({'Value': top_k_problems.index, 'Count': top_k_problems.values, 'Category': 'Fairness Topics'})

top_k_options = df['Perspective'].value_counts().head(k)
top_k_options_df = pd.DataFrame({'Value': top_k_options.index, 'Count': top_k_options.values, 'Category': 'Perspective Dimensions'})

def rename_group(x):
    if x == 'domestic violence victims':
        return 'DV victims'
    return x

def rename_topics(x):
    if x == 'accessible basic healthcare':
        return 'basic healthcare'
    if x == 'accessible public transport':
        return 'public transport'
    if x == 'accessible public transportation subsidies':
        return 'transit subsidies'
    if x == 'emergency medical services':
        return 'emergency services'
    return x 

top_k_groups_df['Value'] = top_k_groups_df['Value'].apply(rename_group)
top_k_problems_df['Value'] = top_k_problems_df['Value'].apply(rename_topics)

all_top_k_df = pd.concat([top_k_problems_df, top_k_groups_df])

plt.figure(figsize=(8, 6))
ax = sns.barplot(x='Count', y='Value', hue='Category', data=all_top_k_df,
                 palette={'Social Groups': (243/255, 183/255, 149/255), 
                          'Fairness Topics': (125/255, 221/255, 255/255), 
                          'Perspective Dimensions': (255/255, 242/255, 204/255)}) 

for p in ax.patches:
    width = p.get_width() 
    if width > 0:  
        ax.text(width - 1,  
                p.get_y() + p.get_height() / 2, 
                f'{int(width)}', 
                ha='right', 
                va='center', 
                fontsize=18)

plt.xlabel('Count', fontsize=20)
plt.ylabel('')
plt.legend(title='', fontsize=20)
ax.tick_params(axis='x', labelsize=20) 
ax.tick_params(axis='y', labelsize=20) 
plt.tight_layout()
plt.savefig("fig5-3.pdf")
plt.show()

In [None]:
# fig 5
rgb_colors = [
    (255, 248, 229),
    (255, 242, 205),
    (236, 251, 218),
    (211, 226, 183),
    (200, 230, 245),
    (180, 220, 235),
    (160, 200, 220) 
]

# Convert RGB tuples to hex strings for Matplotlib
hex_colors = ['#%02x%02x%02x' % rgb for rgb in rgb_colors]

labels = [f"{i}" for i in range(len(df_list) - 1)]
labels.append("6 Matches")
sizes = [len(matching_df) for matching_df in df_list]

total_size = sum(sizes)

def autopct_format(percentage):
    count = int(round(percentage * total_size / 100.0))
    return f'{percentage:.1f}%\n({count})'

colors = hex_colors * (len(sizes) // len(hex_colors) + 1)
colors = colors[:len(sizes)]

wedge_properties = {'edgecolor': 'white', 'linewidth': 2}

plt.figure(figsize=(8, 8))
plt.pie(sizes, labels=labels, autopct=autopct_format, startangle=140, colors=colors, wedgeprops=wedge_properties,
        labeldistance=1.05,
        pctdistance=0.8,
        textprops={'fontsize': 20})
plt.axis('equal')
plt.savefig("fig5-2.pdf")
plt.show()

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from scipy.stats import gaussian_kde
from scipy.spatial import ConvexHull
from matplotlib.colors import LinearSegmentedColormap, Normalize

sns.set_style("whitegrid")
sns.set_context("talk", font_scale=1.1)
plt.rcParams.update({
    "figure.facecolor": "#f5f5f5",
    "axes.facecolor": "#ffffff",
    "axes.edgecolor": "#cccccc",
    "axes.linewidth": 1.0,
    "grid.color": "#dddddd",
    "grid.linestyle": "--",
    "grid.linewidth": 0.6
})

answer_cols = [
    'gpt-4o-mini_role_1_answer', 'gpt-4o-mini_role_2_answer',
    'gpt-4o-mini_role_3_answer', 'gpt-4o-mini_role_4_answer',
    'gpt-4o-mini_role_5_answer', 'gpt-4o-mini_role_6_answer',
    'gpt-4o-mini_role_7_answer', 'majority_answer', 'Falcon3-7B-Instruct_answer', 
    'Llama-3.2-3B-Instruct_answer', 'Llama-3.1-8B-Instruct_answer', 
    'llama3.1-8b-cpt-sea-lionv3-instruct_answer', 
    'Mistral-7B-Instruct_answer', 'Qwen2.5-7B-Instruct_answer'
]
X = np.vstack([df[col].values for col in answer_cols])  # shape = (7, N)
roles = [f"Persona {i+1}" for i in range(8)]
roles.append('Falcon3-7B')
roles.append('Llama-3.2-3B')
roles.append('Llama-3.1-8B')
roles.append('Sea-lionv3')
roles.append('Mistral-7B')
roles.append('Qwen2.5-7B')

proj = PCA(n_components=2, random_state=42).fit_transform(X)
df_proj = pd.DataFrame(proj, columns=['x','y'])
df_proj['role'] = roles

xy = df_proj[['x','y']].T.values
kde = gaussian_kde(xy)
res = 200
xmin, xmax = df_proj.x.min(), df_proj.x.max()
ymin, ymax = df_proj.y.min(), df_proj.y.max()
Xg = np.linspace(xmin, xmax, res)
Yg = np.linspace(ymin, ymax, res)
xx, yy = np.meshgrid(Xg, Yg)
coords = np.vstack([xx.ravel(), yy.ravel()])
dens = kde(coords).reshape(res, res)

norm = Normalize(vmin=dens.min(), vmax=dens.max())
dens_norm = norm(dens)
cmap = LinearSegmentedColormap.from_list("sky_coral", ["skyblue","lightcoral"])

fig, ax = plt.subplots(figsize=(10, 8))

ax.imshow(
    dens_norm,
    origin='lower',
    cmap=cmap,
    aspect='auto',
    alpha=0.7,
    transform=ax.transAxes,
    extent=(0, 1, 0, 1),
    zorder=0
)


palette = sns.color_palette("Set2", n_colors=14)

markers = ["o","s","^","D","P","X","v","*", "*", "*", "*", "*", "*", "*"]

for idx, role in enumerate(roles):
    sub = df_proj[df_proj['role'] == role]

    ax.scatter(
        sub.x, sub.y,
        s=400,
        marker=markers[idx],
        # edgecolor="#333333",
        edgecolor='red' if idx == 7 else "#333333",
        linewidth=1.5,
        facecolor='red' if idx == 7 else palette[idx],
        label=role,
        alpha=0.9,
        zorder=1
    )

    if len(sub) >= 3:
        hull = ConvexHull(sub[['x','y']].values)
        pts = sub[['x','y']].values[hull.vertices]
        ax.fill(
            pts[:,0], pts[:,1],
            # facecolor=palette[idx],
            facecolor='red' if idx == len(roles) - 1 else palette[idx],
            edgecolor='red' if idx == len(roles) - 1 else palette[idx],
            # edgecolor=palette[idx],
            alpha=0.2,
            zorder=1
        )

ax.legend(
    title="Model",
    frameon=True,
    facecolor="#fafafa",
    edgecolor="#cccccc",
    bbox_to_anchor=(1, 1),
    loc='upper left'
)

plt.tight_layout()
plt.savefig("fig4.pdf")
plt.show()

In [None]:
# fig 4
human_df = human_df[human_df['Group'] != "aborted children"]
mapping = {'non-binary people': 'non-binary', 'white people': 'white', 'Asian people': 'Asian', 'African American people':'African American',
  'American Indian people': 'American Indian', 'Latino people':'Latino', 'christian people':'christian', 'buddhist people':'buddhist', 'hindu people':'hindu', 
  'jewish people':'jewish', 'muslim people':'muslim', 'heterosexual people':'heterosexual', 'homosexual people':'homosexual', 
   'elderly people':'elderly', 'domestic violence victims':'DV victims', 'people with disabilities':'disabilities', 'black lives matter supporters':'BLM supporters'}

def replace_func(x):
    if x in mapping:
        return mapping[x]
    else:
        return x

human_df['Group'] = human_df['Group'].apply(replace_func)

role_answer_cols = [
    'gpt-4o-mini_role_1_answer',
    'gpt-4o-mini_role_2_answer',
    'gpt-4o-mini_role_3_answer',
    'gpt-4o-mini_role_4_answer',
    'gpt-4o-mini_role_5_answer',
    'gpt-4o-mini_role_6_answer',
    'gpt-4o-mini_role_7_answer'
]

from collections import defaultdict
def calculate_option_ratios_ordered(df):
    role_columns = [
        'gpt-4o-mini_role_1_answer',
        'gpt-4o-mini_role_2_answer',
        'gpt-4o-mini_role_3_answer',
        'gpt-4o-mini_role_4_answer',
        'gpt-4o-mini_role_5_answer',
        'gpt-4o-mini_role_6_answer',
        'gpt-4o-mini_role_7_answer'
    ]
    result = defaultdict(dict)

    for group_type in df['Group Type'].unique():
        group_type_df = df[df['Group Type'] == group_type].copy()
        for group in group_type_df['Group'].unique():
            group_df = group_type_df[group_type_df['Group'] == group].copy()
            all_answers = []
            for col in role_columns:
                all_answers.extend(group_df[col].dropna().tolist())

            if not all_answers:
                result[group_type][group] = [0.0, 0.0, 0.0]
                continue

            option_counts = defaultdict(int)
            for answer in all_answers:
                option_counts[answer] += 1

            total_answers = len(all_answers)
            option_ratios = [
                option_counts.get(0, 0) / total_answers,
                option_counts.get(1, 0) / total_answers,
                option_counts.get(2, 0) / total_answers
            ]
            result[group_type][group] = option_ratios

    return dict(result)
    
option_ratios = calculate_option_ratios_ordered(human_df)

keys = ["Topic", "Problem Category", "Specific Problem", "Group Type", "Group", "Perspective", "Option"]
human_columns = [f'gpt-4o-mini_role_{i}_answer' for i in range(1, 8)]
custom_human_labels = ['1', '2', '3', '4', '5', '6', '7']
human_dfs = [human_df[keys + [col]].rename(columns={col: 'Answer'}) for col in human_columns]

group_types = ['Gender', 'Race', 'Religion', 'Sexual Orientation', 'Age', 'Minority']
target = 'Problem Category'

def compute_yes_proportion(df, group_by=None, group_values=None, subcategory=None):
    if group_by:
        filtered_df = df[df[group_by].isin(group_values)] if group_values else df
    else:
        filtered_df = df
    if subcategory:
        filtered_df = filtered_df[filtered_df[target] == subcategory]
    if filtered_df.empty:
        return 0.0
    return filtered_df['Answer'].apply(lambda x: x == 0.0).mean()

def get_all_stats(human_dfs, group_by, group_values, subcategories):
    n_humans = len(human_dfs)
    n_subcategories = len(subcategories)
    yes_proportions = np.zeros((n_humans, n_subcategories))
    for i, df in enumerate(human_dfs):
        for k, subcategory in enumerate(subcategories):
            yes_proportions[i, k] = compute_yes_proportion(df, group_by, group_values, subcategory)
    return yes_proportions

def calculate_answer_proportions(series):
    all_answers = series.values.flatten()
    valid_answers = [ans for ans in all_answers if ans in [0, 1, 2]]

    if not valid_answers:
        return [0.0, 0.0, 0.0]

    total_valid_answers = len(valid_answers)
    proportion_0 = valid_answers.count(0) / total_valid_answers
    proportion_1 = valid_answers.count(1) / total_valid_answers
    proportion_2 = valid_answers.count(2) / total_valid_answers

    return [proportion_0, proportion_1, proportion_2]

def plot_scatter_by_group_type(human_dfs, human_df, group_types, human_labels):
    global role_answer_cols
    all_labels = human_labels
    n_cols = len(group_types)

    specific_problems = human_df[target].unique()[::-1]
    specific_problems_short = [
        'basic material needs', 'basic health needs', 'basic social services', 'fundamental rights', 'education',
        'work opportunities', 'political opportunities', 'compensation', 'social recognition', 'reciprocity',
        'welfare', 'tax', 'anti-discrimination', 'legal justice', 'public resource equity'
    ]
    
    specific_problems_short = specific_problems_short[::-1]

    fig_height = len(specific_problems) * 0.2
    fig_width = 2.2 * n_cols
    fig = plt.figure(figsize=(fig_width, fig_height * 1.2))

    colors = plt.cm.viridis(np.linspace(0.3, 0.8, len(specific_problems)))
    sizes = np.linspace(10, 100, 20)
    bar_segment_colors = ['skyblue', 'lightcoral', 'lightgrey']

    bar_left = 0.01
    bar_bottom = 0.1
    bar_width = 0.1
    bar_height = 0.65
    bar_ax = fig.add_axes([bar_left, bar_bottom, bar_width, bar_height])
    problem_category_proportions_dict = human_df.groupby('Problem Category')[role_answer_cols].apply(calculate_answer_proportions).to_dict()
    bar_segments = np.array([problem_category_proportions_dict.get(category) for category in specific_problems])

    proportion_0 = np.array([props[0] for props in bar_segments])
    sorted_indices_bar = np.argsort(proportion_0)
    sorted_specific_problems_short_bar = np.array(specific_problems_short)[sorted_indices_bar]
    bar_segments = bar_segments[sorted_indices_bar]
    specific_problems_short = sorted_specific_problems_short_bar

    cumulative = np.zeros(len(specific_problems_short))

    for i in range(2, -1, -1):
        segment = -bar_segments[:, i]
        bar_ax.barh(
            range(len(specific_problems_short)),
            segment,
            left=cumulative,
            color=bar_segment_colors[i],
            alpha=1,
            edgecolor='none'
        )
        cumulative += segment

    current_ticks = bar_ax.get_xticks()
    new_labels = [abs(tick) for tick in current_ticks]
    bar_ax.set_xticks(current_ticks[1:])

    new_labels = [0, 0.5, 1]
    font = {'size': 7}
    bar_ax.set_xticklabels(new_labels, fontdict=font)

    for i, label in enumerate(specific_problems_short):
        bar_ax.text(-0.02, i, label, va='center', ha='right', fontsize=10, color='black', alpha=0.7) 

    for spine in bar_ax.spines.values():
        spine.set_visible(False)

    bar_ax.set_yticks(range(len(specific_problems_short)))
    bar_ax.set_yticklabels(specific_problems_short, fontsize=8)
    bar_ax.tick_params(axis='y', which='both', length=0)
    bar_ax.set_title("Overall Proportion\n", fontsize=10)
    bar_ax.title.set_size(10)
    bar_ax.title.set_ha('center')

    import matplotlib.patches as mpatches

    legend_patches = [mpatches.Patch(color=color, label=label)
                      for color, label in zip(bar_segment_colors, ['A', 'B', 'C'])]

    legend_x = bar_left + bar_width / 2 - 0.04
    legend_y = bar_bottom + bar_height

    fig.legend(handles=legend_patches,
               loc='lower left',
               bbox_to_anchor=(legend_x, legend_y),
               fontsize=8,
               frameon=False,
               ncol=3, 
               handlelength=0.8, 
               handletextpad=0.2)
    scatter_start_x = bar_left + bar_width + 0.009
    scatter_width = (1 - scatter_start_x - 0.03) / n_cols - 0.045

    for col, group_type in enumerate(group_types):
        ax_scatter_left = scatter_start_x + col * scatter_width
        ax_scatter = fig.add_axes([ax_scatter_left, bar_bottom, scatter_width, bar_height], sharey=bar_ax)
        yes_proportions = get_all_stats(human_dfs, 'Group Type', [group_type], specific_problems)

        for i, label in enumerate(all_labels):
            for j, problem in enumerate(specific_problems):
                yes_prop = yes_proportions[i, j]
                if np.isnan(yes_prop):
                    continue
                size_idx = int(yes_prop * (len(sizes) - 1))
                ax_scatter.scatter(i, j, s=sizes[size_idx], c=[colors[j]], alpha=0.6)

        ax_scatter.set_xticks(range(len(all_labels)))
        ax_scatter.set_xticklabels(all_labels, rotation=0, ha='right', fontsize=8)
        ax_scatter.set_yticks(range(len(specific_problems_short)))
        ax_scatter.set_yticklabels([])
        ax_scatter.grid(True, linestyle='--', alpha=0.5)
        ax_scatter.set_xlim(-0.3, len(all_labels) - 0.5)
        ax_scatter.set_ylim(-0.5, len(specific_problems_short) - 0.5)
        if col > 0:
            ax_scatter.tick_params(axis='y', which='both', length=0)

        ax_scatter.set_title('')
        group_ratios = option_ratios.get(group_type, {})

        group_labels_top = human_df[human_df['Group Type'] == group_type]['Group'].unique()
        # Create 2D list for A, B, C percentages
        group_values_top = []
        for label in group_labels_top:
            ratios = group_ratios.get(label, [0.0, 0.0, 0.0])  # [A, B, C] percentages
            group_values_top.append(ratios)
        group_values_top = np.array(group_values_top)

        # Sort by first column (A's percentage) in descending order
        sorted_indices = np.argsort(group_values_top[:, 0])[::-1]
        group_labels_top = group_labels_top[sorted_indices]
        group_values_top = group_values_top[sorted_indices]

        # Bar plot setup
        top_bar_height = 0.7 * bar_height  # Adjust height as needed
        top_bar_bottom = bar_bottom + bar_height + 0.015  # Position above scatter plot

        ax_bar_top = fig.add_axes([ax_scatter_left, top_bar_bottom, scatter_width, top_bar_height])

        left = np.zeros(len(group_labels_top))
        labels = ['A', 'B', 'C']

        for i in range(3):  # For A, B, C
            bars = ax_bar_top.barh(group_labels_top, group_values_top[:, i], left=left, 
                                  color=bar_segment_colors[i], alpha=0.7, label=labels[i])
            left += group_values_top[:, i]  # Update left position for next stack

        ax_bar_top.set_title(group_type, fontsize=10, y=-0.062)
        ax_bar_top.tick_params(axis='y', left=False, right=False, labelleft=False, labelright=False)
        ax_bar_top.tick_params(axis='x', labelsize=6, bottom=False, labelbottom=False)
        ax_bar_top.spines['top'].set_visible(False)
        ax_bar_top.spines['right'].set_visible(False)
        ax_bar_top.spines['left'].set_visible(False)
        ax_bar_top.spines['bottom'].set_visible(False)
        ax_bar_top.set_xticks([])

        ind = 0
        for bar in bars:
            width = bar.get_width()
            yval = bar.get_y() + bar.get_height()/2
            ax_bar_top.text(0.01, yval, group_labels_top[ind], ha='left', va='center', fontsize=10, color='black')
            ind += 1

        ax_bar_top.invert_yaxis() # To have the first label at the top
    plt.subplots_adjust(wspace=0.00, hspace=0.3, left=0.1)
    plt.savefig('fig3.pdf', format='pdf', bbox_inches='tight')
    plt.show()

plot_scatter_by_group_type(human_dfs, human_df, group_types, custom_human_labels)