In [None]:
import os
import gc
import sys

import torch
import numpy as np

from tqdm import tqdm

from matplotlib import pyplot as plt
plt.rcParams['text.usetex'] = True
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = 'Times New Roman'
import seaborn as sns

sys.path.append('..')
from config import HF_NAME_MAP

In [None]:
tasks = [
    'antonym',
    'english-french',
    'english-german',
    'english-spanish',
    'french-english',
    'german-english',
    'spanish-english',
]
for model_name, hf_name_full in tqdm(HF_NAME_MAP.items()):
    fig, axs = plt.subplots(1, len(tasks), figsize=(10, 3))
    
    needs_cbar = False
    for i in range(len(tasks)):
        indirect_effect = 0
        indirect_effect_path = f'../function_vectors/results/{model_name}/{tasks[i]}/{tasks[i]}_indirect_effect.pt'
        if os.path.isfile(indirect_effect_path):
            indirect_effect = torch.load(indirect_effect_path, weights_only=False)
            cie = torch.mean(indirect_effect, dim=0)
            if i == len(tasks) - 1:
                needs_cbar = True
            sns.heatmap(cie.T, cmap=sns.color_palette("vlag_r", as_cmap=True), ax=axs[i], cbar=needs_cbar)
            if i == 0:
                axs[i].set_ylabel('Head Index')
            else:
                axs[i].set_ylabel('')
            axs[i].set_xlabel('Layer')
            axs[i].set_yticks(np.linspace(0, cie.T.shape[0], 5), np.linspace(0, cie.T.shape[0], 5), rotation=0)
            axs[i].set_title(tasks[i])
            if i > 0:
                axs[i].tick_params(labelleft=False, left=False)
    plt.suptitle(f'Casual Indirect Effect | {model_name}')
    plt.tight_layout()
    # plt.show()
    # break
    plt.savefig(f'../figures/locality/{model_name}.pdf', dpi=300)
    plt.close()
    del fig, axs
    gc.collect()