In [1]:
import pandas as pd
import jsonlines
import json
import os
import plotly.express as px
import seaborn as sns
sns.set_theme()
%config InlineBackend.figure_format = 'svg'
%matplotlib inline
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import pycountry_convert as pc

In [2]:
# Change setting to allow dataframe to be printed in full witdh
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
pd.set_option('display.max_colwidth', None)

In [3]:
# Get the current directory
current_directory = os.getcwd()
# Get the parent directory
parent_directory = os.path.dirname(os.path.dirname(current_directory))

In [4]:
dalle_2_sd_questions_df = pd.read_json('../../vqa/experiments/paper_plots/vqa_questions.jsonl', lines=True)
dalle_3_questions_df = pd.read_json('../../vqa/experiments/dale3_paper_plots/vqa_questions.jsonl', lines=True)
unspecific_questions_df = pd.read_json('../../vqa/experiments/unspecific_prompts/vqa_questions.jsonl', lines=True)

In [5]:
print("Shape of the vqa questions:", dalle_2_sd_questions_df.shape, dalle_3_questions_df.shape, unspecific_questions_df.shape)

Shape of the vqa questions: (15178, 4) (7726, 4) (71472, 4)


In [6]:
# Extract  the dish name, model name and continent
def extract_dish_name(file_path):
    # Split the file path into its components
    path_components = file_path.split(os.sep)
    indices_to_check = ["all_submitted_dishes", "countries_all", "all_submitted_dishes_dalle3"]
    # Find the index of 'all_submitted_dishes' in the path components
    for index_name in indices_to_check:
        try:
            index = path_components.index(index_name)
            # The model name is the next component after 'all_submitted_dishes'
            if index_name == "all_submitted_dishes":
                dish_name = path_components[index + 3]
            elif index_name == "countries_all":
                dish_name = path_components[index + 3]
            else:
                dish_name = path_components[index + 2]
            return dish_name
        except ValueError:
            continue
    return None

def extract_model_name(file_path):
    # Split the file path into its components
    path_components = file_path.split(os.sep)
    indices_to_check = ["all_submitted_dishes", "countries_all", "all_submitted_dishes_dalle3"]
    # Find the index of 'all_submitted_dishes' in the path components
    for index_name in indices_to_check:
        try:
            index = path_components.index(index_name)
            # The model name is the next component after 'all_submitted_dishes'
            if index_name == "all_submitted_dishes":
                model_name = path_components[index + 1]
            elif index_name == "countries_all":
                model_name = path_components[index + 2]
            else:
                model_name = path_components[index + 3]
            return model_name
        except ValueError:
            continue
    return None

def extract_country_name(file_path):
    # Split the file path into its components
    path_components = file_path.split(os.sep)
    indices_to_check = ["all_submitted_dishes", "countries_all", "all_submitted_dishes_dalle3"]
    # Find the index of 'all_submitted_dishes' in the path components
    for index_name in indices_to_check:
        try:
            index = path_components.index(index_name)
            # The model name is the next component after 'all_submitted_dishes'
            if index_name == "all_submitted_dishes":
                countries = path_components[index + 2]
                countries = countries.split(", ")
            elif index_name == "countries_all":
                country_string = path_components[index + 1]
                        # Split the string at the underscore
                country = country_string.split("_")
                countries = [country[1]]
            else:
                countries = path_components[index + 1]
                countries = countries.split(", ")
            return countries
        except ValueError:
            continue
    return None

In [7]:
dalle_2_sd_questions_df['dish_name'] = dalle_2_sd_questions_df['image_path'].map(extract_dish_name)
dalle_2_sd_questions_df['model_name'] = dalle_2_sd_questions_df['image_path'].map(extract_model_name)
dalle_2_sd_questions_df['country_name'] = dalle_2_sd_questions_df['image_path'].map(extract_country_name)

In [8]:
dalle_3_questions_df['dish_name'] = dalle_3_questions_df['image_path'].map(extract_dish_name)
dalle_3_questions_df['model_name'] = dalle_3_questions_df['image_path'].map(extract_model_name)
dalle_3_questions_df['country_name'] = dalle_3_questions_df['image_path'].map(extract_country_name)

In [9]:
unspecific_questions_df['dish_name'] = unspecific_questions_df['image_path'].map(extract_dish_name)
unspecific_questions_df['model_name'] = unspecific_questions_df['image_path'].map(extract_model_name)
unspecific_questions_df['country_name'] = unspecific_questions_df['image_path'].map(extract_country_name)

In [10]:
# Read the countries with continent data
countries_df = pd.read_csv(os.path.join(parent_directory, 'data/countries_with_continent.csv'))

In [11]:
# Explode the lists of country names into separate rows
vqa_questions_exploded = dalle_2_sd_questions_df.explode('country_name')

# Perform the merge operation
merged_df = pd.merge(vqa_questions_exploded,
                     countries_df[['Name', 'Continent Name']],
                     left_on='country_name',
                     right_on='Name',
                     how='left')

# Group by the original index and aggregate the continent names into lists
result_df = merged_df.groupby('question_id')['Continent Name'].apply(list).reset_index()

# Join the continent names back to the original DataFrame
dalle_2_sd_vqa_questions_with_countries = dalle_2_sd_questions_df.join(result_df.set_index('question_id'), on='question_id')

In [12]:
# Explode the lists of country names into separate rows
vqa_questions_exploded = unspecific_questions_df.explode('country_name')

# Perform the merge operation
merged_df = pd.merge(vqa_questions_exploded,
                     countries_df[['Name', 'Continent Name']],
                     left_on='country_name',
                     right_on='Name',
                     how='left')

# Group by the original index and aggregate the continent names into lists
result_df = merged_df.groupby('question_id')['Continent Name'].apply(list).reset_index()

# Join the continent names back to the original DataFrame
unspecific_vqa_questions_with_countries = unspecific_questions_df.join(result_df.set_index('question_id'), on='question_id')

In [13]:
# Explode the lists of country names into separate rows
vqa_questions_exploded = dalle_3_questions_df.explode('country_name')

# Perform the merge operation
merged_df = pd.merge(vqa_questions_exploded,
                     countries_df[['Name', 'Continent Name']],
                     left_on='country_name',
                     right_on='Name',
                     how='left')

# Group by the original index and aggregate the continent names into lists
result_df = merged_df.groupby('question_id')['Continent Name'].apply(list).reset_index()

# Join the continent names back to the original DataFrame
dalle_3_vqa_questions_with_countries = dalle_3_questions_df.join(result_df.set_index('question_id'), on='question_id')

In [14]:
# Filter the DataFrame to get the rows where the list contains a null item
rows_with_null = dalle_2_sd_vqa_questions_with_countries[dalle_2_sd_vqa_questions_with_countries['Continent Name'].apply(lambda x: any(pd.isna(x)))]
len(rows_with_null)

0

In [15]:
# Filter the DataFrame to get the rows where the list contains a null item
rows_with_null = dalle_3_vqa_questions_with_countries[dalle_3_vqa_questions_with_countries['Continent Name'].apply(lambda x: any(pd.isna(x)))]
len(rows_with_null)

0

In [16]:
# Filter the DataFrame to get the rows where the list contains a null item
rows_with_null = unspecific_vqa_questions_with_countries[unspecific_vqa_questions_with_countries['Continent Name'].apply(lambda x: any(pd.isna(x)))]
len(rows_with_null)

5680

In [126]:
rows_with_null

Unnamed: 0,image_name,question_id,question,image_path,dish_name,model_name,country_name,Continent Name
2800,Palmyra Atoll_02.png,2800,Question: Is this a picture of food?. Choices: A: Yes B: No. Please respond with the letter corresponding to your choice. Answer:,/food-bias/countries_all/254_Palmyra Atoll/dalle2/Palmyra Atoll_02.png,Palmyra Atoll_02.png,dalle2,[Palmyra Atoll],"[nan, Oceania]"
2801,Palmyra Atoll_02.png,2801,Question: Is the dish placed outdoors or indoors?. Choices: A: Outdoors B: Indoors. Please respond with the letter corresponding to your choice. Answer:,/food-bias/countries_all/254_Palmyra Atoll/dalle2/Palmyra Atoll_02.png,Palmyra Atoll_02.png,dalle2,[Palmyra Atoll],"[nan, Oceania]"
2802,Palmyra Atoll_02.png,2802,"Question: What utensils, if any, are shown in this image?. Choices: A: Fork B: Spoon C: Knife D: Chopsticks E: No utensils shown. Please respond with the letter corresponding to your choice. Answer:",/food-bias/countries_all/254_Palmyra Atoll/dalle2/Palmyra Atoll_02.png,Palmyra Atoll_02.png,dalle2,[Palmyra Atoll],"[nan, Oceania]"
2803,Palmyra Atoll_02.png,2803,Question: Is the dish placed on a table?. Choices: A: Yes B: No. Please respond with the letter corresponding to your choice. Answer:,/food-bias/countries_all/254_Palmyra Atoll/dalle2/Palmyra Atoll_02.png,Palmyra Atoll_02.png,dalle2,[Palmyra Atoll],"[nan, Oceania]"
2804,Palmyra Atoll_02.png,2804,Question: What material is the dish or plate in the image most likely made of?. Choices: A: Ceramic B: Glass C: Metal D: Plastic E: Wood F: Paper G: Clay. Please respond with the letter corresponding to your choice. Answer:,/food-bias/countries_all/254_Palmyra Atoll/dalle2/Palmyra Atoll_02.png,Palmyra Atoll_02.png,dalle2,[Palmyra Atoll],"[nan, Oceania]"
...,...,...,...,...,...,...,...,...
69547,Azad Kashmir_01.png,69547,Question: What type of lighting is used in the image?. Choices: A: Natural light B: Low light C: High contrast light D: Soft and diffused light E: Mixed lighting F: No visible lighting source. Please respond with the letter corresponding to your choice. Answer:,/food-bias/countries_all/284_Azad Kashmir/dalle3/Azad Kashmir_01.png,Azad Kashmir_01.png,dalle3,[Azad Kashmir],"[nan, Asia]"
69548,Azad Kashmir_01.png,69548,"Question: Are there any additional elements in the image? (e.g., drinks, side dishes, condiments). Choices: A: Yes, drinks B: Yes, side dishes C: Yes, condiments D: Yes, multiple elements E: No, just the main dish. Please respond with the letter corresponding to your choice. Answer:",/food-bias/countries_all/284_Azad Kashmir/dalle3/Azad Kashmir_01.png,Azad Kashmir_01.png,dalle3,[Azad Kashmir],"[nan, Asia]"
69549,Azad Kashmir_01.png,69549,Question: Are there any utencils shown in this image?. Choices: A: Yes B: No. Please respond with the letter corresponding to your choice. Answer:,/food-bias/countries_all/284_Azad Kashmir/dalle3/Azad Kashmir_01.png,Azad Kashmir_01.png,dalle3,[Azad Kashmir],"[nan, Asia]"
69550,Azad Kashmir_01.png,69550,Question: Is there a person shown in this image?. Choices: A: Yes B: No. Please respond with the letter corresponding to your choice. Answer:,/food-bias/countries_all/284_Azad Kashmir/dalle3/Azad Kashmir_01.png,Azad Kashmir_01.png,dalle3,[Azad Kashmir],"[nan, Asia]"


In [18]:
unique_countries = set(country for sublist in rows_with_null['country_name'] for country in sublist)
unique_countries

{'Abkhazia',
 'Antarctica',
 'Azad Kashmir',
 'Baker Island',
 'Bavaria',
 'Corsica',
 'Crimea',
 'East Timor',
 'French Southern Territories',
 'Holy See',
 'Howland Island',
 'Jarvis Island',
 'Johnston Atoll',
 'Kingman Reef',
 'Midway Atoll',
 'Navassa Island',
 'North Ossetia',
 'Palmyra Atoll',
 'Pitcairn',
 'South Moluccas',
 'Timor-Leste',
 'Vatican City',
 'Wake Island',
 'Western Sahara'}

In [32]:
# Load the vqa model responses
dalle2_sd_vqa_answers_df = pd.read_json('../../vqa/experiments/paper_plots/vqa_paper_plots_answers.jsonl', lines=True)
dalle3_vqa_answers_df = pd.read_json('../../vqa/experiments/dale3_paper_plots/dalle3_vqa_paper_plot_answers.jsonl', lines=True)
unspecified_vqa_answers_df = pd.read_json('../../vqa/unspecific__prompts_dish_answers.jsonl', lines=True)

In [33]:
print("Shape of the vqa responses:", dalle2_sd_vqa_answers_df.shape, dalle3_vqa_answers_df.shape, unspecified_vqa_answers_df.shape)

Shape of the vqa responses: (15178, 6) (7726, 6) (71472, 6)


In [50]:
# Load the questions 
with open(os.path.join('../../vqa/experiments/unspecific_prompts/questions.json'), "r") as f:
    questions_json = json.load(f)

questions = questions_json["questions"]

In [51]:
def format_prompt(question):
    if question["type"] == "multiple_choice":
        choices_str = " ".join(
            [f"{key}: {value}" for key, value in question["choices"].items()]
        )
        return f"Question: {question['text']}. Choices: {choices_str}. Please respond with the letter corresponding to your choice. Answer:"
    elif question["type"] == "list":
        choices_str = " ".join(
            [f"{key}: {value}" for key, value in question["choices"].items()]
        )
        return f"Question: {question['text']}. Choices: {choices_str}. Please respond with the letter corresponding to your choice. Answer:"
    elif question["type"] == "free_form":
        return f"Question: {question['text']}. Answer:"

In [52]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.io as pio


# Define the plot_multiple_choice function using Plotly
def plot_multiple_choice_pdf(grouped_by_model_name, q4_json):
    response_counts = grouped_by_model_name.size().reset_index(name='count')
    
    fig = make_subplots(
        rows=1,
        cols=1,
        shared_xaxes=True,
        vertical_spacing=0.1
    )
    
    for model_name in response_counts['model_name'].unique():
        model_data = response_counts[response_counts['model_name'] == model_name]
        fig.add_trace(go.Bar(
            x=model_data['text'],
            y=model_data['count'],
            name=model_name
        ))
    
    fig.update_layout(
        title={
            'text': q4_json['text'],
            'y': 0.9,
            'x': 0.5,
            'xanchor': 'center',
            'yanchor': 'top'
        },
        xaxis_title='Response',
        yaxis_title='Count',
        font=dict(
            family="Helvetica Neue",
            size=12,
        ),
        title_font=dict(
            family="Helvetica Neue",
            size=16,
            color='black'
        ),
        plot_bgcolor='white',
        xaxis=dict(
            showline=True,
            linewidth=2,
            linecolor='grey',
            mirror=True,
            tickvals=list(q4_json['choices'].keys()),
            ticktext=list(q4_json['choices'].values()),
        ),
        yaxis=dict(
            showline=True,
            linewidth=2,
            linecolor='grey',
            mirror=True,
            title='Count'
        ),
        barmode='group',  # Group bars side by side
        width=1200,  # Adjust the width of the plot here
        height=600  # Optionally, adjust the height of the plot
    )
    
    return fig

In [53]:
# for question in questions:
#     if question.get('type') == 'multiple_choice':
#         q4_json = question
#         q4_prompt = format_prompt(q4_json)
#         q4_responses = vqa_answers_df[vqa_answers_df["prompt"] == q4_prompt]
#         question_ids = q4_responses['question_id'].values
#         questions_df = vqa_questions_df[vqa_questions_df['question_id'].isin(question_ids)]
#         merged_df = pd.merge(q4_responses, questions_df[['question_id', 'model_name', 'image_path']], on='question_id', how='left')
#         grouped_by_model_name = merged_df.groupby(['model_name','text'])
        
#         fig = plot_multiple_choice_pdf(grouped_by_model_name, q4_json)
#         pio.write_image(fig, os.path.join("results", "worldwide", f"{q4_json['text'].replace('?', '')}.pdf"), scale=2)
#         fig.show()

In [54]:
import os
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.io as pio
import pandas as pd
import numpy as np

def plot_multiple_choice_countries_pdf(grouped_by_model_name, q4_json, group_by='Continent Name', color_sequence=None, legend_map=None, legend_title='Continent', results_dir='results/continent'):
    if color_sequence is None:
        color_sequence = {
        'Africa': '#F2545B',
        'Antarctica': '#74c2e1',
        'Asia': '#785EF0',
        'Europe': '#648FFF',
        'North America': '#078BA0',
        'Oceania': '#3DD24E',
        'South America': '#FFB000'
    }

    
    if legend_map is None:
        legend_map = {
            'Africa': 'Africa',
            'Asia': 'Asia',
            'Europe': 'Europe',
            'North America': 'North America',
            'Oceania': 'Oceania',
            'South America': 'South America'
        }

    models = ['sd21', 'dalle2', 'dalle3']
    unique_groups = set()

    for entry in grouped_by_model_name[group_by]:
        unique_groups.update(entry)

    unique_groups = list(unique_groups)

    num_models = len(models)
    fig = make_subplots(rows=1, cols=num_models, shared_yaxes=True, horizontal_spacing=0.05, subplot_titles=["Stable Diffusion v2.1", "DALL-E 2","DALL-E 3"])

    for i, model_name in enumerate(models):
        model_data = grouped_by_model_name[grouped_by_model_name['model_name'] == model_name]
        for group in unique_groups:
            if pd.isna(group) == False:
                group_data = model_data[model_data[group_by].apply(lambda x: group in x)]
                response_counts = group_data.groupby('text').size().reset_index(name='count')
                total_responses = response_counts['count'].sum()
                response_counts['proportion'] = (response_counts['count'] / total_responses) * 100
                
                # Ensure all choices are present in response_counts, even if their count is 0
                all_choices = list(q4_json['choices'].keys())
                response_counts = response_counts.set_index('text').reindex(all_choices, fill_value=0).reset_index()

                trace = go.Bar(
                    x=[q4_json['choices'][key] for key in response_counts['text']],
                    y=response_counts['proportion'],
                    name=legend_map.get(group, group),
                    legendgroup=group,
                    marker=dict(color=color_sequence.get(group, '#000000')),
                    showlegend=(i == 0)
                )
                fig.add_trace(trace, row=1, col=i+1)

        fig.update_xaxes(
            tickvals=list(q4_json['choices'].values()),
            ticktext=list(q4_json['choices'].values()),
            categoryorder='array',
            categoryarray=list(q4_json['choices'].values()),
            row=1, col=i+1
        )

    fig.update_yaxes(title_text='Proportion of Images (%)', row=1, col=1)

    fig.update_layout(
       legend=dict(
            orientation="h",
            yanchor="bottom",
            y=1.2,
            xanchor="right",
            x=1,
            bgcolor='rgba(255, 255, 255, 0.8)',
            bordercolor='grey',
            borderwidth=1,
            title='Continent',
            title_font=dict(
            family="Helvetica Neue",
            size=18,
            color='black'
            ),
        ),
        # title={
        #     'text': f"{q4_json['text']}",
        #     'y': 0.9,
        #     'x': 0.5,
        #     'xanchor': 'center',
        #     'yanchor': 'top'
        # },
        font=dict(
            family="Helvetica Neue",
            size=14,
        ),
        title_font=dict(
            family="Helvetica Neue",
            size=18,
            color='black'
        ),
        plot_bgcolor='white',
        legend_title_text=legend_title,
        barmode='group',
        width=1400,
        height=500,
        margin=go.layout.Margin(
        l=5, #left margin
        r=5, #right margin
        b=5, #bottom margin
        t=30  #top margin
    )
    )

    fig.update_xaxes(
            title_text="VQA Responses",
            row=1,
            col=2,
            title_font=dict(family="Helvetica Neue", size=16),
        )

    for i in range(num_models):
        fig.update_layout(
            {f'xaxis{i+1}': dict(
                showline=True,
                linewidth=1,
                linecolor='grey',
                mirror=True
            ),
            f'yaxis{i+1}': dict(
                showline=True,
                linewidth=1,
                linecolor='grey',
                mirror=True
            )}
        )
    
    fig.update_layout(font=dict(
            family="Helvetica Neue",
            size=18,
            style='normal',
            color='black'
        ),)
    
    fig.update_annotations(font=dict(family="Helvetica Neue", size=18, color='black'))


    # Save the plot with separate subplots for each model
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
        
    pio.write_image(fig, os.path.join(results_dir, f"{q4_json['text'].replace('?', '')}_subplots.pdf"), scale=6)
    fig.show()

In [55]:
# Do analysis and group by continent level for each question
for question_json in questions:
    if question_json.get('type') == 'multiple_choice':
        prompt = format_prompt(question_json)
        question_answers = unspecified_vqa_answers_df[unspecified_vqa_answers_df["prompt"] == prompt]
        question_ids = question_answers['question_id'].values
        questions_df = unspecific_vqa_questions_with_countries[unspecific_vqa_questions_with_countries['question_id'].isin(question_ids)]
        merged_df = pd.merge(question_answers, questions_df[['question_id', 'model_name', 'image_path', 'Continent Name']], on='question_id', how='left')

        plot_multiple_choice_countries_pdf(merged_df, question_json, results_dir='results/continent/unspecified')

In [111]:
# Do analysis and group by continent level for each question
for question_json in questions:
    if question_json.get('type') == 'multiple_choice':
        prompt = format_prompt(question_json)
        question_answers = dalle2_sd_vqa_answers_df[dalle2_sd_vqa_answers_df["prompt"] == prompt]
        question_ids = question_answers['question_id'].values
        questions_df = dalle_2_sd_vqa_questions_with_countries[dalle_2_sd_vqa_questions_with_countries['question_id'].isin(question_ids)]
        merged_df = pd.merge(question_answers, questions_df[['question_id', 'model_name', 'image_path', 'Continent Name']], on='question_id', how='left')

        # Process Dalle 3 images
        dalle3_question_answers = dalle3_vqa_answers_df[dalle3_vqa_answers_df["prompt"] == prompt]
        dalle3_question_ids = dalle3_question_answers['question_id'].values
        dalle3_questions_df = dalle_3_vqa_questions_with_countries[dalle_3_vqa_questions_with_countries['question_id'].isin(question_ids)]
        dalle_3_merged_df = pd.merge(dalle3_question_answers, dalle3_questions_df[['question_id', 'model_name', 'image_path', 'Continent Name']], on='question_id', how='left')

        concat_df = pd.concat([merged_df, dalle_3_merged_df])

        plot_multiple_choice_countries_pdf(concat_df, question_json, results_dir='results/continent/unspecified')

In [81]:
countries = ["Kenya", "Nigeria", "Algeria", "Cameroon", "South Africa", "United States of America (USA)"]

In [None]:
/

FileNotFoundError: [Errno 2] No such file or directory: 'taste_results.csv'

In [173]:
# Define a color sequence for each continent
country_color_sequence = {
    'Kenya': '#636efb',
    'Nigeria': '#ef563b',
    'Algeria': '#41cc96',
    'Cameroon': '#ab63fb',
    'South Africa': '#f7a15b',
    'United States of America (USA)': '#47d3f3'
}
country_name_map = {
    'Kenya': 'Kenya',
    'Nigeria': 'Nigeria',
    'Algeria': 'Algeria',
    'Cameroon': 'Cameroon',
    'South Africa': 'South Africa',
    'United States of America (USA)': 'USA'
}
country_results_dir = 'results/country'

In [174]:
# Run the experiments for specific countries
# Do analysis and group by continent level for each question
for question_json in questions:
    if question_json.get('type') == 'multiple_choice':
        prompt = format_prompt(question_json)
        question_answers = vqa_answers_df[vqa_answers_df["prompt"] == prompt]
        question_ids = question_answers['question_id'].values
        questions_df = vqa_questions_with_countries[vqa_questions_with_countries['question_id'].isin(question_ids)]
        # Filter questions for specified countries
        country_questions_df = questions_df[questions_df['country_name'].apply(lambda x: any(item in x for item in countries))]
        country_questions_copy = country_questions_df.copy()
        country_questions_copy.loc[:, 'country'] = country_questions_df.loc[:, 'country_name'].map(lambda x: next((item for item in x if item in countries), None))
        # Merge question answers with filtered country questions
        merged_df = pd.merge(question_answers, country_questions_copy[['question_id', 'model_name', 'image_path', 'country_name', 'country']], on='question_id', how='right')
        # Call the plotting function with appropriate parameters
        plot_multiple_choice_countries_level_pdf(merged_df, question_json, group_by='country', color_sequence=country_color_sequence, legend_map=country_name_map, legend_title='Country', results_dir=country_results_dir)