### Engineering

In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
class Checkpoint(torch.autograd.Function):
    @staticmethod
    def forward(ctx, phony, recomputed, function, input):
        ctx.recomputed = recomputed
        ctx.function = function
        ctx.input = input
        
        with torch.no_grad():
            output = function(input)
        return output
    
    @staticmethod
    def backward(ctx, grad_input):
        output, input_leaf = ctx.recomputed
        input_leaf = input_leaf.detach().requires_grad_(
            input_leaf.requires_grad
        )
        
        with torch.grad_enabled():
            torch.autograd.backward(output, grad_input)
        
        grads = [None, None, None]
        if input_leaf.requires_grad:
            grads.extend([input_leaf.grad])
        else:
            grads.extend([None])
        return tuple(grads)

In [None]:
#include <iostream>
using namespace std;

In [None]:
for (auto i: xs) {
    std::cout << "value: " << i << std::endl;
}

In [3]:
class HostUpdatedInterrupt(Exception):
    def __init__(self, skip_sync):
        self.skip_sync = skip_sync

In [None]:
monitor
reassign if nodes leave
new communication ring if nodes join

In [None]:
class ModelStateHandler:
    def __init__(self, model):
        self.value = model
        self.save()
    
    def save(self):
        self._model_state_dict = self.value.state_dict()
    
    def restore(self):
        self.value.load_state_dict(self._model_state_dict)
    
    def sync(self):
        broadcast_parameters(self.value)

In [5]:
def get_handler(v):
    for handler_type, handler_cls in handler_registry:
        if isinstance(v, handler_cls):
            return handler_cls(v)
    return None

In [4]:
def get_handlers(states):
    handlers = {}
    remainders = {}
    
    for k, v in states:
        handler = get_handler(v)
        if handler is None:
            remainders[k] = v
        else:
            handlers[k] = v
    
    return handlers, remainders

In [None]:
def compute_total_memory(model):
    total_memory = 0
    
    for param in model.parameters():
        total_memory += param.storage.size() * param.numel()

In [6]:
import socketserver

In [None]:
with socketserver.ThreadingTCPServer(
    (MASTER_HOST, MASTER_PORT),
    EchoRequestHandler
) as server:
    server.serve_forever()

In [7]:
import copy

In [8]:
class ModelStateHandler:
    def __init__(self, model):
        self.set_value(model)
    
    def save(self):
        self._model_state_dict = copy.deepcopy(
            self.value.state_dict()
        )
    
    def restore(self):
        self.value.load_state_dict(self._model_state_dict)
    
    def sync(self):
        broadcast_parameters(self.value)
    
    def set_value(self, value):
        self.value = value
        self.save()

In [9]:
def get_handler(v):
    for handler_type, handler_cls in handler_registry:
        if isinstance(v, handler_cls):
            return handler_cls(v)
    return None

In [10]:
def get_handlers(states):
    handlers = {}
    remainders = {}
    
    for k, v in states:
        handler = get_handler(v)
        if handler is None:
            remainders[k] = v
        else:
            handlers[k] = v
    return handlers, remainders

In [None]:
rank = torch.distributed.get_rank()
group = None
ranks = [0, 1, 3, 6]

In [None]:
if rank in ranks:
    group = torch.distributed.new_group(ranks)

In [None]:
if group is not None:
    torch.distributed.broadcast(x, src=0, group=group)

In [13]:
import threading

In [14]:
data = threading.local()

In [15]:
def run_worker():
    print_and_modify(data)

In [None]:
thread = threading.Thread(target=run_worker)

In [17]:
def run_worker(event):
    print("waiting")
    event.wait()
    print("received")

In [16]:
event = threading.Event()

In [None]:
worker_thread = threading.Thread(
    target=run_worker,
    args=(event,)
)

### MechInterp

In [None]:
step 1: head 2
step 2: logit difference directions
step 3: einsu

In [18]:
def print_shape(module, input):
    print(input.shape)

In [None]:
model.blocks[1].register_forward_pre_hook(print_shape)

In [None]:
model.reset_hooks()

In [None]:
cache["hook_enmbed"]

In [None]:
tokens = model.to_tokens(text)

In [20]:
data = None

In [19]:
def extract_neuron(activations, hook):
    data = activations[:, :, neuron_idx]
    return activations

In [None]:
hook_name = f"blocks.{layer_idx}.mlp.hook_post"

In [None]:
model.run_with_cache(
    tokens,
    fwd_hooks=[(hook, extract_neuron)]
)

In [None]:
arg_neuron = torch.argmax(data, dim=-1)

In [None]:
x@W_Q

x = embed + pos_embed + sum(12 heads)


[embed + pos_embed + sum(12 heads)] @ W_Q

embed @ W_Q + 

In [None]:
embed + pos_embed + sum(12 heads)

In [None]:
model.embed(tokens)

In [None]:
(head_1 + head_2 + head_3) + bias

In [None]:
_, cache = model.run_with_cache(
    tokens,
)

In [None]:
final_residual_stream = cache[final_residual_name]
last_token_final_residual_stream = final_residual_stream[:, -1, :]

In [None]:
scaled_last_token_final_residual_stream = model.apply_ln_to_stack(
    last_token_final_residual_stream,
    layer=-1,
    pos_slice=-1
)

In [None]:
W_E = model.W_E
correct_residual_direction = W_E[:, correct_token]
incorrect_residual_direction = W_E[:, incorrect_token]

In [None]:
logit_diff_direction = correct_residual_direction - incorrect_residual_direction

In [21]:
from einops import einsum

In [None]:
logit_difference = einsum(
    scaled_last_token_final_residual_stream,
    logit_diff_direction,
    ""
)

In [None]:
tokens = model.to_tokens(text)

In [None]:
_, cache = model.run_with_cache(tokens)

In [None]:
embed = cache["hook_embed"]
unembed = cache["hook_unembed"]
head_outputs = cache["result", layer_idx-1]
input_components = torch.cat([embed, unembed, head_outputs], dim=0)

In [None]:
W_Q = model.W_Q[layer_idx, head_idx]
query_components = einsum(input_components, W_Q, "")

In [None]:
W_K = model.W_K[layer_idx, head_idx]
key_components = einsum(input_components, W_Q, "")

In [None]:
decomposed_scores = einsum(
    query_components, key_components,
    ""
)

In [None]:
cache["blocks.1.attn.pattern_hook"]

In [22]:
def split_model(model, balances, devices):
    layers = {}
    partritions = []
    partrition_idx = 0
    
    for name, layer in model.named_children():
        layers[name] = layer
        
        if len(layers) == balances[partrition_idx]:
            partrition = nn.Module(layers)
            partrition.to(devices[partrition_idx])
            partritions.append(partrtion)
            layers.clear()
    return partritions

In [None]:
class ShortCutProjection(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_channels, out_channels),
            nn.ReLU(),
        )

In [None]:
clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

The neuron matrix, $W_{\text{neur}}$, represents the weights between the attention layer and the MLP (ReLU) layer in our transformer model. It tells us how to get from a weighted sum of initial embeddings (the output of the attention layer) to the inputs to the ReLU function.

To calculate $W_{\text{neur}}$, we simply multiply the relevant weight matrices:
$W_{\text{neur}} = W_E W_V W_O W_{\text{in}}$

Where:

$W_E$ is the embedding matrix
$W_V$ and $W_O$ are the value and output matrices for each attention head
$W_{\text{in}}$ is the input matrix for the MLP
So if we had:

$W_E \in \mathbb{R}^{100 \times 512}$ (embedding dim = 512)
4 attention heads
For each head, $W_V, W_O \in \mathbb{R}^{512 \times 64}$
$W_{\text{in}} \in \mathbb{R}^{64 \times 2048}$ (MLP dim = 2048)
Then $W_{\text{neur}}$ would be:

$W_{\text{neur}} \in \mathbb{R}^{4 \times 100 \times 2048}$

And we could calculate it as:

py
Copy
W_E = torch.rand(100, 512)  # Embedding matrix
W_V = torch.rand(512, 64)   # Value matrix for one head
W_O = torch.rand(512, 64)   # Output matrix for one head 

W_in = torch.rand(64, 2048) # Input matrix for MLP 

# Stack value/output matrices for 4 heads 
W_V = torch.stack([W_V]*4)   
W_O = torch.stack([W_O]*4)   

W_neur = W_E @ W_V @ W_O @ W_in
So $W_{\text{neur}}$ tells us how to transform the output of the attention layer (a weighted sum of 4 embedding vectors) into the 2048-dimensional input for the ReLU activation in the MLP.

Does this help explain the neuron matrix? Let me know if you have any other questions!