In [None]:
import zstandard as zstd
import json
import os
import io
from tqdm import tqdm
from nnsight import LanguageModel
from dictionary_learning.buffer import ActivationBuffer
from dictionary_learning.dictionary import AutoEncoder
from dictionary_learning.training import trainSAE
import torch as t

In [None]:
# set up model
model = LanguageModel('EleutherAI/pythia-70m-deduped', device_map='cuda:0')
submodule = model.gpt_neox.layers[2].mlp.dense_4h_to_h

# Train SAEs

In [None]:
# set up data as a generator
data_path = '/share/data/datasets/pile/the-eye.eu/public/AI/pile/train/00.jsonl.zst' # this dataset is not available anymore on the-eye.eu
compressed_file = open(data_path, 'rb')
dctx = zstd.ZstdDecompressor()
reader = dctx.stream_reader(compressed_file)
text_stream = io.TextIOWrapper(reader, encoding='utf-8')
def generator():
    for line in text_stream:
        yield json.loads(line)['text']
data = generator()

In [None]:
buffer = ActivationBuffer(
    data,
    model,
    submodule,
    in_batch_size=64,
    out_batch_size=4096,
    n_ctxs=5e4,
    device='cuda:0'
)

In [None]:
ae = trainSAE(
    buffer,
    activation_dim=512,
    dictionary_size = 8 * 512,
    steps=1000,
    lr = 1e-3,
    sparsity_penalty = 3e-4,
    entropy=False,
    resample_steps = 1000,
    log_steps = None,
    device='cuda:0'
)

# Load SAEs

In [None]:
ae = AutoEncoder(512, 4*512)#.cuda()
ae.load_state_dict(t.load('autoencoders/reg0.0001_entFalse.pt'))

In [None]:
acts = next(buffer)

In [None]:
import numpy as np
dict_acts = ae.encode(acts.cuda())
freqs = (dict_acts !=0).sum(dim=0) / dict_acts.shape[0]

import matplotlib.pyplot as plt
plt.hist(freqs.cpu(), bins=np.logspace(np.log10(1e-4), np.log10(4096), 100))
plt.xscale('log')

plt.show()

In [None]:
t.nonzero(1 - (dict_acts == 0).all(dim=0).float())

In [None]:
ae.encode(acts.cuda())

In [None]:
from einops import rearrange
import torch as t

inputs = buffer.tokenized_batch()
with model.generate(max_new_tokens=1, pad_token_id=model.tokenizer.pad_token_id) as generator:
    with generator.invoke(inputs['input_ids'], scan=False) as invoker:
        hidden_states = submodule.output.save()
dictionary_activations = ae.encode(hidden_states.value)
flattened_acts = rearrange(dictionary_activations, 'b n d -> (b n) d')
freqs = (flattened_acts !=0).sum(dim=0) / flattened_acts.shape[0]

In [None]:
for idx, freq in enumerate(freqs):
    if 3e-3 < freq and freq < 1e-2:
        print(f"feat {idx} freq: {freq}")

In [None]:
def list_decode(x):
    if isinstance(x, int):
        return model.tokenizer.decode(x)
    else:
        return [list_decode(y) for y in x]

In [None]:
k = 30
feat = 400
acts = dictionary_activations[:, :, feat].cpu()
flattened_acts = rearrange(acts, 'b l -> (b l)')
topk_indices = t.argsort(flattened_acts, dim=0, descending=True)[:k]
batch_indices = topk_indices // acts.shape[1]
token_indices = topk_indices % acts.shape[1]

In [None]:
from circuitsvis.activations import text_neuron_activations

tokens = [
    inputs['input_ids'][batch_idx, :token_idx+1].tolist() for batch_idx, token_idx in zip(batch_indices, token_indices)
]
tokens = list_decode(tokens)
activations = [
    acts[batch_idx, :token_id+1, None, None] for batch_idx, token_id in zip(batch_indices, token_indices)
]
text_neuron_activations(tokens, activations)

In [None]:
batch_indices

In [None]:
print(model.tokenizer.decode(inputs[40].ids))

In [None]:
import numpy as np
import torch
k = 10
found_indices = torch.argsort(dictionary_activations, descending=True)[:k]
num_datapoints = int(dictionary_activations.shape[0]/128)
datapoint_indices =[np.unravel_index(i, (64, 128)) for i in found_indices]
text_list = []
full_text = []
token_list = []
full_token_list = []
for md, s_ind in datapoint_indices:
    md = int(md)
    s_ind = int(s_ind)
    full_tok = torch.tensor(dataset[md]["input_ids"])
    full_text.append(tokenizer.decode(full_tok))
    tok = dataset[md]["input_ids"][:s_ind+1]
    text = tokenizer.decode(tok)
    text_list.append(text)
    token_list.append(tok)
    full_token_list.append(full_tok)
text_list, full_text, token_list, full_token_list

In [None]:
# Now we can use the model to get the activations
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from einops import rearrange
import torch 
# num_features, d_model = autoencoder.encoder.shape # Fix this for shape purposes
texts = buffer.text_batch()
datapoints = len(texts)
batch_size = 64
neuron_activations = torch.zeros((datapoints*max_length, d_model))
dictionary_activations = torch.zeros((datapoints*max_length))

with torch.no_grad(), dataset.formatted_as("pt"):
    dl = DataLoader(dataset["input_ids"], batch_size=batch_size)
    for i, batch in enumerate(tqdm(dl)):
        # Replace this with your residual stream stuff
        # _, cache = model.run_with_cache(batch.to(device))
        # batched_neuron_activations = rearrange(cache[cache_name], "b s n -> (b s) n" )

        # Replace with your projection to probe direction
        # batched_dictionary_activations = smaller_auto_encoder.encode(batched_neuron_activations)
        dictionary_activations[i*batch_size*max_length:(i+1)*batch_size*max_length] = batched_dictionary_activations.cpu()

In [None]:
import torch as t
def entropy(p):
    p = p/p.sum(dim=-1, keepdim=True)
    log_p = p.log().nan_to_num()
    entropies = -(p * log_p).sum(dim=-1)
    out = entropies.nan_to_num().mean()
    return out

In [None]:
x = t.Tensor([[0, 0,0], [1, 4, 2]])
entropy(x)

In [None]:
x

In [None]:
x * x.log()

In [None]:
0 * float("-inf")

In [None]:
import torch as t

def entropy(p):
    eps = 1e-8
    # Calculate the sum along the last dimension (i.e., sum of each vector in the batch)
    p_sum = p.sum(dim=-1, keepdim=True)
    
    # Avoid in-place operations that can interfere with autograd
    p_normed = p / (p_sum + eps)  # Add eps to prevent division by zero
    
    # Compute the log safely, adding eps inside the log to prevent log(0)
    p_log = t.log(p_normed + eps)  # Add eps to prevent log(0)

    # Compute the entropy, this will give zero for elements where p_normed is zero
    ent = -(p_normed * p_log)
    
    # Zero out the entropy where the sum of p is zero (i.e., for all-zero vectors)
    ent = t.where(p_sum > 0, ent, t.zeros_like(ent))

    # Sum the entropy across the features and then take the mean across the batch
    return ent.sum(dim=-1).mean()

# Example usage:
batch_size = 3
vector_length = 5
p = t.tensor([[0.1, 0.2, 0.7, 0.0, 0.0],
              [0.0, 0.0, 0.0, 0.0, 0.0],  # All-zero vector
              [0.3, 0.3, 0.4, 0.0, 0.0]], requires_grad=True)

entropy_value = entropy(p)
entropy_value.backward()

print("Entropy:", entropy_value.item())
print("Gradients:", p.grad)