In [None]:
! pip uninstall -y transformer_lens
! pip install git+https://github.com/taufeeque9/TransformerLens/
! pip install git+https://github.com/minyoungg/vqtorch/
! pip install termcolor
! pip install -U accelerate
! pip install -U kaleido

In [1]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [1]:
from termcolor import colored
import plotly.express as px
import transformers
import codebook_features
import torch
import evaluate
import numpy as np
import copy
import wandb
import json
import transformer_lens.utils as utils
from collections import namedtuple
from functools import partial
from torch.nn import functional as F

from transformers import GPTNeoXConfig, GPTNeoXForCausalLM, GPT2TokenizerFast, pipeline, set_seed
from torch.utils.data import IterableDataset
from codebook_features import models, run_clm, train_toy_model, trainer as cb_trainer
from codebook_features.utils import *
from codebook_features.toy_utils import *
import os

torch.set_grad_enabled(False)

def logits_to_pred(logits, k=5):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    probs = sorted_logits.softmax(dim=-1)
    topk_preds = [tokenizer.convert_ids_to_tokens(e) for e in sorted_indices[:, -1, :k]]
    topk_preds = [tokenizer.convert_tokens_to_string([e]) for batch in topk_preds for e in batch]
    return [(topk_preds[i], probs[:, -1, i].item()) for i in range(len(topk_preds))]


The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path="config", config_name="toy_main")


In [2]:
# Hyperparameters
hp = dict(
    run_name = "cb_model_neox",
    tags = [],
    num_states = 100,
    num_edges = 10,
    seq_len = 128,
    vocab_size = 11,
#     hidden_size = 64,
#     intermediate_size = 256,
    hidden_size = 128,
    intermediate_size = 512,
    num_hidden_layers = 4,
    num_attention_heads = 4,
    rotary_emb_base = 10000,
)

In [None]:
# Hyperparameters for 4circle
hp = dict(
    run_name = "cb_model_neox",
    tags = [],
    num_states = 4,
    num_edges = 1,
    seq_len = 128,
    vocab_size = 3,
#     hidden_size = 64,
#     intermediate_size = 256,
    hidden_size = 16,
    intermediate_size = 64,
    num_hidden_layers = 1,
    num_attention_heads = 1,
    rotary_emb_base = 10000,
)

## Dataset Generation

In [None]:
# pretrained on 4circle
base_path = "/data/outputs/2023-06-02/03-05-17/"
checkpoint = "checkpoint-6500/"

In [7]:
# pretrained on 100s
base_path = "/data/outputs/2023-06-02/03-38-51/"
checkpoint = "checkpoint-9500/"

In [16]:
# ft only attn 100s
base_path = "/data/outputs/2023-06-07/21-33-13/"
checkpoint = "checkpoint-18000/"

In [3]:
# ft on multi k 100s
base_path = "/data/outputs/2023-06-08/12-50-15/"
checkpoint = "checkpoint-13000/"

In [None]:
# ft-cb on 4circle
base_path = "/data/outputs/2023-05-30/11-59-03/"
checkpoint = "checkpoint-500/"

In [4]:
# pretrained on 100 states
base_path = "/data/codebook-features/codebook_features/outputs/2023-05-25/12-56-03/"
checkpoint = "checkpoint-4500/"

In [4]:
# ft on 100 states
base_path = "/data/outputs/2023-06-02/06-12-08/"
checkpoint = "checkpoint-15500/"

In [None]:
base_path = "/data/codebook-features/codebook_features/outputs/2023-05-20/13-27-44/"
checkpoint = ""

In [None]:
# ft-cb on 100 states
base_path = "/data/codebook-features/codebook_features/outputs/2023-05-26/11-01-04/"
checkpoint = "checkpoint-17500/"

In [4]:
tokenizer = train_toy_model.create_tokenizer(base_path+"toy", hp["vocab_size"])
# automata = train_toy_model.ToyGraph(N=hp["num_states"], edges=hp["num_edges"], seed=42)
automata = train_toy_model.ToyGraph.load(base_path + "toy/automata.npy",representation_base=hp["vocab_size"]-1, seed=42)
train_dataset = train_toy_model.ToyDataset(automata, tokenizer=tokenizer, seq_len=hp["seq_len"])
eval_dataset = train_toy_model.ToyDataset(automata, tokenizer=tokenizer, seq_len=hp["seq_len"], max_samples=2048)

Vocab:
{"0": 0, "1": 1, "2": 2, "3": 3, "4": 4, "5": 5, "6": 6, "7": 7, "8": 8, "9": 9, "<|endoftext|>": 10}


## Load Model

In [9]:
# base model
device = "cuda"
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(base_path + "output_toy/" + checkpoint)
model = model.to(device).eval()
config = None

In [None]:
gen_seq = tokenizer.decode(model.generate(max_length=hp["seq_len"], do_sample=True)[0])
traj = automata.seq_to_traj(gen_seq)
acc, _ = automata.transition_accuracy(traj)
print(gen_seq)
print(traj)
print(acc)

In [5]:
# cb model
device = "cpu"
config = GPTNeoXConfig(vocab_size=hp["vocab_size"], hidden_size=hp["hidden_size"], num_hidden_layers=hp["num_hidden_layers"], num_attention_heads=hp["num_attention_heads"], intermediate_size=hp["intermediate_size"], rotary_emb_base=hp["rotary_emb_base"], bos_token_id=hp["vocab_size"]-1, eos_token_id=hp["vocab_size"]-1, max_position_embeddings=hp["seq_len"])
config.architectures = ["GPTNeoXForCausalLM"]
model = GPTNeoXForCausalLM(config=config)
model = model.to(device).eval()
orig_cb_model = models.wrap_codebook(model_or_path=model, pretrained_path=base_path+f"output_toy/{checkpoint}")
orig_cb_model = orig_cb_model.to(device).eval()

Time: 7.152557373046875e-06 5.364418029785156e-05 2.315725088119507


In [6]:
# hooked model
hooked_kwargs = dict(center_unembed=False,center_writing_weights=False,fold_ln=False,fold_value_biases=False,refactor_factored_attn_matrices=False,device=device)
cb_model = models.convert_to_hooked_model_for_toy(base_path+f"output_toy/{checkpoint}", orig_cb_model, config=config, hooked_kwargs=hooked_kwargs)
cb_model = cb_model.to(device).eval()

Time: 5.7220458984375e-06 2.7179718017578125e-05 1.4187428951263428


In [8]:
report_to = "none"
# report_to = "all"
training_args = run_clm.TrainingArguments(
#     no_cuda=True,
    output_dir="toy/output",
    do_train=True,
    do_eval=True,
    evaluation_strategy="steps",
    per_device_train_batch_size=512,
    per_device_eval_batch_size=512,
    learning_rate=1e-3,
#     weight_decay=1e-1,
    max_steps=20000,
#     lr_scheduler_type="linear",
    lr_scheduler_type="constant",
    warmup_ratio=0.1,
    logging_first_step=True,
    logging_steps=500,
    eval_steps=500,
    overwrite_output_dir=True,
    seed=42,
    train_model_params=True,
    model_lr_factor=1.0,
    report_to=report_to,
    dataloader_num_workers=8,

)

cfg_dict = {"training_args": training_args.__dict__, "model_args": config.__dict__ if config else None}
cfg_dict = {**hp, **cfg_dict}
model_args = run_clm.ModelArguments(model_name_or_path="toy/model")
data_args = run_clm.DataTrainingArguments(dataset_name="toy_graph", max_eval_samples=2048)

optimizers = (None, None)
if isinstance(model, models.CodebookModel):
    if training_args.train_model_params:
        params = [
            {
                "params": model.get_codebook_params(),
                "lr": training_args.learning_rate,
                # weight decay for codebook params is used through
                # `codebook_weight_decay` param that is used directly
                # to compute regularized loss.
                "weight_decay": 0.0,
            },
            {
                "params": model.get_model_params(),
                "lr": training_args.model_lr_factor * training_args.learning_rate,
                "weight_decay": training_args.weight_decay,
            },
        ]
    else:
        params = model.get_codebook_params()
    if len(params) > 0:
        optimizer = torch.optim.AdamW(
            params,
            training_args.learning_rate,
        )
        optimizers = (optimizer, None)

callbacks = []
# if report_to == "all":
#     callbacks = [cb_trainer.WandbCallback()]

def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        # Depending on the model and config, logits may contain extra tensors,
        # like past_key_values, but logits always come first
        logits = logits[0]
    return logits.argmax(dim=-1)

metric = evaluate.load("accuracy")
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # preds have the same shape as the labels, after the argmax(-1) has been calculated
    # by preprocess_logits_for_metrics but we need to shift the labels
    labels = labels[:, 1:].reshape(-1)
    preds = preds[:, :-1].reshape(-1)
    return metric.compute(predictions=preds, references=labels)

trainer = train_toy_model.ToyModelTrainer(
    model=model,
    toy_graph=automata,
    gen_seq_len=hp["seq_len"],
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    optimizers=optimizers,
    callbacks=callbacks,
)


In [None]:
from transformers import pipeline, set_seed

generator = pipeline('text-generation', model=model, tokenizer=tokenizer, device=0)
gen_seq = generator("", max_length=50, do_sample=True, temperature=0.7)[0]['generated_text']

## Eval

In [None]:
model = orig_cb_model
# model = cb_model

In [None]:
example_answer = "49"
# example_prompt = f"33" # 3 4 6 1 2
# example_prompt = f"63" # 9 8 5 7 4 6
# example_prompt = f"47" # 0 1 2 3 4 5 8
# example_prompt = f"71" # 1 3 4 5 8 0 2
# example_prompt = f"72" # 8 9 4 5 7 3
example_prompt = f"34" # 4 5 7 1 6 8
utils.test_prompt(example_prompt, example_answer, cb_model, prepend_bos=True, prepend_space_to_answer=False, top_k=10)

In [None]:
for i in range(100):
    print(automata.nbrs(i), "-", i)

In [60]:
cb_model.reset_hook_kwargs()
base_state = 636
base_input = cb_model.to_tokens(automata.traj_to_str([base_state]), prepend_bos=True)
base_input = base_input.to(device)
base_logits, base_cache = cb_model.run_with_cache(base_input)


In [61]:
corrupted_state = 336
corrupted_input = cb_model.to_tokens(automata.traj_to_str([corrupted_state]), prepend_bos=True)
corrupted_input = corrupted_input.to(device)
corrupted_logits, corrupted_cache = cb_model.run_with_cache(corrupted_input)


In [62]:
softmax = True
# answer_tokens_both = torch.tensor([[1, 0]], device=cb_model.cfg.device)
answer_tokens_both = torch.tensor([[3, 9]], device=cb_model.cfg.device)

In [63]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()

def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_average_logit_diff)/(base_average_logit_diff - corrupted_average_logit_diff)



base_average_logit_diff = logits_to_ave_logit_diff(base_logits.softmax(dim=-1) if softmax else base_logits, answer_tokens_both)
corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits.softmax(dim=-1) if softmax else corrupted_logits, answer_tokens_both)

print(base_average_logit_diff, corrupted_average_logit_diff)

tensor(0.9996, device='cuda:0') tensor(-0.5823, device='cuda:0')


In [70]:
from functools import partial
scores, logit_diffs = [], []
pos = [-1]
# pos = list(range(17))
for i in range(cb_model.cfg.n_layers):
    for head in list(range(cb_model.cfg.n_heads)) + [None]:
        hook_fn = partial(patch_codebook_ids, pos=pos, cache=base_cache)#, code_idx=list(range(32)))
        cb_model.reset_codebook_metrics()
        if head is not None:
            cb_str = f'blocks.{i}.attn.codebook_layer.codebook.{head}.hook_codebook_ids'
        else:
            cb_str = f'blocks.{i}.mlp.codebook_layer.hook_codebook_ids'
        patched_logits = cb_model.run_with_hooks(
            corrupted_input,
            fwd_hooks = [(cb_str, hook_fn)],
            return_type="logits"
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits.softmax(dim=-1) if softmax else patched_logits, answer_tokens_both)
        logit_diffs.append(patched_logit_diff)
    #     print(patched_logit_diff)
        scores.append(normalize_patched_logit_diff(patched_logit_diff).item())
        if patched_logit_diff < -6:
            print("pred:", i, head)
            print(logits_to_pred(patched_logits,k=5))
    #     scores.append(patched_logit_diff.item())

# scores.append(normalize_patched_logit_diff(corrupted_average_logit_diff).item())
print(scores)
print(logit_diffs)
x_label = [f"L{l}H{h}" if h is not None else f"L{l}MLP" for l in range(cb_model.cfg.n_layers) for h in list(range(cb_model.cfg.n_heads)) + [None]]
line(scores, x=x_label, title=f"Logit Difference for Nth layers' codes for pos {pos}", labels={"y":"Normalized Logit Difference", "x": "Layer"})

[0.390800803899765, 0.0, 0.0, 0.7187427282333374, 0.0, 0.1705697625875473, 0.1652897596359253, 0.017624331638216972, 0.22878846526145935, 0.9348824620246887, 0.07006233185529709, 0.18100808560848236, 0.20456857979297638, 0.24130479991436005, 0.2563377916812897, 0.016519445925951004, -0.007724066264927387, -0.07203540951013565, -0.009037615731358528, 0.2594574987888336]
[tensor(0.0359, device='cuda:0'), tensor(-0.5823, device='cuda:0'), tensor(-0.5823, device='cuda:0'), tensor(0.5547, device='cuda:0'), tensor(-0.5823, device='cuda:0'), tensor(-0.3125, device='cuda:0'), tensor(-0.3208, device='cuda:0'), tensor(-0.5544, device='cuda:0'), tensor(-0.2204, device='cuda:0'), tensor(0.8966, device='cuda:0'), tensor(-0.4715, device='cuda:0'), tensor(-0.2960, device='cuda:0'), tensor(-0.2587, device='cuda:0'), tensor(-0.2006, device='cuda:0'), tensor(-0.1768, device='cuda:0'), tensor(-0.5562, device='cuda:0'), tensor(-0.5945, device='cuda:0'), tensor(-0.6962, device='cuda:0'), tensor(-0.5966, de

## Token Code Maps

In [9]:
max_samples = 10*1024
model = orig_cb_model
model.enable_logging()
model.reset_codebook_metrics()
train_dataset_tkns = train_toy_model.ToyDataset(automata, tokenizer=tokenizer, seq_len=128, max_samples=max_samples, save_tokens=True)

In [11]:
trainer.args.dataloader_num_workers = 0
trainer.args.report_on = "none"
trainer.model = model
codebook_acts = {}
def store_cb_activations(key, codebook_ids, codebook_acts=codebook_acts):
    assert len(codebook_ids.shape) == 3  # (bs, seq_len, k_codebook)
    if key not in codebook_acts:
        codebook_acts[key] = []
    codebook_acts[key].append(codebook_ids)

model.set_hook_fn(store_cb_activations)
trainer.model = model

metrics = trainer.evaluate(eval_dataset=train_dataset_tkns)
print(metrics)

# len(train_dataset_tkns.tokens)
print(codebook_acts.keys())
cb_acts = codebook_acts
num_codes = 10000
from tqdm import tqdm
for k, v in codebook_acts.items():
    v_len_to_get = max_samples//v[0].shape[0]
    cb_acts[k] = np.concatenate(v[:v_len_to_get], axis=0)

tokens = np.vstack(train_dataset_tkns.tokens)
print(tokens.shape)
print(cb_acts['layer1_attn_preproj_ccb0'].shape)
print(cb_acts['layer2_mlp'].shape)

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:10 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


{'eval_loss': 1.204332709312439, 'eval_accuracy': 0.46034617987204723, 'eval_runtime': 32.5704, 'eval_samples_per_second': 314.396, 'eval_steps_per_second': 0.614, 'eval_transition_accuracy': 0.7519354838709678, 'eval_first_transition_accuracy': 0.98, 'eval_multicode_k': 1, 'eval_dead_code_fraction/layer0': 0.9267, 'eval_MSE/layer0': 29718.738983438758, 'eval_input_norm/layer0': 127.66250211015993, 'eval_output_norm/layer0': 10.971974760256575, 'eval_dead_code_fraction/layer1': 0.382, 'eval_MSE/layer1': 68.49746054299729, 'eval_input_norm/layer1': 6.285440554106894, 'eval_output_norm/layer1': 11.180642855833774, 'eval_dead_code_fraction/layer2': 0.1394, 'eval_MSE/layer2': 151.147979416076, 'eval_input_norm/layer2': 6.778055443413886, 'eval_output_norm/layer2': 14.170335556351365, 'eval_dead_code_fraction/layer3': 0.4526, 'eval_MSE/layer3': 184.61280731194486, 'eval_input_norm/layer3': 8.196509695231274, 'eval_output_norm/layer3': 15.474042212992469}
dict_keys(['layer0_attn_preproj_ccb0

In [10]:
rev_automata = automata.reverse()

is_trigram = False
prefix_random_states_len = 10
plot_code_grp_distr = False

repeat = 3 if is_trigram else 2

js_divs = {}
all_state_info = {}

chars = [str(i) for i in range(automata.representation_base)]
all_valid_inputs = [''.join(combination) for combination in itertools.product(chars, repeat=repeat) if valid_input(''.join(combination), automata)]
all_valid_inputs_tokens = tokenizer(all_valid_inputs, return_tensors="pt")["input_ids"].to(
    cb_model.cfg.device
)
if prefix_random_states_len > 0:
    start_states = [s[0] for s in automata.seq_to_traj(all_valid_inputs)]
    random_state_prefix = rev_automata.generate_trajectories(prefix_random_states_len, start_states=start_states)
    random_state_prefix = random_state_prefix[:, ::-1]
    random_state_prefix = random_state_prefix.astype(int)
    random_state_prefix = [automata.traj_to_str(traj) for traj in random_state_prefix]
    random_state_prefix = tokenizer(random_state_prefix, return_tensors="pt")["input_ids"].to(
        cb_model.cfg.device
    )
    all_valid_inputs_tokens = torch.cat([random_state_prefix, all_valid_inputs_tokens], dim=1)
all_valid_inputs_tokens = F.pad(all_valid_inputs_tokens, (1, 0), value=tokenizer.bos_token_id)

for input in tqdm(all_valid_inputs):
    input_tensor = cb_model.to_tokens(str(input), prepend_bos=True).to(device)
    logits, cache = cb_model.run_with_cache(input_tensor)
    all_state_info[input] = (logits, cache)

for iter_a, input_a in enumerate(all_valid_inputs):
    for input_b in all_valid_inputs[iter_a+1:]:
        js_divs[(input_a, input_b)] = JSD(all_state_info[input_a][0], all_state_info[input_b][0]).item()

avg_js_div = sum(js_divs.values()) / len(js_divs)

if plot_code_grp_distr:
    code_groups_for_all_comps = {}
    for layer in tqdm(range(cb_model.cfg.n_layers)):
        for ccb_num in range(cb_model.cfg.n_heads):
            code_groups_for_all_comps[(layer, "attn", ccb_num)] = partition_input_on_codebook(cb_model=cb_model, automata=automata, layer=layer, cb_at="attn", ccb_num=ccb_num, input_len=3)
        code_groups_for_all_comps[(layer, "mlp", None)] = partition_input_on_codebook(cb_model=cb_model, automata=automata, layer=layer, cb_at="mlp", ccb_num=None, input_len=3)

all_states_tokens = all_valid_inputs_tokens
all_states_logits = torch.cat([v[0] for v in all_state_info.values()], dim=0)
all_states_cache = {k: v[1] for k, v in all_state_info.items()}

In [11]:
patchings_to_plot_orig = ["none", "l0_attn", "l1_mlp", "all_attn", "all_mlp", "all_attn_mlp"]
patchings_to_plot = [s.replace("all", "l0,l1,l2,l3") for s in patchings_to_plot_orig]

js_divs = {k: 0 for k in patchings_to_plot}
js_divs["none"] = 0

input_list = all_valid_inputs
for state_b, input_b in tqdm(enumerate(input_list)):
    js_divs_w_b = JSD(all_states_logits, all_states_logits[state_b].unsqueeze(0), reduction="none").sum() / (automata.N - 1) # removing b as JSD(b,b) = 0
    js_divs["none"] += js_divs_w_b

    for patching in patchings_to_plot:
        if patching == "none":
            continue
        cb_at = patching.split("_")[1:]
        layers = get_layers_from_patching_str(patching)
        heads = [None] * len(cb_at)
        if "attn" in cb_at:
            attn_idx = cb_at.index("attn")
            cb_at.pop(attn_idx), heads.pop(attn_idx)
            cb_at += ["attn"] * cb_model.cfg.n_heads
            heads += list(range(cb_model.cfg.n_heads))
        cb_at_rep = cb_at * len(layers)
        heads_rep = heads * len(layers)
        layers_rep = []
        for l in layers:
            layers_rep += [l] * len(cb_at)
        
        cache_b = all_state_info[input_b][1]
        code = [cache_b[get_cb_layer_name(cb_at_rep[i], layers_rep[i], heads_rep[i])][0, -1, :] for i in range(len(cb_at_rep))]

        mod_logits, mod_cache = run_with_codes(
            all_states_tokens,
            cb_model,
            code,
            cb_at_rep,
            layers_rep,
            heads_rep,
            pos=[-1],
        )
        js_divs_w_b = JSD(mod_logits, all_states_logits[state_b].unsqueeze(0), reduction="none").sum() / (automata.N - 1) # removing b as JSD(b,b) = 0
        js_divs[patching] += js_divs_w_b

js_divs = [js_divs[k].item() / automata.N for k in patchings_to_plot]
js_divs = [js_div / max(js_divs) for js_div in js_divs]

# plot js_divs using plotly
fig = go.Figure()
x_labels = [clean_patching_name(patching) for patching in patchings_to_plot_orig]
fig.add_trace(go.Bar(x=x_labels, y=js_divs))
fig.update_layout(title=f'k=4 Model: JS Div on Code Patching for {"Trigrams" if is_trigram else "Bigrams"}', xaxis_title='Code Patching Components', yaxis_title='Normalized JS Div')
fig.show()

686it [1:34:37,  8.28s/it]


In [17]:
repeat = 2

chars = [str(i) for i in range(automata.representation_base)]
all_valid_inputs = [''.join(combination) for combination in itertools.product(chars, repeat=repeat) if valid_input(''.join(combination), automata)]

js_divs = {}
state_logits = {}
for state_a in all_valid_inputs:
    a_input = cb_model.to_tokens(state_a, prepend_bos=True).to(device)
    a_logits = cb_model(a_input)
    state_logits[state_a] = a_logits

for iter_a, state_a in enumerate(all_valid_inputs):
    for state_b in all_valid_inputs[iter_a+1:]:
        # kl_divs[(state_a, state_b)] = F.kl_div(F.log_softmax(state_logits[state_a], dim=-1), F.log_softmax(state_logits[state_b], dim=-1), log_target=True, reduction="batchmean").item()
        js_divs[(state_a, state_b)] = JSD(state_logits[state_a], state_logits[state_b]).item()

avg_js_div = sum(js_divs.values()) / len(js_divs)

code_groups_for_all_comps = {}
for layer in tqdm(range(cb_model.cfg.n_layers)):
    for ccb_num in range(cb_model.cfg.n_heads):
        code_groups_for_all_comps[(layer, "attn", ccb_num)] = partition_input_on_codebook(
            cb_model=cb_model, automata=automata, layer=layer, cb_at="attn", ccb_num=ccb_num, input_len=repeat
        )
    code_groups_for_all_comps[(layer, "mlp", None)] = partition_input_on_codebook(
        cb_model=cb_model, automata=automata, layer=layer, cb_at="mlp", ccb_num=None, input_len=repeat
    )


In [22]:
for layer in range(cb_model.cfg.n_layers):
    for cb_at in ["attn", "mlp"]:
        for ccb_num in range(cb_model.cfg.n_heads) if cb_at == "attn" else [None]:
            plot_js_div(code_groups_for_all_comps, layer, cb_at, ccb_num, js_divs, show_plot=False, image_name_prefix="k4_model")

In [84]:
cb_model.reset_hook_kwargs()
# cb_model.all_codebooks[1][1].set_hook_kwargs(disable_codes=[6486], keep_k_codes=False, disable_for_tkns=[-1])
# cb_model.all_codebooks[2][1].set_hook_kwargs(disable_codes=[8406], keep_k_codes=False, disable_for_tkns=[-1])
# cb_model.set_hook_kwargs(idx=[1], disable_topk=1, keep_k_codes=False, disable_for_tkns=[-1])
# cb_model.all_codebooks[3][1].set_hook_kwargs(disable_codes=[111], keep_k_codes=False, disable_for_tkns=[-1])
base_str = "00"
base_input = cb_model.to_tokens(base_str, prepend_bos=True)
print(base_input)
base_input = base_input.to(device)
base_logits, base_cache = cb_model.run_with_cache(base_input)
print(logits_to_pred(base_logits))

cb_model.all_codebooks[3][1].set_hook_kwargs(disable_codes=[1185,16,111,902], keep_k_codes=False, disable_for_tkns=[-1])
base_logits, base_cache = cb_model.run_with_cache(base_input)
print(logits_to_pred(base_logits))


tensor([[10,  0,  0]], device='cuda:0')
[('6', 0.2214595377445221), ('9', 0.18503721058368683), ('0', 0.17624445259571075), ('4', 0.14047598838806152), ('8', 0.11677603423595428)]
[('6', 0.1748712658882141), ('4', 0.1526215821504593), ('0', 0.14920181035995483), ('9', 0.1374754011631012), ('8', 0.12292425334453583)]


In [76]:
# 1,1 1,2
layer = 3
# ccb_num = 3
# cb_at = "attn"
cb_at = "mlp"
if cb_at == "attn":
    ccb = f"_preproj_ccb{ccb_num}"
    indices = base_cache[f'blocks.{layer}.attn.codebook_layer.codebook.{ccb_num}.hook_codebook_ids'][0,-1].tolist()
else:
    ccb = ""
    indices = base_cache[f'blocks.{layer}.mlp.codebook_layer.hook_codebook_ids'][0,-1].tolist()

cb_str = f"layer{layer}_{cb_at}{ccb}"

if (not cb_str in ft_tkns) or True:
    ft_tkns[cb_str] = features_to_tokens_fast(cb_str)


  0%|          | 0/10240 [00:00<?, ?it/s]

100%|██████████| 10240/10240 [00:01<00:00, 7586.29it/s]


In [None]:
# print_ft_tkns(ft_tkns[cb_str],n=5,indices=indices,max_examples=50)
print_ft_tkns(ft_tkns[cb_str],n=5,indices=[111],max_examples=500)

In [None]:
start_state_activations(33, cb_at="attn", layer=0, ccb_num=2)

In [None]:
for layer in range(cb_model.cfg.n_layers):
    for head in range(cb_model.cfg.n_heads):
        print(layer, head, start_state_activations(63, cb_at="attn", layer=layer, ccb_num=head))
    # print(layer, start_state_activations(63, cb_at="mlp", layer=layer, ccb_num=None))


In [90]:
print(start_state_activations(63, cb_at="mlp", layer=1, ccb_num=None))

code [4776]
[33]


In [None]:
from functools import partial
# print(generate_with_codes("<|endoftext|>00", [5], ["attn"], layer_idx=[2], head_idx=[3], pos=[-1]))
disable_other_comps = True
code = [8268]
cb_at = ["mlp"]
layer_idx = [3]
head_idx = [1]
pos = [-1]
list_of_arg_tuples = [CodeInfoTuple(code[i], cb_at[i], layer_idx[i], head_idx[i], pos[i]) for i in range(len(code))]
print(generate_with_codes("<|endoftext|>33", list_of_arg_tuples=list_of_arg_tuples, disable_other_comps=disable_other_comps))
# print(generate_with_codes("<|endoftext|>011", list_of_arg_tuples=list_of_arg_tuples, disable_other_comps=disable_other_comps))
# print(generate_with_codes("<|endoftext|>10", list_of_arg_tuples=list_of_arg_tuples, disable_other_comps=disable_other_comps))
# print(generate_with_codes("<|endoftext|>110", list_of_arg_tuples=list_of_arg_tuples, disable_other_comps=disable_other_comps))


In [None]:
cb_model.generate(max_new_tokens=60, do_sample=True, stop_at_eos=True,prepend_bos=True)