In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

# Ensure the clt package is discoverable
project_dir = os.path.abspath('..')
if project_dir not in sys.path:
    sys.path.insert(0, project_dir)

from clt.activation_generation.generator import ActivationConfig
from clt.nnsight.extractor import ActivationExtractorCLT

In [3]:
cfg = ActivationConfig(
    model_name="allenai/OLMo-2-0425-1B-Instruct",
    mlp_input_module_path_template="model.layers.{}.mlp.input",
    mlp_output_module_path_template="model.layers.{}.mlp.output",
    # model_dtype=args.model_dtype,
    activation_dtype="bfloat16",
    dataset_path="allenai/olmo-mix-1124",
    # dataset_split=args.dataset_split,
    # dataset_text_column=args.dataset_text_column,
    # dataset_skip=args.dataset_skip,
    context_size=4096,
    inference_batch_size=8,
    # exclude_special_tokens=args.exclude_special_tokens,
    prepend_bos=True,
    # streaming=args.streaming,
    # dataset_trust_remote_code=args.trust_remote_code,
    # cache_path=args.cache_path,
    target_total_tokens=1000000,
    activation_dir=None,
    # output_format=args.output_format,
    compression=None,
    chunk_token_threshold=1,
    compute_norm_stats=True,
    # nnsight_tracer_kwargs=nnsight_tracer_kwargs,
    # nnsight_invoker_args=nnsight_invoker_args,
    # remote_server_url=args.remote_server_url,
    # delete_after_upload=args.delete_after_upload,
    # upload_max_retries=args.upload_max_retries,
    # upload_initial_backoff=args.upload_initial_backoff,
    # upload_max_backoff=args.upload_max_backoff,
    # enable_profiling=args.enable_profiling,
)

INFO:clt.config.data_config:ActivationConfig Summary:
  Model: allenai/OLMo-2-0425-1B-Instruct
  Dataset: allenai/olmo-mix-1124 (Split: train, Skip: None)
  Target Tokens: 1000000
  Chunk Threshold: 1
  Activation Dtype: bfloat16
  Output Dir: None


In [5]:
device = 'cuda'
extractor = ActivationExtractorCLT(
    model_name=cfg.model_name,
    mlp_input_module_path_template=cfg.mlp_input_module_path_template,
    mlp_output_module_path_template=cfg.mlp_output_module_path_template,
    device=device,
    model_dtype=cfg.model_dtype,
    context_size=cfg.context_size,
    inference_batch_size=cfg.inference_batch_size,
    exclude_special_tokens=cfg.exclude_special_tokens,
    prepend_bos=cfg.prepend_bos,
    nnsight_tracer_kwargs=cfg.nnsight_tracer_kwargs,
    nnsight_invoker_args=cfg.nnsight_invoker_args
)
stream = extractor.stream_activations(
    dataset_path=cfg.dataset_path,
    dataset_split=cfg.dataset_split,
    dataset_text_column=cfg.dataset_text_column,
    dataset_skip=cfg.dataset_skip,
    streaming=cfg.streaming,
    dataset_trust_remote_code=cfg.dataset_trust_remote_code,
    cache_path=cfg.cache_path,
)
stream

<generator object ActivationExtractorCLT.stream_activations at 0xcec7200>

# Generate and Save

In [6]:
from __future__ import annotations

import os
import json
import queue
import random
import logging
import threading
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any, DefaultDict
from concurrent.futures import ThreadPoolExecutor, as_completed

import torch
import numpy as np
import h5py
from tqdm import tqdm
import requests
from urllib.parse import quote, urljoin

# ––– local imports (keep relative to package root) –––
from clt.nnsight.extractor import ActivationExtractorCLT  # noqa: E402
from clt.config.data_config import ActivationConfig  # noqa: E402

# --- Profiling Imports ---
import time  # Keep this one
from contextlib import contextmanager
from collections import defaultdict
import psutil

# Local application imports
# from clt.training.utils import torch_bfloat16_to_numpy_uint16 # Removed unused import

try:
    import GPUtil
except ImportError:
    GPUtil = None  # type: ignore
# --- End Profiling Imports ---

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

ActivationBatch = Tuple[Dict[int, torch.Tensor], Dict[int, torch.Tensor]]

In [7]:
from clt.activation_generation.generator import _RunningStat

In [8]:
tgt_tokens = cfg.target_total_tokens
chunk_tokens = cfg.chunk_token_threshold
pbar = tqdm(total=tgt_tokens or None, unit="tok", smoothing=0.2)

# Collect manifest rows in‑memory to avoid pre‑allocation mismatch bugs.
# Each entry is (chunk_id, local_row).  For 1 M tokens this is only 8 MB.
manifest_rows: List[np.ndarray] = []

# Norm‑stat structures
stats: Dict[int, Dict[str, _RunningStat]] = {}

g_row = 0
c_idx = 0
buf_inp: Dict[int, List[torch.Tensor]] = {}
buf_tgt: Dict[int, List[torch.Tensor]] = {}
layer_ids: Optional[List[int]] = None
d_model = -1
dtype_str = "unknown"

for batch_idx, (batch_inp, batch_tgt) in enumerate(stream):
    # with self._conditional_measure("batch_processing_total"):
    if tgt_tokens and g_row >= tgt_tokens:
        break
    if not batch_inp:
        continue

    # with self._conditional_measure("batch_metadata_setup"):
    if layer_ids is None:
        layer_ids = sorted(batch_inp.keys())
        d_model = batch_inp[layer_ids[0]].shape[-1]
        dtype_str = str(batch_inp[layer_ids[0]].dtype)
        # if self.profiler:
        #     self.profiler.set_layer_ids_ref(layer_ids)
        for lid in layer_ids:
            buf_inp[lid] = []
            buf_tgt[lid] = []
            if cfg.compute_norm_stats:
                stats[lid] = {
                    "inputs": _RunningStat(d_model, device=device),
                    "targets": _RunningStat(d_model, device=device),
                }
        print(
            "Layers=%d d_model=%d dtype=%s", len(layer_ids) if layer_ids else 0, d_model, dtype_str
        )

    n_tok_in_batch = 0
    if layer_ids and batch_inp.get(layer_ids[0]) is not None:
        n_tok_in_batch = batch_inp[layer_ids[0]].shape[0]

    # with self._conditional_measure("batch_gpu_tensor_accumulate"):
    if layer_ids:
        for lid in layer_ids:
            if lid in batch_inp and lid in batch_tgt:
                inp = batch_inp[lid].detach()
                tgt = batch_tgt[lid].detach()
                buf_inp[lid].append(inp)
                buf_tgt[lid].append(tgt)
                if cfg.compute_norm_stats and lid in stats:
                    # with self._conditional_measure(f"batch_norm_stats_update_layer_{lid}"):
                    stats[lid]["inputs"].update(inp)
                    stats[lid]["targets"].update(tgt)
            else:
                print(
                    f"Layer {lid} expected but not found in current batch. Skipping accumulation for this layer."
                )

    if n_tok_in_batch > 0:
        g_row += n_tok_in_batch
        pbar.update(n_tok_in_batch)
    #     if self.profiler:
    #         self.profiler.total_tokens_processed_for_batch_profiling += n_tok_in_batch
    # if self.profiler:
    #     self.profiler.batch_processing_total_calls += 1

    # if layer_ids and buf_inp.get(layer_ids[0]):
    #     cur_rows = sum(t.shape[0] for t in buf_inp[layer_ids[0]])
    #     if cur_rows >= chunk_tokens:
    #         # with self._conditional_measure("chunk_write_dispatch"):
    #         self._write_chunk(
    #             c_idx,
    #             buf_inp,
    #             buf_tgt,
    #             layer_ids,
    #             d_model,
    #             cur_rows,
    #             manifest_rows,
    #             g_row - cur_rows,
    #         )
    #         c_idx += 1
    #         # with self._conditional_measure("chunk_buffer_clear"):
    #         if layer_ids:
    #             for lid_clear in layer_ids:
    #                 buf_inp[lid_clear].clear()
    #                 buf_tgt[lid_clear].clear()
    # if batch_idx > 0 and batch_idx % 50 == 0:
    #     if self.profiler:
    #         self.profiler.log_system_metrics(f"batch_interval_{batch_idx}")
    break

  0%|          | 0/1000000 [00:00<?, ?tok/s]

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  2%|▏         | 15001/1000000 [00:11<12:23, 1325.58tok/s]

Layers=%d d_model=%d dtype=%s 16 2048 torch.float32


In [10]:
buf_inp.keys(), buf_tgt.keys()

(dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]),
 dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]))

In [13]:
buf_inp[0][0].shape

torch.Size([15001, 2048])

In [14]:
buf_tgt[0][0].shape

torch.Size([15001, 2048])

In [None]:
batch_inp[0].shape

torch.Size([15001, 2048])

: 

In [None]:
def generate_and_save():
    tgt_tokens = cfg.target_total_tokens
    chunk_tokens = cfg.chunk_token_threshold
    pbar = tqdm(total=tgt_tokens or None, unit="tok", smoothing=0.2)

    # Collect manifest rows in‑memory to avoid pre‑allocation mismatch bugs.
    # Each entry is (chunk_id, local_row).  For 1 M tokens this is only 8 MB.
    manifest_rows: List[np.ndarray] = []

    # Norm‑stat structures
    stats: Dict[int, Dict[str, _RunningStat]] = {}

    g_row = 0
    c_idx = 0
    buf_inp: Dict[int, List[torch.Tensor]] = {}
    buf_tgt: Dict[int, List[torch.Tensor]] = {}
    layer_ids: Optional[List[int]] = None
    d_model = -1
    dtype_str = "unknown"

    for batch_idx, (batch_inp, batch_tgt) in enumerate(stream):
        # with self._conditional_measure("batch_processing_total"):
        if tgt_tokens and g_row >= tgt_tokens:
            break
        if not batch_inp:
            continue

        # with self._conditional_measure("batch_metadata_setup"):
        if layer_ids is None:
            layer_ids = sorted(batch_inp.keys())
            d_model = batch_inp[layer_ids[0]].shape[-1]
            dtype_str = str(batch_inp[layer_ids[0]].dtype)
            # if self.profiler:
            #     self.profiler.set_layer_ids_ref(layer_ids)
            for lid in layer_ids:
                buf_inp[lid] = []
                buf_tgt[lid] = []
                if cfg.compute_norm_stats:
                    stats[lid] = {
                        "inputs": _RunningStat(d_model, device=device),
                        "targets": _RunningStat(d_model, device=device),
                    }
            print(
                "Layers=%d d_model=%d dtype=%s", len(layer_ids) if layer_ids else 0, d_model, dtype_str
            )

        n_tok_in_batch = 0
        if layer_ids and batch_inp.get(layer_ids[0]) is not None:
            n_tok_in_batch = batch_inp[layer_ids[0]].shape[0]

        # with self._conditional_measure("batch_gpu_tensor_accumulate"):
        if layer_ids:
            for lid in layer_ids:
                if lid in batch_inp and lid in batch_tgt:
                    inp = batch_inp[lid].detach()
                    tgt = batch_tgt[lid].detach()
                    buf_inp[lid].append(inp)
                    buf_tgt[lid].append(tgt)
                    if cfg.compute_norm_stats and lid in stats:
                        # with self._conditional_measure(f"batch_norm_stats_update_layer_{lid}"):
                        stats[lid]["inputs"].update(inp)
                        stats[lid]["targets"].update(tgt)
                else:
                    print(
                        f"Layer {lid} expected but not found in current batch. Skipping accumulation for this layer."
                    )

        if n_tok_in_batch > 0:
            g_row += n_tok_in_batch
            pbar.update(n_tok_in_batch)
        #     if self.profiler:
        #         self.profiler.total_tokens_processed_for_batch_profiling += n_tok_in_batch
        # if self.profiler:
        #     self.profiler.batch_processing_total_calls += 1

        if layer_ids and buf_inp.get(layer_ids[0]):
            cur_rows = sum(t.shape[0] for t in buf_inp[layer_ids[0]])
            if cur_rows >= chunk_tokens:
                # with self._conditional_measure("chunk_write_dispatch"):
                self._write_chunk(
                    c_idx,
                    buf_inp,
                    buf_tgt,
                    layer_ids,
                    d_model,
                    cur_rows,
                    manifest_rows,
                    g_row - cur_rows,
                )
                c_idx += 1
                # with self._conditional_measure("chunk_buffer_clear"):
                if layer_ids:
                    for lid_clear in layer_ids:
                        buf_inp[lid_clear].clear()
                        buf_tgt[lid_clear].clear()
        # if batch_idx > 0 and batch_idx % 50 == 0:
        #     if self.profiler:
        #         self.profiler.log_system_metrics(f"batch_interval_{batch_idx}")

    # Flush final partial chunk
    if layer_ids and buf_inp.get(layer_ids[0]):
        # with self._conditional_measure("final_chunk_write_dispatch"):
        rows = sum(t.shape[0] for t in buf_inp[layer_ids[0]])
        self._write_chunk(
            c_idx,
            buf_inp,
            buf_tgt,
            layer_ids,
            d_model,
            rows,
            manifest_rows,
            g_row - rows,
        )
        c_idx += 1

    # if self.profiler:
    #     self.profiler.log_system_metrics("pre_manifest_write")
    # with self._conditional_measure("manifest_concatenate_and_write"):
    if manifest_rows:
        manifest_arr = np.concatenate(manifest_rows, axis=0)
        manifest_arr.tofile(self.manifest_final)
    else:
        print("Manifest_rows is empty, skipping manifest write.")

    # Upload final manifest if remote
    if self.storage_type == "remote" and self.manifest_final.exists():
        try:
            self._upload_binary_file(self.manifest_final, "manifest")
        except Exception as e:
            print("Failed to upload manifest.bin: %s", e)

    # Write metadata JSON
    meta = {
        "model_name": cfg.model_name,
        "dataset": cfg.dataset_path,
        "split": cfg.dataset_split,
        "num_layers": len(layer_ids or []),
        "d_model": d_model,
        "dtype": cfg.activation_dtype,
        "total_tokens": g_row,
        "chunk_tokens": chunk_tokens,
        "created": time.strftime("%Y-%m-%d %H:%M:%S"),
    }
    # with self._conditional_measure("metadata_json_write"):
    with open(self.out_dir / "metadata.json", "w") as f:
        json.dump(meta, f, indent=2)
    print("metadata.json written")

    meta_path = self.out_dir / "metadata.json"
    if self.storage_type == "remote" and self.cfg.remote_server_url:
        # with self._conditional_measure("metadata_json_upload"):
        try:
            self._upload_json(meta_path, "metadata")
            print("metadata.json uploaded to server")
        except Exception as e:
            print("Failed to upload metadata.json: %s", e)

    # Write norm_stats.json
    if cfg.compute_norm_stats and stats:
        norm: Dict[str, Any] = {}
        if layer_ids:
            for lid in layer_ids:
                if lid in stats:
                    # with self._conditional_measure(f"norm_stats_finalize_layer_{lid}"):
                    m_in, s_in = stats[lid]["inputs"].finalize()
                    m_tg, s_tg = stats[lid]["targets"].finalize()
                    norm[str(lid)] = {
                        "inputs": {"mean": m_in.tolist(), "std": s_in.tolist()},
                        "targets": {"mean": m_tg.tolist(), "std": s_tg.tolist()},
                    }
                else:
                    print(f"Layer ID {lid} not found in stats dict during norm_stats finalization.")
        else:
            print("layer_ids is None, cannot write norm_stats.")

        if norm:
            # with self._conditional_measure("norm_stats_json_write"):
            with open(self.out_dir / "norm_stats.json", "w") as f:
                json.dump(norm, f)
            print("norm_stats.json written")

            norm_path = self.out_dir / "norm_stats.json"
            if self.storage_type == "remote" and self.cfg.remote_server_url:
                # with self._conditional_measure("norm_stats_json_upload"):
                try:
                    self._upload_json(norm_path, "norm_stats")
                    print("norm_stats.json uploaded to server")
                except Exception as e:
                    print("Failed to upload norm_stats.json: %s", e)
        elif cfg.compute_norm_stats:
            print("Norm stats computation was enabled, but no norm stats were generated.")

    # # Finish uploading (only if we are in remote mode)
    # if self.storage_type == "remote" and self.uploader and self.upload_q:
    #     with self._conditional_measure("uploader_join"):
    #         self.upload_q.put(None)
    #         self.upload_q.join()
    # print("Finished: %d chunks, %s tokens", c_idx, f"{g_row:,}")
    # if self.profiler:
    #     self.profiler.log_system_metrics("final_system_state")
    #     self.profiler.report()