In [45]:
%reload_ext autoreload
%autoreload 2

In [46]:
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from activation_store.collect import activation_store

import torch

## Load model

In [47]:
model_name = "Qwen/Qwen2.5-0.5B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="eager",  # flex_attention  flash_attention_2 sdpa eager
)
tokenizer = AutoTokenizer.from_pretrained(model_name)


## Load data and tokenize

In [48]:
N = 20
max_length = 256

imdb = load_dataset('wassname/imdb_dpo', split=f'test[:{N}]', keep_in_memory=False)


def proc(row):
    messages = [
        {"role":"user", "content": row['prompt'] },
        {"role":"assistant", "content": row['chosen'] }
    ]
    return tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False, return_dict=True, max_length=max_length)

ds2 = imdb.map(proc).with_format("torch")
new_cols = set(ds2.column_names) - set(imdb.column_names)
ds2 = ds2.select_columns(new_cols)
ds2

Dataset({
    features: ['attention_mask', 'input_ids'],
    num_rows: 20
})

## Data loader

In [None]:
from torch.utils.data import DataLoader
def collate_fn(examples):
    # Pad the batch to max length within this batch
    return tokenizer.pad(
        examples,
        padding=True,
        return_tensors="pt",
        max_length=max_length,  
        truncation=True,
    )
ds = DataLoader(ds2, batch_size=4, num_workers=0, collate_fn=collate_fn)
print(ds)


<torch.utils.data.dataloader.DataLoader object at 0x7089f82ccb30>


## Collect activations

In [None]:
# choose layers to cache
layers = [k for k,v in model.named_modules() if k.endswith('mlp.down_proj')]
layers

['model.layers.0.mlp.down_proj',
 'model.layers.1.mlp.down_proj',
 'model.layers.2.mlp.down_proj',
 'model.layers.3.mlp.down_proj',
 'model.layers.4.mlp.down_proj',
 'model.layers.5.mlp.down_proj',
 'model.layers.6.mlp.down_proj',
 'model.layers.7.mlp.down_proj',
 'model.layers.8.mlp.down_proj',
 'model.layers.9.mlp.down_proj',
 'model.layers.10.mlp.down_proj',
 'model.layers.11.mlp.down_proj',
 'model.layers.12.mlp.down_proj',
 'model.layers.13.mlp.down_proj',
 'model.layers.14.mlp.down_proj',
 'model.layers.15.mlp.down_proj',
 'model.layers.16.mlp.down_proj',
 'model.layers.17.mlp.down_proj',
 'model.layers.18.mlp.down_proj',
 'model.layers.19.mlp.down_proj',
 'model.layers.20.mlp.down_proj',
 'model.layers.21.mlp.down_proj',
 'model.layers.22.mlp.down_proj',
 'model.layers.23.mlp.down_proj']

In [None]:
f = activation_store(ds, model, layers=layers)
f

[32m2025-02-16 09:36:37.315[0m | [1mINFO    [0m | [36mactivation_store.collect[0m:[36mactivation_store[0m:[36m77[0m - [1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__fac086acb713a85e.parquet[0m


collecting activations:   0%|          | 0/5 [00:00<?, ?it/s]

You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


PosixPath('/media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__fac086acb713a85e.parquet')

In [57]:
from datasets import Dataset
ds_a = Dataset.from_parquet(str(f)).with_format("torch")
ds_a

Dataset({
    features: ['act-model.layers.0.mlp.down_proj', 'act-model.layers.1.mlp.down_proj', 'act-model.layers.2.mlp.down_proj', 'act-model.layers.3.mlp.down_proj', 'act-model.layers.4.mlp.down_proj', 'act-model.layers.5.mlp.down_proj', 'act-model.layers.6.mlp.down_proj', 'act-model.layers.7.mlp.down_proj', 'act-model.layers.8.mlp.down_proj', 'act-model.layers.9.mlp.down_proj', 'act-model.layers.10.mlp.down_proj', 'act-model.layers.11.mlp.down_proj', 'act-model.layers.12.mlp.down_proj', 'act-model.layers.13.mlp.down_proj', 'act-model.layers.14.mlp.down_proj', 'act-model.layers.15.mlp.down_proj', 'act-model.layers.16.mlp.down_proj', 'act-model.layers.17.mlp.down_proj', 'act-model.layers.18.mlp.down_proj', 'act-model.layers.19.mlp.down_proj', 'act-model.layers.20.mlp.down_proj', 'act-model.layers.21.mlp.down_proj', 'act-model.layers.22.mlp.down_proj', 'act-model.layers.23.mlp.down_proj', 'logits', 'hidden_states'],
    num_rows: 20
})

In [None]:
ds_a[0:2]['hidden_states'].shape # [batch, layers, tokens, hidden_states]

torch.Size([2, 25, 453, 896])

In [61]:
ds_a[0:2]['act-model.layers.0.mlp.down_proj'].shape

torch.Size([2, 453, 896])

## Get supressed activations

In [64]:
from jaxtyping import Float, Int
from torch import Tensor
from einops import rearrange


def get_supressed_activations(
    hs: Float[Tensor, "l b t h"], w_out, w_inv
) -> Float[Tensor, "l b t h"]:
    """
    Novel experiment: Here we define a transform to isolate supressed activations, where we hypothesis that style/concepts/scratchpads and other internal only representations must be stored.

    See the following references for more information:

    - https://arxiv.org/pdf/2401.12181
        - > Suppression neurons that are similar, except decrease the probability of a group of related tokens

    - https://arxiv.org/html/2406.19384
        - > Previous work suggests that networks contain ensembles of “prediction" neurons, which act as probability promoters [66, 24, 32] and work in tandem with suppression neurons (Section 5.4).

    - https://arxiv.org/pdf/2401.12181
        > We find a striking pattern which is remarkably consistent across the different seeds: after about the halfway point in the model, prediction neurons become increasingly prevalent until the very end of the network where there is a sudden shift towards a much larger number of suppression neurons.
    """
    with torch.no_grad():
        # here we pass the hs through the last layer, take a diff, and then project it back to find which activation changes lead to supressed
        hs2 = rearrange(hs[:, :, -1:], "l b t h -> (l b t) h")
        hs_out2 = torch.nn.functional.linear(hs2, w_out)
        hs_out = rearrange(
            hs_out2, "(l b t) h -> l b t h", l=hs.shape[0], b=hs.shape[1], t=1
        )
        diffs = hs_out[:, :, :].diff(dim=0)
        diffs2 = rearrange(diffs, "l b t h -> (l b t) h")
        # W_inv = get_cache_inv(w_out)

        diffs_inv2 = torch.nn.functional.linear(diffs2.to(dtype=w_inv.dtype), w_inv)
        diffs_inv = rearrange(
            diffs_inv2, "(l b t) h -> l b t h", l=hs.shape[0] - 1, b=hs.shape[1], t=1
        ).to(w_out.dtype)
        # TODO just return this?
        eps = 1.0e-1
        supressed_mask = (diffs_inv < -eps).to(hs.dtype)
        # supressed_mask = repeat(supressed_mask, 'l b 1 h -> l b t h', t=hs.shape[2])
    supressed_act = hs[1:] * supressed_mask
    return supressed_act

In [None]:
from activation_store.collect import default_postprocess_result

Wo = model.get_output_embeddings().weight.detach().clone().cpu()
Wo_inv = torch.pinverse(Wo.clone().float())

@torch.no_grad()
def sup_postproc(input, trace, output, model):

    
    o = default_postprocess_result(input, trace, output, model)
    
    hs = o.pop('hidden_states')
    hs = rearrange(hs, "b l t h -> l b t h")
    hs_s = get_supressed_activations(hs, Wo.to(hs.dtype), Wo_inv.to(hs.dtype))
    hs_s = rearrange(hs_s, "l b t h -> b l t h")
    o['hidden_states_supressed'] = hs_s.half()
    
    return o


In [86]:
f2 = activation_store(ds, model, postprocess_result=sup_postproc)
f2

[32m2025-02-16 09:52:12.917[0m | [1mINFO    [0m | [36mactivation_store.collect[0m:[36mactivation_store[0m:[36m78[0m - [1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__115ab10dde7bd7a3.parquet[0m


collecting activations:   0%|          | 0/5 [00:00<?, ?it/s]

PosixPath('/media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__115ab10dde7bd7a3.parquet')

In [None]:
ds_a2 = Dataset.from_parquet(str(f2)).with_format("torch")
ds_a2

Generating train split: 0 examples [00:00, ? examples/s]

Dataset({
    features: ['attention_mask', 'logits', 'hidden_states_supressed'],
    num_rows: 20
})

In [92]:
ds_a2.info
ds_a2[0:2]['hidden_states_supressed'].shape

torch.Size([2, 24, 453, 896])

In [93]:
ds_a2[0:2]['hidden_states_supressed']

tensor([[[[ 0.0000e+00,  0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,
           -0.0000e+00, -0.0000e+00],
          [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
           -0.0000e+00, -0.0000e+00],
          ...,
          [ 0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-0.0000e+00,  0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
            0.0000e+00,  0.0000e+00],
          [-0.0000e+00,  0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
            0.0000e+00, -0.0000e+00]],

         [[ 0.0000e+00, -0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,
            0.0000e+00, -0.0000e+00],
          [-0.0000e+00, -0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
           -0.0000e+00, -0.0000e+00],
          [-0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
           -0.0000e+00,  0.0000e+00],
          ...,
     