In [1]:
from transformer_lens import HookedTransformer
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "gpt2"
model = HookedTransformer.from_pretrained(model_name, device = device)

  from .autonotebook import tqdm as notebook_tqdm


Loaded pretrained model gpt2 into HookedTransformer


In [2]:
from sparsifiers import SAE
from huggingface_hub import hf_hub_download

layer = 6
# e2e_name = "https://huggingface.co/apollo-research/e2e-saes-gpt2/blob/main/downstream_similar_ce_layer_6.pt"
downstream_name = f"downstream_similar_ce_layer_{layer}.pt"
# downstream_name = f"downstream_similar_l0_layer_{layer}.pt"
local_name = f"local_similar_ce_layer_{layer}.pt"
# local_name = f"local_similar_l0_layer_{layer}.pt"

activation_name = f"transformer.h.{layer}"
model_id = "apollo-research/e2e-saes-gpt2"

In [3]:
def load_sae(model_id, hh_file_location):
    ae_download_location = hf_hub_download(repo_id=model_id, filename=hh_file_location)

    sae_list = torch.load(ae_download_location)

    # Rename the keys in the state dictionary
    def rename_keys(state_dict):
        renamed_state_dict = {}
        for key, value in state_dict.items():
            if key == 'encoder.weight':
                renamed_state_dict['encoder.0.weight'] = value
            elif key == 'encoder.bias':
                renamed_state_dict['encoder.0.bias'] = value
            else:
                renamed_state_dict[key] = value
        return renamed_state_dict

    # Create an instance of the SAE model with the extracted parameters
    n_dict_components, input_size = sae_list['encoder.weight'].shape
    sae = SAE(input_size=input_size, n_dict_components=n_dict_components)

    # Rename the keys in the state dictionary
    renamed_sae_list = rename_keys(sae_list)

    # Load the state dictionary into the model instance
    sae.load_state_dict(renamed_sae_list)
    return sae

sae_local = load_sae(model_id, local_name).to(device)
sae_downstream = load_sae(model_id, downstream_name).to(device)

  sae_list = torch.load(ae_download_location)


In [4]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from einops import rearrange

def download_dataset(dataset_name, tokenizer, max_length=256, num_datapoints=None):
    if(num_datapoints):
        split_text = f"train[:{num_datapoints}]"
    else:
        split_text = "train"
    dataset = load_dataset(dataset_name, split=split_text).map(
        lambda x: tokenizer(x['text']),
        batched=True,
    ).filter(
        lambda x: len(x['input_ids']) > max_length
    ).map(
        lambda x: {'input_ids': x['input_ids'][:max_length]}
    )
    return dataset

dataset_name = "stas/openwebtext-10k"
max_seq_length = 40
print(f"Downloading {dataset_name}")
dataset = download_dataset(dataset_name, tokenizer=model.tokenizer, max_length=max_seq_length, num_datapoints=None) # num_datapoints grabs all of them if None

Using the latest cached version of the dataset since stas/openwebtext-10k couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'plain_text' at /Users/wendysun/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/152771d7ae284673c3ad7ffdd9b3afc2741f1d00 (last modified on Wed Aug 28 11:02:24 2024).


Downloading stas/openwebtext-10k


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [5]:
# now run through the model & grab activations to get dead features
def get_dictionary_activations(model, dataset, cache_name, max_seq_length, autoencoder, batch_size=32):
    device = model.cfg.device
    num_features, d_model = autoencoder.encoder[0].weight.shape
    datapoints = dataset.num_rows
    dictionary_activations = torch.zeros((datapoints*max_seq_length, num_features))
    token_list = torch.zeros((datapoints*max_seq_length), dtype=torch.int64)
    with torch.no_grad(), dataset.formatted_as("pt"):
        dl = DataLoader(dataset["input_ids"], batch_size=batch_size)
        for i, batch in enumerate(tqdm(dl)):
            batch = batch.to(device)
            token_list[i*batch_size*max_seq_length:(i+1)*batch_size*max_seq_length] = rearrange(batch, "b s -> (b s)")
            # with Trace(model, cache_name) as ret:
            #     _ = model(batch).logits
            #     internal_activations = ret.output
            #     # check if instance tuple
            #     if(isinstance(internal_activations, tuple)):
            #         internal_activations = internal_activations[0]
            _, cache = model.run_with_cache(batch)
            # print(cache)
            internal_activations = cache[cache_name]
            batched_neuron_activations = rearrange(internal_activations, "b s n -> (b s) n" )
            batched_dictionary_activations = autoencoder.encoder(batched_neuron_activations)
            dictionary_activations[i*batch_size*max_seq_length:(i+1)*batch_size*max_seq_length,:] = batched_dictionary_activations.cpu()
    return dictionary_activations, token_list

batch_size = 128
activation_name = f"blocks.{layer}.hook_resid_post"

In [6]:
local_dictionary_activations, local_token_list = get_dictionary_activations(model, dataset, activation_name, max_seq_length, sae_local, batch_size=batch_size)

100%|██████████| 79/79 [04:22<00:00,  3.32s/it]


: 

In [7]:
downstream_dictionary_activations, downstream_token_list = get_dictionary_activations(model, dataset, activation_name, max_seq_length, sae_downstream, batch_size=batch_size)

In [None]:
do_decoder = True

if(do_decoder):
    dec_local = sae_local.decoder.weight.data
    dec_downstream = sae_downstream.decoder.weight.data
else:
    dec_local = sae_local.encoder[0].weight.data.T
    dec_downstream = sae_downstream.encoder[0].weight.data.T

dec_local = dec_local / dec_local.norm(dim=0, keepdim=True)
dec_downstream = dec_downstream / dec_downstream.norm(dim=0, keepdim=True)

In [None]:
# Compute cos sim
cos_sim = torch.mm(dec_local.T, dec_downstream)

# Find pairs of local & downstream features w/ cos sim > 0.999
threshold = 0.9
high_cos_sim_pairs = torch.nonzero(cos_sim > threshold)

In [None]:
# Sanity check (cos_sim has shape feature x feature)
cos_sim.shape

torch.Size([46080, 46080])

In [None]:
# Sanity check (use high_cos_sim_pairs to index into cos_sim gives values > threshold)
cos_sim[high_cos_sim_pairs[:, 0], high_cos_sim_pairs[:, 1]]

tensor([0.9144, 0.9669, 0.9729,  ..., 0.9210, 0.9125, 0.9211])

In [None]:
len(high_cos_sim_pairs)

7750

In [None]:
def find_activation_contexts(activations, token_list, top_k=5):
    top_indices = torch.argsort(activations, descending=True)[:top_k]
    contexts = [token_list[idx-2:idx+3].tolist() for idx in top_indices]
    return contexts

In [None]:
subset_examples = []

for local_idx, downstream_idx in tqdm(high_cos_sim_pairs):
    local_activations = local_dictionary_activations[:, local_idx]
    downstream_activations = downstream_dictionary_activations[:, downstream_idx]
    
    # Check if local feature activates on a subset of downstream feature
    # Count as active if above mean + std
    local_active = local_activations > local_activations.mean() + local_activations.std()
    downstream_active = downstream_activations > downstream_activations.mean() + downstream_activations.std()
    
    if local_active.sum() < downstream_active.sum() and (local_active & downstream_active).sum() / local_active.sum() > 0.2:
        local_contexts = find_activation_contexts(local_activations, token_list)
        downstream_contexts = find_activation_contexts(downstream_activations, token_list)
        
        subset_examples.append({
            'local_idx': local_idx.item(),
            'downstream_idx': downstream_idx.item(),
            'local_contexts': local_contexts,
            'downstream_contexts': downstream_contexts
        })

Processing feature 9: 100%|██████████| 8/8 [00:27<00:00,  3.44s/it]
Processing feature 9: 100%|██████████| 8/8 [00:25<00:00,  3.24s/it]
  0%|          | 1/7750 [00:53<115:36:38, 53.71s/it]

torch.Size([40000])
torch.Size([40000])


Processing feature 12:  50%|█████     | 4/8 [00:16<00:16,  4.02s/it]
  0%|          | 1/7750 [01:09<150:12:38, 69.78s/it]


KeyboardInterrupt: 

In [None]:
print(f"{len(subset_examples)} local feature that is a subset of downstream feature")

0 local feature that is a subset of downstream feature


In [None]:
# Print out the first 5 examples
for i, example in enumerate(subset_examples[:5]):
    print(f"\nExample {i+1}:")

    print(f"Local feature {example['local_idx']} contexts:")
    for context in example['local_contexts'][:2]:
        print(" ".join(model.to_string(context)))
    
    print(f"Downstream feature {example['downstream_idx']} contexts:")
    for context in example['downstream_contexts'][:2]:
        print(" ".join(model.to_string(context)))