In [2]:
import os
import torch
import numpy as np
from nnsight import LanguageModel
import datasets

import sys
sys.path.append('../')
from loading_utils import load_submodule_and_dictionary, DictionaryCfg

  from .autonotebook import tqdm as notebook_tqdm


## Tokenize 20 datasets of the pile
`quanta-discovery/misc/create_pile_canonical.py`

## Evaluate model loss
`quanta-discovery/scripts/evaluate_pile_losses.py`

## Synthesize dataset with Tokens of low loss

In [3]:
model_name = "pythia-70m-deduped"
step = 143000
device = "cuda:0"
cache_dir = "/home/can/feature_clustering/cache/"
pile_canonical = "/home/can/data/pile_test_tokenized_200k/"
loss_threshold = 0.0001
skip = 1
num_tokens = 10000
block_len = 250
output_dir = "/home/can/feature_clustering/results/"
filter = None
verbose = True

In [4]:
particular_model_cache_dir = os.path.join(cache_dir, model_name, f"step{step}")
losses_cached = [f for f in os.listdir(particular_model_cache_dir) if f.endswith("losses.pt")]
max_i = max(list(range(len(losses_cached))), key=lambda i: int(losses_cached[i].split("_")[0]))
docs, tokens = int(losses_cached[max_i].split("_")[0]), int(losses_cached[max_i].split("_")[2])
losses = torch.load(os.path.join(particular_model_cache_dir, f"{docs}_docs_{tokens}_tokens_losses.pt"))
c = 1 / np.log(2) # for nats to bits conversion

if filter:
    criterias = torch.load(filter)
    token_idxs = ((losses < (loss_threshold / c)) & (~criterias)).nonzero().flatten()
else:
    token_idxs = (losses < (loss_threshold / c)).nonzero().flatten()
token_idxs = token_idxs[::skip]
token_idxs = token_idxs[:num_tokens].tolist()
assert len(token_idxs) == num_tokens, "not enough tokens meeting loss threshold (and filter) to sample from"

In [14]:
max(token_idxs)

1169481

In [2]:
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map='cuda:0')

# Load all dictionaries

dictionary_dir = "/share/projects/dictionary_circuits/autoencoders/pythia-70m-deduped"
dictionary_size = 32768
submodule, dictionary = load_submodule_and_dictionary(
    model, 
    submod_name='model.gpt_neox.layers.5.mlp.dense_4h_to_h',
    dict_cfg=DictionaryCfg(dictionary_dir, dictionary_size)
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [12]:
# Get context for token
dataset = datasets.load_from_disk(pile_canonical)

total_tokens_in_past_docs = 0
for doc in dataset:
    dataset_len = doc['tokens_len']
    token_doc_idxs = token_idxs[(tokens >= total_tokens_in_past_docs) & (tokens < total_tokens_in_past_docs + dataset_len)]
    if token_doc_idxs.shape[0] > 0:
        print(f"Found {token_doc_idxs.shape[0]} tokens in doc {doc['doc_idx']}")
        # Do in NNsight
        with model.invoke(doc['input_ids'].to(device)) as invoker:
            
        tok = doc['input_ids'].to(device)
        logits = model(tok)
        # Cache all feature activations in forward pass

        # Cache all feature gradients in backward pass


{'text': 'Roman Catholic Diocese of Tambacounda\n\nThe Roman Catholic Diocese of Tambacounda () is a diocese located in the city of Tambacounda in the Ecclesiastical province of Dakar in Senegal.\n\nHistory\n August 13, 1970: Established as Apostolic Prefecture of Tambacounda from the Diocese of Kaolack and Diocese of Saint-Louis du Sénégal\n April 17, 1989: Promoted as Diocese of Tambacounda\n\nSpecial churches\n The cathedral is Cathédrale Marie Reine de l’Univers in Tambacounda, which is located in the Medina Coura neighborhood of the town.\n\nLeadership\n Bishops of Tambacounda (Roman rite)\n Bishop Jean-Noël Diouf (since 1989.04.17)\n Prefects Apostolic of Tambacounda (Roman rite) \n Fr. Clément Cailleau, C.S.Sp. (1970.08.13 – 1986.04.24)\n\nSee also\nRoman Catholicism in Senegal\n\nReferences\n\nExternal links\n GCatholic.org\n Catholic Hierarchy \n\nCategory:Roman Catholic dioceses in Senegal\nCategory:Tambacounda\nCategory:Christian organizations established in 1970\nCategory:R

## Save results for ~10K tokens
- feature activation
- gradient of correct logit w.r.t feature

In [12]:
# Define metric as final token logit
def metric_fn(model):
    logits = model.embed_out.output
    batch_size = logits.shape[0]
    return logits[t.arange(batch_size), -1, prompt_batch_final_tok]