In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset
import torch

from activation_store.collect import activation_store


## Load model

In [3]:
model_name = "Qwen/Qwen2.5-0.5B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto",
    attn_implementation="eager",  # flex_attention  flash_attention_2 sdpa eager
)
tokenizer = AutoTokenizer.from_pretrained(model_name)


## Load data and tokenize

In [4]:
N = 10
max_length = 128

imdb = load_dataset('wassname/imdb_dpo', split=f'test[:{N}]', keep_in_memory=False)


def proc(row):
    messages = [
        {"role":"user", "content": row['prompt'] },
        {"role":"assistant", "content": row['chosen'] }
    ]
    return tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=False, return_dict=True, max_length=max_length)

ds2 = imdb.map(proc).with_format("torch")
new_cols = set(ds2.column_names) - set(imdb.column_names)
ds2 = ds2.select_columns(new_cols)
ds2

Dataset({
    features: ['attention_mask', 'input_ids'],
    num_rows: 10
})

## Data loader

In [5]:
from torch.utils.data import DataLoader
def collate_fn(examples):
    # Pad the batch to max length within this batch
    return tokenizer.pad(
        examples,
        padding=True,
        return_tensors="pt",
    )
ds = DataLoader(ds2, batch_size=4, num_workers=0, collate_fn=collate_fn)
print(ds)


<torch.utils.data.dataloader.DataLoader object at 0x7089fb69f6e0>


## Collect activations

In [6]:
# choose layers to cache
layers = [k for k,v in model.named_modules() if 'mlp.down_proj' in k]
layers

['model.layers.0.mlp.down_proj',
 'model.layers.1.mlp.down_proj',
 'model.layers.2.mlp.down_proj',
 'model.layers.3.mlp.down_proj',
 'model.layers.4.mlp.down_proj',
 'model.layers.5.mlp.down_proj',
 'model.layers.6.mlp.down_proj',
 'model.layers.7.mlp.down_proj',
 'model.layers.8.mlp.down_proj',
 'model.layers.9.mlp.down_proj',
 'model.layers.10.mlp.down_proj',
 'model.layers.11.mlp.down_proj',
 'model.layers.12.mlp.down_proj',
 'model.layers.13.mlp.down_proj',
 'model.layers.14.mlp.down_proj',
 'model.layers.15.mlp.down_proj',
 'model.layers.16.mlp.down_proj',
 'model.layers.17.mlp.down_proj',
 'model.layers.18.mlp.down_proj',
 'model.layers.19.mlp.down_proj',
 'model.layers.20.mlp.down_proj',
 'model.layers.21.mlp.down_proj',
 'model.layers.22.mlp.down_proj',
 'model.layers.23.mlp.down_proj']

In [7]:
f = activation_store(ds, model, layers=layers, writer_batch_size=10)
f

[32m2025-02-16 09:16:55.292[0m | [1mINFO    [0m | [36mactivation_store.collect[0m:[36mactivation_store[0m:[36m122[0m - [1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__4a18b59a7867ed48.parquet[0m


collecting activations:   0%|          | 0/3 [00:00<?, ?it/s]

You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


PosixPath('/media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__4a18b59a7867ed48.parquet')

In [8]:
# load
ds_a = Dataset.from_parquet(str(f)).with_format("torch")
ds_a

Generating train split: 0 examples [00:00, ? examples/s]

Dataset({
    features: ['attention_mask', 'act-model.layers.0.mlp.down_proj', 'act-model.layers.1.mlp.down_proj', 'act-model.layers.2.mlp.down_proj', 'act-model.layers.3.mlp.down_proj', 'act-model.layers.4.mlp.down_proj', 'act-model.layers.5.mlp.down_proj', 'act-model.layers.6.mlp.down_proj', 'act-model.layers.7.mlp.down_proj', 'act-model.layers.8.mlp.down_proj', 'act-model.layers.9.mlp.down_proj', 'act-model.layers.10.mlp.down_proj', 'act-model.layers.11.mlp.down_proj', 'act-model.layers.12.mlp.down_proj', 'act-model.layers.13.mlp.down_proj', 'act-model.layers.14.mlp.down_proj', 'act-model.layers.15.mlp.down_proj', 'act-model.layers.16.mlp.down_proj', 'act-model.layers.17.mlp.down_proj', 'act-model.layers.18.mlp.down_proj', 'act-model.layers.19.mlp.down_proj', 'act-model.layers.20.mlp.down_proj', 'act-model.layers.21.mlp.down_proj', 'act-model.layers.22.mlp.down_proj', 'act-model.layers.23.mlp.down_proj', 'logits', 'hidden_states'],
    num_rows: 10
})

In [10]:
ds_a.info

DatasetInfo(description='', citation='', homepage='', license='', features={'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'act-model.layers.0.mlp.down_proj': Sequence(feature=Sequence(feature=Value(dtype='float16', id=None), length=-1, id=None), length=-1, id=None), 'act-model.layers.1.mlp.down_proj': Sequence(feature=Sequence(feature=Value(dtype='float16', id=None), length=-1, id=None), length=-1, id=None), 'act-model.layers.2.mlp.down_proj': Sequence(feature=Sequence(feature=Value(dtype='float16', id=None), length=-1, id=None), length=-1, id=None), 'act-model.layers.3.mlp.down_proj': Sequence(feature=Sequence(feature=Value(dtype='float16', id=None), length=-1, id=None), length=-1, id=None), 'act-model.layers.4.mlp.down_proj': Sequence(feature=Sequence(feature=Value(dtype='float16', id=None), length=-1, id=None), length=-1, id=None), 'act-model.layers.5.mlp.down_proj': Sequence(feature=Sequence(feature=Value(dtype='float16', id=None), length=-1,

In [11]:
ds_a[0:2]['hidden_states'].shape

torch.Size([2, 25, 453, 896])

In [12]:
ds_a[0:2]['act-model.layers.0.mlp.down_proj'].shape

torch.Size([2, 453, 896])

In [9]:
1/0

ZeroDivisionError: division by zero

## With dtypes compression - wip

In [13]:
def float_to_int8(x: torch.Tensor) -> torch.Tensor:
    """Converts a floating point tensor to float16, then reinterprets as int16."""
    downcast = x.type(torch.float8_e4m3fn)
    # if not downcast.isfinite().all():
    #     raise ValueError("Cannot convert to 16 bit: values are not finite")

    return downcast.view(torch.int8)

def int8_to_float32(x: torch.Tensor) -> torch.Tensor:
    """Converts int16 to float16, then reinterprets as float32."""
    return x.view(torch.float8_e4m3fn).type(torch.float32)


x = torch.randn(2, 3, 4)
x2 = float_to_int8(x)
x3 = int8_to_float32(x2)
assert torch.isfinite(x3).all()
assert torch.allclose(x, x3, rtol=1e-1)
d = ((x-x3)/x).abs().mean()
print(f'lost {d:.2%}')

lost 2.02%


In [27]:
from activation_store.collect import default_postprocess_result
from datasets.features.features import cast_to_python_objects
# o = cast_to_python_objects(o, only_1d_for_numpy=True, optimize_list_casting=False)

def float8_postprocess_result(
    input, trace, output, model
):
    o = default_postprocess_result(input, trace, output, model)
    # o = cast_to_python_objects(o, only_1d_for_numpy=False, optimize_list_casting=False)

    for k, v in o.items():
        if k=='attention_mask':
            o[k] = v.to(torch.int8)
        if isinstance(v, torch.Tensor) and torch.is_floating_point(v):
            print(k, v.dtype, v.shape, 'to int8')
            o[k] = float_to_int8(v.float())
        else:
            print('no conv', k, type(v))
    # o = {k: float_to_int8(v) if isinstance(v, torch.Tensor) else v
        #   for k, v in o.items()}
    return o

In [28]:
from datasets.arrow_writer import OptimizedTypedSequence, _ArrayXDExtensionType
from datasets.features.features import Features, Array2D, Array3D, Array4D, Array5D

# manually build features
optimized_int_type_by_col = {
    "attention_mask": "int8",  # binary tensor
    "special_tokens_mask": "int8",
    "input_ids": "int32",  # typical vocab size: 0-50k (max ~500k, never > 1M)
    "token_type_ids": "int8",  # binary mask; some (XLNetModel) use an additional token represented by a 2
}

def build_schema(d):
    inferred_features = Features()
    cols = d.keys()
    for col in cols:
        x = d[col]
        if col in optimized_int_type_by_col:
            dtype = optimized_int_type_by_col[col]
            typed_sequence = OptimizedTypedSequence(x, col=col)
            inferred_features[col] = typed_sequence.get_inferred_type()
        else:
            if x.ndim == 1:
                inferred_features[col] = OptimizedTypedSequence(x, col=col)
                inferred_features[col] = typed_sequence.get_inferred_type()
            shape=(-1,)+x.shape[1:]
            dtype = 'int8' if x.dtype == torch.float32 else x.dtype
            if x.ndim == 2:
                cls = Array2D
            elif x.ndim == 3:
                cls = Array3D
            elif x.ndim == 4:
                cls = Array4D
            elif x.ndim == 5:
                cls = Array5D
            else:
                raise ValueError(f"Unsupported number of dimensions: {x.ndim}")
            inferred_features[col] = cls(dtype=dtype, shape=shape)
    return inferred_features.arrow_schema
    # Features.from_arrow_schema(schema)

d = ds_a[0:2]
schema = build_schema(d)
schema
Features.from_arrow_schema(schema)

{'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'act-model.layers.0.mlp.down_proj': Array3D(shape=(-1, 453, 896), dtype='int8', id=None),
 'act-model.layers.1.mlp.down_proj': Array3D(shape=(-1, 453, 896), dtype='int8', id=None),
 'act-model.layers.2.mlp.down_proj': Array3D(shape=(-1, 453, 896), dtype='int8', id=None),
 'act-model.layers.3.mlp.down_proj': Array3D(shape=(-1, 453, 896), dtype='int8', id=None),
 'act-model.layers.4.mlp.down_proj': Array3D(shape=(-1, 453, 896), dtype='int8', id=None),
 'act-model.layers.5.mlp.down_proj': Array3D(shape=(-1, 453, 896), dtype='int8', id=None),
 'act-model.layers.6.mlp.down_proj': Array3D(shape=(-1, 453, 896), dtype='int8', id=None),
 'act-model.layers.7.mlp.down_proj': Array3D(shape=(-1, 453, 896), dtype='int8', id=None),
 'act-model.layers.8.mlp.down_proj': Array3D(shape=(-1, 453, 896), dtype='int8', id=None),
 'act-model.layers.9.mlp.down_proj': Array3D(shape=(-1, 453, 896), dtype='int8', id=None),
 'a

In [29]:
f2 = activation_store(ds, model, layers=layers, writer_batch_size=10, 
                      schema=schema,
                        postprocess_result=float8_postprocess_result)
f2
ds_a2 = Dataset.from_parquet(str(f2)).with_format("torch")
ds_a2.info

[32m2025-02-16 09:25:08.798[0m | [1mINFO    [0m | [36mactivation_store.collect[0m:[36mactivation_store[0m:[36m152[0m - [1mcreating dataset /media/wassname/SGIronWolf/projects5/elk/cache_transformer_acts/outputs/.ds/ds__c6184d05bf03be61.parquet[0m


collecting activations:   0%|          | 0/3 [00:00<?, ?it/s]

no conv attention_mask <class 'torch.Tensor'>
act-model.layers.0.mlp.down_proj torch.float16 torch.Size([4, 453, 896]) to int8
act-model.layers.1.mlp.down_proj torch.float16 torch.Size([4, 453, 896]) to int8
act-model.layers.2.mlp.down_proj torch.float16 torch.Size([4, 453, 896]) to int8
act-model.layers.3.mlp.down_proj torch.float16 torch.Size([4, 453, 896]) to int8
act-model.layers.4.mlp.down_proj torch.float16 torch.Size([4, 453, 896]) to int8
act-model.layers.5.mlp.down_proj torch.float16 torch.Size([4, 453, 896]) to int8
act-model.layers.6.mlp.down_proj torch.float16 torch.Size([4, 453, 896]) to int8
act-model.layers.7.mlp.down_proj torch.float16 torch.Size([4, 453, 896]) to int8
act-model.layers.8.mlp.down_proj torch.float16 torch.Size([4, 453, 896]) to int8
act-model.layers.9.mlp.down_proj torch.float16 torch.Size([4, 453, 896]) to int8
act-model.layers.10.mlp.down_proj torch.float16 torch.Size([4, 453, 896]) to int8
act-model.layers.11.mlp.down_proj torch.float16 torch.Size([4,

ArrowTypeError: Could not convert tensor([[  48,   25,  -83,  ...,   31,   45,   41],
        [ -76, -100,  -94,  ...,   26,  -84, -117],
        [ -97,   26,   15,  ...,  -97, -107, -109],
        ...,
        [  44,  -94, -104,  ..., -110,   18,   27],
        [ -77,   26,  -77,  ..., -100,   33,   43],
        [ -98,   22, -111,  ..., -110,   14, -107]], dtype=torch.int8) with type Tensor: was not a sequence or recognized null for conversion to list type

In [None]:
f2 = activation_store(ds, model, layers=layers, writer_batch_size=10, postprocess_result=float8_postprocess_result)
f2

In [None]:
from datasets import Dataset
# load
ds_a2 = Dataset.from_parquet(str(f2)).with_format("torch")
for c in ds_a2.column_names[1:]:
    print(c)
    ds_a2[c] = int8_to_float32(ds_a2[0:-1][c])
# ds_a2 = int8_to_float32(ds_a)
ds_a2.info

In [None]:
d = ds_a2[:][c]
print(c)
d.shape

In [None]:
ds_a2.info

In [None]:
ds_a[0:2]['logits'].shape