|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 6:</h2>|<h1>Intervention (causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Interfering with attention <h1>|
|<h2>Lecture:</h2>|<h1><b>Head ablation and token prediction<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import torch.nn.functional as F

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval()

In [None]:
# some useful variables
nheads = model.config.n_head
n_emb = model.config.n_embd
head_dim = model.config.n_embd // nheads

In [None]:
def hook4attn(module,input):

  # if out of range, do nothing
  if head2ablate in np.arange(nheads):

    # reshape so we can index heads
    head_tensor = input[0].view(nbatches,ntokens,nheads,head_dim)

    # then replace
    head_tensor[:,:,head2ablate,:] = 0
    print(f'Zeroed out H{head2ablate}')

    # reshape back to tensor
    head_tensor = head_tensor.view(nbatches,ntokens,n_emb)

    # return a tuple matching the original
    input = (head_tensor,*input[1:])

  return input

model.transformer.h[5].attn.c_proj.register_forward_pre_hook(hook4attn)

In [None]:
tokens = tokenizer.encode('Berlin is the capital of',return_tensors='pt')
nbatches,ntokens = tokens.shape

for i in range(ntokens):
  print(f'Token position {i:2} is index {tokens[0,i]} and is "{tokenizer.decode(tokens[0,i])}"')

In [None]:
# target and semantically related nontarget
nontarget_idx = tokenizer.encode(' France')[0]
target_idx = tokenizer.encode(' Germany')[0]

# confirm single-tokens
nontarget_idx,target_idx

In [None]:
# initialize to invalid index
head2ablate = 100000

with torch.no_grad():
  out = model(tokens)

# calculate softmax probability in percent
logsm_clean = F.log_softmax(out.logits[0,-1,:],dim=-1).detach().numpy()

In [None]:
plt.figure(figsize=(10,4))

# all the log-sm values
plt.plot(logsm_clean,'k.',markersize=2,alpha=.3)

# the target and nontarget values
plt.plot(target_idx,logsm_clean[target_idx],'gs',label='Germany')
plt.plot(nontarget_idx,logsm_clean[nontarget_idx],'ro',label='France')

# make the graph look pretty :D
plt.gca().set(xlabel='Vocab elements',ylabel='Log softmax',xlim=[0,model.config.vocab_size])
plt.title(f'Predicted next token is "{tokenizer.decode(np.argmax(logsm_clean))}"',fontweight='bold')
plt.legend()

plt.show()

# Zero-out each attention head

In [None]:
resultsZero = np.zeros((nheads,3))

# loop over heads
for head2ablate in range(nheads):

  # forward pass
  with torch.no_grad():
    out = model(tokens)

  # softmax
  logsm = F.log_softmax(out.logits[0,-1,:],dim=-1).detach().numpy()

  # sm logits for target and nontarget
  resultsZero[head2ablate,0] = logsm[target_idx]
  resultsZero[head2ablate,1] = logsm[nontarget_idx]

  # and the predicted next token
  resultsZero[head2ablate,2] = np.argmax(logsm)

In [None]:
fig,axs = plt.subplots(1,2,figsize=(10,4))

axs[0].bar(range(model.config.n_layer),resultsZero[:,0] - logsm_clean[target_idx],color=[.7,.7,.9],edgecolor='k')
axs[0].axhline(0,color='gray')
axs[0].set(xlabel='Head',ylabel='$\Delta$ logit from clean',ylim=[-.4,.4],
           xticks=range(nheads),title='$\Delta$ in log-prob. for target word')

axs[1].bar(range(model.config.n_layer),resultsZero[:,1] - logsm_clean[nontarget_idx],color=[.9,.7,.7],edgecolor='k')
axs[1].axhline(0,color='gray')
axs[1].set(xlabel='Head',ylabel='$\Delta$ logit from clean',ylim=[-.4,.4],
           xticks=range(nheads),title='$\Delta$ in log-prob. for non-target word')

plt.tight_layout()
plt.show()

In [None]:
# print the token selection
for i in range(nheads):
  print(f'When ablating head {i:2}, the predicted new token was "{tokenizer.decode(int(resultsZero[i,2]))}"')