# Put all figures from experiments into a large grid

In [None]:
import os
import gc
import sys
import json

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.lines as mlines
plt.rcParams['text.usetex'] = True
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = 'Times New Roman'

import numpy as np
import pandas as pd

project_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(project_dir)

from function_vectors.src.utils import custom_utils as pqd
from config import (
    MODEL_FP_MAP,
    HF_NAME_MAP,
    STAGE_NAME_LONG,
    STAGE_NAME,
    STAGE_NAME_DATA
)

def get_baseline(model_name, data_name):
    """Read a JSON file and return its contents as a dictionary."""
    file_path = f"../function_vectors/results/{model_name}/{data_name}/baseline_n_shots.json"
    try:
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data
    except Exception as e:
        print(f"An error occurred: {e}")
        return None

In [None]:
HF_NAME_MAP

In [None]:
tmp = {k: HF_NAME_MAP[k] for k in list(HF_NAME_MAP.keys())[15:]}

In [None]:
# tmp = {'gptj-6b': 'EleutherAI/gpt-j-6b',}
tmp = {'gemma2-27b': 'google/gemma-2-27b',}

In [None]:
# Takes ~60 min for all tasks and models
save_location = 'param_search_appendix'

for model_name, hf_name_full in HF_NAME_MAP.items():
# for model_name, hf_name_full in tmp.items():
    # Set configs
    model_fp = MODEL_FP_MAP[model_name]
    hf_name = hf_name_full.split('/')[1]
    model_config = json.load(open(os.path.join(model_fp, 'config.json')))
    if model_name in ['gptj-6b']:
        n_layers = model_config['n_layer']
    else:
        n_layers = model_config['num_hidden_layers']
    N_TEST = 50
    if float(model_name.split('-')[1][:-1]) > 27:
        LAMBDAS = [1, 4, 16] 
        N_HEADS = [2, 64, 512, 1024]
    else:
        LAMBDAS = [0.5, 1, 2, 4, 8, 16, 32, 64]
        N_HEADS = [2, 16, 32, 64, 128, 256, 512]
    TASKS = [
        'antonym',
        'english-french',
        'english-german',
        'english-spanish',
        'french-english',
        'german-english',
        'spanish-english',
        'present-past',
        'country-capital',
    ]

    for data_name in TASKS:
        # Get baseline performance
        baseline_n_shots = get_baseline(model_name, data_name)
        
        # Make sure the data model combination exists
        if baseline_n_shots is None:
            continue
        save_dp = f'../figures/activation_patching/{save_location}/{model_name}'
        if not os.path.exists(save_dp):
            os.makedirs(save_dp)
        save_path = os.path.join(save_dp, f'{model_name}_{data_name}_full.pdf')
        print(model_name, hf_name, data_name)
        
        fig, axs = plt.subplots(len(LAMBDAS), len(N_HEADS), figsize=(10,10))

        zero_shot = dict(zip([str(i) for i in range(n_layers)], [baseline_n_shots['0']] * n_layers))
        one_shot = dict(zip([str(i) for i in range(n_layers)], [baseline_n_shots['1']] * n_layers))
        five_shot = dict(zip([str(i) for i in range(n_layers)], [baseline_n_shots['5']] * n_layers))

        palette = {
            '0_shot_FV': '#000000',
            '0_shot': '#D55E00',
            '1_shot': '#44AA99',
            '5_shot': '#332288',
        }
        
        for row_id, lda in enumerate(LAMBDAS):
            for col_id, n_head in enumerate(N_HEADS):
                n_head_clean, lda_clean = pqd.clean_numbers(n_head, lda)
                # Get results from Causual Indirect Effect
                saved_fname = f'{hf_name}_{data_name}_{n_head_clean}-top-heads_{lda_clean}-lambda_{N_TEST}-test-samples'
                dp = f"../function_vectors/results/{model_name}/{data_name}/0_shot_w_FV"
                results_fp = os.path.join(dp, f"{saved_fname}.json")
                if os.path.isfile(results_fp):
                    zero_shot_fv = json.load(open(results_fp))
                    to_plot = pd.DataFrame({
                        '0_shot_FV': zero_shot_fv,
                        '0_shot': zero_shot,
                        # '1_shot': one_shot,
                        '5_shot': five_shot,
                    })
                    to_plot.index = to_plot.index.astype(int)
                    to_plot.index += 1 
                    sns.lineplot(to_plot, ax=axs[row_id][col_id], legend=False, linewidth=2, palette=palette)
                # plot formatting
                if col_id == 0:
                    axs[row_id][col_id].set_ylabel(f'{lda}', fontsize=25, rotation=0, labelpad=20)
                    axs[row_id][col_id].yaxis.label.set_position((0, 0.33))
                # yticks
                axs[row_id][col_id].yaxis.tick_right()
                axs[row_id][col_id].set_yticks([0, 1])
                if col_id != len(N_HEADS) - 1:
                    axs[row_id][col_id].set_yticklabels([])
                    axs[row_id][col_id].set_yticks([])
                
                if row_id == 0:
                    axs[row_id][col_id].set_title(f'{n_head}', fontsize=25, rotation=0)
                # xticks
                axs[row_id][col_id].set_xticks([1, n_layers//2, n_layers])
                if row_id != len(LAMBDAS) - 1:
                    axs[row_id][col_id].set_xticks([])
                    
                    
                axs[row_id][col_id].tick_params(axis='x', labelsize=20)
                axs[row_id][col_id].tick_params(axis='y', labelsize=17)

        legend_labels = [
            mlines.Line2D([0], [0], color=palette['0_shot_FV'], linestyle='-', label='0-shot with Function Vector'),
            mlines.Line2D([0], [0], color=palette['0_shot'], linestyle='--', label='0-shot'),
            # mlines.Line2D([0], [0], color=palette['1_shot'], linestyle=':', label='1-shot'),
            mlines.Line2D([0], [0], color=palette['5_shot'], linestyle='-.', label='5-shot'),
        ]
        # Add the customized legend
        fig.legend(
            handles=legend_labels,
            loc='upper center', bbox_to_anchor=(0.5, 1.01),
            ncol=len(legend_labels), fontsize=20, framealpha=0,
        )   
        plt.suptitle(' \n ', fontsize=35)
        fig.text(0.55, .90, r'Number of Heads ($\mathcal{A}_{n}$)', ha='center', fontsize=30)
        fig.text(0.55, 1.02, f'{STAGE_NAME_LONG[model_name]} | {STAGE_NAME_DATA[data_name]}', ha='center', fontsize=30)
        fig.supylabel(r'Function Vector Strength ($\lambda$)', fontsize=30)

        fig.supxlabel('Activation Patching Layer', fontsize=30)
        

        fig.tight_layout()
        
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
        del fig, axs, zero_shot_fv, to_plot, legend_labels
        gc.collect()

In [None]:
from matplotlib.ticker import MultipleLocator

save_location = 'param_search_main'
tmp = {'mistral3-7b': 'mistralai/Mistral-7B-v0.3',}
# tmp = {'gptj-6b': 'EleutherAI/gpt-j-6b',}
# for model_name, hf_name_full in HF_NAME_MAP.items():
for model_name, hf_name_full in tmp.items():
    # Set configs
    model_fp = MODEL_FP_MAP[model_name]
    hf_name = hf_name_full.split('/')[1]
    model_config = json.load(open(os.path.join(model_fp, 'config.json')))
    if model_name in ['gptj-6b']:
        n_layers = model_config['n_layer']
    else:
        n_layers = model_config['num_hidden_layers']
    N_TEST = 50
    if float(model_name.split('-')[1][:-1]) > 27:
        LAMBDAS = [1, 4, 16] 
        N_HEADS = [2, 64, 512, 1024]
    else:
        LAMBDAS = [0.5, 1, 2] # 3 here
        N_HEADS = [64, 128, 256, 512] # 4 here
    TASKS = [
        # 'antonym',
        # 'english-french',
        # 'english-german',
        # 'english-spanish',
        # 'french-english',
        # 'german-english',
        # 'spanish-english',
        # 'present-past',
        'country-capital',
    ]

    for data_name in TASKS: #['french-english-antonym']:#, 'english-french', 'antonym']:
        # Get baseline performance
        baseline_n_shots = get_baseline(model_name, data_name)
        
        # Make sure the data model combination exists
        if baseline_n_shots is None:
            continue
        save_dp = f'../figures/activation_patching/{save_location}/{model_name}'
        if not os.path.exists(save_dp):
            os.makedirs(save_dp)
        save_path = os.path.join(save_dp, f'{model_name}_{data_name}.pdf')
        # if os.path.exists(save_path):
        #     continue
        print(model_name, hf_name, data_name)
        
        fig, axs = plt.subplots(len(LAMBDAS), len(N_HEADS), figsize=(10,6))

        zero_shot = dict(zip([str(i) for i in range(n_layers)], [baseline_n_shots['0']] * n_layers))
        one_shot = dict(zip([str(i) for i in range(n_layers)], [baseline_n_shots['1']] * n_layers))
        five_shot = dict(zip([str(i) for i in range(n_layers)], [baseline_n_shots['5']] * n_layers))

        palette = {
            '0_shot_FV': '#000000',
            '0_shot': '#D55E00',
            '1_shot': '#44AA99', #44AA99
            '5_shot': '#332288', #332288
        }
        
        for row_id, lda in enumerate(LAMBDAS):
            for col_id, n_head in enumerate(N_HEADS):
                n_head_clean, lda_clean = pqd.clean_numbers(n_head, lda)
                # Get results from Causual Indirect Effect
                saved_fname = f'{hf_name}_{data_name}_{n_head_clean}-top-heads_{lda_clean}-lambda_{N_TEST}-test-samples'
                dp = f"../function_vectors/results/{model_name}/{data_name}/0_shot_w_FV"
                results_fp = os.path.join(dp, f"{saved_fname}.json")
                if os.path.isfile(results_fp):
                    zero_shot_fv = json.load(open(results_fp))
                    to_plot = pd.DataFrame({
                        '0_shot_FV': zero_shot_fv,
                        '0_shot': zero_shot,
                        # '1_shot': one_shot,
                        '5_shot': five_shot,
                    })
                    to_plot.index = to_plot.index.astype(int)
                    to_plot.index += 1 
                    sns.lineplot(to_plot, ax=axs[row_id][col_id], legend=False, linewidth=3, palette=palette)
                # plot formatting
                axs[row_id][col_id].xaxis.set_minor_locator(MultipleLocator(2))
                if col_id == 0:
                    if lda == 0.5:
                        lda_ylabel = '.5'
                    else:
                        lda_ylabel = lda
                    axs[row_id][col_id].set_ylabel(f'{lda_ylabel}', fontsize=30, rotation=0, labelpad=20)
                    axs[row_id][col_id].yaxis.label.set_position((0, 0.33))
                # yticks
                axs[row_id][col_id].yaxis.tick_right()
                axs[row_id][col_id].set_yticks([0, 1])
                if col_id != len(N_HEADS) - 1:
                    axs[row_id][col_id].set_yticklabels([])
                    axs[row_id][col_id].set_yticks([])
                
                if row_id == 0:
                    axs[row_id][col_id].set_title(f'{n_head}', fontsize=30, rotation=0)
                # xticks
                axs[row_id][col_id].set_xticks(np.append(np.arange(0, n_layers, 16), n_layers-1))
                axs[row_id][col_id].set_xlim(0, n_layers-1)
                axs[row_id][col_id].tick_params(axis='x', which='minor', length=5, width=1)
                if row_id != len(LAMBDAS) - 1:
                    axs[row_id][col_id].set_xticklabels([])
                    
                    
                axs[row_id][col_id].tick_params(axis='x', labelsize=30, length=10, width=1)
                axs[row_id][col_id].tick_params(axis='y', labelsize=25)

        # Manually create legend handles
        # palette = sns.color_palette()
        # palette = {
        #     '0-shot w/ Function Vector': 'darkred',
        #     '0-shot': 'gray',
        #     '1-shot': '#3f4443',
        #     '5-shot': 'black',
        # }
        legend_labels = [
            mlines.Line2D([0], [0], color=palette['0_shot_FV'], linestyle='-', label='0-shot with Function Vector'),
            mlines.Line2D([0], [0], color=palette['0_shot'], linestyle='--', label='0-shot'),
            # mlines.Line2D([0], [0], color=palette['1_shot'], linestyle=':', label='1-shot'),
            mlines.Line2D([0], [0], color=palette['5_shot'], linestyle='-.', label='5-shot'),
        ]
        # Add the customized legend
        leg = fig.legend(
            handles=legend_labels,
            loc='upper center', bbox_to_anchor=(0.5, 1.095),
            ncol=2, fontsize=30, framealpha=0
        )   
        for line in leg.get_lines():
            line.set_linewidth(4.0)
        # plt.suptitle(f'{hf_name} | {data_name}')
        plt.suptitle(' \n ', fontsize=45)
        fig.text(0.55, .8, r'Number of Heads ($\mathcal{A}_{n}$)', ha='center', fontsize=28)
        # fig.text(0.55, 1.1, f'FV | {STAGE_NAME_LONG[model_name]} | {STAGE_NAME_DATA[data_name]}', ha='center', fontsize=35)
        fig.supylabel(r'Function Vector Strength ($\lambda$)', fontsize=28, y=.43)
        # fig.text(0.05, 0.5, 'test 1', va='center', ha='center', rotation=90, fontsize=14)
        # fig.text(1.0, 0.5, 'test 2', va='center', ha='center', rotation=90, fontsize=14)
        fig.supxlabel(r'Activation Patching Layer ($\ell$)', fontsize=28, x=.54)
        # plt.legend(loc="upper center") 
        
        # plt.subplots_adjust(left=0.05, right=1.05)
        fig.tight_layout()
        
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        # plt.close()
        del fig, axs, zero_shot_fv, to_plot, legend_labels
        gc.collect()
        # break
    # break

# Check 5 vs 10 shot perf

In [None]:
TASKS = [
        'antonym',
        'english-french',
        'english-german',
        'english-spanish',
        'french-english',
        'german-english',
        'spanish-english',
        'present-past',
        'country-capital',
    ]
all_data = []
for model_name, hf_name_full in HF_NAME_MAP.items():
    for data_name in TASKS:
        # Get baseline performance
        baseline_n_shots = get_baseline(model_name, data_name)
        
        # Make sure the data model combination exists
        if baseline_n_shots is None:
            continue
        
        zero_shot = baseline_n_shots['0']
        five_shot = baseline_n_shots['5']
        ten_shot = baseline_n_shots['10']
        all_data.append((model_name, data_name, ten_shot, five_shot, zero_shot))

In [None]:
palette = {
    '0-shot': '#000000',
    '10-shot': '#0072B2', #44AA99
    '5-shot': '#D55E00', #332288
}
legend_labels = [
    mlines.Line2D([0], [0], color=palette['10-shot'], linestyle='-', label='10-shot'),
    mlines.Line2D([0], [0], color=palette['5-shot'], linestyle='--', label='5-shot'),
    mlines.Line2D([0], [0], color=palette['0-shot'], linestyle=':', label='0-shot'),
    
]
df = pd.DataFrame(all_data, columns=['model', 'data', '10-shot', '5-shot', '0-shot'])
df2 = df.drop('data', axis=1)
df2 = df2.melt(id_vars='model', var_name='nshot', value_name='ICL Acc')
df2['model'] = df2['model'].map(STAGE_NAME)
plt.figure(figsize=(14, 4.5))
ax = sns.lineplot(data=df2, x='model', y='ICL Acc', hue='nshot', style='nshot', palette=palette)
ax.set_xlabel('')
ax.set_yticks(np.arange(0, 1.1, .2))
ax.set_xticks(np.arange(0, 36, 1))
ax.set_ylabel('ICL Accuracy', fontsize=25)
ax.tick_params('x', rotation=45, labelsize=15)
ax.tick_params('y', labelsize=20)
plt.legend(handles=legend_labels, fontsize=20, framealpha=0, loc='center left', bbox_to_anchor=(0, 0.4))
plt.title('5 vs 10 Shot Performance Across 11 ICL Tasks', fontsize=35)
plt.tight_layout()
plt.savefig('../figures/activation_patching/5_vs_10_shot.pdf', dpi=300, bbox_inches='tight')