In [1]:
# Required installations for transformers and datasets
# !pip install transformers datasets
# !pip install keras huggingface_hub
# !pip install tensorflow
# !pip install python-dotenv
# !pip install zstandard
#!pip install bitsandbytes

In [2]:
import os
import time
from dotenv import load_dotenv
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from huggingface_hub import login
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
from transformers import BitsAndBytesConfig

In [14]:
import os
FILE_NAMES = []
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        if filename.startswith("act"):
            FILE_NAMES.append(os.path.join(dirname, filename))

In [10]:
# scale_factor = 34.12206415510119 # at 1.6mil tokens
# scale_factor = 34.128712991170886 # at 10.6mil tokens
scale_factor = 11.888623072966611 # 10mil but with <begin> token removed

class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(SparseAutoencoder, self).__init__()
        # Encoder
        self.encoder = nn.Linear(input_dim, hidden_dim)
        # Decoder
        self.decoder = nn.Linear(hidden_dim, input_dim)
        
    def forward(self, x):
        encoded = torch.relu(self.encoder(x))
        # encoded = torch.nn.LeakyReLU(0.01)(self.encoder(x))
        decoded = self.decoder(encoded)
        return decoded, encoded

from torch.utils.data import Dataset, DataLoader
class ActivationDataset(Dataset):
    def __init__(self, data_dir, batch_size, f_type, test_fraction=0.01, scale_factor=1.0, seed=42):
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.seed = seed

        if f_type in ["train", "test", "all"]:
            self.f_type = f_type
        else:
            raise ValueError("f_type must be 'train' or 'test' or 'all'")
        
        if not 0 <= test_fraction <= 1:
            raise ValueError("test_fraction must be between 0 and 1")
        self.test_fraction = test_fraction

        self.scale_factor = scale_factor
        self.file_names = FILE_NAMES
        
        split_idx = int(len(self.file_names) * (1 - test_fraction))
        if f_type == "train":
            self.file_names = self.file_names[:split_idx]
        elif f_type == "test":
            self.file_names = self.file_names[split_idx:]
        else: # all
            pass

        print(f"Loaded {len(self.file_names)} batches for {f_type} set")

    def __len__(self):
        return len(self.file_names)
    
    def __getitem__(self, idx):
        activations = np.load(self.file_names[idx])
        if self.f_type == "all":
            sent_idx = activations[:, -3]
            token_idx = activations[:, -2] 
            token = activations[:, -1]
        # remove last 3 columns (sent_idx, token_idx, and token)
        activations = activations[:, :-3]
        # normalize activations
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        activations = torch.tensor(activations, dtype=torch.float32, device=device)
        # print("Activation Range Before Normalization:", torch.min(activations).item(), torch.max(activations).item())
        activations = activations / self.scale_factor * np.sqrt(activations.shape[1])
        # print("Activation Range After Normalization:", torch.min(activations).item(), torch.max(activations).item())

        if self.f_type == "train":
            # Set seed for reproducibility
            np.random.seed(self.seed)
            # random subsample 8192 examples
            indices = torch.randperm(activations.shape[0], device=activations.device)[:self.batch_size]
            activations = activations[indices]
        
        if self.f_type == "all":
            return activations, sent_idx, token_idx, token
        else:
            return activations

In [None]:
data_dir = "activations_data"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_dim = 3072  
hidden_dim = 65536

model = SparseAutoencoder(input_dim, hidden_dim).to(device)
# model.load_state_dict(torch.load("models/sparse_autoencoder_496.3666.pth"))
checkpoint = torch.load("/kaggle/input/checkpoint65k_sae/pytorch/default/1/checkpoint")
print(checkpoint["hyper_parameters"])
model.load_state_dict(checkpoint['state_dict'])

criterion = nn.MSELoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
l1_lambda = 0.01  # Regularization strength for sparsity

In [40]:
# Test model
# test_dataset = ActivationDataset(
#     data_dir, 
#     batch_size=0, # not subsampled
#     f_type="test", 
#     # test_fraction=0.01, # last batch file
#     test_fraction=0.6, # 12 files == cca 10mil tokens
#     scale_factor=scale_factor, 
#     seed=42 # not used for test set
# ) # this outputs batches of size 81k  - too big for VRAM
test_dataset = ActivationDataset(
    data_dir, 
    batch_size=4096,
    f_type="train", 
    # test_fraction=0.01, # last batch file
    test_fraction=0.0, # 12 files == cca 10mil tokens
    scale_factor=scale_factor, 
    seed=123 # different seed that in actual training
) # this outputs batches of size 49k - uses 7820MiB VRAM = 95% of GPU
data_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) # take 1 batch at a time

# Set model to evaluation mode
model.eval()

# Run and compute reconstruction error, l1 loss, and total loss
total_loss = 0; total_mse_loss = 0; total_l1_loss = 0; num_batches = 0
global_active_mask = torch.zeros((hidden_dim), dtype=torch.bool, device=device)
for batch in data_loader:
    batch = batch.to(device)
    
    outputs, encoded = model(batch)

    # percent of active features
    # print(encoded.min().item(), encoded.max().item())
    global_active_mask |= torch.any(encoded > 0, dim=1).squeeze(0)
    active_features = torch.any(encoded != 0, dim=1).sum().item()  # Count active features
    total_features = encoded.shape[2]  # Total number of latent features (4096)
    percent_active_features = active_features / total_features
    print(f"Percent Active Features: {percent_active_features * 100:.2f}%")

    mse_loss = criterion(outputs, batch)
    decoder_weight_norms = torch.norm(model.decoder.weight, p=2, dim=0)  # Shape: [num_features]
    l1_terms = encoded * decoder_weight_norms.unsqueeze(0)  # Shape: [batch_size, num_features]
    l1_loss = torch.mean(l1_terms)  # Normalize across both batch size and features
    loss = mse_loss + l1_loss

    total_loss += loss.item()
    total_mse_loss += mse_loss.item()
    total_l1_loss += l1_loss.item()

    explained_variance = 1 - mse_loss / torch.var(batch)
    # Print batch-level metrics
    print(f"MSE Loss: {mse_loss.item():.4f}, L1 Loss: {l1_loss.item():.4f}, Explained Var: {explained_variance.item():.4f}")
    num_batches += 1

Loaded 129 batches for train set
Percent Active Features: 100.00%
MSE Loss: 0.0750, L1 Loss: 0.0122, Explained Var: 0.9234
Percent Active Features: 100.00%
MSE Loss: 0.0758, L1 Loss: 0.0123, Explained Var: 0.9239
Percent Active Features: 100.00%
MSE Loss: 0.0758, L1 Loss: 0.0123, Explained Var: 0.9239
Percent Active Features: 100.00%
MSE Loss: 0.0761, L1 Loss: 0.0124, Explained Var: 0.9241
Percent Active Features: 100.00%
MSE Loss: 0.0769, L1 Loss: 0.0125, Explained Var: 0.9236
Percent Active Features: 100.00%
MSE Loss: 0.0767, L1 Loss: 0.0125, Explained Var: 0.9240
Percent Active Features: 100.00%
MSE Loss: 0.0770, L1 Loss: 0.0125, Explained Var: 0.9228
Percent Active Features: 100.00%
MSE Loss: 0.0759, L1 Loss: 0.0123, Explained Var: 0.9241
Percent Active Features: 100.00%
MSE Loss: 0.0760, L1 Loss: 0.0123, Explained Var: 0.9237
Percent Active Features: 100.00%
MSE Loss: 0.0755, L1 Loss: 0.0123, Explained Var: 0.9246
Percent Active Features: 100.00%
MSE Loss: 0.0757, L1 Loss: 0.0124,

In [41]:
# Print final metrics
print(f"Total Test Loss: {total_loss/num_batches:.4f}")
print(f"Total MSE Loss: {total_mse_loss/num_batches:.4f}")
print(f"Total L1 Loss: {total_l1_loss/num_batches:.4f}")

active_features = global_active_mask.sum().item()
total_features = global_active_mask.numel()
global_sparsity = (1 - active_features / total_features) * 100
print(f"Global Sparsity Across All Batches: {global_sparsity:.2f}%")
print(f"Percent of Active Features: {active_features / total_features * 100:.2f}%")

Total Test Loss: 0.0883
Total MSE Loss: 0.0759
Total L1 Loss: 0.0124
Global Sparsity Across All Batches: 0.00%
Percent of Active Features: 100.00%
