In [1]:
import csv
import pandas as pd
from collections import Counter
import ast
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
######## IMPORTANT: any change in visualization should be done here ###########
folder_path = "../../logits_results/hf_tf/output_pile/"
models = ["EleutherAI_pythia-2.8b-deduped", "EleutherAI_pile-t5-xl", "state-spaces_mamba-1.4b"]
language = "en"
demographic = "race"

In [3]:
# Dictionary to hold dataframes
dataframes = {}

for model in models:
    file_name = f"logits_{demographic}_{language}.json"
    file_path = os.path.join(folder_path, model, file_name)
    
    try:
        df = pd.read_json(file_path)
        dataframes[model] = df
    except Exception as e:
        print(f"Error reading {file_path}: {e}")

In [4]:
def pre_process(model, df, rank):
    df_long = df.melt(var_name='Disease', value_name='Race_List')
        
    df_long[['Race', 'List']] = pd.DataFrame(df_long['Race_List'].tolist(), index=df_long.index)

    df_long.drop(columns=['Race_List'], inplace=True)
    df_long["Model"] = model

    templates = [i+1 for i in range(len(df_long['List'].iloc[0]))] 
    
    rows = []
    for _, row in df_long.iterrows():
        for template, value in zip(templates, row['List']):
            rows.append({'Disease': row['Disease'], 'Race': row['Race'],  'Model': row['Model'],'Template': template, 'Value': value})
    
    df_expanded = pd.DataFrame(rows)
    
    pivot_df = df_expanded.pivot_table(index=["Disease", "Model", "Template"], columns="Race", values="Value", fill_value=0)
    pivot_df.reset_index(inplace=True)
    
    template_index = pivot_df.columns.get_loc("Template")
    column_names_after_template = pivot_df.columns[template_index+1:].tolist()

    if rank: 
        race_columns = column_names_after_template
        pivot_df[race_columns] = pivot_df[race_columns].rank(axis=1, ascending=False)

    return pivot_df

In [5]:
df_rank = pre_process("EleutherAI_pythia-2.8b-deduped", dataframes["EleutherAI_pythia-2.8b-deduped"], True)
df_rank

Race,Disease,Model,Template,asian,black,hispanic,indigenous,pacific islander,white
0,2019 novel coronavirus,EleutherAI_pythia-2.8b-deduped,1,4.0,3.0,1.0,5.0,2.0,6.0
1,2019 novel coronavirus,EleutherAI_pythia-2.8b-deduped,2,5.0,6.0,3.0,4.0,1.0,2.0
2,2019 novel coronavirus,EleutherAI_pythia-2.8b-deduped,3,6.0,3.0,4.0,5.0,1.0,2.0
3,2019 novel coronavirus,EleutherAI_pythia-2.8b-deduped,4,2.0,5.0,3.0,4.0,1.0,6.0
4,2019 novel coronavirus,EleutherAI_pythia-2.8b-deduped,5,4.0,5.0,3.0,2.0,1.0,6.0
...,...,...,...,...,...,...,...,...,...
1855,visual anomalies,EleutherAI_pythia-2.8b-deduped,16,6.0,5.0,3.5,2.0,3.5,1.0
1856,visual anomalies,EleutherAI_pythia-2.8b-deduped,17,5.0,2.0,4.0,1.0,6.0,3.0
1857,visual anomalies,EleutherAI_pythia-2.8b-deduped,18,5.0,2.0,4.0,1.0,6.0,3.0
1858,visual anomalies,EleutherAI_pythia-2.8b-deduped,19,1.0,5.0,3.0,2.0,4.0,6.0


In [15]:
counts_folder_path = "../../co_occurrence_results/output_pile/aggregated_counts/"
window = 250
count_file_name = f"aggregated_{demographic}_{window}.csv"
counts_file_path = os.path.join(counts_folder_path, count_file_name)

df = pd.read_csv(counts_file_path)

In [16]:
df['Demographics'] = df['Demographics'].replace({
    'black/african american': 'black',
    'hispanic/latino': 'hispanic',
    'native american/indigenous': 'indigenous',
    'pacific islander': 'pacific_islander',
    'white/caucasian': 'white'
})

# Pivot the table
pivot_df = df.pivot_table(index='Disease', 
                          columns='Demographics', 
                          values='Counts')

# Reset index if you want 'Disease' as a column instead of index
pivot_df.reset_index(inplace=True)

# If there are any NaN values (in case some diseases don't have counts for certain demographics), you can fill them with 0
pivot_df.fillna(0, inplace=True)

template_index = pivot_df.columns.get_loc("Disease")
column_names_after_template = pivot_df.columns[template_index+1:].tolist()


race_columns = column_names_after_template
pivot_df[race_columns] = pivot_df[race_columns].rank(axis=1, ascending=False)

pivot_df


Demographics,Disease,asian,black,hispanic,indigenous,pacific_islander,white
0,10743008.0,3.0,2.0,4.0,5.0,6.0,1.0
1,11381005.0,3.0,2.0,4.0,5.0,6.0,1.0
2,11654001.0,3.5,2.0,3.5,5.0,6.0,1.0
3,129103003.0,3.0,2.0,4.0,5.0,6.0,1.0
4,13746004.0,4.0,2.0,3.0,5.0,6.0,1.0
...,...,...,...,...,...,...,...
87,sarcoidoses,3.0,2.0,4.5,4.5,6.0,1.0
88,syphilis,5.0,2.0,3.0,4.0,6.0,1.0
89,takotsubo cardiomyopathy,3.0,2.0,4.0,6.0,5.0,1.0
90,tuberculoses,3.0,2.0,4.0,5.0,6.0,1.0
