In [1]:
# Required installations for transformers and datasets
# !pip install transformers datasets
# !pip install keras huggingface_hub
# !pip install tensorflow
# !pip install python-dotenv
# !pip install zstandard
#!pip install bitsandbytes

In [2]:
import os
import time
from dotenv import load_dotenv
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from huggingface_hub import login
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
from transformers import BitsAndBytesConfig

In [3]:
import os
FILE_NAMES = []
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        if filename.startswith("act"): # skip model checkpoint
            FILE_NAMES.append(os.path.join(dirname, filename))

In [4]:
FILE_NAMES = sorted(FILE_NAMES)

In [5]:
# scale_factor = 34.12206415510119 # at 1.6mil tokens
# scale_factor = 34.128712991170886 # at 10.6mil tokens
scale_factor = 11.888623072966611 # 10mil but with <begin> token removed

class SparseAutoencoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(SparseAutoencoder, self).__init__()
        # Encoder
        self.encoder = nn.Linear(input_dim, hidden_dim)
        # Decoder
        self.decoder = nn.Linear(hidden_dim, input_dim)
        
    def forward(self, x):
        encoded = torch.relu(self.encoder(x))
        # encoded = torch.nn.LeakyReLU(0.01)(self.encoder(x))
        decoded = self.decoder(encoded)
        return decoded, encoded

from torch.utils.data import Dataset, DataLoader
class ActivationDataset(Dataset):
    def __init__(self, data_dir, batch_size, f_type, test_fraction=0.01, scale_factor=1.0, seed=42):
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.seed = seed

        if f_type in ["train", "test", "all"]:
            self.f_type = f_type
        else:
            raise ValueError("f_type must be 'train' or 'test' or 'all'")
        
        if not 0 <= test_fraction <= 1:
            raise ValueError("test_fraction must be between 0 and 1")
        self.test_fraction = test_fraction

        self.scale_factor = scale_factor
        self.file_names = FILE_NAMES
        
        split_idx = int(len(self.file_names) * (1 - test_fraction))
        if f_type == "train":
            self.file_names = self.file_names[:split_idx]
        elif f_type == "test":
            self.file_names = self.file_names[split_idx:]
        else: # all
            pass

        print(f"Loaded {len(self.file_names)} batches for {f_type} set")

    def __len__(self):
        return len(self.file_names)
    
    def __getitem__(self, idx):
        activations = np.load(self.file_names[idx])
        if self.f_type == "all":
            sent_idx = activations[:, -3]
            token_idx = activations[:, -2] 
            token = activations[:, -1]
        # remove last 3 columns (sent_idx, token_idx, and token)
        activations = activations[:, :-3]
        # normalize activations
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        activations = torch.tensor(activations, dtype=torch.float32, device=device)
        # print("Activation Range Before Normalization:", torch.min(activations).item(), torch.max(activations).item())
        activations = activations / self.scale_factor * np.sqrt(activations.shape[1])
        # print("Activation Range After Normalization:", torch.min(activations).item(), torch.max(activations).item())

        if self.f_type == "train":
            # Set seed for reproducibility
            np.random.seed(self.seed)
            # random subsample 8192 examples
            indices = torch.randperm(activations.shape[0], device=activations.device)[:self.batch_size]
            activations = activations[indices]
        
        if self.f_type == "all":
            return activations, sent_idx, token_idx, token
        else:
            return activations

In [6]:
data_dir = "activations_data"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
input_dim = 3072  
hidden_dim = 65536

model = SparseAutoencoder(input_dim, hidden_dim).to(device)
# model.load_state_dict(torch.load("models/sparse_autoencoder_496.3666.pth"))
checkpoint = torch.load("/kaggle/input/checkpoint65k_sae/pytorch/default/1/checkpoint")
print(checkpoint["hyper_parameters"])
model.load_state_dict(checkpoint['state_dict'])

criterion = nn.MSELoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
l1_lambda = 0.01  # Regularization strength for sparsity

  checkpoint = torch.load("/kaggle/input/checkpoint65k_sae/pytorch/default/1/checkpoint")


{'input_dim': 3072, 'hidden_dim': 65536, 'l1_lambda': 0.00597965, 'lr': 2.5011e-05}


In [7]:
# Set model to evaluation mode
model.eval()
os.makedirs("sparse_latent_vectors", exist_ok=True)

dataset = ActivationDataset(
    data_dir, 
    batch_size=0, # not subsampled
    f_type="all", 
    test_fraction=1.0, # not used if type=all
    scale_factor=scale_factor, 
    seed=42 # not used
)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False) # take 1 batch at a time

# Extract and save latent vectors
batch_skip = 4 # 20GB limit on kaggle output
num_batches = 1
batch_size = 4096  # Size we can fit in VRAM
num_minibatches = 19  # 81920/8192 = 10 minibatches per batch
with torch.no_grad():
    for idx, batch_data in enumerate(data_loader):
        if idx < batch_skip:
            continue
        if idx >= batch_skip+num_batches :
            break
        batch, sent_idx, token_idx, token = batch_data
        sent_idx = sent_idx.to(device)
        token_idx = token_idx.to(device)
        token = token.to(device)
        batch = batch.squeeze(0)  # Remove batch dimension of 1
        
        # Process minibatches and save immediately
        for i in range(num_minibatches):
            start_idx = i * batch_size
            end_idx = (i + 1) * batch_size
            
            # Get minibatch slice
            minibatch = batch[start_idx:end_idx]
            _, encoded = model(minibatch)
            
            # Stack with metadata
            # Reshape metadata tensors to match batch size
            sent_idx_batch = sent_idx[:,start_idx:end_idx].T
            token_idx_batch = token_idx[:,start_idx:end_idx].T
            token_batch = token[:,start_idx:end_idx].T
            
            output_vectors = torch.cat((encoded, sent_idx_batch, token_idx_batch, token_batch), dim=1)
            
            # Save each minibatch immediately as a PyTorch tensor
            torch.save(output_vectors, f"sparse_latent_vectors/latent_vectors_batch_{idx}_minibatch_{i}.pt")
            # output_saved = torch.load(f"sparse_latent_vectors/latent_vectors_batch_{idx}_minibatch_{i}.pt")
            # output_vectors = output_vectors.to(torch.float16)
            # print(f"Data saved is near equal: {torch.allclose(output_vectors[:,:-3], output_saved[:,:-3], atol=1e-1)}")
            print(f"Saved minibatch {i+1} of {num_minibatches} for batch {idx}")



Loaded 129 batches for all set
Saved minibatch 1 of 19 for batch 4
Saved minibatch 2 of 19 for batch 4
Saved minibatch 3 of 19 for batch 4
Saved minibatch 4 of 19 for batch 4
Saved minibatch 5 of 19 for batch 4
Saved minibatch 6 of 19 for batch 4
Saved minibatch 7 of 19 for batch 4
Saved minibatch 8 of 19 for batch 4
Saved minibatch 9 of 19 for batch 4
Saved minibatch 10 of 19 for batch 4
Saved minibatch 11 of 19 for batch 4
Saved minibatch 12 of 19 for batch 4
Saved minibatch 13 of 19 for batch 4
Saved minibatch 14 of 19 for batch 4
Saved minibatch 15 of 19 for batch 4
Saved minibatch 16 of 19 for batch 4
Saved minibatch 17 of 19 for batch 4
Saved minibatch 18 of 19 for batch 4
Saved minibatch 19 of 19 for batch 4
