In [45]:
%reload_ext autoreload
%autoreload 2

In [46]:
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from activation_store.collect import activation_store

import torch

## Load model

In [None]:
model_name = "Qwen/Qwen2.5-0.5B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="eager",  # flex_attention  flash_attention_2 sdpa eager
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.paddding_side = "left"
tokenizer.truncation_side = "left"

## Load data and tokenize

In [48]:
N = 20
max_length = 256

imdb = load_dataset('wassname/imdb_dpo', split=f'test[:{N}]', keep_in_memory=False)


def proc(row):
    messages = [
        {"role":"user", "content": row['prompt'] },
        {"role":"assistant", "content": row['chosen'] }
    ]
    return tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False, return_dict=True, max_length=max_length)

ds2 = imdb.map(proc).with_format("torch")
new_cols = set(ds2.column_names) - set(imdb.column_names)
ds2 = ds2.select_columns(new_cols)
ds2

Dataset({
    features: ['attention_mask', 'input_ids'],
    num_rows: 20
})

## Data loader

In [None]:
from torch.utils.data import DataLoader
from transformers.data import DataCollatorForLanguageModeling
collate_fn = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
ds = DataLoader(ds2, batch_size=4, num_workers=0, collate_fn=collate_fn)
print(ds)


<torch.utils.data.dataloader.DataLoader object at 0x7089f82ccb30>


## Collect activations

In [None]:
# choose layers to cache
layers = [k for k,v in model.named_modules() if k.endswith('mlp.down_proj')]
layers

['model.layers.0.mlp.down_proj',
 'model.layers.1.mlp.down_proj',
 'model.layers.2.mlp.down_proj',
 'model.layers.3.mlp.down_proj',
 'model.layers.4.mlp.down_proj',
 'model.layers.5.mlp.down_proj',
 'model.layers.6.mlp.down_proj',
 'model.layers.7.mlp.down_proj',
 'model.layers.8.mlp.down_proj',
 'model.layers.9.mlp.down_proj',
 'model.layers.10.mlp.down_proj',
 'model.layers.11.mlp.down_proj',
 'model.layers.12.mlp.down_proj',
 'model.layers.13.mlp.down_proj',
 'model.layers.14.mlp.down_proj',
 'model.layers.15.mlp.down_proj',
 'model.layers.16.mlp.down_proj',
 'model.layers.17.mlp.down_proj',
 'model.layers.18.mlp.down_proj',
 'model.layers.19.mlp.down_proj',
 'model.layers.20.mlp.down_proj',
 'model.layers.21.mlp.down_proj',
 'model.layers.22.mlp.down_proj',
 'model.layers.23.mlp.down_proj']

In [None]:
f = activation_store(ds, model, layers=layers)
f

[32m2025-02-16 09:36:37.315[0m | [1mINFO    [0m | [36mactivation_store.collect[0m:[36mactivation_store[0m:[36m77[0m - [1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__fac086acb713a85e.parquet[0m


collecting activations:   0%|          | 0/5 [00:00<?, ?it/s]

You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


PosixPath('/media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__fac086acb713a85e.parquet')

In [57]:
from datasets import Dataset
ds_a = Dataset.from_parquet(str(f)).with_format("torch")
ds_a

Dataset({
    features: ['act-model.layers.0.mlp.down_proj', 'act-model.layers.1.mlp.down_proj', 'act-model.layers.2.mlp.down_proj', 'act-model.layers.3.mlp.down_proj', 'act-model.layers.4.mlp.down_proj', 'act-model.layers.5.mlp.down_proj', 'act-model.layers.6.mlp.down_proj', 'act-model.layers.7.mlp.down_proj', 'act-model.layers.8.mlp.down_proj', 'act-model.layers.9.mlp.down_proj', 'act-model.layers.10.mlp.down_proj', 'act-model.layers.11.mlp.down_proj', 'act-model.layers.12.mlp.down_proj', 'act-model.layers.13.mlp.down_proj', 'act-model.layers.14.mlp.down_proj', 'act-model.layers.15.mlp.down_proj', 'act-model.layers.16.mlp.down_proj', 'act-model.layers.17.mlp.down_proj', 'act-model.layers.18.mlp.down_proj', 'act-model.layers.19.mlp.down_proj', 'act-model.layers.20.mlp.down_proj', 'act-model.layers.21.mlp.down_proj', 'act-model.layers.22.mlp.down_proj', 'act-model.layers.23.mlp.down_proj', 'logits', 'hidden_states'],
    num_rows: 20
})

In [None]:
ds_a[0:2]['hidden_states'].shape # [batch, layers, tokens, hidden_states]

torch.Size([2, 25, 453, 896])

In [61]:
ds_a[0:2]['act-model.layers.0.mlp.down_proj'].shape

torch.Size([2, 453, 896])