### MechInterp

### Engineering

In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
import copy

In [None]:
class ParameterSharding:
    def __init__(self, param_groups, parallel_context):
        self.param_groups = param_groups
        self.parallel_context = parallel_context
    
    def shard(self):
        world_size = self.parallel_context.get_world_size()
        partitioned_params = [[] for _ in range(world_size)]
        sizes = [0 for _ in range(world_size)]
        
        for param_group in self.param_groups:
            param_list = [[] for _ in range(world_size)]
            
            for p in param_groups["params"]:
                next_rank = sizes.index(min(sizes))
                param_list[next_rank].append(p)
                sizes[rank] += p.numel()
            
            for rank, params in enumerate(param_list):
                param_group_rank = copy.copy(param_group)
                param_group_rank["rank"] = params
                
                partitioned_params[rank].append(param_group_rank)

In [None]:
embedding, linear, attention, layer norm

In [None]:
step 1: determine the global rank of the current process
step 2: resize embedding layer
step 3: parallelize embedding, linear, attention, layer norm
step 4: resize vocab space

In [3]:
def wait_stream(source_stream, target_stream):
    if isinstance(target_stream, torch.cuda.Stream):
        if isinstance(source_stream, torch.cuda.Stream):
            # GPU waits GPU
            source_stream.wait_stream(target_stream)
        else:
            # CPU waits GPU
            target_stream.syncronous()

In [4]:
# class Copy(torch.autograd.Function):
#     @staticmethod
#     def forward(ctx, prev_stream, next_stream, x):
#         ctx.prev_stream = prev_stream
#         ctx.next_stream = next_stream
        
#         wait_stream(
#             source_stream=next_stream,
#             target_stream=prev_stream
#         )
        
#         return x

#     @staticmethod
#     def backward(ctx, grad):
#         prev_stream = ctx.prev_stream
#         next_stream = ctx.next_stream
        
#         wait_stream(
#             source_stream=prev_stream,
#             target_stream=next_stream
#         )
        
#         return x

In [5]:
class Copy(torch.autograd.Function):
    @staticmethod
    def forward(ctx, prev_stream, next_stream, input):
        ctx.prev_stream = prev_stream
        ctx.next_stream = next_stream
        
        compute_stream = torch.cuda.default_stream(next_stream.device)
        
        with torch.cuda.stream(prev_stream), torch.cuda.stream(next_stream):
            moved_input = input.to(next_stream.device)
            input.record_stream(prev_stream)
            moved_input.record_stream(compute_stream)
        
        return moved_input
    
    @staticmethod
    def backward(ctx, grad_input):
        prev_stream = ctx.prev_stream
        next_stream = ctx.next_stream
        
        compute_stream = torch.cuda.default_stream(prev_stream.device)
        
        with torch.cuda.stream(prev_stream), torch.cuda.stream(next_stream):
            moved_grad_input = grad_input.to(prev_stream.device)
            
            grad_input.record_stream(next_stream)
            

In [None]:
output2 = embed + pos_embed + attn00 + attn01 + mlp0 + attn10 + attn11 + mlp1

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout):
        super().__init__()
        self.mha = MultiHeadAttention(
            d_model=d_model,
            n_heads=n_heads
        )
        self.norm1 = ResidualLayerNorm(
            d_model=d_model, dropout=dropout
        )
        self.mlp = PositionWiseFeedForward(
            d_model=d_model, d_ff=d_ff, dropout=dropout
        )
        self.norm2 = ResidualLayerNorm(
            d_model=d_model, dropout=dropout
        )
    
    def forward(self, embeddings):
        attn_output, att_weights = self.mha(
            pre_q=embeddings,
            pre_k=embeddings,
            pre_v=embeddings
        )
        norm1 = self.norm1(attn_output, residual=embeddings)
        mlp = self.mlp
        norm2 = self.norm2(mlp, residual=norm1)
        
        return norm2, attn_weights

In [6]:
from einops import rearrange

In [None]:
def compute_loss(model, xb, yb):
    logits = model(xb)
    
    logits = rearrange(logits, "bs sq n_embed -> (bs sq) n_embed")
    yb = rearrange(yb, "bs sq -> (bs sq)")
    
    loss = F.cross_entropy(
        logits,
        yb
    )
    
    return loss

In [7]:
def probability_scores(image_embedding, text_embedding):
    image_norm = image_embedding.norm(dim=-1, keepdim=True)
    image_embedding = image_embedding / image_norm
    
    text_norm = text_embedding.norm(dim=-1, keepdim=True)
    text_embedding = text_embedding / text_norm
    
    similarities = image_embedding @ text_embedding.T
    probs = F.softmax(similarities, dim=-1)
    
    return probs

In [None]:
import 

step 1: elasticdriver spots a change in the worker nodes
step 2: it sends a hostupdatedrequest to the notification service of the coordinating worker
step 3: if there are changes in node, the notification service passes the request to the notification manager
step 4:

In [None]:
class Copy

In [None]:
class TestFruit:
    def setup_method(self):
        self.fruit = Fruit(name="banana")
    
    def test_init(self):
        assert self.fruit.name == "banana"
    
    def teardown_method(self):
        del self.fruit

In [8]:
from torch.utils.data import Sampler

In [None]:
class EvenSampler(Sampler):
    def __init__(self, data):
        super().__init__()
        self.data = data
    
    def __iter__(self):
        return [x for x in range(0, len(self.data), 2)]

step 1: normalize the loss with current_epoch / n_epoch
step 2: calculate the gradients with respect to the normalized loss
step 3: accumulate the gradients
step 4: if current_epoch == n_epoch, update, otherwise, repeat step 1

In [None]:
elasticdriver, torchstate, 3notification, hostdiscovery

In [9]:
class Copy(torch.autograd.Function):
    @staticmethod
    def forward(ctx, prev_stream, next_stream, input):
        ctx.prev_stream = prev_stream
        ctx.next_stream = next_stream
        
        compute_stream = torch.cuda.default_stream(next_stream.device)
        
        with torch.cuda.stream(prev_stream), torch.cuda.stream(next_stream):
            moved_input = input.to(next_stream.device)
            
            input.record_stream(prev_stream)
            moved_input.record_stream(compute_stream)
        
        return moved_input
    
    @staticmethod
    def backward(ctx, grad_input):
        prev_stream = ctx.prev_stream
        next_stream = ctx.next_stream
        
        compute_stream = torch.cuda.default_stream(prev_stream.device)
        
        with torch.cuda.stream(prev_stream), torch.cuda.stream(next_stream):
            moved_grad_input = grad_input.to(prev_stream.device)
            
            grad_input.record_stream(next_stream)
            moved_grad_input.record_stream(compute_stream)
        
        return None, None, moved_grad_input

In [None]:
#!/bin/bash

In [10]:
class TestFruit:
    def setup_method(self):
        self.fruit = Fruit("x")
    
    def test_fruit(self):
        assert self.fruit.name == "x"
    
    def teardown_method(self):
        del self.fruit

In [11]:
import pytest

In [12]:
@pytest.mark.parametrize(
    "test_input, expected"
)
def test_square(test_input, expected):
    assert square(test_input) == expected

In [None]:
step 1: record all the interdimate activations
step 2: analyze attention patterns
step 3: spot induction heads
step 4: decompose the attetion scores of the induction heads
step 5: identify pair q-k that produces that induction charaterstic
step 6: trace backward
step 7: construct the full circuit

In [None]:
step 1: prob[0] = sigmoid(logit0 - logit1)
step 2: logit0 = resid2 @ W_U[0], logit1 = resid2 @ W_U[1]
step 3: logit0 - logit1 = resid2 @ (W_U[0] - W_U[1])
step 4: resid2 = resid1 @ ln1 @ W_OV^{1, 0}
step 5: resid1 @ ln1 @ W_OV^{1, 0} @ (W_U[0] - W_U[1])

In [None]:
_, cache = model.run_with_cache(board_history)

In [None]:
mlp_acts = cache[hook_name]

In [13]:
layer_idx = 5
neuron_idx = 1393

In [None]:
neuron_acts = mlp_acts[:, neuron_idx]

In [None]:
threashold = neuron_acts.quantile(0.99)

In [None]:
move_idxs = neuron_acts > threashold

In [None]:
(board_states == 1)[:, :, move_idxs]

In [None]:
W_V = torch.zeros(seq_len, d_model)

In [15]:
corrupted_prompt = "X told Y: 'Persistence is all you need.' Z replied back to "

In [None]:
clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

In [16]:
layer_idx, head_idx = 6, 9

In [None]:
W_pos = model.W_pos

In [None]:
pos_one = W_pos[:, 0]
pos_two = W_pos[:, 1]

In [None]:
pos_one /= pos_one.norm(dim=-1, keepdim=True)
pos_two /= pos_two.norm(dim=-1, keepdim=True)

In [None]:
similarity = 

In [None]:
_, cache = model.run_with_cache(board_history)

In [17]:
hook_name = "blocks.5.mlp.hook_post"

In [None]:
mlp_acts = cache[hook_name]

In [None]:
neuron_acts = mlp_acts[:, 1393]

In [None]:
A^1 @ x @ W_OV^1 

In [None]:
_, cache = model.run_with_cache(past_moves)

In [None]:
mlp_acts = cache[hook_name]

In [None]:
top_neurons = mlp_acts.std(dim=[0, 1]).argsort(descending=True)[:10]

In [18]:
layer_idx = 4

In [19]:
from einops import einsum

In [None]:
def calculate_consine_similarity(neuron_idx, feature):
    W_out = model.W_out[layer_idx, neuron_idx, :]
    W_out /= W_out.norm(dim=-1, keepdimt=True)
    
    feature /= feature.norm(dim=-1, keepdim=True)
    
    return einsum(
        W_out,
        feature,
    )

In [None]:
heatmap_blanks = []
for neuron_idx in top_neurons:
    heatmap_blanks.append(calculate_consine_similarity(
        neuron_idx,
        blank_dir
    ))

In [None]:
pos_embed + emed + sum(12 heads in layer 0)

In [None]:
step 1: prob[0] = sigmoid(logit0 - logit1)
step 2: logit0 = resid @ W_U[0], logit1 = resid @ W_U[1]
step 3: logit0 - logit1 = resid @ (W_U[0] - W_U[1])

In [None]:
A@x@W_OV 

In [None]:
diagonal = x.diagonal(dim1=-2, dim2=-1)

In [20]:
corrupted_prompt = "X told Y: 'Persistence is all you need.' Z replied back to "

In [None]:
clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

In [None]:
_, clean_cache = model.run_with_cache(clean_tokens)
_, corrupted_cache = model.run_with_cache(corrupted_tokens)

In [21]:
from transformer_lens.utils import get_act_name

In [22]:
layer_idx, head_idx = 6, 9

In [23]:
hook_name = get_act_name("result", layer_idx)

In [None]:
sender_acts = corrupted_cache[hook_name][:, head_idx]

In [24]:
def patch_corrupted_sender_activations(activations, hook):
    activations = sender_acts
    return activations

In [None]:
_, cache = model.run_with_cache(
    clean_tokens,
    fwd_hooks=[(hook_name, patch_corrupted_sender_activations)]
)

In [None]:
receiver_acts = cache[receiver_hook_name]

In [None]:
def patch_receiver_acts()

step 1: a clean prompt, a corrupted prompt
step 2: record all the interdimate activations of the clean prompt and corrupted prompt
step 3: choose a sender component, and a receiver component
step 4: run the corrupted prompt and record the activations of the sender
step 5: run the clean prompt and patch the corrupted sender activations
step 6: record the receiveer activations from step 5
step 7: run the clean prompt again and patch the receiver activations from step 6

In [None]:
A^{1}@x@W_OV

In [27]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple neural network
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        return self.linear(x)

# Create a model and an optimizer
model = SimpleNN()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Step the optimizer to initialize its internal states
input_tensor = torch.randn(5, 10)
output = model(input_tensor)
loss = output.sum()
loss.backward()
optimizer.step()

# Check if the optimizer states reference the same memory as the model parameters
for group in optimizer.param_groups:
    for param in group['params']:
        param_memory = param.data_ptr()
        optimizer_memory = optimizer.state[param].get('momentum_buffer')
        
        break
        print(f"Param memory: {param_memory}, Optimizer state memory: {optimizer_memory}")
        if param_memory == optimizer_memory:
            print("Memory references are the same.")
        else:
            print("Memory references are different.")


In [28]:
optimizer_memory