# Quantify differences between models and across tasks with different parameters

In [None]:
import os
import sys
import json

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.ticker import ScalarFormatter
plt.rcParams['text.usetex'] = True
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = 'Times New Roman'
import seaborn as sns

import numpy as np
import pandas as pd
from collections import defaultdict

sys.path.append('..')
from function_vectors.src.utils import custom_utils as pqd
from config import IMPLEMENTED_MODELS, MODEL_FP_MAP, HF_NAME_MAP, STAGE_NAME

def get_baseline(model_name, data_name):
    """Read a JSON file and return its contents as a dictionary."""
    file_path = f"../function_vectors/results/{model_name}/{data_name}/baseline_n_shots.json"
    try:
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data
    except Exception as e:
        # print(f"An error occurred: {e}")
        return None

TASKS = [
    'antonym',
    'english-french',
    'english-german',
    'english-spanish',
    'french-english',
    'german-english',
    'spanish-english',
    'present-past',
    'country-capital'
]
N_TEST = 50

# To plot across heads

In [None]:
all_subplots_data = []
all_recovery = []
for default_params in [True, False]:
    all_avgs, all_peaks, all_scores, peak_layer = [], [], [], []
    recovery = {'1': [], '5': [], '5_90': [], '5_75': [], '5_50': []}
    for data_name in TASKS:
        for model_name, hf_name_full in HF_NAME_MAP.items():
            # Set configs
            model_fp = MODEL_FP_MAP[model_name]
            hf_name = hf_name_full.split('/')[1]
            model_config = json.load(open(os.path.join(model_fp, 'config.json')))
            if model_name in ['gptj-6b']:
                n_layers = model_config['n_layer']
            else:
                n_layers = model_config['num_hidden_layers']
            if float(model_name.split('-')[1][:-1]) > 27:
                LAMBDAS = [1, 4, 16] 
                N_HEADS = [2, 64, 512, 1024]
            else:
                LAMBDAS = [0.5, 1, 2, 4, 8, 16, 32, 64]
                N_HEADS = [2, 16, 32, 64, 128, 256, 512]
            if default_params:
                LAMBDAS = [1] 
                if float(model_name.split('-')[1][:-1]) > 27:
                    N_HEADS = [2, 64]
                else:
                    N_HEADS = [2, 16]
            # Get baseline performance
            baseline_n_shots = get_baseline(model_name, data_name)
            if baseline_n_shots is None:
                continue
            zero_shot = dict(zip([str(i) for i in range(n_layers)], [baseline_n_shots['0']] * n_layers))['0']
            one_shot = dict(zip([str(i) for i in range(n_layers)], [baseline_n_shots['1']] * n_layers))['0']
            five_shot = dict(zip([str(i) for i in range(n_layers)], [baseline_n_shots['5']] * n_layers))['0']
            by_layer, avgs, peaks, scores, peak_layer = [defaultdict(list) for _ in range(5)]
            max_perf_recovery = []
            for col_id, n_head in enumerate(N_HEADS):
                for row_id, lda in enumerate(LAMBDAS):
                    n_head_clean, lda_clean = pqd.clean_numbers(n_head, lda)
                    # Get results from Causual Indirect Effect
                    saved_fname = f'{hf_name}_{data_name}_{n_head_clean}-top-heads_{lda_clean}-lambda_{N_TEST}-test-samples'
                    dp = f"../function_vectors/results/{model_name}/{data_name}/0_shot_w_FV"
                    results_fp = os.path.join(dp, f"{saved_fname}.json")
                    if os.path.isfile(results_fp):
                        zero_shot_fv = json.load(open(results_fp))
                        recovery_per_layer = np.array(list(zero_shot_fv.values()))/five_shot
                        by_layer[n_head].append(recovery_per_layer)
                        avgs[n_head].append(np.nanmean(recovery_per_layer)) # AUC type score
                        peaks[n_head].append(np.nanmax(recovery_per_layer)) # peak performance recovery
                        scores[n_head].append(np.nanmean(recovery_per_layer)*np.max(recovery_per_layer)) # AUC*peak
                        max_perf_recovery.append(np.nanmax(list(zero_shot_fv.values())))
                # Calculate the peak layer???
                if by_layer[n_head]:
                    avgs[n_head] = np.max(avgs[n_head])
                    peaks[n_head] = np.max(peaks[n_head])
                    scores[n_head] = np.max(scores[n_head])
                else:
                    nan_array = np.empty(200)
                    nan_array[:] = np.nan
            recovery['1'].append(np.nanmax(max_perf_recovery) > zero_shot)
            recovery['5'].append(np.nanmax(max_perf_recovery) >= five_shot)
            recovery['5_90'].append(np.nanmax(max_perf_recovery) >= 0.90*five_shot)
            recovery['5_75'].append(np.nanmax(max_perf_recovery) >= 0.75*five_shot)
            recovery['5_50'].append(np.nanmax(max_perf_recovery) >= 0.50*five_shot)
            
            fv_df = pd.DataFrame([avgs])
            fv_df.insert(0, 'data_name', data_name)
            fv_df.insert(0, 'model_name', model_name)
            all_scores.append(fv_df)

            fv_df = pd.DataFrame([peaks])
            fv_df.insert(0, 'data_name', data_name)
            fv_df.insert(0, 'model_name', model_name)
            all_peaks.append(fv_df)
    all_subplots_data.append(all_scores)
    all_subplots_data.append(all_peaks)
    all_recovery.append(recovery)

# To plot across lambdas or layers
- (expected warnings) RuntimeWarning: All-NaN slice encountered all_layers.append(np.nanmax(np.array(list(by_layer2.values())), axis=0))

In [None]:
default_params = False
all_subplots_data = []
all_recovery = []
for default_params in [True, False]:
    all_avgs, all_peaks, all_scores, peak_layer, all_layers, model_names_l, tasks_l = [], [], [], [], [], [], []
    n_layers_l = {}
    recovery = {'1': [], '5': [], '5_90': [], '5_75': [], '5_50': []}
    for model_name, hf_name_full in HF_NAME_MAP.items():
        # Set configs
        model_fp = MODEL_FP_MAP[model_name]
        hf_name = hf_name_full.split('/')[1]
        model_config = json.load(open(os.path.join(model_fp, 'config.json')))
        if model_name in ['gptj-6b']:
            n_layers = model_config['n_layer']
        else:
            n_layers = model_config['num_hidden_layers']
        N_TEST = 50
        if float(model_name.split('-')[1][:-1]) > 27:
            # continue
            LAMBDAS = [1, 4, 16] 
            N_HEADS = [2, 64, 512, 1024]
        else:
            LAMBDAS = [0.5, 1, 2, 4, 8, 16, 32, 64]
            N_HEADS = [2, 16, 32, 64, 128, 256, 512]
        if default_params:
            LAMBDAS = [1] 
            if float(model_name.split('-')[1][:-1]) > 27:
                N_HEADS = [2, 64]
            else:
                N_HEADS = [2, 16, 32]
        n_layers_l[model_name] = n_layers
        for data_name in TASKS:
            # Get baseline performance
            baseline_n_shots = get_baseline(model_name, data_name)
            if baseline_n_shots is None:
                continue
            zero_shot = dict(zip([str(i) for i in range(n_layers)], [baseline_n_shots['0']] * n_layers))['0']
            one_shot = dict(zip([str(i) for i in range(n_layers)], [baseline_n_shots['1']] * n_layers))['0']
            five_shot = dict(zip([str(i) for i in range(n_layers)], [baseline_n_shots['5']] * n_layers))['0']
            by_layer, by_layer2, avgs, peaks, scores, peak_layer = [defaultdict(list) for _ in range(6)]
            max_perf_recovery = []
            for col_id, lda in enumerate(LAMBDAS):
                for row_id, n_head in enumerate(N_HEADS):
                    n_head_clean, lda_clean = pqd.clean_numbers(n_head, lda)
                    # Get results from Causual Indirect Effect
                    saved_fname = f'{hf_name}_{data_name}_{n_head_clean}-top-heads_{lda_clean}-lambda_{N_TEST}-test-samples'
                    dp = f"../function_vectors/results/{model_name}/{data_name}/0_shot_w_FV"
                    results_fp = os.path.join(dp, f"{saved_fname}.json")
                    if os.path.isfile(results_fp):
                        zero_shot_fv = json.load(open(results_fp))
                        recovery_per_layer = np.array(list(zero_shot_fv.values()))/five_shot
                        by_layer[lda].append(recovery_per_layer)
                        by_layer2[lda].append(np.array(recovery_per_layer))
                        avgs[lda].append(np.nanmean(recovery_per_layer)) # AUC type score
                        peaks[lda].append(np.nanmax(recovery_per_layer)) # peak performance recovery
                        scores[lda].append(np.nanmean(recovery_per_layer)*np.max(recovery_per_layer)) # AUC*peak
                        max_perf_recovery.append(np.nanmax(list(zero_shot_fv.values())))
                if by_layer[lda]: # why do i do this?
                    avgs[lda] = np.max(avgs[lda])
                    # get peak layer if it is at least better than 1 shot
                    use_only_gt_one_shot = np.array(by_layer[lda]) - one_shot > -0.05
                    gt_one_shot = np.where(use_only_gt_one_shot, np.array(by_layer[lda]), -1)
                    if gt_one_shot.sum() > -gt_one_shot.size:
                        peak_layer[lda] = np.argmax(np.max(gt_one_shot, axis=0)) + 1
                    else:
                        peak_layer[lda] = np.nan#'TMP'
                    peaks[lda] = np.max(peaks[lda])
                    scores[lda] = np.max(scores[lda])
                    layer_perf_lda_avg = np.array(by_layer2[lda]).max(axis=0)
                    by_layer2[lda] = np.pad(layer_perf_lda_avg, (0, 200 - layer_perf_lda_avg.size), 'constant', constant_values=np.nan) 
                else:
                    nan_array = np.empty(200)
                    nan_array[:] = np.nan
                    by_layer2[lda] =  nan_array
            all_layers.append(np.nanmax(np.array(list(by_layer2.values())), axis=0))
            model_names_l.append(model_name)
            tasks_l.append(data_name)
            recovery['1'].append(np.nanmax(max_perf_recovery) > zero_shot)
            recovery['5'].append(np.nanmax(max_perf_recovery) >= five_shot)
            recovery['5_90'].append(np.nanmax(max_perf_recovery) >= 0.90*five_shot)
            recovery['5_75'].append(np.nanmax(max_perf_recovery) >= 0.75*five_shot)
            recovery['5_50'].append(np.nanmax(max_perf_recovery) >= 0.50*five_shot)
            
            fv_df = pd.DataFrame([avgs])
            fv_df.insert(0, 'data_name', data_name)
            fv_df.insert(0, 'model_name', model_name)
            all_scores.append(fv_df)

            fv_df = pd.DataFrame([peaks])
            fv_df.insert(0, 'data_name', data_name)
            fv_df.insert(0, 'model_name', model_name)
            all_peaks.append(fv_df)
    all_subplots_data.append(all_scores)
    all_subplots_data.append(all_peaks)
    all_recovery.append(recovery)

# Run this for all plots

In [None]:
fv_df = pd.concat(all_peaks, axis=0)
fv_df

In [None]:
model_order = [
    'gptj-6b',       
    'pythia-6.9b',  
    'llama-7b',
    'llama2-7b',   
    'llama2i-7b',
    'llama2-13b',    
    'llama2i-13b',    
    'llama2-70b',
    'llama3-8b',
    'llama3i-8b',
    'llama3-70b',
    'llama3i-70b',
    'llama3.2-3b',    
    'mistral1-7b',    
    'mistral3-7b',    
    'mistral3i-7b',
    'amber-7b',
    'falcon3-7b',
    'gemma2-2b',      
    'gemma2-9b',      
    'gemma2i-9b',
    'gemma2-27b',     
    'qwen2-1.5b', 
    'qwen2-7b',      
    'qwen2i-7b',  
    'qwen2-72b',
    'qwen2.5-3b',
    'qwen2.5i-3b',
    'qwen2.5-7b',
    'qwen2.5i-7b',
    'qwen2.5-14b',
    'olmo-7b',             
    'olmoi-7b',
    'olmo2-7b',
    'olmo2i-7b',
    'olmo2-13b',
]
instr_order = [
    'llama2-7b',     
    'llama2i-7b',
    'llama2-13b',   
    'llama2i-13b',
    'llama3-8b',
    'llama3i-8b',
    'llama3-70b',
    'llama3i-70b',
    'mistral3-7b',    
    'mistral3i-7b',     
    'gemma2-9b',      
    'gemma2i-9b',
    'qwen2-7b',      
    'qwen2i-7b',  
    'qwen2.5-3b',
    'qwen2.5i-3b',
    'qwen2.5-7b',
    'qwen2.5i-7b',
    'olmo-7b',               
    'olmoi-7b',
    'olmo2-7b',
    'olmo2i-7b',
]

diff_instr = [
    'llama2-7b',
    'llama2-13b',
    'llama3-8b',
    'llama3-70b',
    'mistral3-7b',
    'gemma2-9b',
    'qwen2-7b',
    'qwen2.5-3b',
    'qwen2.5-7b',
    'olmo-7b',
    'olmo2-7b',
]

instr_only = [
    'llama2i-7b',
    'llama2i-13b',
    'llama3i-8b',
    'llama3i-70b',
    'mistral3i-7b',     
    'gemma2i-9b',
    'qwen2i-7b',  
    'qwen2.5i-3b',
    'qwen2.5i-7b',
    'olmoi-7b',
    'olmo2i-7b',
]

data_short_map = {
    'antonym': 'ant',
    'present-past': 'pres-past',
    'country-capital': 'cntry-cap',
    'french-english': 'fr-eng',
    'german-english': 'ge-eng',
    'spanish-english': 'sp-eng',
    'english-french': 'eng-fr',
    'english-german': 'eng-ge',
    'english-spanish': 'eng-sp',
}

# Load in Task Vector Performance

In [None]:
scores = False
tv_recovery = {'0': [], '5': [], '5_90': [], '5_75': [], '5_50': []}
all_scores, all_peaks = [], []
for model_name in model_order:
    model_name = model_name.replace('-', '_')
    if os.path.exists(f'../icl_task_vectors/for_fv/{model_name}.csv'):
        tv_fv_df = pd.read_csv(f'../icl_task_vectors/for_fv/{model_name}.csv')
        baselines = pd.read_csv(f'../icl_task_vectors/for_fv/{model_name}_baselines.csv')
        model_name = model_name.replace('_', '-')
        scores, peaks = [model_name], [model_name]
        # For performance recovery
        for task in ['translation_fr_en', 'country-capital', 'translation_es_en', 'translation_en_fr', 'translation_en_es', 'translation_en_it', 'translation_it_en', 'present-past', 'antonym']:
            tv_recovery['0'].append(tv_fv_df.loc[:, task].max(skipna=True, axis=0) > zero_shot)
            tv_recovery['5'].append(tv_fv_df.loc[:, task].max(skipna=True, axis=0) >= five_shot)
            tv_recovery['5_90'].append(tv_fv_df.loc[:, task].max(skipna=True, axis=0) >= 0.9*five_shot)
            tv_recovery['5_75'].append(tv_fv_df.loc[:, task].max(skipna=True, axis=0) >= 0.75*five_shot)
            tv_recovery['5_50'].append(tv_fv_df.loc[:, task].max(skipna=True, axis=0) >= 0.50*five_shot)
            
        for task in ['antonym', 'present-past', 'country-capital', 'to-eng', 'from-eng']:
            zero_shot = baselines.loc[0, task]
            five_shot = baselines.loc[1, task]

            score = tv_fv_df[task].mean()*tv_fv_df[task].max()/five_shot
            scores.append(score)

            peak = tv_fv_df[task].max()/five_shot
            peaks.append(peak)
        all_scores.append(scores)
        all_peaks.append(peaks)
tv_scores = pd.DataFrame(all_scores, columns=['model_name', 'antonym', 'present-past', 'country-capital', 'to-eng', 'from-eng'])
tv_scores.insert(1, 'for_code_consistency', 0)
tv_peaks = pd.DataFrame(all_peaks, columns=['model_name', 'antonym', 'present-past', 'country-capital', 'to-eng', 'from-eng'])
tv_peaks.insert(1, 'for_code_consistency', 0)

# Plot heads against istr-base models

In [None]:
head_n_models = fv_df.copy()

head_n_models = head_n_models.drop(columns=['data_name']).groupby('model_name').mean()
head_n_models = head_n_models.reindex(instr_order)

even_rows = head_n_models.iloc[0::2].reset_index(drop=True)
odd_rows = head_n_models.iloc[1::2].reset_index(drop=True)   
head_n_models = odd_rows - even_rows
head_n_models.index = diff_instr

plt.figure(figsize=(3.5, 3))

head_n_models = head_n_models.reset_index(names='model_name')
head_n_models = pd.melt(head_n_models, id_vars='model_name', var_name='N Heads', value_name='Performance Recovery')
ax = sns.lineplot(data=head_n_models, x='N Heads', y='Performance Recovery', linewidth=2, color='black')

plt.xticks([2, 16, 32, 64, 128, 256, 512])
plt.ylabel('Perf Recovery')
plt.xlabel('N Heads')

plt.axhline(0, color='gray', linestyle='--')

ax.set_xscale('log')
ax.set_xticks([2, 16, 32, 64, 128, 256, 512])
ax.set_xlim([2, 512])
ax.xaxis.set_major_formatter(ScalarFormatter())
ax.xaxis.get_major_formatter().set_useOffset(False)
ax.minorticks_off()
ax.set_yticks(np.linspace(0, 0.4, 5))
ax.tick_params('x', labelsize=12)
ax.tick_params('y', labelsize=13)

ax.set_xlabel("Number of Heads", fontsize=13)
ax.set_ylabel("$\Delta$ Performance Recovery", fontsize=13)

plt.title('Post-trained Performance over Base', fontsize=13)
plt.tight_layout()
plt.savefig('../figures/activation_patching/head_instr.pdf', dpi=300, bbox_inches='tight')

# Plot heads against task

In [None]:
head_n_tasks = fv_df.copy()
plt.figure(figsize=(3.5,3))
head_n_tasks = head_n_tasks.drop(columns=['model_name'])
head_n_tasks = pd.melt(head_n_tasks, id_vars='data_name', var_name='N Heads', value_name='Performance Recovery')
# Formatting for plot
head_n_tasks['data_name'] = head_n_tasks['data_name'].map(data_short_map)
head_n_tasks['sort_order'] = head_n_tasks['data_name'].map(lambda x: list(data_short_map.values()).index(x))
head_n_tasks = head_n_tasks.sort_values('sort_order').drop(columns=['sort_order'])
palette = {
    'ling and fact': '#D55E00',
    'eng to [lang]': '#000000',
    '[lang] to eng': '#332288',
}
hm2 = head_n_tasks.copy()
hm2 = hm2.set_index('data_name')
hm2 = hm2.rename(index={'ant': 'ling and fact', 'pres-past': 'ling and fact', 'cntry-cap': 'ling and fact'})
hm2 = hm2.rename(index={'eng-fr': 'eng to [lang]', 'eng-ge': 'eng to [lang]', 'eng-sp': 'eng to [lang]'})
hm2 = hm2.rename(index={'fr-eng': '[lang] to eng', 'ge-eng': '[lang] to eng', 'sp-eng': '[lang] to eng'})
hm2 = hm2.reset_index()

ax = sns.lineplot(data=hm2, x='N Heads', y='Performance Recovery', hue='data_name', style='data_name', dashes=True, linewidth=2, palette=palette)

ax.set_xscale('log')
ax.set_xticks([2, 16, 32, 64, 128, 256, 512])
ax.set_xlim([2, 512])
ax.xaxis.set_major_formatter(ScalarFormatter())
ax.xaxis.get_major_formatter().set_useOffset(False)
ax.minorticks_off()
ax.set_yticks(np.linspace(0, 1.0, 6))
ax.tick_params('x', labelsize=12)
ax.tick_params('y', labelsize=13)

ax.tick_params(labelsize=12)

ax.set_xlabel("Number of Heads", fontsize=12)
ax.set_ylabel("Peak Performance Recovery", fontsize=13)

plt.legend(ncol=1, loc='lower right', bbox_to_anchor=(0.97, -0.04), fontsize=11, framealpha=0) # 

plt.title('Localization can be Task-Dependent', fontsize=13)

plt.tight_layout()
plt.savefig('../figures/activation_patching/head_tasks.pdf', dpi=300, bbox_inches='tight')

# Joint Heads analysis

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(6, 3.5))

# Left plot tasks/heads
sns.lineplot(data=hm2, x='N Heads', y='Performance Recovery', hue='data_name', style='data_name', dashes=True, linewidth=2, palette=palette, ax=axs[0])
axs[0].set_xscale('log')
axs[0].set_xticks([2, 16, 32, 64, 128, 256, 512])
axs[0].set_xlim([2, 512])
axs[0].xaxis.set_major_formatter(ScalarFormatter())
axs[0].xaxis.get_major_formatter().set_useOffset(False)
axs[0].minorticks_off()
axs[0].set_yticks(np.linspace(0, 0.9, 4))
axs[0].tick_params('x', labelsize=15, rotation=45)
axs[0].tick_params('y', labelsize=15)


axs[0].set_ylabel("Peak Performance Recovery", fontsize=15, y = 0.5)
axs[0].set_xlabel('')

axs[0].legend(ncol=1, loc='lower right', bbox_to_anchor=(0.99, -0.04), fontsize=12, framealpha=0) # 

axs[0].set_title('Task Localization', fontsize=17)

# Heads and instruction tuning
sns.lineplot(data=head_n_models, x='N Heads', y='Performance Recovery', linewidth=2, color='black', ax=axs[1])
axs[1].set_xticks([2, 16, 32, 64, 128, 256, 512])
axs[1].set_ylabel('Perf Recovery')
axs[1].set_xlabel('N Heads')

axs[1].axhline(0, color='gray', linestyle='--')

axs[1].set_xscale('log')
axs[1].set_xticks([2, 16, 32, 64, 128, 256, 512])
axs[1].set_xlim([2, 512])
axs[1].xaxis.set_major_formatter(ScalarFormatter())
axs[1].xaxis.get_major_formatter().set_useOffset(False)
axs[1].minorticks_off()
axs[1].set_yticks(np.linspace(0, 0.4, 5))
axs[1].tick_params('x', labelsize=15, rotation=45)
axs[1].tick_params('y', labelsize=15)

axs[1].set_ylabel('')
axs[1].set_xlabel('')

axs[1].set_title('Post-Trained minus Base', fontsize=17)

fig.supxlabel('Number of Heads ($\mathcal{A}_{n}$)', fontsize=15, y=0.07)


plt.tight_layout(w_pad=0.001)
plt.savefig('../figures/activation_patching/head_tasks_instr.pdf', dpi=300, bbox_inches='tight')

# Instruction performance of Task vectors

In [None]:
head_n_models = tv_peaks.copy()

head_n_models = head_n_models.drop(columns=['for_code_consistency'])
head_n_models = head_n_models.set_index('model_name')
heads_n_models = head_n_models.rename(columns=data_short_map)
head_n_models = head_n_models.reindex(instr_order)

even_rows = head_n_models.iloc[0::2].reset_index(drop=True)
odd_rows = head_n_models.iloc[1::2].reset_index(drop=True)   
head_n_models = odd_rows - even_rows
head_n_models.index = diff_instr

plt.figure(figsize=(4.5, 4))

head_n_models = head_n_models.reset_index(names='model_name')
head_n_models = pd.melt(head_n_models, id_vars='model_name', var_name='Task', value_name='Performance Recovery')
ax = sns.lineplot(data=head_n_models, x='Task', y='Performance Recovery', linewidth=2, color='black')

plt.axhline(0, color='gray', linestyle='--')


ax.set_yticks(np.linspace(-0.4, 0.2, 7))
ax.tick_params('x', labelsize=14, rotation=45)
ax.tick_params('y', labelsize=14)

ax.set_xlabel("ICL Task", fontsize=15)
ax.set_ylabel("$\Delta$ Performance Recovery", fontsize=15)

plt.title('Post-Trained minus Base', fontsize=18)
plt.tight_layout()
plt.savefig('../figures/activation_patching/tv_instr.pdf', dpi=300, bbox_inches='tight')

# 6 subplots with Performance (default/gridsearch/TV, avgs/peaks)

In [None]:
d_scores = pd.concat(all_subplots_data[0], axis=0)
d_peaks = pd.concat(all_subplots_data[1], axis=0)
g_scores = pd.concat(all_subplots_data[2], axis=0)
g_peaks = pd.concat(all_subplots_data[3], axis=0)
to_plot = [d_scores, d_peaks, g_scores, g_peaks, tv_scores, tv_peaks]
plot_titles = ['FV Default Param', 'FV Default Param', 'FV Param Search', 'FV Param Search', 'Task Vector', 'TV']

fig, axs = plt.subplots(1, len(to_plot), figsize=(13.5, 21))#, dpi=300)

vmin = 0
vmax = 1.0
norm = plt.Normalize(vmin=vmin, vmax=vmax)

for i, (fv_df, title) in enumerate(zip(to_plot, plot_titles)):
    # plot best performance across tasks and models
    bp = fv_df.copy()
    bp = bp.fillna(0)
    bp['best_perf'] = bp.iloc[:, 2:].max(axis=1)

    if i < 4:
        bp = bp.pivot(index='model_name', columns='data_name', values='best_perf')#.fillna(0)
        bp['from-eng'] = bp[['english-french', 'english-german', 'english-spanish']].mean(axis=1)
        bp['to-eng'] = bp[['french-english', 'german-english', 'spanish-english']].mean(axis=1)
    else:
        bp.index = bp['model_name']
        bp = bp.drop('model_name', axis=1)        

    bp = bp.reindex(model_order)
    bp.index = bp.index.map(STAGE_NAME)
    if i < 2:
        palette = sns.color_palette("rocket", as_cmap=True)
    elif i < 4:
        palette = sns.color_palette("mako", as_cmap=True)
    else:
        palette = sns.color_palette("cividis", as_cmap=True)
    ax = sns.heatmap(bp[['antonym', 'present-past', 'country-capital', 'to-eng', 'from-eng']], annot=False, cmap=palette, fmt='.2f', ax=axs[i], cbar=False, norm=norm)
    ax.hlines(np.arange(0, bp.shape[0]), *ax.get_xlim(), color='black', linewidth=1.5)
    
    axs[i].set_ylabel('')
    xticks = [0.5, 1.5, 2.5, 3.5, 4.5]
    xtick_labels = ['a', 'b', 'c', 'd', 'e']
    axs[i].set_xticks(xticks, xtick_labels)
    axs[i].tick_params(axis='x', labelsize=45, rotation=0, length=10)
    axs[i].set_xlabel('')

    if (i+1)%2 == 0:
        axs[i].set_title('Peak', fontsize=40)
    else:
        axs[i].set_title('Average', fontsize=40)

    if i > 0:
        axs[i].set_yticks([])
    else:
        axs[i].tick_params(axis='y', labelsize=45, rotation=0)

cbar_ax1 = fig.add_axes([0.235, 0.08, 0.25, 0.02])  # [left, bottom, width, height]
cbar_ax2 = fig.add_axes([0.568, 0.08, 0.25, 0.02])
cbar_ax3 = fig.add_axes([0.908, 0.08, 0.25, 0.02])
for palette_name, cbar_ax in zip(['rocket', 'mako', 'cividis'], [cbar_ax1, cbar_ax2, cbar_ax3]):
    cbar = fig.colorbar(plt.cm.ScalarMappable(cmap=sns.color_palette(palette_name, as_cmap=True), norm=norm), cax=cbar_ax, ticks=np.linspace(0, 1.0, 3), orientation="horizontal")
    cbar.ax.tick_params(labelsize=45)

fig.text(0.5275, 1.04, 'Function Vector', ha='center', fontsize=50, transform=fig.transFigure)
fig.text(0.36, 1, 'Default Param', ha='center', fontsize=50, transform=fig.transFigure)
fig.text(0.695, 1, 'Param Search', ha='center', fontsize=50, transform=fig.transFigure)
fig.text(1.025, 1, 'Task Vector', ha='center', fontsize=50, transform=fig.transFigure)
plt.tight_layout(rect=[0, 0.1, 1.2, 1])
plt.savefig('../figures/activation_patching/6_plots_acc_all_models.pfv_df', dpi=300, bbox_inches='tight')

# Plot best Layer, Lambda, Heads

In [None]:
# best_param = r'Number of Heads ($\mathcal{A}_{n}$)'
# best_param = r'Function Vector Strength ($\lambda$)'
best_param = r'Activation Patching Layer ($\ell$)'

# plot best performance across tasks and models
bp = fv_df.copy()

if best_param[0] in ['A']:
    layer_perf = pd.DataFrame(all_layers)
    layer_perf.insert(0, 'model_name', model_names_l)
    layer_perf.insert(1, 'data_name', tasks_l)
    bp = layer_perf
    
    
bp = bp.fillna(0)
bp['best_perf'] = bp.iloc[:, 2:].max(axis=1)
bp = bp.query("best_perf > 0.2")


argmax_map = {i: j for i, j in enumerate(bp.iloc[:, 2:].columns)}
bp['best_perf'] = np.argmax(bp.iloc[:, 2:], axis=1)
# bp['best_perf'] = bp['best_perf'].map(argmax_map).astype('int')
tasks = ['antonym', 'present-past', 'country-capital', 'english-french', 'english-german', 'english-spanish', 'french-english', 'german-english', 'spanish-english']



bp_raw = bp.copy()

bp = bp.pivot(index='model_name', columns='data_name', values='best_perf')
bp['from-english'] = bp[['english-french', 'english-german', 'english-spanish']].mean(axis=1)
bp['to-english'] = bp[['french-english', 'german-english', 'spanish-english']].mean(axis=1)

if best_param[0] in ['N', 'F']:
    cmap = mcolors.ListedColormap(
        sns.color_palette("mako", n_colors=8)
    )
    vmin = bp.min().min()
    vmax = bp.max().max()
else:
    bp = bp.div(bp.index.map(n_layers_l), axis=0) * 100
    cmap = sns.color_palette('mako', as_cmap=True)
    vmin = 0
    vmax = 100

bp = bp.reindex(model_order)
bp.index = bp.index.map(STAGE_NAME)
bp = bp[tasks]
bp = bp.rename(columns=data_short_map)
plt.figure(figsize=(13.5, 20))


ax = sns.heatmap(bp, cmap=cmap, vmin=vmin, vmax=vmax)
ax.hlines(np.arange(0, bp.shape[0]), *ax.get_xlim(), color='black', linewidth=1.5)

colorbar = ax.collections[0].colorbar
if best_param[0] in ['N', 'F']:
    if best_param[0] == 'N':
        category_mapping = {0: 2, 1: 16, 2: 32, 3: 64, 4: 128, 5: 256, 6: 512, 7: 1024}
        colorbar.set_ticks(np.linspace(0.5, 6.5, 8))
    elif best_param[0] == 'F':
        category_mapping = {0: 0.5, 1: 1.0, 2: 2.0, 3: 4.0, 4: 8.0, 5: 16.0, 6: 32.0, 7: 64.0}
        colorbar.set_ticks(np.linspace(0.5, 6.5, 8))

    colorbar.set_ticklabels(list(category_mapping.values()), fontsize=30)
else:
    colorbar.set_ticks(np.linspace(0, 100, 5))  # Ensure colorbar ticks go from 0 to 1
    colorbar.ax.tick_params(labelsize=30)

# cbar = fig.colorbar(plt.cm.ScalarMappable(cmap=sns.color_palette(palette_name, as_cmap=True), norm=norm), cax=cbar_ax, ticks=np.linspace(0, 1.0, 3), orientation="horizontal")
#     cbar.ax.tick_params(labelsize=45)

ax.tick_params('x', rotation=45, labelsize=40)
ax.tick_params('y', labelsize=40)
plt.xticks()
plt.title(rf'Best {best_param}', fontsize=40)
plt.xlabel('')
plt.ylabel('')
plt.tight_layout()
plt.savefig(f'../figures/activation_patching/best_param/best_{best_param}.pdf', dpi=300, bbox_inches='tight')

# Plot 5-shot vs. Peak Performance Recovery

In [None]:
di = {'FV Default Param': [], 'FV Param Search': [], 'Task Vector': []}
experiments = ['FV Default Param', 'FV Param Search']
shots = ['5', '5_90', '5_75', '5_50']
for recovery, experiment in zip(all_recovery, experiments):
    for shot in shots:
        recovered = sum(recovery[shot])
        total_n = len(recovery[shot])
        di[experiment].append(recovered/total_n)
        print(f'{experiment}, {shot}-shot, {recovered} out of {total_n}, {recovered/total_n}')
    print()
    
shots = ['5', '5_90', '5_75', '5_50']
for shot in shots:
    recovered = sum(tv_recovery[shot])
    total_n = len(tv_recovery[shot])
    di['Task Vector'].append(recovered/total_n)
    print(f'Task Vector, {shot}-shot, {recovered} out of {total_n}, {recovered/total_n}')

In [None]:
mr = pd.DataFrame(di)
mr = mr.iloc[::-1, :].reset_index(drop=True)
mr.index = [0.5, 0.75, 0.9, 1.0]
display(mr)
plt.figure(figsize=(3.5,3))
ax = sns.lineplot(mr, linewidth=3)
ax.set_xticks([0.5, 0.75, 0.9, 1.0])
ax.set_yticks(np.linspace(0, 1, 5))
ax.set_xlabel('Percent of 5-shot Performance')
ax.set_ylabel('Samples Above Threshold')
plt.legend(loc='upper right', framealpha=0)
plt.show()