In [1]:
# testing consequences of interventions on gpt2, and how they match up against our SAE's
import sys
from dataclasses import dataclass, field
from functools import partial
from typing import Dict, List

import datasets
import seaborn as sns
import torch
import torch.nn.functional as F
from einops import einsum
from safetensors.torch import load_file
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from sae.data import chunk_and_tokenize

sys.path.append("..")

%env CUDA_LAUNCH_BLOCKING=1

env: CUDA_LAUNCH_BLOCKING=1


In [2]:
ckpt_path = "/home/sid/tensor-sae/checkpoints/pythia14m-all-layers-rp1t/pythia70m-all-layers-rp1t-sample_20240912_003009/layers.0_layers.1_layers.2_layers.3_layers.4_layers.5/sae-915.safetensors"
model_name = "EleutherAI/pythia-70m"

In [3]:
# to use jacrevd need eager implementation
model = AutoModelForCausalLM.from_pretrained(
    model_name, attn_implementation="eager"
).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name)

sae_ckpt = load_file(ckpt_path, device="cuda:0")

feature_encoder_weights = sae_ckpt.get("encoder.weight", sae_ckpt.get("weight"))
feature_encoder_bias = sae_ckpt.get("encoder.bias", sae_ckpt.get("bias"))
# legacy keys
feature_decoder_weights = sae_ckpt["decoder.weight"]
feature_decoder_bias = sae_ckpt["decoder.bias"]

In [4]:
seed = 42
seq_len = 64
num_samples = 100

dataset = datasets.load_dataset(
    "togethercomputer/RedPajama-Data-1T-Sample",
    split="train",
    trust_remote_code=True,
)

dataset = dataset.train_test_split(
    test_size=0.8, seed=seed
).get("test").select(range(num_samples))

tokenized = chunk_and_tokenize(dataset, tokenizer, max_seq_len=seq_len)

In [5]:
feature_encoder_bias.shape

torch.Size([24576])

In [6]:
@dataclass
class InterventionOutputs:
    activation_positions: torch.Tensor
    causal_embeddings: torch.Tensor
    v_j: torch.Tensor
    v_k: torch.Tensor
    is_valid: torch.Tensor

In [7]:
@dataclass
class FeatureStats:
    causality: List[float] = field(default_factory=list)
    cosine: List[float] = field(default_factory=list)
    error: List[float] = field(default_factory=list)
    feature_activation_strength: List[float] = field(default_factory=list)


@dataclass
class GlobalFeatureStatistics:
    feature_activation_rate: torch.Tensor
    total_active_features: float
    avg_active_features_per_token: float
    feature_dict: Dict[int, FeatureStats]


def compute_feature_statistics(
    model,
    tokenized,
    feature_encoder_weights,
    feature_encoder_bias,
    sae_top_k: int = 128,
    batch_size: int = 256,
    exclude_first_k_tokens: int = 4,
):
    # (N,)
    total_active_features = torch.zeros(
        feature_encoder_weights.shape[0], device=model.device
    )

    dataloader = torch.utils.data.DataLoader(
        tokenized, batch_size=batch_size, shuffle=False
    )

    for batch in tqdm(dataloader, desc="Processing batches"):
        input_ids = batch["input_ids"].to(model.device)
        attention_mask = torch.ones_like(input_ids, device=model.device)

        with torch.no_grad():
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=True,
            )
            hiddens = outputs.hidden_states

        stacked_hiddens = torch.cat(hiddens[1:], dim=-1)[:, exclude_first_k_tokens:, :]

        encoded_features = torch.einsum(
            "be,nse->nsb", feature_encoder_weights, stacked_hiddens
        )
        encoded_features = encoded_features + feature_encoder_bias.unsqueeze(
            0
        ).unsqueeze(0)

        k_th_strongest = (
            torch.topk(encoded_features, k=sae_top_k, dim=-1)
            .values[:, :, -1]
            .unsqueeze(-1)
        )

        batch_binary_mask = (encoded_features >= k_th_strongest).float()

        total_active_features += batch_binary_mask.sum(dim=(0, 1))

    seq_len = tokenized[0]["input_ids"].shape[0] - exclude_first_k_tokens
    feature_activation_rate = total_active_features / (len(tokenized) * seq_len)

    total_active = total_active_features.sum().item()
    avg_active_per_token = total_active / (len(tokenized) * seq_len)

    feature_dict = {i: FeatureStats() for i in range(feature_encoder_weights.shape[-1])}

    return GlobalFeatureStatistics(
        feature_activation_rate=feature_activation_rate,
        total_active_features=total_active,
        avg_active_features_per_token=avg_active_per_token,
        feature_dict=feature_dict,
    )

In [19]:
def perform_intervention(
    model: torch.nn.Module,
    batch: torch.Tensor,
    feature_activation_rate: torch.Tensor,
    intervention_index: int,
    readout_index: int,
    feature_encoder_weights: torch.Tensor,
    feature_encoder_bias: torch.Tensor,
    feature_decoder_weights: torch.Tensor,
    lambda_value: float = 1.0,
    num_tokens: int = 1,
    feature_top_k: int = None,
    sae_top_k: int = 128,
) -> InterventionOutputs:
    """
    Perform an intervention on a model's activations using Sparse Autoencoder (SAE) features.

    Args:
        model: The PyTorch model to intervene on.
        batch: Input tensor to the model.
        feature_activation_rate: the global feature activation rate statistic
        intervention_index: Index of the layer to intervene on.
        readout_index: Index of the layer to read out from.
        feature_encoder_weights: Weights of the SAE encoder.
        feature_encoder_bias: Bias of the SAE encoder.
        feature_decoder_weights: Weights of the SAE decoder.
        lambda_value: Strength of the intervention (default: 1.0).
        num_tokens: Number of tokens to intervene on (default: 1).
        feature_top_k: Index of the specific feature to intervene on.
        sae_top_k: Number of top SAE features to consider.

    Returns:
        the results of the intervention as an InterventionOutputs object
    """

    activation_positions = None
    causal_embeddings = None
    # j < k in layer idx
    v_j = None
    v_k = None
    is_valid = None

    global_top_feature_indices = torch.topk(
        feature_activation_rate, k=sae_top_k, dim=-1
    ).indices

    def strengthen_specific_features(module, input, output, layer_offset=0):
        nonlocal activation_positions, causal_embeddings, v_j, v_k, is_valid

        embed_dim = output[0].shape[-1]
        feature_encoder_segment = feature_encoder_weights[
            :,
            (intervention_index - layer_offset) * embed_dim : (
                intervention_index - layer_offset + 1
            )
            * embed_dim,
        ]
        feature_decoder_segment = feature_decoder_weights[
            :,
            (intervention_index - layer_offset) * embed_dim : (
                intervention_index - layer_offset + 1
            )
            * embed_dim,
        ]

        # Encode input activations
        feature_activation = (
            einsum(output[0], feature_encoder_segment.T, "b s e, e n -> b s n")
            - feature_encoder_bias
        )

        batch_size, seq_len, _ = output[0].shape
        causal_embeddings = output[0]

        # Create a mask for active features
        feature_mask = torch.zeros(
            batch_size,
            seq_len,
            feature_activation.shape[-1],
            device=output[0].device,
            dtype=torch.bool,
        )
        feature_mask.scatter_(-1, global_top_feature_indices[None, None, :], 1)
        feature_activation_global_masked = feature_activation * feature_mask

        # Compute top-k index based on global activation
        _, top_k_feature_index = torch.kthvalue(
            feature_activation_global_masked,
            k=feature_top_k,
            dim=-1,
        )

        # Get the decoder vectors for the specified feature index
        v_j = feature_decoder_segment[top_k_feature_index]

        # Select the tokens with the highest activation of the pinned feature
        _, activation_positions = torch.topk(
            feature_activation_global_masked, k=num_tokens, dim=-1
        )

        # Create a mask for the selected token positions
        token_mask = torch.zeros(
            batch_size,
            seq_len,
            feature_activation_global_masked.shape[-1],
            device=output[0].device,
        )
        token_mask.scatter_(-1, activation_positions, 1)

        new_output = output[0].clone()

        # Add lambda * v_j only to the selected token positions
        new_output += lambda_value * v_j * token_mask.any(-1, keepdim=True)
        new_outputs = [new_output] + list(output[1:])

        # Assign v_k
        intervention_decoder_segment = feature_decoder_weights[
            :,
            (readout_index - layer_offset) * embed_dim : (
                readout_index - layer_offset + 1
            )
            * embed_dim,
        ]
        v_k = intervention_decoder_segment[top_k_feature_index]

        # Check if the feature fires in any of the tokens
        is_valid = (
            (feature_activation_global_masked[:, :, top_k_feature_index] > 0)
            .any(dim=1)
            .bool()
        )

        return tuple(new_outputs)

    if "gpt" in model_name:
        intervention_hook = model.transformer.h[
            intervention_index
        ].register_forward_hook(
            partial(strengthen_specific_features, layer_offset=intervention_index)
        )
    else:
        intervention_hook = model.gpt_neox.layers[
            intervention_index
        ].register_forward_hook(
            partial(strengthen_specific_features, layer_offset=intervention_index)
        )

    # do interventions by running hooks
    with intervention_hook, torch.no_grad():
        model(**batch)

    return InterventionOutputs(
        activation_positions, causal_embeddings, v_j, v_k, is_valid
    )

In [20]:
# Usage
stats = compute_feature_statistics(
    model,
    tokenized,
    feature_encoder_weights,
    feature_encoder_bias,
    sae_top_k=128
)

print(f"Total number of active features: {stats.total_active_features}")
print(f"Average number of active features per token: {stats.avg_active_features_per_token:.2f}")

Processing batches: 100%|██████████| 7/7 [00:01<00:00,  4.12it/s]


Total number of active features: 13301770.0
Average number of active features per token: 128.00


In [21]:
# Example usage with specific feature indices
intervention_index = 4
readout_index = 5
text = ["Hello, world!", "Hello, world!"]
test_batch = tokenizer(text, return_tensors="pt").to("cuda")

feature_top_k = 10

intervention = perform_intervention(
    model=model,
    batch=test_batch,
    feature_activation_rate=stats.feature_activation_rate,
    intervention_index=intervention_index,
    readout_index=readout_index,
    feature_encoder_weights=feature_encoder_weights,
    feature_encoder_bias=feature_encoder_bias,
    feature_decoder_weights=feature_decoder_weights,
    lambda_value=1.0,
    num_tokens=1,
    feature_top_k=feature_top_k
)

torch.Size([2, 4, 1])


In [None]:
intervention.activation_positions

tensor([[[21350],
         [    0],
         [    0],
         [    1]],

        [[    0],
         [    0],
         [    0],
         [    1]]], device='cuda:0')