Sure, let's break down the term $e W_{Q K}^{1.4}\left(x^{0.7}\right)^T$.

$e$ : This represents the token embedding of a particular token. For example, let's say the token is "cat". The embedding "e" will be a vector representation of this token, which encodes the semantic meaning of the word "cat".

$W_{Q K}^{1.4}$ : This represents the combined weights for the query and key matrices for the 4th head in the first layer. This combined matrix is what is used to transform the input embeddings to the query and key embeddings used in the attention mechanism.

$x^{0.7}$ : This represents the output of the 7th head in the 0th layer (i.e., the previous layer) of the Transformer. This output is essentially the representation of the token at that particular position, as "understood" by that attention head.

$e W_{Q K}^{1.4}$: This is the transformed query embedding of a token. It means that we are taking the token embedding "e" and transforming it using the query and key weight matrix $W_{Q K}^{1.4}$. This gives us the query embedding for the token, in the context of the 4th head in the 1st layer.

$\left(x^{0.7}\right)^T$ : This represents the transposed output of the 7th head in the 0th layer. Transposing this matrix allows us to do a matrix multiplication with the transformed token embedding.

Finally, the whole term $e W_{Q K}^{1.4}\left(x^{0.7}\right)^T$ represents the calculation of attention scores for the input token "cat" in the context of the 4th head in the 1st layer, particularly emphasizing the output of the 7th head from the previous layer. This score informs how much attention should be paid to this token when constructing the final sentence representation.

In the case of the induction head discussed in the question, this score represents how much the model is attending to the previous tokens in the sequence to form the next-token prediction.

### MechInterp

In [1]:
import streamlit as st

In [3]:
st.sidebar.header(1)

2023-07-09 04:22:04.424 
  command:

    streamlit run /Users/education/opt/anaconda3/envs/ml_engineering/lib/python3.8/site-packages/ipykernel_launcher.py [ARGUMENTS]


DeltaGenerator(_root_container=1, _parent=DeltaGenerator())

In [None]:
def adjust():
    layer_idx = st.sidebar.slider("layer_idx", 0, 12, 5)
    head_idx = st.sidebar.slider("head_idx", 0, 12, 4)
    return layer_idx, head_idx

In [None]:
layer_idx, head_idx = adjust()

In [None]:
mid_resid = pre_resid + output_attention_layer

In [None]:
_, cache = model.run_with_cache(tokens)

In [None]:
attention_pattern = cache["pattern", 0, "attn"]

In [4]:
import circuitsvis as cv

In [None]:
token_strs = model.token_strs

In [None]:
cv.attention.attention_pattern(tokens=tokens, attention=attention_pattern)

In [5]:
def set_activation(activations, hook):
    activations[:, :, 4, :] = 0
    return activations

In [None]:
loss = model.run_with_cache(
    tokens,
    fwd_hooks=[(hook_name, set_activations)],
    return="loss"
)

In [None]:
cache["attn", 0]

In [None]:
model.embed(text)

In [6]:
from itertools import product

In [None]:
combinations = product(range(n_heads), range(n_layers))

In [None]:
step 1: residual = embed(tokens) + unembed(tokens)
step 2: reisudal = blocks(residual)
step 3: residual = ln_final(residual)
step 4: logits = unembed(residual)

In [None]:
W_pos = model.W_pos
W_Q = model.W_Q[layer_idx, head_idx]
W_K = model.W_K[layer_idx, head_idx]

In [None]:
pos_by_pos_scores = W_pos @ W_Q @ W_K @ W_pos.T

In [7]:
import torch
from torch import nn
import torch.nn.functional as F

In [None]:
def mask_score(scores):
    mask = torch.triu(torch.ones_like(scores)).bool()
    neg_inf = torch.tensor(-1e9)
    return torch.where(mask, scores, neg_inf)

In [None]:
masked_pos_by_pos_score = mask_score(pos_by_pos_scores / (d_head**0.5))

In [None]:
pos_by_pos_pattern = F.softmax(masked_pos_by_pos_score, dim=-1)

In [None]:
_, cache = model.run_with_cache(tokens)

In [None]:
final_residual_stream = cache[final_residual_name]

In [None]:
last_token_final_residual_stream = cache[:, -1, :]

In [None]:
scaled_last_token_final_residual_stream = model.apply_ln_to_stack(
    last_token_final_residual_stream,
    layer=-1,
    pos_slice=-1
)

In [None]:
W_U = model.W_U

In [None]:
correct_residual_direction = W_U[: correct_token]
incorrect_residual_direction = W_U[: incorrect_token]

In [None]:
logit_difference_direction = correct_residual_direction - incorrect_residual_direction

In [8]:
from einops import einsum

In [None]:
logit_difference = einsum(
    scaled_last_token_final_residual_stream,
    logit_difference_direction,
    ""
)

In [None]:
tokens = model.to_tokens(text)

In [None]:
_, cache = model.run_with_cache(tokens)

In [None]:
embed = cache["hook_embed"]
pos_embed = cache["hook_pos_embed"]

In [None]:
head_outputs = cache["result", prev_layer_idx]

In [None]:
input_components = torch.cat([
    embed, pos_embed, head_outputs
], dim=0)

In [None]:
W_Q = model.W_Q[next_layer_idx, head_idx]

In [None]:
query_components = einsum(input_components, W_Q)

In [None]:
query_contributions = query_components.pow(2).sum(dim=-1)

In [None]:
cache["blocks.0.attn.hook_pattern"]

In [None]:
pos_resid = mlp + mid_resid

In [None]:
result: the output from all attention heads
step 2: the output of the whole attention layer

### Engineering

In [9]:
import functools

In [10]:
def my_decorator(func):
    @functools.wraps
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    return wrapper

In [None]:
def compute_total_memory(model):
    total_memory = 0
    
    for param in model.parameters():
        total_memory += param.numel() * param_siz

In [None]:
#include <iostream>

In [None]:
using namespace std;

In [None]:
class Book() {
    public:
        string title;
    
        string Book(aTitle) {
            title = aTitle;
        }
}

In [None]:
for (; i<10; i++) {
    std::cout << "i = " << i << std::endl;
}

In [None]:
typedef int age;

In [11]:
from torch.utils.data import Sampler

In [None]:
class EvenSampler(Sampler):
    def __init__(self, data):
        self.data = data
    
    def __iter__(self):
        return iter([x for x in range(0, len(self.data), 2)])

In [12]:
def compute_forward_pass_using_data_parallelism(model, input, device_ids, output_id):
    models = nn.parallel.replicate(model, device_ids)
    inputs = nn.parallel.scatter(input, device_ids)
    
    output_parallel = nn.parallel.parallel_apply(models, inputs)
    outputs = nn.parallel.gather(output_parallel, target_device=output_id)
    return outputs

In [13]:
ranks = [0, 1, 3, 6]

In [None]:
rank = torch.distributed.get_rank()

In [None]:
group = torch.distributed.ProcessGroup([0, 1, 3, 6])

In [None]:
if rank == 0:
    torch.distributed.send(x, group=group)

In [None]:
int* h_a, int* h_b, innt* h_c;

In [None]:
size_t bytes = sizeof(int) * n

In [None]:
h_a = (int*)malloc(bytes)
h_b = (int*)malloc(bytes)
h_c = (int*)malloc(bytes)

In [None]:
def wait_stream(source_stream, target_stream):
    if isinstance(target_stream, torch.cuda.Stream):
        if isinstance(source_stream, torch.cuda.Stream):
            # GPU waits for GPU
            source_stream.wait_stream(target_stream)
        else:
            target_stream.syncronous()

In [15]:
class Wait(torch.autograd.Function):
    @staticmethod
    def forward(ctx, prev_stream, next_stream, input):
        ctx.prev_stream = prev_stream
        ctx.next_stream = next_stream
        
        wait_stream(
            source_stream=next_stream,
            target_stream=prev_stream,
        )
        
        return input
    
    @staticmethod
    def backward(ctx, grad_input):
        wait_stream(
            source_stream=ctx.prev_stream,
            target_stream=ctx.next_stream
        )
        return tuple([None, None] + [grad_input])

In [16]:
import socketserver

In [None]:
with socketserver.ThreadingTCPServer(
    (MASTER_HOST, MASTER_PORT),
    EchoRequestHandler
):
    pass

In [19]:
import copy

In [20]:
class ModelStateHandler:
    def __init__(self, model):
        self.value = model
    
    def restore(self):
        self.value.load_state_dict(self._model_state_dict)
    
    def save(self):
        self._model_state_dict = copy.deepcopy(self.value.state_dict())
    
    def sync(self):
        broadcast_parameters(self.value)

In [21]:
def get_handler(v):
    for handler_type, handler_cls in handler_registry:
        if isinstance(v, handler_cls):
            return handler_cls(v)
    return None

In [22]:
def get_handlers(states):
    handlers = {}
    remainders = {}
    
    for k, v in states:
        handler = get_handler(v)
        if handler is None:
            remainders[v] = handler
        else:
            handlers[v] = handler
    return handlers, remainders

In [23]:
n_microbatches, n_partritions = 4, 3

In [24]:
n_clock_cycles = n_microbatches + n_partritions - 1

In [26]:
for clock_idx in range(n_clock_cycles):
    start_partrition = max(clock_idx+1-n_microbatches, 0)
    end_partrition = min(clock_idx+1, n_partritions)
    
    tasks = []
    for partrition_idx in range(start_partrition, end_partrition):
        microbatch_idx = clock_idx - partrition_idx
        tasks.append((microbatch_idx, partrition_idx))
    
    print(tasks)

[(0, 0)]
[(1, 0), (0, 1)]
[(2, 0), (1, 1), (0, 2)]
[(3, 0), (2, 1), (1, 2)]
[(3, 1), (2, 2)]
[(3, 2)]


In [None]:
rank = torch.distributed.get_rank()
ranks = [0, 1, 3, 6]
group = None

In [None]:
if rank in ranks:
    group = torch.distributed.ProcessGroup(ranks)

In [None]:
if group is not None:
    torch.distributed.send(x, src=0, group=group)

In [28]:
import socketserver

In [None]:
with socketserver.ThreadingTCPServer(
    (MASTER_HOST, MASTER_PORT),
    EchoRequestHandler
) as server:
    server.serve_forever()

In [29]:
from typing import OrderedDict

In [None]:
def split_model(model, balances, devices):
    layers = OrderedDict()
    partrition_idx = 0
    partritions = []
    
    for name, layer in model.named_children():
        layers[name] = layer
        
        if len(layers) == balances[partrition_idx]:
            partritions.append(nn.Sequential(layers).to())

In [None]:
step 1: ln1 = ln1(resid_pre)
step 2: head_outputs = head1(ln1) + head2(ln2)
step 3: mid_resid = head_outputs. + resid_pre
step 4: ln2 = ln2(mid_resid)
step 5: mlp = mlp(ln2)
step 6: post_resid = mlp + mid_resid

In [30]:
def split_model(model, balances, devices):
    partritions = []
    layers = {}
    partrition_idx = 0
    
    for name, layer in model.named_children():
        layers[name] = layer
        
        if len(layers) == balances[partrition_idx]:
            partritions.append(
                nn.Sequential(layers).to(devices[partrition_idx])
            )
            layers.clear()
            partrition_idx+=1
        
    return partritions

In [31]:
import gymnasium as gym

In [None]:
class ShowerEnv(gym.Env):
    def __init__(self):
        super().__init__()
        self.observa

In [None]:
for param in model.parameters():
    param.requires_grad = False

In [None]:
for idx, block in model.blocks:
    if idx >= 5:
        for param in block.parameters():
            param.requires_grad = True

In [None]:
for param in model.lm_head.parameters():
    param.requires_grad = True

for param in model.ln_final.parameters():
    param.requires_grad = True