In [None]:
import os
import sys
project_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(project_dir)

import json
import numpy as np
import pandas as pd


import seaborn as sns
from matplotlib import pyplot as plt
from matplotlib.ticker import MultipleLocator
plt.rcParams['text.usetex'] = True
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = 'Times New Roman'

from config import IMPLEMENTED_MODELS, MODEL_FP_MAP, STAGE_NAME_LONG, STAGE_NAME_DATA

In [None]:
models = [
    # 'gptj-6b',       
    # 'pythia-6.9b',
    # 'llama-7b',
    # 'llama2-7b',     
    # 'llama2i-7b',
    # 'llama2-13b',     
    # 'llama2i-13b',
    # 'llama2-70b',
    # 'llama3_0-8b',
    # 'llama3-8b',
    # 'llama3i-8b',
    # 'llama3-70b',
    # 'llama3i-70b',
    # 'llama3.2-3b',  
    # 'mistral1-7b',    
    # 'mistral3-7b',    
    # 'mistral3i-7b',
    # 'gemma2-2b',      
    'gemma2-9b',      
    # 'gemma2i-9b',
    # 'gemma2-27b',    
    # 'qwen2-1.5b', 
    # 'qwen2-7b',      
    # 'qwen2i-7b',  
    # 'qwen2-72b',
    # 'qwen2.5-3b',
    # 'qwen2.5i-3b',
    # 'qwen2.5-7b',
    # 'qwen2.5i-7b',
    # 'qwen2.5-14b',
    # 'olmo-7b-20BT',
    # 'olmo-7b-50BT',
    # 'olmo-7b-100BT',
    # 'olmo-7b-2700BT',
    # 'olmo-7b',        
    # 'olmos-7b',         
    # 'olmoi-7b',
    # 'olmo2-7b',
    # 'olmo2i-7b',
    # 'olmo2-13b',
    # 'amber-7b-21BT',
    # 'amber-7b-49BT',
    # 'amber-7b-101BT',
    # 'amber-7b',
    # 'falcon3-7b',
]

data_names = [
    # 'antonym',
    # 'english-french',
    # 'english-german',
    # 'english-spanish',
    # 'french-english',
    # 'german-english',
    # 'spanish-english',
    # 'present-past',
    # 'country-capital',
    # 'colors',
    'tqa',
]
def get_task_accuracy(model_name, data_name):
    acc_fp = f'../function_vectors/logit_lens/results/token_probs/{model_name}/{data_name}/token_rank_by_layer.csv'
    token_rank_by_layer = pd.read_csv(acc_fp)
    last_layer = token_rank_by_layer['layerid'].max()
    last_layer_token_ranks = token_rank_by_layer.query(
        f"layerid == {last_layer}"
    )
    task_acc = np.mean(last_layer_token_ranks['correct_rank'] == 1)
    return task_acc, last_layer

# Plot Logit Lens

In [None]:
palette = {
    'Predicted': '#44AA99',
    'Correct': '#117733',
    'Incorrect': '#CC6677',
    'Top': '#5C5C5D',
}
for model_name in models:
    for data_name in data_names:
        ll_fp = f'../function_vectors/logit_lens/results/token_probs/{model_name}/{data_name}/token_probabilities.csv'
        if not os.path.exists(ll_fp):
            continue
        task_acc, n_layers = get_task_accuracy(model_name, data_name)
        token_probs = pd.read_csv(ll_fp)
        token_probs = token_probs.rename(columns={'pred_prob': 'Predicted', 'correct_prob': 'Correct', 'second_prob': 'Incorrect', 'top_prob': 'Top'})
        to_plot = ['layerid', 'Top', 'Correct', 'Incorrect']
        token_probs = pd.melt(
            token_probs.loc[:, to_plot], id_vars='layerid', var_name='metric_name', value_name='metric'
        )
        plt.figure(figsize=(4, 2.5))
        ax = sns.lineplot(data=token_probs, x='layerid', y='metric', hue='metric_name', style='metric_name', palette=palette, linewidth=2)
        ax.set_title(f'{STAGE_NAME_LONG[model_name]}', fontsize=20)# | task_acc = {task_acc}')
        ax.set_yticks(np.linspace(0, 1, 6))
        try:
            model_size = float(model_name.split('-')[-1][:-1])
        except:
            spaces = 8
        if model_size > 13:
            spaces = 16
        else:
            spaces = 8
        ax.set_xticks(np.arange(0, n_layers+1, spaces))
        ax.set_xlim([0, n_layers])
        ax.tick_params('both', labelsize=12)
        ax.set_xlabel('Layer', fontsize=12)
        ax.set_ylabel('Token Probability', fontsize=15)
        plt.legend(title='', framealpha=0, loc='lower left', bbox_to_anchor=(-0.02, 0), fontsize=13)
        plt.tight_layout()
        plt.savefig(f'../figures/logit_lens/{model_name}_tfqa.pdf', bbox_inches='tight', dpi=300)

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(7,6))
indices = [(0,0), (0,1), (1,0), (1,1)]
data_name = data_names[0]
for k, model_name in enumerate(['llama2-70b','llama3_0-70b','qwen2-7b', 'gemma2-9b']):
    model_size = float(model_name.split('-')[1][:-1])
    i, j = indices[k]
    if i == 0 and j == 0:
        legend = True
    else:
        legend = False

    palette = {
        # 'Pred. Token': '#000000',
        'Correct': '#117733',
        'Incorrect': '#CC6677',
        'Top': '#5C5C5D',
    }
    ll_fp = f'../function_vectors/logit_lens/results/token_probs/{model_name}/{data_name}/token_probabilities.csv'
    if not os.path.exists(ll_fp):
        continue
    task_acc, n_layers = get_task_accuracy(model_name, data_name)
    token_probs = pd.read_csv(ll_fp)
    token_probs = token_probs.rename(columns={'correct_prob': 'Correct', 'second_prob': 'Incorrect', 'top_prob': 'Top'})
    to_plot = ['layerid', 'Top', 'Correct', 'Incorrect']
    token_probs = pd.melt(
        token_probs.loc[:, to_plot], id_vars='layerid', var_name='metric_name', value_name='metric'
    )
    
    sns.lineplot(data=token_probs, x='layerid', y='metric', hue='metric_name', style='metric_name', ax=axs[i][j], linewidth=2.5, legend=legend, palette=palette)
    
    axs[i][j].set_yticks(np.arange(0, 1.1, 0.5))
    if j == 0:
        if i == 0:
            incr = 16
        if i == 1:
            incr = 7
    elif j == 1:
        if i == 0:
            incr = 16
        if i == 1:
            incr = 7
    
    axs[i][j].set_xticks(np.append(np.arange(0, n_layers, incr), n_layers))
    # fontsize
    axs[i][j].tick_params(axis='x', labelsize=25, rotation=0, length=12, width=1)
    axs[i][j].tick_params(axis='y', labelsize=20)

    axs[i][j].set_title(STAGE_NAME_LONG[model_name], fontsize='25')
    if model_size > 9:
        axs[i][j].xaxis.set_minor_locator(MultipleLocator(2))
    else:
        axs[i][j].xaxis.set_minor_locator(MultipleLocator(1))
    axs[i][j].tick_params(axis='x', which='minor', length=6, width=1)
    axs[i][j].set_xlim(0, n_layers)

    axs[i][j].set_xlabel('')
    axs[i][j].set_ylabel('')
    if j == 1:
        axs[i][j].set_yticks([])
        

fig.supxlabel(r'Projected Layer ($\ell$)', fontsize='25', y=0.045, x=0.55)
fig.supylabel('Projected Token Probability', fontsize='25', x=0.03, y=0.56)
axs[0][0].legend(loc='upper left', fontsize=17, framealpha=0, bbox_to_anchor=(-0.04, 1.08))
plt.tight_layout()
plt.savefig(f'../figures/logit_lens/4_logit_lens_others.pdf', dpi=300, bbox_inches='tight')

# Plot Apathy

In [None]:
import function_vectors.logit_lens.utils.ll_helpers as llh

In [None]:
i_idx = [0, 1, 2, 3, 4, 0, 1, 2, 3, 4]
j_idx = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
fig, axs = plt.subplots(5, 2, figsize=(6,12))
for i, model_name in enumerate(models):
    for data_name in data_names:
        if model_size > 13:
            spaces = 32
        else:
            spaces = 8
        ll_fp = f'../function_vectors/logit_lens/results/token_probs/{model_name}/{data_name}/layer_analyses.csv'
        if not os.path.exists(ll_fp):
            continue
        layer_analyses, n_layers = get_task_accuracy(model_name, data_name)
        layer_analyses = pd.read_csv(ll_fp)
        layer_simple = layer_analyses[['layerid', 'hidden_name', 'apathy']]
        layer_simple = layer_simple.groupby(['layerid', 'hidden_name']).agg('mean').reset_index()
        layer_simple = layer_simple[layer_simple['hidden_name'].isin(['h_mha', 'h_mlp'])]
        layer_simple['apathy'] = llh.minmax(layer_simple['apathy'])
        metric_map = {'h_mha': 'MHA', 'h_mlp': 'MLP', 'Top': 'Top'}
        layer_simple['hidden_name'] = layer_simple['hidden_name'].map(metric_map)
        legend = True if i_idx[i] == 0 and j_idx[i] == 0 else False
        sns.lineplot(data=layer_simple, x='layerid', y='apathy', hue='hidden_name', style='hidden_name', ax=axs[i_idx[i]][j_idx[i]], linewidth=2, legend=legend, dashes=[(4, 1), (1,1)])

        ll_fp = f'../function_vectors/logit_lens/results/token_probs/{model_name}/{data_name}/token_probabilities.csv'
        if not os.path.exists(ll_fp):
            continue
        task_acc, n_layers = get_task_accuracy(model_name, data_name)
        token_probs = pd.read_csv(ll_fp)
        token_probs = token_probs.rename(columns={'pred_prob': 'Predicted', 'correct_prob': 'Correct', 'second_prob': 'Incorrect', 'top_prob': 'Top'})
        to_plot = ['layerid', 'Top']# , 'Correct', 'Incorrect']
        token_probs = pd.melt(
            token_probs.loc[:, to_plot], id_vars='layerid', var_name='metric_name', value_name='metric'
        )
        token_probs['metric'] = llh.minmax(token_probs['metric'])
        palette = {
            # 'Pred. Token': '#000000',
            'Correct': '#117733',
            'Incorrect': '#CC6677',
            'Top': '#5C5C5D',
        }
        
        
        sns.lineplot(data=token_probs, x='layerid', y='metric', hue='metric_name', style='metric_name', palette=palette, linewidth=2, ax=axs[i_idx[i]][j_idx[i]], legend=legend)
        
        axs[i_idx[i]][j_idx[i]].set_title(f'{STAGE_NAME_LONG[model_name]}', fontsize=22)# | task_acc = {task_acc}')
        axs[i_idx[i]][j_idx[i]].set_yticks(np.linspace(0, 1, 6))
        axs[i_idx[i]][j_idx[i]].set_xticks(np.arange(0, n_layers+1, spaces))
        axs[i_idx[i]][j_idx[i]].set_xlim([0, n_layers])
        if j_idx[i] == 1:
            axs[i_idx[i]][j_idx[i]].set_yticks([])
        axs[i_idx[i]][j_idx[i]].set_ylabel('')
        axs[i_idx[i]][j_idx[i]].set_xlabel('')
        axs[i_idx[i]][j_idx[i]].tick_params('both', labelsize=15)
        if legend:
            axs[i_idx[i]][j_idx[i]].legend(title='', framealpha=0, loc='upper left', bbox_to_anchor=(0, 1), fontsize=13)
        fig.supxlabel('Layer', fontsize=25)
        fig.supylabel('Token Probability and Apathy', fontsize=25)
        fig.suptitle('Apathy and Token Probabilities on TruthfulQA', fontsize=20)
        plt.tight_layout()
        plt.savefig(f'../figures/logit_lens/apathy.pdf', bbox_inches='tight', dpi=300)

# Plot Figure 1

In [None]:
def get_baseline(model_name, data_name):
    """Read a JSON file and return its contents as a dictionary."""
    file_path = f"../function_vectors/results/{model_name}/{data_name}/baseline_n_shots.json"
    try:
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
        return data
    except Exception as e:
        print(f"An error occurred: {e}")
        return None

In [None]:
fig, axs = plt.subplots(3, 3, figsize=(8.5,7))
data_names = ['tqa', 'present-past', 'present-past'] #country-capital
fv_tv_tp = [ 
    [
        ('olmo-7b', ''),
        ('qwen2-72b', ''),
        ('llama-7b', '')
    ],
    [
        ('mistral3-7b', 'Mistral-7B-v0.3'),
        # ('llama3-70b', 'Llama-3.1-70B'),
        ('llama3i-70b', 'Llama-3.1-70B-Instruct'),
        # ('mistral1-7b', 'Mistral-7B-v0.1'),
        # ('gptj-6b', 'gpt-j-6b')
        ('llama2-7b', 'Llama-2-7b-hf')
    ],
    [
        # ('olmo-7b', 'OLMo-7B-0724-hf'),
        ('pythia-6.9b', 'pythia-6.9b'),
        ('qwen2.5-7b', 'Qwen2.5-7B'),
        # ('qwen2i-7b', 'Qwen2-7B-Instruct'),
        # ('olmoi-7b', 'OLMo-7B-0724-Instruct-hf'),
        # ('qwen2.5-3b', 'Qwen2.5-3B'),
        # ('llama3-8b', 'Meta-Llama-3.1-8B'),
        #  ('qwen2-7b', 'Qwen2-7B'),
        #  ('llama2-13b', 'Llama-2-13b-hf'),
        #  ('llama-7b', 'llama-7b'),
         ('gptj-6b', 'gpt-j-6b')
    ]
]
heads_lambdas = [
    ['016', '002'], # mistral3i
    ['064', '001'], # llama3-70b
    # ['016', '004'], # mistral1
    # ['016', '004'], # gptj
    ['002', '016'], # gptj
]
layers = [
    [8, 16, 8],
    [8, 16, 8],
    [8, 7, 7],
]
palette = {
    'Function Vector': '#000000',
    'Task Vector': '#000000',
    '0-shot': '#D55E00',
    '5-shot': '#332288',
    # For logit lens
    'Correct': '#117733',
    'Incorrect': '#CC6677',
    'Top': '#5C5C5D',
}
for i, models in enumerate(fv_tv_tp):
    data_name = data_names[i]
    for j, (model_name, hf_name) in enumerate(models):
        
        # configs
        model_fp = MODEL_FP_MAP[model_name]
        model_config = json.load(open(os.path.join(model_fp, 'config.json')))
        if model_name in ['gptj-6b']:
            n_layers = model_config['n_layer']
        else:
            n_layers = model_config['num_hidden_layers']
        # for plots
        if i == 0:
            ll_fp = f'../function_vectors/logit_lens/results/token_probs/{model_name}/{data_name}/token_probabilities.csv'
            if not os.path.exists(ll_fp):
                continue
            task_acc, n_layers = get_task_accuracy(model_name, data_name)
            token_probs = pd.read_csv(ll_fp)
            token_probs = token_probs.rename(columns={'pred_prob': 'Predicted', 'correct_prob': 'Correct', 'second_prob': 'Incorrect', 'top_prob': 'Top'})
            to_plot = ['layerid', 'Top', 'Correct', 'Incorrect']
            token_probs = pd.melt(
                token_probs.loc[:, to_plot], id_vars='layerid', var_name='metric_name', value_name='metric'
            )
            legend = True if j == 0 else False
            sns.lineplot(data=token_probs, ax=axs[i][j], x='layerid', y='metric', hue='metric_name', style='metric_name', palette=palette, legend=legend, linewidth=3)
            axs[i][j].set_xlabel('')
            axs[i][j].set_ylabel('')
        else:
            # For function vectors, first row of plot
            if i == 1:
                baseline_n_shots = get_baseline(model_name, data_name)
                # print(model_name, baseline_n_shots)
                zero_shot = dict(zip([str(i) for i in range(n_layers)], [baseline_n_shots['0']] * n_layers))
                five_shot = dict(zip([str(i) for i in range(n_layers)], [baseline_n_shots['5']] * n_layers))
                saved_fname = f'{hf_name}_{data_name}_{heads_lambdas[j][0]}-top-heads_{heads_lambdas[j][1]}-lambda_50-test-samples'
                dp = f"../function_vectors/results/{model_name}/{data_name}/0_shot_w_FV"
                results_fp = os.path.join(dp, f"{saved_fname}.json")
                if os.path.isfile(results_fp):
                    zero_shot_fv = json.load(open(results_fp))
                    to_plot = pd.DataFrame({
                            'Function Vector': zero_shot_fv,
                            '0-shot': zero_shot,
                            '5-shot': five_shot,
                        })           
            if i == 2:
                model_name = model_name.replace('-', '_')
                if os.path.exists(f'../icl_task_vectors/for_fv/{model_name}.csv'):
                    tv_df = pd.read_csv(f'../icl_task_vectors/for_fv/{model_name}.csv')
                    baselines = pd.read_csv(f'../icl_task_vectors/for_fv/{model_name}_baselines.csv')
                    model_name = model_name.replace('_', '-')
                    zero_shot = baselines.loc[0, data_name]
                    five_shot = baselines.loc[1, data_name]
                    to_plot = pd.DataFrame({
                        'Task Vector': tv_df[data_name],
                        '0-shot': zero_shot,
                        '5-shot': five_shot,
                    })
            legend = True if j == 0 else False
            sns.lineplot(to_plot, ax=axs[i][j], legend=legend, linewidth=3, palette=palette)  
        # axs[i][j].xaxis.set_minor_locator(MultipleLocator(1))
        ylabels = ['{(DoLA)\nToken Probability', '(Function Vectors)\nICL Accuracy', '(Task Vectors)\nICL Accuracy']
        if j == 0:
            axs[i][j].set_ylabel(ylabels[i], fontsize=18, rotation=90)
            axs[i][j].yaxis.label.set_position((-1, 0.5))
        # yticks
        axs[i][j].yaxis.tick_right()
        axs[i][j].set_yticks([0, 1])
        if j in [0, 1]:
            axs[i][j].set_yticklabels([])
            axs[i][j].set_yticks([])
        
        axs[i][j].set_title(f'{STAGE_NAME_LONG[model_name]}', fontsize=20, rotation=0) #{STAGE_NAME_DATA[data_name]}
        # xticks
        axs[i][j].set_xticks(np.append(np.arange(0, n_layers, layers[i][j]), n_layers-1))
        axs[i][j].set_xlim(0, n_layers-1)
        # axs[i][j].tick_params(axis='x', which='minor', length=8, width=1)
            
            
        axs[i][j].tick_params(axis='x', labelsize=15, length=6, width=1)
        axs[i][j].tick_params(axis='y', labelsize=15)
        if j == 0:
            leg = axs[i][j].legend(
                loc='upper left', fontsize=15.5, framealpha=0, bbox_to_anchor=(-.05, 0.95)
            )
            if i == 0:
                leg = axs[i][j].legend(
                    loc='upper left', fontsize=16, framealpha=0, bbox_to_anchor=(-.05, 1.0)
                )
    
plt.suptitle(' ', fontsize=25)
fig.text(0.365, .935, r'Unreliable Behavior', ha='center', fontsize=30)
fig.text(0.833, .935, r'Prior Work', ha='center', fontsize=30)
fig.supxlabel(r'Model layer ($\ell$)', fontsize=20, y=0.03)

plt.tight_layout(w_pad=0.0001)
plt.savefig('../figures/figure1.pdf', dpi=300, bbox_inches='tight')
