|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 6:</h2>|<h1>Intervention (causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Modifying MLP<h1>|
|<h2>Lecture:</h2>|<h1><b>Statistical criteria to lesion MLP neurons<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.gridspec import GridSpec

import scipy.stats as stats
from statsmodels.stats.multitest import fdrcorrection

import torch
import torch.nn.functional as F

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [None]:
from transformers import BertTokenizer, BertForMaskedLM

# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval()

In [None]:
nneurons = model.bert.encoder.layer[4].intermediate.dense.weight.shape[0]
nneurons, model.bert.encoder.layer[4].intermediate

# Part 1: Get MLP t-values for him-vs-her

In [None]:
# dictionary to store the mlp activations
mlp_values = {}

def implant_hook(layer_number):
  def hook(module, input, output):
    mlp_values[f'L{layer_number}'] = output.detach().numpy() # detach from the computational graph
  return hook


# surgery ;)
whichlayer = 9
handle = model.bert.encoder.layer[whichlayer].intermediate.dense.register_forward_hook(implant_hook(whichlayer))

# Forward pass and get activations

In [None]:
# generated by Claude.ai
sentences = [
    "I saw him at the market.",
    "She gave him the book.",
    "They asked him for advice.",
    "We invited him to dinner.",
    "The dog followed him home.",
    "They asked him to join.",
    "He saw him at the park yesterday.",
    "Did you give him your address?",
    "I haven't seen him in ages.",
    "I told him the truth.",
    "They congratulated him on his success.",
    "She recognized him immediately.",
    "The teacher praised him for his work.",
    "I met him last summer.",
    "The child hugged him tightly.",
    "They warned him about the danger.",
    "She drove him to the airport.",
    "We waited for him for hours.",
    "The cat scratched him accidentally.",
    "They surprised him with a gift.",
    "She called him on the phone.",
    "The jury found him not guilty.",
    "I remembered him from school.",
    "They elected him as president.",
    "She forgave him for his mistake.",
    "The police questioned him yesterday.",
    "I helped him with his homework.",
    "They spotted him in the crowd.",
    "She visited him in the hospital.",
    "The manager promoted him last week.",
    "I trusted him completely.",
    "They respected him for his honesty.",
    "She taught him how to swim.",
    "The bird attacked him suddenly.",
    "I greeted him warmly.",
    "They supported him through difficult times.",
    "She ignored him at the party.",
    "The judge sentenced him to community service.",
    "I photographed him during the event.",
    "They believed him despite the evidence.",
    "She surprised him on his birthday.",
    "The guard stopped him at the entrance.",
    "I missed him terribly.",
    "They watched him leave the building.",
    "She accompanied him to the concert.",
    "The crowd cheered him enthusiastically.",
    "I described him to the police.",
    "They thanked him for his help.",
    "She admired him for his courage.",
    "The committee nominated him for the award.",
    "I married him last spring.",
    "They informed him about the changes.",
    "She introduced him to the parents.",
    "The author based the character on him.",

## same sentences but with "her"

    "I saw her at the market.",
    "She gave her the book.",
    "They asked her for advice.",
    "We invited her to dinner.",
    "The dog followed her home.",
    "They asked her to join.",
    "He saw her at the park yesterday.",
    "Did you give her your address?",
    "I haven't seen her in ages.",
    "I told her the truth.",
    "They congratulated her on his success.",
    "She recognized her immediately.",
    "The teacher praised her for his work.",
    "I met her last summer.",
    "The child hugged her tightly.",
    "They warned her about the danger.",
    "She drove her to the airport.",
    "We waited for her for hours.",
    "The cat scratched her accidentally.",
    "They surprised her with a gift.",
    "She called her on the phone.",
    "The jury found her not guilty.",
    "I remembered her from school.",
    "They elected her as president.",
    "She forgave her for his mistake.",
    "The police questioned her yesterday.",
    "I helped her with his homework.",
    "They spotted her in the crowd.",
    "She visited her in the hospital.",
    "The manager promoted her last week.",
    "I trusted her completely.",
    "They respected her for his honesty.",
    "She taught her how to swim.",
    "The bird attacked her suddenly.",
    "I greeted her warmly.",
    "They supported her through difficult times.",
    "She ignored her at the party.",
    "The judge sentenced her to community service.",
    "I photographed her during the event.",
    "They believed her despite the evidence.",
    "She surprised her on his birthday.",
    "The guard stopped her at the entrance.",
    "I missed her terribly.",
    "They watched her leave the building.",
    "She accompanied her to the concert.",
    "The crowd cheered her enthusiastically.",
    "I described her to the police.",
    "They thanked her for his help.",
    "She admired her for his courage.",
    "The committee nominated her for the award.",
    "I married her last spring.",
    "They informed her about the changes.",
    "She introduced her to the parents.",
    "The author based the character on her."
]

# indices of him/her sentences
him_sentences = np.arange(len(sentences)//2)
her_sentences = np.arange(len(sentences)//2,len(sentences))

print(f'There are {len(sentences)} sentences.')

In [None]:
# identify the target token
target_token_him = tokenizer.encode('him',add_special_tokens=False)[0]
target_token_her = tokenizer.encode('her',add_special_tokens=False)[0]
print(f'The target token indices are {target_token_him} and {target_token_her}\n')

# tokenize
tokens = tokenizer(sentences,padding=True,return_tensors='pt')

In [None]:
tokens

# Forward pass and get the activations

In [None]:
with torch.no_grad():
  model(**tokens)

handle.remove()

mlp_values[f'L{whichlayer}'].shape

In [None]:
# histogram bin edges
binedges = np.linspace(-10,10,101)

plt.figure(figsize=(10,4))
y,_ = np.histogram(mlp_values[f'L{whichlayer}'].flatten(),binedges)
plt.plot(binedges[:-1],y,linewidth=2,label='Pre-gelu')

y,_ = np.histogram(F.gelu(torch.tensor(mlp_values[f'L{whichlayer}'].flatten())),binedges)
plt.plot(binedges[:-1],y,linewidth=2,label='Post-gelu')

plt.legend()
plt.gca().set(xlim=binedges[[0,-1]],xlabel='Activation value',ylabel='Count',yscale='log')
plt.show()

In [None]:
# loop through sentences to get target activations

acts = np.zeros((len(sentences),mlp_values[f'L{whichlayer}'].shape[2]))

for senti in range(len(sentences)):

  # find the index of either of the target tokens
  targBool = np.isin(tokens['input_ids'][senti].numpy(),[target_token_him,target_token_her])
  targidx = np.where(targBool)[0]

  # reminder: the np.where() code works fine here b/c each sentence contains exactly one occurrance of the target
  # see other code files for multiple target words per sentence.

  # then get the activation
  acts[senti,:] = mlp_values[f'L{whichlayer}'][senti,targidx,:]

acts.shape

In [None]:
# t-test and find significant neurons via FDR (correction for multiple comparisons)
tres = stats.ttest_1samp(acts[:54,:]-acts[54:,:],popmean=0,axis=0)
issig = tres.pvalue < fdrcorrection(tres.pvalue)[0]

himNeurons = issig & (tres.statistic>0)
herNeurons = issig & (tres.statistic<0)

plt.figure(figsize=(10,5))
plt.plot(np.where(~issig)[0],tres.statistic[~issig],'rx',alpha=.4,label='Non-sig.')
plt.plot(np.where(himNeurons)[0],tres.statistic[himNeurons],'ks',markerfacecolor=[.7,.9,.7,.5],label='him > her')
plt.plot(np.where(herNeurons)[0],tres.statistic[herNeurons],'ks',markerfacecolor=[.9,.7,.7,.5],label='her > him')

plt.legend()
plt.gca().set(xlim=[-1,nneurons],xlabel='MLP expansion neurons',ylabel='T-value',
              title=f'{himNeurons.sum()} "him" neurons and {herNeurons.sum()} "her" neurons')
plt.show()

# Part 2: Confirm BERT can make accurate predictions

In [None]:
texts = [ 'Robert helped Lucy with her project, and she thanked him for his hard work.',
          'Robert helped Lucy with [MASK] project, and she thanked him for his hard work.',
          'Robert helped Lucy with her project, and she thanked [MASK] for his hard work.' ]

# tokenize
tokens = tokenizer(texts,return_tensors='pt')
tokens

In [None]:
# find indices of [MASK]
mask_idx_her = torch.where(tokens['input_ids'][1] == tokenizer.mask_token_id)[0].item()
mask_idx_him = torch.where(tokens['input_ids'][2] == tokenizer.mask_token_id)[0].item()

print(f'Masks are at indices {mask_idx_her} and {mask_idx_him}')

In [None]:
with torch.no_grad():
  out = model(**tokens)

logits = out.logits.detach().numpy()
logits.shape

In [None]:
target_logits = np.zeros((3,2,2))

for senti in range(3):
  target_logits[senti,0,:] = logits[senti,mask_idx_her,[target_token_her,target_token_him]]
  target_logits[senti,1,:] = logits[senti,mask_idx_him,[target_token_her,target_token_him]]

target_logits

In [None]:
plt.figure(figsize=(10,4))

for i in range(3):
  plt.bar(np.array([-.1,.1])+i*1.5,target_logits[i,0,:],width=.2,facecolor=[[.9,.7,.9],[.7,.9,.9]],edgecolor='k')
  plt.bar(np.array([-.1,.1])+i*1.5+.5,target_logits[i,1,:],width=.2,facecolor=[[.9,.7,.9],[.7,.9,.9]],edgecolor='k')

# create the bar labels
basetxt = 'her :: him     her :: him\n----------------+----------------\nher token   ||   him token'
xticklabels = [ basetxt + '\n\n|__________$\\bf{Clean}$_________|',
                basetxt + '\n\n|_______$\\bf{HER\; mask}$_______|',
                basetxt + '\n\n|_______$\\bf{HIM\; mask}$_______|' ]

plt.gca().set(xticks=np.arange(.25,3.5,1.5),xticklabels=xticklabels,ylabel='Logits')
plt.show()

# Part 3: Manipulate "him" and "her" neurons in MLP

In [None]:
def implant_hook(layer_number):
  def ablation_hook(module, input, output):
    output[1,mask_idx_her,herNeurons] = 0
    output[2,mask_idx_him,himNeurons] = 0
    return output
  return ablation_hook

# question: should we hook into intermediate or intermediate.dense?
handle = model.bert.encoder.layer[whichlayer].intermediate.register_forward_hook(implant_hook(whichlayer))


In [None]:
with torch.no_grad():
  out = model(**tokens)

handle.remove()
logitsZero = out.logits.detach().numpy()
logitsZero.shape

In [None]:
target_logitsZ = np.zeros((3,2,2))

for senti in range(3):
  target_logitsZ[senti,0,:] = logitsZero[senti,mask_idx_her,[target_token_her,target_token_him]]
  target_logitsZ[senti,1,:] = logitsZero[senti,mask_idx_him,[target_token_her,target_token_him]]

target_logitsZ

In [None]:
fig = plt.figure(figsize=(13,5))
gs = GridSpec(1,3,figure=fig)
ax0 = fig.add_subplot(gs[:2])
ax1 = fig.add_subplot(gs[2])

deltaLogits = target_logits - target_logitsZ

for i in range(3):
  ax0.bar(np.array([-.1,.1])+i*1.5,deltaLogits[i,0,:],width=.2,facecolor=[[.9,.7,.9],[.7,.9,.9]],edgecolor='k')
  ax0.bar(np.array([-.1,.1])+i*1.5+.5,deltaLogits[i,1,:],width=.2,facecolor=[[.9,.7,.9],[.7,.9,.9]],edgecolor='k')

ax0.axhline(0,color='k',linewidth=.2)
ax0.set(xticks=np.arange(.25,3.5,1.5),xticklabels=xticklabels,ylabel='$\Delta$ logits',
        title=f'$\Delta$ from clean (ablation in layer {whichlayer})')

ax1.plot(target_logits.flatten(),target_logitsZ.flatten(),'ko',markerfacecolor=[.7,.9,.9,.6],markersize=10)
ax1.set(xlabel='Clean logits',ylabel='Ablation logits')
ax1.grid(linewidth=.4)

plt.tight_layout()
plt.show()