In [17]:
import os
import sys

import torch as t

import einops
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import requests

from huggingface_hub import hf_hub_download
from IPython.display import HTML, IFrame, clear_output, display
#from jaxtyping import Float, Int

from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name, test_prompt, to_numpy

device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")

In [16]:
def generate(model, tokenizer, prompt, max_length=50, temperature=0.7) -> str:
    """
    Generate text from a prompt using the model
    """
    tokens = tokenizer.encode(prompt)
    tokens = torch.Tensor(tokens).long().unsqueeze(0)
    for _ in range(max_length):
        logits = model(tokens)[0][:, -1, :]
        probs = torch.softmax(logits / temperature, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        tokens = torch.cat([tokens, next_token], dim=-1)
    return tokenizer.decode_sequence(tokens[0].tolist())

In [60]:
#sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..",)))
#print("\n".join(sys.path))
# %%
current_dir = os.path.dirname(os.path.abspath("spar_sae_circuit_sandbox.ipynb"))
model_dir = os.path.join(current_dir, '..') # Assuming it's one level up
#toy_model_dir = os.path.join(current_dir, '..', 'llm_from_scratch/LLM_from_scratch/')

sys.path.append(model_dir)
#sys.path.append(toy_model_dir)

from config.gpt.training import options
from config.sae.models import sae_options
from models.gpt import GPT
from models.sparsified import SparsifiedGPT
from data.tokenizers import ASCIITokenizer, TikTokenTokenizer

#from utils import generate
c_name = 'standardx8.shakespeare_64x4'
name = 'standard.shakespeare_64x4'
#name = 'shakespeare_64x4'
config = sae_options[c_name]

model = SparsifiedGPT(config)
model_path = os.path.join("../checkpoints", name)
model = model.load(model_path, device=config.device)

tokenizer = ASCIITokenizer() if "shake" in name else TikTokenTokenizer()

#print(generate(model, tokenizer, "Today I thought,", max_length=100))

In [61]:
#t.set_grad_enabled(False)
#hf_model = 'davidquarel/standard.shakespeare_64x4'
#gpt2: HookedSAETransformer = HookedSAETransformer.from_pretrained(repo_id=hf_model, device=device)

In [62]:
model.saes

ModuleDict(
  (0): StandardSAE()
  (1): StandardSAE()
  (2): StandardSAE()
  (3): StandardSAE()
  (4): StandardSAE()
)