|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 6:</h2>|<h1>Intervention (causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>How to modify activations<h1>|
|<h2>Lecture:</h2>|<h1><b>Activation manipulation: Code implementations<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch

model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model.eval()

In [None]:
# some text to process
tokens = tokenizer.encode('I wish coffee would taste like chocolate.',return_tensors='pt')

# the token we'll modify
tokenizer.decode(tokens[0,4])

# Hook demo 1 (doesn't work :P )

In [None]:
# initialize activations dictionary
activations = {}


def implant_hook(layer_number):
  def hook(module, input, output):

    # -- GOAL: split into QKV then modify in-place
    # split the output into QKV (each is [B,S,E])
    q,k,v = output.split(model.config.hidden_size, dim=2)

    # zero-out Q activations for token index 4
    q[:,4,:10] = 0
    # -- the above line crashes, b/c pytorch doesn't allow in-place editing

    # Recombine the modified q with k and v
    QKV = torch.cat([q,k,v],dim=2)

    # store the activations
    activations['qkv'] = output

    # output the QKV matrix so it replaces the original
    return QKV

  return hook



layer2modify = 3
hookHandle = model.transformer.h[layer2modify].attn.c_attn.register_forward_hook(implant_hook(layer2modify))

In [None]:
# push the tokens through the model
model(tokens)

In [None]:
hookHandle.remove()

# Hook demo 2 (works but might require careful indexing)

In [None]:
# initialize activations dictionary
activations = {}


def implant_hook(layer_number):
  def hook(module, input, output):

    # this works (though you'd need to index carefully to access K or V)
    output[:,4,:10] = 0

    # store the activations
    activations['qkv'] = output

    # output the QKV matrix so it replaces the original
    return output

  return hook



layer2modify = 3
hookHandle = model.transformer.h[layer2modify].attn.c_attn.register_forward_hook(implant_hook(layer2modify))

In [None]:
# push the tokens through the model
model(tokens)
activations

In [None]:
activations['qkv'][:,:,:10]

In [None]:
hookHandle.remove()

# Hook demo 3 (works b/c of out-of-place instead of in-place operations)

In [None]:
# initialize activations dictionary
activations = {}


def implant_hook(layer_number):
  def hook(module, input, output):

    # -- GOAL: split into QKV then modify a copy of q
    # split the output into QKV (each is [B,S,E])
    q,k,v = output.split(model.config.hidden_size, dim=2)

    # zero-out Q activations for token index 4
    q_copy = q.clone() # copy of q -- not a slice, so no in-place operations
    q_copy[:,4,:10] = 0 # static manipulation

    # recombine the modified q with k and v
    QKV = torch.cat([q_copy,k,v],dim=2)

    # store the activations
    activations['qkv'] = QKV

    # output the QKV matrix so it replaces the original
    return QKV

  return hook



layer2modify = 3
hookHandle = model.transformer.h[layer2modify].attn.c_attn.register_forward_hook(implant_hook(layer2modify))

In [None]:
model(tokens)
activations

In [None]:
# confirmation
activations['qkv'][0,4,:]

In [None]:
# remove the hook
hookHandle.remove()

# Hook demo 4: cache all layers and manipulate only one

In [None]:
# initialize activations dictionary
activations = {}


def implant_hook(layer_number):
  def hook(module, input, output):

    # modify the activation only for this layer
    if layer_number==3:
      output[:,4,:10] = 0

    # store the activations
    activations[f'qkv_{layer_number}'] = output

    # output the QKV matrix so it replaces the original (unchanged for non-target layers)
    return output

  return hook


handles = []
for layeri in range(12):
  h = model.transformer.h[layeri].attn.c_attn.register_forward_hook(implant_hook(layeri))
  handles.append(h) # get all the handles for later removal

In [None]:
handles

In [None]:
# push through the model
model(tokens)
activations.keys()

In [None]:
# confirm
for i in range(12):
  firstQs = activations[f'qkv_{i}'][0,4,:5].detach()
  print(f'Q acts from layer {i:2}:',firstQs)

In [None]:
# remove all handles
for h in handles:
  h.remove()

# Hook demo 5: Dynamic manipulation

In [None]:
# initialize activations dictionary
activations = {}
q2replace = torch.zeros(10)

def implant_hook(layer_number):
  def hook(module, input, output):

    # replace with a variable name
    output[:,4,:10] = q2replace

    activations['qkv'] = output
    return output
  return hook

layer2modify = 3
hookHandle = model.transformer.h[layer2modify].attn.c_attn.register_forward_hook(implant_hook(layer2modify))

In [None]:
model(tokens)
activations['qkv'][0,4,:15].detach()

In [None]:
# make a new replacement
q2replace = torch.linspace(-1,.7,10)

model(tokens)
activations['qkv'][0,4,:15].detach()

In [None]:
hookHandle.remove()

# Hook demo 6: More flexibility with dictionary

In [None]:
# initialize activations dictionary
activations = {}

# dictionary of replacements
act_replacements = {
    3 : torch.zeros(8),
    8 : torch.arange(5)
    }


def implant_hook(layer_number):
  def hook(module, input, output):

    # modify the activation if this layer has a key in the dictionary
    if layer_number in act_replacements.keys():
      newdata = act_replacements[layer_number]
      output[:,4,:len(newdata)] = newdata

    # store the activations
    activations[f'qkv_{layer_number}'] = output

    # output the QKV matrix so it replaces the original (unchanged for non-target layers)
    return output

  return hook


handles = []
for layeri in range(12):
  h = model.transformer.h[layeri].attn.c_attn.register_forward_hook(implant_hook(layeri))
  handles.append(h) # get all the handles for later removal

In [None]:
# confirm
model(tokens)
for i in range(12):
  firstQs = activations[f'qkv_{i}'][0,4,:8].detach()
  print(f'Q acts from layer {i:2}:\n ',firstQs)

In [None]:
# remove one replacement
del act_replacements[8]

# add another
act_replacements[11] = torch.tensor([1,2,3])

act_replacements

In [None]:
# try again
model(tokens)
for i in range(12):
  firstQs = activations[f'qkv_{i}'][0,4,:8].detach()
  print(f'Q acts from layer {i:2}:\n ',firstQs)

In [None]:
# remove all handles
for h in handles:
  h.remove()

In [None]:
for k,i in act_replacements.items():
  print(f'key {k} has values {i}')