In [1]:
import torch
from transformer_lens import HookedTransformer
from functools import partial
import torch.nn.functional as F
from eap.metrics import logit_diff, direct_logit
import transformer_lens.utils as utils
from eap.graph import Graph
from eap.dataset import EAPDataset
from eap.attribute import attribute
import time
from rich import print as rprint
import pandas as pd
from eap.evaluate import evaluate_graph, evaluate_baseline,get_circuit_logits

In [None]:
LLAMA_2_7B_CHAT_PATH = "meta-llama/Llama-2-7b-chat-hf"
from transformers import LlamaForCausalLM
model = HookedTransformer.from_pretrained(LLAMA_2_7B_CHAT_PATH, device="cuda", fold_ln=False, center_writing_weights=False, center_unembed=False)
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True

In [5]:
clean_subject = 'Eiffel Tower'
corrupted_subject = 'the Great Walls'
clean = f'The official currency of the country where {clean_subject} is loacted in is the'
corrupted = f'The official currency of the country where {corrupted_subject} is loacted in is the'
assert len(model.to_str_tokens(clean.format(clean_subject))) == len(model.to_str_tokens(corrupted.format(corrupted_subject)))
labels = ['Euro','Chinese']
country_idx = model.tokenizer(labels[0],add_special_tokens=False).input_ids[0]
corrupted_country_idx = model.tokenizer(labels[1],add_special_tokens=False).input_ids[0]
# dataset = {k:[] for k in ['clean','country_idx', 'corrupted',  'corrupted_country_idx']}
# for k, v in zip(['clean', 'country_idx', 'corrupted', 'corrupted_country_idx'], [clean, country_idx, corrupted, corrupted_country_idx]):
#     dataset[k].append(v)
# df2 = pd.DataFrame.from_dict(dataset)
# df2.to_csv(f'capital_city.csv', index=False)

In [6]:
label = [[country_idx, corrupted_country_idx]]
label = torch.tensor(label)
data = ([clean],[corrupted],label)

In [5]:
# ds = EAPDataset(filename='capital_city.csv',task='fact-retrieval')
# dataloader = ds.to_dataloader(1)

In [12]:
g = Graph.from_model(model)
start_time = time.time()
# Attribute using the model, graph, clean / corrupted data and labels, as well as a metric
attribute(model, g, data, partial(logit_diff, loss=True, mean=True), method='EAP-IG-case', ig_steps=100)
# attribute(model, g, data, partial(direct_logit, loss=True, mean=True), method='EAP-IG-case', ig_steps=30)
# attribute(model, g, dataloader, partial(logit_diff, loss=True, mean=True), method='EAP-IG', ig_steps=30)
g.apply_topn(5000, absolute=True)
g.prune_dead_nodes()

g.to_json('graph.json')

gz = g.to_graphviz()
gz.draw(f'graph.png', prog='dot')

end_time = time.time()
execution_time = end_time - start_time
print(f"程序执行时间：{execution_time}秒")

100%|██████████| 1592881/1592881 [00:01<00:00, 1062625.82it/s]


程序执行时间：43.55915355682373秒


In [None]:
def get_component_logits(logits, model, answer_token, top_k=10):
    logits = utils.remove_batch_dim(logits)
    # print(heads_out[head_name].shape)
    probs = logits.softmax(dim=-1)
    token_probs = probs[-1]
    answer_str_token = model.to_string(answer_token)
    sorted_token_probs, sorted_token_values = token_probs.sort(descending=True)
    # Janky way to get the index of the token in the sorted list - I couldn't find a better way?
    correct_rank = torch.arange(len(sorted_token_values))[
        (sorted_token_values == answer_token).cpu()
    ].item()
    # answer_ranks = []
    # answer_ranks.append((answer_str_token, correct_rank))
    # String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places.
    # rprint gives rich text printing
    rprint(
        f"Performance on answer token:\n[b]Rank: {correct_rank: <8} Logit: {logits[-1, answer_token].item():5.2f} Prob: {token_probs[answer_token].item():6.2%} Token: |{answer_str_token}|[/b]"
    )
    for i in range(top_k):
        print(
            f"Top {i}th token. Logit: {logits[-1, sorted_token_values[i]].item():5.2f} Prob: {sorted_token_probs[i].item():6.2%} Token: |{model.to_string(sorted_token_values[i])}|"
        )
    # rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}")

In [13]:
logits = get_circuit_logits(model, g, data)
get_component_logits(logits, model, answer_token=model.to_tokens('Euro',prepend_bos=False)[0], top_k=5)

Top 0th token. Logit: 16.94 Prob: 56.56% Token: |Euro|
Top 1th token. Logit: 15.96 Prob: 21.39% Token: |French|
Top 2th token. Logit: 14.06 Prob:  3.18% Token: |_|
Top 3th token. Logit: 13.95 Prob:  2.85% Token: |euro|
Top 4th token. Logit: 13.91 Prob:  2.74% Token: |Eu|


In [69]:
baseline = evaluate_baseline(model, dataloader, partial(logit_diff, loss=False, mean=False)).mean().item()
results = evaluate_graph(model, g, dataloader, partial(logit_diff, loss=False, mean=False)).mean().item()
print(f"Original performance was {baseline}; the circuit's performance is {results}")

100%|██████████| 1/1 [00:00<00:00,  8.79it/s]
100%|██████████| 1/1 [00:00<00:00,  8.91it/s]
100%|██████████| 1/1 [00:00<00:00,  6.82it/s]

Original performance was 10.043922424316406; the circuit's performance is 6.337347984313965



