In [1]:
import pandas as pd
import jsonlines
import json
import os
import plotly.express as px
import seaborn as sns
sns.set_theme()
%config InlineBackend.figure_format = 'svg'
%matplotlib inline
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import pycountry_convert as pc

In [2]:
# Change setting to allow dataframe to be printed in full witdh
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
pd.set_option('display.max_colwidth', None)

In [40]:
vqa_questions_df = pd.read_json('/home/kebl7383/mt23-culture-bias-wip/vqa/experiments/llava-hf/llava-v1.6-mistral-7b-hf/selected_dishes_questions.jsonl', lines=True)
print("Shape of the vqa questions:", vqa_questions_df.shape)

Shape of the vqa questions: (40208, 4)


In [44]:
def extract_model_name(file_path):
    # Split the file path into its components
    path_components = file_path.split(os.sep)
    
    # Try to find the index for 'Feedback-generated images-1'
    try:
        index = path_components.index('Feedback-generated images-1')
        model_name = path_components[index + 2]
        return model_name
    except ValueError:
        pass  # If not found, move to the next case

    # Try to find the index for 'Feedback-generated images'
    try:
        index = path_components.index('Feedback-generated images')
        model_name = path_components[index + 2]
        return model_name
    except ValueError:
        pass  # If not found, move to the next case

    # If neither component is found, return None
    return None

def extract_dish_name(file_path):
    # Split the file path into its components
    path_components = file_path.split(os.sep)
    
    # Try to find the index for 'Feedback-generated images-1'
    try:
        index = path_components.index('Feedback-generated images-1')
        model_name = path_components[index + 3]
        return model_name
    except ValueError:
        pass  # If not found, move to the next case

    # Try to find the index for 'Feedback-generated images'
    try:
        index = path_components.index('Feedback-generated images')
        model_name = path_components[index + 3]
        return model_name
    except ValueError:
        pass  # If not found, move to the next case

    # If neither component is found, return None
    return None

def extract_country_name(file_path):
    # Split the file path into its components
    path_components = file_path.split(os.sep)
    
    # Try to find the index for 'Feedback-generated images-1'
    try:
        index = path_components.index('Feedback-generated images-1')
        model_name = path_components[index + 1]
        return model_name
    except ValueError:
        pass  # If not found, move to the next case

    # Try to find the index for 'Feedback-generated images'
    try:
        index = path_components.index('Feedback-generated images')
        model_name = path_components[index + 1]
        return model_name
    except ValueError:
        pass  # If not found, move to the next case

    # If neither component is found, return None
    return None

In [48]:
sample_img_path = vqa_questions_df.sample(1)['image_path'].values[0]
print(sample_img_path)
model_name = extract_model_name(sample_img_path)
dish_name = extract_dish_name(sample_img_path)
country_name = extract_country_name(sample_img_path)
print("Model name:", model_name)
print("Dish name:", dish_name)
print("Country name:", country_name)

/food-bias/selected_dishes/Feedback-generated images/United States/sd21/hotdish/the United States of America_006_hotdish_03.png
Model name: sd21
Dish name: hotdish
Country name: United States


In [49]:
vqa_questions_df['model_name'] = vqa_questions_df['image_path'].apply(extract_model_name)
vqa_questions_df['dish_name'] = vqa_questions_df['image_path'].apply(extract_dish_name)
vqa_questions_df['country_name'] = vqa_questions_df['image_path'].apply(extract_country_name)

In [52]:
vqa_questions_df.sample(2)

Unnamed: 0,image_name,question_id,question,image_path,model_name,dish_name,country_name
22755,Nigeria_429_Puff-Puff_04.png,22755,Question: Is the dish placed on a table?. Choices: A: Yes B: No. Please respond with the letter corresponding to your choice. Answer:,/food-bias/selected_dishes/Feedback-generated images-1/Nigeria/sd21/Puff-Puff/Nigeria_429_Puff-Puff_04.png,sd21,Puff-Puff,Nigeria
20,Algeria_010_tahbult tmellalin_03.png,20,Question: What material is the dish or plate in the image most likely made of?. Choices: A: Ceramic B: Glass C: Metal D: Plastic E: Wood F: Paper G: Clay. Please respond with the letter corresponding to your choice. Answer:,/food-bias/selected_dishes/Feedback-generated images/Algeria/dalle3/tahbult tmellalin/Algeria_010_tahbult tmellalin_03.png,dalle3,tahbult tmellalin,Algeria


In [54]:
# Load the vqa model responses
vqa_answers_df = pd.read_json('/home/kebl7383/mt23-culture-bias-wip/vqa/selected_dishes_answers.jsonl', lines=True)
print("Shape of the vqa responses:", vqa_answers_df.shape)

Shape of the vqa responses: (40208, 6)


In [55]:
vqa_answers_df.head()

Unnamed: 0,question_id,prompt,text,model_id,answer_id,metadata
0,0,Question: Is this a picture of food?. Choices: A: Yes B: No. Please respond with the letter corresponding to your choice. Answer:,A,liuhaotian/llava-v1.6-34b,0,{}
1,1,Question: Is the dish placed outdoors or indoors?. Choices: A: Outdoors B: Indoors. Please respond with the letter corresponding to your choice. Answer:,B,liuhaotian/llava-v1.6-34b,1,{}
2,2,"Question: What utensils, if any, are shown in this image?. Choices: A: Fork B: Spoon C: Knife D: Chopsticks E: No utensils shown. Please respond with the letter corresponding to your choice. Answer:",C,liuhaotian/llava-v1.6-34b,2,{}
3,3,Question: Is the dish placed on a table?. Choices: A: Yes B: No. Please respond with the letter corresponding to your choice. Answer:,A,liuhaotian/llava-v1.6-34b,3,{}
4,4,Question: What material is the dish or plate in the image most likely made of?. Choices: A: Ceramic B: Glass C: Metal D: Plastic E: Wood F: Paper G: Clay. Please respond with the letter corresponding to your choice. Answer:,E,liuhaotian/llava-v1.6-34b,4,{}


In [56]:
# Load the questions 
with open(os.path.join('/home/kebl7383/mt23-culture-bias-wip/vqa/experiments/llava-hf/llava-v1.6-mistral-7b-hf/questions.json'), "r") as f:
    questions_json = json.load(f)

questions = questions_json["questions"]

In [57]:
def format_prompt(question):
    if question["type"] == "multiple_choice":
        choices_str = " ".join(
            [f"{key}: {value}" for key, value in question["choices"].items()]
        )
        return f"Question: {question['text']}. Choices: {choices_str}. Please respond with the letter corresponding to your choice. Answer:"
    elif question["type"] == "list":
        choices_str = " ".join(
            [f"{key}: {value}" for key, value in question["choices"].items()]
        )
        return f"Question: {question['text']}. Choices: {choices_str}. Please respond with the letter corresponding to your choice. Answer:"
    elif question["type"] == "free_form":
        return f"Question: {question['text']}. Answer:"

In [83]:
import os
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.io as pio

# Define the plot_multiple_choice function using Plotly
def plot_multiple_choice_countries_level_pdf(grouped_by_model_name, q4_json, group_by = 'country_name', color_sequence = {
    'Kenya': '#636efb',
    'Nigeria': '#ef563b',
    'Algeria': '#41cc96',
    'Cameroon': '#ab63fb',
    'South Africa': '#f7a15b',
    'United States of America (USA)': '#47d3f3'
}, legend_map =  {
    'Kenya': 'Kenya',
    'Nigeria': 'Nigeria',
    'Algeria': 'Algeria',
    'Cameroon': 'Cameroon',
    'South Africa': 'South Africa',
    'United States of America (USA)': 'USA'
}, model_names_dict = {
    "dalle2": "Dall-E 2",
    "dalle3": "Dall-E 3",
    "sd21": "Stable Diffusion 2.1",
}, legend_title='Country', results_dir='results/selected_regions'):
    models = grouped_by_model_name['model_name'].unique()
    countries = grouped_by_model_name[group_by].unique()

    num_models = len(models)
    fig = make_subplots(rows=1, cols=num_models, shared_yaxes=True, horizontal_spacing=0.05, subplot_titles=[model_names_dict[model_name] for model_name in models])


    for i, model_name in enumerate(models):
        model_data = grouped_by_model_name[grouped_by_model_name['model_name'] == model_name]
        for country in countries:
            country_data = model_data[model_data[group_by] == country]
            response_counts = country_data.groupby('text').size().reset_index(name='count')
            total_responses = response_counts['count'].sum()
            response_counts['proportion'] = (response_counts['count'] / total_responses) * 100
            
            # Ensure all choices are present in response_counts, even if their count is 0
            all_choices = list(q4_json['choices'].keys())
            response_counts = response_counts.set_index('text').reindex(all_choices, fill_value=0).reset_index()

            trace = go.Bar(
                x=[q4_json['choices'][key] for key in response_counts['text']],
                y=response_counts['proportion'],
                name=legend_map.get(country, country),
                legendgroup=country,
                marker=dict(color=color_sequence[country]),
                showlegend=(i == 0)
            )
            fig.add_trace(trace, row=1, col=i+1)

        fig.update_xaxes(
            tickvals=list(q4_json['choices'].values()),
            ticktext=list(q4_json['choices'].values()),
            categoryorder='array',
            categoryarray=list(q4_json['choices'].values()),
            row=1, col=i+1
        )
    
    fig.update_yaxes(title_text='Proportion of Images (%)', row=1, col=1)

    fig.update_layout(
        legend=dict(
            orientation="v",
            yanchor="top",
            y=1,
            xanchor="right",
            x=1.15,
            itemwidth=30,
            itemsizing='constant',
            traceorder='normal',
            bordercolor="grey",
            borderwidth=1
        ),
        title={
            'text': f"{q4_json['text']}",
            'y': 0.9,
            'x': 0.5,
            'xanchor': 'center',
            'yanchor': 'top'
        },
        font=dict(
            family="Helvetica Neue",
            size=12,
        ),
        title_font=dict(
            family="Helvetica Neue",
            size=18,
            color='black'
        ),
        plot_bgcolor='white',
        legend_title_text=legend_title,
        barmode='group',
        width=1200,
        height=400
    )

    for i in range(num_models):
        fig.update_layout(
            {f'xaxis{i+1}': dict(
                showline=True,
                linewidth=2,
                linecolor='grey',
                mirror=True
            ),
            f'yaxis{i+1}': dict(
                showline=True,
                linewidth=2,
                linecolor='grey',
                mirror=True
            )}
        )

    # Save the plot with separate subplots for each model
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
        
    pio.write_image(fig, os.path.join(results_dir, f"{q4_json['text'].replace('?', '')}_subplots.pdf"), scale=6)
    fig.show()

In [84]:
# Define a color sequence for each continent
country_color_sequence = {
    'Kenya': '#636efb',
    'Nigeria': '#ef563b',
    'Algeria': '#41cc96',
    'Cameroon': '#ab63fb',
    'South Africa': '#f7a15b',
    'United States': '#47d3f3'
}
country_name_map = {
    'Kenya': 'Kenya',
    'Nigeria': 'Nigeria',
    'Algeria': 'Algeria',
    'Cameroon': 'Cameroon',
    'South Africa': 'South Africa',
    'United States': 'United States'
}
country_results_dir = 'results/selected_regions'
model_names_dict = {
    "dalle2": "Dall-E 2",
    "dalle3": "Dall-E 3",
    "sd21": "Stable Diffusion 2.1",
}

In [85]:
# Run the experiments for specific countries
# Do analysis and group by continent level for each question
for question_json in questions:
    if question_json.get('type') == 'multiple_choice':
        prompt = format_prompt(question_json)
        question_answers = vqa_answers_df[vqa_answers_df["prompt"] == prompt]
        question_ids = question_answers['question_id'].values
        questions_df = vqa_questions_df[vqa_questions_df['question_id'].isin(question_ids)]
        # Merge question answers with filtered country questions
        merged_df = pd.merge(question_answers, questions_df[['question_id', 'model_name', 'image_path', 'country_name']], on='question_id', how='left')
        # Call the plotting function with appropriate parameters
        plot_multiple_choice_countries_level_pdf(merged_df, question_json, group_by='country_name', color_sequence=country_color_sequence, legend_map=country_name_map, legend_title='Country', results_dir=country_results_dir)