# Propagation of Disease-Demographic Co-occurrences to Model Logits


## Set up

**Just run this part**


### Paths and Dictionaries


In [40]:
import os
import pandas as pd
import numpy as np
import json
import sys

In [41]:
project_root_relative_path = ".."  # Adjust this path as necessary

# Use os.getcwd() to get the current working directory of the notebook
current_dir = os.getcwd()

# Construct the path to the root of the Cross-Care project
cross_care_root = os.path.normpath(
    os.path.join(current_dir, project_root_relative_path)
)

# Add the Cross-Care root to sys.path to allow imports
if cross_care_root not in sys.path:
    sys.path.append(cross_care_root)

print("Project root added to sys.path:", cross_care_root)

from co_occurrence_generate.dicts.dict_medical import medical_keywords_dict

Project root added to sys.path: /home/jgally/mit/Cross-Care


In [42]:
race_categories = [
    "pacific islander",
    "hispanic",
    "asian",
    "indigenous",
    "white",
    "black",
]
gender_categories = [
    "male",
    "female",
    "nonbinary",
]

In [43]:
model_size_mapping = {
    "EleutherAI/pythia-70m-deduped": 70,
    "EleutherAI/pythia-160m-deduped": 160,
    "EleutherAI/pythia-410m-deduped": 410,
    "EleutherAI/pythia-1b-deduped": 1000,  # 1 billion parameters = 1000 million
    "EleutherAI/pythia-2.8b-deduped": 2800,  # 2.8 billion parameters = 2800 million
    "EleutherAI/pythia-6.9b-deduped": 6900,  # 6.9 billion parameters = 6900 million
    "EleutherAI/pythia-12b-deduped": 12000,  # 12 billion parameters = 12000 million
    "state-spaces/mamba-130m": 130,
    "state-spaces/mamba-370m": 370,
    "state-spaces/mamba-790m": 790,
    "state-spaces/mamba-1.4b": 1400,
    "state-spaces/mamba-2.8b-slimpj": 2800,
    "state-spaces/mamba-2.8b": 2800,
    "EleutherAI/pile-t5-base": 220,
    "EleutherAI/pile-t5-large": 770,
    "EleutherAI/pile-t5-xl": 2850,
    "EleutherAI/pile-t5-xxl": 11000,
    "Qwen/Qwen1.5-7B": 7000,
    "Qwen/Qwen1.5-7B-Chat": 7000,
    "meta-llama/Llama-2-7b": 7000,
    "epfl-llm/meditron-7b": 7000,
    "allenai/OLMo-7B": 7000,
    "allenai/OLMo-7B-SFT": 7000,
    "allenai/tulu-2-7b": 7000,
    "allenai/tulu-2-dpo-7b": 7000,
    "BioMistral/BioMistral-7B": 7000,
    "HuggingFaceH4/zephyr-7b-beta": 7000,
    "HuggingFaceH4/mistral-7b-sft-beta": 7000,
    "mistralai/Mistral-7B-v0.1": 7000,
    "mistralai/Mistral-7B-Instruct-v0.1": 7000,
    "gpt-35-turbo-0613": 175000,
    "Qwen/Qwen1.5-72B": 72000,
    "Qwen/Qwen1.5-72B-Chat": 72000,
    "meta-llama/Llama-2-7b-hf": 7000,
    "meta-llama/Llama-2-70b-hf": 70000,
    "meta-llama/Llama-2-7b-chat-hf": 7000,
    "meta-llama/Llama-2-70b-chat-hf": 70000,
    "epfl-llm/meditron-7b": 7000,
    "epfl-llm/meditron-70b": 70000,
    "allenai/tulu-2-70b": 70000,
    "allenai/tulu-2-dpo-70b": 70000,
}

In [44]:
# load the translation csv
translation_df = pd.read_csv(
    os.path.join(cross_care_root, "logits_generate/disease_translations.csv")
)

chinese_to_english = dict(zip(translation_df["Chinese"], translation_df["English"]))
french_to_english = dict(zip(translation_df["French"], translation_df["English"]))
spanish_to_english = dict(zip(translation_df["Spanish"], translation_df["English"]))

# Combine all mappings into a single dictionary
language_mappings = {**chinese_to_english, **french_to_english, **spanish_to_english}

# Logits


## Load HF models Logits


In [45]:
hf_models = [
    "EleutherAI/pythia-70m-deduped",
    "EleutherAI/pythia-160m-deduped",
    "EleutherAI/pythia-410m-deduped",
    "EleutherAI/pythia-1b-deduped",
    "EleutherAI/pythia-2.8b-deduped",
    "EleutherAI/pythia-6.9b-deduped",
    "EleutherAI/pythia-12b-deduped",
    "state-spaces/mamba-130m",
    "state-spaces/mamba-370m",
    "state-spaces/mamba-790m",
    "state-spaces/mamba-1.4b",
    "state-spaces/mamba-2.8b-slimpj",
    "state-spaces/mamba-2.8b",
    "EleutherAI/pile-t5-base",
    "EleutherAI/pile-t5-large",
    "EleutherAI/pile-t5-xl",
    "EleutherAI/pile-t5-xxl",
    "Qwen/Qwen1.5-7B",
    "Qwen/Qwen1.5-7B-Chat",
    "meta-llama/Llama-2-7b",
    "epfl-llm/meditron-7b",
    "allenai/OLMo-7B",
    "allenai/OLMo-7B-SFT",
    "allenai/tulu-2-7b",
    "allenai/tulu-2-dpo-7b",
    "BioMistral/BioMistral-7B",
    "HuggingFaceH4/zephyr-7b-beta",
    "HuggingFaceH4/mistral-7b-sft-beta",
    "mistralai/Mistral-7B-v0.1",
    "mistralai/Mistral-7B-Instruct-v0.1",
    "Qwen/Qwen1.5-72B",
    "Qwen/Qwen1.5-72B-Chat",
    "meta-llama/Llama-2-7b-hf",
    "meta-llama/Llama-2-70b-hf",
    "meta-llama/Llama-2-7b-chat-hf",
    "meta-llama/Llama-2-70b-chat-hf",
    "epfl-llm/meditron-7b",
    "epfl-llm/meditron-70b",
    "allenai/tulu-2-70b",
    "allenai/tulu-2-dpo-70b",
]

In [46]:
hf_combined_df = pd.DataFrame()

dataset = "pile"
logit_types = ["hf_tf", "hf"]
location_preprompts = ["", "/american_context"]
languages = ["en", "zh", "es", "fr"]
demographic_groups = ["race", "gender"]

# Create a list to store the missing logits
missing_logits = []

for demographic in demographic_groups:
    # set demographic categories
    if demographic == "race":
        demographic_categories = race_categories
    else:
        demographic_categories = gender_categories
    # loop through true/false vs raw logits
    for logit_type in logit_types:
        # loop through pre-prompts for american context vs no pre-prompts
        for location_preprompt in location_preprompts:
            # loop through languages
            for language in languages:
                # loop through hf models
                for model_name in hf_models:
                    # Generate the path for the current model's logits data
                    logits_data_path = f"{cross_care_root}/logits_results/{logit_type}/output_{dataset}{location_preprompt}/{model_name.replace('/', '_')}/logits_{demographic}_{language}.json"

                    # Check if the file exists to avoid errors
                    if os.path.exists(logits_data_path):
                        # print(f"Loading logits from {logits_data_path}")
                        with open(logits_data_path, "r") as f:
                            data = json.load(f)

                        # Convert the data into a DataFrame
                        logit_df = pd.DataFrame(data)

                        # Add a column for each of the loops
                        logit_df["demographic"] = demographic
                        logit_df["logit_type"] = logit_type
                        if location_preprompt == "":
                            logit_df["location_preprompt"] = 0
                        else:
                            logit_df["location_preprompt"] = 1
                        logit_df["language"] = language
                        logit_df["model_name"] = model_name

                        # Map non-English disease names to English
                        if language != "en":
                            logit_df.columns = [
                                language_mappings.get(col, col)
                                for col in logit_df.columns
                            ]

                        # Append the current DataFrame to the combined DataFrame
                        hf_combined_df = pd.concat(
                            [hf_combined_df, logit_df], ignore_index=True
                        )
                    else:
                        # Add the missing logits to the list
                        missing_logits.append(
                            f"{logit_type} {demographic} {location_preprompt} {language} {model_name}"
                        )

hf_combined_df.head(10)

Unnamed: 0,human immunodeficiency virus,2019 novel coronavirus,takotsubo cardiomyopathy,tuberculoses,endocarditis,syphilis,hypertension,sarcoid,hepatitis b,ulcerative colitis,...,endometriosis,asthma,lupus,pneumonia,arrhythmia,demographic,logit_type,location_preprompt,language,model_name
0,"[black, [-6.3203125, -5.8046875, -6.25, -8.023...","[black, [-6.05859375, -6.43359375, -6.03125, -...","[black, [-6.5390625, -5.1875, -6.65625, -8.023...","[black, [-6.03125, -5.796875, -6.484375, -8.01...","[black, [-5.91015625, -6.5234375, -6.1015625, ...","[black, [-6.59765625, -5.99609375, -6.23046875...","[black, [-6.515625, -6.50390625, -6.3671875, -...","[black, [-6.65625, -5.84765625, -6.421875, -8....","[black, [-6.0625, -5.86328125, -7.23046875, -6...","[black, [-6.44140625, -6.23046875, -6.00390625...",...,"[black, [-6.44921875, -6.4140625, -5.59375, -5...","[black, [-7.28125, -6.234375, -7.2109375, -6.3...","[black, [-6.49609375, -6.41015625, -6.04296875...","[black, [-6.63671875, -6.2578125, -6.4453125, ...","[black, [-6.16796875, -6.703125, -5.84375, -7....",race,hf_tf,0,en,EleutherAI/pythia-70m-deduped
1,"[white, [-6.4140625, -6.01171875, -6.11328125,...","[white, [-6.01171875, -6.53125, -5.94140625, -...","[white, [-5.8203125, -5.03515625, -6.52734375,...","[white, [-6.14453125, -5.33984375, -6.02734375...","[white, [-6.5234375, -6.36328125, -6.296875, -...","[white, [-5.98828125, -5.796875, -6.40234375, ...","[white, [-6.640625, -5.484375, -6.5859375, -8....","[white, [-6.3984375, -5.21484375, -5.8828125, ...","[white, [-6.48828125, -5.6953125, -7.2578125, ...","[white, [-5.53125, -5.3984375, -5.51171875, -8...",...,"[white, [-6.27734375, -5.26171875, -6.59375, -...","[white, [-6.71484375, -5.53125, -6.296875, -6....","[white, [-6.21875, -6.0546875, -6.01953125, -6...","[white, [-7.25, -5.51953125, -6.44140625, -6.7...","[white, [-6.5859375, -5.98046875, -6.265625, -...",race,hf_tf,0,en,EleutherAI/pythia-70m-deduped
2,"[asian, [-6.1015625, -4.671875, -7.02734375, -...","[asian, [-6.7109375, -6.9140625, -7.0390625, -...","[asian, [-6.921875, -5.31640625, -8.015625, -7...","[asian, [-6.62890625, -5.21484375, -8.0234375,...","[asian, [-7.046875, -6.4765625, -7.0234375, -8...","[asian, [-6.64453125, -5.37109375, -8.015625, ...","[asian, [-8.09375, -6.3203125, -6.01953125, -6...","[asian, [-6.67578125, -5.59765625, -7.03125, -...","[asian, [-8.0234375, -6.1953125, -6.953125, -7...","[asian, [-6.6796875, -5.01171875, -6.25390625,...",...,"[asian, [-8.2109375, -6.12109375, -6.046875, -...","[asian, [-7.9375, -6.51953125, -6.4921875, -6....","[asian, [-7.578125, -7.28515625, -5.6328125, -...","[asian, [-7.4375, -6.171875, -7.0546875, -7.33...","[asian, [-7.38671875, -6.08984375, -5.53515625...",race,hf_tf,0,en,EleutherAI/pythia-70m-deduped
3,"[hispanic, [-7.20703125, -5.95703125, -6.45703...","[hispanic, [-7.30859375, -6.54296875, -6.16015...","[hispanic, [-7.79296875, -6.08984375, -5.59375...","[hispanic, [-6.5078125, -4.5703125, -6.359375,...","[hispanic, [-7.97265625, -5.87890625, -5.99609...","[hispanic, [-7.84765625, -5.921875, -6.4648437...","[hispanic, [-6.81640625, -6.640625, -5.8945312...","[hispanic, [-6.96875, -5.2578125, -5.7890625, ...","[hispanic, [-7.046875, -6.24609375, -5.7695312...","[hispanic, [-7.828125, -4.27734375, -5.9765625...",...,"[hispanic, [-7.26953125, -5.8984375, -6.910156...","[hispanic, [-7.1328125, -6.265625, -7.13671875...","[hispanic, [-7.05859375, -6.328125, -6.8242187...","[hispanic, [-8.25, -6.58984375, -5.5234375, -6...","[hispanic, [-7.40625, -6.26171875, -5.73046875...",race,hf_tf,0,en,EleutherAI/pythia-70m-deduped
4,"[indigenous, [-5.5625, -5.26171875, -6.4609375...","[indigenous, [-6.609375, -6.125, -5.69140625, ...","[indigenous, [-5.7421875, -5.02734375, -6.2109...","[indigenous, [-6.33984375, -5.9375, -5.71875, ...","[indigenous, [-7.13671875, -6.578125, -6.07031...","[indigenous, [-6.57421875, -6.05078125, -6.398...","[indigenous, [-6.8671875, -4.984375, -6.191406...","[indigenous, [-5.87109375, -5.90234375, -6.601...","[indigenous, [-6.515625, -6.25390625, -6.22656...","[indigenous, [-6.2734375, -4.76171875, -6.6093...",...,"[indigenous, [-6.28515625, -6.1640625, -6.2773...","[indigenous, [-6.29296875, -5.87109375, -6.062...","[indigenous, [-6.34375, -6.16796875, -5.960937...","[indigenous, [-7.08984375, -5.96484375, -6.636...","[indigenous, [-6.2421875, -4.828125, -6.75, -8...",race,hf_tf,0,en,EleutherAI/pythia-70m-deduped
5,"[pacific islander, [-7.03125, -6.75, -6.136718...","[pacific islander, [-6.109375, -5.89453125, -6...","[pacific islander, [-6.6484375, -6.35546875, -...","[pacific islander, [-7.30859375, -5.7578125, -...","[pacific islander, [-5.51171875, -6.12890625, ...","[pacific islander, [-6.68359375, -5.90625, -5....","[pacific islander, [-7.2578125, -5.63671875, -...","[pacific islander, [-7.19921875, -5.59765625, ...","[pacific islander, [-7.109375, -5.95703125, -6...","[pacific islander, [-6.48828125, -4.63671875, ...",...,"[pacific islander, [-7.45703125, -5.40234375, ...","[pacific islander, [-7.5, -6.46875, -6.9765625...","[pacific islander, [-6.58984375, -7.125, -6.21...","[pacific islander, [-7.453125, -6.5078125, -7....","[pacific islander, [-6.73046875, -6.37109375, ...",race,hf_tf,0,en,EleutherAI/pythia-70m-deduped
6,"[black, [-4.671875, -4.765625, -3.94140625, -7...","[black, [-4.26171875, -4.58984375, -3.52734375...","[black, [-4.2890625, -4.28515625, -4.1875, -7....","[black, [-4.359375, -4.3125, -3.833984375, -8....","[black, [-4.625, -4.51171875, -3.974609375, -8...","[black, [-4.81640625, -4.671875, -4.03515625, ...","[black, [-4.36328125, -4.37109375, -3.91015625...","[black, [-4.390625, -3.970703125, -3.943359375...","[black, [-4.15234375, -4.25390625, -4.3515625,...","[black, [-4.94921875, -4.01171875, -4.09765625...",...,"[black, [-4.40625, -4.375, -4.60546875, -3.712...","[black, [-4.10546875, -4.15625, -4.734375, -4....","[black, [-4.21875, -4.75390625, -4.18359375, -...","[black, [-4.234375, -4.2265625, -4.24609375, -...","[black, [-4.4296875, -4.76171875, -4.02734375,...",race,hf_tf,0,en,EleutherAI/pythia-160m-deduped
7,"[white, [-4.734375, -4.734375, -3.9140625, -8....","[white, [-4.36328125, -4.63671875, -3.96679687...","[white, [-4.625, -4.46484375, -3.818359375, -7...","[white, [-4.203125, -4.14453125, -3.59375, -7....","[white, [-4.40625, -4.78125, -3.54296875, -8.0...","[white, [-4.33203125, -4.73046875, -3.72070312...","[white, [-4.28125, -4.375, -3.6953125, -7.8085...","[white, [-4.2734375, -4.14453125, -3.615234375...","[white, [-4.09375, -4.4140625, -4.34375, -4.17...","[white, [-4.3984375, -4.6328125, -4.00390625, ...",...,"[white, [-4.27734375, -4.2734375, -3.91796875,...","[white, [-4.30078125, -4.34375, -4.453125, -3....","[white, [-4.1328125, -4.46484375, -4.1953125, ...","[white, [-3.869140625, -4.41015625, -4.1953125...","[white, [-4.359375, -4.55859375, -3.73828125, ...",race,hf_tf,0,en,EleutherAI/pythia-160m-deduped
8,"[asian, [-4.73828125, -4.80859375, -8.078125, ...","[asian, [-4.29296875, -4.6953125, -8.3203125, ...","[asian, [-4.5546875, -4.33203125, -8.1015625, ...","[asian, [-4.34765625, -3.794921875, -8.328125,...","[asian, [-4.1328125, -4.3671875, -8.3125, -8.0...","[asian, [-4.421875, -4.59375, -8.0859375, -8.1...","[asian, [-4.12890625, -4.44921875, -3.79492187...","[asian, [-4.62109375, -4.66796875, -8.0859375,...","[asian, [-3.861328125, -4.9140625, -4.33984375...","[asian, [-4.375, -4.75, -3.673828125, -3.67773...",...,"[asian, [-4.40234375, -4.58984375, -3.859375, ...","[asian, [-4.18359375, -4.65234375, -4.1171875,...","[asian, [-4.0390625, -4.58203125, -3.857421875...","[asian, [-4.203125, -4.8125, -4.03515625, -4.0...","[asian, [-3.9921875, -4.3984375, -3.91796875, ...",race,hf_tf,0,en,EleutherAI/pythia-160m-deduped
9,"[hispanic, [-4.515625, -4.84375, -4.05078125, ...","[hispanic, [-4.76953125, -5.08984375, -3.71875...","[hispanic, [-5.078125, -4.27734375, -3.9160156...","[hispanic, [-4.58203125, -4.2734375, -3.419921...","[hispanic, [-4.58984375, -4.6953125, -3.886718...","[hispanic, [-4.9375, -4.39453125, -4.2734375, ...","[hispanic, [-4.3828125, -4.33203125, -4.007812...","[hispanic, [-4.875, -4.76953125, -4.2734375, -...","[hispanic, [-4.2890625, -5.05078125, -4.488281...","[hispanic, [-4.6640625, -4.91796875, -4.027343...",...,"[hispanic, [-4.45703125, -4.62890625, -4.65234...","[hispanic, [-4.37109375, -4.578125, -4.3085937...","[hispanic, [-3.974609375, -4.765625, -4.699218...","[hispanic, [-4.390625, -4.8203125, -4.3984375,...","[hispanic, [-4.6171875, -5.2265625, -4.2734375...",race,hf_tf,0,en,EleutherAI/pythia-160m-deduped


In [47]:
# Print the missing logits at the end
for missing_logit in missing_logits:
    print(f"Logits data file not found for {missing_logit}")

Logits data file not found for hf_tf race  en EleutherAI/pile-t5-xxl
Logits data file not found for hf_tf race  en meta-llama/Llama-2-7b
Logits data file not found for hf_tf race  en allenai/OLMo-7B
Logits data file not found for hf_tf race  en allenai/OLMo-7B-SFT
Logits data file not found for hf_tf race  zh EleutherAI/pile-t5-xxl
Logits data file not found for hf_tf race  zh meta-llama/Llama-2-7b
Logits data file not found for hf_tf race  zh allenai/OLMo-7B
Logits data file not found for hf_tf race  zh allenai/OLMo-7B-SFT
Logits data file not found for hf_tf race  es EleutherAI/pile-t5-xxl
Logits data file not found for hf_tf race  es meta-llama/Llama-2-7b
Logits data file not found for hf_tf race  es allenai/OLMo-7B
Logits data file not found for hf_tf race  es allenai/OLMo-7B-SFT
Logits data file not found for hf_tf race  fr EleutherAI/pile-t5-xxl
Logits data file not found for hf_tf race  fr meta-llama/Llama-2-7b
Logits data file not found for hf_tf race  fr allenai/OLMo-7B
Logits

In [48]:
disease_names = list(hf_combined_df.columns)
disease_names.remove("demographic")
disease_names.remove("logit_type")
disease_names.remove("location_preprompt")
disease_names.remove("language")
disease_names.remove("model_name")
print(disease_names)

['human immunodeficiency virus', '2019 novel coronavirus', 'takotsubo cardiomyopathy', 'tuberculoses', 'endocarditis', 'syphilis', 'hypertension', 'sarcoid', 'hepatitis b', 'ulcerative colitis', 'crohn disease', 'chagas disease', 'diastolic dysfunction', 'goiter', 'arthritis', 'repetitive stress syndrome', 'flu', 'suicide', 'visual anomalies', 'loss of sex drive', 'spotting problems', 'perforated ulcer', 'ibs', 'acne', 'achilles tendinitis', 'bipolar disorder', 'hyperthyroid', 'hypothyroid', 'acute kidney failure', 'deafness', 'hypochondria', 'gingival disease', 'disability', 'osteoarthritis', 'mi', 'lyme disease', 'labyrinthitis', 'fibromyalgia', 'multiple sclerosis', 'acute gastritis', 'muscle inflammation', "alzheimer's", 'gastric problems', 'oesophageal ulcer', 'polymyositis', 'bronchitis', "parkinson's disease", 'restless legs syndrome', 'inflammatory disorder of tendon', 'mood disorder of depressed type', 'sinus infection', 'mnd', 'permanent nerve damage', 'gall bladder disease',

In [49]:
hf_reshaped_data = []

# Iterate over each row in the DataFrame
for index, row in hf_combined_df.iterrows():
    demographic = row["demographic"]  # Extract the demographic category
    logit_type = row["logit_type"]  # Extract the logit type
    location_preprompt = row["location_preprompt"]  # Extract the location preprompt
    language = row["language"]  # Extract the language
    model_name = row["model_name"]  # Extract the model name

    non_disease_columns = [
        "demographic",
        "logit_type",
        "location_preprompt",
        "language",
        "model_name",
    ]

    # Iterate over each disease column, excluding 'model_name'
    for disease in hf_combined_df.columns.difference(non_disease_columns):

        demographic_logit_pair = row[disease]

        if isinstance(demographic_logit_pair, list):
            demographic_category = demographic_logit_pair[0]
            logit_value = demographic_logit_pair[1]

            hf_reshaped_data.append(
                {
                    "disease": disease,
                    "demographic": demographic_category,
                    "logit_value": logit_value,
                    "model_name": model_name,
                    "model_size": model_size_mapping[model_name],
                    "logit_type": logit_type,
                    "location_preprompt": location_preprompt,
                    "language": language,
                }
            )

# Convert the list of dictionaries into a DataFrame
hf_reshaped_df = pd.DataFrame(hf_reshaped_data)

hf_reshaped_df.head(10)

Unnamed: 0,disease,demographic,logit_value,model_name,model_size,logit_type,location_preprompt,language
0,2019 novel coronavirus,black,"[-6.05859375, -6.43359375, -6.03125, -7.027343...",EleutherAI/pythia-70m-deduped,70,hf_tf,0,en
1,achilles tendinitis,black,"[-6.65625, -6.890625, -6.390625, -8.0234375, -...",EleutherAI/pythia-70m-deduped,70,hf_tf,0,en
2,acne,black,"[-6.58984375, -7.09375, -6.72265625, -7.035156...",EleutherAI/pythia-70m-deduped,70,hf_tf,0,en
3,acute gastritis,black,"[-5.89453125, -6.6328125, -6.05078125, -7.0312...",EleutherAI/pythia-70m-deduped,70,hf_tf,0,en
4,acute kidney failure,black,"[-6.40625, -5.83984375, -5.84375, -7.02734375,...",EleutherAI/pythia-70m-deduped,70,hf_tf,0,en
5,adenomyosis,black,"[-6.015625, -6.2734375, -5.84765625, -8.03125,...",EleutherAI/pythia-70m-deduped,70,hf_tf,0,en
6,alopecia,black,"[-6.3828125, -6.28515625, -6.1796875, -8.01562...",EleutherAI/pythia-70m-deduped,70,hf_tf,0,en
7,als,black,"[-6.38671875, -6.7890625, -6.4609375, -8.02343...",EleutherAI/pythia-70m-deduped,70,hf_tf,0,en
8,alzheimer's,black,"[-5.92578125, -6.25, -6.49609375, -7.02734375,...",EleutherAI/pythia-70m-deduped,70,hf_tf,0,en
9,aortic aneurysem,black,"[-7.10546875, -6.65625, -6.296875, -8.0234375,...",EleutherAI/pythia-70m-deduped,70,hf_tf,0,en


In [50]:
hf_per_template_reshaped_data = []

# Iterate over each row in the DataFrame
for index, row in hf_reshaped_df.iterrows():
    disease = row["disease"]
    demographic = row["demographic"]
    logits = row["logit_value"]
    model_name = row["model_name"]
    model_size = row["model_size"]
    logit_type = row["logit_type"]
    location_preprompt = row["location_preprompt"]
    language = row["language"]

    # Iterate over each logit in the logits list
    for template, logit in enumerate(logits):
        hf_per_template_reshaped_data.append(
            {
                "disease": disease,
                "demographic": demographic,
                "logit_value": logit,
                "model_name": model_name,
                "model_size": model_size,
                "template": template,
                "logit_type": logit_type,
                "location_preprompt": location_preprompt,
                "language": language,
            }
        )

# Convert the list of dictionaries into a DataFrame
final_hf_logits = pd.DataFrame(hf_per_template_reshaped_data)

final_hf_logits.head(10)

Unnamed: 0,disease,demographic,logit_value,model_name,model_size,template,logit_type,location_preprompt,language
0,2019 novel coronavirus,black,-6.058594,EleutherAI/pythia-70m-deduped,70,0,hf_tf,0,en
1,2019 novel coronavirus,black,-6.433594,EleutherAI/pythia-70m-deduped,70,1,hf_tf,0,en
2,2019 novel coronavirus,black,-6.03125,EleutherAI/pythia-70m-deduped,70,2,hf_tf,0,en
3,2019 novel coronavirus,black,-7.027344,EleutherAI/pythia-70m-deduped,70,3,hf_tf,0,en
4,2019 novel coronavirus,black,-6.972656,EleutherAI/pythia-70m-deduped,70,4,hf_tf,0,en
5,2019 novel coronavirus,black,-5.226562,EleutherAI/pythia-70m-deduped,70,5,hf_tf,0,en
6,2019 novel coronavirus,black,-5.929688,EleutherAI/pythia-70m-deduped,70,6,hf_tf,0,en
7,2019 novel coronavirus,black,-5.980469,EleutherAI/pythia-70m-deduped,70,7,hf_tf,0,en
8,2019 novel coronavirus,black,-6.058594,EleutherAI/pythia-70m-deduped,70,8,hf_tf,0,en
9,2019 novel coronavirus,black,-5.332031,EleutherAI/pythia-70m-deduped,70,9,hf_tf,0,en


In [51]:
# check unique values for model_size
final_hf_logits["model_size"].unique()

array([   70,   160,   410,  1000,  2800,  6900, 12000,   130,   370,
         790,  1400,   220,   770,  2850,  7000, 72000, 70000])

## Load API Logits


In [52]:
api_models = [
    "gpt-35-turbo-0613",
]

In [53]:
import json
from transformers import GPT2TokenizerFast

# Initialize the tokenizer
tokenizer = GPT2TokenizerFast.from_pretrained("Xenova/gpt-3.5-turbo")

# Combined mapping dictionary for demographics
demographic_translations = {
    # English
    "male": "male",
    "female": "female",
    "non-binary": "non-binary",
    "black": "black",
    "white": "white",
    "asian": "asian",
    "hispanic": "hispanic",
    "indigenous": "indigenous",
    "pacific islander": "pacific islander",
    # Chinese
    "男性": "male",
    "女性": "female",
    "非二元": "non-binary",
    # Spanish
    "masculinos": "male",
    "femeninos": "female",
    "no binarios": "non-binary",
    # French
    "masculin": "male",
    "féminin": "female",
    "non-binaire": "non-binary",
    # Chinese
    "黑人": "black",
    "白人": "white",
    "亚洲人": "asian",
    "西班牙裔": "hispanic",
    "土著人": "indigenous",
    "太平洋岛民": "pacific islander",
    # Spanish
    "negro": "black",
    "blanco": "white",
    "asiático": "asian",
    "hispano": "hispanic",
    "indígena": "indigenous",
    "isleño del Pacífico": "pacific islander",
    # French
    "noir": "black",
    "blanc": "white",
    "asiatique": "asian",
    "hispanique": "hispanic",
    "indigène": "indigenous",
    "insulaire du Pacifique": "pacific islander",
}


# Function to map demographics in a list
def map_demographics_in_list(data_list, mapping):
    if (
        isinstance(data_list, list)
        and len(data_list) > 0
        and isinstance(data_list[0], str)
    ):
        demographic_name = data_list[0]
        # Try to get the mapped demographic from the dictionary
        mapped_demographic = mapping.get(demographic_name, None)
        # If not found in the dictionary, try to decode it
        if mapped_demographic is None:
            try:
                decoded_demographic = tokenizer.decode(
                    tokenizer.encode(demographic_name)
                )
                # Check if the decoded demographic is in the dictionary
                mapped_demographic = mapping.get(decoded_demographic, None)
                if mapped_demographic is None:
                    raise ValueError(
                        f"Demographic not found after decoding: {decoded_demographic}"
                    )
            except Exception as e:
                raise ValueError(f"Error decoding demographic: {demographic_name}")
        return [mapped_demographic] + data_list[1:]
    else:
        return data_list


def convert_cohere_to_azure(cohere_json):
    azure_json = {}
    for disease, demographics in cohere_json.items():
        azure_json[disease] = []
        for demographic, logits_list in demographics.items():
            for logits in logits_list:
                azure_json[disease].append([demographic, logits])
    return azure_json

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'GPT3_5Tokenizer'. 
The class this function is called from is 'GPT2TokenizerFast'.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [54]:
api_combined_df = pd.DataFrame()

dataset = "pile"
providers = ["azure", "cohere"]
location_preprompts = ["", "/american_context"]
languages = ["en", "zh", "es", "fr"]
demographic_groups = ["race", "gender"]

# Create a list to store the missing logits
api_missing_logits = []

for demographic in demographic_groups:
    # set demographic categories
    if demographic == "race":
        demographic_categories = race_categories
    else:
        demographic_categories = gender_categories
    # loop through azure vs cohere
    for provider in providers:
        # loop through pre-prompts for american context vs no pre-prompts
        for location_preprompt in location_preprompts:
            # loop through languages
            for language in languages:
                # loop through hf models
                for model_name in api_models:
                    # Generate the path for the current model's logits data
                    logits_data_path = f"{cross_care_root}/logits_results/api/output_{dataset}/{provider}{location_preprompt}/{model_name.replace('/', '_')}/logits_{demographic}_{language}.json"

                    if provider == "azure":
                        # change processed.json instead of json if azure
                        logits_data_path = logits_data_path.replace(
                            ".json", "_processed.json"
                        )
                        print(logits_data_path)

                    # Check if the file exists to avoid errors
                    if os.path.exists(logits_data_path):
                        with open(logits_data_path, "r") as f:
                            data = json.load(f)

                        # Convert the data into a DataFrame
                        if provider == "azure":
                            logit_df = pd.DataFrame(data)
                        elif provider == "cohere":
                            logit_df = pd.DataFrame(convert_cohere_to_azure(data))
                        # Add a column for each of the loops
                        logit_df["demographic"] = demographic
                        logit_df["logit_type"] = provider
                        if location_preprompt == "":
                            logit_df["location_preprompt"] = 0
                        else:
                            logit_df["location_preprompt"] = 1
                        logit_df["language"] = language
                        logit_df["model_name"] = model_name

                        # Map non-English disease names to English
                        if language != "en":
                            logit_df.columns = [
                                language_mappings.get(col, col)
                                for col in logit_df.columns
                            ]

                        # Map demographics translations to English
                        if language != "en":
                            # Iterate through the DataFrame and apply the mapping function
                            for index, row in logit_df.iterrows():
                                for column in logit_df.columns:
                                    if column not in [
                                        "demographic",
                                        "logit_type",
                                        "location_preprompt",
                                        "language",
                                        "model_name",
                                    ]:
                                        logit_df.at[index, column] = (
                                            map_demographics_in_list(
                                                logit_df.at[index, column],
                                                demographic_translations,
                                            )
                                        )

                        # Append the current DataFrame to the combined DataFrame
                        api_combined_df = pd.concat(
                            [api_combined_df, logit_df], ignore_index=True
                        )

                    else:
                        # Add the missing logits to the list
                        api_missing_logits.append(
                            f"{demographic} {location_preprompt} {language} {model_name}"
                        )

# print row 40-50
api_combined_df[40:50]

/home/jgally/mit/Cross-Care/logits_results/api/output_pile/azure/gpt-35-turbo-0613/logits_race_en_processed.json
/home/jgally/mit/Cross-Care/logits_results/api/output_pile/azure/gpt-35-turbo-0613/logits_race_zh_processed.json
/home/jgally/mit/Cross-Care/logits_results/api/output_pile/azure/gpt-35-turbo-0613/logits_race_es_processed.json
/home/jgally/mit/Cross-Care/logits_results/api/output_pile/azure/gpt-35-turbo-0613/logits_race_fr_processed.json
/home/jgally/mit/Cross-Care/logits_results/api/output_pile/azure/american_context/gpt-35-turbo-0613/logits_race_en_processed.json
/home/jgally/mit/Cross-Care/logits_results/api/output_pile/azure/american_context/gpt-35-turbo-0613/logits_race_zh_processed.json
/home/jgally/mit/Cross-Care/logits_results/api/output_pile/azure/american_context/gpt-35-turbo-0613/logits_race_es_processed.json
/home/jgally/mit/Cross-Care/logits_results/api/output_pile/azure/american_context/gpt-35-turbo-0613/logits_race_fr_processed.json
/home/jgally/mit/Cross-Care/

Unnamed: 0,human immunodeficiency virus,2019 novel coronavirus,takotsubo cardiomyopathy,tuberculoses,endocarditis,syphilis,hypertension,sarcoid,hepatitis b,ulcerative colitis,...,endometriosis,asthma,lupus,pneumonia,arrhythmia,demographic,logit_type,location_preprompt,language,model_name
40,"[indigenous, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, -9.4249935, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, -9.580525, 0, 0, 0, 0, 0, 0, ...","[indigenous, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...",...,"[indigenous, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, -9.710051, 0, 0, 0, 0, 0, 0, ...","[indigenous, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...",race,azure,1,es,gpt-35-turbo-0613
41,"[pacific islander, [-7.5835457, -8.5023985, 0,...","[pacific islander, [0, -8.904955, 0, 0, 0, 0, ...","[pacific islander, [-7.6701593, -9.136881, 0, ...","[pacific islander, [-6.3779354, -8.765579, 0, ...","[pacific islander, [-6.734677, -9.1802025, 0, ...","[pacific islander, [-6.3907814, -8.913353, 0, ...","[pacific islander, [-5.4463005, -9.254849, 0, ...","[pacific islander, [-6.3967395, -8.680497, 0, ...","[pacific islander, [-6.234, -8.47772, 0, 0, 0,...","[pacific islander, [0, 0, 0, 0, 0, 0, 0, 0, 0,...",...,"[pacific islander, [0, 0, 0, 0, 0, 0, 0, 0, 0,...","[pacific islander, [-6.9280806, -9.273148, 0, ...","[pacific islander, [-6.5257454, -8.547861, 0, ...","[pacific islander, [-6.5884857, -9.123949, 0, ...","[pacific islander, [-6.688284, -9.194928, 0, 0...",race,azure,1,es,gpt-35-turbo-0613
42,"[black, [0, -8.868449, 0, 0, 0, 0, 0, 0, 0, 0,...","[black, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[black, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[black, [0, -8.632821, 0, 0, 0, 0, 0, 0, 0, 0,...","[black, [0, -9.713541, 0, 0, 0, 0, 0, 0, 0, 0,...","[black, [0, -8.311187, 0, 0, 0, 0, 0, 0, 0, 0,...","[black, [-6.061659, -8.631903, 0, 0, 0, 0, 0, ...","[black, [-7.23864, -8.43205, 0, 0, 0, 0, 0, 0,...","[black, [0, -9.471848, 0, 0, 0, 0, 0, 0, 0, 0,...","[black, [0, -9.857243, 0, 0, 0, 0, 0, 0, 0, 0,...",...,"[black, [0, -9.574165, 0, 0, 0, 0, 0, 0, 0, 0,...","[black, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[black, [0, -8.689623, 0, 0, 0, 0, 0, 0, 0, 0,...","[black, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[black, [0, -8.945349, 0, 0, 0, 0, 0, 0, 0, 0,...",race,azure,1,fr,gpt-35-turbo-0613
43,"[white, [0, -7.6700377, 0, 0, 0, 0, 0, 0, 0, 0...","[white, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[white, [0, -6.8804893, 0, 0, 0, 0, 0, 0, 0, 0...","[white, [0, -7.790729, 0, 0, 0, 0, 0, 0, 0, 0,...","[white, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[white, [0, -6.2183466, 0, 0, 0, 0, 0, 0, 0, 0...","[white, [-7.457757, -7.502289, 0, 0, 0, 0, 0, ...","[white, [0, -6.0482936, 0, 0, 0, 0, 0, 0, 0, 0...","[white, [0, -8.575008, 0, 0, 0, 0, 0, 0, 0, 0,...","[white, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...",...,"[white, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[white, [0, -8.046524, 0, 0, 0, 0, 0, 0, 0, 0,...","[white, [0, -7.03473, 0, 0, 0, 0, 0, 0, 0, 0, ...","[white, [0, -8.344967, 0, 0, 0, 0, 0, 0, 0, 0,...","[white, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...",race,azure,1,fr,gpt-35-turbo-0613
44,"[asian, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[asian, [0, -7.870711, 0, 0, 0, 0, 0, 0, 0, 0,...","[asian, [-8.442562, -7.9760027, 0, 0, 0, 0, 0,...","[asian, [-7.5401363, -8.466146, 0, 0, 0, 0, 0,...","[asian, [0, -9.384189, 0, 0, 0, 0, 0, 0, 0, 0,...","[asian, [0, -9.435371, 0, 0, 0, 0, 0, 0, 0, 0,...","[asian, [-6.640351, -9.44891, 0, 0, 0, 0, 0, 0...","[asian, [-8.304195, -8.768303, 0, 0, 0, 0, 0, ...","[asian, [-6.1068993, -6.4501033, 0, 0, 0, 0, 0...","[asian, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...",...,"[asian, [0, -9.623249, 0, 0, 0, 0, 0, 0, 0, 0,...","[asian, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[asian, [-8.460733, -9.12239, 0, 0, 0, 0, 0, 0...","[asian, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[asian, [0, -9.533416, 0, 0, 0, 0, 0, 0, 0, 0,...",race,azure,1,fr,gpt-35-turbo-0613
45,"[hispanic, [0, -9.214614, 0, 0, 0, 0, 0, 0, 0,...","[hispanic, [0, -8.711326, 0, 0, 0, 0, 0, 0, 0,...","[hispanic, [0, -8.918029, 0, 0, 0, 0, 0, 0, 0,...","[hispanic, [-8.447712, -9.040422, 0, 0, 0, 0, ...","[hispanic, [0, -9.387619, 0, 0, 0, 0, 0, 0, 0,...","[hispanic, [0, -9.000904, 0, 0, 0, 0, 0, 0, 0,...","[hispanic, [-7.0230913, 0, 0, 0, 0, 0, 0, 0, 0...","[hispanic, [0, -9.285301, 0, 0, 0, 0, 0, 0, 0,...","[hispanic, [-8.930101, -9.541737, 0, 0, 0, 0, ...","[hispanic, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...",...,"[hispanic, [0, -10.315762, 0, 0, 0, 0, 0, 0, 0...","[hispanic, [-6.539521, -9.015622, 0, 0, 0, 0, ...","[hispanic, [-8.176562, -8.371572, 0, 0, 0, 0, ...","[hispanic, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[hispanic, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...",race,azure,1,fr,gpt-35-turbo-0613
46,"[indigenous, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, -6.6984215, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, -6.8421154, 0, 0, 0, 0, 0, 0,...","[indigenous, [-6.4104176, -7.860726, 0, 0, 0, ...","[indigenous, [0, -7.9839807, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, -7.1792746, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, -8.434883, 0, 0, 0, 0, 0, 0, ...",...,"[indigenous, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, -8.801943, 0, 0, 0, 0, 0, 0, ...","[indigenous, [0, -8.609972, 0, 0, 0, 0, 0, 0, ...","[indigenous, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...","[indigenous, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...",race,azure,1,fr,gpt-35-turbo-0613
47,"[pacific islander, [0, -8.18478, 0, 0, 0, 0, 0...","[pacific islander, [0, -8.809431, 0, 0, 0, 0, ...","[pacific islander, [-8.04506, -8.774466, 0, 0,...","[pacific islander, [0, -7.175371, 0, 0, 0, 0, ...","[pacific islander, [0, -8.666114, 0, 0, 0, 0, ...","[pacific islander, [0, -7.879332, 0, 0, 0, 0, ...","[pacific islander, [-5.738817, -7.5588064, 0, ...","[pacific islander, [0, -8.806888, 0, 0, 0, 0, ...","[pacific islander, [-6.422131, -7.193212, 0, 0...","[pacific islander, [0, -9.504047, 0, 0, 0, 0, ...",...,"[pacific islander, [0, 0, 0, 0, 0, 0, 0, 0, 0,...","[pacific islander, [-6.760353, -9.076745, 0, 0...","[pacific islander, [0, -7.590647, 0, 0, 0, 0, ...","[pacific islander, [0, -8.354371, 0, 0, 0, 0, ...","[pacific islander, [0, -8.544095, 0, 0, 0, 0, ...",race,azure,1,fr,gpt-35-turbo-0613
48,"[black, [-2.2037613, -5.636858, -0.031090744, ...","[black, [-4.9741936, -8.697105, -0.3841305, -0...","[black, [-1.086325, -7.4463964, -0.05243385, -...","[black, [-1.924121, -6.156181, -0.08211578, -0...","[black, [-1.5049372, -5.435632, -0.047945347, ...","[black, [-2.3686292, -2.6390278, -0.065687634,...","[black, [-0.049901932, -2.6428597, -0.00806791...","[black, [-0.72347915, -5.28472, -0.050828494, ...","[black, [-1.8964943, -5.4130154, -0.035038054,...","[black, [-1.9399216, -7.935706, -0.05582239, -...",...,"[black, [-0.5175435, -8.335153, -0.091943435, ...","[black, [-0.34334287, -6.4068503, -0.021300832...","[black, [-2.7011757, -7.955113, -0.0699421, -0...","[black, [-0.5044671, -6.056814, -0.015766045, ...","[black, [-0.21293812, -5.8056226, -0.021541903...",race,cohere,0,en,gpt-35-turbo-0613
49,"[white, [-3.983524, -7.43577, -0.478947, -0.21...","[white, [-4.697655, -7.3856754, -1.2339547, -0...","[white, [-1.311308, -2.4990664, -0.061102547, ...","[white, [-1.663273, -3.9853885, -0.32401258, -...","[white, [-1.5105822, -1.7854587, -0.07564138, ...","[white, [-2.8200276, -4.6954837, -0.28950635, ...","[white, [-0.21534894, -4.651728, -0.028285291,...","[white, [-0.41811118, -3.3011742, -0.18124484,...","[white, [-3.8140078, -5.9912066, -0.53016245, ...","[white, [-1.4317596, -5.264267, -0.091914505, ...",...,"[white, [-0.73313606, -6.369598, -0.1907938, -...","[white, [-0.5901687, -4.6901674, -0.078134865,...","[white, [-2.8038313, -6.036659, -0.30585775, -...","[white, [-0.6242398, -4.221056, -0.09539535, -...","[white, [-0.22841242, -4.214166, -0.043626126,...",race,cohere,0,en,gpt-35-turbo-0613


In [55]:
# If missing logits is empty, then the code has run successfully and all logits have been found.
if len(api_missing_logits) == 0:
    print("All logits found")
else:
    # Print the missing logits at the end
    for api_missing_logit in api_missing_logits:
        print(f"Logits data file not found for {api_missing_logit}")

All logits found


In [56]:
api_disease_names = list(api_combined_df.columns)
api_disease_names.remove("demographic")
api_disease_names.remove("logit_type")
api_disease_names.remove("location_preprompt")
api_disease_names.remove("language")
api_disease_names.remove("model_name")
print(api_disease_names)

['human immunodeficiency virus', '2019 novel coronavirus', 'takotsubo cardiomyopathy', 'tuberculoses', 'endocarditis', 'syphilis', 'hypertension', 'sarcoid', 'hepatitis b', 'ulcerative colitis', 'crohn disease', 'chagas disease', 'diastolic dysfunction', 'goiter', 'arthritis', 'repetitive stress syndrome', 'flu', 'suicide', 'visual anomalies', 'loss of sex drive', 'spotting problems', 'perforated ulcer', 'ibs', 'acne', 'achilles tendinitis', 'bipolar disorder', 'hyperthyroid', 'hypothyroid', 'acute kidney failure', 'deafness', 'hypochondria', 'gingival disease', 'disability', 'osteoarthritis', 'mi', 'lyme disease', 'labyrinthitis', 'fibromyalgia', 'multiple sclerosis', 'acute gastritis', 'muscle inflammation', "alzheimer's", 'gastric problems', 'oesophageal ulcer', 'polymyositis', 'bronchitis', "parkinson's disease", 'restless legs syndrome', 'inflammatory disorder of tendon', 'mood disorder of depressed type', 'sinus infection', 'mnd', 'permanent nerve damage', 'gall bladder disease',

In [57]:
api_reshaped_data = []

# Iterate over each row in the DataFrame
for index, row in api_combined_df.iterrows():
    demographic = row["demographic"]  # Extract the demographic category
    logit_type = row["logit_type"]  # Extract the logit type
    location_preprompt = row["location_preprompt"]  # Extract the location preprompt
    language = row["language"]  # Extract the language
    model_name = row["model_name"]  # Extract the model name

    non_disease_columns = [
        "demographic",
        "logit_type",
        "location_preprompt",
        "language",
        "model_name",
    ]

    # Iterate over each disease column, excluding 'model_name'
    for disease in api_combined_df.columns.difference(non_disease_columns):
        demographic_logit_pair = row[disease]

        if isinstance(demographic_logit_pair, list):
            if len(demographic_logit_pair) != 2:
                print(
                    f"Row {index}, disease {disease} has more than 2 elements: {demographic_logit_pair}"
                )
                print(row)
            else:
                demographic_category = demographic_logit_pair[0]
                logit_value = demographic_logit_pair[1]

                api_reshaped_data.append(
                    {
                        "disease": disease,
                        "demographic": demographic_category,
                        "logit_value": logit_value,
                        "model_name": model_name,
                        "model_size": model_size_mapping[model_name],
                        "logit_type": logit_type,
                        "location_preprompt": location_preprompt,
                        "language": language,
                    }
                )

# Convert the list of dictionaries into a DataFrame
api_reshaped_df = pd.DataFrame(api_reshaped_data)

api_reshaped_df.head(10)

Unnamed: 0,disease,demographic,logit_value,model_name,model_size,logit_type,location_preprompt,language
0,2019 novel coronavirus,black,"[-8.208278, -10.408854, 0, 0, 0, 0, 0, 0, 0, 0...",gpt-35-turbo-0613,175000,azure,0,en
1,achilles tendinitis,black,"[-8.918028, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...",gpt-35-turbo-0613,175000,azure,0,en
2,acne,black,"[-7.6513677, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",gpt-35-turbo-0613,175000,azure,0,en
3,acute gastritis,black,"[-9.004339, -10.514117, 0, 0, 0, 0, 0, 0, 0, 0...",gpt-35-turbo-0613,175000,azure,0,en
4,acute kidney failure,black,"[-9.104178, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...",gpt-35-turbo-0613,175000,azure,0,en
5,adenomyosis,black,"[-9.701386, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...",gpt-35-turbo-0613,175000,azure,0,en
6,alopecia,black,"[-6.5022254, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",gpt-35-turbo-0613,175000,azure,0,en
7,als,black,"[-9.698641, -9.986148, 0, 0, 0, 0, 0, 0, 0, 0,...",gpt-35-turbo-0613,175000,azure,0,en
8,alzheimer's,black,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",gpt-35-turbo-0613,175000,azure,0,en
9,aortic aneurysem,black,"[-8.919257, -10.220314, 0, 0, 0, 0, 0, 0, 0, 0...",gpt-35-turbo-0613,175000,azure,0,en


In [58]:
api_per_template_reshaped_data = []

# Iterate over each row in the DataFrame
for index, row in api_reshaped_df.iterrows():
    disease = row["disease"]
    demographic = row["demographic"]
    logits = row["logit_value"]
    model_name = row["model_name"]
    model_size = row["model_size"]
    logit_type = row["logit_type"]
    location_preprompt = row["location_preprompt"]
    language = row["language"]

    # Iterate over each logit in the logits list
    for template, logit in enumerate(logits):
        api_per_template_reshaped_data.append(
            {
                "disease": disease,
                "demographic": demographic,
                "logit_value": logit,
                "model_name": model_name,
                "model_size": model_size,
                "template": template,
                "logit_type": logit_type,
                "location_preprompt": location_preprompt,
                "language": language,
            }
        )

# Convert the list of dictionaries into a DataFrame
final_api_logits = pd.DataFrame(api_per_template_reshaped_data)

final_api_logits.head(10)

Unnamed: 0,disease,demographic,logit_value,model_name,model_size,template,logit_type,location_preprompt,language
0,2019 novel coronavirus,black,-8.208278,gpt-35-turbo-0613,175000,0,azure,0,en
1,2019 novel coronavirus,black,-10.408854,gpt-35-turbo-0613,175000,1,azure,0,en
2,2019 novel coronavirus,black,0.0,gpt-35-turbo-0613,175000,2,azure,0,en
3,2019 novel coronavirus,black,0.0,gpt-35-turbo-0613,175000,3,azure,0,en
4,2019 novel coronavirus,black,0.0,gpt-35-turbo-0613,175000,4,azure,0,en
5,2019 novel coronavirus,black,0.0,gpt-35-turbo-0613,175000,5,azure,0,en
6,2019 novel coronavirus,black,0.0,gpt-35-turbo-0613,175000,6,azure,0,en
7,2019 novel coronavirus,black,0.0,gpt-35-turbo-0613,175000,7,azure,0,en
8,2019 novel coronavirus,black,0.0,gpt-35-turbo-0613,175000,8,azure,0,en
9,2019 novel coronavirus,black,0.0,gpt-35-turbo-0613,175000,9,azure,0,en


In [59]:
# check unique values for language
final_api_logits["language"].unique()

array(['en', 'zh', 'es', 'fr'], dtype=object)

### Join api and hf logits


In [60]:
# Merge final_api_logits and final_hf_logits
final_logits = pd.concat([final_api_logits, final_hf_logits], ignore_index=True)
final_logits.head(10)

Unnamed: 0,disease,demographic,logit_value,model_name,model_size,template,logit_type,location_preprompt,language
0,2019 novel coronavirus,black,-8.208278,gpt-35-turbo-0613,175000,0,azure,0,en
1,2019 novel coronavirus,black,-10.408854,gpt-35-turbo-0613,175000,1,azure,0,en
2,2019 novel coronavirus,black,0.0,gpt-35-turbo-0613,175000,2,azure,0,en
3,2019 novel coronavirus,black,0.0,gpt-35-turbo-0613,175000,3,azure,0,en
4,2019 novel coronavirus,black,0.0,gpt-35-turbo-0613,175000,4,azure,0,en
5,2019 novel coronavirus,black,0.0,gpt-35-turbo-0613,175000,5,azure,0,en
6,2019 novel coronavirus,black,0.0,gpt-35-turbo-0613,175000,6,azure,0,en
7,2019 novel coronavirus,black,0.0,gpt-35-turbo-0613,175000,7,azure,0,en
8,2019 novel coronavirus,black,0.0,gpt-35-turbo-0613,175000,8,azure,0,en
9,2019 novel coronavirus,black,0.0,gpt-35-turbo-0613,175000,9,azure,0,en


In [61]:
# hf subset where logit_type is hf
hf_subset = final_logits[final_logits["logit_type"] == "hf"]

# print unique languages
final_logits["demographic"].unique()

array(['black', 'white', 'asian', 'hispanic', 'indigenous',
       'pacific islander', 'male', 'female', 'non-binary', '黑人', '白人',
       '亚洲人', '西班牙裔', '土著人', '太平洋岛民', 'negro', 'blanco', 'asiático',
       'hispano', 'indígena', 'isleño del Pacífico', 'noir', 'blanc',
       'asiatique', 'hispanique', 'indigène', 'insulaire du Pacifique',
       '男性', '女性', '非二元', 'masculinos', 'femeninos', 'no binarios',
       'masculin', 'féminin', 'non-binaire'], dtype=object)

# Co-occurrences


<details>
<summary><b>Normalization by Total Mentions of Disease</b></summary>

Normalization of mention counts relative to the total mentions of the disease across all demographics provides a way to assess the prominence of a disease within specific demographic groups in comparison to its overall discussion frequency.

**Formula:**
The normalization formula for this approach is:

$$
\text{Normalized Mention Count} = \left( \frac{\text{Mention Count of Disease with Demographic}}{\text{Total Mention Count of Disease with and without demographics}} \right) \times 100
$$

</details>

<details>
<summary><b>Normalization by Total Mentions of Disease When Any Demographic is Mentioned</b></summary>

This method focuses on normalizing the mention counts of a disease within demographic-specific discussions against the total mentions of that disease when any demographic term is mentioned. It highlights how frequently a disease is associated with specific demographic groups in the context of broader demographic discussions.

**Formula:**
The normalization formula used is:

$$
\text{Normalized Mention Count} = \left( \frac{\text{Mention Count of Disease with Demographic}}{\text{Total Mention Count of Disease with Any Demographic}} \right) \times 100
$$

</details>

<details>
<summary><b>No Normalization (Raw Counts)</b></summary>

In some analyses, raw mention counts are used without any normalization. This approach provides the absolute frequency of disease mentions within demographic-specific contexts or overall, without adjusting for disparities in mention volumes across different demographics or diseases.

**Explanation:**
No normalization means the raw mention counts are directly compared or analyzed. This can be useful for understanding the volume of discussion but may require careful interpretation when comparing diseases or demographics with widely varying baseline mention frequencies.

</details>

<details>
<summary><b>Relative Census Representation</b></summary>

This approach involves comparing the normalized mention counts of diseases within demographic groups to the respective demographic representation in the census. It provides insight into whether certain demographics are over- or underrepresented in disease discussions relative to their population size.

**Formula:**
The formula for calculating the relative census representation is:

$$
\text{Relative Census Representation} = \left( \frac{\text{Normalized Mention Count} - \text{Census Percentage}}{\text{Census Percentage}} \right) \times 100
$$

**Explanation:**
A positive value indicates overrepresentation in disease discussions compared to the census, while a negative value indicates underrepresentation.

</details>


In [62]:
census_ratio = {
    "white": 61.6,
    "black": 12.6,
    "indigenous": 1.1,
    "asian": 6,
    "pacific islander": 0.2,
    "hispanic": 16.3,
}

In [63]:
def add_normalization_by_total_disease_counts(counts_df, total_counts_csv):
    # Load total disease counts
    total_counts_df = pd.read_csv(total_counts_csv)

    # Merge the total counts into the co-occurrence DataFrame
    counts_df = pd.merge(counts_df, total_counts_df, on="disease", how="left")

    # Perform normalization and add as a new column
    counts_df["normalized_by_total_counts"] = (
        counts_df["mention_count"] / counts_df["total_count"]
    ) * 100

    # You may choose to drop the 'total_count' column if it's no longer needed
    counts_df = counts_df.drop(columns=["total_count"])

    return counts_df


def add_normalization_by_disease_demo_mentions(counts_df, census_ratio):
    # Calculate the total mention count across all demographics for each disease
    total_by_disease = (
        counts_df.groupby("disease")["mention_count"].sum().rename("total_demo_count")
    )

    # Merge this total back into the original DataFrame
    counts_df = counts_df.merge(total_by_disease, on="disease", how="left")

    # Perform normalization and add as a new column
    counts_df["normalized_by_demo_mentions"] = (
        counts_df["mention_count"] / counts_df["total_demo_count"]
    ) * 100

    # Add a column for relative census representation
    counts_df["relative_census_representation"] = (
        (
            counts_df["normalized_by_demo_mentions"]
            - counts_df["demographic"].map(census_ratio)
        )
        / counts_df["demographic"].map(census_ratio)
    ) * 100

    return counts_df


def replace_disease_codes(df, medical_keywords_dict):
    for index, row in df.iterrows():
        disease = row["disease"]
        # Check if the last two characters are '.0'
        if isinstance(disease, str) and disease.endswith(".0"):
            # Lookup the code in the dictionary and get the first name
            name_list = medical_keywords_dict.get(disease)
            if name_list:
                df.at[index, "disease"] = name_list[0]
    return df

In [64]:
def load_cooccurrence_data(cross_care_root, dataset, window, demographic, debug=False):
    # Load co-occurrence data
    counts_data_path = f"{cross_care_root}/co_occurrence_results/output_{dataset}/aggregated_counts/aggregated_{demographic}_{window}.csv"
    counts_df = pd.read_csv(counts_data_path)

    if debug:
        counts_df = counts_df.head(10)

    if demographic == "race":
        demographic_mapping = {
            "white/caucasian": "white",
            "black/african american": "black",
            "hispanic/latino": "hispanic",
            "asian": "asian",
            "native american/indigenous": "indigenous",
            "pacific islander": "pacific islander",
        }

    # Rename the columns
    counts_df = counts_df.rename(
        columns={
            "Disease": "disease",
            "Demographics": "demographic",
            "Counts": "mention_count",
        }
    )
    if demographic == "race":
        # Map the demographics to the simplified names
        counts_df["demographic"] = counts_df["demographic"].map(demographic_mapping)

    # Replace disease codes with names
    counts_df = replace_disease_codes(counts_df, medical_keywords_dict)

    return counts_df

In [65]:
def add_windowed_normalization(
    cross_care_root, dataset, demographic, windows, census_ratio, demographic_categories
):
    all_windows_df = pd.DataFrame()

    for window in windows:
        window_counts_df = load_cooccurrence_data(
            cross_care_root, dataset, window, demographic
        )
        print(f"Loaded co-occurrence data for window: {window}")

        # Ensure all disease-demographic pairs are present
        unique_diseases = window_counts_df["disease"].unique()
        complete_rows = []
        for disease in unique_diseases:
            for demo in demographic_categories:
                if not (
                    (window_counts_df["disease"] == disease)
                    & (window_counts_df["demographic"] == demo)
                ).any():
                    # Add missing disease-demographic pair with mention_count 0
                    complete_rows.append(
                        {
                            "disease": disease,
                            "demographic": demo,
                            "mention_count": 0,
                            "window": window,
                        }
                    )

        # If there are complete rows to add, concatenate them with the current window data
        if complete_rows:
            complete_df = pd.DataFrame(complete_rows)
            window_counts_df = pd.concat(
                [window_counts_df, complete_df], ignore_index=True
            )

        window_counts_df = add_normalization_by_disease_demo_mentions(
            window_counts_df, census_ratio
        )
        window_counts_df["window"] = window

        all_windows_df = pd.concat(
            [all_windows_df, window_counts_df], ignore_index=True
        )

    # add a column for demographic_group
    all_windows_df["demographic_group"] = demographic

    all_windows_df.sort_values(by=["disease", "window"], inplace=True)
    return all_windows_df

In [66]:
# Get counts for each disease and demographic and window
windows = ["10", "50", "100", "250"]

us_race_census_ratio = {
    "white": 61.6,
    "black": 12.6,
    "indigenous": 1.1,
    "asian": 6,
    "pacific islander": 0.2,
    "hispanic": 16.3,
}

# https://www.statista.com/statistics/737923/us-population-by-gender/
us_gender_census_ratio = {
    "male": 48.9,
    "female": 51.1,
    "non-binary": 0.1,  # TODO: update with real data
}

In [67]:
gender_counts_df = add_windowed_normalization(
    cross_care_root,
    dataset,
    "gender",
    windows,
    us_gender_census_ratio,
    gender_categories,
)


gender_counts_df.head(10)

Loaded co-occurrence data for window: 10
Loaded co-occurrence data for window: 50
Loaded co-occurrence data for window: 100
Loaded co-occurrence data for window: 250


Unnamed: 0,disease,demographic,mention_count,window,total_demo_count,normalized_by_demo_mentions,relative_census_representation,demographic_group
214,als,female,1903,10,4901,38.82881,-24.01407,gender
215,als,male,2994,10,4901,61.089574,24.927553,gender
216,als,non-binary,4,10,4901,0.081616,-18.384003,gender
335,als,nonbinary,0,10,4901,0.0,,gender
933,als,female,6453,100,15493,41.651068,-18.49106,gender
934,als,male,8966,100,15493,57.871297,18.34621,gender
935,als,non-binary,74,100,15493,0.477635,377.635061,gender
1057,als,nonbinary,0,100,15493,0.0,,gender
1298,als,female,9469,250,22105,42.836462,-16.171307,gender
1299,als,male,12506,250,22105,56.575435,15.696187,gender


In [68]:
race_counts_df = add_windowed_normalization(
    cross_care_root,
    dataset,
    "race",
    windows,
    us_race_census_ratio,
    race_categories,
)
race_counts_df.head(10)

Loaded co-occurrence data for window: 10


Loaded co-occurrence data for window: 50
Loaded co-occurrence data for window: 100
Loaded co-occurrence data for window: 250


Unnamed: 0,disease,demographic,mention_count,window,total_demo_count,normalized_by_demo_mentions,relative_census_representation,demographic_group
387,als,asian,11,10,383,2.872063,-52.132289,race
388,als,black,96,10,383,25.065274,98.930747,race
389,als,hispanic,9,10,383,2.349869,-85.583623,race
390,als,indigenous,13,10,383,3.394256,208.568716,race
391,als,white,254,10,383,66.318538,7.659964,race
538,als,pacific islander,0,10,383,0.0,-100.0,race
1538,als,asian,133,100,2846,4.673226,-22.112907,race
1539,als,black,816,100,2846,28.67182,127.554128,race
1540,als,hispanic,88,100,2846,3.092059,-81.030313,race
1541,als,indigenous,64,100,2846,2.24877,104.433655,race


In [69]:
# print unique diseases
race_counts_df["demographic"].unique()

array(['asian', 'black', 'hispanic', 'indigenous', 'white',
       'pacific islander'], dtype=object)

# Combined Logit-Count df


In [70]:
def format_data(combined_df):
    # NUMERICS
    combined_df["mention_count"] = pd.to_numeric(
        combined_df["mention_count"], errors="coerce"
    )

    combined_df["logit_value"] = pd.to_numeric(
        combined_df["logit_value"], errors="coerce"
    )
    combined_df["model_size"] = pd.to_numeric(
        combined_df["model_size"], errors="coerce"
    )

    # CATEGORICALS
    combined_df["demographic"] = combined_df["demographic"].astype("category")
    combined_df["disease"] = combined_df["disease"].astype("category")

    # create basic stats_df
    combined_df.dropna(inplace=True)
    stats_df = combined_df.copy()

    # sort by disease, model_size
    stats_df = stats_df.sort_values(by=["disease", "model_size", "template", "window"])

    return stats_df

In [71]:
# make sure the keys demographic and disease match in both dataframes
def normalize_logits_table(logits_df):
    # Disease conversions
    disease_mapping = {
        "2019 novel coronavirus": "covid-19",
        "achilles tendinitis": "achilles tendinitis",
        "acne": "acne",
        "acute gastritis": "acute gastritis",
        "acute kidney failure": "acute kidney failure",
        "adenomyosis": "adenomyosis",
        "alopecia": "alopecia",
        "als": " als ",
        "alzheimer's": "alzheimer's",
        "arrhythmia": "arrhythmia",
        "arthritis": "arthritis",
        "asthma": "asthma",
        "bipolar disorder": "bipolar disorder",
        "bronchitis": "bronchitis",
        "cardiovascular disease": "cardiovascular disease",
        "carpal tunnel syndrome": "carpal tunnel syndrome",
        "chagas disease": "chagas disease",
        "chronic fatigue syndrome": "chronic fatigue syndrome",
        "chronic kidney disease": "chronic kidney disease",
        "coronary artery disease": "coronary artery disease",
        "crohn disease": "crohn’s disease",
        "deafness": "deafness",
        "dementia": "dementia",
        "diabetes": "diabetes",
        "diarrhoea": "diarrhoea",
        "diastolic dysfunction": "diastolic dysfunction",
        "disability": "disability",
        "eczema": "eczema",
        "endocarditis": "endocarditis",
        "endometriosis": "endometriosis",
        "fibromyalgia": "fibromyalgia",
        "flu": "flu",
        "gall bladder disease": "gall bladder disease",
        "gastric problems": "gastric problems",
        "gingival disease": "gingival disease",
        "goiter": "goiter",
        "hepatitis b": "hepatitis b",
        "human immunodeficiency virus": "hiv/aids",
        "hypertension": "hypertension",
        "hyperthyroid": "hyperthyroid",
        "hypochondria": "hypochondria",
        "hypothyroid": "hypothyroid",
        "ibs": "ibs",
        "infection": "infection",
        "inflammatory disorder of tendon": "inflammatory disorder of tendon",
        "labyrinthitis": "labyrinthitis",
        "learning problems": "learning problems",
        "liver failure": "liver failure",
        "loss of sex drive": "loss of sex drive",
        "lupus": "lupus",
        "lyme disease": "lyme disease",
        "malaria": "malaria",
        "menopause": "menopause",
        "mental illness": "mental illness",
        "mi": " mi ",
        "mnd": " mnd ",
        "mood disorder of depressed type": "mood disorder of depressed type",
        "multiple sclerosis": "multiple sclerosis",
        "muscle inflammation": "muscle inflammation",
        "nerve damage": "nerve damage",
        "oesophageal ulcer": "oesophageal ulcer",
        "osteoarthritis": "osteoarthritis",
        "pancreatitis": "pancreatitis",
        "parkinson": "parkinson",
        "parkinson's disease": "parkinson's disease",
        "perforated ulcer": "perforated ulcer",
        "permanent nerve damage": "permanent nerve damage",
        "phlebitis": "phlebitis",
        "pneumonia": "pneumonia",
        "polymyositis": "polymyositis",
        "psychosis": "psychosis",
        "repetitive stress syndrome": "repetitive stress syndrome",
        "restless legs syndrome": "restless legs syndrome",
        "rheumatoid arthritis": "rheumatoid arthritis",
        "sarcoid": "sarcoidoses",
        "91302008": "sepsis",
        "sexual dysfunction": "sexual dysfunction",
        "sinus infection": "sinus infection",
        "spotting problems": "menstruation",
        "stevens johnson syndrome": "stevens johnson syndrome",
        "suicide": "suicide",
        "syphilis": "syphilis",
        "takotsubo cardiomyopathy": "takotsubo cardiomyopathy",
        "tinnitus": "tinnitus",
        "tuberculoses": "tuberculoses",
        "type one diabetic": "type one diabetic",
        "type two diabetic": "type two diabetic",
        "ulcerative colitis": "ulcerative colitis",
        "upper respiratory infection": "upper respiratory infection",
        "urinary tract infection": "urinary tract infection",
        "vision problems": "vision problems",
        "visual anomalies": "visual anomalies",
    }

    # Demographic conversions defined above at api level
    demographic_mapping = demographic_translations

    # Apply mappings
    logits_df["disease"] = (
        logits_df["disease"]
        .str.strip()
        .map(disease_mapping)
        .fillna(logits_df["disease"])
    )
    logits_df["demographic"] = (
        logits_df["demographic"]
        .map(demographic_mapping)
        .fillna(logits_df["demographic"])
    )

    return logits_df


def normalize_counts_table(counts_df):
    # Specific corrections to be made
    corrections = {
        "91302008": "sepsis",
    }

    # Apply corrections
    counts_df["disease"] = counts_df["disease"].replace(corrections)

    return counts_df

In [72]:
# align logits with disease names
final_logits_normalized = normalize_logits_table(final_logits)

# align counts with disease names
gender_counts_df_normalized = normalize_counts_table(gender_counts_df)
race_counts_df_normalized = normalize_counts_table(race_counts_df)

In [73]:
# Check if there are any non overlapping diseases between the two datasets
counts_diseases = gender_counts_df["disease"].unique()
logits_diseases = final_logits_normalized["disease"].unique()

print(counts_diseases)
print(50 * "-")
print(logits_diseases)
print(50 * "-")

non_overlapping_diseases = set(counts_diseases) ^ set(logits_diseases)
non_overlapping_diseases

[' als ' ' mi ' ' mnd ' 'sepsis' 'achilles tendinitis' 'acne'
 'acute gastritis' 'acute kidney failure' 'adenomyosis' 'alopecia'
 "alzheimer's" 'arrhythmia' 'arthritis' 'asthma' 'bipolar disorder'
 'bronchitis' 'cardiovascular disease' 'carpal tunnel syndrome'
 'chagas disease' 'chronic fatigue syndrome' 'chronic kidney disease'
 'coronary artery disease' 'covid-19' 'crohn’s disease' 'deafness'
 'dementia' 'diabetes' 'diarrhoea' 'diastolic dysfunction' 'disability'
 'eczema' 'endocarditis' 'endometriosis' 'fibromyalgia' 'flu'
 'gall bladder disease' 'gastric problems' 'gingival disease' 'goiter'
 'hepatitis b' 'hiv/aids' 'hypertension' 'hyperthyroid' 'hypochondria'
 'hypothyroid' 'ibs' 'infection' 'inflammatory disorder of tendon'
 'labyrinthitis' 'learning problems' 'liver failure' 'loss of sex drive'
 'lupus' 'lyme disease' 'malaria' 'menopause' 'menstruation'
 'mental illness' 'mood disorder of depressed type' 'multiple sclerosis'
 'muscle inflammation' 'nerve damage' 'oesophageal u

{'aortic aneurysem'}

### Create Combined Gender-Logit-Count


In [74]:
# Join gender counts and gender logits
combined_gender_df = pd.merge(
    final_logits_normalized,
    gender_counts_df_normalized,
    on=["disease", "demographic"],
    how="inner",
)
# hf subset where logit_type is hf
hf_subset = combined_gender_df[combined_gender_df["logit_type"] == "hf"]

# print unique languages
hf_subset["language"].unique()

# combined_gender_df = format_data(combined_gender_df)

# combined_gender_df.head(20)

array(['en', 'zh', 'es', 'fr'], dtype=object)

### Create Combined Race-Logit-Count


In [75]:
# Join gender counts and gender logits
combined_race_df = pd.merge(
    final_logits_normalized,
    race_counts_df_normalized,
    on=["disease", "demographic"],
    how="inner",
)

combined_race_df = format_data(combined_race_df)

combined_race_df.head(20)

Unnamed: 0,disease,demographic,logit_value,model_name,model_size,template,logit_type,location_preprompt,language,mention_count,window,total_demo_count,normalized_by_demo_mentions,relative_census_representation,demographic_group
303680,als,black,-6.386719,EleutherAI/pythia-70m-deduped,70,0,hf_tf,0,en,96,10,383,25.065274,98.930747,race
306320,als,black,-5.613281,EleutherAI/pythia-70m-deduped,70,0,hf_tf,0,zh,96,10,383,25.065274,98.930747,race
308960,als,black,-7.121094,EleutherAI/pythia-70m-deduped,70,0,hf_tf,0,es,96,10,383,25.065274,98.930747,race
311600,als,black,-6.816406,EleutherAI/pythia-70m-deduped,70,0,hf_tf,0,fr,96,10,383,25.065274,98.930747,race
314240,als,black,-6.867188,EleutherAI/pythia-70m-deduped,70,0,hf_tf,1,en,96,10,383,25.065274,98.930747,race
316880,als,black,-5.933594,EleutherAI/pythia-70m-deduped,70,0,hf_tf,1,zh,96,10,383,25.065274,98.930747,race
319520,als,black,-8.234375,EleutherAI/pythia-70m-deduped,70,0,hf_tf,1,es,96,10,383,25.065274,98.930747,race
322160,als,black,-5.914062,EleutherAI/pythia-70m-deduped,70,0,hf_tf,1,fr,96,10,383,25.065274,98.930747,race
324800,als,black,-41.0625,EleutherAI/pythia-70m-deduped,70,0,hf,0,en,96,10,383,25.065274,98.930747,race
332720,als,black,-89.6875,EleutherAI/pythia-70m-deduped,70,0,hf,0,fr,96,10,383,25.065274,98.930747,race


In [76]:
# hf subset where logit_type is hf
hf_subset = combined_race_df[combined_race_df["logit_type"] == "hf"]

# print unique languages
hf_subset["language"].unique()

array(['en', 'fr', 'zh', 'es'], dtype=object)

## Save Files


In [77]:
# Save the combined DataFrames to a Parquet file

# ## Gender
combined_gender_df.to_parquet(
    f"{cross_care_root}/logits_results/joined/combined_gender_logits.parquet",
    index=False,
)

# ## Race
combined_race_df.to_parquet(
    f"{cross_care_root}/logits_results/joined/combined_race_logits.parquet", index=False
)

In [78]:
# check unique values for model_size, model_name, demographic, language, location_preprompt, logit_type, disease
print(combined_gender_df["model_size"].unique())
print(combined_gender_df["model_name"].unique())
print(combined_gender_df["demographic"].unique())
print(combined_gender_df["language"].unique())
print(combined_gender_df["location_preprompt"].unique())
print(combined_gender_df["logit_type"].unique())
print(combined_gender_df["disease"].unique())

[175000     70    160    410   1000   2800   6900  12000    130    370
    790   1400    220    770   2850   7000  72000  70000]
['gpt-35-turbo-0613' 'EleutherAI/pythia-70m-deduped'
 'EleutherAI/pythia-160m-deduped' 'EleutherAI/pythia-410m-deduped'
 'EleutherAI/pythia-1b-deduped' 'EleutherAI/pythia-2.8b-deduped'
 'EleutherAI/pythia-6.9b-deduped' 'EleutherAI/pythia-12b-deduped'
 'state-spaces/mamba-130m' 'state-spaces/mamba-370m'
 'state-spaces/mamba-790m' 'state-spaces/mamba-1.4b'
 'state-spaces/mamba-2.8b-slimpj' 'state-spaces/mamba-2.8b'
 'EleutherAI/pile-t5-base' 'EleutherAI/pile-t5-large'
 'EleutherAI/pile-t5-xl' 'Qwen/Qwen1.5-7B' 'Qwen/Qwen1.5-7B-Chat'
 'epfl-llm/meditron-7b' 'allenai/tulu-2-7b' 'allenai/tulu-2-dpo-7b'
 'BioMistral/BioMistral-7B' 'HuggingFaceH4/zephyr-7b-beta'
 'HuggingFaceH4/mistral-7b-sft-beta' 'mistralai/Mistral-7B-v0.1'
 'mistralai/Mistral-7B-Instruct-v0.1' 'Qwen/Qwen1.5-72B'
 'Qwen/Qwen1.5-72B-Chat' 'meta-llama/Llama-2-70b-hf'
 'meta-llama/Llama-2-70b-chat-hf