In [36]:
"""
Sample from a trained model
"""
import os
import pickle
from contextlib import nullcontext
import torch
import torch.nn as nn 
import torch.nn.functional as F
import tiktoken
from model import GPTConfig, GPT
import numpy as np
import wandb
import time

1. Load the model checkpoint.
2. Evaluate the model on a bunch of (N) contexts from the dataset. This should likely be done after loading train.bin and val.bin
as is done in get_batch. 
3. After evaluation, obtain the activations of the linear layer in Transformer and save them somewhere. 

I trained a 1 layer LM on Shakespeare dataset. It achieved training loss of 1.796 and validation loss of 1.920. block size was 64, batch size was 12, n_embd was 128, so n_ffwd was 512. 

With block size of 64 and batch size of 12, the number of tokens processed in each training step were 768. I trained for 2000 iterations so the total number of tokens was ~1.54M.

Now the Anthropic paper had trained on 100B tokens and collected a dataset of 10B activation vectors to train the autoencoder (by sampling activation vectors for 250 tokens each in 40 million contexts). Out of this, they used around 8.2B activation vectors for training the autoencoder. They trained for 1 million steps with batch size of 8192 (activation vectors where each vector is of length 512). 

For this work, I will ignore the validation dataset and just work with the training data (for now). I will, for now, choose around 1e5 contexts and sample 6 activation vectors to obtain a datset of 6e5 activation vectors. (I dont have a concrete reason for choosing this number of contexts contexts: just that my dataset is already pretty small, i.e. of 1.54M tokens only while Anthropic had 100B tokens so I might not be able to choose too many independent data points.)

So I think that all of the activations in the autoencoder dataset can be saved in one torch tensor or numpy array. 

In [10]:
# load transformer training dataset and define get_batch
dataset = 'shakespeare_char'
data_dir = os.path.join('data', dataset)
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
batch_size = 6000
block_size = 12
val_data = None # not loading val_data for now
device = 'cpu'
device_type = 'cuda' if 'cuda' in device else 'cpu'
def get_batch(split): # not modifying this function from nanoGPT train.py but will always just pass split='train'
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

In [11]:
## load the pre-trained transformer model 
out_dir = 'out-shakespeare-char' # ignored if init_from is not 'resume'
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
compile = False # TODO: Don't know why I needed to set compile to False before loading the model..
# TODO: I dont know why the next 4 lines are needed. state_dict does not seem to have any keys with unwanted_prefix.
unwanted_prefix = '_orig_mod.' 
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)

model.eval()
model.to(device)
if compile:
    model = torch.compile(model) # requires PyTorch 2.0 (optional)

number of parameters: 0.21M


In [4]:
# ## get contexts
# seed = 1337
# dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
# torch.manual_seed(seed)
# torch.cuda.manual_seed(seed)
# torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
# torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
# device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
# ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
# ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

contexts we obtain are of shape (b, t) = (6000, 12)


Assume b is even. That's okay.

At some point, replace torch.randint from the definition of ix to something involving torch.randperm that makes sure to pick out distinct batch numbers each time. This would be helpful in making sure that no data is repeated in training.

For now I can leave the definition of ix as it is just to get some baselines. 

In [46]:
# Split train_dataset into chunks of size block_size
# In chunk # 0, 0:block_size
# In chunk # 1, block_size:2*block_size
# ...
# 
n_chunks = len(train_data)//block_size
chunks_permuted = torch.randperm(n_chunks)

In [44]:
class count_chunks:
    def __init__(self):
        self.count = 0

    def increment(self, b):
        self.count += b

In [51]:
torch.randint(1000, (3,))

tensor([281, 464, 727])

In [45]:
counter = Count()
counter.increment(2)
counter.count

2

In [12]:
def initial_data(b, seed=0, n=256, t=1024, train_data=train_data):
    # get b contexts, n < t tokens 
    # returns b*n activation vectors
    assert n <= t, "Number of tokens chosen must not exceed context window length"

    torch.manual_seed(seed)
    ix = torch.randint(len(train_data) - block_size, (b,))
    contexts = torch.stack([torch.from_numpy((train_data[i:i+block_size]).astype(np.int64)) for i in ix]) # (b, t)
    activations = model.get_gelu_acts(contexts) # (b, t, n_ffwd)
    
    # sample n tokens from each context and flatten the batch and token dimension
    data = torch.stack([activations[i, torch.randint(t, (n,)), :] for i in range(b)]).view(-1, activations.shape[-1]) #(b*n, n_ffwd)

    # randomly shuffle all activation vectors and return
    return data[torch.randperm(b*n)] 

def refill_data(data, seed=0, b=100, n=256, t=1024):
    # remove the first N//2 contexts as they have already been used 
    # fill new contexts and shuffle again
    torch.manual_seed(seed)
    N, n_ffwd = data.shape # N = b*n/2
    data = data[N//2:] # remove the first half of activation vectors 
    ix = torch.randint(len(train_data) - block_size, (b//2,)) # pick new b//2 contexts
    contexts = torch.stack([torch.from_numpy((train_data[i:i+block_size]).astype(np.int64)) for i in ix]) # (b//2, t)
    activations = model.get_gelu_acts(contexts) # (b//2, t, n_ffwd)

    # sample n tokens from each context and flatten the batch and token dimension
    new_data = torch.stack([activations[i, torch.randint(t, (n,)), :] for i in range(b//2)]).view(-1, n_ffwd) # (n * b//2, n_ffwd)
    data = torch.cat((data, new_data))
    return data[torch.randperm(n * b)] # randomly shuffling all activation vectors   

In [37]:
class AutoEncoder(nn.Module):
    def __init__(self, n, m, lam=0.003):
        # for us, n will be d_MLP and m will be the number of features
        super().__init__()
        self.enc = nn.Linear(n, m)
        self.relu = nn.ReLU()
        self.dec = nn.Linear(m, n)
        self.lam = lam # coefficient of L_1 loss

    def forward(self, acts):
        # acts is of shape (.., n) where .. are batch dimensions
        x = acts - self.dec.bias # (.., n)
        f = self.relu(self.enc(x)) # (.., m)
        x = self.dec(f) # (.., n)
        recons_loss = F.mse_loss(x, acts) # scalar
        l1loss = F.l1_loss(f, torch.zeros(f.shape), reduction='sum') # scalar
        loss = recons_loss + self.lam * l1loss # scalar
        out = {'recons_loss': recons_loss, 'l1loss': l1loss, 
                'loss': loss, 'recons_acts': x, 'f': f}
        return loss, out

In [38]:
block_size = 12 # length of context window
n_tokens = block_size//4 # number of tokens from each context
n_buffer_contexts = 100 # number of contexts in a buffer
data = initial_data(b=n_buffer_contexts, n=n_tokens, t=block_size) 
print(data.shape)
data = refill_data(data, seed=0, n=n_tokens, t=block_size).shape

torch.Size([300, 512])


In [None]:
n_contexts = len(train_data)//block_size # total number of contexts on which we will train



In [14]:
wandb.init(project=f'sae-{dataset}',
           name=f'sae_{dataset}_{time.time()}')

torch.manual_seed(0)
d_mlp = data.shape[1] 
curr_batch = data[0:10]
n_features = 1024

sae = AutoEncoder(d_mlp, n_features, lam=1e-3)

batch_size = 600
n_batches = 3600

optimizer = torch.optim.Adam(sae.parameters(), lr=3e-4)
for batch in range(n_batches):
    curr_batch = sae_data[batch*10: batch*10+10]
    loss, out = sae(curr_batch)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    # remove gradient information parallel to the decoder columns
    optimizer.step()

    # normalize decoder columns

    # TODO: normalize the reconstruction loss
    # sae.dec.weight 
    if batch % 100 == 0:
        xs, ys = get_batch('train')
        _, reconstructed_nll_loss = model.reconstructed_loss(sae, xs, ys)
        print(f"batch: {batch}/{n_batches}, recons loss: {out['recons_loss'].item():.2f}, l1_loss: {out['l1loss'].item():.2f}, total_loss = {loss.item():.2f}")
        wandb.log({"recons_loss": out['recons_loss'].item(),
                "l1_loss": out['l1loss'].item(),
                "total_loss": loss.item(),
                "l0_norm": torch.mean(torch.count_nonzero(out['f'], dim=-1), dtype=torch.float32),
                'nll_loss': reconstructed_nll_loss 
                }
        # TODO: Also log decoder columns lengths
        )

wandb.finish()

In [15]:
torch.manual_seed(0)
d_mlp = sae_data.shape[1] 
curr_batch = sae_data[0:10]
n_features = 1024

sae = AutoEncoder(d_mlp, n_features, lam=8e-4)
curr_batch = sae_data[0: 10]
loss, out = sae(curr_batch)


batch_size = 600
xs, ys = get_batch('train')
print(model(xs, ys)[1])
model.reconstructed_loss(sae, xs, ys)[1]

tensor(1.8860, grad_fn=<NllLossBackward0>)


tensor(5.2012, grad_fn=<NllLossBackward0>)

In [150]:
torch.manual_seed(0)
d_mlp = sae_data.shape[1]
curr_batch = sae_data[0:10]
n_features = 1024

sae = AutoEncoder(d_mlp, n_features, lam=8e-4)

print(sae.dec.weight[0])
# sae.dec has shape (n, m) = (512, 1024)
# we want each decoder column to be unit norm
print(torch.linalg.vector_norm(sae.dec.weight, dim=0))
print(torch.linalg.vector_norm(F.normalize(sae.dec.weight, dim=0), dim=0))
print(sae.dec.weight[0])

tensor([-0.0249,  0.0285, -0.0200,  ..., -0.0105,  0.0050, -0.0005],
       grad_fn=<SelectBackward0>)
tensor([0.4084, 0.4176, 0.4176,  ..., 0.4198, 0.3988, 0.4185],
       grad_fn=<LinalgVectorNormBackward0>)
tensor([1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],
       grad_fn=<LinalgVectorNormBackward0>)
tensor([-0.0249,  0.0285, -0.0200,  ..., -0.0105,  0.0050, -0.0005],
       grad_fn=<SelectBackward0>)


In [None]:
@torch.no_grad()
def estimate_percentage_mlp_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

batch: 0/3600, recons loss: 0.10, l1_loss: 698.85, total_loss = 0.80
batch: 100/3600, recons loss: 0.06, l1_loss: 2.37, total_loss = 0.07
batch: 200/3600, recons loss: 0.07, l1_loss: 25.37, total_loss = 0.10
batch: 300/3600, recons loss: 0.09, l1_loss: 1.56, total_loss = 0.09
batch: 400/3600, recons loss: 0.07, l1_loss: 0.06, total_loss = 0.07
batch: 500/3600, recons loss: 0.07, l1_loss: 0.00, total_loss = 0.07
batch: 600/3600, recons loss: 0.08, l1_loss: 8.53, total_loss = 0.09
batch: 700/3600, recons loss: 0.11, l1_loss: 1.87, total_loss = 0.11
batch: 800/3600, recons loss: 0.07, l1_loss: 1.37, total_loss = 0.07
batch: 900/3600, recons loss: 0.07, l1_loss: 0.02, total_loss = 0.07
batch: 1000/3600, recons loss: 0.09, l1_loss: 0.22, total_loss = 0.09
batch: 1100/3600, recons loss: 0.07, l1_loss: 0.12, total_loss = 0.07
batch: 1200/3600, recons loss: 0.11, l1_loss: 0.33, total_loss = 0.11


KeyboardInterrupt: 

wandb: Network error (ReadTimeout), entering retry loop.
