|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 6:</h2>|<h1>Intervention (causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Editing hidden states<h1>|
|<h2>Lecture:</h2>|<h1><b>CodeChallenge: Measure and correct BERT's bias<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

# Exercise 1: Hook the BERT model

In [None]:
from transformers import BertTokenizer, BertForMaskedLM

# Load BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
model = BertForMaskedLM.from_pretrained('bert-large-uncased')

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

In [None]:
# indices (redefined later)
layer2replace = 40000 # no replacement...
hs_vector2replace = torch.zeros(model.config.hidden_size)


mixture = [.1,.9]

# hooking functions
def implant_hook(layer_number):
  def hook(module, input, output):

    # only change this layer if there's a matching variable value
    if layer_number == layer2replace:

      # unpack tuple
      hidden, *rest = output

      # mix the old and the new
      hidden[0,maskTarget_idx,:] = mixture[0]*hidden[0,maskTarget_idx,:] + mixture[1]*hs_vector2replace


      # reconstruct output
      output = tuple([hidden]+rest)

      print(f'Replaced layer {layer_number:2}')

    return output
  return hook


# loop over layers and do surgery
handles = []
for layeri in range(model.config.num_hidden_layers):
  h = model.bert.encoder.layer[layeri].register_forward_hook(implant_hook(layeri))
  handles.append(h)

# Exercise 2: Test for a gender bias in BERT

In [None]:
# list of target words
target_words = [ 'he','she','they' ]

# tokenize sentences
tokens_he   = tokenizer(f'The engineer informed the client that he would need more time.',return_tensors='pt')
tokens_she  = tokenizer(f'The engineer informed the client that she would need more time.',return_tensors='pt')
tokens_they = tokenizer(f'The engineer informed the client that they would need more time.',return_tensors='pt')

# tokenize the masked sentence
tokens_mask = tokenizer(f'The engineer informed the client that {tokenizer.mask_token} would need more time.',return_tensors='pt')

In [None]:
# the mask index
maskTarget_idx = torch.where(tokens_mask['input_ids'][0] == tokenizer.mask_token_id)[0].item()

# token indices of target words
targets_idx = [tokenizer.encode(t)[1] for t in target_words]

# print out the tokens
for t in tokens_mask['input_ids'][0]:
  print(f'{t:5}: "{tokenizer.decode(t)}"')

print(f'\nThe mask is in token index {maskTarget_idx}\n')
for t in targets_idx:
  print(f'Target "{tokenizer.decode(t)}" is index {t}')

In [None]:
# redefine as outside the range, in case you want to rerun this code later
layer2replace = 40000

# forward-pass the four versions
with torch.no_grad():
  out_he = model(**tokens_he.to(device),output_hidden_states=True)
  out_she = model(**tokens_she.to(device),output_hidden_states=True)
  out_they = model(**tokens_they.to(device),output_hidden_states=True)
  out_mask = model(**tokens_mask.to(device),output_hidden_states=True)

In [None]:
# grab and visualize the log-softmax

fig,axs = plt.subplots(2,3,figsize=(10,5))

# for "he"
logsm = F.log_softmax(out_he.logits[0,maskTarget_idx,:],dim=-1).detach().cpu()
axs[0,0].bar(range(3),logsm[targets_idx])
axs[1,0].bar(range(3),torch.exp(logsm[targets_idx]))
axs[0,0].set(xticks=range(3),xticklabels=target_words,ylabel='Log-softmax',title='Probs. in "he" sentence')
axs[1,0].set(xticks=range(3),xticklabels=target_words,xlabel='Target words',ylabel='Softmax prob')


# for "she"
logsm = F.log_softmax(out_she.logits[0,maskTarget_idx,:],dim=-1).detach().cpu()
axs[0,1].bar(range(3),logsm[targets_idx])
axs[1,1].bar(range(3),torch.exp(logsm[targets_idx]))
axs[0,1].set(xticks=range(3),xticklabels=target_words,ylabel='Log-softmax',title='Probs. in "she" sentence')
axs[1,1].set(xticks=range(3),xticklabels=target_words,xlabel='Target words',ylabel='Softmax prob')

# for "they"
logsm = F.log_softmax(out_they.logits[0,maskTarget_idx,:],dim=-1).detach().cpu()
axs[0,2].bar(range(3),logsm[targets_idx])
axs[1,2].bar(range(3),torch.exp(logsm[targets_idx]))
axs[0,2].set(xticks=range(3),xticklabels=target_words,ylabel='Log-softmax',title='Probs. in "they" sentence')
axs[1,2].set(xticks=range(3),xticklabels=target_words,xlabel='Target words',ylabel='Softmax prob')


plt.tight_layout()
plt.show()

In [None]:
# grab and visualize the log-softmax
logsm = F.log_softmax(out_mask.logits[0,maskTarget_idx,:],dim=-1).detach().cpu()

fig,axs = plt.subplots(1,2,figsize=(10,3.5))

axs[0].bar(range(3),logsm[targets_idx])
axs[1].bar(range(3),torch.exp(logsm[targets_idx]))

axs[0].set(xticks=range(3),xticklabels=target_words,xlabel='Target words',ylabel='Log-softmax',title='Log-softmax for masked word')
axs[1].set(xticks=range(3),xticklabels=target_words,xlabel='Target words',ylabel='Softmax prob.',title='Softmax probability for masked word')

fig.suptitle(tokenizer.decode(tokens_mask['input_ids'][0,1:-1]),fontweight='bold')

plt.tight_layout()
plt.show()

# Exercise 3: Edit in an anti-bias?

In [None]:
# get s/he/they activation from one hidden state

layer2replace = 10
hs_vector2replace = out_she.hidden_states[layer2replace+1][0,maskTarget_idx,:]

with torch.no_grad():
  out_mask_replace = model(**tokens_mask.to(device),output_hidden_states=True)

In [None]:
# grab and visualize the log-softmax
logsm_orig = F.log_softmax(out_mask.logits[0,maskTarget_idx,:],dim=-1).detach().cpu()
logsm_repl = F.log_softmax(out_mask_replace.logits[0,maskTarget_idx,:],dim=-1).detach().cpu()

fig,axs = plt.subplots(1,2,figsize=(10,3.5))

axs[0].bar(np.arange(3)-.2,logsm_orig[targets_idx],width=.5,label='Original')
axs[0].bar(np.arange(3)+.2,logsm_repl[targets_idx],width=.5,label='Modified')
axs[0].legend()
axs[0].set(xticks=range(3),xticklabels=target_words,xlabel='Target words',ylabel='Log-softmax',title='Log-softmax for masked word')

axs[1].bar(np.arange(3)-.2,torch.exp(logsm_orig[targets_idx]),width=.5,label='Original')
axs[1].bar(np.arange(3)+.2,torch.exp(logsm_repl[targets_idx]),width=.5,label='Modified')
axs[1].legend()
axs[1].set(xticks=range(3),xticklabels=target_words,xlabel='Target words',ylabel='Softmax prob.',title='Softmax probability for masked word')

fig.suptitle(tokenizer.decode(tokens_mask['input_ids'][0,1:-1]),fontweight='bold')

plt.tight_layout()
plt.show()

In [None]:
bias_orig = logsm_orig[targets_idx[0]] - logsm_orig[targets_idx[1]]
bias_repl = logsm_repl[targets_idx[0]] - logsm_repl[targets_idx[1]]

print(f'Bias (he-she) in original model: {bias_orig:.3f}')
print(f'Bias (he-she) in modified model: {bias_repl:.3f}')

# Exercise 4: Laminar profile of anti-bias impact

In [None]:
mixture = [.5,.5]

In [None]:
bias_scores = torch.zeros(model.config.num_hidden_layers)

for layer2replace in range(model.config.num_hidden_layers):

  # vector to replace (from "she" sentence)
  hs_vector2replace = out_she.hidden_states[layer2replace+1][0,maskTarget_idx,:]

  # forward-pass with hook to replace
  with torch.no_grad():
    out_mask_replace = model(**tokens_mask.to(device),output_hidden_states=True)

  # calculate the log-sm probabilities
  logsm_repl = F.log_softmax(out_mask_replace.logits[0,maskTarget_idx,:],dim=-1).detach().cpu()

  # calculate the bias towards "he"
  bias_scores[layer2replace] = logsm_repl[targets_idx[0]] - logsm_repl[targets_idx[1]]

In [None]:
plt.figure(figsize=(8,3))
plt.plot(bias_scores,'ko-',markerfacecolor=[.7,.9,.7],markersize=10,linewidth=.5)
plt.axhline(0,linestyle='--',zorder=-3,color='gray')
plt.gca().set(xlabel='Layer of replacement',ylabel='Bias score')
plt.show()

In [None]:
for h in handles:
  h.remove()