In [2]:
%pip install jaxtyping transformer_lens plotly-utils einops torch
%pip install git+https://github.com/callummcdougall/eindex.git

Collecting plotly-utils
  Using cached plotly_utils-0.0.3-py3-none-any.whl.metadata (796 bytes)
Using cached plotly_utils-0.0.3-py3-none-any.whl (3.1 kB)
Installing collected packages: plotly-utils
Successfully installed plotly-utils-0.0.3

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.
Collecting git+https://github.com/callummcdougall/eindex.git
  Cloning https://github.com/callummcdougall/eindex.git to /private/var/folders/3c/gk783fjn6g7fhss6pg10m_q40000gn/T/pip-req-build-hji5o24j
  Running command git clone --filter=blob:none --quiet https://github.com/callummcdougall/eindex.git /private/var/folders/3c/gk783fjn6g7fhss6pg10m_q40000gn/T/pip-req-build-hji5o24j
  Resolved https://github.com/callummcdougall/eindex.git to co

In [4]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
import torch as t
from torch import nn, Tensor
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from dataclasses import dataclass
import numpy as np
import einops
from jaxtyping import Float, Int
from typing import Optional, Callable, Union, List, Tuple
from functools import partial
from tqdm.notebook import tqdm
from dataclasses import dataclass
from rich import print as rprint
from rich.table import Table
from IPython.display import display, HTML
from pathlib import Path

# from plotly_utils import imshow, line, hist

device = (
    "cuda"
    if t.cuda.is_available()
    else "mps"
    if t.backends.mps.is_available()
    else "cpu"
)

print("Using device:", device)

MAIN = __name__ == "__main__"

Using device: mps


In [7]:
def linear_lr(step, steps):
    return (1 - (step / steps))

def constant_lr(*_):
    return 1.0

def cosine_decay_lr(step, steps):
    return np.cos(0.5 * np.pi * step / (steps - 1))


@dataclass
class Config:
    # We optimize n_instances models in a single training loop to let us sweep over
    # sparsity or importance curves  efficiently. You should treat `n_instances` as
    # kinda like a batch dimension, but one which is built into our training setup.
    n_instances: int
    n_features: int = 5
    n_hidden: int = 2
    n_correlated_pairs: int = 0
    n_anticorrelated_pairs: int = 0


class Model(nn.Module):
    W: Float[Tensor, "n_instances n_hidden n_features"]
    b_final: Float[Tensor, "n_instances n_features"]
    # Our linear map is x -> ReLU(W.T @ W @ x + b_final)

    def __init__(
        self,
        cfg: Config,
        feature_probability: Optional[Union[float, Tensor]] = None,
        importance: Optional[Union[float, Tensor]] = None,
        device = device,
    ):
        super().__init__()
        self.cfg = cfg

        if feature_probability is None: feature_probability = t.ones(())
        if isinstance(feature_probability, float): feature_probability = t.tensor(feature_probability)
        self.feature_probability = feature_probability.to(device).broadcast_to((cfg.n_instances, cfg.n_features))
        if importance is None: importance = t.ones(())
        if isinstance(importance, float): importance = t.tensor(importance)
        self.importance = importance.to(device).broadcast_to((cfg.n_instances, cfg.n_features))

        self.W = nn.Parameter(nn.init.xavier_normal_(t.empty((cfg.n_instances, cfg.n_hidden, cfg.n_features))))
        self.b_final = nn.Parameter(t.zeros((cfg.n_instances, cfg.n_features)))
        self.to(device)


    def forward(
        self,
        features: Float[Tensor, "... instances features"]
    ) -> Float[Tensor, "... instances features"]:
        hidden = einops.einsum(
           features, self.W,
           "... instances features, instances hidden features -> ... instances hidden"
        )
        out = einops.einsum(
            hidden, self.W,
            "... instances hidden, instances hidden features -> ... instances features"
        )
        return F.relu(out + self.b_final)


    def generate_batch(self, batch_size) -> Float[Tensor, "batch_size instances features"]:
        '''
        Generates a batch of data. We'll return to this function later when we apply correlations.
        '''
        # Generate the features, before randomly setting some to zero
        feat = t.rand((batch_size, self.cfg.n_instances, self.cfg.n_features), device=self.W.device)

        # Generate a random boolean array, which is 1 wherever we'll keep a feature, and zero where we'll set it to zero
        feat_seeds = t.rand((batch_size, self.cfg.n_instances, self.cfg.n_features), device=self.W.device)
        feat_is_present = feat_seeds <= self.feature_probability

        # Create our batch from the features, where we set some to zero
        batch = t.where(feat_is_present, feat, 0.0)

        return batch


    def calculate_loss(
        self,
        out: Float[Tensor, "batch instances features"],
        batch: Float[Tensor, "batch instances features"],
    ) -> Float[Tensor, ""]:
        '''
        Calculates the loss for a given batch, using this loss described in the Toy Models paper:

            https://transformer-circuits.pub/2022/toy_model/index.html#demonstrating-setup-loss

        Remember, `self.importance` will always have shape (n_instances, n_features).
        '''
        error = self.importance * ((batch - out) ** 2)
        loss = einops.reduce(error, 'batch instances features -> instances', 'mean').sum()
        return loss


    def optimize(
        self,
        batch_size: int = 1024,
        steps: int = 10_000,
        log_freq: int = 100,
        lr: float = 1e-3,
        lr_scale: Callable[[int, int], float] = constant_lr,
    ):
        '''
        Optimizes the model using the given hyperparameters.
        '''
        optimizer = t.optim.Adam(list(self.parameters()), lr=lr)

        progress_bar = tqdm(range(steps))

        for step in progress_bar:

            # Update learning rate
            step_lr = lr * lr_scale(step, steps)
            for group in optimizer.param_groups:
                group['lr'] = step_lr

            # Optimize
            optimizer.zero_grad()
            batch = self.generate_batch(batch_size)
            out = self(batch)
            loss = self.calculate_loss(out, batch)
            loss.backward()
            optimizer.step()

            # Display progress bar
            if step % log_freq == 0 or (step + 1 == steps):
                progress_bar.set_postfix(loss=loss.item()/self.cfg.n_instances, lr=step_lr)

In [9]:
@dataclass
class AutoEncoderConfig:
    n_instances: int
    n_input_ae: int
    n_hidden_ae: int
    l1_coeff: float = 0.5
    tied_weights: bool = False
    weight_normalize_eps: float = 1e-8


class AutoEncoder(nn.Module):
    W_enc: Float[Tensor, "n_instances n_input_ae n_hidden_ae"]
    W_dec: Float[Tensor, "n_instances n_hidden_ae n_input_ae"]
    b_enc: Float[Tensor, "n_instances n_hidden_ae"]
    b_dec: Float[Tensor, "n_instances n_input_ae"]


    def __init__(self, cfg: AutoEncoderConfig):
        '''
        Initializes the two weights and biases according to the type signature above.

        If self.cfg.tied_weights = True, then we only create W_enc, not W_dec.
        '''
        super(AutoEncoder, self).__init__()
        self.cfg = cfg

        self.W_enc = nn.Parameter(nn.init.xavier_normal_(t.empty((cfg.n_instances, cfg.n_input_ae, cfg.n_hidden_ae))))
        if not(cfg.tied_weights):
            self.W_dec = nn.Parameter(nn.init.xavier_normal_(t.empty((cfg.n_instances, cfg.n_hidden_ae, cfg.n_input_ae))))

        self.b_enc = nn.Parameter(t.zeros(cfg.n_instances, cfg.n_hidden_ae))
        self.b_dec = nn.Parameter(t.zeros(cfg.n_instances, cfg.n_input_ae))

        self.to(device)


    def normalize_and_return_W_dec(self) -> Float[Tensor, "n_instances n_hidden_ae n_input_ae"]:
        '''
        If self.cfg.tied_weights = True, we return the normalized & transposed encoder weights.
        If self.cfg.tied_weights = False, we normalize the decoder weights in-place, and return them.

        Normalization should be over the `n_input_ae` dimension, i.e. each feature should have a noramlized decoder weight.
        '''
        if self.cfg.tied_weights:
            return self.W_enc.transpose(-1, -2) / (self.W_enc.transpose(-1, -2).norm(dim=1, keepdim=True) + self.cfg.weight_normalize_eps)
        else:
            self.W_dec.data = self.W_dec.data / (self.W_dec.data.norm(dim=2, keepdim=True) + self.cfg.weight_normalize_eps)
            return self.W_dec


    def forward(self, h: Float[Tensor, "batch_size n_instances n_input_ae"]):
        '''
        Runs a forward pass on the autoencoder, and returns several outputs.

        Inputs:
            h: Float[Tensor, "batch_size n_instances n_input_ae"]
                hidden activations generated from a Model instance

        Returns:
            l1_loss: Float[Tensor, "batch_size n_instances"]
                L1 loss for each batch elem & each instance (sum over the `n_hidden_ae` dimension)
            l2_loss: Float[Tensor, "batch_size n_instances"]
                L2 loss for each batch elem & each instance (take mean over the `n_input_ae` dimension)
            loss: Float[Tensor, ""]
                Sum of L1 and L2 loss (with the former scaled by `self.cfg.l1_coeff). We sum over the `n_instances`
                dimension but take mean over the batch dimension
            acts: Float[Tensor, "batch_size n_instances n_hidden_ae"]
                Activations of the autoencoder's hidden states (post-ReLU)
            h_reconstructed: Float[Tensor, "batch_size n_instances n_input_ae"]
                Reconstructed hidden states, i.e. the autoencoder's final output
        '''
        # Compute activations
        h_cent = h - self.b_dec
        acts = einops.einsum(
            h_cent, self.W_enc,
            "batch_size n_instances n_input_ae, n_instances n_input_ae n_hidden_ae -> batch_size n_instances n_hidden_ae"
        )
        acts = F.relu(acts + self.b_enc)

        # Compute reconstructed input
        h_reconstructed = einops.einsum(
            acts, self.normalize_and_return_W_dec(),
            "batch_size n_instances n_hidden_ae, n_instances n_hidden_ae n_input_ae -> batch_size n_instances n_input_ae"
        ) + self.b_dec

        # Compute loss, return values
        l2_loss = (h_reconstructed - h).pow(2).mean(-1) # shape [batch_size n_instances]
        l1_loss = acts.abs().sum(-1) # shape [batch_size n_instances]
        loss = (self.cfg.l1_coeff * l1_loss + l2_loss).mean(0).sum() # scalar

        return l1_loss, l2_loss, loss, acts, h_reconstructed


    def optimize(
        self,
        model: Model,
        batch_size: int = 1024,
        steps: int = 10_000,
        log_freq: int = 100,
        lr: float = 1e-3,
        lr_scale: Callable[[int, int], float] = constant_lr,
        neuron_resample_window: Optional[int] = None,
        dead_neuron_window: Optional[int] = None,
        neuron_resample_scale: float = 0.2,
    ):
        '''
        Optimizes the autoencoder using the given hyperparameters.

        The autoencoder is trained on the hidden state activations produced by 'model', and it
        learns to reconstruct the features which this model represents in superposition.
        '''
        if neuron_resample_window is not None:
            assert (dead_neuron_window is not None) and (dead_neuron_window < neuron_resample_window)

        optimizer = t.optim.Adam(list(self.parameters()), lr=lr)
        frac_active_list = []
        progress_bar = tqdm(range(steps))

        # Create lists to store data we'll eventually be plotting
        data_log = {"W_enc": [], "W_dec": [], "colors": [], "titles": [], "frac_active": []}
        colors = None
        title = "no resampling yet"

        for step in progress_bar:

            # Update learning rate based on `lr_scale` function
            step_lr = lr * lr_scale(step, steps)
            for group in optimizer.param_groups:
                group['lr'] = step_lr

            # Get a batch of hidden activations from the model (for the training step & the neuron resampling)
            with t.inference_mode():
                features = model.generate_batch(batch_size)
                h = einops.einsum(features, model.W, "batch instances feats, instances hidden feats -> batch instances hidden")

            # Resample dead neurons
            if (neuron_resample_window is not None) and ((step + 1) % neuron_resample_window == 0):
                # Get the fraction of neurons active in the previous window
                frac_active_in_window = t.stack(frac_active_list[-neuron_resample_window:], dim=0)
                # Apply resampling
                colors, title = self.resample_neurons(h, frac_active_in_window, neuron_resample_scale)

            # Optimize
            l1_loss, l2_loss, loss, acts, _ = self.forward(h)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # Calculate the mean sparsities over batch dim for each (instance, feature)
            frac_active = (acts.abs() > 1e-8).float().mean(0)
            frac_active_list.append(frac_active)

            # Display progress bar, and append new values for plotting
            if step % log_freq == 0 or (step + 1 == steps):
                progress_bar.set_postfix(l1_loss=self.cfg.l1_coeff * l1_loss.mean(0).sum().item(), l2_loss=l2_loss.mean(0).sum().item(), lr=step_lr)
                data_log["W_enc"].append(self.W_enc.detach().cpu().clone())
                data_log["W_dec"].append(self.normalize_and_return_W_dec().detach().cpu().clone())
                data_log["colors"].append(colors)
                data_log["titles"].append(f"Step {step}/{steps}: {title}")
                data_log["frac_active"].append(frac_active.detach().cpu().clone())

        return data_log


    @t.no_grad()
    def resample_neurons(
        self,
        h: Float[Tensor, "batch_size n_instances n_input_ae"],
        frac_active_in_window: Float[Tensor, "window n_instances n_hidden_ae"],
        neuron_resample_scale: float,
    ) -> Tuple[List[List[str]], str]:
        '''
        Resamples neurons that have been dead for 'dead_neuron_window' steps, according to `frac_active`.
        '''
        # Get a tensor of dead neurons
        dead_features_mask = frac_active_in_window.sum(0) < 1e-8 # shape [instances hidden_ae]
        n_dead = dead_features_mask.int().sum().item()

        # Get our random replacement values
        replacement_values = t.randn((n_dead, self.cfg.n_input_ae), device=self.W_enc.device) # shape [n_dead n_input_ae]
        replacement_values_normalized = replacement_values / (replacement_values.norm(dim=-1, keepdim=True) + 1e-8)

        # Change the corresponding values in W_enc, W_dec, and b_enc (note we transpose W_enc to return a view with correct shape)
        self.W_enc.data.transpose(-1, -2)[dead_features_mask] = replacement_values_normalized
        self.W_dec.data[dead_features_mask] = replacement_values_normalized
        self.b_enc.data[dead_features_mask] = 0.0

        # Return data for visualising the resampling process
        colors = [["red" if dead else "black" for dead in dead_neuron_mask_inst] for dead_neuron_mask_inst in dead_features_mask]
        title = f"resampling {n_dead}/{dead_features_mask.numel()} neurons (shown in red)"
        return colors, title

In [10]:
from transformer_lens import HookedTransformer, FactoredMatrix
from transformer_lens.hook_points import HookPoint

from transformer_lens.utils import (
    load_dataset,
    tokenize_and_concatenate,
    download_file_from_hf,
)

In [11]:
VERSION_DICT = {"run1": 25, "run2": 47}


def load_autoencoder_from_huggingface(versions: List[str] = ["run1", "run2"]):
    state_dict = {}

    for version in versions:
        version_id = VERSION_DICT[version]
        # Load the data from huggingface (both metadata and state dict)
        sae_data: dict = download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version_id}_cfg.json")
        new_state_dict: dict = download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version_id}.pt", force_is_torch=True)
        # Add new state dict to the existing one
        for k, v in new_state_dict.items():
            state_dict[k] = t.stack([state_dict[k], v]) if k in state_dict else v

    # Get data about the model dimensions, and use that to initialize our model (with 2 instances)
    d_mlp = sae_data["d_mlp"]
    dict_mult = sae_data["dict_mult"]
    n_hidden_ae = d_mlp * dict_mult

    cfg = AutoEncoderConfig(
        n_instances = 2,
        n_input_ae = d_mlp,
        n_hidden_ae = n_hidden_ae,
    )

    # Initialize our model, and load in state dict
    autoencoder = AutoEncoder(cfg)
    autoencoder.load_state_dict(state_dict)

    return autoencoder


autoencoder = load_autoencoder_from_huggingface()

25_cfg.json:   0%|          | 0.00/283 [00:00<?, ?B/s]

25.pt:   0%|          | 0.00/269M [00:00<?, ?B/s]

47_cfg.json:   0%|          | 0.00/309 [00:00<?, ?B/s]

47.pt:   0%|          | 0.00/269M [00:00<?, ?B/s]

In [12]:
model = HookedTransformer.from_pretrained("gelu-1l").to(device)

print(model)

Loaded pretrained model gelu-1l into HookedTransformer
Moving model to device:  mps
HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
 

In [None]:
data = load_dataset("NeelNanda/c4-code-20k", split="train")
tokenized_data = tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(42)
all_tokens = tokenized_data["tokens"]
print("Tokens shape: ", all_tokens.shape)

In [None]:
@t.inference_mode()
def highest_activating_tokens(
    tokens: Int[Tensor, "batch seq"],
    model: HookedTransformer,
    autoencoder: AutoEncoder,
    feature_idx: int,
    autoencoder_B: bool = False,
    k: int = 10,
) -> Tuple[Int[Tensor, "k 2"], Float[Tensor, "k"]]:
    '''
    Returns the indices & values for the highest-activating tokens in the given batch of data.
    '''
    batch_size, seq_len = tokens.shape
    instance_idx = 1 if autoencoder_B else 0

    # Get the post activations from the clean run
    cache = model.run_with_cache(tokens, names_filter=["blocks.0.mlp.hook_post"])[1]
    post = cache["blocks.0.mlp.hook_post"]
    post_reshaped = einops.rearrange(post, "batch seq d_mlp -> (batch seq) d_mlp")

    # Compute activations (not from a fwd pass, but explicitly, by taking only the feature we want)
    # This code is copied from the first part of the 'forward' method of the AutoEncoder class
    h_cent = post_reshaped - autoencoder.b_dec[instance_idx]
    acts = einops.einsum(
        h_cent, autoencoder.W_enc[instance_idx, :, feature_idx],
        "batch_size n_input_ae, n_input_ae -> batch_size"
    )

    # Get the top k largest activations
    top_acts_values, top_acts_indices = acts.topk(k)

    # Convert the indices into (batch, seq) indices
    top_acts_batch = top_acts_indices // seq_len
    top_acts_seq = top_acts_indices % seq_len

    return t.stack([top_acts_batch, top_acts_seq], dim=-1), top_acts_values


def display_top_sequences(top_acts_indices, top_acts_values, tokens):
    table = Table("Sequence", "Activation", title="Tokens which most activate this feature")
    for (batch_idx, seq_idx), value in zip(top_acts_indices, top_acts_values):
        # Get the sequence as a string (with some padding on either side of our sequence)
        seq = ""
        for i in range(max(seq_idx-5, 0), min(seq_idx+5, all_tokens.shape[1])):
            new_str_token = model.to_single_str_token(tokens[batch_idx, i].item()).replace("\n", "\\n")
            # Highlight the token with the high activation
            if i == seq_idx: new_str_token = f"[b u dark_orange]{new_str_token}[/]"
            seq += new_str_token
        # Print the sequence, and the activation value
        table.add_row(seq, f'{value:.2f}')
    rprint(table)

tokens = all_tokens[:200]

# for feature_idx in range(5,10):
#   top_acts_indices, top_acts_values = highest_activating_tokens(tokens, model, autoencoder, feature_idx=feature_idx, autoencoder_B=False)
#   display_top_sequences(top_acts_indices, top_acts_values, tokens)

top_acts_indices, top_acts_values = highest_activating_tokens(tokens, model, autoencoder, feature_idx=7, autoencoder_B=False)
display_top_sequences(top_acts_indices, top_acts_values, tokens)

In [None]:
W_dec_vector = autoencoder.W_dec[0, 7]

W_dec_logits = W_dec_vector @ model.W_out[0] @ model.W_U

top_tokens = W_dec_logits.topk(10)
bottom_tokens = W_dec_logits.topk(10, largest=False)

s = "Top tokens:\n"
for token, value in zip(top_tokens.indices, top_tokens.values):
    s += f"({value:.2f}) {model.to_single_str_token(token.item())}\n"
s += "\nBottom tokens:\n"
for token, value in zip(bottom_tokens.indices, bottom_tokens.values):
    s += f"({value:.2f}) {model.to_single_str_token(token.item())}\n"
rprint(s)

# Experimenting...

For each feature we want to:

1. Get the indices and values for the topK activations
    
    * Get post activations from a clean run
    
    * Compute activations (not from a fwd pass, but explicitly, by taking only the feature we want)

2. For each activation tuple (index and value), get the token as a deserialized/human-readable string with some padding/context and the activation logit (indicating the activation power)

We need to code up a solution that lets us do this for many features, iteratively.

Importantly, we want to store the results in a data structure that can be (1) easily parsed, and (2) easily serialized to JSON so that it can be fed to an LLM for deeper analysis.

In [None]:
import torch
import einops
from typing import List, Dict, NamedTuple
from rich.console import Console
from rich.table import Table

console = Console()

class SeqActivation(NamedTuple):
    context: str
    seq: str
    activation: float

class FeatureActivations(NamedTuple):
    feature_idx: int
    top_10_activations: List[SeqActivation]
    bottom_10_activations: List[SeqActivation]

@torch.inference_mode()
def get_feature_activations(
    tokens: torch.Tensor,
    model: HookedTransformer,
    autoencoder: AutoEncoder,
    feature_idx: int,
    k: int = 10
) -> FeatureActivations:
    batch_size, seq_len = tokens.shape
    
    # Get the post activations from the clean run
    cache = model.run_with_cache(tokens, names_filter=["blocks.0.mlp.hook_post"])[1]
    post = cache["blocks.0.mlp.hook_post"]
    post_reshaped = einops.rearrange(post, "batch seq d_mlp -> (batch seq) d_mlp")

    # Compute activations for this feature
    h_cent = post_reshaped - autoencoder.b_dec[0]
    acts = einops.einsum(
        h_cent, autoencoder.W_enc[0, :, feature_idx],
        "batch_size n_input_ae, n_input_ae -> batch_size"
    )

    # Get the top and bottom k activations
    top_acts_values, top_acts_indices = acts.topk(k)
    bottom_acts_values, bottom_acts_indices = acts.topk(k, largest=False)

    def create_seq_activations(values, indices):
        seq_activations = []
        for idx, value in zip(indices, values):
            batch_idx = idx // seq_len
            seq_idx = idx % seq_len
            
            context = ""
            for i in range(max(seq_idx-5, 0), min(seq_idx+5, tokens.shape[1])):
                token = model.to_single_str_token(tokens[batch_idx, i].item()).replace("\n", "\\n")
                if i == seq_idx:
                    seq = token
                context += token

            seq_activations.append(SeqActivation(context=context, seq=seq, activation=value.item()))
        
        return seq_activations

    top_10_activations = create_seq_activations(top_acts_values, top_acts_indices)
    bottom_10_activations = create_seq_activations(bottom_acts_values, bottom_acts_indices)

    return FeatureActivations(
        feature_idx=feature_idx,
        top_10_activations=top_10_activations,
        bottom_10_activations=bottom_10_activations
    )

def collect_feature_activations(
    tokens: torch.Tensor,
    model: HookedTransformer,
    autoencoder: AutoEncoder,
    num_features: int
) -> List[FeatureActivations]:
    feature_activations = []
    for feature_idx in range(num_features):
        activations = get_feature_activations(tokens, model, autoencoder, feature_idx)
        feature_activations.append(activations)
        console.print(f"Processed feature {feature_idx + 1}/{num_features}")
    return feature_activations

# Usage
tokens = all_tokens[:200]  # Adjust as needed
num_features_to_analyze = 10  # Adjust as needed

feature_activations = collect_feature_activations(tokens, model, autoencoder, num_features_to_analyze)

# Display results
for feature in feature_activations:
    console.print(f"\nFeature {feature.feature_idx}:")
    
    table = Table(title=f"Top 10 Activations for Feature {feature.feature_idx}")
    table.add_column("Context", style="cyan")
    table.add_column("Sequence", style="magenta")
    table.add_column("Activation", justify="right", style="green")
    
    for activation in feature.top_10_activations:
        table.add_row(activation.context, activation.seq, f"{activation.activation:.4f}")
    
    console.print(table)

    table = Table(title=f"Bottom 10 Activations for Feature {feature.feature_idx}")
    table.add_column("Context", style="cyan")
    table.add_column("Sequence", style="magenta")
    table.add_column("Activation", justify="right", style="red")
    
    for activation in feature.bottom_10_activations:
        table.add_row(activation.context, activation.seq, f"{activation.activation:.4f}")
    
    console.print(table)