In [None]:
import os
import gc

import numpy as np
import pandas as pd

import seaborn as sns
from matplotlib import pyplot as plt

os.chdir('../icl_task_vectors')

from scripts.figures import main as figs

import sys
sys.path.append('..')
from config import (
    STAGE_NAME_LONG,
    STAGE_NAME_DATA
)

In [None]:
experiment_id = "2025_01_16"
results = figs.load_main_results(experiment_id)
accuracies = figs.extract_accuracies(results)
accuracies_df = figs.create_accuracies_df(results)
grouped_accuracies_df = figs.create_grouped_accuracies_df(accuracies_df)

In [None]:
task_map = {
    'linguistic_antonyms': 'antonym',
    'linguistic_present_simple_past_simple': 'present-past',
    'knowledge_country_capital': 'country-capital',
    'translation_en_fr': 'translation_en_fr',
    'translation_en_it': 'translation_en_it',
    'translation_en_es': 'translation_en_es',
    'translation_fr_en': 'translation_fr_en',
    'translation_it_en': 'translation_it_en',
    'translation_es_en': 'translation_es_en',
}
task_short = {
    'linguistic_antonyms': 'antonym',
    'knowledge_country_capital': 'country-capital',
    'linguistic_present_simple_past_simple': 'present-past',
    'translation_es_en': 'sp-eng',
    'translation_en_es': 'eng-sp',
    'translation_en_it': 'eng-it',
    'translation_it_en': 'it-eng',
    'translation_fr_en': 'fr-eng',
    'translation_en_fr': 'eng-fr',
}
palette = {
    'Predicted': '#44AA99',
    'Top': '#000000', #44AA99
    'Correct': '#332288',
    'Incorrect': '#D55E00',
}

In [None]:
regular_accuracy_threshold = 0

for idx, model_name in enumerate(results.keys()):
    min_num_layers = min(
        len(results[model_name][task_name]["tv_dev_accruacy_by_layer"]) for task_name in results[model_name]
    )
    all_tv_dev_accruacy_by_layer = {
            task_name: np.array(list(results[model_name][task_name]["tv_dev_accruacy_by_layer"].values())[:min_num_layers])
            for task_name in results[model_name]
    }
    fig, axs = plt.subplots(1, 9, figsize=(12,3))
    df = pd.DataFrame(all_tv_dev_accruacy_by_layer)
    # To save for FV plotting
    df['to-eng'] = df[['translation_fr_en', 'translation_en_es', 'translation_it_en']].mean(axis=1)
    df['from-eng'] = df[['translation_es_en', 'translation_en_fr', 'translation_en_it']].mean(axis=1)
    df = df.rename(columns=task_map)
    df.to_csv(f'./for_fv/{model_name}.csv', index=False)
    baselines = []
    for i, task in enumerate(task_map.keys()):
        baseline = results[model_name][task]['baseline_accuracy']
        icl = results[model_name][task]['icl_accuracy']
        baselines.append((task, baseline, icl))
        # plotting
        sns.lineplot(df[task_map[task]], ax=axs[i], color='black', label='Task Vector', legend=False)
        axs[i].set_title(task_short[task])
        axs[i].axhline(baseline, color='#D55E00', linestyle='--', label='0-shot')
        axs[i].axhline(icl, color='#332288', linestyle=':', label='5-shot')
        axs[i].set_ylabel('')
        axs[i].set_yticks([])
        axs[i].set_xticks(np.linspace(0, min_num_layers, 5))
        axs[i].set_xlim(0, min_num_layers)

    handles, labels = axs[i].get_legend_handles_labels()
    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.92), ncol=3, framealpha=0, fontsize=12)
    baselines = pd.DataFrame(baselines, columns = ['', 'baseline', 'icl']).T
    baselines.columns = baselines.iloc[0]
    baselines = baselines.drop('', axis=0)
    baselines['to-eng'] = baselines[['translation_fr_en', 'translation_en_es', 'translation_it_en']].mean(axis=1)
    baselines['from-eng'] = baselines[['translation_es_en', 'translation_en_fr', 'translation_en_it']].mean(axis=1)
    baselines = baselines.rename(columns=task_map)
    baselines.to_csv(f'./for_fv/{model_name}_baselines.csv', index=False)
        
    axs[0].set_yticks(np.linspace(0, 1, 6))
    plt.suptitle(f"{STAGE_NAME_LONG[model_name.replace('_', '-')]}\n", fontsize=15)
    fig.supxlabel(f'Layer', y=0.1)
    fig.supylabel('ICL Accuracy', y=0.47)
    plt.tight_layout()
    # plt.show()
    # break
    plt.savefig(f'../figures/activation_patching/task_vector/tv_{model_name}.pdf', bbox_inches='tight', dpi=300)
    plt.close()
    del fig, axs, df
    gc.collect()