In [None]:
import torch
import torch.nn as nn
import einops
import transformer_lens
from datasets import load_dataset
from dotenv import load_dotenv

load_dotenv()

import os
hf_token = os.getenv('HF_TOKEN')

In [None]:
# data
banana_bonanza = load_dataset("stetef/Banana-Bonanza")

In [None]:
class SAE(nn.Module):
    def __init__(self, hidden_dim, latent_dim, input_dim):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.input_dim = input_dim

        self.encoder_layer1 = nn.Linear(input_dim, hidden_dim, bias=True)
        self.relu1 = nn.ReLU()
        self.encoder_layer2 = nn.Linear(hidden_dim, latent_dim, bias=True)
        self.relu2 = nn.ReLU()

        self.decoder_layer1 = nn.Linear(latent_dim, hidden_dim, bias=True)
        self.relu3 = nn.ReLU()
        self.decoder_layer2 = nn.Linear(hidden_dim, input_dim, bias=True)

    def encode(self, data):
        hidden_layer = self.relu1(self.encoder_layer1(data))
        return self.relu2(self.encoder_layer2(hidden_layer))
    
    def decode(self, latent_space):
        hidden_layer = self.relu3(self.decoder_layer1(latent_space))
        return self.decoder_layer2(hidden_layer)
    
    def forward(self, data):
        encoding = self.encode(data)
        reconstruction = self.decode(encoding)
        return reconstruction, encoding

In [None]:
def loss(model, data, beta, sparsity_param):
    batch_reconstruction, batch_encoding = model.forward(data)
    reconstruction_error = (batch_reconstruction - data).pow(2)
    l2_error = einops.reduce(reconstruction_error, 'batch_size input_dim -> batch_size', 'sum').mean()

    # l1_error = batch_encoding.sum()
    kl_lossifier = nn.KLDivLoss(reduction='sum')  # should have batch size
    sparsity_loss = kl_lossifier(sparsity_param, batch_encoding)
    return l2_error +  beta * sparsity_loss

In [None]:
# model to analyze
model_checkpoint = "meta-llama/Llama-3.2-1B"
model = transformer_lens.HookedTransformer.from_pretrained(model_checkpoint)

In [None]:
# hyperparams


In [None]:
sae = SAE()