In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from transformer_lens import HookedTransformer
import pandas as pd
import torch
from tqdm import tqdm
from pathlib import Path

In [40]:
def load_model(adapter_path, hf_model_name, translens_model_name, scratch_cache_dir = None):
    base_model = AutoModelForCausalLM.from_pretrained(hf_model_name, cache_dir=scratch_cache_dir)
    model_with_lora = PeftModel.from_pretrained(base_model, adapter_path)
    model_with_lora = model_with_lora.merge_and_unload()
    model = HookedTransformer.from_pretrained(model_name=translens_model_name, hf_model=model_with_lora, cache_dir=scratch_cache_dir)  

    model.cfg.use_split_qkv_input = True
    model.cfg.use_attn_result = True
    model.cfg.use_hook_mlp_in = True
    model.cfg.ungroup_grouped_query_attention = True
    return model

In [41]:
def evaluate_prompt(model: HookedTransformer, prompt: str, max_new_tokens = 5):
    tokens = model.to_tokens(prompt)
    generated_tokens = tokens

    for _ in range(max_new_tokens):
        logits, cache = model.run_with_cache(generated_tokens)
        next_token_logits = logits[0, -1, :]
        probs = torch.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        generated_tokens = torch.cat((generated_tokens, next_token.unsqueeze(0)), dim=1)

    generated_text = model.to_string(generated_tokens[0])
    cleaned_text = generated_text.replace('<|endoftext|>', '')
    cleaned_text = cleaned_text.replace(prompt, '')
    return cleaned_text

In [70]:
def run_eval(model: HookedTransformer, df: pd.DataFrame):
    correct = 0
    incorrect = 0
    for index, row in tqdm(df.iterrows()):
        prompt = row['clean']
        correct_answer = row['label']
        pred_answer = evaluate_prompt(model, prompt)
        if str(correct_answer) == str(pred_answer):
            correct += 1 
        else:
            incorrect +=1
    
    perc = correct / (correct + incorrect)
    return correct, incorrect, perc

In [None]:
scratch_cache_dir = "/mnt/fast0/rje41/.cache/huggingface"    
hf_model_name = "EleutherAI/pythia-1.4B-deduped"
translens_model_name="pythia-1.4B-deduped"
eap_dataset_path = "../datasets/add_sub_test.csv"

In [None]:
folder = Path("../fine-tuning/add_sub/results2")
all_paths = list(folder.rglob("*"))
adapter_paths = [p for p in all_paths if p.is_dir()]

results = []
df = pd.read_csv('../datasets/add_sub.csv')
for adapter_path in adapter_paths:
    model = load_model(adapter_path, hf_model_name, translens_model_name, scratch_cache_dir)
    correct, incorrect, perc = run_eval(model, df)

    results.append({
        'percentage_correct': perc,
        'adapter_path': adapter_path
    })

In [73]:
scratch_cache_dir = "/mnt/fast0/rje41/.cache/huggingface"    
hf_model_name = "EleutherAI/pythia-1.4B-deduped"
translens_model_name="pythia-1.4B-deduped"
eap_dataset_path = "../datasets/add_sub_test.csv"
adapter_path = "../fine-tuning/add_sub/results2/checkpoint-10000"
model = load_model(adapter_path, hf_model_name, translens_model_name, scratch_cache_dir)

Loaded pretrained model pythia-1.4B-deduped into HookedTransformer


In [55]:
df = pd.read_csv('../datasets/add_sub_test.csv')

print(df.head())

       clean  corrupted  label
0  46 + 20 =  46 + 64 =     66
1  95 - 26 =  95 - 28 =     69
2  68 + 45 =  68 + 14 =    113
3  16 + 59 =  16 + 30 =     75
4  54 + 66 =   54 + 7 =    120


In [74]:
correct, incorrect, perc = run_eval(model, df)

500it [04:41,  1.78it/s]


In [76]:
perc

0.396