In [605]:
import os
import sys
from pathlib import Path
from safetensors.torch import load_model
import json

import torch as t

import einops
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import requests

from huggingface_hub import hf_hub_download
from IPython.display import HTML, IFrame, clear_output, display
#from jaxtyping import Float, Int

from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name, test_prompt, to_numpy



# Path setup
#project_root = Path(__file__).parent.parent
#sys.path.append(str(project_root))

# Imports from the project
from config.sae.models import SAEConfig
from models.sparsified import SparsifiedGPT
from models.gpt import GPT

device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")

In [627]:
# Loading SparsifiedGPT
device = torch.device("cpu")
checkpoint_dir = Path('/Volumes/MacMini/gpt-circuits') / "checkpoints"
gpt_dir = checkpoint_dir / "shakespeare_64x4"
sae_dir = checkpoint_dir / "standard.shakespeare_64x4"
    
# Load GPT model
print("Loading GPT model...")
gpt = GPT.load(gpt_dir, device=device)
    
# Load SAE config
print("Loading SAE configuration...")
sae_config_dir = sae_dir / "sae.json"
with open(sae_config_dir, "r") as f:
    meta = json.load(f)
config = SAEConfig(**meta)
config.gpt_config = gpt.config
    
# Create model using saved config
print("Creating SparsifiedGPT model...")
model = SparsifiedGPT(config)
model.gpt = gpt
    
# Load SAE weights
print("Loading SAE weights...")
for layer_name, module in model.saes.items():
    weights_path = os.path.join(sae_dir, f"sae.{layer_name}.safetensors")
    load_model(module, weights_path, device=device.type)

Loading GPT model...
Loading SAE configuration...
Creating SparsifiedGPT model...
Loading SAE weights...


In [628]:
#sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..",)))
#print("\n".join(sys.path))
# %%
current_dir = os.path.dirname(os.path.abspath("spar_sae_circuit_sandbox.ipynb"))
model_dir = os.path.join(current_dir, '..') # Assuming it's one level up
#toy_model_dir = os.path.join(current_dir, '..', 'llm_from_scratch/LLM_from_scratch/')

sys.path.append(model_dir)
#sys.path.append(toy_model_dir)

from config.gpt.training import options
from config.sae.models import sae_options
from models.gpt import GPT
from models.sparsified import SparsifiedGPT
from data.tokenizers import ASCIITokenizer, TikTokenTokenizer

#from utils import generate
c_name = 'standardx8.shakespeare_64x4'
name = 'standard.shakespeare_64x4'
#name = 'shakespeare_64x4'
config = sae_options[c_name]

model = SparsifiedGPT(config)
model_path = os.path.join("../checkpoints", name)
model = model.load(model_path, device=config.device)

tokenizer = ASCIITokenizer() if "shake" in name else TikTokenTokenizer()

In [539]:
def generate(model, tokenizer, prompt, max_length=50, temperature=0.7) -> str:
    """
    Generate text from a prompt using the model
    """
    tokens = tokenizer.encode(prompt)
    tokens = torch.Tensor(tokens).long().unsqueeze(0)
    
    for _ in range(max_length):
        logits = model(tokens).logits[0][-1]
        probs = torch.softmax(logits / temperature, dim=-1)
        #next_token = torch.multinomial(probs, num_samples=1)
        next_token = torch.argmax(probs, keepdim=True)
        
        tokens = torch.cat([tokens.squeeze(0), next_token], dim=-1).unsqueeze(0)
        
    #return tokenizer.decode_sequence(tokens[0].tolist())
    return tokens

In [609]:
prompt = "His name is Licio, born in Mantua."
print(len(prompt))
output = generate(model, tokenizer, prompt, max_length=1)
print(output)

34
tensor([[ 72, 105, 115,  32, 110,  97, 109, 101,  32, 105, 115,  32,  76, 105,
          99, 105, 111,  44,  32,  98, 111, 114, 110,  32, 105, 110,  32,  77,
          97, 110, 116, 117,  97,  46,  10]])


In [611]:
## New line character is 10 in AsciiTokenizer
## Space character is 32 in AsciiTokenizer

In [614]:
random_prompt = "\n"
tokens = tokenizer.encode(random_prompt)
tokens = torch.Tensor(tokens).long().unsqueeze(0)
tokens

tensor([[10]])

In [615]:
output = model(tokens)

In [616]:
#with model.use_saes():
    #output_sae = model(tokens)

In [414]:
vars(output).keys()

dict_keys(['logits', 'cross_entropy_loss', 'activations', 'ce_loss_increases', 'compound_ce_loss_increase', 'sae_loss_components', 'feature_magnitudes', 'reconstructed_activations'])

In [415]:
output.activations[0].shape

torch.Size([1, 1, 64])

In [416]:
output.feature_magnitudes[0].shape

torch.Size([1, 1, 512])

In [382]:
#feature layers
feat_layer0 = output.feature_magnitudes[0].squeeze(0)
feat_layer1 = output.feature_magnitudes[1].squeeze(0)
feat_layer2 = output.feature_magnitudes[2].squeeze(0)
feat_layer3 = output.feature_magnitudes[3].squeeze(0)
feat_layer4 = output.feature_magnitudes[4].squeeze(0)

#minimum value a feature can be considered "active"
feat_threshold = 0

In [391]:
torch.where(feat_layer0 > feat_threshold)[1].shape

torch.Size([7])

In [593]:
# Load the Shakespeare validation data
file_path = '/Volumes/MacMini/gpt-circuits/data/shakespeare/val_000000.npy'
#with open(file_path) as file:
val_input = np.load(file_path)

In [629]:
#Identify the locations of the output logit of interest
#
new_line_indices = np.where(val_input == 10)[0]
left_cut = 0
right_cut = 0
nl_prompt_list = []
dl_prompt_list = []

for i, idx in enumerate(new_line_indices):
    right_cut = idx
    if right_cut == left_cut:
        continue
    elif right_cut - left_cut == 1:
        left_cut = right_cut
        continue
    else:
        #Grab the sequence between newline characters plus the next two characters to check if its a double line
        token_sequence = val_input[left_cut+1:right_cut+2]
        if token_sequence[-1] == 10:
            dl_prompt_list.append(tokenizer.decode_sequence(token_sequence[:-2]))
        else:
            #if len(token_sequence) >= 10:
            nl_prompt_list.append(tokenizer.decode_sequence(token_sequence[:-2]))
        left_cut = right_cut
    

In [630]:
sample_nl_prompt_list = []
for prompt in nl_prompt_list:
    if len(prompt) == 16:
        #print(prompt)
        #print(len(tokenizer.encode(prompt)))
        sample_nl_prompt_list.append(prompt)

sample_dl_prompt_list = []
for prompt in dl_prompt_list:
    if len(prompt) == 16:
        print(prompt)
        print(len(tokenizer.encode(prompt)))
        sample_dl_prompt_list.append(prompt)

Then show it me.
16
'Tis with cares.
16
Call them forth.
16
countenance her.
16
How now, Grumio!
16
a puppet of her.
16
And what of him?
16
Is't so, indeed.
16
Roundly replied.
16
Who shall begin?
16
die a dry death.
16
Heaviness in me.
16
To the syllable.
16
Done. The wager?
16
beyond credit,--
16
how you take it!
16


In [634]:
model_output = []
for prompt in sample_dl_prompt_list:
    output = generate(model, tokenizer, prompt, max_length=1)
    model_output.append(output)

In [635]:
model_output

[tensor([[ 84, 104, 101, 110,  32, 115, 104, 111, 119,  32, 105, 116,  32, 109,
          101,  46,  10]]),
 tensor([[ 39,  84, 105, 115,  32, 119, 105, 116, 104,  32,  99,  97, 114, 101,
          115,  46,  10]]),
 tensor([[ 67,  97, 108, 108,  32, 116, 104, 101, 109,  32, 102, 111, 114, 116,
          104,  46,  10]]),
 tensor([[ 99, 111, 117, 110, 116, 101, 110,  97, 110,  99, 101,  32, 104, 101,
          114,  46,  10]]),
 tensor([[ 72, 111, 119,  32, 110, 111, 119,  44,  32,  71, 114, 117, 109, 105,
          111,  33,  32]]),
 tensor([[ 97,  32, 112, 117, 112, 112, 101, 116,  32, 111, 102,  32, 104, 101,
          114,  46,  10]]),
 tensor([[ 65, 110, 100,  32, 119, 104,  97, 116,  32, 111, 102,  32, 104, 105,
          109,  63,  10]]),
 tensor([[ 73, 115,  39, 116,  32, 115, 111,  44,  32, 105, 110, 100, 101, 101,
          100,  46,  10]]),
 tensor([[ 82, 111, 117, 110, 100, 108, 121,  32, 114, 101, 112, 108, 105, 101,
          100,  46,  10]]),
 tensor([[ 87, 104, 111,  32

In [637]:
count = 0
output_list = []
for output in model_output:
    if output[0][-1] == 10:# and output[0][-2]:
        count += 1
        output_list.append(output[0][:-1])
        print(output[0][:-1])

print(count/len(model_output))

tensor([ 84, 104, 101, 110,  32, 115, 104, 111, 119,  32, 105, 116,  32, 109,
        101,  46])
tensor([ 39,  84, 105, 115,  32, 119, 105, 116, 104,  32,  99,  97, 114, 101,
        115,  46])
tensor([ 67,  97, 108, 108,  32, 116, 104, 101, 109,  32, 102, 111, 114, 116,
        104,  46])
tensor([ 99, 111, 117, 110, 116, 101, 110,  97, 110,  99, 101,  32, 104, 101,
        114,  46])
tensor([ 97,  32, 112, 117, 112, 112, 101, 116,  32, 111, 102,  32, 104, 101,
        114,  46])
tensor([ 65, 110, 100,  32, 119, 104,  97, 116,  32, 111, 102,  32, 104, 105,
        109,  63])
tensor([ 73, 115,  39, 116,  32, 115, 111,  44,  32, 105, 110, 100, 101, 101,
        100,  46])
tensor([ 82, 111, 117, 110, 100, 108, 121,  32, 114, 101, 112, 108, 105, 101,
        100,  46])
tensor([ 87, 104, 111,  32, 115, 104,  97, 108, 108,  32,  98, 101, 103, 105,
        110,  63])
tensor([100, 105, 101,  32,  97,  32, 100, 114, 121,  32, 100, 101,  97, 116,
        104,  46])
tensor([ 72, 101,  97, 118, 10

In [642]:
output_list.to('cpu')

AttributeError: 'list' object has no attribute 'to'

In [639]:
torch.save(output_list, 'double_newline_prompts.pt')

In [643]:
double_newline_prompts = torch.load('double_newline_prompts.pt', weights_only=True)

In [644]:
double_newline_prompts

[tensor([ 84, 104, 101, 110,  32, 115, 104, 111, 119,  32, 105, 116,  32, 109,
         101,  46]),
 tensor([ 39,  84, 105, 115,  32, 119, 105, 116, 104,  32,  99,  97, 114, 101,
         115,  46]),
 tensor([ 67,  97, 108, 108,  32, 116, 104, 101, 109,  32, 102, 111, 114, 116,
         104,  46]),
 tensor([ 99, 111, 117, 110, 116, 101, 110,  97, 110,  99, 101,  32, 104, 101,
         114,  46]),
 tensor([ 97,  32, 112, 117, 112, 112, 101, 116,  32, 111, 102,  32, 104, 101,
         114,  46]),
 tensor([ 65, 110, 100,  32, 119, 104,  97, 116,  32, 111, 102,  32, 104, 105,
         109,  63]),
 tensor([ 73, 115,  39, 116,  32, 115, 111,  44,  32, 105, 110, 100, 101, 101,
         100,  46]),
 tensor([ 82, 111, 117, 110, 100, 108, 121,  32, 114, 101, 112, 108, 105, 101,
         100,  46]),
 tensor([ 87, 104, 111,  32, 115, 104,  97, 108, 108,  32,  98, 101, 103, 105,
         110,  63]),
 tensor([100, 105, 101,  32,  97,  32, 100, 114, 121,  32, 100, 101,  97, 116,
         104,  46]),
