In [None]:
import torch as t
from fancy_einsum import einsum
import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer
from rich import print as rprint
import json

In [None]:
t.set_grad_enabled(False)
print("Disabled automatic differentiation")

In [None]:
gpt2_small_model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

# Get the default device used
device: t.device = utils.get_device()
print(device)

In [None]:
with open("../gpt2_small_data/IOI.json", "r") as f:
    data_samples = json.load(f)

anchored_data = [data['sentence'] for data in data_samples[:1]]
answers = [(" " + data_samples[i]["label"], " A") for i in range(len(data_samples[:1]))]

rprint(anchored_data[0])
rprint(answers[0])
answer_tokens = t.concat([
    gpt2_small_model.to_tokens(names, prepend_bos=False).T for names in answers
])
rprint(answer_tokens[0])
print(len(data_samples))

In [None]:
tokens = gpt2_small_model.to_tokens(anchored_data[0], prepend_bos=True)
original_logits, gpt2_cache = gpt2_small_model.run_with_cache(tokens)

In [None]:
layers_ls = [i for i in range(gpt2_small_model.cfg.n_layers)]

In [None]:
def mlp_lens_detector(model, model_cache, anchored_pos, correct_pos, layer_index):
    sent_weighted_values = model_cache["mlp_out", layer_index]
    correct_pos_mlp_lens = einsum("i,i->", model.W_U[:, correct_pos], model.ln_final(sent_weighted_values[0, -1, :])) + model.b_U[correct_pos] # logit lens for correct pos
    anchored_pos_mlp_lens = einsum("i,i->", model.W_U[:, anchored_pos], model.ln_final(sent_weighted_values[0, -1, :])) + model.b_U[anchored_pos] # logit lens for anchored pos
    print(f"Correct pos mlp lens: {correct_pos_mlp_lens}")
    print(f"Anchored pos mlp lens: {anchored_pos_mlp_lens}")
    
    return (anchored_pos_mlp_lens - correct_pos_mlp_lens).item(), correct_pos_mlp_lens.item(), anchored_pos_mlp_lens.item()

In [None]:
def coefficient_dim_contribution_detector(model, model_cache, anchored_pos, correct_pos, layer_index, gap_threshold=4):
    coefficients_final_pos = model_cache["post", layer_index].squeeze(0)[-1]
    W_out = model.W_out[layer_index]

    contributions_ls = []
    for i in range(coefficients_final_pos.shape[0]):
        coefficient_abs = t.abs(coefficients_final_pos[i])
        value_2_norm = t.linalg.vector_norm(W_out[i], ord=2)
        contributions_ls.append(coefficient_abs * value_2_norm)

    contributions = t.stack(contributions_ls)

    topk_contributions, topk_indices = t.topk(contributions, k=10, largest=True)

    topk_correct_pos_contributions = [] # abs value of correct pos contribution
    for i in range(topk_indices.shape[0]):
        correct_pos_contribution = einsum("d_model, d_model->", model.W_U[:, correct_pos], model.ln_final(coefficients_final_pos[topk_indices[i]] * W_out[topk_indices[i]] + model.b_out[layer_index])) + model.b_U[correct_pos]
        topk_correct_pos_contributions.append(correct_pos_contribution)

    topk_anchor_pos_contributions = [] # abs value of anchored pos contribution
    for i in range(topk_indices.shape[0]):
        anchor_pos_contribution = einsum("d_model, d_model->", model.W_U[:, anchored_pos], model.ln_final(coefficients_final_pos[topk_indices[i]] * W_out[topk_indices[i]] + model.b_out[layer_index])) + model.b_U[anchored_pos]
        topk_anchor_pos_contributions.append(anchor_pos_contribution)

    topk_words_ls = [] # visualize top k words stored in model's W_out based on specific dimension
    for i in range(topk_indices.shape[0]):
        per_value_words = []
        value_unembed = einsum("d_model, d_model d_vocab-> d_vocab", model.ln_final(W_out[topk_indices[i]] + model.b_out[layer_index]), model.W_U) + model.b_U
        prob_unembed = value_unembed.softmax(dim=-1)
        prob_unembed_values, prob_unembed_indices = prob_unembed.sort(descending=True)
        for j in range(20):
            per_value_words.append(model.to_string(prob_unembed_indices[j]))
        
        topk_words_ls.append(per_value_words)

    diff_contributions = t.stack(topk_anchor_pos_contributions) - t.stack(topk_correct_pos_contributions)
    large_gap_indices = t.where(diff_contributions > gap_threshold)[0]
    large_gap_topk_indices = topk_indices[large_gap_indices]

    large_gap_topk_words_ls = [topk_words_ls[i] for i in large_gap_indices]
    
    return large_gap_topk_indices, diff_contributions[large_gap_indices].tolist(), large_gap_topk_words_ls, topk_contributions.tolist(), topk_indices.tolist(), topk_correct_pos_contributions, topk_anchor_pos_contributions


In [None]:
def fix_mlp(model, gap_dim, layer_index, anchored_pos, correct_pos, alpha_1 = 1, alpha_2 = 8): # follow Value_dim = Value_dim - alpha_1 * anchord_unemebed + alpha_2 * correct_unembed
    original_W_out = model.blocks[layer_index].mlp.W_out.clone()
    model.blocks[layer_index].mlp.W_out[gap_dim, :] = model.blocks[layer_index].mlp.W_out[gap_dim, :] - alpha_1 * model.W_U[:, anchored_pos] + alpha_2 * model.W_U[:, correct_pos]
    return original_W_out, model

In [None]:
def prediction_compare(modified_model, tokens, original_logits, correct_pos):
    modified_logits, modified_gpt2_cache = modified_model.run_with_cache(tokens)
    modified_final_pos_probs = modified_logits.squeeze(0)[-1].softmax(dim=-1)
    modified_final_pos_values, modified_final_pos_indices = modified_final_pos_probs.sort(descending=True)
    
    
    top_20_words_modified = []
    for i in range(20):
        top_20_words_modified.append(modified_model.to_string(modified_final_pos_indices[i]))
    modified_correct_prob = modified_final_pos_probs[correct_pos]
    

    original_final_pos_probs = original_logits.squeeze(0)[-1].softmax(dim=-1)
    original_final_pos_values, original_final_pos_indices = original_final_pos_probs.sort(descending=True)
    top_20_words_original = []
    for i in range(20):
        top_20_words_original.append(modified_model.to_string(original_final_pos_indices[i]))
    
    original_correct_prob = original_final_pos_probs[correct_pos]

    return modified_final_pos_indices[0].item() == correct_pos, modified_final_pos_values[0], original_final_pos_values[0], top_20_words_modified, top_20_words_original, modified_correct_prob, original_correct_prob

In [None]:
print(f"Sent: {gpt2_small_model.to_string(tokens)}")

for layer_index in layers_ls[::-1]:
    
    print(f"Layer: {layer_index}")
    
    mlp_lens, correct_pos_lens, anchored_pos_lens = mlp_lens_detector(gpt2_small_model, gpt2_cache, answer_tokens[0][1].item(), answer_tokens[0][0].item(), layer_index)

    print(f"MLP logit difference at layer {layer_index}: {mlp_lens}")

    large_gap_topk_indices, diff_contributions, large_gap_topk_words_ls, topk_contributions, topk_indices, took_correct_pos_contributions, topk_anchor_pos_contributions = coefficient_dim_contribution_detector(gpt2_small_model, gpt2_cache, answer_tokens[0][1].item(), answer_tokens[0][0].item(), layer_index)

    print(f"Large diff dimensions at layer {layer_index} are {[i.item() for i in large_gap_topk_indices]}")
    print(f"Top 20 words for large diff dimensions at layer {layer_index} are {large_gap_topk_words_ls}")
    print(f"Diff logit contributions for large diff dimensions at layer {layer_index} are {diff_contributions}")

    if len(large_gap_topk_indices) != 0:
                
        original_W_out, modified_model = fix_mlp(gpt2_small_model, large_gap_topk_indices[0], layer_index, answer_tokens[0][1].item(), answer_tokens[0][0].item())
        
        prediction_result, modified_final_pos_top1_value, original_final_pos_top1_value, top_20_words_modified, top_20_words_original, modified_correct_prob, original_correct_prob = prediction_compare(modified_model, tokens.squeeze(0), original_logits, answer_tokens[0][0].item())
        

        print(f"Modified prediction result is {prediction_result}, modified next token prob is {modified_final_pos_top1_value}, original next token prob is {original_final_pos_top1_value}")
        print(f"Top 20 words for modified model are {top_20_words_modified}")
        print(f"Top 20 words for original model are {top_20_words_original}")
        print("--------------------------------------------------------------------------------------------------------------------")

        gpt2_small_model.blocks[layer_index].mlp.W_out[large_gap_topk_indices[0], :] = original_W_out[large_gap_topk_indices[0], :]