Tutorial on using SAE on GPT-2 from scratch.

Takeaways:
- People will understand how to do mechnanistic interpretability, in an end-to-end manner.
- People can apply this same code here, to investigate other models by just changing the Huggingface model name.
- The necessary visualizations have been taken care of.
- People have the right tooling to do the job. The tooling is even more general-purpose in that it can be used to understand model inner-working, work with many deep learning paradigm (NLP, CV, RL, ES...).
- Library functionality:
  - Fast inference.
  - Can run remote inference.
  - Hook up and integrate well to other frameworks, programming langugages (e.g. llama.cpp...).

### Get the data

In [1]:
from pathlib import Path

import pandas as pd
import numpy as np
import torch
import json
from tqdm import tqdm

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from dawnet.model import ModelRunner

model_id = "openai-community/gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
)
model = model.eval()
runner = ModelRunner(model)
print(runner._model)



GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


In [3]:
type(model)

transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel

In [3]:
def split_texts(input_path, output_path):
    from langchain_text_splitters import RecursiveCharacterTextSplitter
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=750, chunk_overlap=20, length_function=lambda x: len(tokenizer.encode(x)), is_separator_regex=False
    )
    df = pd.read_parquet(input_path)
    print(df.shape)
    texts = [each['text'] for _, each in df.iterrows() if len(each['text']) > 150]
    print(len(texts))

    splitted_texts = []
    for text in tqdm(texts):
        splitted_texts += [each for each in text_splitter.split_text(text) if len(each) > 150]
    with open(output_path, 'w') as fo:
        json.dump(splitted_texts, fo)

In [5]:
stem = "train-00006-of-00041"
split_texts(f"/data/mech/data/wikipedia/{stem}.parquet", f"/data/mech/data/splitted/{stem}.json")

(156288, 4)
153291


100%|█████████████████████████████████████████████████████████████████████████████| 153291/153291 [14:38<00:00, 174.42it/s]


In [3]:
def get_intermediate(text_file, path, layers, batch_size=2000):
    with open(text_file) as fi:
        texts = json.load(fi)
    print(f"There are {len(texts)} text lines")
    try:
        handler1 = runner.cache_outputs(*layers)
        intermediates = {layer: [] for layer in layers}
        with torch.no_grad():
            for idx, text in enumerate(tqdm(texts)):
                text = "<|endoftext|>" + text
                input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)
                runner(input_ids)
                for layer in layers:
                    intermediates[layer].append(runner._output[layer][0].cpu().squeeze().numpy()[1:])
                if idx >= batch_size and idx % batch_size == 0:
                    for layer in layers:
                        output_folder = Path(path) / layer
                        output_folder.mkdir(exist_ok=True, parents=True)
                        stem = Path(text_file).stem
                        np.save(str(output_folder / f"{stem}_{idx}.pth"), np.concatenate(intermediates[layer]))
                    intermediates = {layer: [] for layer in layers}
    finally:
        handler1.clear()

# get_intermediate(
#     "/data/mech/data/splitted/train-00001-of-00041.json",
#     path="/data/mech/data/layers",
#     layers=["transformer.h.10"]
# )

There are 287260 text lines


100%|████████████████████████████████████████████████████████████████████████████| 287260/287260 [8:28:18<00:00,  9.42it/s]


In [None]:
def get_autoregressive_intermediates(text_file, path, layers, batch_size=2000):
    with open(text_file) as fi:
        texts = json.load(fi)
    print(f"There are {len(texts)} text lines")
    try:
        handler1 = runner.cache_outputs(*layers)
        intermediates = {layer: [] for layer in layers}
        with torch.no_grad():
            for idx, text in enumerate(tqdm(texts)):
                text = "<|endoftext|>" + text
                input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)
                runner(input_ids)
                for layer in layers:
                    intermediates[layer].append(runner._output[layer][0].cpu().squeeze().numpy()[1:])
                if idx >= batch_size and idx % batch_size == 0:
                    for layer in layers:
                        output_folder = Path(path) / layer
                        output_folder.mkdir(exist_ok=True, parents=True)
                        stem = Path(text_file).stem
                        np.save(str(output_folder / f"{stem}_{idx}.pth"), np.concatenate(intermediates[layer]))
                    intermediates = {layer: [] for layer in layers}
    finally:
        handler1.clear()

In [15]:
text_tss, intermediate_tss, output_tss = [], [], []

idx = 1001
while idx <= 5000:
    with torch.no_grad():
        # store the input text
        # store the intermediate layer
        # store the output logits
        text = "<|endoftext|>" + new_texts[idx]
        input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)
        output = runner(input_ids)
        intermediate_ts = runner._output["transformer.h.10"][0].cpu().squeeze().numpy()
        output_ts = output["logits"].cpu().squeeze().numpy()

        text_tss.append(text)
        intermediate_tss.append(intermediate_ts)
        output_tss.append(output_ts)
        if idx % 20 == 0:
            with open(f"/data/mech/data/output/texts_{idx:07d}.json", "w") as fo:
                json.dump(text_tss, fo)
            it = np.concatenate(intermediate_tss)
            ot = np.concatenate(output_tss)
            np.save(f"/data/mech/data/output/it_{idx:07d}.npy", it)
            np.save(f"/data/mech/data/output/ot_{idx:07d}.npy", ot)
            text_tss, intermediate_tss, output_tss = [], [], []
        idx += 1

KeyboardInterrupt: 

In [43]:
intermediate = runner._output["transformer.h.10"][0].cpu().squeeze().numpy()

In [44]:
intermediate.shape

(379, 768)

In [45]:
output_ts = output["logits"].cpu().squeeze().numpy()

In [41]:
output['logits'].shape

torch.Size([1, 379, 50257])

In [26]:
new_texts = []
idx = 0
while len(new_texts) < 1000000:
    new_texts += text_splitter.split_text(texts[idx])
    idx += 1

IndexError: list index out of range