In [None]:
!pip install --no-cache-dir -U "transformers>=4.51.0" accelerate datasets torch pandas tqdm nnsight huggingface_hub peft
!pip install --no-cache-dir typing-extensions --upgrade
!pip uninstall -y torchvision

In [2]:
import os
import json
import re
import random
import torch
import torch.nn as nn
from torch.utils.data import IterableDataset, DataLoader, get_worker_info
# from nnsight import LanguageModel
from peft import LoraConfig, get_peft_model
from tqdm import tqdm
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM  # other option: Qwen/Qwen2.5-0.5B-Instruct

In [3]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
def load_model_and_tokenizer(model_name="meta-llama/Llama-3.2-1B-Instruct", attn_implementation="sdpa", mode="eval", **kwargs):
    """Load model and tokenizer with standard setup.

    Returns:
        tuple: (model, tokenizer, config_dict) where config_dict has num_layers, num_heads, head_dim
    """
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.padding_side = "left"
    tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        dtype="auto",
        device_map="auto",
        attn_implementation=attn_implementation,
        **kwargs
    )

    if mode == "eval":
        model.eval()
    else:
        model.train()

    num_heads = model.config.num_attention_heads
    head_dim = model.config.hidden_size // num_heads
    num_layers = model.config.num_hidden_layers

    config = {
        "num_layers": num_layers,
        "num_heads": num_heads,
        "head_dim": head_dim,
    }

    return model, tokenizer, config

In [5]:
def load_data(dataset_name="HuggingFaceFW/fineweb", split="train", streaming=True, first_k=int(1e5), buffer_frac=0.1, val_frac=0.05):
    ds_stream = load_dataset(
        dataset_name,
        split=split,
        streaming=streaming
    )
    ds_stream = ds_stream.shuffle(buffer_size=int(first_k * buffer_frac), seed=0)
    ds_train = ds_stream.take(first_k)
    ds_val = ds_stream.skip(first_k).take(int(first_k * val_frac))
    return ds_train, ds_val

In [6]:
def get_cropped_text_ids(dataset, tokenizer, prefix_ids, cropped_len=48):
    for item in dataset:
        text = item["text"]
        text_ids = tokenizer(
            text,
            return_tensors=None,
            add_special_tokens=False
        )["input_ids"]

        if len(text_ids) >= cropped_len:
            start = rng.randint(0, len(text_ids) - cropped_len)
            selected_ids = text_ids[start:start + cropped_len]
            yield prefix_ids + selected_ids

In [7]:
subject, tokenizer, config = load_model_and_tokenizer()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

In [8]:
for p in subject.parameters():
    p.requires_grad = False

In [12]:
ds_train, ds_val = load_data(first_k=100000)

Resolving data files:   0%|          | 0/27468 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/27468 [00:00<?, ?it/s]

In [9]:
nnsight_model = LanguageModel(subject, tokenizer)

In [13]:
INSTRUCT_PREFIX = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n
Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n
<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"""

prefix_ids = tokenizer(
    INSTRUCT_PREFIX,
    return_tensors=None,
    add_special_tokens=False
)["input_ids"]
len_prefix = len(prefix_ids)

cropped_len = 48

rng = random.Random(42)

In [None]:
text = INSTRUCT_PREFIX + next(iter(ds_val))["text"][:200]
enc = tokenizer(
    text,
    return_tensors=None,
    add_special_tokens=False
)
input_ids = enc["input_ids"]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
for a, b in zip(tokens, input_ids):
    print(a, "    ", b)

In [None]:
text = INSTRUCT_PREFIX + next(iter(ds_val))["text"][:20]
enc = tokenizer(
    text,
    return_tensors=None,
    add_special_tokens=False
)
input_ids = enc["input_ids"]
tokens = tokenizer.convert_ids_to_tokens(input_ids)
for a, b in zip(tokens, input_ids):
    print(a, "    ", b)

In [15]:
cropped_data_ids = get_cropped_text_ids(ds_val, tokenizer, prefix_ids)

In [16]:
li = next(iter(cropped_data_ids))
print(li)
tokens = tokenizer.convert_ids_to_tokens(li)
len(tokens)

[128000, 128006, 9125, 128007, 1432, 38766, 1303, 33025, 2696, 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25, 220, 1627, 10263, 220, 2366, 19, 1432, 128009, 128006, 882, 128007, 271, 527, 1274, 889, 1505, 5694, 19596, 14918, 4245, 311, 264, 26682, 12205, 315, 6677, 922, 2574, 12765, 304, 279, 3596, 8915, 1917, 13, 578, 22963, 1917, 11031, 779, 5043, 1606, 315, 279, 502, 2574, 12765, 1855, 323, 1475, 2046, 4028, 279, 1917, 627, 2181, 1587, 539, 1935, 25294]


78

In [None]:
tokens

In [18]:
tokenizer.decode(li[30:], skip_special_tokens=False)

' are people who find themselves floating mainly due to a shallow wealth of knowledge about things happening in the ever dynamic world. The soccer world moves so fast because of the new things happening each and every week across the world.\nIt does not take rocket'

In [34]:
enc = tokenizer(
    " X" * 16,
    return_tensors=None,
    add_special_tokens=True
)["input_ids"]
enc

[128000,
 1630,
 1630,
 1630,
 1630,
 1630,
 1630,
 1630,
 1630,
 1630,
 1630,
 1630,
 1630,
 1630,
 1630,
 1630,
 1630]

In [21]:
prompt = li
batch_ids = torch.tensor([prefix_ids, prefix_ids], device=subject.device)
with nnsight_model.trace(batch_ids) as tracer:
    resid = nnsight_model.model.layers[10].output[:].save()
print(resid.shape)

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


torch.Size([2, 30, 2048])


In [None]:
resid[0].float().cpu().numpy()[:10]

In [14]:
class CroppedTokenDataset(IterableDataset):
    def __init__(self, hf_dataset, tokenizer, prefix_ids, cropped_len=48, mode="train"):
        self.ds = hf_dataset
        self.tokenizer = tokenizer
        self.prefix = torch.tensor(prefix_ids, dtype=torch.long)
        self.cropped_len = cropped_len
        self.mode = mode

    def __iter__(self):
        info = get_worker_info()
        wid = 0 if info is None else info.id
        nw  = 1 if info is None else info.num_workers

        ds = self.ds if nw == 1 else self.ds.shard(num_shards=nw, index=wid)

        g = torch.Generator()
        g.manual_seed(0 if info is None else info.seed)

        for item in ds:
            text_ids = self.tokenizer(item["text"], return_tensors=None, add_special_tokens=False)["input_ids"]
            if len(text_ids) >= self.cropped_len:
                if self.mode == "train":
                    start = int(torch.randint(0, len(text_ids) - self.cropped_len + 1, (1,), generator=g).item())
                else:
                    start = 0

                cropped = torch.tensor(text_ids[start:start + self.cropped_len], dtype=torch.long)
                yield torch.cat([self.prefix, cropped], dim=0)

In [15]:
class Encoder(nn.Module):
    def __init__(self, d_in=2048, multiplier=8, top_k=16):
        super().__init__()
        self.top_k = top_k
        self.w_enc = nn.Linear(d_in, d_in * multiplier, bias=True)
        self.w_emb = nn.Linear(d_in * multiplier, d_in, bias=False)
        self.reset_parameters()

    def reset_parameters(self):
        with torch.no_grad():
            W = torch.randn_like(self.w_enc.weight)
            W /= W.norm(dim=1, keepdim=True)
            self.w_enc.weight.copy_(W)
            self.w_enc.bias.zero_()
            self.w_emb.weight.copy_(self.w_enc.weight.T)

    def forward(self, x):  # (B, 16, d_in)
        y = self.w_enc(x)  # (B, 16, d_in*mult)

        idx = torch.topk(y, self.top_k, dim=-1).indices
        mask = torch.zeros_like(y, dtype=torch.bool)
        mask.scatter_(-1, idx, True)
        masked_y = y * mask.to(y.dtype)

        out = self.w_emb(masked_y)  # (B, 16, d_in)
        return out, idx

In [16]:
# def get_resid_stream_vector(layer, input_ids, prefix_ids, cropped_len):
#     with nnsight_model.trace(input_ids):
#         resid = nnsight_model.model.layers[layer].output[:]
#         start = len(prefix_ids) + cropped_len // 3
#         end = start + cropped_len // 3
#         out = resid[:, start:end, :].save()
#         return out

def get_resid_stream_vector(model, input_ids, layer, start, end, attention_mask=None):
    out = model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        output_hidden_states=True,
        return_dict=True,
        use_cache=False
    )
    resid = out.hidden_states[layer + 1]
    return resid[:, start:end, :]

In [17]:
def get_resid_stream_vector_efficient(model, input_ids, layer, start, end, attention_mask=None):
    saved = {}
    def hook(module, inp, out):
        saved["slice"] = out[:, start:end, :].detach()

    h = model.model.layers[layer].register_forward_hook(hook)
    try:
        with torch.inference_mode():
            model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                use_cache=False,
                output_hidden_states=False,
                return_dict=False
            )
        return saved["slice"]
    finally:
        h.remove()

In [53]:
test = get_resid_stream_vector(subject, torch.tensor([[2040,3520]], device="cuda"),3,0,5 )

In [18]:
decoder_base, _, _ = load_model_and_tokenizer(mode="train")
lora_cfg = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj"]
)
decoder = get_peft_model(decoder_base, lora_cfg).train()

In [19]:
A = decoder.base_model.model.model.layers[0].self_attn.q_proj.lora_A["default"].weight.detach()
B = decoder.base_model.model.model.layers[0].self_attn.q_proj.lora_B["default"].weight.detach()

print("A mean/std:", A.mean().item(), A.std().item())
print("B mean/std:", B.mean().item(), B.std().item())
print("B all zero:", (B == 0).all().item())

A mean/std: -1.779676676960662e-05 0.012750617228448391
B mean/std: 0.0 0.0
B all zero: True


In [20]:
d_model = decoder.config.hidden_size
d_model_multiplier = 8
encoder = Encoder(d_in=d_model, multiplier=d_model_multiplier, top_k=16).to(decoder.device).to(torch.bfloat16)
optim = torch.optim.AdamW(
    list(encoder.parameters()) + list(decoder.parameters()),
    lr=1e-4,
    weight_decay=0.01
)
device = decoder.device
start_cropped_pos = len_prefix + cropped_len // 3
end_cropped_pos = start_cropped_pos + cropped_len // 3
layer = 8
batch_size=64
dummy = tokenizer(
    " X" * (cropped_len // 3),
    return_tensors="pt",
    add_special_tokens=False
)["input_ids"].expand(batch_size, -1).to(device)
patch_idx = torch.arange(16, device=device)

In [21]:
concepts_last_occ_by_seen_tokens = torch.full(
    (d_model * d_model_multiplier,),
    -1,
    dtype=torch.long,
    device=device
)
seen_tokens = 0
inactive_concepts_tracker = []

In [22]:
pcd_train_ds = CroppedTokenDataset(
    hf_dataset=ds_train,
    tokenizer=tokenizer,
    prefix_ids=prefix_ids,
    cropped_len=48,
    mode="train"
)

pcd_val_ds = CroppedTokenDataset(
    hf_dataset=ds_val,
    tokenizer=tokenizer,
    prefix_ids=prefix_ids,
    cropped_len=48,
    mode="val"
)

# 3) dataloader
train_loader = DataLoader(
    pcd_train_ds,
    batch_size=64,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True,
    drop_last=True
)

val_loader = DataLoader(
    pcd_val_ds,
    batch_size=64,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True,
    drop_last=True
)

In [23]:
patch_state = {"vecs": None}
def patch_resid_stream_hook(idx):
    def hook(module, inp, out):
        h = out.clone()
        h[:, idx, :] = patch_state["vecs"].to(h.dtype)
        return h
    return hook

In [24]:
def train_step(subject_model, batch, layer, start_pos, end_pos,
               include_aux_loss=True, update_last_occ=True, aux_thresh=2e5, eps_aux=1e-4, k_aux=250):

    with torch.no_grad():
        encoder_in = get_resid_stream_vector_efficient(
            subject_model, batch, layer, start_pos, end_pos
        )  # (B, 16, d_model)

    encoder_out, idx = encoder(encoder_in)
    suffix = batch[:, -16:]
    decoder_in = torch.cat([dummy, suffix], dim=-1)

    patch_state["vecs"] = encoder_out

    label_ids = decoder_in.clone()
    label_ids[:, :dummy.size(-1)] = -100

    out = decoder(
        input_ids=decoder_in,
        labels=label_ids,
        use_cache=False
    )
    ce_loss = out.loss

    recent_concepts = torch.unique(idx.reshape(-1))

    if update_last_occ:
        concepts_last_occ_by_seen_tokens[recent_concepts] = seen_tokens

    window_start = max(0, seen_tokens - aux_thresh)
    inactive = concepts_last_occ_by_seen_tokens < window_start

    num_inactive = inactive.sum().item()
    aux_loss = 0.0

    if include_aux_loss:

        W_inactive = encoder.w_enc.weight[inactive]  # (#inactive, d_model)
        num_for_aux = W_inactive.size(0)

        if num_for_aux > 0:
            x_flat = encoder_in.reshape(-1, encoder_in.size(-1))  # (B*16, d_model)
            dot = x_flat @ W_inactive.T  # (B*16, #inactive)

            k_eff = min(num_for_aux, k_aux)
            top_vals = torch.topk(dot, k_eff, dim=1).values

            aux_loss = -(eps_aux / k_eff) * top_vals.sum(dim=1).mean()

    return ce_loss + aux_loss, num_inactive

In [25]:
num_epochs = 10
patience = 5
curr_bad = 0
best_val = float("inf")

handle = decoder.base_model.model.model.embed_tokens.register_forward_hook(
    patch_resid_stream_hook(patch_idx)
)

every_n_steps = 200
inactive_concepts_n_steps = 100
total_inactive_concepts = 0
global_step = 0
count_steps = 0

stop_training = False
for epoch in range(num_epochs):

    pbar = tqdm(enumerate(train_loader, start=1), desc=f"epoch {epoch+1}/{num_epochs}")
    for step, train_batch in pbar:
        global_step += 1

        train_batch = train_batch.to(device, non_blocking=True)

        loss, num_inact_concepts = train_step(
            subject, train_batch, layer, start_cropped_pos, end_cropped_pos
        )
        count_steps += 1
        total_inactive_concepts += num_inact_concepts
        seen_tokens += train_batch.size(0) * (cropped_len // 3)

        optim.zero_grad(set_to_none=True)
        loss.backward()
        optim.step()
        pbar.set_postfix(loss=float(loss.item()))

        if global_step % inactive_concepts_n_steps == 0:
            inactive_concepts_tracker.append((seen_tokens, total_inactive_concepts / count_steps))
            total_inactive_concepts = 0
            count_steps = 0

        if step % every_n_steps == 0:

            encoder.eval()
            decoder.eval()

            total = 0.0
            n = 0
            with torch.no_grad():
                for val_batch in tqdm(val_loader, desc="val", leave=False):
                    val_batch = val_batch.to(device, non_blocking=True)
                    val_loss, _ = train_step(
                        subject, val_batch, layer, start_cropped_pos, end_cropped_pos,
                        include_aux_loss=False, update_last_occ=False)

                    total += val_loss.item()
                    n += 1

            val_mean = total / n

            if val_mean < best_val:
                best_val = val_mean
                curr_bad = 0
                torch.save(
                    {
                        "encoder": encoder.state_dict(),
                        "decoder": decoder.state_dict(),
                        "optim": optim.state_dict(),
                        "epoch": epoch,
                        "step": step,
                        "best_val": best_val,
                        "curr_bad": curr_bad,
                    },
                    "best_checkpoint.pt",
                )

            else:
                curr_bad += 1
                if curr_bad >= patience:
                    stop_training = True
                    break

            encoder.train()
            decoder.train()
            pbar.set_postfix(loss=float(loss.item()), val=float(val_mean), best_val=float(best_val))

    if stop_training:
        print("Stopping training...")
        break

handle.remove()

epoch 1/10: 3it [00:34, 11.36s/it, loss=7.27]


FileNotFoundError: Caught FileNotFoundError in DataLoader worker process 3.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/_utils/fetch.py", line 33, in fetch
    data.append(next(self.dataset_iter))
                ^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-3632968432.py", line 19, in __iter__
    for item in ds:
                ^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 2347, in __iter__
    yield from self._iter_pytorch()
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 2262, in _iter_pytorch
    for key, example in ex_iterable:
                        ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 1882, in __iter__
    for key, pa_table in self._iter_arrow():
                         ^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 1905, in _iter_arrow
    for key, pa_table in self.ex_iterable._iter_arrow():
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 499, in _iter_arrow
    for key, pa_table in iterator:
                         ^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 151, in _convert_to_arrow
    for key, example in iterator:
                        ^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 1745, in __iter__
    for key_example in islice(self.ex_iterable, self.n - ex_iterable_num_taken):
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 1558, in __iter__
    for x in self.ex_iterable:
             ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 325, in __iter__
    for key, pa_table in self.generate_tables_fn(**gen_kwags):
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/packaged_modules/parquet/parquet.py", line 87, in _generate_tables
    for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/utils/track.py", line 49, in __iter__
    for x in self.generator(*self.args):
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/utils/file_utils.py", line 1359, in _iter_from_urlpaths
    raise FileNotFoundError(urlpath)
FileNotFoundError: hf://datasets/HuggingFaceFW/fineweb@9bb295ddab0e05d785b879661af7260fed5140fc/data/CC-MAIN-2022-33/004_00048.parquet
