In [84]:
import torch
import torch.nn as nn
import einops
import transformer_lens
from datasets import load_dataset
from dotenv import load_dotenv

load_dotenv()

import os
hf_token = os.getenv('HF_TOKEN')

In [85]:
# data
banana_bonanza = load_dataset("stetef/Banana-Bonanza")

In [86]:
class SAE(nn.Module):
    def __init__(self, hidden_dim, latent_dim, input_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.input_dim = input_dim

        self.encoder_layer1 = nn.Linear(input_dim, hidden_dim, bias=True)
        self.relu1 = nn.ReLU()
        self.encoder_layer2 = nn.Linear(hidden_dim, latent_dim, bias=True)
        self.relu2 = nn.ReLU()

        self.decoder_layer1 = nn.Linear(latent_dim, hidden_dim, bias=True)
        self.relu3 = nn.ReLU()
        self.decoder_layer2 = nn.Linear(hidden_dim, input_dim, bias=True)

    def encode(self, data):
        hidden_layer = self.relu1(self.encoder_layer1(data))
        return self.relu2(self.encoder_layer2(hidden_layer))
    
    def decode(self, latent_space):
        hidden_layer = self.relu3(self.decoder_layer1(latent_space))
        return self.decoder_layer2(hidden_layer)
    
    def forward(self, data):
        encoding = self.encode(data)
        reconstruction = self.decode(encoding)
        return reconstruction, encoding

In [87]:
def loss(model, data, beta, sparsity_param):
    batch_reconstruction, batch_encoding = model.forward(data)
    reconstruction_error = (batch_reconstruction - data).pow(2)
    l2_error = einops.reduce(reconstruction_error, 'batch_size input_dim -> batch_size', 'sum').mean()

    # l1_error = batch_encoding.sum()
    kl_lossifier = nn.KLDivLoss(reduction='sum')  # should have batch size
    sparsity_loss = kl_lossifier(sparsity_param, batch_encoding)
    return l2_error +  beta * sparsity_loss

In [88]:
# model to analyze
model_checkpoint = "meta-llama/Llama-3.2-1B"
model = transformer_lens.HookedTransformer.from_pretrained(model_checkpoint)

KeyboardInterrupt: 

# run using transformer-lens, check that it works

In [None]:
# hyperparams

# take dataset
# https://huggingface.co/datasets/stetef/Banana-Bonanza
questions = banana_bonanza["train"]['Question']
llama32_tokens = model.to_tokens(questions[0])

# print(llama32_tokens.shape)
# new_token = 29182
# print(torch.cat([llama32_tokens, torch.tensor([new_token]).unsqueeze(0).to("mps")], dim=1))
print(questions[0])
for i in range(150):
    # print(llama32_tokens.shape)
    # print(llama32_tokens.device)
    llama32_logits, _ = model.run_with_cache(llama32_tokens, remove_batch_dim=True)
    logits = llama32_logits[0][-1]
    softmax = nn.Softmax(dim=-1)
    normalized_distribution = softmax(logits)
    next_token = torch.argmax(normalized_distribution).item()
    llama32_tokens = torch.cat([llama32_tokens, torch.tensor([next_token]).unsqueeze(0).to("mps")], dim=1)

    if next_token == 128001:
        break

    print(model.to_string(next_token), end='')

# run using transformer lens with hooks to train SAE, print loss from time to time
# replace transformer activations with trained reconstructed SAE activations
# override transformer activations with a particular feature turned on


If 5 workers can complete a job in 12 days, how many days would it take 3 workers to complete the same job?
 A) 10 B) 12 C) 15 D) 18 E) 20
A. 10
B. 12
C. 15
D. 18
E. 20
Answer: B

In [97]:
print(model)

sae = SAE()



HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-15): 16 x TransformerBlock(
      (ln1): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_att

TypeError: SAE.__init__() missing 3 required positional arguments: 'hidden_dim', 'latent_dim', and 'input_dim'