In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Standard imports
# Imports for displaying vis in Colab / notebook
from typing import Optional

import einops
import torch
import torch.nn as nn
from beartype import beartype
from jaxtyping import Float, Int, jaxtyped
from tqdm import tqdm

PORT = 8000

import sae_bench.sae_bench_utils.dataset_info as dataset_info
import sae_bench.sae_bench_utils.dataset_utils as dataset_utils

torch.set_grad_enabled(False)

In [None]:
if torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

In [None]:
from sae_lens import SAE
from transformer_lens import HookedTransformer

model_dtype = torch.bfloat16

model = HookedTransformer.from_pretrained(
    "pythia-70m-deduped", device=device, dtype=model_dtype
)

In [None]:
# the cfg dict is returned alongside the SAE since it may contain useful information for analysing the SAE (eg: instantiating an activation store)
# Note that this is not the same as the SAEs config dict, rather it is whatever was in the HF repo, from which we can extract the SAE config dict
# We also return the feature sparsities which are stored in HF for convenience.
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release="sae_bench_pythia70m_sweep_topk_ctx128_0730",
    sae_id="blocks.4.hook_resid_post__trainer_10",
    device=device,
)
8

print(sae.cfg)

context_length = sae.cfg.context_size
layer = sae.cfg.hook_layer
batch_size = 128

In [None]:
train_set_size = 4000
test_set_size = 1000

random_seed = 42

dataset_name = "bias_in_bios"
train_df, test_df = dataset_utils.load_huggingface_dataset(dataset_name)
train_data, test_data = dataset_utils.get_multi_label_train_test_data(
    train_df, test_df, dataset_name, train_set_size, test_set_size, random_seed
)

In [None]:
chosen_classes = ["0", "1", "2"]

train_data = utils.filter_dataset(train_data, chosen_classes)
test_data = utils.filter_dataset(test_data, chosen_classes)

In [None]:
train_data = utils.tokenize_data(train_data, model.tokenizer, context_length, device)
test_data = utils.tokenize_data(test_data, model.tokenizer, context_length, device)

In [None]:
first_key = next(iter(train_data.keys()))
print(first_key)
print(train_data[first_key].keys())

In [None]:
print(len(train_data[first_key]["input_ids"]))

In [None]:
hook_name = f"blocks.{layer}.hook_resid_post"


@jaxtyped(typechecker=beartype)
@torch.no_grad
def get_all_llm_activations(
    tokenized_inputs_dict: dict[
        str, dict[str, Int[torch.Tensor, "dataset_size seq_len"]]
    ],
    model: HookedTransformer,
    batch_size: int,
    hook_name: str,
) -> dict[str, Float[torch.Tensor, "batch_size seq_len d_model"]]:
    """VERY IMPORTANT NOTE: We zero out masked token activations in this function. Later, we ignore zeroed activations."""
    all_classes_acts_BLD = {}

    for class_name in tokenized_inputs_dict:
        all_acts_BLD = []
        tokenized_inputs = tokenized_inputs_dict[class_name]

        for i in tqdm(range(len(tokenized_inputs["input_ids"]) // batch_size)):
            tokens_BL = tokenized_inputs["input_ids"][
                i * batch_size : (i + 1) * batch_size
            ]
            attention_mask_BL = tokenized_inputs["attention_mask"][
                i * batch_size : (i + 1) * batch_size
            ]

            acts_BLD = None

            def activation_hook(resid_BLD: torch.Tensor, hook):
                nonlocal acts_BLD
                acts_BLD = resid_BLD

            model.run_with_hooks(
                tokens_BL, return_type=None, fwd_hooks=[(hook_name, activation_hook)]
            )

            acts_BLD = acts_BLD * attention_mask_BL[:, :, None]
            all_acts_BLD.append(acts_BLD)

        all_acts_BLD = torch.cat(all_acts_BLD, dim=0)

        all_classes_acts_BLD[class_name] = all_acts_BLD

    return all_classes_acts_BLD


all_train_acts_BLD = get_all_llm_activations(train_data, model, batch_size, hook_name)
all_test_acts_BLD = get_all_llm_activations(test_data, model, batch_size, hook_name)

In [None]:
@jaxtyped(typechecker=beartype)
def create_meaned_model_activations(
    all_llm_activations_BLD: dict[
        str, Float[torch.Tensor, "batch_size seq_len d_model"]
    ],
    dtype: torch.dtype,
) -> dict[str, Float[torch.Tensor, "batch_size d_model"]]:
    """VERY IMPORTANT NOTE: We assume that the activations have been zeroed out for masked tokens."""
    all_llm_activations_BD = {}
    for class_name in all_llm_activations_BLD:
        acts_BLD = all_llm_activations_BLD[class_name]
        activations_BL = einops.reduce(acts_BLD, "B L D -> B L", "sum")
        nonzero_acts_BL = (activations_BL != 0.0).to(dtype=dtype)
        nonzero_acts_B = einops.reduce(nonzero_acts_BL, "B L -> B", "sum")

        meaned_acts_BD = (
            einops.reduce(acts_BLD, "B L D -> B D", "sum") / nonzero_acts_B[:, None]
        )
        all_llm_activations_BD[class_name] = meaned_acts_BD

    return all_llm_activations_BD

In [None]:
@jaxtyped(typechecker=beartype)
@torch.no_grad
def get_sae_meaned_activations(
    all_llm_activations_BLD: dict[
        str, Float[torch.Tensor, "batch_size seq_len d_model"]
    ],
    sae: SAE,
    sae_batch_size: int,
    dtype: torch.dtype,
) -> dict[str, Float[torch.Tensor, "batch_size d_sae"]]:
    """VERY IMPORTANT NOTE: We assume that the activations have been zeroed out for masked tokens."""
    all_sae_activations_BF = {}
    for class_name in tqdm(all_llm_activations_BLD):
        all_acts_BLD = all_llm_activations_BLD[class_name]

        all_acts_BF = []

        for i in range(len(all_acts_BLD) // sae_batch_size):
            acts_BLD = all_acts_BLD[i * sae_batch_size : (i + 1) * sae_batch_size]
            acts_BLF = sae.encode(acts_BLD)

            activations_BL = einops.reduce(acts_BLD, "B L D -> B L", "sum")
            nonzero_acts_BL = (activations_BL != 0.0).to(dtype=dtype)
            nonzero_acts_B = einops.reduce(nonzero_acts_BL, "B L -> B", "sum")

            acts_BLF = acts_BLF * nonzero_acts_BL[:, :, None]
            acts_BF = (
                einops.reduce(acts_BLF, "B L F -> B F", "sum") / nonzero_acts_B[:, None]
            )
            acts_BF = acts_BF.to(dtype=dtype)

            all_acts_BF.append(acts_BF)

        all_acts_BF = torch.cat(all_acts_BF, dim=0)
        all_sae_activations_BF[class_name] = all_acts_BF

    return all_sae_activations_BF

In [None]:
@jaxtyped(typechecker=beartype)
def prepare_probe_data(
    all_activations: dict[str, Float[torch.Tensor, "num_datapoints d_model"]],
    class_idx: str,
    batch_size: int,
    select_top_k: int | None = None,  # experimental feature
) -> tuple[
    list[Float[torch.Tensor, "batch_size d_model"]],
    list[Int[torch.Tensor, "batch_size"]],
]:
    positive_acts_BD = all_activations[class_idx]
    device = positive_acts_BD.device

    num_positive = len(positive_acts_BD)

    # Collect all negative class activations and labels
    negative_acts = []
    for idx, acts in all_activations.items():
        if idx != class_idx:
            negative_acts.append(acts)

    negative_acts = torch.cat(negative_acts)

    # Randomly select num_positive samples from negative class
    indices = torch.randperm(len(negative_acts))[:num_positive]
    selected_negative_acts_BD = negative_acts[indices]

    assert selected_negative_acts_BD.shape == positive_acts_BD.shape

    # Experimental feature: find the top k features that differ the most between in distribution and out of distribution
    # zero out the rest. Useful for k-sparse probing experiments.
    if select_top_k is not None:
        positive_distribution_D = positive_acts_BD.mean(dim=(0))
        negative_distribution_D = negative_acts.mean(dim=(0))
        distribution_diff_D = (positive_distribution_D - negative_distribution_D).abs()
        top_k_indices_D = torch.argsort(distribution_diff_D, descending=True)[
            :select_top_k
        ]

        mask_D = torch.ones(
            distribution_diff_D.shape[0],
            dtype=torch.bool,
            device=positive_acts_BD.device,
        )
        mask_D[top_k_indices_D] = False

        masked_positive_acts_BD = positive_acts_BD.clone()
        masked_negative_acts_BD = selected_negative_acts_BD.clone()

        masked_positive_acts_BD[:, mask_D] = 0.0
        masked_negative_acts_BD[:, mask_D] = 0.0
    else:
        masked_positive_acts_BD = positive_acts_BD
        masked_negative_acts_BD = selected_negative_acts_BD

    # Combine positive and negative samples
    combined_acts = torch.cat([masked_positive_acts_BD, masked_negative_acts_BD])

    combined_labels = torch.empty(len(combined_acts), dtype=torch.int, device=device)
    combined_labels[:num_positive] = dataset_info.POSITIVE_CLASS_LABEL
    combined_labels[num_positive:] = dataset_info.NEGATIVE_CLASS_LABEL

    # Shuffle the combined data
    shuffle_indices = torch.randperm(len(combined_acts))
    shuffled_acts = combined_acts[shuffle_indices]
    shuffled_labels = combined_labels[shuffle_indices]

    # Reshape into lists of tensors with specified batch_size
    num_samples = len(shuffled_acts)
    num_batches = num_samples // batch_size

    batched_acts = [
        shuffled_acts[i * batch_size : (i + 1) * batch_size] for i in range(num_batches)
    ]
    batched_labels = [
        shuffled_labels[i * batch_size : (i + 1) * batch_size]
        for i in range(num_batches)
    ]

    return batched_acts, batched_labels

In [None]:
# Probe model and training
class Probe(nn.Module):
    def __init__(self, activation_dim: int, dtype: torch.dtype):
        super().__init__()
        self.net = nn.Linear(activation_dim, 1, bias=True, dtype=dtype)

    def forward(self, x):
        return self.net(x).squeeze(-1)


def test_probe(
    input_batches: list[Float[torch.Tensor, "batch_size d_model"]],
    label_batches: list[Int[torch.Tensor, "batch_size"]],
    probe: Probe,
) -> float:
    criterion = nn.BCEWithLogitsLoss()

    with torch.no_grad():
        corrects_0 = []
        corrects_1 = []
        all_corrects = []
        losses = []

        for acts_BD, labels_B in zip(input_batches, label_batches):
            logits_B = probe(acts_BD)
            preds_B = (logits_B > 0.0).long()
            correct_B = (preds_B == labels_B).float()

            all_corrects.append(correct_B)
            corrects_0.append(correct_B[labels_B == 0])
            corrects_1.append(correct_B[labels_B == 1])

            loss = criterion(logits_B, labels_B.to(dtype=probe.net.weight.dtype))
            losses.append(loss)

        accuracy_all = torch.cat(all_corrects).mean().item()
        accuracy_0 = torch.cat(corrects_0).mean().item() if corrects_0 else 0.0
        accuracy_1 = torch.cat(corrects_1).mean().item() if corrects_1 else 0.0
        loss = torch.stack(losses).mean().item()

    return accuracy_all


def train_probe(
    train_input_batches: list[Float[torch.Tensor, "batch_size d_model"]],
    train_label_batches: list[Int[torch.Tensor, "batch_size"]],
    test_input_batches: list[Float[torch.Tensor, "batch_size d_model"]],
    test_label_batches: list[Int[torch.Tensor, "batch_size"]],
    dim: int,
    epochs: int,
    device: str,
    model_dtype: torch.dtype,
    lr: float,
    verbose: bool = False,
) -> tuple[Probe, float]:
    probe = Probe(dim, model_dtype).to(device)
    optimizer = torch.optim.AdamW(probe.parameters(), lr=lr)
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        for acts_BD, labels_B in zip(train_input_batches, train_label_batches):
            logits_B = probe(acts_BD)
            loss = criterion(
                logits_B, labels_B.clone().detach().to(device=device, dtype=model_dtype)
            )
            optimizer.zero_grad()

            loss.backward()
            optimizer.step()

        train_accuracy = test_probe(train_input_batches, train_label_batches, probe)

        test_accuracy = test_probe(test_input_batches, test_label_batches, probe)

        if epoch == epochs - 1 and verbose:
            print(
                f"\nEpoch {epoch + 1}/{epochs} Loss: {loss.item()}, train accuracy: {train_accuracy}, test accuracy: {test_accuracy}\n"
            )

    return probe, test_accuracy


def train_probe_on_activations(
    train_activations: dict[str, Float[torch.Tensor, "num_datapoints d_model"]],
    test_activations: dict[str, Float[torch.Tensor, "num_datapoints d_model"]],
    probe_batch_size: int,
    epochs: int,
    lr: float,
    model_dtype: torch.dtype,
    device: str,
    select_top_k: int | None = None,
) -> tuple[dict[str, Probe], dict[str, float]]:
    torch.set_grad_enabled(True)

    probes, test_accuracies = {}, {}

    for profession in train_activations.keys():
        train_acts, train_labels = prepare_probe_data(
            train_activations, profession, probe_batch_size, select_top_k
        )

        test_acts, test_labels = prepare_probe_data(
            test_activations, profession, probe_batch_size, select_top_k
        )

        activation_dim = train_acts[0].shape[1]

        print(f"activation dim: {activation_dim}")

        probe, test_accuracy = train_probe(
            train_acts,
            train_labels,
            test_acts,
            test_labels,
            epochs=epochs,
            dim=activation_dim,
            device=device,
            model_dtype=model_dtype,
            lr=lr,
            verbose=False,
        )

        print(f"Test accuracy for {profession}: {test_accuracy}")

        probes[profession] = probe
        test_accuracies[profession] = test_accuracy

    return probes, test_accuracies

In [None]:
first_key = next(iter(all_train_acts_BLD.keys()))
print(all_train_acts_BLD[first_key].shape)
all_train_acts_BD = create_meaned_model_activations(all_train_acts_BLD, model_dtype)
all_test_acts_BD = create_meaned_model_activations(all_test_acts_BLD, model_dtype)
print(all_train_acts_BD[first_key].shape)

In [None]:
probe_batch_size = 128
epochs = 10
lr = 1e-3

probes, test_accuracies = train_probe_on_activations(
    all_train_acts_BD,
    all_test_acts_BD,
    probe_batch_size,
    epochs,
    lr,
    model_dtype,
    device,
    select_top_k=None,
)

In [None]:
sae_batch_size = 32
all_sae_train_acts_BF = get_sae_meaned_activations(
    all_train_acts_BLD, sae, sae_batch_size, model_dtype
)
all_sae_test_acts_BF = get_sae_meaned_activations(
    all_test_acts_BLD, sae, sae_batch_size, model_dtype
)

In [None]:
sae_probes, sae_test_accuracies = train_probe_on_activations(
    all_sae_train_acts_BF,
    all_sae_test_acts_BF,
    probe_batch_size,
    epochs,
    lr,
    model_dtype,
    device,
    select_top_k=None,
)

In [None]:
for k in [1, 5, 10, 20, 50]:
    train_probe_on_activations(
        all_sae_train_acts_BF,
        all_sae_test_acts_BF,
        probe_batch_size,
        epochs,
        lr,
        model_dtype,
        device,
        select_top_k=k,
    )