### Engineering

In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [None]:
step 1: split a micro-batch
step 2: create cuda stream
step 3: pipe
step 4: gather the output

In [2]:
import torch.distributed as dist

In [5]:
class Reduce(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, parallel_context):
        group = parallel_context.get_group(ParallelModel.TENSOR)
        return dist.all_reduce(input, group=group)
    
    @staticmethod
    def backward(ctx, grad_input):
        return (grad_input, None)

In [6]:
class ParallelEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, parallel_context):
        super().__init__()
        world_size = parallel_context.get_world_size(ParallelMode.TENSOR)
        
        num_embeddings_per_partition = num_embeddings // world_size
        self.vocab_start_idx, self.vocab_end_idx = self._get_vocab_range(
            num_embeddings_per_partition,
            parallel_context
        )
        self.weight = nn.Parameter(torch.randn(
            num_embeddings_per_partition,
            embedding_dim
        ))
    
    def _get_vocab_range(self, num_embeddings_per_partition, parallel_context):
        rank = parallel_context.get_local_rank(ParallelMode.TENSOR)
        start_idx = rank*num_embeddings_per_partition
        end_idx = start_idx+num_embeddings_per_partition
        return start_idx, end_idx
    
    def forward(self, input):
        input_mask = (input < self.vocab_start_idx) | (input >= self.vocab_end_idx)
        masked_input = input.clone() - self.vocab_start_idx
        masked_input[input_mask] = 0
        
        parallel_embedding = F.embedding(masked_input, self.weight)
        parallel_embedding[input_mask, :] = 0.
        
        embeddings = Reduce.apply(parallel_embeddings, parallel_context)
        return embeddings

In [8]:
class Copy(torch.autograd.Function):
    @staticmethod
    def forward(ctx, prev_stream, next_stream, input):
        ctx.prev_stream = prev_stream
        ctx.next_stream = next_stream
        
        compute_stream = torch.cuda.default_stream(next_stream.device)
        
        with torch.cuda.stream(prev_stream), torch.cuda.stream(next_stream):
            moved_input = input.to(next_stream.device)
            input.record_stream(prev_stream)
            moved_input.record_stream(next_stream)
        
        return moved_input
    
    @staticmethod
    def backward(ctx, grad_input):
        prev_stream = ctx.prev_stream
        next_stream = ctx.next_stream
        
        compute_stream = torch.cuda.default_stream(prev_stream.device)
        
        with torch.cuda.stream(prev_stream), torch.cuda.stream(next_stream):
            moved_grad = grad_input.to(prev_stream.device)
            
            grad_input.record_stream(next_stream)
            moved_grad.record_stream(prev_stream)
        
        return moved_grad

In [None]:
_, cache = model.run_with_cache(clean_tokens)

In [9]:
layer_idx, head_idx = 0, 2

In [None]:
head_output = cache["z", layer_idx][:, head_idx]

In [None]:
output = head_output @ cache.W_O[layer_idx, head_idx]

In [None]:
handles = []
for hook_func in hooks:
    handles.append(model.ln_f.register_forward_pre_hook(hook_func))

In [None]:
tokens = model.to_tokens(text)

In [None]:
_, cache = model.run_with_cache(tokens)

In [10]:
layer_idx, neuron_idx = 3, 69

In [11]:
hook_name = f"blocks.{layer_idx}.mlp.hook_post"

In [12]:
hook_name

'blocks.3.mlp.hook_post'

In [None]:
neurons = cache[hook_name]

In [None]:
idx = torch.argmax(neurons, dim=-1)

In [None]:
dim, shape, dtype, requires_grad, type

In [None]:
F.layer_norm(input, normalized_shape=(5,), weight=weight)

In [None]:
process,
node enter

In [None]:
step 1: mask targets
step 2: predicted logits
step 3: sum predicted logits
step 4: calculate log
step 5: 

In [None]:
_, clean_cache = model.run_with_cache(clean_tokens)
_, corrupted_cache = model.run_with_cache(corrupted_tokens)

In [13]:
def patch_head(acts, hook, clean_cache, corrupted_cache, target_head):
    target_layer_idx, target_head_idx = target_head
    
    if target_layer_idx == hook.layer():
        acts[:, target_head_idx] = corrupted_cache[hook.name][:, target_head_idx]
    else:
        acts = clean_cache[hook.name]
    return acts

In [14]:
from itertools import product

In [None]:
combinations = product(range(n_layers), range(n_heads))

In [16]:
from functools import partial
from transformer_lens.utils import get_act_name

In [None]:
def resid2logits(cache):
    hook_name = get_act_name("resid_post", -1)
    resid = cache[hook_name][:, -1, :]
    return model.unembed(model.ln_final(resid))

In [None]:
result = torch.zeros(n_layers, n_heads)

for layer_idx, head_idx in combinations:
    hook_name = get_act_name("z", layer_idx)
    hook_func = partial(
        patch_head,
        clean_cache=clean_cache,
        corrupted_cache=corrupted_cache,
        target_head=(layer_idx, head_idx)
    )
    
    _, cache = model.run_with_cache(
        clean_tokens,
        fwd_hooks=[(hook_name, hook_func)]
    )
    patched_logits = resid2logits(cache)
    results[layer_idx][head_idx] = compute_ioi_metric(patched_logits)

In [None]:
_, cache = model.run_with_cache(board_history)

In [17]:
layer_idx, head_idx = 5, 1393

In [None]:
mlp_neurons = cache["out", layer_idx]

In [None]:
rank*partition_size

In [None]:
shape, dim, requires_grad, dtype

In [None]:
clock cycle 1: B(m, n)
clock cycle 2: B(m-1, n), B(m, n-1)
clock cycle 3: B(m-2, n), B(m-1, n-1), B(m, n-2)

In [None]:
{x: y**2 for x, y in simple_dict}

In [None]:
shape, dtype, requires_grad, dim

In [None]:
_, cache = model.run_with_cache(board_history)

In [None]:
mlp_neurons = cache[hook_name, layer_idx][:, head_idx]

In [None]:
thre = mlp_neurons.quantile(0.99)
top_neurons = mlp_neurons > thre

In [None]:
(board_states == 1)[top_neurons].float().mean(dim=-1)

In [None]:
W_E = model.W_E
W_Q = model.W_Q[1, 4]
W_K = model.W_K[1, 4]
W_O = model.W_O[0, 7]
W_V = model.W_V[0, 7]

In [None]:
Q = W_E @ W_Q
K = W_E @ W_V @ W_O @ W_K

In [None]:
tokens = model.to_tokens(text)
_, cache = model.run_with_cache(tokens)

In [None]:
attn_prob = cache["attn", 9][:, 9][:, -1, 0]

In [None]:
W_pos = model.W_pos

In [None]:
torch.cosine_similarity(W_pos[:, 0], W_pos[:, 1])

In [None]:
step 1: convert input tokens to fourier basis
step 2: trig identities -> addition
step 3: convert the result

In [None]:
W_O = model.W_Q[0, 1]
W_V = model.W_V[0, 1]
W_Q = model.W_Q[1, 2]
W_K = model.W_K[1, 2]

In [None]:
W_OV = W_O @ W_V
W_QK = W_Q @ W_K.T

In [None]:
scores = {
    "Q": torch.zeros(n_heads, n_heads),
    "K": torch.zeros(n_heads, n_heads),
    "V": torch.zeros(n_heads, n_heads)
}

In [18]:
def compute_composition_score(W_A, W_B):
    W_AB_norm = (W_A@W_B).pow(2).sum()
    W_A_norm = W_A.pow(2).sum()
    W_B_norm = W_B.pow(2).sum()
    return W_AB_norm/(W_A_norm*W_B_norm)

In [19]:
from einops import rearrange

In [None]:
W_O = model.W_O
W_V = model.W_V
W_Q = model.W_Q
W_K = model.W_K

W_OV = W_V @ W_O
W_QK = W_Q @ rearrange(
    W_K, "... d_head d_model -> ... d_model d_head"
)

In [None]:
for i in range(n_heads):
    for j in range(n_heads):
        scores["Q"] = compute_composition_score(
            W_OV[0, i],
            W_QK[1, j]
        )
        scores["K"] = compute_composition_score(W_OV[0, i], W_QK[1, j].T)
        scores["V"] = compute_composition_score(W_OV[0, i], W_OV[1, j])

In [None]:
W_E@W_OV^{0, 7}@W_QK^{1, 4}@W_E

In [None]:
_, cache = model.run_with_cache(past_moves)

In [None]:
top_neurons = cache[hook_name].std(dim=[0, 1]).argsort(descending=True)

In [20]:
from einops import einsum

In [21]:
def compute_consine_similarity(neuron_idx, feature):
    W_out = model.W_out[layer_idx, neuron_idx]
    W_out /= W_out.norm(dim=-1)
    
    feature /= feature.norm(dim=-1)
    
    return einsum(
        W_out, feature
    )

In [22]:
heatmap_blank = []

In [None]:
for neuron_idx in top_neurons:
    heatmap_blank.append(compute_consine_similarity(
        neuron_idx, blank_dir
    ))

In [None]:
W_O = model.W_O[0, 1]
W_V = model.W_V[0, 1]
W_Q = model.W_Q[1, 2]
W_K = model.W_K[1, 2]

W_OV = W_V @ W_O
W_QK = W_Q @ W_K.permute(-2, -1)

In [None]:
text_similarities = text_embedding @ text_embedding.T
image_similarities = image_embeding @ image_emebdding.T


targets = F.softmax(
    (text_similarities+image_similarities)/(2*temperature),
    dim=-1
)

In [None]:
class BatchNorm(nn.Module):
    def __init__(self, mom, eps):
        super().__init__()
        self.mom, self.eps = eps
        self.adds = nn.Parameter(torch.zeros(1))
        self.mults = nn.Parameter(torch.ones(1))
        self.register_buffer("mean", torch.zeros(1))
        self.register_buffer("var", torch.ones(1))
    
    def update_stats(self, x):
        mean = x.mean(dim=0, keepdim=True)
        var = x.var(dim=0, keepdim=True)
        
        self.mean.lerp_(mean, self.mom)
        self.var.lerp_(var, self.mom)
        return mean, var
    
    def forward(self, x):
        with torch.no_grad():
            mean, var = self.update_stats(x)
        
        x = (x-mean)/(var+self.eps)
        x = self.adds + self.mults*x
        return x

In [None]:
W_E = model.W_E
W_O = model.W_O[layer_idx, head_idx]
W_V = model.W_V[layer_idx, head_idx]

In [23]:
from transformer_lens import FactoredMatrix

In [None]:
OV_circuit = FactoredMatrix(W_V, W_O)

In [None]:
full_OV_circuit = W_E @ OV_circuit @ W_E.T

In [None]:
class ResidualLayerNorm(nn.Module):
    def __init__(self, d_model, dropout):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

In [None]:
step 1: determine global rank
step 2: initialize global dist group
step 3: initialize parallel groups
step 4: set device
step 5: set seed

In [None]:
step 1: mask targets
step 2: compute predicted logits
step 3: all-reduce
step 4: calculate log(sum(exp(logits)))
step 5: calculate the loss 

In [24]:
d_model = 16
d_head = 4

In [28]:
W_V = torch.zeros(d_head, d_model)
W_V[torch.arange(0, 4), torch.arange(0, 4)] = 1.

In [29]:
W_V

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [None]:
W_O = torch.zeros(d_model, d_head)
W_O[8:11, :] = torch.eye(3)

In [None]:
agent_rref.rpc_sync().init("hello")

In [31]:
class MultiHeadAttention(nn.Module):
    def __init__(self, attention, d_model, n_heads):
        super().__init__()
        self.attention = attention
        self.d_head = d_model // n_heads
        self.d_model = d_model
        self.n_heads = n_heads
        
        self.to_q = nn.Linear(d_model, d_model)
        self.to_k = nn.Linear(d_model, d_model)
        self.to_v = nn.Linear(d_model, d_model)
        self.projection = nn.Linear(d_model, d_model)
    
    def split_heads(self, x):
        return rearrange(x, "batch_size seq_len d_model -> \
            batch_size seq_len n_heads d_head \
        ", d_head=self.d_head)
    
    def concat_heads(self, x):
        return rearrange(x, "batch_size seq_len n_heads d_head -> \
            batch_size seq_len d_model \
        ", d_model = self.d_model)
    
    def forward(self, pre_q, pre_k, pre_v):
        q = self.to_q(pre_q)
        k = self.to_k(pre_k)
        v = self.to_v(pre_v)
        
        q = self.split_heads(q)
        k = self.split_heads(k)
        v = self.split_heads(v)
        
        attn_output, attn_weights = self.attention(q, k, v)
        output = self.concat_heads(attn_output)
        return self.projection(output)

In [None]:
def discount_reward(rewards, discount_factor):
    discount_rewards = []
    
    for reward

First, we need to write out the MLP function, which is $\operatorname{ReLU}\left(w^T W{\text {in }}\right) W_{\text {out }}$. Then, we look at the attention function. The attention function, before ReLU, is $\sum_h\left(\alpha^h t_0+\left(1-\alpha^h\right) t_1\right)^T W_{\text {neur }}^h$. Here, $W_{\text {neur }}^h$ is the matrix that determines how to get from a weighted average of initial embeddings to our neuron activations.

In [32]:
import torch

In [39]:
x = torch.randn(2, 4)

In [40]:
list(x.size())

[2, 4]

In [38]:
len(x.size())

2

In [41]:
torch.zeros(size=list(x.size()))

tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.]])