|<h2>Course:</h2>|<h1><b><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></b></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Identifing latent factors<h1>|
|<h2>Lecture:</h2>|<h1><b>CodeChallenge: Laminar profile of autoencoder sparsity<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, GPT2Tokenizer

# Exercise 1: Import text, tokenize, get MLP activations

In [None]:
# model & tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model     = AutoModelForCausalLM.from_pretrained('gpt2')

# push to GPU in eval mode
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.eval().to(device)

In [None]:
# a function to hook the activations
activations = {}

def implant_hook(layer_number):
  def hook(module, input, output):
    activations[f'mlp_{layer_number}'] = output.detach()
  return hook

# put hooks in all layers
for layer2hook in range(model.config.n_layer):
  model.transformer.h[layer2hook].mlp.c_fc.register_forward_hook(implant_hook(layer2hook))

In [None]:
import requests
text = requests.get('https://en.wikipedia.org/wiki/Light-emitting_diode').text
print(f'Full website contains {len(text):,} characters')

text = text[text.find('mw-body-content'):]
text = text[:text.find('id="References"')]
print(f'Selected text contains {len(text):,} characters')

In [None]:
tokenizer.pad_token = tokenizer.eos_token

# tokenize the text
tokens = tokenizer.encode(text)
print(f'There are {len(tokens):,} tokens')

In [None]:
u,c = np.unique(tokens,return_counts=True)
sidx = np.argsort(-c)
u = u[sidx]
c = c[sidx]

print('Top 30 most common tokens:')
for i in range(30):
  print(f'  {c[i]:4} apperances of "{tokenizer.decode(u[i])}"')

In [None]:
# create 10 batches of 1024 tokens
batches = torch.tensor(tokens[:10*1024]).reshape(10,1024).to(device)
batches.shape

In [None]:
# Forward-pass the tokens to get the activations
with torch.no_grad():
  model(batches)

# Exercise 2: Create an autoencoder class in a function

In [None]:
def createTheSAE(num_latent):

  # create the class
  class SparseAE(nn.Module):
    def __init__(self, input_dim, latent_dim, k=None, sparsity_weight=1, decor_weight=.0005):
      super().__init__()
      self.encoder = nn.Linear(input_dim, latent_dim, bias=False)
      # self.decoder = nn.Linear(latent_dim, input_dim, bias=False)
      # note: decoder is tied to encoder in forward()

      self.sparsity_weight = sparsity_weight
      self.decor_weight = decor_weight

      # k-sparse parameter defaults to 50% of input
      if k==None:
        self.k = input_dim//2
      else:
        self.k = k

    def forward(self, x):

      # forward pass to the latent layer
      latent = F.relu(self.encoder(x))

      # "k-sparsify": force sparsity by zeroing out small activations
      topk_vals = torch.topk(latent,self.k,dim=1)[0]
      thresh = topk_vals[:,-1].unsqueeze(1) # kth-largest value is the smallest of the sorted top-k
      mask = (latent >= thresh).float() # mask is 0's and 1's
      latent_sparse = latent * mask

      # finally, decode via tied weights
      y = F.linear(latent_sparse, self.encoder.weight.t())

      return y,latent_sparse

    def sparsity_loss(self, z):
      return self.sparsity_weight * torch.mean(torch.abs(z))

    # penalty on inter-latent covariance
    def decorrelation_loss(self, estLatent):
      cov = torch.cov(estLatent.T)
      off_diag = cov - torch.diag(torch.diag(cov))
      return self.decor_weight * torch.sum(off_diag**2)



  # create an instance of the autoencoder
  ae = SparseAE(input_dim=X.shape[1], k=num_latent//3, latent_dim=num_latent)
  ae = ae.to(device)
  return ae

In [None]:
### test with some data

# get an MLP layer activation
nhidden = activations['mlp_3'].shape[-1]
X = activations['mlp_3'].reshape(-1,nhidden)

# create an SAE model instance (using a 2x expansion for the latent layer)
aemodel = createTheSAE(X.shape[1]*2)
aemodel = aemodel.to(device)

# and push some data through
aemodel(X)

# Exercise 3: Train layer-specific SAEs

In [None]:
# takes ~3 mins for gpt2-small on GPU
# takes ~20 mins for gpt2-large on GPU

# initialize
latentDensity = np.zeros(model.config.n_layer)
densityActivation = np.zeros(model.config.n_layer)
finalloss = np.zeros(model.config.n_layer)

n_epochs = 75


## loop over layers
for layeri in range(model.config.n_layer):

  # get the activations from this layer, and reshape to (N x hidden)
  nhidden = activations[f'mlp_{layeri}'][:,1:,:].shape[-1]
  X = activations[f'mlp_{layeri}'].reshape(-1,nhidden) # remove first token

  # create an SAE model for this layer (using a 2x expansion for the latent layer)
  aemodel = createTheSAE(X.shape[1]*2)
  aemodel = aemodel.to(device)

  ### train the model
  optimizer = optim.Adam(aemodel.parameters(), lr=.0001)
  mse_loss  = nn.MSELoss().to(device)

  for epoch in range(n_epochs):

    # forward pass
    optimizer.zero_grad()
    x_pred,latent = aemodel(X)

    # backprop
    loss = mse_loss(x_pred,X) + aemodel.sparsity_loss(latent)
    loss.backward()
    optimizer.step()

  # final loss after all epochs
  finalloss[layeri] = loss.item()



  ## final run to get latent activations
  with torch.no_grad():
    aeout,latent = aemodel(X)

  # convert to numpy and back to CPU
  latent = latent.cpu().numpy()


  ### latent layer characteristics
  # mask for zero-valued activations
  densitymask = np.full(latent.shape,np.nan)
  densitymask[latent!=0] = 1

  # density is the percent of nonzero activations per latent component
  densityPerComponent = 100 * np.nansum(densitymask,axis=0) / densitymask.shape[0]
  latentDensity[layeri] = densityPerComponent.mean()

  # token-averaged activation magnitude, excluding zeros
  nonzeroAct = np.nanmean(np.abs(latent*densitymask),axis=0)
  nonzeroAct[np.isnan(nonzeroAct)] = 0

  # average nonzero activations per latent component (after minmax scaling)
  dpc = (densityPerComponent-densityPerComponent.min()) / (densityPerComponent.max()-densityPerComponent.min())
  nza = (nonzeroAct-nonzeroAct.min()) / (nonzeroAct.max()-nonzeroAct.min())
  densityActivation[layeri] = np.mean( dpc*nza )

  print(f'Finished layer {layeri+1:2}/{model.config.n_layer}')

# Exercise 4: Visualize SAE laminar profiles

In [None]:
# plot the results
_,axs = plt.subplots(1,3,figsize=(12,3.5))

axs[0].plot(finalloss,'ks',markerfacecolor=[.7,.7,.9],markersize=9)
axs[0].set(xlabel='Layer',ylabel='MSE Loss',title='Final SAE loss per layer')

axs[1].plot(latentDensity,'ks',markerfacecolor=[.7,.9,.7],markersize=9)
axs[1].set(xlabel='Layer',ylabel='Density (% nonzero activations)',title='Latent density per layer')

axs[2].plot(densityActivation,'ks',markerfacecolor=[.9,.7,.7],markersize=9)
axs[2].set(xlabel='Layer',ylabel='Density $\\times$ activation',title='(Density $\\times$ activation) per layer')

plt.tight_layout()
plt.show()