In [1]:
"""
Sample from a trained model
"""
import os
from contextlib import nullcontext
import torch
import torch.nn as nn 
import torch.nn.functional as F
from model import GPTConfig, GPT
import numpy as np
import wandb
import time

In [2]:
# Other things to consider:
# Should I save activations as float16 instead of float32 to save space? This should be okay since we don't need that much precision I think.

1. Load the model checkpoint.
2. Evaluate the model on a bunch of (N) contexts from the dataset. This should likely be done after loading train.bin and val.bin
as is done in get_batch. 
3. After evaluation, obtain the activations of the linear layer in Transformer and save them somewhere. 

I trained a 1 layer LM on Shakespeare dataset. It achieved training loss of 1.796 and validation loss of 1.920. block size was 64, batch size was 12, n_embd was 128, so n_ffwd was 512. 

With block size of 64 and batch size of 12, the number of tokens processed in each training step were 768. I trained for 2000 iterations so the total number of tokens was ~1.54M.

Now the Anthropic paper had trained on 100B tokens and collected a dataset of 10B activation vectors to train the autoencoder (by sampling activation vectors for 250 tokens each in 40 million contexts). Out of this, they used around 8.2B activation vectors for training the autoencoder. They trained for 1 million steps with batch size of 8192 (activation vectors where each vector is of length 512). 

For this work, I will ignore the validation dataset and just work with the training data (for now). I will, for now, choose around 1e5 contexts and sample 6 activation vectors to obtain a datset of 6e5 activation vectors. (I dont have a concrete reason for choosing this number of contexts contexts: just that my dataset is already pretty small, i.e. of 1.54M tokens only while Anthropic had 100B tokens so I might not be able to choose too many independent data points.)

So I think that all of the activations in the autoencoder dataset can be saved in one torch tensor or numpy array. 

In [3]:
# load transformer training dataset and define get_batch
dataset = 'shakespeare_char'
data_dir = os.path.join('data', dataset)
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
batch_size = 6000
block_size = 12
val_data = None # not loading val_data for now
device = 'cpu'
device_type = 'cuda' if 'cuda' in device else 'cpu'
def get_batch(split): # not modifying this function from nanoGPT train.py but will always just pass split='train'
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
    if device_type == 'cuda':
        # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True)
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y

In [4]:
## load the pre-trained transformer model 
out_dir = 'out-shakespeare-char' # ignored if init_from is not 'resume'
ckpt_path = os.path.join(out_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
compile = False # TODO: Don't know why I needed to set compile to False before loading the model..
# TODO: I dont know why the next 4 lines are needed. state_dict does not seem to have any keys with unwanted_prefix.
unwanted_prefix = '_orig_mod.' 
for k,v in list(state_dict.items()):
    if k.startswith(unwanted_prefix):
        state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)

model.eval()
model.to(device)
if compile:
    model = torch.compile(model) # requires PyTorch 2.0 (optional)

number of parameters: 0.21M


In [5]:
## get contexts
seed = 1337
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

In [6]:
def initial_data(b, seed=0, n=256, t=1024, train_data=train_data):
    # get b contexts, n < t tokens 
    # returns b*n activation vectors
    assert n <= t, "Number of tokens chosen must not exceed context window length"

    torch.manual_seed(seed)
    ix = torch.randint(len(train_data) - block_size, (b,))
    contexts = torch.stack([torch.from_numpy((train_data[i:i+block_size]).astype(np.int64)) for i in ix]) # (b, t)
    activations = model.get_gelu_acts(contexts) # (b, t, n_ffwd)
    
    # sample n tokens from each context and flatten the batch and token dimension
    data = torch.stack([activations[i, torch.randint(t, (n,)), :] for i in range(b)]).view(-1, activations.shape[-1]) #(b*n, n_ffwd)

    # randomly shuffle all activation vectors and return
    return data[torch.randperm(b*n)] 

def refill_data(data, seed=0, b=100, n=256, t=1024):
    # remove the first N//2 contexts as they have already been used 
    # fill new contexts and shuffle again
    torch.manual_seed(seed)
    N, n_ffwd = data.shape # N = b*n/2
    assert N == b * n, "there is some issue with shape of data"
    data = data[N//2:] # remove the first half of activation vectors 
    ix = torch.randint(len(train_data) - block_size, (b//2,)) # pick new b//2 contexts
    contexts = torch.stack([torch.from_numpy((train_data[i:i+block_size]).astype(np.int64)) for i in ix]) # (b//2, t)
    activations = model.get_gelu_acts(contexts) # (b//2, t, n_ffwd)

    # sample n tokens from each context and flatten the batch and token dimension
    new_data = torch.stack([activations[i, torch.randint(t, (n,)), :] for i in range(b//2)]).view(-1, n_ffwd) # (b//2 * n, n_ffwd)
    data = torch.cat((data, new_data)) 
    return data[torch.randperm(n * b)]

In [7]:
class AutoEncoder(nn.Module):
    def __init__(self, n, m, lam=0.003):
        # for us, n will be d_MLP and m will be the number of features
        super().__init__()
        self.enc = nn.Linear(n, m)
        self.relu = nn.ReLU()
        self.dec = nn.Linear(m, n)
        self.lam = lam # coefficient of L_1 loss

    def forward(self, acts):
        # acts is of shape (.., n) where .. are batch dimensions
        x = acts - self.dec.bias # (.., n)
        f = self.relu(self.enc(x)) # (.., m)
        x = self.dec(f) # (.., n)
        mseloss = F.mse_loss(x, acts) # scalar
        l1loss = F.l1_loss(f, torch.zeros(f.shape), reduction='sum') # scalar
        loss = mseloss + self.lam * l1loss # scalar
        out = {'mse_loss': mseloss, 'l1loss': l1loss, 
                'loss': loss, 'recons_acts': x, 'f': f}
        return loss, out

In [28]:
wandb_log = True
batch_size = 10
n_steps = 3600
block_size = 12 # length of context window
n_tokens = block_size//4 # number of tokens from each context
contexts_in_buffer = 50 * batch_size # number of contexts in buffer

# let it be an even multiple of batch size so that after an integer number of steps, buffer is exactly half-used
assert contexts_in_buffer % (2*batch_size) == 0, "adjust contexts_in_buffer so that it is an even multiple of batch_size"

# There are contexts_in_buffer * n_tokens activations in the buffer
refill_interval = int(contexts_in_buffer * n_tokens/(2*batch_size))

# load initial data
data = initial_data(b=contexts_in_buffer, n=n_tokens, t=block_size) 
print(f"data has shape {tuple(data.shape)}")

torch.manual_seed(0)
n_features = 1024 # change this to 4096 for owt
d_mlp = data.shape[-1] # MLP activation dimension
sae = AutoEncoder(d_mlp, n_features, lam=1e-3)
optimizer = torch.optim.Adam(sae.parameters(), lr=3e-4)
batch = 0

if wandb_log:
    wandb.init(project=f'sae-{dataset}', name=f'sae_{dataset}_{time.time()}')

for i in range(n_steps):    
    if i > 0 and i % refill_interval == 0:
        print(f'updating data buffer after {i} steps')
        data = refill_data(data, seed=i, b=contexts_in_buffer, n=n_tokens, t=block_size)
        batch = 0

    curr_batch = data[batch * batch_size: (batch + 1) * batch_size]
    loss, out = sae(curr_batch)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()

    #  TODO: remove gradient information parallel to the decoder columns
    optimizer.step()

    # normalize decoder columns
    # sae.dec.weight = nn.Parameter(F.normalize(sae.dec.weight, dim=0))

    batch += 1
    break

    if i % 100 == 0:
        
        xs, ys = get_batch('train')
        reconstructed_nll_loss = model.reconstructed_loss(sae, xs, ys)
        
        print(f"batch: {i}/{n_steps}, mse loss: {out['mse_loss'].item():.2f}, l1_loss: {out['l1loss'].item():.2f}, \
              total_loss = {loss.item():.2f}, nll loss: {reconstructed_nll_loss}")

        if wandb_log:
            wandb.log({'losses/mse_loss': out['mse_loss'].item(),
                    'losses/l1_loss': out['l1loss'].item(),
                    'losses/total_loss': loss.item(),
                    'losses/nll_loss': reconstructed_nll_loss,
                    'debug/l0_norm': torch.mean(torch.count_nonzero(out['f'], dim=-1), dtype=torch.float32),
                    'debug/dictionary_vector_ave_length': torch.mean(torch.linalg.vector_norm(sae.dec.weight, dim=0)),
                    })
        
        #TODO: compute feature density histograms

    # if i > 0 and i % 25000 == 0:
    # TODO: resample neurons

if wandb_log:
    wandb.finish()

data has shape (1500, 512)




In [34]:
sae.dec.weight, sae.dec.weight.grad

(Parameter containing:
 tensor([[-0.0252,  0.0288, -0.0203,  ..., -0.0102,  0.0053, -0.0002],
         [ 0.0047, -0.0002,  0.0252,  ..., -0.0066, -0.0299,  0.0206],
         [-0.0115, -0.0143,  0.0305,  ...,  0.0191, -0.0277, -0.0086],
         ...,
         [-0.0166, -0.0164, -0.0159,  ...,  0.0298,  0.0079, -0.0017],
         [-0.0210, -0.0062,  0.0021,  ..., -0.0260,  0.0216, -0.0265],
         [ 0.0096, -0.0087,  0.0194,  ...,  0.0227,  0.0022, -0.0062]],
        requires_grad=True),
 tensor([[ 4.2569e-05, -1.6896e-04,  4.4277e-06,  ..., -1.2405e-04,
          -3.4062e-06, -1.4735e-05],
         [ 5.4584e-05,  8.3731e-05,  1.1390e-05,  ...,  3.7421e-05,
           4.3136e-06,  4.2958e-05],
         [ 9.4619e-06, -3.2001e-05,  1.0430e-05,  ..., -3.3968e-05,
          -1.7464e-05,  3.5828e-06],
         ...,
         [ 9.5980e-06, -1.7804e-06, -2.0746e-05,  ...,  2.3471e-05,
           3.6533e-06, -9.6159e-06],
         [-6.2396e-06,  1.1661e-05, -2.6781e-05,  ...,  4.7548e-05,
     

In [39]:
a = torch.randint(9, (3, 4)) # weight
b = torch.randint(9, (3, 4)) # grad
a, b

(tensor([[1, 1, 8, 7],
         [3, 8, 5, 0],
         [5, 6, 1, 7]]),
 tensor([[8, 0, 4, 4],
         [5, 7, 6, 0],
         [2, 8, 8, 4]]))

In [42]:
a*b, torch.sum(a*b, dim=0)

(tensor([[ 8,  0, 32, 28],
         [15, 56, 30,  0],
         [10, 48,  8, 28]]),
 tensor([ 33, 104,  70,  56]))

In [65]:
c = b - torch.sum(b * F.normalize(a.to(dtype=torch.float16), dim=0), dim=0)

In [66]:
a*c, torch.sum(a*c, dim=0)

(tensor([[  2.4219, -10.3516, -27.0312, -11.5938],
         [ -1.7344, -26.8125,  -6.8945,  -0.0000],
         [-17.8906, -14.1094,   0.6211, -11.5938]], dtype=torch.float16),
 tensor([-17.2031, -51.2812, -33.3125, -23.1875], dtype=torch.float16))

In [53]:
# compute the component of grad in the direction of weight
c = b - b * (F.normalize(a.to(dtype=torch.float16), dim=0))