In [1]:
"""
Sample from a trained model
"""
import os
from contextlib import nullcontext
import torch
import torch.nn as nn 
import torch.nn.functional as F
from model import GPTConfig, GPT
import numpy as np
import wandb
import time

In [2]:
# Other things to consider:
# Should I save activations as float16 instead of float32 to save space? This should be okay since we don't need that much precision I think.

In [3]:
# Things to do:
# 1. figure out CPU/GPU issues
# 2. Possible memory issues with activations
# 3. Feature density histograms
# 4. Neuron resampling

In [4]:
# load transformer training dataset and define get_batch
dataset = 'shakespeare_char'
data_dir = os.path.join('data', dataset)
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
batch_size = 6000
block_size = 12
val_data = None # not loading val_data for now
device = 'cpu'
device_type = 'cuda' if 'cuda' in device else 'cpu'
def get_batch(split): # not modifying this function from nanoGPT train.py but will always just pass split='train'
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

In [5]:
## load the pre-trained transformer model 
out_dir = 'out-shakespeare-char' # ignored if init_from is not 'resume'
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
compile = False # TODO: Don't know why I needed to set compile to False before loading the model..
# TODO: I dont know why the next 4 lines are needed. state_dict does not seem to have any keys with unwanted_prefix.
unwanted_prefix = '_orig_mod.' 
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)

model.eval()
model.to(device)
if compile:
    model = torch.compile(model) # requires PyTorch 2.0 (optional)

number of parameters: 0.21M


In [6]:
## get contexts
seed = 1337
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

In [7]:
def initial_data(b, seed=0, n=256, t=1024, train_data=train_data):
    # get b contexts, n < t tokens 
    # returns b*n activation vectors
    assert n <= t, "Number of tokens chosen must not exceed context window length"

    torch.manual_seed(seed)
    ix = torch.randint(len(train_data) - block_size, (b,))
    contexts = torch.stack([torch.from_numpy((train_data[i:i+block_size]).astype(np.int64)) for i in ix]) # (b, t)
    activations = model.get_gelu_acts(contexts) # (b, t, n_ffwd)
    
    # sample n tokens from each context and flatten the batch and token dimension
    data = torch.stack([activations[i, torch.randint(t, (n,)), :] for i in range(b)]).view(-1, activations.shape[-1]) #(b*n, n_ffwd)

    # randomly shuffle all activation vectors and return
    return data[torch.randperm(b*n)] 

def refill_data(data, seed=0, b=100, n=256, t=1024):
    # remove the first N//2 contexts as they have already been used 
    # fill new contexts and shuffle again
    torch.manual_seed(seed)
    N, n_ffwd = data.shape # N = b*n/2
    assert N == b * n, "there is some issue with shape of data"
    data = data[N//2:] # remove the first half of activation vectors 
    ix = torch.randint(len(train_data) - block_size, (b//2,)) # pick new b//2 contexts
    contexts = torch.stack([torch.from_numpy((train_data[i:i+block_size]).astype(np.int64)) for i in ix]) # (b//2, t)
    activations = model.get_gelu_acts(contexts) # (b//2, t, n_ffwd)

    # sample n tokens from each context and flatten the batch and token dimension
    new_data = torch.stack([activations[i, torch.randint(t, (n,)), :] for i in range(b//2)]).view(-1, n_ffwd) # (b//2 * n, n_ffwd)
    data = torch.cat((data, new_data)) 
    return data[torch.randperm(n * b)]

In [8]:
class AutoEncoder(nn.Module):
    def __init__(self, n, m, lam=0.003):
        # for us, n will be d_MLP and m will be the number of features
        super().__init__()
        self.enc = nn.Linear(n, m)
        self.relu = nn.ReLU()
        self.dec = nn.Linear(m, n)
        self.lam = lam # coefficient of L_1 loss

    def forward(self, acts):
        # acts is of shape (b, n) where b = batch_size, n = d_MLP
        x = acts - self.dec.bias # (b, n), dtype =
        f = self.relu(self.enc(x)) # (b, m)
        x = self.dec(f) # (b, n)
        mseloss = F.mse_loss(x, acts) # scalar
        l1loss = F.l1_loss(f, torch.zeros(f.shape), reduction='sum') # scalar
        loss = mseloss + self.lam * l1loss # scalar
        out = {'mse_loss': mseloss, 'l1loss': l1loss, 
                'loss': loss, 'recons_acts': x, 'f': f}
        return loss, out

In [9]:
ex = torch.randint(0, 3, (3, 4))
print(ex)
print(torch.count_nonzero(ex, dim=-1))
print(torch.mean(torch.count_nonzero(ex, dim=-1), dtype=torch.float32))

tensor([[1, 1, 2, 1],
        [2, 1, 2, 2],
        [0, 2, 2, 0]])
tensor([4, 4, 2])
tensor(3.3333)


In [10]:
def update_grad(grad, weight):
    # remove gradient information parallel to weight vectors
    
    # compute projection of gradient onto weight
    # recall proj_b a = (a.\hat{b}) \hat{b} is the projection of a onto b

    unit_w = F.normalize(weight, dim=0) # \hat{b}
    proj = torch.sum(grad * unit_w, dim=0) * unit_w 

    return grad - proj

In [11]:
wandb_log = True
batch_size = 10
n_steps = 3000
block_size = 12 # length of context window
n_tokens = block_size//4 # number of tokens from each context
contexts_in_buffer = 50 * batch_size # number of contexts in buffer

# let it be an even multiple of batch size so that after an integer number of steps, buffer is exactly half-used
assert contexts_in_buffer % (2*batch_size) == 0, "adjust contexts_in_buffer so that it is an even multiple of batch_size"

# There are contexts_in_buffer * n_tokens activations in the buffer
refill_interval = int(contexts_in_buffer * n_tokens/(2*batch_size))

# load initial data
data = initial_data(b=contexts_in_buffer, n=n_tokens, t=block_size) 
print(f"data has shape {tuple(data.shape)}")

torch.manual_seed(0)
n_features = 1024 # change this to 4096 for owt
d_mlp = data.shape[-1] # MLP activation dimension
sae = AutoEncoder(d_mlp, n_features, lam=1e-3)
optimizer = torch.optim.Adam(sae.parameters(), lr=3e-4)
batch = 0

if wandb_log:
    wandb.init(project=f'sae-{dataset}', name=f'sae_{dataset}_{time.time()}')

for i in range(n_steps):    
    if i > 0 and i % refill_interval == 0:
        print(f'updating data buffer after {i} steps')
        data = refill_data(data, seed=i, b=contexts_in_buffer, n=n_tokens, t=block_size)
        batch = 0

    curr_batch = data[batch * batch_size: (batch + 1) * batch_size]
    loss, out = sae(curr_batch)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()

    # remove gradient information parallel to the decoder columns
    sae.dec.weight.grad = update_grad(sae.dec.weight.grad, sae.dec.weight)
    optimizer.step()

    # normalize decoder columns
    sae.dec.weight = nn.Parameter(F.normalize(sae.dec.weight, dim=0))

    batch += 1

    if i % 100 == 0:
        
        xs, ys = get_batch('train')
        reconstructed_nll_loss = model.reconstructed_loss(sae, xs, ys)
        
        print(f"batch: {i}/{n_steps}, mse loss: {out['mse_loss'].item():.2f}, l1_loss: {out['l1loss'].item():.2f}, \
              total_loss = {loss.item():.2f}, nll loss: {reconstructed_nll_loss}")

        if wandb_log:
            wandb.log({'losses/mse_loss': out['mse_loss'].item(),
                    'losses/l1_loss': out['l1loss'].item(),
                    'losses/total_loss': loss.item(),
                    'losses/nll_loss': reconstructed_nll_loss,
                    'debug/l0_norm': torch.mean(torch.count_nonzero(out['f'], dim=-1), dtype=torch.float32),
                    'debug/mean_dictionary_vector_length': torch.mean(torch.linalg.vector_norm(sae.dec.weight, dim=0)),
                    })
    
    # if i > 0 and i % 1000 == 0: # plot the feature density histograms
    #     # pick 10 tokens each from 1000 contexts and count, for each autoencoder neuron, the number of tokens on which its output > 0. 
    #     raise NotImplementedError
    #     #TODO: compute feature density histograms

    # if i > 0 and i % 25000 == 0:
    # TODO: resample neurons

if wandb_log:
    wandb.finish()

data has shape (1500, 512)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mshehper[0m. Use [1m`wandb login --relogin`[0m to force relogin


batch: 0/3000, mse loss: 0.09, l1_loss: 664.95,               total_loss = 0.76, nll loss: 4.876166343688965
updating data buffer after 75 steps
batch: 100/3000, mse loss: 0.05, l1_loss: 8.36,               total_loss = 0.06, nll loss: 10.940332412719727
updating data buffer after 150 steps
batch: 200/3000, mse loss: 0.10, l1_loss: 2.06,               total_loss = 0.10, nll loss: 11.48123836517334
updating data buffer after 225 steps
updating data buffer after 300 steps
batch: 300/3000, mse loss: 0.07, l1_loss: 0.11,               total_loss = 0.07, nll loss: 11.974457740783691
updating data buffer after 375 steps
batch: 400/3000, mse loss: 0.08, l1_loss: 1.76,               total_loss = 0.08, nll loss: 10.865640640258789
updating data buffer after 450 steps
batch: 500/3000, mse loss: 0.08, l1_loss: 2.99,               total_loss = 0.09, nll loss: 9.930020332336426
updating data buffer after 525 steps
updating data buffer after 600 steps
batch: 600/3000, mse loss: 0.10, l1_loss: 0.06, 



0,1
debug/l0_norm,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
debug/mean_dictionary_vector_length,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss,▇▂▇▅▅▆█▂▅▅▄▅▄▅▆▆▅▅▃█▄▃▅▅▄▁▆▄▇▅
losses/nll_loss,▂▇██▇▆▅▄▃▂▂▂▁▁▂▁▂▂▁▂▁▁▁▂▂▂▂▂▁▁
losses/total_loss,█▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▂▁

0,1
debug/l0_norm,0.7
debug/mean_dictionary_vector_length,1.0
losses/l1_loss,0.16536
losses/mse_loss,0.0738
losses/nll_loss,4.65763
losses/total_loss,0.07396


In [19]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import numpy as np

# Sample data
data = [np.random.normal(0, 1, 100) + i for i in range(10)]

fig, ax = plt.subplots()
bins = np.linspace(-5, 15, 30)

def animate(i):
    ax.clear()
    ax.hist(data[i], bins=bins, color='blue', alpha=0.7)
    ax.set_title(f"Histogram at Step {i}")

ani = FuncAnimation(fig, animate, frames=len(data), interval=500)

# Convert animation to HTML5 video and display
html_video = HTML(ani.to_html5_video())

# Close the figure to prevent displaying the static image
plt.close(fig)

# Display the HTML video
html_video


In [26]:
# use this to avoid showing matplotlib figures in output; change the settings by matplotlib inline
#%matplotlib agg

In [33]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import io

wandb.init(project='trying-tables')

for i in range(5):
    # Sample data for the histogram
    data = np.random.randn(1000)

    # Create a figure and axis for the plot
    fig, ax = plt.subplots()

    # Create a histogram plot
    ax.hist(data, bins=30)
    ax.set_title('Histogram')

    # Save the plot to a buffer
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)

    # Convert buffer to PIL Image
    image = Image.open(buf)

    wandb.log({"example_rgb_image": wandb.Image(image),
               "step": i*5})

wandb.finish()





0,1
step,▁▃▅▆█

0,1
step,20


In [None]:
# This is good. But I want to do is something of the following form:

# for step in range of steps:
# ----- # perform training 
# ----- if step = 10000 or something:
# ----- ----- # TODO: how to calculate the feature density? how much data to use, etc?
              # I THINK I can take 10 million contexts, pick 10 tokens in each context and calculate feature activations on these 100 million tokens
              # TODO: should I choose different 10 million contexts every time during training?
              # perhaps I need to log the top 10 tokens (and a context of 4 tokens on each side) for each alive feature
                # perhaps there should be a widget that you can scroll through to see the top 10 tokens (and their contexts) for each 
# ----- ----- TECHNICALLY EASY: log the number of features in the high density cluster; TODO: It's hard to define a cutoff
# ----- ----- TECHNICALLY EASY: log the minimimum feature density of the high density cluster; TODO: It's hard to define a cutoff
# ----- ----- TECHNICALLY EASY: log the number of features with density above 1%; if it's too high, increase the L1 coefficient 
# ----- ----- TECHNICALLY EASY: plotting the histogram
# ----- ----- TECHNICALLY EASY: log the number of alive autoencoder neurons (i.e. those in the high density cluster + those in the ultralow density cluster)
# ----- ----- TECHNICALLY EASY: log the minimum feature density amongst alive autoencoder neurons
# ----- ----- 

In [35]:
import wandb
import numpy as np

# Initialize a wandb run
wandb.init(project='trying-tables')

for i in range(5):
    # Example array
    array = np.random.rand(10, 5)  # A 10x5 array

    # Convert array to wandb.Table
    table = wandb.Table(data=array, columns=[f"Col{i}" for i in range(array.shape[1])])

    # Log the table
    wandb.log({"my_array_table": table})

# Finish the run
wandb.finish()




In [None]:
# For manual inspection, perhaps I could use a table 

In [None]:
# Autoencoder neurons are going to be of three kinds.
# 1. In ultralow density cluster
# 2. In high density cluster
# 3. dead

# Ideally we want to minimize the number of neurons that are dead or are in the ultralow density cluster. That's perhaps where neuron resampling comes in?

In [None]:
# Feature Density Histograms: Specific metrics from these histograms include:
# The number of alive features outside of the ultralow density cluster
# The minimum feature density at which we see a significant number of non-ultralow-density-cluster features.
# The number of features with density above 1%. A significant number of features above this level seems to correspond to an L1 coefficient that is too low.

#### Comments

1. Load the model checkpoint.
2. Evaluate the model on a bunch of (N) contexts from the dataset. This should likely be done after loading train.bin and val.bin
as is done in get_batch. 
3. After evaluation, obtain the activations of the linear layer in Transformer and save them somewhere.

I trained a 1 layer LM on Shakespeare dataset. It achieved training loss of 1.796 and validation loss of 1.920. block size was 64, batch size was 12, n_embd was 128, so n_ffwd was 512. 

With block size of 64 and batch size of 12, the number of tokens processed in each training step were 768. I trained for 2000 iterations so the total number of tokens was ~1.54M.

Now the Anthropic paper had trained on 100B tokens and collected a dataset of 10B activation vectors to train the autoencoder (by sampling activation vectors for 250 tokens each in 40 million contexts). Out of this, they used around 8.2B activation vectors for training the autoencoder. They trained for 1 million steps with batch size of 8192 (activation vectors where each vector is of length 512). 

For this work, I will ignore the validation dataset and just work with the training data (for now). I will, for now, choose around 1e5 contexts and sample 6 activation vectors to obtain a datset of 6e5 activation vectors. (I dont have a concrete reason for choosing this number of contexts contexts: just that my dataset is already pretty small, i.e. of 1.54M tokens only while Anthropic had 100B tokens so I might not be able to choose too many independent data points.)

So I think that all of the activations in the autoencoder dataset can be saved in one torch tensor or numpy array. 