|<h2>Course:</h2>|<h1><a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">A deep understanding of AI language model mechanisms</a></h1>|
|-|:-:|
|<h2>Part 5:</h2>|<h1>Observation (non-causal) mech interp<h1>|
|<h2>Section:</h2>|<h1>Investigating neurons and dimensions<h1>|
|<h2>Lecture:</h2>|<h1><b>CodeChallenge: Reproducibility of activation maximization<b></h1>|

<br>

<h5><b>Teacher:</b> Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h5>
<h5><b>Course URL:</b> <a href="https://udemy.com/course/dullms_x/?couponCode=202508" target="_blank">udemy.com/course/dullms_x/?couponCode=202508</a></h5>
<i>Using the code without the course may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from transformers import GPT2Model, GPT2Tokenizer

# vector plots
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

# Exercise 1: A training function

In [None]:
# load GPT2 model and tokenizer
model = GPT2Model.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# store a copy of the original embeddings
embeddings = model.wte.weight.detach().cpu()

# use GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# move the model to the GPU
model = model.to(device)
model.eval()

In [None]:
def trainingFunction():

  # initialize random embeddings and normalize the stds
  optimized_embeddings = torch.randn((1, seq_len, embeddings.shape[1]), requires_grad=True, device=device)
  torch.nn.init.normal_(optimized_embeddings, mean=0, std=torch.std(embeddings))

  # create an optimizer
  optimizer = torch.optim.Adam([optimized_embeddings], lr=lr)


  # loop over training steps
  for step in range(n_steps):

    # clear gradient
    optimizer.zero_grad()

    # patch embeddings directly into the model
    outputs = model(inputs_embeds = optimized_embeddings,output_hidden_states=True)
    allActivations = outputs.hidden_states[layer_idx]

    # extract the activations (averaged over tokens)
    neuron_activation = allActivations[0,:,dim_idx].mean()

    # calculate loss and run gradient descent
    loss = -neuron_activation + lambda_l2 * torch.sum(optimized_embeddings**2)
    loss.backward()
    optimizer.step()

  # output the optimized embeddings
  return optimized_embeddings

In [None]:
# length of the token sequence
seq_len = 5

# dimension and layer to maximize
layer_idx = 8
dim_idx = 91

n_steps = 300   # optimization steps
lr = .001       # learning rate
lambda_l2 = .01 # regularization amount

In [None]:
### test it
optimized_embeddings = trainingFunction()

In [None]:
# decode embeddings to closest tokens
optimized_tokens = []

for emb in optimized_embeddings[0]:

  # cosine similarity with embedding weights
  similarities = F.cosine_similarity(emb.unsqueeze(0).detach().cpu(), embeddings)

  # find the max similarity
  maxtok = np.argmax(similarities)
  optimized_tokens.append(maxtok)

tokenizer.decode(optimized_tokens)

# Exercise 2: Run it 10x

In [None]:
# number of experiment repetitions
numberRepeats = 10

# initialize a matrix for all optimized tokens
all_optimized_tokens = np.zeros((numberRepeats,seq_len),dtype=int)


# start the loop!
for runi in range(numberRepeats):

  # call the training function
  optimized_embeddings = trainingFunction()

  # decode embeddings to closest tokens
  optimized_tokens = []

  for embi,emb in enumerate(optimized_embeddings[0]):

    # cosine similarity with embedding weights
    similarities = F.cosine_similarity(emb.unsqueeze(0).detach().cpu(), embeddings)

    # find the max similarity
    maxtok = np.argmax(similarities)
    all_optimized_tokens[runi,embi] = maxtok

  # status
  print(f'Finished repeat {runi+1} of {numberRepeats}')


In [None]:
all_optimized_tokens

In [None]:
unitokens,counts = np.unique(all_optimized_tokens,return_counts=True)

print(f'{len(unitokens)}/{np.prod(all_optimized_tokens.shape)} tokens are unique.\n')

for t,c in zip(unitokens,counts):
  print(f'{c:2} optimization for token "{tokenizer.decode([t])}"')