In [1]:
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer

In [2]:
class HookedTransformer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, num_layers):
        super().__init__()
        encoder_layer = TransformerEncoderLayer(
            d_model, nhead, dim_feedforward, activation="relu"
        )
        self.transformer = TransformerEncoder(encoder_layer, num_layers)
        self.mlp_activations = []

    def forward(self, src):
        return self.transformer(src)

    def get_mlp_activations(self, module, input, output):
        print(f"Hook called for module: {module}")
        self.mlp_activations.append(output.detach())

In [4]:
def test_hooked_transformer():
    # Set random seed for reproducibility
    torch.manual_seed(42)

    # Initialize the transformer
    d_model = 64
    nhead = 1
    dim_feedforward = 512
    num_layers = 1
    transformer = HookedTransformer(d_model, nhead, dim_feedforward, num_layers)

    # Register the hook
    hooks = []
    for name, module in transformer.named_modules():
        if (
            "linear2" in name
        ):  # This is the second linear layer in the feed-forward network
            print(f"Registering forward hook for module: {name}")
            hook = module.register_forward_hook(transformer.get_mlp_activations)
            hooks.append(hook)

    # Create input tensor
    batch_size = 2
    seq_length = 10
    input_tensor = torch.randn(seq_length, batch_size, d_model)

    # Forward pass
    output = transformer(input_tensor)

    # Check if MLP activations were collected
    assert len(transformer.mlp_activations) > 0, "No MLP activations were collected"

    # Check the shape of collected activations
    expected_shape = (seq_length, batch_size, d_model)  # Corrected shape
    actual_shape = transformer.mlp_activations[0].shape
    assert (
        actual_shape == expected_shape
    ), f"Expected shape {expected_shape}, but got {actual_shape}"

    # We can't check for non-negative values as we're capturing before ReLU
    # But we can check if the activations are not all zero
    assert not torch.all(
        transformer.mlp_activations[0] == 0
    ), "All MLP activations are zero, which is unlikely"

    print("All tests passed successfully!")

    # Remove the hooks
    for hook in hooks:
        hook.remove()

In [5]:
test_hooked_transformer()

Registering forward hook for module: transformer.layers.0.linear2
Hook called for module: Linear(in_features=512, out_features=64, bias=True)
All tests passed successfully!




In [6]:
d_model = 64
nhead = 1
dim_feedforward = 512
num_layers = 1
transformer = HookedTransformer(d_model, nhead, dim_feedforward, num_layers)

for name, module in transformer.named_modules():
    print(f'{name} : {module}')

 : HookedTransformer(
  (transformer): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
        (linear1): Linear(in_features=64, out_features=512, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=512, out_features=64, bias=True)
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
)
transformer : TransformerEncoder(
  (layers): ModuleList(
    (0): TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
      )
      (linear1): Linear(in_features=64, out_

In [7]:
class CustomTransformerEncoderLayer(TransformerEncoderLayer):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        super().__init__(d_model, nhead, dim_feedforward, dropout, activation)
        self.activation_fn = nn.ReLU()  # Explicitly define ReLU as a module

    def _ff_block(self, x):
        x = self.linear1(x)
        x = self.activation_fn(x)  # Use the module version of ReLU
        x = self.dropout1(x)
        x = self.linear2(x)
        return x

class HookedTransformer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, num_layers):
        super().__init__()
        encoder_layer = CustomTransformerEncoderLayer(d_model, nhead, dim_feedforward, activation='relu')
        self.transformer = TransformerEncoder(encoder_layer, num_layers)
        self.mlp_activations = []

    def forward(self, src):
        return self.transformer(src)

    def get_mlp_activations(self, module, input, output):
        print(f"Hook called for module: {module}")
        self.mlp_activations.append(output.detach())

In [8]:
def test_hooked_transformer():
    # Set random seed for reproducibility
    torch.manual_seed(42)

    # Initialize the transformer
    d_model = 64
    nhead = 4
    dim_feedforward = 512
    num_layers = 1
    transformer = HookedTransformer(d_model, nhead, dim_feedforward, num_layers)

    # Register the hook
    hooks = []
    for name, module in transformer.named_modules():
        if isinstance(module, nn.ReLU):
            print(f"Registering forward hook for module: {name}")
            hook = module.register_forward_hook(transformer.get_mlp_activations)
            hooks.append(hook)

    # Create input tensor
    batch_size = 2
    seq_length = 10
    input_tensor = torch.randn(seq_length, batch_size, d_model)

    # Forward pass
    output = transformer(input_tensor)

    # Check if MLP activations were collected
    assert len(transformer.mlp_activations) > 0, "No MLP activations were collected"

    # Check the shape of collected activations
    expected_shape = (seq_length, batch_size, dim_feedforward)
    actual_shape = transformer.mlp_activations[0].shape
    assert actual_shape == expected_shape, f"Expected shape {expected_shape}, but got {actual_shape}"

    # Check if activations are non-negative (due to ReLU)
    assert torch.all(transformer.mlp_activations[0] >= 0), "ReLU activations should be non-negative"

    # Check if some activations are positive (not all zeros)
    assert torch.any(transformer.mlp_activations[0] > 0), "All ReLU activations are zero, which is unlikely"

    print("All tests passed successfully!")

    # Remove the hooks
    for hook in hooks:
        hook.remove()

if __name__ == "__main__":
    test_hooked_transformer()

Registering forward hook for module: transformer.layers.0.activation_fn
Hook called for module: ReLU()
All tests passed successfully!
