|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Investigating neurons and dimensions<h1>|
|<h2>Lecture:</h2>|<h1><b>CodeChallenge: Activation histograms by token length<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
# run first to install and then restart
# !pip install -U datasets huggingface_hub fsspec

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

from datasets import load_dataset

# vector plots
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

# Exercise 1: Import the model and implant hooks

In [None]:
# for exercises 1-6
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-125m')
model = AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-125m')

# for exercise 7
# tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-1.3B')
# model = AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-1.3B')

In [None]:
# use GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# move the model to the GPU and switch to eval
model = model.to(device)
model.eval()

In [None]:
# hook function
activations = {}

def implant_hook(layer_number):
  def hook(module, input, output):

    # store in the dictionary
    activations[f'mlp_{layer_number}'] = module.c_fc(input[0]).cpu()
  return hook


# put hooks in all layers
for layeri in range(model.config.num_layers):
  model.transformer.h[layeri].mlp.register_forward_hook(implant_hook(layeri))

In [None]:
# number of MLP expansion neurons
nneurons = model.transformer.h[3].mlp.c_fc.weight.shape[0]

# Exercise 2: Import and tokenize fineweb

In [None]:
fineweb = load_dataset('HuggingFaceFW/fineweb', split='train', streaming=True)
fw_iterator = iter(fineweb)  # create iterator

# get multiple examples:
for _ in range(5):
  example = next(fw_iterator)
  print('\n',example['text'][:100])

In [None]:
# how many tokens in total
desiredTokenCount = 8192

# initialize empty tensor (must be ints!)
allTokens = torch.tensor([],dtype=torch.long)
allTokenLengths = np.array([])

# reinitialize iterator
fw_iterator = iter(fineweb)


# keep importing data until we have enough
while allTokens.numel()<desiredTokenCount:

  # import the text
  text = next(fw_iterator)['text']

  # tokenize
  tokens = tokenizer.encode(text,return_tensors='pt')

  # get token lengths
  tokenLengths = np.array([len(tokenizer.decode(t)) for t in tokens[0]])

  # stack the tokens and the lengths
  allTokens = torch.cat( (allTokens,tokens) ,dim=-1)
  allTokenLengths = np.concatenate( (allTokenLengths,tokenLengths) )


# trim the vectors
allTokens = allTokens[0,:desiredTokenCount]
allTokenLengths = allTokenLengths[:desiredTokenCount]

print(allTokens.shape)
print(allTokenLengths.shape)

In [None]:
# bar plot of token counts, with median
u,c = np.unique(allTokenLengths,return_counts=True)
medianTokLength = np.median(allTokenLengths)

# make the bar graph
plt.figure(figsize=(10,4))
plt.bar(u,c,color=[.7,.7,.9],edgecolor='k')
plt.axvline(medianTokLength,linestyle='--',color='k',linewidth=3,label='Median')

plt.legend()
plt.gca().set(xlabel='Token character count',ylabel='Frequency',title='Distribution of token lengths')
plt.show()

In [None]:
# print a summary
print(f'There are {sum(allTokenLengths<medianTokLength):,} tokens shorter than the median.')
print(f'There are {sum(allTokenLengths>medianTokLength):,} tokens longer than the median.')
print(f'There are {sum(allTokenLengths==medianTokLength):,} tokens equal to the median.')

# Exercise 3: Get activations

In [None]:
# get a batch of tokens
print(allTokens.shape)
batch = allTokens.reshape(16,512)
batch.shape,type(batch)

In [None]:
# forward pass the batch
# ~1 min on cpu for 125m
# 2 secs on gpu for 1.3B (lol)
with torch.no_grad():
  model(batch.to(device))

In [None]:
activations.keys()

In [None]:
# check shape -- should be batch X tokens X nneurons
activations['mlp_10'].shape

# Exercise 4: Activations distributions by median split

In [None]:
# extract and flatten activations
acts = activations['mlp_4'].reshape(-1,nneurons)

# activations by length split
binedges = torch.linspace(-8,5,51)
yS,_ = torch.histogram(acts[allTokenLengths<medianTokLength,:],bins=binedges,density=True)
yL,_ = torch.histogram(acts[allTokenLengths>medianTokLength,:],bins=binedges,density=True)
yM,_ = torch.histogram(acts[allTokenLengths==medianTokLength,:],bins=binedges,density=True)

# visualize
plt.figure(figsize=(10,5))
plt.plot(binedges[:-1],yS,linewidth=2,label='Short tokens')
plt.plot(binedges[:-1],yL,linewidth=2,label='Long tokens')
plt.plot(binedges[:-1],yM,linewidth=2,label='Median tokens')

plt.gca().set(xlim=binedges[[0,-1]],xlabel='Activations',ylabel='Density',
              title='Distribution of activations by token length')

plt.legend()
plt.show()

# Exercise 5: Activation-length correlations in one layer

In [None]:
# get the activations and numpyify
acts = activations['mlp_4'].reshape(-1,nneurons).numpy()

# standardize the activations from all neurons
zacts = (acts-acts.mean(axis=0,keepdims=True)) / np.std(acts,axis=0,ddof=1,keepdims=True)

In [None]:
# confirm
zacts.shape, zacts[:,600].mean(), zacts[:,600].std(ddof=1)

In [None]:
# normalize the token lengths
zTokenLens = (allTokenLengths-allTokenLengths.mean()) / allTokenLengths.std(ddof=1)

# confirm
zTokenLens.mean(), zTokenLens.std(ddof=1)

In [None]:
# confirm one correlation value
np.corrcoef(acts[:,0],allTokenLengths)

In [None]:
# covariance of standardized variables
sum( zacts[:,0]*zTokenLens) / (desiredTokenCount-1)

In [None]:
# calculate all correlation coefficients
allCorrs = np.zeros(nneurons)

for ni in range(nneurons):
  allCorrs[ni] = sum(zTokenLens*zacts[:,ni]) / (desiredTokenCount-1)

In [None]:
# and visualize!
plt.figure(figsize=(8,4))
plt.hist(allCorrs,bins=100,color=[.7,.9,.7],linewidth=.5,edgecolor='gray')

plt.gca().set(xlabel='Correlation coefficient',ylabel='Count',title='Histogram of all correlation coefficients')
plt.show()

# Exercise 6: Correlations in all layers

In [None]:
allCorrs = np.zeros((model.config.num_layers,nneurons))

# loop over all the layers
for layeri in range(model.config.num_layers):

  # get and normalize the activations
  acts = activations[f'mlp_{layeri}'].reshape(-1,nneurons).numpy()
  zacts = (acts-acts.mean(axis=0,keepdims=True)) / np.std(acts,axis=0,ddof=1,keepdims=True)

  # loop over all the neurons and correlate
  for ni in range(nneurons):
    allCorrs[layeri,ni] = sum(zTokenLens*zacts[:,ni]) / (desiredTokenCount-1)

In [None]:
# histograms
rEdges = torch.linspace(-.8,.8,81)
rHistCounts = np.zeros((model.config.num_layers,len(rEdges)-1))

# get histogram of each layer
for layeri in range(model.config.num_layers):
  rHistCounts[layeri,:],_ = np.histogram(allCorrs[layeri,:],bins=rEdges,density=True)

In [None]:
# and visualize
fig,axs = plt.subplots(1,2,figsize=(12,4))

for layeri in range(model.config.num_layers):
  axs[0].plot(rEdges[:-1],rHistCounts[layeri,:],color=mpl.cm.plasma(layeri/(model.config.num_layers-1)),label=f'MLP h.{layeri}')

axs[0].legend()
axs[0].axvline(0,linestyle='--',color=[.7,.7,.7])
axs[0].set(xlabel='Correlation coefficients',ylabel='Density',xlim=rEdges[[0,-1]],
           title='Correlation histograms for each layer')

# colorbar for line color (layer number)
cmap = mpl.colormaps['plasma']
norm = mpl.colors.BoundaryNorm(np.arange(model.config.num_layers), cmap.N)
sm = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
cbar = fig.colorbar(sm, ax=axs[0], pad=.01)


# image
h = axs[1].imshow(rHistCounts,aspect='auto',vmin=0,vmax=4,origin='lower',extent=[rEdges[0],rEdges[-1],0,model.config.num_layers])
axs[1].set(xlabel='Correlation coefficient',ylabel='Transformer block',title='Image of all histograms')
fig.colorbar(h,ax=axs[1],pad=.01)

plt.tight_layout()
plt.show()