In [3]:
import os
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from omegaconf import OmegaConf
from catboost import CatBoostClassifier
from transformers import AutoTokenizer
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
import evaluate
from transformer_lens import HookedTransformer

from metrics import (
    EM_compute,
    has_answer, # InAcc
    F1_compute,
)

bleu_metric = evaluate.load("bleu")
rouge_metric = evaluate.load("rouge")

In [2]:
cfg = OmegaConf.load("config.yaml")
device = "cuda:0" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(cfg.model_id)
model = HookedTransformer.from_pretrained(
    cfg.model_id,
    device=device,
    tokenizer=tokenizer,
)
model.set_use_attn_result(True)
model.eval()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model meta-llama/Llama-3.1-8B-Instruct into HookedTransformer


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-31): 32 x TransformerBlock(
      (ln1): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_att

In [3]:
nq_dataset = pd.read_csv('data/adaptive_rag_natural_questions.csv')
wiki_dataset = pd.read_csv('data/adaptive_rag_2wikimultihopqa.csv')
hotpot_dataset = pd.read_csv('data/adaptive_rag_hotpotqa.csv')
musique_dataset = pd.read_csv('data/adaptive_rag_musique.csv')

In [4]:
nq_train, nq_test = train_test_split(nq_dataset, test_size=0.2, random_state=cfg.seed)
nq_train_idx = nq_train.index
nq_test_idx = nq_test.index

wiki_train, wiki_test = train_test_split(wiki_dataset, test_size=0.2, random_state=cfg.seed)
wiki_train_idx = wiki_train.index
wiki_test_idx = wiki_test.index

hotpot_train, hotpot_test = train_test_split(hotpot_dataset, test_size=0.2, random_state=cfg.seed)
hotpot_train_idx = hotpot_train.index
hotpot_test_idx = hotpot_test.index

musique_train, musique_test = train_test_split(musique_dataset, test_size=0.2, random_state=cfg.seed)
musique_train_idx = musique_train.index
musique_test_idx = musique_test.index

In [5]:
def store_activations(activation_dict):
    def hook_fn(activation, hook):
        activation_dict[hook.name] = activation.detach()
    return hook_fn

In [6]:
def make_data(dataset, activation, layer):
    X = torch.empty((dataset.shape[0], 4096), dtype=torch.float32)
    y = torch.empty((dataset.shape[0]), dtype=torch.int64)
    
    idx = 0
    activation_store = {}
    
    for _, row in tqdm(dataset.iterrows(), total=len(dataset)):
    
        text = tokenizer.apply_chat_template([
            {"role": "user", "content": f"question: {row['question']}\nanswer: {row['our_answer_wo_context']}"}
        ], tokenize=False)
        
        tokens = model.to_tokens(text, prepend_bos=False).to(device)
        with torch.no_grad():
            # (1, seq_len, vocab_size)
            logits = model.run_with_hooks(
                tokens,
                return_type="logits",
                fwd_hooks=[(f"blocks.{i}.{activation}", store_activations(activation_store)) 
                    for i in range(model.cfg.n_layers)],
            )
        resid = activation_store[f"blocks.{layer}.{activation}"] # last hidden_state before lm head (1, seq_len, hidden_size)
        final_logits = resid[:, -1, :].squeeze(dim=0) # (vocab_size)
    
        X[idx] = final_logits.cpu()
        f1 = F1_compute(row['reference'].split(';'), row['our_answer_wo_context'])
        y[idx] = 1 if f1 >= 0.55 else 0
        idx += 1
    return X, y

In [7]:
def process_dataframe(df, pred_col, gt_col):
    total_has_answer = 0
    total_em = 0
    total_f1 = 0
    count = len(df)
    has_answer_arr = []
    em_arr = []
    f1_arr = []

    predictions_for_bleu_rouge = []
    references_for_bleu_rouge = []

    for _, row in df.iterrows():
        prediction = row[pred_col]
        ground_truths = row[gt_col].split(';')

        has_ans = has_answer(ground_truths, prediction)
        has_answer_arr.append(has_ans)

        em = EM_compute(ground_truths, prediction)
        em_arr.append(em)

        f1 = F1_compute(ground_truths, prediction)
        f1_arr.append(f1)

        total_has_answer += has_ans
        total_em += em
        total_f1 += f1

        predictions_for_bleu_rouge.append(prediction)
        references_for_bleu_rouge.append(ground_truths)

    mean_has_answer = total_has_answer / count if count > 0 else 0
    mean_em = total_em / count if count > 0 else 0
    mean_f1 = total_f1 / count if count > 0 else 0

    bleu_results = bleu_metric.compute(
        predictions=predictions_for_bleu_rouge,
        references=references_for_bleu_rouge
    )
    rouge_results = rouge_metric.compute(
        predictions=predictions_for_bleu_rouge,
        references=references_for_bleu_rouge
    )
    bleu_score = bleu_results["bleu"]
    rouge_score = rouge_results["rougeL"]

    return mean_has_answer, mean_em, mean_f1, bleu_score, rouge_score, df

# Make activations data

In [4]:
for activation in ('hook_resid_post'): #, 'hook_resid_pre', 'hook_resid_mid'):
    for layer in (2, 12, 28, 31):
        for dataset, dataset_name in zip(
        (nq_dataset, wiki_dataset, hotpot_dataset, musique_dataset),
        ('NQ', 'wiki', 'hotpot', 'musique')
    ):
            X_path = f'../data/2stage_forward_data/X_{dataset_name}_{activation}_{layer}_layer.pt'
            y_path = f'../data/2stage_forward_data/y_{dataset_name}_{activation}_{layer}_layer.pt'
            if os.path.exists(X_path):
                X = torch.load(X_path)
                y = torch.load(y_path)
            else:
                X, y = make_data(dataset, activation, layer)
                torch.save(X, X_path)
                torch.save(y, y_path)

# Models

In [5]:
rag_calls = {}
pred_datasets = {}

for dataset, dataset_name in zip(
        (nq_dataset, wiki_dataset, hotpot_dataset, musique_dataset),
        ('NQ', 'wiki', 'hotpot', 'musique')
    ):
    train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=cfg.seed)
    train_idx = train_dataset.index
    test_idx = test_dataset.index
    rag_calls[dataset_name] = {
        'Never RAG': 0, 
        'Dola': 0,
        'Always RAG': test_dataset.shape[0], 
        'Always RAG (Supportive)': test_dataset[test_dataset["is_supportive"] == True].shape[0],
        'Always RAG (Non-Supportive)': test_dataset[test_dataset["is_supportive"] == False].shape[0]
    }
    
    def call_rag(preds_df, preds, method_name, rag_calls):
        answers = []
        for index, pred_class in tqdm(enumerate(preds)):
            
            if pred_class == 0: # call rag and get context
                final_answer = preds_df.iloc[index]['our_answer_w_context']
            else: # answer is correct, do not call rag
                final_answer = preds_df.iloc[index]['our_answer_wo_context']
    
            answers.append(final_answer)
            
        preds_df[method_name] = answers
        rag_calls[method_name] = list(preds).count(0)
        
        return preds_df

    for activation in ('hook_resid_post'): #, 'hook_resid_pre', 'hook_resid_mid'):
        for layer in (2, 12, 28, 31):
            
            X_path = f'data/2stage_forward_data/X_{dataset_name}_{activation}_{layer}_layer.pt'
            y_path = f'data/2stage_forward_data/y_{dataset_name}_{activation}_{layer}_layer.pt'
            if os.path.exists(X_path):
                X = torch.load(X_path)
                y = torch.load(y_path)
            else:
                continue
                
            X_train = X[train_idx]
            y_train = y[train_idx]
            X_test = X[test_idx]
            y_test = y[test_idx]            

            # Unbalanced data + LogisticRegression
            clf = LogisticRegression(max_iter=1000)
            clf.fit(X_train, y_train)
            y_pred = clf.predict(X_test)
            test_dataset = call_rag(
                test_dataset, 
                y_pred, 
                f'log_reg_{layer}_layer_{activation}', 
                rag_calls[dataset_name]
            )
            # Downsampling + LogisticRegression
            class_0_indices = np.where(y_train == 0)[0]
            class_1_indices = np.where(y_train == 1)[0]
            
            downsampled_class_0_indices = np.random.choice(class_0_indices, size=int(len(class_1_indices)*1.5), replace=False)
            balanced_indices = np.concatenate([downsampled_class_0_indices, class_1_indices])
            np.random.shuffle(balanced_indices)
            
            X_train_balanced = X_train[balanced_indices]
            y_train_balanced = y_train[balanced_indices]

            clf = LogisticRegression(max_iter=1000)
            clf.fit(X_train_balanced, y_train_balanced)
            y_pred = clf.predict(X_test)
            test_dataset = call_rag(
                test_dataset, 
                y_pred, 
                f'log_reg_downsampling_{layer}_layer_{activation}',
                rag_calls[dataset_name]
            )
            # Catboost
            pca = PCA(n_components=512)
            X_reduced = pca.fit_transform(X)
            X_reduced_train = X_reduced[train_idx]
            X_reduced_test = X_reduced[test_idx]
            
            clf = CatBoostClassifier(
                iterations=1000,
                depth=6,
                learning_rate=0.05,
                loss_function='CrossEntropy',
                verbose=100,
                task_type="GPU"
            )
            clf.fit(X_reduced_train, y_train.numpy(), eval_set=(X_reduced_test, y_test.numpy()), early_stopping_rounds=50, verbose=False)
            y_pred = clf.predict(X_reduced_test)
            test_dataset = call_rag(
                test_dataset, 
                y_pred, 
                f'catboost_{layer}_layer_{activation}',
                rag_calls[dataset_name]
            )
            
    pred_datasets[dataset_name] = test_dataset.copy()
    

# Metrics

In [61]:
methods = {
    "Never RAG": lambda df: (df, "our_answer_wo_context"),
    "Always RAG": lambda df: (df, "our_answer_w_context"),
    "Always RAG (Supportive)": lambda df: (
        df[df["is_supportive"] == True].copy(), 
        "our_answer_w_context"
    ),
    "Always RAG (Non-Supportive)": lambda df: (
        df[df["is_supportive"] == False].copy(), 
        "our_answer_w_context"
    ),
    "Dola": lambda df: (df, "our_answer_dola"),
    "log_reg_2_layer_hook_resid_post": lambda df: (df, "log_reg_2_layer_hook_resid_post"),
    "log_reg_12_layer_hook_resid_post": lambda df: (df, "log_reg_12_layer_hook_resid_post"),
    "log_reg_28_layer_hook_resid_post": lambda df: (df, "log_reg_28_layer_hook_resid_post"),
    "log_reg_31_layer_hook_resid_post": lambda df: (df, "log_reg_31_layer_hook_resid_post"),
    "log_reg_downsampling_2_layer_hook_resid_post": lambda df: (df, "log_reg_downsampling_2_layer_hook_resid_post"),
    "log_reg_downsampling_12_layer_hook_resid_post": lambda df: (df, "log_reg_downsampling_12_layer_hook_resid_post"),
    "log_reg_downsampling_28_layer_hook_resid_post": lambda df: (df, "log_reg_downsampling_28_layer_hook_resid_post"),
    "log_reg_downsampling_31_layer_hook_resid_post": lambda df: (df, "log_reg_downsampling_31_layer_hook_resid_post"),
    "catboost_2_layer_hook_resid_post": lambda df: (df, "catboost_2_layer_hook_resid_post"),
    "catboost_12_layer_hook_resid_post": lambda df: (df, "catboost_12_layer_hook_resid_post"),
    "catboost_28_layer_hook_resid_post": lambda df: (df, "catboost_28_layer_hook_resid_post"),
    "catboost_31_layer_hook_resid_post": lambda df: (df, "catboost_31_layer_hook_resid_post"),
}

In [62]:
rows = list(methods.keys())
columns = pd.MultiIndex.from_product(
    [pred_datasets.keys(), ["F1", "Exact Match", "InAcc", "BLEU", "ROUGE", "RAG calls"]],
    names=["Dataset", "Metric"]
)

data = []

In [63]:
for method_name, method_func in methods.items():
    row_data = []
    
    for dataset_name, df_test in pred_datasets.items():
        df_local, pred_col = method_func(df_test)
        gt_col = 'reference'

        inacc_score, mean_em, mean_f1, bleu_score, rouge_score, _ = process_dataframe(
            df_local, pred_col, gt_col
        )
        rag_calls_value = rag_calls.get(dataset_name, {}).get(method_name, 0)

        row_data.extend([mean_f1, mean_em, inacc_score, bleu_score, rouge_score, rag_calls_value])

    data.append(row_data)

metrics_df = pd.DataFrame(data, index=rows, columns=columns)

In [64]:
metrics_df

Dataset,NQ,NQ,NQ,NQ,NQ,NQ,wiki,wiki,wiki,wiki,...,hotpot,hotpot,hotpot,hotpot,musique,musique,musique,musique,musique,musique
Metric,F1,Exact Match,InAcc,BLEU,ROUGE,RAG calls,F1,Exact Match,InAcc,BLEU,...,InAcc,BLEU,ROUGE,RAG calls,F1,Exact Match,InAcc,BLEU,ROUGE,RAG calls
Never RAG,0.430445,0.291396,0.434794,0.069344,0.425247,0,0.26905,0.1095,0.3245,0.098352,...,0.256269,0.052803,0.282901,0,0.099673,0.025256,0.048512,0.0144,0.098662,0
Always RAG,0.427553,0.334957,0.419372,0.108941,0.423733,3696,0.175587,0.1345,0.163,0.12741,...,0.231194,0.093501,0.283401,1994,0.078807,0.036759,0.042261,0.020306,0.078343,3999
Always RAG (Supportive),0.662747,0.54362,0.697329,0.170131,0.657159,1685,0.367547,0.303983,0.362683,0.24343,...,0.508816,0.23312,0.595278,397,0.278069,0.169776,0.186567,0.083195,0.277037,536
Always RAG (Non-Supportive),0.230486,0.160119,0.186474,0.045728,0.228499,2011,0.115465,0.081418,0.10046,0.091689,...,0.162179,0.064518,0.205082,1597,0.047965,0.016171,0.019925,0.012872,0.047458,3463
Dola,0.432136,0.295184,0.439665,0.069343,0.427191,0,0.271028,0.11,0.3265,0.097489,...,0.249248,0.050875,0.276229,0,0.097032,0.025006,0.046262,0.013919,0.096094,0
log_reg_2_layer_hook_resid_post,0.449517,0.357143,0.438582,0.112788,0.445165,2641,0.191811,0.1415,0.183,0.144931,...,0.236209,0.093744,0.288911,1945,0.078807,0.036759,0.042261,0.020306,0.078343,3999
log_reg_12_layer_hook_resid_post,0.518909,0.395833,0.50974,0.134425,0.51254,2258,0.283889,0.1675,0.297,0.173257,...,0.297392,0.109928,0.353082,1540,0.104951,0.052263,0.062766,0.030186,0.103227,3792
log_reg_28_layer_hook_resid_post,0.530393,0.404491,0.518398,0.140147,0.523665,2231,0.307634,0.182,0.3235,0.1851,...,0.302909,0.109562,0.366072,1497,0.107694,0.053263,0.065766,0.030797,0.106132,3772
log_reg_31_layer_hook_resid_post,0.530393,0.404491,0.518398,0.140147,0.523665,2231,0.307634,0.182,0.3235,0.1851,...,0.302909,0.109562,0.366072,1497,0.107694,0.053263,0.065766,0.030797,0.106132,3772
log_reg_downsampling_2_layer_hook_resid_post,0.452923,0.36066,0.441558,0.112956,0.448359,2569,0.233585,0.1375,0.2495,0.146103,...,0.250752,0.10037,0.305808,1621,0.0928,0.042761,0.049512,0.025052,0.09213,3492
