|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 6:</h2>|<h1>Intervention (causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Editing hidden states<h1>|
|<h2>Lecture:</h2>|<h1><b>CodeChallenge: Noisy and shuffled BERT predictions<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

import torch
import torch.nn.functional as F

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

# Exercise 1: Hook the BERT model

In [None]:
from transformers import BertTokenizer, BertForMaskedLM

# Load BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
model = BertForMaskedLM.from_pretrained('bert-large-uncased')

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

In [None]:
# indices (redefined later)
layer2shuffle = 40000 # no shuffling...
maskTarget_idx = 1


# hooking functions
def implant_hook(layer_number):
  def hook(module, input, output):

    # only change this layer
    if layer_number == layer2shuffle:

      # unpack tuple
      hidden, *rest = output

      # randomly shuffle
      acts = hidden[0,maskTarget_idx,:]
      shuffleidx = torch.randperm(model.config.hidden_size)
      hidden[0,maskTarget_idx,:] = acts[shuffleidx]

      # reconstruct output
      output = tuple([hidden]+rest)

      print(f'Shuffled layer {layer_number}')

    return output
  return hook


# loop over layers and do surgery
handles = []
for layeri in range(model.config.num_hidden_layers):
  h = model.bert.encoder.layer[layeri].register_forward_hook(implant_hook(layeri))
  handles.append(h)

# Exercise 2: Tokens and masked predictions

In [None]:
# the text
tokens = tokenizer(f'Pay no attention to that man {tokenizer.mask_token} the curtain!',return_tensors='pt')

# the mask index
maskTarget_idx = torch.where(tokens['input_ids'][0] == tokenizer.mask_token_id)[0].item()

# correct response
correct_target = tokenizer.encode('behind',add_special_tokens=False)[0]

# print out the tokens
for t in tokens['input_ids'][0]:
  print(f'{t:5}: "{tokenizer.decode(t)}"')

print(f'\nThe mask is in token index {maskTarget_idx}')

In [None]:
# redefine as outside the range, in case you want to rerun this code later
layer2shuffle = 40000

with torch.no_grad():
  out_pure = model(**tokens.to(device),output_hidden_states=True)

print(f'There are {len(out_pure.hidden_states)} hidden states,\neach of size {list(out_pure.hidden_states[2].shape)}')

In [None]:
# predicted tokens
for i in range(len(tokens['input_ids'][0])):

  # the word in the original text
  actual = tokenizer.decode(tokens['input_ids'][0,i])

  # the predicted word
  predicted = tokenizer.decode(torch.argmax(out_pure.logits[0,i,:].detach().cpu(),dim=-1))

  # print!
  print(f'{actual:>12} predicted as "{predicted}"')

# Exercise 3: Masked predictions after random shuffling

In [None]:
mask_prediction = np.zeros((model.config.num_hidden_layers,2))

for layeri in range(model.config.num_hidden_layers):

  # layer to shuffle
  layer2shuffle = layeri

  # forward pass
  with torch.no_grad(): out=model(**tokens.to(device))

  # log-softmax the logits and store the correct target's lsm
  logsm = F.log_softmax(out.logits[0,maskTarget_idx,:].detach().cpu(),dim=-1)
  mask_prediction[layeri,0] = logsm[correct_target]

  # predicted masked target
  maxidx = torch.argmax(logsm,dim=-1)
  mask_prediction[layeri,1] = maxidx


In [None]:
for layeri in range(model.config.num_hidden_layers):
  print(f'Shuffling layer {layeri:2} led to prediction: "{tokenizer.decode(int(mask_prediction[layeri,1]))}"')

In [None]:
# log-softmax the unshuffled data
logsm = F.log_softmax(out_pure.logits[0,maskTarget_idx,:].detach().cpu(),dim=-1)
pure_max_logit = logsm[correct_target]

# visualize
plt.figure(figsize=(8,3))
plt.plot(mask_prediction[:,0],'ks',markerfacecolor=[.9,.7,.7],markersize=10,label='Shuffled')
plt.axhline(pure_max_logit,color='b',zorder=-4,label='Unshuffled')

plt.gca().set(xlabel='Hidden layer that was shuffled',ylabel='Correct token logit (log-sm)')
plt.legend()
plt.show()

# Exercise 4: Add noise instead

In [None]:
for h in handles:
  h.remove()

In [None]:
# hooking functions
def implant_hook(layer_number):
  def hook(module, input, output):

    # only change one layer
    if layer_number == layer2shuffle:

      # unpack tuple
      hidden, *rest = output

      # noisify
      acts = hidden[0,maskTarget_idx,:]
      noise = torch.randn_like(acts) * acts.std()*2 # same size and slightly bigger std
      hidden[0,maskTarget_idx,:] = acts + noise

      # reconstruct output
      output = tuple([hidden]+rest)

    return output
  return hook


# loop over layers and do surgery
handles = []
for layeri in range(model.config.num_hidden_layers):
  h = model.bert.encoder.layer[layeri].register_forward_hook(implant_hook(layeri))
  handles.append(h)

In [None]:
mask_prediction = np.zeros((model.config.num_hidden_layers,2))

for layeri in range(model.config.num_hidden_layers):

  # layer to shuffle
  layer2shuffle = layeri

  # forward pass
  with torch.no_grad(): out=model(**tokens.to(device))

  # log-softmax the logits and store the correct target's lsm
  logsm = F.log_softmax(out.logits[0,maskTarget_idx,:].detach().cpu(),dim=-1)
  mask_prediction[layeri,0] = logsm[correct_target]

  # predicted masked target
  maxidx = torch.argmax(logsm,dim=-1)
  mask_prediction[layeri,1] = maxidx


for layeri in range(model.config.num_hidden_layers):
  print(f'Shuffling layer {layeri:2} led to prediction: "{tokenizer.decode(int(mask_prediction[layeri,1]))}"')

In [None]:
plt.figure(figsize=(8,3))
plt.plot(mask_prediction[:,0],'ks',markerfacecolor=[.9,.7,.7],markersize=10,label='Shuffled')
plt.axhline(pure_max_logit,color='b',label='Unshuffled')

plt.gca().set(xlabel='Hidden layer that was shuffled',ylabel='MASK location max logit')
plt.legend(loc='lower right')
plt.show()

# Exercise 5: Distributions of activations and noise

In [None]:
# # another model run with no shuffling, and with hidden states
# layer2shuffle = 98765
# with torch.no_grad():
#   out = model(**tokens.to(device),output_hidden_states=True)

# out_pure is from exercise 2
hs = out_pure.hidden_states

In [None]:
fig = plt.figure(figsize=(10,4))

# bins for the histogram
xvals4hist = np.linspace(-4,4,101)

# loop over layers
for i in range(1,len(hs)):

  # get the vectorized data
  vdat = hs[i].detach().cpu().numpy().flatten()

  # calculate and draw the histogram
  y,x = np.histogram(vdat,bins=xvals4hist)
  plt.plot(x[:-1],y,color=mpl.cm.plasma(i/(len(hs)+1)))

# and again for gaussian noise
noise = np.random.randn(len(vdat)) * vdat.std()*2
y,x = np.histogram(noise,bins=xvals4hist)
plt.plot(x[:-1],y,'k',linewidth=3,label='Gaussian noise')

# create a colorbar for the lines
norm = mpl.colors.Normalize(vmin=0,vmax=len(hs))
sm = mpl.cm.ScalarMappable(cmap=mpl.cm.plasma,norm=norm)
cbar = fig.colorbar(sm,ax=plt.gca(),pad=.01)
cbar.set_label(r'Transformer block')

# final touches
plt.legend()
plt.gca().set(xlabel='Activation values',ylabel='Count',xlim=xvals4hist[[0,-1]],
              ylim=[-10,None],title='Histograms of activation values')
plt.show()