In [None]:
! pip uninstall -y transformer_lens
! pip install git+https://github.com/taufeeque9/TransformerLens/
! pip install git+https://github.com/minyoungg/vqtorch/
! pip install termcolor
! pip install -U accelerate
! pip install -U kaleido

In [1]:
from termcolor import colored
import plotly.express as px
import transformers
import codebook_features
import torch
import evaluate
import numpy as np
import copy
import wandb
import json
import transformer_lens.utils as utils
from collections import namedtuple
from functools import partial
from torch.nn import functional as F
import itertools


from transformers import (
    GPTNeoXConfig,
    GPTNeoXForCausalLM,
    GPT2TokenizerFast,
    pipeline,
    set_seed,
)
from torch.utils.data import IterableDataset
from codebook_features import models, run_clm, train_toy_model, trainer as cb_trainer
from codebook_features.utils import *
from codebook_features.toy_utils import *
import os

torch.set_grad_enabled(False)

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path="config", config_name="toy_main")


<torch.autograd.grad_mode.set_grad_enabled at 0x7ff7906f3220>

In [2]:
# Hyperparameters
hp = dict(
    run_name="cb_model_neox",
    tags=[],
    num_states=100,
    num_edges=10,
    seq_len=128,
    vocab_size=11,
    #     hidden_size = 64,
    #     intermediate_size = 256,
    hidden_size=128,
    intermediate_size=512,
    num_hidden_layers=4,
    num_attention_heads=4,
    rotary_emb_base=10000,
)

In [None]:
# Hyperparameters for 4circle
hp = dict(
    run_name="cb_model_neox",
    tags=[],
    num_states=4,
    num_edges=1,
    seq_len=128,
    vocab_size=3,
    #     hidden_size = 64,
    #     intermediate_size = 256,
    hidden_size=16,
    intermediate_size=64,
    num_hidden_layers=1,
    num_attention_heads=1,
    rotary_emb_base=10000,
)

## Dataset Generation

In [None]:
# pretrained on 4circle
base_path = "/data/outputs/2023-06-02/03-05-17/"
checkpoint = "checkpoint-6500/"


In [10]:
# pretrained on 100s
base_path = "/data/outputs/2023-06-02/03-38-51/"
checkpoint = "checkpoint-9500/"


In [4]:
# pretrained on 100s loki
base_path = "../outputs/2023-06-02/03-38-51/"
checkpoint = "checkpoint-9500/"


In [None]:
# pretrained on 100 states
base_path = "/data/codebook-features/codebook_features/outputs/2023-05-25/12-56-03/"
checkpoint = "checkpoint-4500/"


In [3]:
# ft on 100 states
base_path = "/data/outputs/2023-06-02/06-12-08/"
checkpoint = "checkpoint-15500/"


In [4]:
# ft on 100 states loki
base_path = "../outputs/2023-06-02/06-12-08/"
checkpoint = "checkpoint-15500/"


In [None]:
base_path = "/data/codebook-features/codebook_features/outputs/2023-05-20/13-27-44/"
checkpoint = ""


In [20]:
# ft-cb on 100 states
base_path = "/data/codebook-features/codebook_features/outputs/2023-05-26/11-01-04/"
checkpoint = "checkpoint-17500/"


In [4]:
tokenizer = train_toy_model.create_tokenizer(base_path + "toy", hp["vocab_size"])
# automata = train_toy_model.ToyGraph(N=hp["num_states"], edges=hp["num_edges"], seed=42)
automata = train_toy_model.ToyGraph.load(
    base_path + "toy/automata.npy", representation_base=hp["vocab_size"] - 1, seed=42
)
train_dataset = train_toy_model.ToyDataset(
    automata, tokenizer=tokenizer, seq_len=hp["seq_len"]
)
eval_dataset = train_toy_model.ToyDataset(
    automata, tokenizer=tokenizer, seq_len=hp["seq_len"], max_samples=2048
)

Vocab:
{"0": 0, "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "<|endoftext|>": 10}


## Load Model

In [None]:
# base model
from transformers import AutoModelForCausalLM
device = device
model_path = base_path + "output_toy/" + checkpoint
model = AutoModelForCausalLM.from_pretrained(model_path)
model = model.to(device).eval()
config = None
base_model = model

In [23]:
# base model with hook
import transformer_lens
from transformer_lens import loading

hooked_kwargs = dict(
    center_unembed=False,
    center_writing_weights=False,
    fold_ln=False,
    fold_value_biases=False,
    refactor_factored_attn_matrices=False,
    device=device,
)

hooked_config = loading.convert_hf_model_config(model_path, config)
hooked_model = transformer_lens.HookedTransformer(hooked_config)
if hooked_kwargs is None:
    hooked_kwargs = {}
if "device" in hooked_kwargs:
    hooked_kwargs.pop("device")
state_dict = models.convert_state_dict(base_model, hooked_model.cfg)  # type: ignore
hooked_model.load_and_process_state_dict(
    state_dict,
    **hooked_kwargs,
)

In [None]:
gen_seq = tokenizer.decode(model.generate(max_length=hp["seq_len"], do_sample=True)[0])
traj = automata.seq_to_traj(gen_seq)
acc, _ = automata.transition_accuracy(traj)
print(gen_seq)
print(traj)
print(acc)


In [5]:
# cb model
device = "cuda"
config = GPTNeoXConfig(
    vocab_size=hp["vocab_size"],
    hidden_size=hp["hidden_size"],
    num_hidden_layers=hp["num_hidden_layers"],
    num_attention_heads=hp["num_attention_heads"],
    intermediate_size=hp["intermediate_size"],
    rotary_emb_base=hp["rotary_emb_base"],
    bos_token_id=hp["vocab_size"] - 1,
    eos_token_id=hp["vocab_size"] - 1,
    max_position_embeddings=hp["seq_len"],
)
config.architectures = ["GPTNeoXForCausalLM"]
model = GPTNeoXForCausalLM(config=config)
model = model.to(device).eval()
orig_cb_model = models.wrap_codebook(
    model_or_path=model, pretrained_path=base_path + f"output_toy/{checkpoint}"
)
orig_cb_model = orig_cb_model.to(device).eval()

Time: 6.198883056640625e-06 4.5299530029296875e-05 1.2808454036712646


In [6]:
# hooked model
hooked_kwargs = dict(
    center_unembed=False,
    center_writing_weights=False,
    fold_ln=False,
    fold_value_biases=False,
    refactor_factored_attn_matrices=False,
    device=device,
)
cb_model = models.convert_to_hooked_model_for_toy(
    base_path + f"output_toy/{checkpoint}",
    orig_cb_model,
    config=config,
    hooked_kwargs=hooked_kwargs,
)
cb_model = cb_model.to(device).eval()

Time: 8.106231689453125e-06 3.790855407714844e-05 1.4247303009033203


In [8]:
report_to = "none"
# report_to = "all"
training_args = run_clm.TrainingArguments(
    #     no_cuda=True,
    output_dir="toy/output",
    do_train=True,
    do_eval=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=512,
    per_device_eval_batch_size=512,
    learning_rate=1e-3,
    #     weight_decay=1e-1,
    max_steps=20000,
    #     lr_scheduler_type="linear",
    lr_scheduler_type="constant",
    warmup_ratio=0.1,
    logging_first_step=True,
    logging_steps=500,
    eval_steps=500,
    overwrite_output_dir=True,
    seed=42,
    train_model_params=True,
    model_lr_factor=1.0,
    report_to=report_to,
    dataloader_num_workers=8,
)

cfg_dict = {
    "training_args": training_args.__dict__,
    "model_args": config.__dict__ if config else None,
}
cfg_dict = {**hp, **cfg_dict}
model_args = run_clm.ModelArguments(model_name_or_path="toy/model")
data_args = run_clm.DataTrainingArguments(
    dataset_name="toy_graph", max_eval_samples=2048
)

optimizers = (None, None)
if isinstance(model, models.CodebookModel):
    if training_args.train_model_params:
        params = [
            {
                "params": model.get_codebook_params(),
                "lr": training_args.learning_rate,
                # weight decay for codebook params is used through
                # `codebook_weight_decay` param that is used directly
                # to compute regularized loss.
                "weight_decay": 0.0,
            },
            {
                "params": model.get_model_params(),
                "lr": training_args.model_lr_factor * training_args.learning_rate,
                "weight_decay": training_args.weight_decay,
            },
        ]
    else:
        params = model.get_codebook_params()
    if len(params) > 0:
        optimizer = torch.optim.AdamW(
            params,
            training_args.learning_rate,
        )
        optimizers = (optimizer, None)

callbacks = []
# if report_to == "all":
#     callbacks = [cb_trainer.WandbCallback()]


def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        # Depending on the model and config, logits may contain extra tensors,
        # like past_key_values, but logits always come first
        logits = logits[0]
    return logits.argmax(dim=-1)


metric = evaluate.load("accuracy")


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # preds have the same shape as the labels, after the argmax(-1) has been calculated
    # by preprocess_logits_for_metrics but we need to shift the labels
    labels = labels[:, 1:].reshape(-1)
    preds = preds[:, :-1].reshape(-1)
    return metric.compute(predictions=preds, references=labels)


trainer = train_toy_model.ToyModelTrainer(
    model=model,
    toy_graph=automata,
    gen_seq_len=hp["seq_len"],
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    optimizers=optimizers,
    callbacks=callbacks,
)

In [None]:
from transformers import pipeline, set_seed

generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
gen_seq = generator("", max_length=50, do_sample=True, temperature=0.7)[0][
    "generated_text"
]

## Eval

In [None]:
# model = orig_cb_model
# model = cb_model
cb_model

In [None]:
example_answer = "34"
example_prompt = f"33"  # 3 4 6 1 2
# example_prompt = f"63" # 9 8 5 7 4 6
# example_prompt = f"47" # 0 1 2 3 4 5 8
# example_prompt = f"71" # 1 3 4 5 8 0 2
# example_prompt = f"72" # 8 9 4 5 7 3
# example_prompt = f"63" # 4 5 7 1 6 8
utils.test_prompt(
    example_prompt,
    example_answer,
    cb_model,
    prepend_bos=False,
    prepend_space_to_answer=False,
    top_k=10,
)

In [10]:
cb_model.reset_hook_kwargs()
base_state = 636
base_input = cb_model.to_tokens(automata.traj_to_str([base_state]), prepend_bos=True)
base_input = base_input.to(device)
base_logits, base_cache = cb_model.run_with_cache(base_input)

In [11]:
corrupted_state = 336
corrupted_input = cb_model.to_tokens(
    automata.traj_to_str([corrupted_state]), prepend_bos=True
)
corrupted_input = corrupted_input.to(device)
corrupted_logits, corrupted_cache = cb_model.run_with_cache(corrupted_input)

In [13]:
softmax = True
# answer_tokens_both = torch.tensor([[1, 0]], device=cb_model.cfg.device)
answer_tokens_both = torch.tensor([[3, 9]], device=cb_model.cfg.device)


In [14]:
base_average_logit_diff = logits_to_ave_logit_diff(
    base_logits.softmax(dim=-1) if softmax else base_logits, answer_tokens_both
)
corrupted_average_logit_diff = logits_to_ave_logit_diff(
    corrupted_logits.softmax(dim=-1) if softmax else corrupted_logits,
    answer_tokens_both,
)

print(base_average_logit_diff, corrupted_average_logit_diff)

tensor(1.0000, device='cuda:0') tensor(-0.3566, device='cuda:0')


In [None]:
# base_cache[f'blocks.{i}.attn.codebook_layer.codebook.{0}.hook_codebook_ids']
base_output = base_cache["blocks.3.hook_mlp_out"]
cb_model.reset_hook_kwargs()
cb_model.all_codebooks[3][1].set_hook_kwargs(disable_topk=32, keep_k_codes=True)
mod_logits, mod_cache = cb_model.run_with_cache(base_input)
mod_output = mod_cache["blocks.3.hook_mlp_out"]
torch.nn.functional.cosine_similarity(base_output, mod_output, dim=-1)

In [None]:
vs = base_cache["blocks.3.mlp.codebook_layer.hook_codebook_ids"][0, 1]
vs = cb_model.all_codebooks[3][1].codebook(vs)
torch.nn.functional.cosine_similarity(vs.unsqueeze(0), vs.unsqueeze(1), dim=-1)
torch.nn.functional.pairwise_distance(base_output, mod_output, p=2)

In [None]:
from functools import partial

scores, logit_diffs = [], []
pos = [-1]
# pos = list(range(17))
for i in range(cb_model.cfg.n_layers):
    for head in list(range(cb_model.cfg.n_heads)) + [None]:
        hook_fn = partial(
            patch_codebook_ids, pos=pos, cache=base_cache
        )  # , code_idx=list(range(32)))
        cb_model.reset_codebook_metrics()
        if head is not None:
            cb_str = f"blocks.{i}.attn.codebook_layer.codebook.{head}.hook_codebook_ids"
        else:
            cb_str = f"blocks.{i}.mlp.codebook_layer.hook_codebook_ids"
        patched_logits = cb_model.run_with_hooks(
            corrupted_input, fwd_hooks=[(cb_str, hook_fn)], return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(
            patched_logits.softmax(dim=-1) if softmax else patched_logits,
            answer_tokens_both,
        )
        logit_diffs.append(patched_logit_diff)
        #     print(patched_logit_diff)
        scores.append(
            normalize_patched_logit_diff(
                patched_logit_diff,
                corrupted_average_logit_diff,
                base_average_logit_diff,
            ).item()
        )
        if patched_logit_diff < -6:
            print("pred:", i, head)
            print(logits_to_pred(patched_logits, k=5))
    #     scores.append(patched_logit_diff.item())

# scores.append(normalize_patched_logit_diff(corrupted_average_logit_diff).item())
print(scores)
print(logit_diffs)
x_label = [
    f"L{l}H{h}" if h is not None else f"L{l}MLP"
    for l in range(cb_model.cfg.n_layers)
    for h in list(range(cb_model.cfg.n_heads)) + [None]
]
line(
    scores,
    x=x_label,
    title=f"Logit Difference for Nth layers' codes for pos {pos}",
    labels={"y": "Normalized Logit Difference", "x": "Layer"},
)

## NSP Eval

In [7]:
model = cb_model

In [32]:
pos = [1, 2]
# pos = 2
# patch_types = "resid_pre"
# patch_types = "attn_out"
patch_types = ["attn_out", "mlp_out"]
# patch_types = ["mlp_out", "attn_1_head_1", "attn_1_head_2"]

all_states = [automata.token_repr(state) for state in range(automata.N)]
all_states_tokens = tokenizer(all_states, return_tensors="pt")["input_ids"].to(
    cb_model.cfg.device
)
all_states_tokens = F.pad(all_states_tokens, (1, 0), value=tokenizer.bos_token_id)
all_states_logits, all_states_cache = cb_model.run_with_cache(all_states_tokens)

# state_as = [33]
# state_bs = [11]

state_as = list(range(automata.N))
state_bs = list(range(automata.N))
state_bs_str = [automata.token_repr(state) for state in state_bs]
state_bs_tokens = tokenizer(state_bs_str, return_tensors="pt")["input_ids"].to(
    cb_model.cfg.device
)
state_bs_tokens = F.pad(state_bs_tokens, (1, 0), value=tokenizer.bos_token_id)

In [94]:
import plotly.graph_objects as go

from torch.nn import functional as F
import pandas as pd
import re

if isinstance(patch_types, str):
    patch_types = [patch_types]

common_states_in_a_b = set(state_as).intersection(set(state_bs))
len_cross_prod = len(state_as) * len(state_bs) - len(common_states_in_a_b)
nsp_patching_result = torch.zeros((model.cfg.n_layers + 1, len_cross_prod, 2), device=model.cfg.device)
kl_div_result = torch.zeros((model.cfg.n_layers + 1, len_cross_prod), device=model.cfg.device)


def remove_idx_from_tensor(tensor, idx):
    return torch.cat([tensor[:idx], tensor[idx + 1 :]])


# assert all([pt in ["attn_out", "mlp_out"] for pt in patch_types])
for layer in tqdm(range(model.cfg.n_layers + 1)):
    offset = 0
    for state_a in state_as:
        state_a_cache = {
            k: v[state_a].unsqueeze(0) for k, v in all_states_cache.items()
        }
        temp_hook_fn = partial(
            residual_stream_patching_hook, cache=state_a_cache, position=pos
        )
        if layer == model.cfg.n_layers:
            fwd_hooks = [
                (name, temp_hook_fn)
                for layer_idx in range(model.cfg.n_layers)
                for name in get_cb_layer_names(layer_idx, patch_types)
            ]
        else:
            fwd_hooks = [
                (name, temp_hook_fn) for name in get_cb_layer_names(layer, patch_types)
            ]
        state_a_tokens = model.to_tokens(str(state_a), prepend_bos=True)
        try:
            a_index_in_b = state_bs.index(state_a)
            state_bs_wo_a = remove_idx_from_tensor(state_bs_tokens, a_index_in_b)
            state_bs_str_wo_a = (
                state_bs_str[:a_index_in_b] + state_bs_str[a_index_in_b + 1 :]
            )
        except ValueError:
            state_bs_wo_a = state_bs_tokens
            state_bs_str_wo_a = state_bs_str
        
        # state b <- state a
        len_stats = len(state_bs_wo_a)
        mod_logits = model.run_with_hooks(state_bs_wo_a, fwd_hooks=fwd_hooks)
        gt_logits = (
            all_states_logits[state_a].unsqueeze(0).repeat(mod_logits.shape[0], 1, 1)
        )
        kl_div_batch = kl_div(gt_logits, mod_logits, reduction="none").sum(dim=-1)
        kl_div_result[layer, offset:offset+len_stats] = kl_div_batch
        nsp, _ = get_next_state_probs(
            state_bs_wo_a, model, automata, fwd_hooks=fwd_hooks
        )
        acc_a = correct_next_state_probs(state_a, nsp, automata, print_info="")
        acc_b = correct_next_state_probs(state_bs_str_wo_a, nsp, automata, print_info="")
        acc_a = torch.tensor(acc_a, device=model.cfg.device)
        acc_b = torch.tensor(acc_b, device=model.cfg.device)

        nsp_patching_result[layer, offset:offset+len_stats, 0] = acc_a
        nsp_patching_result[layer, offset:offset+len_stats, 1] = acc_b
        offset += len_stats
    assert offset == len_cross_prod

nsp_patching_result *= 100

df = pd.DataFrame(
    {
        "Layer": [str(i) for i in range(model.cfg.n_layers)] + ["All"],
        "NSP Acc Sa": nsp_patching_result[:, :, 0].mean(dim=1).tolist(),
        "NSP Acc Sb": nsp_patching_result[:, :, 1].mean(dim=1).tolist(),
        "KL Div": kl_div_result.mean(dim=1).tolist(),
        "NSP Acc Sa Std": nsp_patching_result[:, :, 0].std(dim=1).tolist(),
        "NSP Acc Sb Std": nsp_patching_result[:, :, 1].std(dim=1).tolist(),
        "KL Div Std": kl_div_result.std(dim=1).tolist(),
    }
)
df["NSP Acc Sa Std"] /= 2
df["NSP Acc Sb Std"] /= 2

patch_types_res = ""
for i, patch_type in enumerate(patch_types):
    attn_head = re.match(r"attn_(\d)_head_(\d)", patch_type)
    mlp = re.match(r"mlp_(\d)", patch_type)
    if i > 0:
        patch_types_res += " & "
    if patch_type == "attn_out":
        patch_types_res += "Attn"
    elif patch_type == "mlp_out":
        patch_types_res += "MLP"
    elif patch_type == "resid_pre":
        patch_types_res += "Pre-Residual"
    elif patch_type == "resid_post":
        patch_types_res += "Post-Residual"
    elif attn_head:
        patch_types_res += f"Attn {attn_head.group(1)} Head {attn_head.group(2)}"
    elif mlp:
        patch_types_res += f"MLP {mlp.group(1)}"

    else:
        raise ValueError(f"Unknown patch type: {patch_type}")
if isinstance(pos, int):
    pos_str = str(pos)
else:
    pos_str = str(pos[0]) if len(pos) == 1 else f"{pos[0]}-{pos[1]}"

# create a blank figure
fig = go.Figure()
fig.add_trace(go.Scatter(x=df["Layer"], y=df["NSP Acc Sa"], error_y=dict(type='data', array=df["NSP Acc Sa Std"]), name="NSP Accuracy (S<sub>A</sub>)"))
fig.add_trace(go.Scatter(x=df["Layer"], y=df["NSP Acc Sb"], error_y=dict(type='data', array=df["NSP Acc Sb Std"]), name="NSP Accuracy (S<sub>B</sub>)"))
# Create the plot using plotly express
# fig = px.line(df, x="Layer", y=df.columns[1:3]) # color="Metric"
title = f"NSP and KL Div Result. Patching S<sub>A</sub> &#8594; S<sub>B</sub> codes of {patch_types_res} at position {pos_str}"
fig.update_layout(title={"text": title, "x": 0.5, "xanchor": "center"})

# # Add a secondary y-axis for KL Divergence
fig.add_trace(
    go.Scatter(
        x=df["Layer"],
        y=df["KL Div"],
        error_y=dict(type='data', array=df["KL Div Std"]),
        name="Token 3 KL Div with S<sub>A</sub>",
        yaxis="y2",
    )
)

# Set the left y-axis title to NSP Accuracy
fig.update_layout(yaxis=dict(title="NSP Accuracy", range=(-3, 103)))

# Set the right y-axis title to KL Divergence
fig.update_layout(
    yaxis2=dict(title="KL Divergence", overlaying="y", side="right", range=[-0.2, 6])
)
# fig.update_yaxes(range=['auto', 'auto'])
# Add a custom legend title
fig.update_layout(legend=dict(title="Metric"))

fig.show()

  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [02:10<00:00, 26.11s/it]


In [8]:
tokenizer(["123456","12", "1234"], return_tensors="pt", padding=True)["input_ids"]

tensor([[ 1,  2,  3,  4,  5,  6],
        [ 1,  2, 10, 10, 10, 10],
        [ 1,  2,  3,  4, 10, 10]])

In [7]:
rev_automata = automata.reverse()

is_trigram = True
prefix_random_states_len = 0
plot_code_grp_distr = True

repeat = 3 if is_trigram else 2

js_divs = {}
all_state_info = {}

chars = [str(i) for i in range(automata.representation_base)]
all_valid_inputs = [''.join(combination) for combination in itertools.product(chars, repeat=repeat) if valid_input(''.join(combination), automata)]
all_valid_inputs_tokens = tokenizer(all_valid_inputs, return_tensors="pt")["input_ids"].to(
    cb_model.cfg.device
)
if prefix_random_states_len > 0:
    start_states = [s[0] for s in automata.seq_to_traj(all_valid_inputs)]
    random_state_prefix = rev_automata.generate_trajectories(prefix_random_states_len, start_states=start_states)
    random_state_prefix = random_state_prefix[:, ::-1]
    random_state_prefix = random_state_prefix.astype(int)
    random_state_prefix = [automata.traj_to_str(traj) for traj in random_state_prefix]
    random_state_prefix = tokenizer(random_state_prefix, return_tensors="pt")["input_ids"].to(
        cb_model.cfg.device
    )
    all_valid_inputs_tokens = torch.cat([random_state_prefix, all_valid_inputs_tokens], dim=1)
all_valid_inputs_tokens = F.pad(all_valid_inputs_tokens, (1, 0), value=tokenizer.bos_token_id)

for input in tqdm(all_valid_inputs):
    input_tensor = cb_model.to_tokens(str(input), prepend_bos=True).to(device)
    logits, cache = cb_model.run_with_cache(input_tensor)
    all_state_info[input] = (logits, cache)

for iter_a, input_a in enumerate(all_valid_inputs):
    for input_b in all_valid_inputs[iter_a+1:]:
        js_divs[(input_a, input_b)] = JSD(all_state_info[input_a][0], all_state_info[input_b][0]).item()

avg_js_div = sum(js_divs.values()) / len(js_divs)

if plot_code_grp_distr:
    code_groups_for_all_comps = {}
    for layer in tqdm(range(cb_model.cfg.n_layers)):
        for ccb_num in range(cb_model.cfg.n_heads):
            code_groups_for_all_comps[(layer, "attn", ccb_num)] = partition_input_on_codebook(cb_model=cb_model, automata=automata, layer=layer, cb_at="attn", ccb_num=ccb_num, input_len=3)
        code_groups_for_all_comps[(layer, "mlp", None)] = partition_input_on_codebook(cb_model=cb_model, automata=automata, layer=layer, cb_at="mlp", ccb_num=None, input_len=3)

all_states_logits = torch.cat([v[0] for v in all_state_info.values()], dim=0)
all_states_cache = {k: v[1] for k, v in all_state_info.items()}

100%|██████████| 686/686 [00:16<00:00, 41.25it/s]
  0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
if plot_code_grp_distr:
    code_groups_for_all_comps = {}
    for layer in tqdm(range(cb_model.cfg.n_layers)):
        for ccb_num in tqdm(range(cb_model.cfg.n_heads)):
            code_groups_for_all_comps[(layer, "attn", ccb_num)] = partition_input_on_codebook(cb_model=cb_model, automata=automata, layer=layer, cb_at="attn", ccb_num=ccb_num, input_len=3)
        code_groups_for_all_comps[(layer, "mlp", None)] = partition_input_on_codebook(cb_model=cb_model, automata=automata, layer=layer, cb_at="mlp", ccb_num=None, input_len=3)

In [None]:
plot_js_div(code_groups_for_all_comps, 3, "mlp", None, js_divs, show_plot=True, image_name_prefix=None)

In [14]:
patchings_to_plot_orig = ["none", "l0_attn", "l1_mlp", "all_attn", "all_mlp", "all_attn_mlp"]
patchings_to_plot = [s.replace("all", "l0,l1,l2,l3") for s in patchings_to_plot_orig]

js_divs = {k: 0 for k in patchings_to_plot}
js_divs["none"] = 0

for state_b, input_b in tqdm(enumerate(all_valid_inputs)):
    js_divs_w_b = JSD(all_states_logits, all_states_logits[state_b].unsqueeze(0), reduction="none").sum() / (automata.N - 1) # removing b as JSD(b,b) = 0
    js_divs["none"] += js_divs_w_b

    for patching in patchings_to_plot:
        if patching == "none":
            continue
        cb_at = patching.split("_")[1:]
        layers = get_layers_from_patching_str(patching)
        heads = [None] * len(cb_at)
        if "attn" in cb_at:
            attn_idx = cb_at.index("attn")
            cb_at.pop(attn_idx), heads.pop(attn_idx)
            cb_at += ["attn"] * cb_model.cfg.n_heads
            heads += list(range(cb_model.cfg.n_heads))
        cb_at_rep = cb_at * len(layers)
        heads_rep = heads * len(layers)
        layers_rep = []
        for l in layers:
            layers_rep += [l] * len(cb_at)
        
        cache_b = all_state_info[input_b][1]
        code = [cache_b[get_cb_layer_name(cb_at_rep[i], layers_rep[i], heads_rep[i])][0, -1, :] for i in range(len(cb_at_rep))]

        mod_logits, mod_cache = run_with_codes(
            all_valid_inputs_tokens,
            cb_model,
            code,
            cb_at_rep,
            layers_rep,
            heads_rep,
            pos=[-1],
        )
        js_divs_w_b = JSD(mod_logits, all_states_logits[state_b].unsqueeze(0), reduction="none").sum() / (automata.N - 1) # removing b as JSD(b,b) = 0
        js_divs[patching] += js_divs_w_b

js_divs = [js_divs[k].item() / automata.N for k in patchings_to_plot]
js_divs = [js_div / max(js_divs) for js_div in js_divs]

# plot js_divs using plotly
fig = go.Figure()
x_labels = [clean_patching_name(patching) for patching in patchings_to_plot_orig]
fig.add_trace(go.Bar(x=x_labels, y=js_divs))
fig.update_layout(title=f'JS Div on Code Patching for {"Trigrams" if is_trigram else "Bigrams"}', xaxis_title='Code Patching Components', yaxis_title='Normalized JS Div')
fig.show()

100it [00:20,  4.91it/s]


In [16]:
torch.stack([torch.zeros(3,5), torch.zeros(3,5)], dim=0).shape

torch.Size([2, 3, 5])

In [19]:
for layer in range(cb_model.cfg.n_layers):
    for cb_at in ["attn", "mlp"]:
        for ccb_num in range(cb_model.cfg.n_heads) if cb_at == "attn" else [None]:
            plot_js_div(code_groups_for_all_comps, layer, cb_at, ccb_num, js_divs, show_plot=False, image_name_prefix="k1_model")

In [50]:
acc = first_transition_accuracy(cb_model, automata, fwd_hooks=None, prepend_bos=True)
print(acc)

100%|██████████| 100/100 [00:19<00:00,  5.09it/s]

0.9750000000000004





## Token Code Maps

In [27]:
model.reset_hook_kwargs()
model.set_hook_kwargs(idx=[2, 3], disable_topk=1, keep_k_codes=False)

In [8]:
max_samples = 10 * 1024
model = orig_cb_model
model.enable_logging()
model.reset_codebook_metrics()
train_dataset_tkns = train_toy_model.ToyDataset(
    automata,
    tokenizer=tokenizer,
    seq_len=128,
    max_samples=max_samples,
    save_tokens=True,
)

In [None]:
trainer.args.dataloader_num_workers = 0
trainer.args.report_on = "none"
trainer.model = model
codebook_acts = {}


def store_cb_activations(key, codebook_ids, codebook_acts=codebook_acts):
    assert len(codebook_ids.shape) == 3  # (bs, seq_len, k_codebook)
    if key not in codebook_acts:
        codebook_acts[key] = []
    codebook_acts[key].append(codebook_ids)

if isinstance(model, models.CodebookModel):
    model.set_hook_fn(store_cb_activations)
trainer.model = model

metrics = trainer.evaluate(eval_dataset=train_dataset_tkns)
print(metrics)
tokens = np.vstack(train_dataset_tkns.tokens)

if isinstance(model, models.CodebookModel):
    print(codebook_acts.keys())
    cb_acts = codebook_acts
    num_codes = 10000
    from tqdm import tqdm
    for k, v in codebook_acts.items():
        v_len_to_get = max_samples // v[0].shape[0]
        cb_acts[k] = np.concatenate(v[:v_len_to_get], axis=0)

    print(tokens.shape)
    print(cb_acts["layer1_attn_preproj_ccb0"].shape)
    print(cb_acts["layer2_mlp"].shape)

In [10]:
{'eval_loss': 1.2997702360153198, 'eval_accuracy': 0.4505251906988189, 'eval_runtime': 38.5728, 'eval_samples_per_second': 265.472, 'eval_steps_per_second': 0.518, 'eval_transition_accuracy': 0.5843548387096774, 'eval_first_transition_accuracy': 0.95, 'eval_multicode_k': 1, 'eval_dead_code_fraction/layer0': 0.99624, 'eval_MSE/layer0': 220431.78924572692, 'eval_input_norm/layer0': 333.8144458924039, 'eval_output_norm/layer0': 12.890674110075121, 'eval_dead_code_fraction/layer1': 0.94042, 'eval_MSE/layer1': 132.1956132092747, 'eval_input_norm/layer1': 6.530500547282445, 'eval_output_norm/layer1': 13.093124667135442, 'eval_dead_code_fraction/layer2': 0.79874, 'eval_MSE/layer2': 349.0161221121163, 'eval_input_norm/layer2': 6.156233979892996, 'eval_output_norm/layer2': 17.864714929908395, 'eval_dead_code_fraction/layer3': 0.93764, 'eval_MSE/layer3': 349.25655640419984, 'eval_input_norm/layer3': 7.597045652759595, 'eval_output_norm/layer3': 16.90762827033095}

orig_cb_model.disable_codebooks()

In [None]:
d1 = {'eval_loss': 1.5225844383239746, 'eval_accuracy': 0.4184462659940945, 'eval_transition_accuracy': 0.5582258064516129, 'eval_first_transition_accuracy': 0.95,'eval_dead_code_fraction/layer0': 0.99698, 'eval_dead_code_fraction/layer1': 0.97624, 'eval_dead_code_fraction/layer2': 0.96486, 'eval_dead_code_fraction/layer3': 0.98734}
d2 = {'eval_loss': 1.29947030544281, 'eval_accuracy': 0.4502691313976378, 'eval_transition_accuracy': 0.6043548387096774, 'eval_first_transition_accuracy': 0.94, 'eval_dead_code_fraction/layer0': 0.99624, 'eval_dead_code_fraction/layer1': 0.93898, 'eval_dead_code_fraction/layer2': 0.78164, 'eval_dead_code_fraction/layer3': 0.93472}
d3 = {k: v for k, v in d2.items() if k in d1}
print(d3)

In [None]:
import torch.nn as nn


def remove_dead_codes_from_codebook(codebook):
    alive_codes = codebook.counts > 0
    new_embedding = nn.Embedding(alive_codes.sum(), codebook.codebook.embedding_dim)
    new_embedding.weight.data.copy_(codebook.codebook.weight.data[alive_codes])
    new_embedding.to(codebook.codebook.weight.device)
    codebook.codebook = new_embedding

    codebook.counts = codebook.counts[alive_codes]
    codebook._num_codes = alive_codes.sum().item()


def remove_dead_codes():
    for k, v in model.all_codebooks.items():
        for h in range(4):
            remove_dead_codes_from_codebook(v[0].codebook[h])
        remove_dead_codes_from_codebook(v[1])


remove_dead_codes()

In [None]:
for k, v in model.all_codebooks.items():
    print("attn")
    for h in range(4):
        # print(k, h, v[0].codebook[h].num_codes, v[0].codebook[h].counts)
        cb_model.all_codebooks[k][0] = v[0]
        cb_model.model.blocks[k].attn.codebook_layer = v[0]

    print("mlp")
    cb_model.all_codebooks[k][1] = v[1]
    cb_model.model.blocks[k].mlp.codebook_layer = v[1]
    # print(k, v[1].counts)

In [None]:
for k, v in model.all_codebooks.items():
    print("attn")
    for h in range(4):
        print(k, h, v[0].codebook[h].num_codes, v[0].codebook[h].counts)

    print("mlp")
    print(k, v[1].counts)

In [34]:
trigram_partition = partition_input_on_codebook(cb_model, automata, "mlp", 1, None, input_len=3)
multi_input_codes = [code for code, inputs in trigram_partition.items() if len(inputs) > 1]
print(len(multi_input_codes), len(trigram_partition), len(multi_input_codes) / len(trigram_partition))

print("Unique Trigrams:", len(trigram_partition) - len(multi_input_codes))

chars = [str(c) for c in range(automata.representation_base)]
input_range = itertools.product(chars, repeat=3)
input_range = ["".join(inp_tuple) for inp_tuple in input_range]
input_range = [inp for inp in input_range if valid_input(inp, automata)]
print("Valid Trigrams:", len(input_range))

57 603 0.0945273631840796


In [None]:
partition_input_on_codebook(cb_model, automata, "mlp", 2, None, input_len=3)

In [61]:
mlp3_partition = partition_input_on_codebook(cb_model, automata, "mlp", 3, None, input_len=3)

In [None]:
mlp3_partition

In [62]:
code_predicts = {}
for code, trigram_inputs in mlp3_partition.items():
    next_token_votes = np.zeros(automata.representation_base)
    for trigram_input in trigram_inputs:
        nbrs = automata.nbrs(int(trigram_input[:automata.digits]))
        token_repr_nbrs = [automata.token_repr(nbr) for nbr in nbrs]
        valid_nbrs = [nbr for nbr in token_repr_nbrs if nbr[0] == trigram_input[-1]]
        for nbr in valid_nbrs:
            next_token_votes[int(nbr[1])] += 1
    next_token = next_token_votes.argmax()
    code_predicts[code] = (next_token, f'{100*next_token_votes[next_token]/len(trigram_inputs):.2f}%')

In [63]:
code_predicts

{6111: (9, '100.00%'),
 3556: (9, '100.00%'),
 9199: (1, '100.00%'),
 9528: (8, '87.50%'),
 3028: (2, '100.00%'),
 4142: (6, '100.00%'),
 7597: (7, '100.00%'),
 5024: (4, '87.50%'),
 7743: (0, '100.00%'),
 4138: (4, '100.00%'),
 9703: (2, '100.00%'),
 9031: (5, '100.00%'),
 4272: (1, '100.00%'),
 9511: (6, '100.00%'),
 9562: (8, '100.00%'),
 547: (9, '100.00%'),
 2736: (2, '75.00%'),
 3433: (3, '100.00%'),
 9294: (0, '100.00%'),
 6500: (5, '100.00%'),
 6598: (0, '100.00%'),
 5987: (0, '50.00%'),
 7898: (2, '100.00%'),
 3262: (0, '100.00%'),
 5463: (9, '100.00%'),
 7506: (3, '100.00%'),
 763: (8, '100.00%'),
 6060: (3, '100.00%'),
 7614: (0, '100.00%'),
 5823: (2, '100.00%'),
 1283: (2, '100.00%'),
 5287: (4, '100.00%'),
 1722: (1, '100.00%'),
 1812: (7, '100.00%'),
 9414: (8, '100.00%'),
 8730: (3, '100.00%'),
 2979: (3, '100.00%'),
 9947: (8, '100.00%'),
 1208: (7, '100.00%'),
 5365: (6, '100.00%'),
 2012: (7, '100.00%'),
 5700: (2, '100.00%'),
 1768: (0, '100.00%'),
 8444: (1, '80.00

In [22]:

codes_to_disable = {}
for k in all_cache['0'].keys():
    if "codebook" not in k:
        continue
    enabled_codes = set()
    for input, cache in all_cache.items():
        enabled_codes.update(cache[k][0, -1, :].tolist())
    codes_to_disable[k] = set(range(10000)) - enabled_codes

codes_to_disable = {k: list(v) for k, v in codes_to_disable.items()}

In [24]:
orig_cb_model.reset_hook_kwargs()

for key, codes in codes_to_disable.items():
    layer, attn_or_mlp, head = get_codebook_info_from_hook_key(key)
    layer_codebooks = orig_cb_model.all_codebooks[layer]
    if attn_or_mlp == "attn":
        layer_codebooks[0].set_hook_kwargs(disable_codes=codes, head_idx=head)
    else:
        layer_codebooks[1].set_hook_kwargs(disable_codes=codes)


In [94]:
base_str = "63"
base_input = cb_model.to_tokens(base_str, prepend_bos=True)
print(base_input)
base_input = base_input.to(device)

base_logits, base_cache = cb_model.run_with_cache(base_input)
print(logits_to_pred(base_logits, tokenizer))

tensor([[10,  6,  3]], device='cuda:0')
[('4', 0.5618906617164612), ('5', 0.13749220967292786), ('6', 0.09639862924814224), ('8', 0.07662083208560944), ('7', 0.07236792147159576)]


In [79]:
automata.nbrs_to(63)

array([ 3,  5, 26, 35, 43, 44, 46, 49, 63, 95])

In [113]:
inp = "-25-63"
inp = cb_model.to_tokens(inp, prepend_bos=True).to(device)
nsp = get_next_state_probs(inp, cb_model, automata, None, False)
print(correct_next_state_probs(63, nsp, automata))

mod_logits, mod_cache = cb_model.run_with_cache(inp)


[0.5]


In [71]:
inp = "-64-96-07-96-94-84-89-26-63-51-44-01-81-43-08-11-70-39-25-10-34-89-36-42-11-07-28-05-00-09-49-63"
cb_model.to_tokens(inp, prepend_bos=True)

tensor([[10, 10,  6,  4,  9,  6,  0,  7,  9,  6,  9,  4,  8,  4,  8,  9,  2,  6,
          6,  3,  5,  1,  4,  4,  0,  1,  8,  1,  4,  3,  0,  8,  1,  1,  7,  0,
          3,  9,  2,  5,  1,  0,  3,  4,  8,  9,  3,  6,  4,  2,  1,  1,  0,  7,
          2,  8,  0,  5,  0,  0,  0,  9,  4,  9,  6,  3]], device='cuda:0')

In [66]:
nsp

(tensor([[46, 67, 69, 14, 34, 33, 36, 42, 23, 31]], device='cuda:0'),
 tensor([[0.1361, 0.1094, 0.1061, 0.0876, 0.0761, 0.0687, 0.0686, 0.0674, 0.0396,
          0.0345]], device='cuda:0'))

In [58]:
base_logits[0, -1, :10]

tensor([-4.0287, -2.7718, -2.8299, -1.6859,  2.8451,  1.4373,  1.0823,  0.7955,
         0.8526,  0.3121], device='cuda:0')

In [11]:
automata.nbrs(63)


array([45, 51, 57, 63, 73, 82, 89, 94, 96, 98])

In [24]:
layer = 1
# ccb_num = 0
# cb_at = "attn"
cb_at = "mlp"
if cb_at == "attn":
    ccb = f"_preproj_ccb{ccb_num}"
    indices = base_cache[
        f"blocks.{layer}.attn.codebook_layer.codebook.{ccb_num}.hook_codebook_ids"
    ][0, -1].tolist()
else:
    ccb = ""
    indices = base_cache[f"blocks.{layer}.mlp.codebook_layer.hook_codebook_ids"][
        0, -1
    ].tolist()

cb_str = f"layer{layer}_{cb_at}{ccb}"

if (not cb_str in ft_tkns) or True:
    ft_tkns[cb_str] = features_to_tokens(cb_str, cb_acts, num_codes)

100%|██████████| 10240/10240 [00:00<00:00, 17830.20it/s]


In [None]:
print_ft_tkns(
    ft_tkns[cb_str],
    tokens=tokens,
    tokenizer=tokenizer,
    n=5,
    indices=indices,
    max_examples=100,
)

In [None]:
for layer in range(cb_model.cfg.n_layers):
    for head in range(cb_model.cfg.n_heads):
        print(
            layer,
            head,
            start_state_activations(63, cb_model, automata, cb_at="attn", layer=layer, ccb_num=head),
        )
    print(layer, start_state_activations(63, cb_model, automata, cb_at="mlp", layer=layer, ccb_num=None))

In [None]:
cb_model.reset_hook_kwargs()
for layer in range(4):
    if layer == 0:
        cb_model.all_codebooks[layer][0].set_hook_kwargs(
            head_idx=[0, 3], disable_topk=1, disable_for_tkns=[1], keep_k_codes=False
        )
    else:
        cb_model.all_codebooks[layer][0].set_hook_kwargs(
            head_idx=[3], disable_topk=1, disable_for_tkns=[1], keep_k_codes=False
        )

mod_logits, mod_cache = cb_model.run_with_cache("63", prepend_bos=False)
print(logits_to_pred(mod_logits, tokenizer, k=5))

In [None]:
from functools import partial

# print(generate_with_codes("<|endoftext|>00", [5], ["attn"], layer_idx=[2], head_idx=[3], pos=[-1]))
disable_other_comps = True
code = [8268]
cb_at = ["mlp"]
layer_idx = [3]
head_idx = [1]
pos = [-1]
list_of_arg_tuples = [
    CodeInfoTuple(code[i], cb_at[i], layer_idx[i], head_idx[i], pos[i])
    for i in range(len(code))
]
print(
    generate_with_codes(
        "<|endoftext|>33",
        list_of_arg_tuples=list_of_arg_tuples,
        disable_other_comps=disable_other_comps,
    )
)
# print(generate_with_codes("<|endoftext|>011", list_of_arg_tuples=list_of_arg_tuples, disable_other_comps=disable_other_comps))
# print(generate_with_codes("<|endoftext|>10", list_of_arg_tuples=list_of_arg_tuples, disable_other_comps=disable_other_comps))
# print(generate_with_codes("<|endoftext|>110", list_of_arg_tuples=list_of_arg_tuples, disable_other_comps=disable_other_comps))

In [11]:
def norm_distance(logits1, logits2, p=1):
    p1 = torch.softmax(logits1, dim=-1)
    p2 = torch.softmax(logits2, dim=-1)
    return torch.norm(p1 - p2, p=p, dim=-1).mean().item()

In [12]:
input_a = "11"
input_b = "63"

input_a = cb_model.to_tokens(input_a, prepend_bos=True).to(device)
input_b = cb_model.to_tokens(input_b, prepend_bos=True).to(device)

logits_a, cache_a = cb_model.run_with_cache(input_a)
logits_b, cache_b = cb_model.run_with_cache(input_b)

print(f"JSD(a, b) = {JSD(logits_a, logits_b)}")
print(f"L1(a, b) = {norm_distance(logits_a[0, -1, :], logits_b[0, -1, :], p=1)}")

JSD(a, b) = 0.28124862909317017
L1(a, b) = 1.2098653316497803


In [22]:
n_layers = 1
cb_at = ( ["attn"] * 4 ) * n_layers
head = [0, 1, 2, 3,] * n_layers
layer = [0] * 4 #+ [1] * 4 + [2] * 4 + [3] * 4
# cb_at = ["mlp"] * 3 + ["attn"] * 3
# layer = [1,2,3,1,2,3]
# head = [None] * 3 + [1,3,3]
code = [cache_b[get_cb_layer_name(cb_at[i], layer[i], head[i])][0, -1, :] for i in range(len(cb_at))]
# ind = [0,1,2,3,4,5]
ind = range(len(cb_at))
mod_logits, mod_cache = run_with_codes(
    input_a,
    cb_model,
    [code[i] for i in ind],
    [cb_at[i] for i in ind],
    [layer[i] for i in ind],
    [head[i] for i in ind],
    pos=[-1],
)
print(logits_to_pred(mod_logits, tokenizer, k=5))

print(f"JSD(a <- b, b) = {JSD(mod_logits, logits_b, pos=-1)}")
print(f"JSD(a <- b, a) = {JSD(mod_logits, logits_a, pos=-1)}")

print(f"L1(a <- b, b) = {norm_distance(mod_logits[0, -1, :], logits_b[0, -1, :], p=1)}")
print(f"L1(a <- b, a) = {norm_distance(mod_logits[0, -1, :], logits_a[0, -1, :], p=1)}")

[('4', 0.9099664688110352), ('5', 0.026785491034388542), ('1', 0.02551252394914627), ('7', 0.010723251849412918), ('9', 0.009922088123857975)]
JSD(a <- b, b) = 0.11190739274024963
JSD(a <- b, a) = 0.38374847173690796
L1(a <- b, b) = 0.7501115798950195
L1(a <- b, a) = 1.5745123624801636


In [None]:
# run_with_codes for every layer and plot the JSD and L1 distance

jsds = []
l1ds = []

for layer in range(cb_model.cfg.n_layers):
    cb_at = ["attn"] * 4 + ["mlp"] * 1
    head = [0, 1, 2, 3, None]
    code = [cache_b[get_cb_layer_name(cb_at[i], layer, head[i])][0, -1, :] for i in range(len(cb_at))]
    ind = range(len(cb_at))
    mod_logits, mod_cache = run_with_codes(
        input_a,
        cb_model,
        [code[i] for i in ind],
        [cb_at[i] for i in ind],
        [layer] * len(cb_at),
        [head[i] for i in ind],
        pos=[-1],
    )
    jsds.append(JSD(mod_logits, logits_b, pos=-1).item())
    l1ds.append(norm_distance(mod_logits[0, -1, :], logits_b[0, -1, :], p=1))

fig = go.Figure()
fig.add_trace(go.Scatter(x=list(range(cb_model.cfg.n_layers)), y=jsds, mode="lines+markers"))
fig.update_layout(title="JSD(a <- b, b)", xaxis_title="Layer", yaxis_title="JSD")

# add l1 dist on y2 axis
fig.add_trace(go.Scatter(x=list(range(cb_model.cfg.n_layers)), y=l1ds, mode="lines+markers", yaxis="y2"))
fig.update_layout(yaxis2=dict(title="L1 distance", overlaying="y", side="right"))
fig.show()

In [None]:
find_code_changes(mod_cache, cache_b)
print("****************")
find_code_changes(mod_cache, cache_a)