In [1]:
#testing consequences of interventions on gpt2, and how they match up against our SAE's
import torch

feature_encoder_weights = torch.randn(768)
feature_encoder_bias = torch.randn(1)

feature_decoder_weights = torch.randn(768)
feature_decoder_bias = torch.randn(1)


In [2]:
import torch
a = torch.randn( [2,4,5])
b = torch.randn( [2, 5,1])
(a @ b).size()

torch.Size([2, 4, 1])

In [3]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# Load pre-trained GPT-2 model and tokenizer
model = GPT2LMHeadModel.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")


# Define the hook function
def strengthen_sae_feature(module, input, output):
    # identify first position where the SAE fires
    # Compute the feature activation

    # output: batch, seq, embed
    # feature_encoder_weights: 1 , embed , 1
    # multiply result: batch, seq,
    # #we want multiply result of batch, seq, embed

    # feature_activation = (output @ feature_encoder_weights.unsqueeze(0).unsqueeze(2)).squeeze(1) - feature_encoder_bias
    # This way of doing it with einsum might be better
    feature_activation = (
        torch.einsum("bse,e->bs", output[0], feature_encoder_weights)
        - feature_encoder_bias
    )

    # Find the first position where the feature activates (if any)
    first_activation_positions = (feature_activation > 0).float().argmax(dim=1)

    # Check if the feature activates at all (argmax will return 0 if no activation)
    has_activation = (feature_activation > 0).any(dim=1)

    # Set positions to -1 where there's no activation
    first_activation_positions[~has_activation] = -1

    # Store the result
    global first_activation_positions_global
    first_activation_positions_global = first_activation_positions

    global store_output
    store_output = output

    # Now, in only those positions, add lambda times feature_decoder_weights to the output
    lambda_value = 1.0  # You can adjust this value as needed
    batch_size, seq_len, embed_dim = output[0].shape

    # Create a mask for the positions where we want to add the feature
    mask = torch.arange(seq_len).unsqueeze(0).expand(
        batch_size, -1
    ) == first_activation_positions_global.unsqueeze(1)
    mask = mask.unsqueeze(-1).expand(-1, -1, embed_dim)

    # Add lambda times feature_decoder_weights to the output at the masked positions
    new_output = output[0] + lambda_value * mask * feature_decoder_weights.unsqueeze(
        0
    ).unsqueeze(0)

    new_outputs = [new_output] + list(output[1:])
    return tuple(new_outputs)


def return_consequent_layer(module, input, output):
    # Get the batch size and sequence length from the output
    batch_size, seq_len, embed_dim = output[0].shape

    # Create a mask for the positions where the feature first activates
    mask = torch.arange(seq_len).unsqueeze(0).expand(
        batch_size, -1
    ) == first_activation_positions_global.unsqueeze(1)

    # Expand the mask to match the embedding dimension
    mask = mask.unsqueeze(-1).expand(-1, -1, embed_dim)

    # Use the mask to zero out all positions except where the feature first activates
    filtered_output = output[0] * mask

    # Sum along the sequence dimension to get one embedding per batch item
    # This will effectively select the embedding at the first activation position
    # for each item in the batch
    selected_embeddings = filtered_output.sum(dim=1)

    # Store the selected embeddings
    global consequent_embeddings
    consequent_embeddings = selected_embeddings

    # # Store the output of the middle layer
    # global modified_layer
    # modified_layer = output


intervention_index = 5
readout_index = 8

# Register the hook on the chosen middle layer
intervention_hook = model.transformer.h[intervention_index].register_forward_hook(
    strengthen_sae_feature
)
readout_hook = model.transformer.h[readout_index].register_forward_hook(
    return_consequent_layer
)

# then we have to read out the modified layer! will be stored under consequent_embeddings, size: batch , embed

# Example usage
text = ["Hello, world!", "Hello, world!"]
inputs = tokenizer(text, return_tensors="pt")

# Forward pass
with torch.no_grad():
    outputs = model(**inputs)

# Now middle_layer_output contains the output of the middle layer
print(f"shape of output embeddings: {consequent_embeddings.shape}")


# Remove the hook after use
intervention_hook.remove()
readout_hook.remove()

shape of output embeddings: torch.Size([2, 768])


In [4]:
first_activation_positions_global

tensor([0, 0])

In [5]:
consequent_embeddings

tensor([[ 0.0194, -0.5493,  0.9223,  ..., -1.5387, -1.3595, -1.2544],
        [ 0.0194, -0.5493,  0.9223,  ..., -1.5387, -1.3595, -1.2544]])

In [6]:
#next code we would need to add:
#filter out any batch elements where the SAE doesn't trigger
#compare to activations from a hook on the clean sequence
#subtract, compute comparisons, etc.!

In [7]:
#I suspect the most efficient way to go about thos jacobian computation is to modify the gpt2 forward pass

In [16]:
model.transformer.h

ModuleList(
  (0-11): 12 x GPT2Block(
    (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (attn): GPT2SdpaAttention(
      (c_attn): Conv1D()
      (c_proj): Conv1D()
      (attn_dropout): Dropout(p=0.1, inplace=False)
      (resid_dropout): Dropout(p=0.1, inplace=False)
    )
    (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (mlp): GPT2MLP(
      (c_fc): Conv1D()
      (c_proj): Conv1D()
      (act): NewGELUActivation()
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
)

In [32]:
def compute_jacobian(model, j_activations, i, j, k):
    """
    Compute the Jacobian of layer k's activations with respect to layer j's activations at position i.

    Args:
    - model: GPT2Model instance
    - j_activations: activations of layer j (shape: [batch_size, seq_len, hidden_size])
    - i: token position
    - j: index of the input layer
    - k: index of the output layer

    Returns:
    - Jacobian matrix
    """
    # Ensure j_activations requires grad
    j_activations.requires_grad_(True)

    # Forward pass to get k_activations
    def forward_to_k(x):
        # Forward pass from j to k
        activations = x
        for layer_idx in range(j, k + 1):
            activations = model.transformer.h[layer_idx](activations)[0]
        return activations[:, i, :]

    # Compute Jacobian
    jacobian = torch.autograd.functional.jacobian(forward_to_k, j_activations)

    return jacobian.squeeze(0, 2)[:, i, :]  # selecting only token pos i.
    # But if we're pre-computing, we could just return the jacobian.squeeze(0,2)

In [33]:

# Generate random input
batch_size, seq_len = 1, 10
j_activations = torch.randn(batch_size, seq_len, 768)

# Compute Jacobian for predicting layer 5 activations from layer 3 activations at position 2
i, j, k = 2, 3, 5
jacobian = compute_jacobian(model, j_activations, i, j, k)

print(f"Jacobian shape: {jacobian.shape}")

Jacobian shape: torch.Size([768, 768])
