In [None]:
! pip uninstall -y transformer_lens
! pip install git+https://github.com/taufeeque9/TransformerLens/
! pip install git+https://github.com/minyoungg/vqtorch/
! pip install termcolor
! pip install -U accelerate
! pip install -U kaleido

In [None]:
from termcolor import colored
import plotly.express as px
import transformers
import codebook_features
import torch
import evaluate
import numpy as np
import copy
import wandb
import json
import transformer_lens.utils as utils
from collections import namedtuple
from functools import partial
from torch.nn import functional as F
import itertools
from tqdm import tqdm


from transformers import (
    GPTNeoXConfig,
    GPTNeoXForCausalLM,
    GPT2TokenizerFast,
    pipeline,
    set_seed,
)
from torch.utils.data import IterableDataset
from codebook_features import models, run_clm, train_toy_model, trainer as cb_trainer   
from codebook_features.utils import *
from codebook_features.toy_utils import *
import os
import math

torch.set_grad_enabled(False)

In [None]:
model_name_or_path = "EleutherAI/pythia-410m-deduped"
# pretrained_path = "/data/outputs/2023-04-25/00-10-18/output_main/" # mlp 70m
pretrained_path = "/data/outputs/2023-05-08/01-42-10/output_main/" # mlp 410m
device = "cuda"
orig_cb_model = models.wrap_codebook(model_or_path=model_name_or_path, pretrained_path=pretrained_path)
from time import time
t0 = time()
orig_cb_model = orig_cb_model.to(device).eval()
orig_cb_model.disable_logging()
t1 = time()
print("Loaded original cb model. Post time:", t1-t0)
hooked_kwargs = dict(center_unembed=False,fold_value_biases=False,center_writing_weights=True,fold_ln=True,refactor_factored_attn_matrices=False,device=device)
cb_model = models.convert_to_hooked_model(model_name_or_path, orig_cb_model, hooked_kwargs=hooked_kwargs)
# # cb_model = cb_model.model
cb_model.disable_logging()
cb_model = cb_model.to(device).eval()
model = orig_cb_model
tokenizer = cb_model.tokenizer

In [None]:
report_to = "none"
# report_to = "all"
training_args = run_clm.TrainingArguments(
    #     no_cuda=True,
    output_dir="temp_mlp/",
    do_eval=True,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
)

model_args = run_clm.ModelArguments(model_name_or_path=pretrained_path, cache_dir="/data/.cache/huggingface/")
data_args = run_clm.DataTrainingArguments(
    dataset_name="wikitext", dataset_config_name="wikitext-103-v1", streaming=False,
)

trainer, lm_datasets, last_checkpoint = run_clm.get_trainer_and_dataset(
    model_args,
    data_args,
    training_args,
    model,
    optimizers=(None, None),
)

In [None]:
max_samples = 2000
dataset = lm_datasets["train"].select(np.random.choice(len(lm_datasets["train"]), max_samples, replace=False))
tokens = dataset["input_ids"]

trainer.args.report_on = "none"
codebook_acts = {}

def store_cb_activations(key, codebook_ids, codebook_acts=codebook_acts):
    assert len(codebook_ids.shape) == 3  # (bs, seq_len, k_codebook)
    if key not in codebook_acts:
        codebook_acts[key] = []
    codebook_acts[key].append(codebook_ids)

orig_cb_model.set_hook_fn(store_cb_activations)

metrics = trainer.evaluate(dataset)
print(metrics)

cb_acts = codebook_acts

In [None]:
from datetime import datetime
import pickle
output_dir = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
output_dir = "/homedir/cb_eval_acts/mlp_" + output_dir
os.makedirs(output_dir, exist_ok=True)
np.save(f"{output_dir}/tokens.npy", tokens)
np.save(f"{output_dir}/metrics.npy", metrics)
with open(f"{output_dir}/cb_acts.pkl", "wb") as f:
    pickle.dump(codebook_acts, f)


In [None]:
seq = orig_cb_model.generate(do_sample=True, max_length=100, temperature=0.7)
tokenizer.decode(seq[0].tolist(), skip_special_tokens=True)

In [None]:
cb_model.generate("Resident Evil", max_new_tokens=10, do_sample=False)

In [None]:
example_prompt = " American"
example_answer = " States"

utils.test_prompt(
    example_prompt,
    example_answer,
    cb_model,
    prepend_bos=False,
    prepend_space_to_answer=False,
    top_k=10,
)

In [None]:
example_prompt = "Paris is the capital city of"
example_answer = " France"

utils.test_prompt(
    example_prompt,
    example_answer,
    cb_model,
    prepend_bos=False,
    prepend_space_to_answer=False,
    top_k=10,
)

In [None]:
input_a = "The Italian dish"
input_b = "The football club Manchester"

input_a = cb_model.to_tokens(input_a, prepend_bos=False).to("cuda")
input_b = cb_model.to_tokens(input_b, prepend_bos=False).to("cuda")

logits_a, cache_a = cb_model.run_with_cache(input_a)
logits_b, cache_b = cb_model.run_with_cache(input_b)

print(f"JSD(a, b) = {JSD(logits_a, logits_b)}")

In [None]:
input_a = "Washington is the capital of The United"
input_b = "London is the capital of The United"

input_a = cb_model.to_tokens(input_a, prepend_bos=False).to("cuda")
input_b = cb_model.to_tokens(input_b, prepend_bos=False).to("cuda")
print(input_a)
print(input_b)
logits_a, cache_a = cb_model.run_with_cache(input_a)
logits_b, cache_b = cb_model.run_with_cache(input_b)

print(f"JSD(a, b) = {JSD(logits_a, logits_b)}")

In [None]:
input_a = "Paris is"
input_b = "The football club Manchester United is at the top of the league. Manchester"

input_a = cb_model.to_tokens(input_a, prepend_bos=False).to("cuda")
input_b = cb_model.to_tokens(input_b, prepend_bos=False).to("cuda")
print(input_a)
print(input_b)
logits_a, cache_a = cb_model.run_with_cache(input_a)
logits_b, cache_b = cb_model.run_with_cache(input_b)

print(f"JSD(a, b) = {JSD(logits_a, logits_b)}")

In [None]:
[0,17,19,20]

In [None]:
pos = -1
# for layer_idx in range(cb_model.cfg.n_layers):
#     n_layers = 1
for layer_idx in range(1):
    n_layers = cb_model.cfg.n_layers
    # n_layers = 1
    cb_at = ( ["mlp"] * cb_model.cfg.n_heads ) * n_layers
    head = (list(range(cb_model.cfg.n_heads))) * n_layers
    # layer = [layer_idx] * len(cb_at)
    layer = range_over_repeat(n_layers, repeat=cb_model.cfg.n_heads)
    # layer = range_over_repeat([0,17,19,20], repeat=cb_model.cfg.n_heads)
    code = [cache_b[get_cb_layer_name(cb_at[i], layer[i], head[i])][0, pos, :] for i in range(len(cb_at))]
    # ind = [0,1,2,3,4,5]
    ind = range(len(cb_at))
    mod_logits, mod_cache = run_with_codes(
        input_a,
        cb_model,
        [code[i] for i in ind],
        [cb_at[i] for i in ind],
        [layer[i] for i in ind],
        [head[i] for i in ind],
        pos=[pos],
    )
    print("Layer:", layer_idx)
    print(logits_to_pred(mod_logits, tokenizer, k=5))
    print(f"JSD(a <- b, b) = {JSD(mod_logits, logits_b, pos=-1)}")


In [None]:
def patch_in_codes(run_cb_ids, hook, pos, code):
    """Patch in the `code` at `run_cb_ids`."""
    if pos:
        run_cb_ids[:, pos] = code
    else:
        run_cb_ids[:, :] = code
    return run_cb_ids



In [None]:
find_code_changes(mod_cache, cache_b)

In [None]:
js_divs = torch.zeros((cb_model.cfg.n_layers, cb_model.cfg.n_heads))
# for layer_idx in tqdm(range(cb_model.cfg.n_layers)):
for layer_idx in tqdm(range(2)):
    for head_idx in range(cb_model.cfg.n_heads):
        # n_layers = cb_model.cfg.n_layers
        cb_at = ["mlp"]
        head = [head_idx]
        layer = [layer_idx] * len(cb_at)
        # layer = range_over_repeat(n_layers, repeat=cb_model.cfg.n_heads)
        # layer = range_over_repeat([0,17,19,20], repeat=cb_model.cfg.n_heads)
        code = [cache_b[get_cb_layer_name(cb_at[i], layer[i], head[i])][0, -1, :] for i in range(len(cb_at))]
        # ind = [0,1,2,3,4,5]
        ind = range(len(cb_at))
        mod_logits, mod_cache = run_with_codes(
            input_a,
            cb_model,
            [code[i] for i in ind],
            [cb_at[i] for i in ind],
            [layer[i] for i in ind],
            [head[i] for i in ind],
            pos=[-1],
        )
        # print("Layer, Head:", layer_idx, head_idx)
        # print(logits_to_pred(mod_logits, tokenizer, k=5))
        # print(f"JSD(a <- b, b) = {JSD(mod_logits, logits_b, pos=-1)}")
        js_divs[layer_idx, head_idx] = JSD(mod_logits, logits_b, pos=-1)

# make the color map range start from 0
imshow(js_divs, xaxis="head", yaxis="layer")



In [None]:
def range_over_repeat(end_or_list, repeat=1):
    
    if isinstance(end_or_list, int):
        end_or_list = range(end_or_list)
    l = []
    for i in end_or_list:
        l += [i] * repeat
    return l