In [4]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Function to create bar plots for specified CSV files
def create_plots(llm_folder, llm_names, id_dataset_subjects, ood_dataset_subjects):
    # Define colors
    colors = plt.cm.tab10(np.linspace(0, 1, 4))  # Adjust color map for more subtle colors
    print(colors)
    colors[0] = np.array([76, 146, 195, 255])/255
    colors[1] = np.array([255, 152, 62, 255])/255
    colors[2] = np.array([86, 179, 86, 255])/255
    colors[3] = np.array([222, 83, 83, 255])/255
    print(colors)

    # Iterate through LLM folders
    for llm_name in llm_names:
        llm_dir = os.path.join(llm_folder, llm_name)

        # Iterate through 'varying_option' and 'varying_position' folders
        for folder_name in ['varying_option', 'varying_position']:
            folder_dir = os.path.join(llm_dir, folder_name)

            # Check if the folder exists
            if os.path.exists(folder_dir):
                # Iterate through specified CSV files
                for csv_file in ['incorrect_likelihoods_ID.csv', 'incorrect_likelihoods_OOD.csv',
                                 'ppa_scores_ID.csv', 'ppa_scores_OOD.csv',
                                 'recall_imbalance_ID.csv', 'recall_imbalance_OOD.csv']:
                    csv_path = os.path.join(folder_dir, csv_file)

                    # Check if the CSV file exists
                    if os.path.exists(csv_path):

                      # Determine dataset subjects based on file name endings
                        # if folder_name == 'varying_option':
                        #     x_labels = ['A', 'B', 'C', 'D']
                        # else:
                        #     x_labels = ['0', '1', '2', '3']

                        # Filter rows based on specified dataset subjects
                        if csv_file.endswith('_ID.csv'):
                            dataset_subjects = id_dataset_subjects
                        elif csv_file.endswith('_OOD.csv'):
                            dataset_subjects = ood_dataset_subjects
                        else:
                            print(f"Invalid file format: '{csv_file}'")
                            continue

                        # Read the CSV file
                        df = pd.read_csv(csv_path)

                        df_filtered = df[df['Dataset'].isin(dataset_subjects)]

                        # Plotting
                        plt.figure(figsize=(12, 6))
                        ax = df_filtered.plot(x='Dataset', kind='bar', color=colors)
                        # plt.xlabel('Dataset Subject')
                        plt.ylabel(f'{csv_file[:-4]}')
                        plt.title(f'Plot for {csv_file[:-4]} in {folder_name} for {llm_name}')
                        plt.xticks(rotation=45, ha='right')
                        plt.tight_layout()

                        # # Customize x-axis labels and ticks
                        # num_subjects = len(dataset_subjects)
                        # num_labels = len(x_labels)
                        # x_ticks = np.arange(num_subjects*num_labels) + 0.5
                        # print(num_subjects,num_labels, x_ticks)
                        # ax.set_xticks(x_ticks)
                        # ax.set_xticklabels([f"{label}\n{subject}" for subject in dataset_subjects for label in x_labels], rotation=45, ha='right')

                        # Customize x-axis labels and ticks
                        # num_subjects = len(dataset_subjects)
                        # num_labels = len(x_labels)
                        # x_ticks = np.arange(num_subjects) * num_labels + num_labels / 2
                        # ax.set_xticks(x_ticks)
                        # ax.set_xticklabels([[[label for label in x_labels], subject] for subject in dataset_subjects], rotation=45, ha='right')


                        # Save the plot inside the folder
                        plot_file = os.path.join(folder_dir, f'{csv_file[:-4]}_plot.png')
                        plt.savefig(plot_file)
                        plt.close()

                        print(f"Plot saved as '{plot_file}'")


# Specify the LLM folders, dataset subjects for ID and OOD plots, and path to the main folder
llm_folders = ['gemma-2b-it']  # Update with LLM folder names
#, 'gemma-7b-it', 'gemma-2b-it', 'llama-2-7b', 'llama-2-13b', 'mistral-7b', 'mistral-7b-instruct', 'llama-3-8b', 'vicuna-7b', 'vicuna-13b', 'llama-7b', 'gemma-7b'
id_dataset_subjects = ['anatomy', 'astronomy', 'business_ethics', 'clinical_knowledge', 'college_biology']  # Update with ID dataset subjects
# 'college_chemistry', 'college_computer_science', 'college_mathematics', 'college_physics', 'computer_security', 'econometrics', 'electrical_engineering', 'elementary_mathematics', 'formal_logic', 'global_facts', 'high_school_biology', 'high_school_chemistry', 'high_school_computer_science', 'high_school_european_history', 'high_school_geography', 'high_school_government_and_politics', 'high_school_macroeconomics', 'high_school_microeconomics', 'high_school_physics', 'high_school_psychology', 'high_school_statistics', 'high_school_us_history', 'high_school_world_history', 'human_aging', 'human_sexuality', 'international_law', 'jurisprudence', 'logical_fallacies', 'machine_learning', 'management', 'marketing', 'medical_genetics', 'miscellaneous', 'moral_disputes', 'moral_scenarios', 'nutrition', 'professional_accounting', 'professional_medicine', 'professional_psychology', 'public_relations', 'security_studies', 'sociology', 'us_foreign_policy', 'virology', 'world_religions'
ood_dataset_subjects = ['professional_law', 'prehistory', 'philosophy', 'high_school_mathematics', 'conceptual_physics', 'college_medicine', 'abstract_algebra']  # Update with OOD dataset subjects
llm_folder = '/content/drive/MyDrive/685 Project/mmlu 03'  # Update with the path to your LLM folder

# Call the function to create plots
create_plots(llm_folder, llm_folders, id_dataset_subjects, ood_dataset_subjects)

[[0.12156863 0.46666667 0.70588235 1.        ]
 [0.83921569 0.15294118 0.15686275 1.        ]
 [0.89019608 0.46666667 0.76078431 1.        ]
 [0.09019608 0.74509804 0.81176471 1.        ]]
[[0.29803922 0.57254902 0.76470588 1.        ]
 [1.         0.59607843 0.24313725 1.        ]
 [0.3372549  0.70196078 0.3372549  1.        ]
 [0.87058824 0.3254902  0.3254902  1.        ]]
Plot saved as '/content/drive/MyDrive/685 Project/mmlu 03/gemma-2b-it/varying_option/incorrect_likelihoods_ID_plot.png'
Plot saved as '/content/drive/MyDrive/685 Project/mmlu 03/gemma-2b-it/varying_option/incorrect_likelihoods_OOD_plot.png'
Plot saved as '/content/drive/MyDrive/685 Project/mmlu 03/gemma-2b-it/varying_option/ppa_scores_ID_plot.png'
Plot saved as '/content/drive/MyDrive/685 Project/mmlu 03/gemma-2b-it/varying_option/ppa_scores_OOD_plot.png'
Plot saved as '/content/drive/MyDrive/685 Project/mmlu 03/gemma-2b-it/varying_option/recall_imbalance_ID_plot.png'
Plot saved as '/content/drive/MyDrive/685 Proje

<Figure size 1200x600 with 0 Axes>

<Figure size 1200x600 with 0 Axes>

<Figure size 1200x600 with 0 Axes>

<Figure size 1200x600 with 0 Axes>

<Figure size 1200x600 with 0 Axes>

<Figure size 1200x600 with 0 Axes>

<Figure size 1200x600 with 0 Axes>

<Figure size 1200x600 with 0 Axes>

<Figure size 1200x600 with 0 Axes>

<Figure size 1200x600 with 0 Axes>

<Figure size 1200x600 with 0 Axes>

<Figure size 1200x600 with 0 Axes>

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive
