|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Investigating layers<h1>|
|<h2>Lecture:</h2>|<h1><b>The Logit Lens<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, GPT2Tokenizer

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
# load GPT2 model and tokenizer
model = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

model.eval()

# Text and activations

In [None]:
text = 'The way you do anything is the way you do everything'
tokens = tokenizer.encode(text,return_tensors='pt')
numTokens = len(tokens[0])

with torch.no_grad():
  output = model(tokens,output_hidden_states=True)

In [None]:
len(output.hidden_states), output.hidden_states[3].shape

# Illustration of logit-lens procedure and outcome

In [None]:
# reference:
# https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens

In [None]:
# get the logits and softmax them

# extract the activations from one layer
activations = output.hidden_states[3][0]

# and the unembedding matrix (tied to the initial embedding in GPT2)
unembedding = model.lm_head.weight.detach()

# calculate the raw logits
logits = activations @ unembedding.t()

# check the shape
logits.shape

In [None]:
# softmax and plot
lsm_outs = F.log_softmax(logits,dim=-1)

# max-softmax is the next prediction
predictedToken = np.argmax(lsm_outs[3,:])
print(f'Next token in the text is "{tokenizer.decode(tokens[0,4])}"')
print(f'Predicted next token at this layer is "{tokenizer.decode(predictedToken)}"')

# show softmax for one token
plt.figure(figsize=(8,4))
plt.plot(predictedToken,lsm_outs[3,predictedToken],'ro',markersize=8)
plt.plot(lsm_outs[3,:],'k.',alpha=.3)
plt.gca().set(xlabel='Token index',ylabel='Log-softmax prob',xlim=[-10,tokenizer.vocab_size+9],
              title=f'Log-softmax logits for the token following "{tokenizer.decode(tokens[0,3])}"')
plt.show()

# Logit-lens over all layers

In [None]:
# initialize an empty list
all_token_predictions = []

# initialize softmax probs
softmax_probs = np.zeros((model.config.n_layer,numTokens))


for layeri,acts in enumerate(output.hidden_states[1:]): # [1:] to skip embedding layer

  # calculate the logits
  logits = acts[0] @ unembedding.t()

  # find predicted next tokens
  # note: we don't need to softmax b/c it doesn't affect argmax
  predictedNextToks = logits.argmax(dim=-1)

  # but here take softmax for subsequent visualization
  sm = F.softmax(logits,-1) # [tokens, vocab]
  softmax_probs[layeri,:] = [sm[i,pi].item() for i,pi in enumerate(predictedNextToks)]

  # get the text predictions for all tokens in the text
  all_token_predictions.append([tokenizer.decode([i.item()]) for i in predictedNextToks])


In [None]:
all_token_predictions

In [None]:
print('Original text:\n', text, '\n')
print('Predictions at first transformer block:\n', ''.join(all_token_predictions[0]), '\n')
print('Predictions at final transformer block:\n', ''.join(all_token_predictions[-1]))

# Visualization

In [None]:
fig,ax = plt.subplots(1,figsize=(10,5))

# original text (separated into a list of decoded tokens)
target = [tokenizer.decode(t) for t in tokens[0]]

# loop over layers
for layeri,layerToks in enumerate(all_token_predictions[:]):

  # y-axis coordinate for this layer
  yCoord = 1-layeri/model.config.n_layer

  # print the layer number in the left margin
  ax.text(-.07,yCoord,f'Layer {layeri+1}:',ha='right')

  # loop over the predicted tokens in this layer
  for xi,tok in enumerate(layerToks):
    ax.text(xi/numTokens,yCoord,tok,ha='center',
            bbox=dict(boxstyle='round,pad=0.3', facecolor=mpl.cm.Reds(softmax_probs[layeri,xi]), edgecolor='none',alpha=.4))

ax.axis('off')

# finally, draw the target tokens at the bottom
ax.text(-.07,yCoord-.1,f'Target:',ha='right',fontweight='bold')
for xi,tok in enumerate(target[1:]):
  ax.text(xi/numTokens,yCoord-.1,tok,ha='center',fontweight='bold')

plt.show()