In [None]:
! pip uninstall -y transformer_lens
! pip install git+https://github.com/taufeeque9/TransformerLens/
! pip install git+https://github.com/minyoungg/vqtorch/
! pip install termcolor
! pip install -U accelerate
! pip install -U kaleido

In [1]:
from termcolor import colored
import plotly.express as px
import transformers
import codebook_features
import torch
import evaluate
import numpy as np
import copy
import wandb
import json
import transformer_lens.utils as utils
from collections import namedtuple
from functools import partial
from torch.nn import functional as F

from transformers import GPTNeoXConfig, GPTNeoXForCausalLM, GPT2TokenizerFast, pipeline, set_seed
from torch.utils.data import IterableDataset
from codebook_features import models, run_clm, train_toy_model, trainer as cb_trainer
from codebook_features.utils import *
from codebook_features.toy_utils import *
import os

torch.set_grad_enabled(False)



The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path="config", config_name="toy_main")


<torch.autograd.grad_mode.set_grad_enabled at 0x7f7dcf7edf90>

In [2]:
# Hyperparameters
hp = dict(
    run_name = "cb_model_neox",
    tags = [],
    num_states = 100,
    num_edges = 10,
    seq_len = 128,
    vocab_size = 11,
#     hidden_size = 64,
#     intermediate_size = 256,
    hidden_size = 128,
    intermediate_size = 512,
    num_hidden_layers = 4,
    num_attention_heads = 4,
    rotary_emb_base = 10000,
)

In [None]:
# Hyperparameters for 4circle
hp = dict(
    run_name = "cb_model_neox",
    tags = [],
    num_states = 4,
    num_edges = 1,
    seq_len = 128,
    vocab_size = 3,
#     hidden_size = 64,
#     intermediate_size = 256,
    hidden_size = 16,
    intermediate_size = 64,
    num_hidden_layers = 1,
    num_attention_heads = 1,
    rotary_emb_base = 10000,
)

## Dataset Generation

In [None]:
# pretrained on 4circle
base_path = "/data/outputs/2023-06-02/03-05-17/"
checkpoint = "checkpoint-6500/"

In [3]:
# pretrained on 100s
base_path = "/data/outputs/2023-06-02/03-38-51/"
checkpoint = "checkpoint-9500/"

In [4]:
# pretrained on 100s loki
base_path = "../outputs/2023-06-02/03-38-51/"
checkpoint = "checkpoint-9500/"

In [None]:
# ft-cb on 4circle
base_path = "/data/outputs/2023-05-30/11-59-03/"
checkpoint = "checkpoint-500/"

In [None]:
# pretrained on 100 states
base_path = "/data/codebook-features/codebook_features/outputs/2023-05-25/12-56-03/"
checkpoint = "checkpoint-4500/"

In [4]:
# ft on 100 states
base_path = "/data/outputs/2023-06-02/06-12-08/"
checkpoint = "checkpoint-15500/"

In [None]:
base_path = "/data/codebook-features/codebook_features/outputs/2023-05-20/13-27-44/"
checkpoint = ""

In [None]:
# ft-cb on 100 states
base_path = "/data/codebook-features/codebook_features/outputs/2023-05-26/11-01-04/"
checkpoint = "checkpoint-17500/"

In [4]:
tokenizer = train_toy_model.create_tokenizer(base_path+"toy", hp["vocab_size"])
# automata = train_toy_model.ToyGraph(N=hp["num_states"], edges=hp["num_edges"], seed=42)
automata = train_toy_model.ToyGraph.load(base_path + "toy/automata.npy",representation_base=hp["vocab_size"]-1, seed=42)
train_dataset = train_toy_model.ToyDataset(automata, tokenizer=tokenizer, seq_len=hp["seq_len"])
eval_dataset = train_toy_model.ToyDataset(automata, tokenizer=tokenizer, seq_len=hp["seq_len"], max_samples=2048)

Vocab:
{"0": 0, "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "<|endoftext|>": 10}


## Load Model

In [5]:
# base model
from transformers import AutoModelForCausalLM
model_path = base_path + "output_toy/" + checkpoint
model = AutoModelForCausalLM.from_pretrained(model_path)
model = model.to("cuda").eval()
config = None
base_model = model

In [6]:
import transformer_lens
from transformer_lens import loading

hooked_kwargs = dict(center_unembed=False,center_writing_weights=False,fold_ln=False,fold_value_biases=False,refactor_factored_attn_matrices=False,device="cuda")

hooked_config = loading.convert_hf_model_config(model_path, config)
hooked_model = transformer_lens.HookedTransformer(hooked_config)
if hooked_kwargs is None:
    hooked_kwargs = {}
if "device" in hooked_kwargs:
    hooked_kwargs.pop("device")
state_dict = models.convert_state_dict(base_model, hooked_model.cfg)  # type: ignore
hooked_model.load_and_process_state_dict(
    state_dict,
    **hooked_kwargs,
)


In [None]:
gen_seq = tokenizer.decode(model.generate(max_length=hp["seq_len"], do_sample=True)[0])
traj = automata.seq_to_traj(gen_seq)
acc, _ = automata.transition_accuracy(traj)
print(gen_seq)
print(traj)
print(acc)

In [6]:
# cb model
config = GPTNeoXConfig(vocab_size=hp["vocab_size"], hidden_size=hp["hidden_size"], num_hidden_layers=hp["num_hidden_layers"], num_attention_heads=hp["num_attention_heads"], intermediate_size=hp["intermediate_size"], rotary_emb_base=hp["rotary_emb_base"], bos_token_id=hp["vocab_size"]-1, eos_token_id=hp["vocab_size"]-1, max_position_embeddings=hp["seq_len"])
config.architectures = ["GPTNeoXForCausalLM"]
model = GPTNeoXForCausalLM(config=config)
model = model.to("cuda").eval()
orig_cb_model = models.wrap_codebook(model_or_path=model, pretrained_path=base_path+f"output_toy/{checkpoint}")
orig_cb_model = orig_cb_model.to("cuda").eval()

In [7]:
# hooked model
hooked_kwargs = dict(center_unembed=False,center_writing_weights=False,fold_ln=False,fold_value_biases=False,refactor_factored_attn_matrices=False,device="cuda")
cb_model = models.convert_to_hooked_model_for_toy(base_path+f"output_toy/{checkpoint}", orig_cb_model, config=config, hooked_kwargs=hooked_kwargs)
cb_model = cb_model.to("cuda").eval()

In [32]:
model = base_model

In [31]:
report_to = "none"
# report_to = "all"
training_args = run_clm.TrainingArguments(
#     no_cuda=True,
    output_dir="toy/output",
    do_train=True,
    do_eval=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=512,
    per_device_eval_batch_size=512,
    learning_rate=1e-3,
#     weight_decay=1e-1,
    max_steps=20000,
#     lr_scheduler_type="linear",
    lr_scheduler_type="constant",
    warmup_ratio=0.1,
    logging_first_step=True,
    logging_steps=500,
    eval_steps=500,
    overwrite_output_dir=True,
    seed=42,
    train_model_params=True,
    model_lr_factor=1.0,
    report_to=report_to,
    dataloader_num_workers=8,

)

cfg_dict = {"training_args": training_args.__dict__, "model_args": config.__dict__ if config else None}
cfg_dict = {**hp, **cfg_dict}
model_args = run_clm.ModelArguments(model_name_or_path="toy/model")
data_args = run_clm.DataTrainingArguments(dataset_name="toy_graph", max_eval_samples=2048)

optimizers = (None, None)
if isinstance(model, models.CodebookModel):
    if training_args.train_model_params:
        params = [
            {
                "params": model.get_codebook_params(),
                "lr": training_args.learning_rate,
                # weight decay for codebook params is used through
                # `codebook_weight_decay` param that is used directly
                # to compute regularized loss.
                "weight_decay": 0.0,
            },
            {
                "params": model.get_model_params(),
                "lr": training_args.model_lr_factor * training_args.learning_rate,
                "weight_decay": training_args.weight_decay,
            },
        ]
    else:
        params = model.get_codebook_params()
    if len(params) > 0:
        optimizer = torch.optim.AdamW(
            params,
            training_args.learning_rate,
        )
        optimizers = (optimizer, None)

callbacks = []
# if report_to == "all":
#     callbacks = [cb_trainer.WandbCallback()]

def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        # Depending on the model and config, logits may contain extra tensors,
        # like past_key_values, but logits always come first
        logits = logits[0]
    return logits.argmax(dim=-1)

metric = evaluate.load("accuracy")
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # preds have the same shape as the labels, after the argmax(-1) has been calculated
    # by preprocess_logits_for_metrics but we need to shift the labels
    labels = labels[:, 1:].reshape(-1)
    preds = preds[:, :-1].reshape(-1)
    return metric.compute(predictions=preds, references=labels)

trainer = train_toy_model.ToyModelTrainer(
    model=model,
    toy_graph=automata,
    gen_seq_len=hp["seq_len"],
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    optimizers=optimizers,
    callbacks=callbacks,
)


Moving model to device:  cuda


In [None]:
from transformers import pipeline, set_seed

generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=0)
gen_seq = generator("", max_length=50, do_sample=True, temperature=0.7)[0]['generated_text']

## Eval

In [18]:
# model = orig_cb_model
# model = cb_model
model = hooked_model
cb_model = hooked_model

In [None]:
example_answer = "98"
# example_prompt = f"33" # 3 4 6 1 2
# example_prompt = f"63" # 9 8 5 7 4 6
# example_prompt = f"47" # 0 1 2 3 4 5 8
# example_prompt = f"71" # 1 3 4 5 8 0 2
# example_prompt = f"72" # 8 9 4 5 7 3
example_prompt = f"63" # 4 5 7 1 6 8
utils.test_prompt(example_prompt, example_answer, cb_model, prepend_bos=True, prepend_space_to_answer=False, top_k=10)

In [None]:
for i in range(100):
    print(automata.nbrs(i), "-", i)

## Token Code Maps

In [30]:
max_samples = 10*1024
# model = hooked_model
train_dataset_tkns = train_toy_model.ToyDataset(automata, tokenizer=tokenizer, seq_len=128, max_samples=max_samples, save_tokens=True)

In [32]:
trainer.args.dataloader_num_workers = 0
trainer.args.report_on = "none"

trainer.model = base_model
# trainer.model = hooked_model

metrics = trainer.evaluate(eval_dataset=train_dataset_tkns)
print(metrics)

# len(train_dataset_tkns.tokens)
from tqdm import tqdm

tokens = np.vstack(train_dataset_tkns.tokens)
print(tokens.shape)


You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:10 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


{'eval_runtime': 22.9548, 'eval_samples_per_second': 446.093, 'eval_steps_per_second': 0.871, 'eval_transition_accuracy': 0.9296774193548387, 'eval_first_transition_accuracy': 1.0}
(10240, 128)


## Eval Non-Codebook

In [7]:
model = hooked_model

In [11]:
clean_prompt = "639"
corrupted_prompt = "589"

clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

def logits_to_logit_diff(logits, correct_answer="8", incorrect_answer="1"):
    # model.to_single_token maps a string value of a single token to the token index for that token
    # If the string is not a single token, it raises an error.
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

# We run on the clean prompt with the cache so we store activations to patch in later.
clean_logits, clean_cache = model.run_with_cache(clean_tokens,prepend_bos=True)
clean_logit_diff = logits_to_logit_diff(clean_logits)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

# We don't need to cache on the corrupted prompt.
corrupted_logits = model(corrupted_tokens,prepend_bos=True)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits)
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")

Clean logit difference: 8.976
Corrupted logit difference: -8.575


In [None]:
# We define a residual stream patching hook
# We choose to act on the residual stream at the start of the layer, so we call it resid_pre
# The type annotations are a guide to the reader and are not necessary

# We make a tensor to store the results for each patching run. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
num_positions = len(clean_tokens[0])
nsp_patching_result = torch.zeros((model.cfg.n_layers, num_positions), device=model.cfg.device)

for layer in tqdm(range(model.cfg.n_layers)):
    for position in range(num_positions):
        # Use functools.partial to create a temporary hook function with the position fixed
        temp_hook_fn = partial(residual_stream_patching_hook, cache=clean_cache, position=position)
        # Run the model with the patching hook
        patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[
            (utils.get_act_name("resid_pre", layer), temp_hook_fn)
        ])
        # Calculate the logit difference
        patched_logit_diff = logits_to_logit_diff(patched_logits).detach()
        # Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)
        nsp_patching_result[layer, position] = (patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff - corrupted_logit_diff)

# Add the index to the end of the label, because plotly doesn't like duplicate labels
token_labels = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(clean_tokens))]
imshow(nsp_patching_result, x=token_labels, xaxis="Position", yaxis="Layer", title="Normalized Logit Difference After Patching Residual Stream")

In [14]:
clean_state = 29
corrupted_state = 51

nclean = automata.nbrs(clean_state)
ncorr = automata.nbrs(corrupted_state)
intersection = list(set(nclean) & set(ncorr))
print("clean_state", nclean)
print("corrupted_state", ncorr)
print(intersection)


clean_state [29 35 47 53 70 76 77 84 92 98]
corrupted_state [ 3  9 14 21 27 35 38 44 45 56]
[35]


In [23]:
all_states = [automata.token_repr(state) for state in range(automata.N)]
all_states_tokens = tokenizer(all_states, return_tensors="pt")['input_ids'].to(model.cfg.device)
all_states_tokens = F.pad(all_states_tokens, (1, 0), value=tokenizer.bos_token_id)
all_states_logits, all_states_cache = model.run_with_cache(all_states_tokens)

state_as = list(range(automata.N))
state_bs = list(range(automata.N))

state_bs_str = [automata.token_repr(state) for state in state_bs]
state_bs_tokens = tokenizer(state_bs_str, return_tensors="pt")['input_ids'].to(model.cfg.device)
state_bs_tokens = F.pad(state_bs_tokens, (1, 0), value=tokenizer.bos_token_id)

In [29]:
import plotly.graph_objects as go

from torch.nn import functional as F
import pandas as pd
import re
pos = [1,2]
# pos = 2
# patch_types = "resid_pre"
# patch_types = "attn_out"
patch_types = ["attn_out", "mlp_out"]

if isinstance(patch_types, str):
    patch_types = [patch_types]

def remove_idx_from_tensor(tensor, idx):
    return torch.cat([tensor[:idx], tensor[idx + 1 :]])

common_states_in_a_b = set(state_as).intersection(set(state_bs))
len_cross_prod = len(state_as) * len(state_bs) - len(common_states_in_a_b)
nsp_patching_result = torch.zeros((model.cfg.n_layers + 1, len_cross_prod, 2), device=model.cfg.device)
kl_div_result = torch.zeros((model.cfg.n_layers + 1, len_cross_prod), device=model.cfg.device)

for layer in tqdm(range(model.cfg.n_layers + 1)):
    offset = 0
    for state_a in state_as:
        state_a_cache = {
            k: v[state_a].unsqueeze(0) for k, v in all_states_cache.items()
        }
        temp_hook_fn = partial(
            residual_stream_patching_hook, cache=state_a_cache, position=pos
        )
        if layer == model.cfg.n_layers:
            fwd_hooks = [(utils.get_act_name(patch_type, layer_inner), temp_hook_fn) for layer_inner in range(model.cfg.n_layers) for patch_type in patch_types]
        else:
            fwd_hooks = [(utils.get_act_name(patch_type, layer), temp_hook_fn) for patch_type in patch_types]
        state_a_tokens = model.to_tokens(str(state_a), prepend_bos=True)
        try:
            a_index_in_b = state_bs.index(state_a)
            state_bs_wo_a = remove_idx_from_tensor(state_bs_tokens, a_index_in_b)
            state_bs_str_wo_a = (
                state_bs_str[:a_index_in_b] + state_bs_str[a_index_in_b + 1 :]
            )
        except ValueError:
            state_bs_wo_a = state_bs_tokens
            state_bs_str_wo_a = state_bs_str
        
        # state b <- state a
        len_stats = len(state_bs_wo_a)
        mod_logits = model.run_with_hooks(state_bs_wo_a, fwd_hooks=fwd_hooks)
        gt_logits = (
            all_states_logits[state_a].unsqueeze(0).repeat(mod_logits.shape[0], 1, 1)
        )
        kl_div_batch = JSD(gt_logits, mod_logits, reduction="none").sum(dim=-1)
        kl_div_result[layer, offset:offset+len_stats] = kl_div_batch
        nsp, _ = get_next_state_probs(
            state_bs_wo_a, model, automata, fwd_hooks=fwd_hooks
        )
        acc_a = correct_next_state_probs(state_a, nsp, automata, print_info="")
        acc_b = correct_next_state_probs(state_bs_str_wo_a, nsp, automata, print_info="")
        acc_a = torch.tensor(acc_a, device=model.cfg.device)
        acc_b = torch.tensor(acc_b, device=model.cfg.device)

        nsp_patching_result[layer, offset:offset+len_stats, 0] = acc_a
        nsp_patching_result[layer, offset:offset+len_stats, 1] = acc_b
        offset += len_stats
    assert offset == len_cross_prod

nsp_patching_result *= 100

df = pd.DataFrame(
    {
        "Layer": [str(i) for i in range(model.cfg.n_layers)] + ["All"],
        "NSP Acc Sa": nsp_patching_result[:, :, 0].mean(dim=1).tolist(),
        "NSP Acc Sb": nsp_patching_result[:, :, 1].mean(dim=1).tolist(),
        "KL Div": kl_div_result.mean(dim=1).tolist(),
        "NSP Acc Sa Std": nsp_patching_result[:, :, 0].std(dim=1).tolist(),
        "NSP Acc Sb Std": nsp_patching_result[:, :, 1].std(dim=1).tolist(),
        "KL Div Std": kl_div_result.std(dim=1).tolist(),
    }
)
df["NSP Acc Sa Std"] /= 2
df["NSP Acc Sb Std"] /= 2

patch_types_res = ""
for i, patch_type in enumerate(patch_types):
    attn_head = re.match(r"attn_(\d)_head_(\d)", patch_type)
    mlp = re.match(r"mlp_(\d)", patch_type)
    if i > 0:
        patch_types_res += " & "
    if patch_type == "attn_out":
        patch_types_res += "Attn"
    elif patch_type == "mlp_out":
        patch_types_res += "MLP"
    elif patch_type == "resid_pre":
        patch_types_res += "Pre-Residual Stream"
    elif patch_type == "resid_post":
        patch_types_res += "Post-Residual Stream"
    elif attn_head:
        patch_types_res += f"Attn {attn_head.group(1)} Head {attn_head.group(2)}"
    elif mlp:
        patch_types_res += f"MLP {mlp.group(1)}"

    else:
        raise ValueError(f"Unknown patch type: {patch_type}")
if isinstance(pos, int):
    pos_str = str(pos)
else:
    pos_str = str(pos[0]) if len(pos) == 1 else f"{pos[0]}-{pos[1]}"

fig = go.Figure()
fig.add_trace(go.Scatter(x=df["Layer"], y=df["NSP Acc Sa"], error_y=dict(type='data', array=df["NSP Acc Sa Std"]), name="NSP Accuracy (S<sub>A</sub>)"))
fig.add_trace(go.Scatter(x=df["Layer"], y=df["NSP Acc Sb"], error_y=dict(type='data', array=df["NSP Acc Sb Std"]), name="NSP Accuracy (S<sub>B</sub>)"))
# Create the plot using plotly express
# fig = px.line(df, x="Layer", y=df.columns[1:3]) # color="Metric"
title = f'NSP and JS Div Result. Patching S<sub>A</sub> &#8594; S<sub>B</sub> activations of {patch_types_res} at position {pos_str}'
fig.update_layout(title={"text": title, "x": 0.5, "xanchor": "center"})

# # Add a secondary y-axis for KL Divergence
fig.add_trace(
    go.Scatter(
        x=df["Layer"],
        y=df["KL Div"],
        error_y=dict(type='data', array=df["KL Div Std"]),
        name="Token 3 JS Div with S<sub>A</sub>",
        yaxis="y2",
    )
)

# Set the left y-axis title to NSP Accuracy
fig.update_layout(yaxis=dict(title="NSP Accuracy", range=(-5, 105)))

# Set the right y-axis title to KL Divergence
fig.update_layout(
    yaxis2=dict(title="JS Divergence", overlaying="y", side="right", range=[-0.1, 2])
)
# fig.update_yaxes(range=['auto', 'auto'])
# Add a custom legend title
fig.update_layout(legend=dict(title="Metric"))

fig.show()

100%|██████████| 5/5 [00:41<00:00,  8.28s/it]
