|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 6:</h2>|<h1>Intervention (causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Interfering with attention<h1>|
|<h2>Lecture:</h2>|<h1><b>CodeChallenge: Head and token patching in IOI<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

# Exercise 1: Model, tokens

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GPT2LMHeadModel.from_pretrained('gpt2-large').to(device)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

model.eval()

In [None]:
text_A = 'When Sam and Sally went to the park, Sam gave a gift to'
text_B = 'When Sam and Sally went to the park, Sally gave a gift to'

target_A = tokenizer.encode(' Sam')[0]
target_B = tokenizer.encode(' Sally')[0]

tokensA = tokenizer.encode(text_A,return_tensors='pt').to(device)
tokensB = tokenizer.encode(text_B,return_tensors='pt').to(device)

In [None]:
# some useful variables
nbatches,ntokens = tokensA.shape

nheads = model.config.n_head
nlayers = model.config.n_layer
n_emb = model.config.n_embd
head_dim = model.config.n_embd // nheads

In [None]:
# dictionary to store head activations from clean model "A"
head_acts_A = {}

def hook4attn_acts(layer_number):
  def hook(module,input):

    # reshape for indexing convenience
    head_tensor = input[0].view(nbatches,ntokens,nheads,head_dim)

    # and store
    head_acts_A[f'L{layer_number}'] = head_tensor
  return hook


handles = []
for layeri in range(nlayers):
  h = model.transformer.h[layeri].attn.c_proj.register_forward_pre_hook(hook4attn_acts(layeri))
  handles.append(h)

In [None]:
# Get "clean" data on texts (no patching)
with torch.no_grad():
  outA = model(tokensA)

# remove the hooks to avoid overwriting head_acts_A
for h in handles: h.remove()

# now run tokensB without the hooks
with torch.no_grad():
  outB = model(tokensB)

In [None]:
logitDiff_A = outA.logits[0,-1,target_A] - outA.logits[0,-1,target_B]
logitDiff_B = outB.logits[0,-1,target_A] - outB.logits[0,-1,target_B]

print(f'Logit difference for text "A": {logitDiff_A:6.3f}')
print(f'Logit difference for text "B": {logitDiff_B:6.3f}')

In [None]:
# check the activations sizes
head_acts_A.keys(), head_acts_A['L4'].shape

# Exercise 2: Layer-, head-, and token-specific patching

In [None]:
# initializations
logitDiffs = np.zeros((nlayers,nheads))


# loop over layers
for layeri in range(nlayers):

  # loop over heads
  for headi in range(nheads):

    # patch this layer and one head
    def hook2patch(module,input):

      # reshape to index heads
      head_tensor = input[0].reshape(nbatches,ntokens,nheads,head_dim).clone()

      # patch (replace final token with those from tokensA)
      head_tensor[:,-1,headi,:] = head_acts_A[f'L{layeri}'][:,-1,headi,:]

      # reshape back to tensor
      head_tensor = head_tensor.reshape(nbatches,ntokens,n_emb)

      # return a tuple to replace the original
      input = (head_tensor,*input[1:])
      return input

    # implant the hook
    handle = model.transformer.h[layeri].attn.c_proj.register_forward_pre_hook(hook2patch)

    # forward pass with hook
    with torch.no_grad(): outB = model(tokensB)

    # remove the hook
    handle.remove()

    # now for the logit-difference test
    logitDiffs[layeri,headi] = outB.logits[0,-1,target_A] - outB.logits[0,-1,target_B]


In [None]:
# visualization
plt.figure(figsize=(11,4))

# plot the logit differences for the "clean" runs (no patching)
plt.axhline(logitDiff_A.cpu(),color='b',label='Clean "A"')
plt.axhline(logitDiff_B.cpu(),color='r',label='Clean "B"')

# then for the experiment results
for i in range(nlayers):
  plt.plot(np.ones(nheads)*i,logitDiffs[i,:],'ko',markerfacecolor=mpl.cm.plasma(i/(n_layers-1)),
           markersize=10,alpha=.5)

# the dividing line
plt.axhline(0,linestyle='--',color='gray',linewidth=.5)
plt.text(0,.1,'Prefer "Sam"',fontsize=12,va='bottom')
plt.text(0,-.1,'Prefer "Sally"',fontsize=12,va='top')

plt.gca().set(xlabel='Transformer block',ylabel='Logit difference',title='Reversing logit bias towards target-Sam')
plt.legend()
plt.show()