### Engineering

In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
import copy

In [3]:
class ParameterSharding:
    def __init__(self, param_groups, parallel_context):
        self.param_groups = param_groups
        self.parallel_context = parallel_context
    
    def shard(self):
        world_size = self.parallel_context.get_world_size()
        partition_params = [[] for _ in range(world_size)]
        sizes = [0 for _ in range(world_size)]
        
        for param_group in self.param_groups:
            params = [[] for _ in range(world_size)]
            
            # split params in a param group
            for p in param_group["params"]:
                next_rank = sizes.index(min(sizes))
                params[next_rank].append(p)
                sizes[next_rank] += p.numel()
            
            # set partitioned params
            for rank, p in params:
                param_group_rank = copy.copy(param_group)
                param_group_rank["params"] = pg
                partition_params[rank] = param_group_rank
        
        return partition_params

In [None]:
(local_rank-1)%world_Size

In [None]:
deployments, services, secret 

In [None]:
kublet-proxy, kubelet, container runtime

In [4]:
import torch.distributed as dist

In [None]:
for group in groups:
    dist.barrier()
    dist.destroy_process_group(group)

In [9]:
class Scatter(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        world_size = dist.get_world_size()
        rank = dist.get_rank()        
        chunks = torch.split(input, input.shape[-1]//world_size)
        return chunks[rank]
    
    @staticmethod
    def backward(ctx, grad_input):
        world_size = dist.get_world_size()
        grads = [torch.empty_like(grad_input) for _ in range(world_size)]
        dist.all_gather(grads, grad_input)
        grads = torch.cat(grads)
        return grads

In [8]:
class Reduce(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        dist.all_reduce(input)
        return input
    
    @staticmethod
    def backward(ctx, grad_input):
        return grad_input

In [5]:
class RowParallelLinear(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        world_size = dist.get_world_size()
        inp_per_partition = input_size // world_size
        
        self.weight = nn.Parameter(torch.randn(
            output_size,
            inp_per_partition
        ))
        self.bias = nn.Parameter(torch.randn(
            output_size
        ))
        
    def forward(self, input):
        input_parallel = Scatter.apply(input)
        output_parallel = F.linear(input_parallel, self.weight)
        outputs = Reduce.apply(output_parallel)
        return outputs + self.bias

In [None]:
(local_rank+1)%local_world_size

In [None]:
step 1: fp16, fp32 of weight
step 2: do forward and backward using fp16
step 3: cast grads to fp32
step 4: update

In [None]:
rank*partition_size

step 1: a clean prompt and a corrupted prompt
step 2: record the interdimate activations of the clean prompt and corrupted prompt
step 3: run the clean prompt again
step 4: patch the activation from the corrupted: 4.2 and all components in between, patch other components from clean

In [10]:
from itertools import product

In [11]:
n_layers = 12
n_heads = 12

In [12]:
combinations = product(range(n_heads), range(n_layers))

In [14]:
list(combinations)

[(0, 0),
 (0, 1),
 (0, 2),
 (0, 3),
 (0, 4),
 (0, 5),
 (0, 6),
 (0, 7),
 (0, 8),
 (0, 9),
 (0, 10),
 (0, 11),
 (1, 0),
 (1, 1),
 (1, 2),
 (1, 3),
 (1, 4),
 (1, 5),
 (1, 6),
 (1, 7),
 (1, 8),
 (1, 9),
 (1, 10),
 (1, 11),
 (2, 0),
 (2, 1),
 (2, 2),
 (2, 3),
 (2, 4),
 (2, 5),
 (2, 6),
 (2, 7),
 (2, 8),
 (2, 9),
 (2, 10),
 (2, 11),
 (3, 0),
 (3, 1),
 (3, 2),
 (3, 3),
 (3, 4),
 (3, 5),
 (3, 6),
 (3, 7),
 (3, 8),
 (3, 9),
 (3, 10),
 (3, 11),
 (4, 0),
 (4, 1),
 (4, 2),
 (4, 3),
 (4, 4),
 (4, 5),
 (4, 6),
 (4, 7),
 (4, 8),
 (4, 9),
 (4, 10),
 (4, 11),
 (5, 0),
 (5, 1),
 (5, 2),
 (5, 3),
 (5, 4),
 (5, 5),
 (5, 6),
 (5, 7),
 (5, 8),
 (5, 9),
 (5, 10),
 (5, 11),
 (6, 0),
 (6, 1),
 (6, 2),
 (6, 3),
 (6, 4),
 (6, 5),
 (6, 6),
 (6, 7),
 (6, 8),
 (6, 9),
 (6, 10),
 (6, 11),
 (7, 0),
 (7, 1),
 (7, 2),
 (7, 3),
 (7, 4),
 (7, 5),
 (7, 6),
 (7, 7),
 (7, 8),
 (7, 9),
 (7, 10),
 (7, 11),
 (8, 0),
 (8, 1),
 (8, 2),
 (8, 3),
 (8, 4),
 (8, 5),
 (8, 6),
 (8, 7),
 (8, 8),
 (8, 9),
 (8, 10),
 (8, 11),
 (9, 0),
 

In [15]:
d_head, d_model = 4, 16

In [20]:
W_V = torch.zeros(d_head, d_model)
W_V[torch.arange(4), torch.arange(4)] = 1.

In [21]:
W_V

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [None]:
W_O = torch.zeros(d_model, d_head)
W_O[8:11] = torch.eye(4)

In [None]:
for hook_func in hooks:
    model.ln_f.register_forward_hook(hook_func)

In [None]:
step 1: diverse distribution
step 2: record all the interdimate neuron activations
step 3: color
step 4: 

In [None]:
resigter > cache > sram > disk > external disk

In [None]:
elasticdriver, torchstate, hostdiscovery, 3 notifs

In [None]:
ready, running, finished, failed, blacklisted, cool

In [None]:
step 1: determine a list of global ranks in that group
step 2: if the process's global rank in that list
step 3: initialize a parallel group
step 4: local rank
step 5: save

In [None]:
W_E = model.W_E

In [None]:
def compute_consine_similarity(x, y):
    x /= x.norm()
    y /= y.norm()
    return x

In [None]:
step 1: split weight
step 2: determine vocab_start_idx, end_idx

In [23]:
from copy import deepcopy

In [None]:
class ModelStateHandler:
    def __init__(self, model):
        self.set_value(model)
    
    def set_value(self, value):
        self.model = model
    
    def commit(self):
        self._saved_state_dict = deepcopy(self.model)
    
    def restore(self):
        self.model.load_state_dict(self._saved_state_dict)

In [None]:
step 1: initialize paritioned weight
step 2: determine 

In [None]:
lazy loading, data prefetch, memory mapping

In [None]:
step 1: b

In [None]:
reduce, scatter, gather, broadcast

In [None]:
clock cycle 1: F(0, 0)
clock cycle 2: F(1, 0), F(0, 1)
clock cycle 3: F(2, 0), F(1, 1)
clock cycle 4: F(2, 1)

In [None]:
step 1: set environment variables
step 2: init global distributed group
step 3: initialize parallel groups
step 4: set device
step 5: set seed

In [None]:
W_pos = model.W_pos

In [None]:
torch.cosine_similarity(W_pos[:, 0], W_pos[:, 1])

In [24]:
W_V = torch.zeros(d_head, d_model)

In [25]:
W_V[torch.arange(3), torch.arange(3)] = 1.

In [26]:
W_V

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [29]:
W_O = torch.zeros(d_model, d_head)
W_O[7:11, :] = torch.eye(4)

In [30]:
W_O

tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

In [None]:
step 1: early heads detect S2 token. Write s2 is 
step 2: middle head 

In [None]:
_, cache = model.run_with_cache(board_history)

In [None]:
mlp_neurons = cache[hook_name][:, 1393]

In [None]:
threashold = mlp_neurons.quantile(0.99)

In [None]:
top_neurons = mlp_neurons > threashold

In [None]:
(board_states == 2)[:, top_neurons].float().mean()

In [None]:
cache["result", 1]

In [None]:
handles = []
for hook_func in hooks:
    handles.append(model.ln_f.register_forward_pre_hook(hook_func))

In [None]:
handles[1].remove()

In [None]:
step 1: q = x @ W_Q
step 2: x = embed + pos_embed + sum(12 heads)
step 3: q = [embed + pos_embed + sum(12 heads)] @ W_Q
step 4: q = embed @ W_Q + pos_embed @ W_Q + ...

In [None]:
A @ x @ W_OV

In [None]:
tokens = model.to_tokens(text)

In [None]:
_, cache = model.run_with_cache(tokens)

In [None]:
input_components = torch.tensor([
    cache["embed"],
    cache["pos_embed"],
    cache["result", prev_layer_idx]
])

In [32]:
from einops import einsum

In [None]:
W_Q = model.W_Q[layer_idx, head_idx]
query_components = einsum(
    input_components,
    W_Q
)

In [None]:
query_contributions = query_components.pow(2).sum(dim=-1)

In [None]:
_, cache = model.run_with_cache(tokens)

In [None]:
target_tokens = tokens[1:]

W_U = model.W_U
W_U_target_tokens = W_U[:, target_tokens]

embed = cache["emed"][:-1]

In [None]:
embed_attributions = einsum(
    embed,
    W_U_target_tokens,
)

In [None]:
scores = {
    "q"
}

In [None]:
W_Q, W_K = model.W_Q, model.W_K
W_V, W_O = model.W_V, model.W_O

In [33]:
from einops import rearrange

In [None]:
W_QK = torch.matmul(
    W_Q,
    rearrange(W_K, "... d_model d_head -> ... d_head d_model")
)
W_OV = W_V @ W_O

In [None]:
def compute_composition_scores(W_A, W_b)

In [None]:
step 1: weight
step 2: 

In [None]:
scores = {
    "q": torch.zeros(n_heads, n_heads),
    "k": torch.zeros(n_heads, n_heads),
    "v": torch.zeros(n_heads, n_heads)
}

In [34]:
from einops import rearrange

In [None]:
W_Q, W_K = model.W_Q, model.W_K
W_O, W_V = model.W_O, model.W_V

In [None]:
W_QK = torch.matmul(
    W_Q,
    rearrange(W_K, "... d_model d_head -> ... d_head d_model")
)

In [None]:
W_OV = W_V @ W_O

In [None]:
def compute_composition_score(W_A, W_B):
    W_AB_norm = (W_A @ W_B).pow(2).sum().sqrt()
    W_A_norm = W_A.pow(2).sum().sqrt()
    W_B_norm = W_B.pow(2).sum().sqrt()
    return W_AB_norm / (W_A_norm * W_B_norm)

In [None]:
for i in range(n_heads):
    for j in range(n_heads):
        scores["q"] = compute_attn_score(W_OV[0, i], W_QK[0, j])
        scores["k"] = compute_attn_score(W_OV[0, i], W_QK[0, j].T)
        scores["v"] = compute_attn_score(W_OV[0, i], W_OV[0, j])

In [None]:
step 1: head0 = L0H00(pre_resid)
step 2: mid_resid = pre_resid + head0
stpe 3: mlp0 = MLP0(mid_resid)
step 4: resid0 = mid_resid + mlp0
step 5: mlp1 = MLP1(resid0)
step 6: resid1 = resid0 + mlp1

In [None]:
step 1: diverse distribution
step 2: record the attention pattern of the target head
step 3: determine the query position
step 4: average
step 5: plot

In [None]:
A@x@W_OV@W_OV

In [36]:
corrupted_prompt = [
    "When X and Y went to the shops, Z gave the bag to",
    "When K and V went to the park, H gave the ball to"
]

In [None]:
clean_tokens = model.to_tokens(clean_prompts)
corrupted_tokens = model.to_tokens(corrupted_prompt)

In [None]:
correct_tokens = model.to_tokens("Mary Tom")
incorrect_tokens = model.to_tokens("John James")

In [None]:
clean_logits, _ =  = model.run_with_cache(clean_tokens)
corrupted_logits, _ = model.run_with_cache(corrupted_tokens)

In [37]:
def compute_avg_logit_diff(logits, correct_tokens, incorrect_tokens):
    final_logits = logits[:, -1, :]
    return (final_logits[:, correct_tokens] - final_logits[:, incorrect_tokens]).mean()

In [None]:
clean_logit_diff = compute_avg_logit_diff(clean_logits, correct_tokens, incorrect_tokens)
corrupted_logit_diff = compute_avg_logit_diff(corrupted_logits, correct_tokens, incorrect_tokens)

In [None]:
def compute_ioi_metric(patched_metric):
    patched_logit_diff = compute_avg_logit_diff(patched_logit_diff)
    return (patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff-corrupted_logit_diff)

In [None]:
step 1: prob[0] = sigmoid(logit0 - logit1)
step 2: logit0 = resid2 @ W_U[0], logit1 = resid2 @ W_U[1]
step 3: logit0 - logit1 = resid2 @ (W_U[0] - W_U[1])
step 4: resid2 = resid1 @ ln1 @ W_OV^{2, 0}
step 5: target_direction = resid1 @ ln1 @ W_OV @ (W_U[0] - W_U[1])

In [None]:
tokens = model.to_tokens(text)

In [None]:
logits = model.run_with_hooks(tokens)

In [None]:
log_probs = F.log_softmax(logits[:, -1, :], dim=-1)

In [None]:
target_tokens = tokens[1:]

In [None]:
predicted_log_probs = -torch.gather(log_probs, dim=1, index=target_tokens)

In [None]:
local_rank * partition_vocab_size

In [None]:
step 1: f1
step 2: 

In [None]:
broadcast, gather

In [None]:
weights = attn_weights.diagonal(dim1=-2, dim2=-1, offset=-1)

In [None]:
broadcast > gather > scatter > all-reduce

In [39]:
from einops import rearrange

In [None]:
W_O = model.W_O
W_V = model.W_V
W_K = model.W_K
W_Q = model.W_Q


W_QK = torch.matmul(
    W_Q,
    rearrange(W_K, "... d_model d_head -> ... d_head d_model")
)
W_OV = W_V @ W_O

In [None]:
scores = {
    "Q": torch.zeros(n_heads, n_heads),
    "K": torch.zeros(n_heads, n_heads),
    "V": torch.zeros(n_heads, n_heads)
}

In [38]:
def compute_composition_score(W_A, W_B):
    W_AB_norm = (W_A@W_B).pow(2).sum()
    W_A_norm = W_A.pow(2).sum()
    W_B_norm = W_B.pow(2).sum()
    return (W_AB_norm)/(W_A_norm*W_B_norm)

In [None]:
for i in range(n_heads):
    for j in range(n_heads):
        scores["Q"][i][j] = compute_composition_score(W_OV[0, i], W_QK[1, j])
        scores["K"][i][j] = compute_composition_score(W_OV[0, i], W_QK[1, j].T)
        scores["V"][i][j] = compute_composition_score(W_OV[0, i], W_OV[1, i])

In [None]:
step 1: init partitioned weight
step 2: mask input
step 3: reduce
step 4:

In [None]:
step 1: head0 = L0H00(pre_resid)
step 2: resid1 = pre_resid + head0

In [None]:
step 1: wait
step 2: get
step 3: construct
step 4: put
step 5: wait
step 6: get
step 7: put

In [None]:
Scatter > Reduce > Identity > Gather

In [None]:
_, clean_cache = model.run_with_cache(clean_tokens)
_, corrupted_cache = model.run_with_cache(corrupted_tokens)

In [40]:
def patch_head(
    head_vector,
    hook,
    clean_activations,
    corrupted_activations,
    target_head
):
    target_layer_idx = target_head[0]
    target_head_idx = target_head[1]
    if hook.layer() == target_layer_idx:
        head_vector[:, target_head_idx] = corrupted_activations[hook.name][:, target_head_idx]
    else:
        head_vector = clean_activations
    
    return head_vector

In [41]:
from itertools import product

In [None]:
combinations = product(range(n_layers), range(n_heads))

In [42]:
from transformer_lens.utils import get_act_name

In [43]:
from functools import partial

In [None]:
def resid_to_logit(cache):
    hook_name = get_act_name("resid_post", n_layers-1)
    resid = cache[hook_name][:, -1, :]
    return model.unembed(model.ln_final(resid))

In [None]:
results = torch.zeros(n_layers, n_heads)

In [None]:
for layer_idx, head_idx in combinations:
    model.reset_hooks()
    hook_name = get_act_name("z", layer_idx)
    hook_func = partial(
        patch_head,
        clean_activations=clean_activations,
        corrupted_activations=corrupted_activations,
        target_head=(layer_idx, head_idx)
    )
    
    _, cache = model.run_with_cache(
        clean_tokens,
        fwd_hooks=[(hook_name, hook_func)]
    )
    
    patched_logits = resid_to_logit(cache)
    results[layer_idx, head_idx] = compute_ioi_metric(patched_logits)

In [None]:
step 1: convert input tokens to fourier basis
step 2: do addition using trig identities
step 3: map back to logits

In [None]:
W_OV = model.W_V[0, 1] @ model.W_O[0, 1]
W_QK = model.W_Q[1, 2] @ model.W_K[1, 2]

In [None]:
virtual_weight = W_OV @ W_QK

In [None]:
_, cache = model.run_with_cache(tokens)

In [None]:
layer_idx = 2

In [None]:
inp_acts = cache["post", layer_idx]

In [None]:
outputs = inp_acts @ model.W_out[layer_idx]

In [None]:
class ResidualLayerNorm(nn.Module):
    def __init__(self, d_model, dropout):
        super().__init__()
        self.

In [None]:
step 1: determine global rank
step 2: initialize global group
step 3: initialize parallel groups
step 4: set device
step 5: set seed