### Engineering

In [74]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
from torch.utils.data import Dataset

In [3]:
class CachedDataset(Dataset):
    def __init__(self, filename):
        self.filename = filename
        self.data = None
    
    def prefetch(self):
        self.data = torch.load(self.filename)
        total_elements = sum([self.data[i].numel() for i in idxs])
        self.cache = torch.zeros(total_elements, dtype=self.data.dtype)
        
        offset = 0
        for i in idxs:
            n_elements = self.data[i].numel()
            
            self.cache[offset:offset+n_elements] = n_elements.view(dim=-1)
            offset += n_elements

In [None]:
self.cache[offset:offset+n_elements] = data[i].view(-1)

In [4]:
class ParallelContext:
    def init_rpc_workers(self):
        if self.pipeline_parallel_size > 1:
            init_method = f"tcp://{host}:"
            ranks = self.get_ranks_in_group(ParallelMode.PIPELINE)
            
            if torch.cuda.is_available():
                rpc_worker_map = {
                    rank: WORKER_NAME.format(rank)
                    for rank in ranks
                }

In [None]:
linear, attention, layernorm, embedding

In [None]:
step 1: determine global rank
step 2: resize embedding size
step 3: resize unembedding size
step 4: parallelize linear layers, embeddings, attention, layer norm

In [None]:
step 1: package
step 2: invoke rpc
step 3: receive 
step 4: execute based on the rpc call's logic

In [None]:
pool watcher, worker thread, job selector

In [8]:
import torch.distributed as dist

In [9]:
class _P2P:
    def recv(self, src_rank, parallel_context, parallel_mode):
        group = parallel_context.get_group(parallel_mode)
        
        dtype, requires_grad, shape = self._recv_metadata(src_rank, parallel_context, parallel_mode)
        
        data = torch.zeros(shape, requires_grad=requires_grad, dtype=dtype)
        dist.recv(data, src=src_rank, group=group)
        
        return data
    
    def _recv_metadata(self, src_rank, parallel_context, parallel_mode):
        group = parallel_context.get_group(parallel_mode)
        
        dtype = torch.zeros(1)
        dist.recv(dtype, src=src_rank, group=group)
        dtype = ID_TO_DTYPE[dtype]
        
        requires_grad = torch.zeros(1)
        dist.recv(requires_grad, src=src_rank, group=group)
        requires_grad = True if requires_grad == 1 else False
        
        shape = torch.zeros(1)
        dist.recv(shape, src=src_rank, group=group)
        
        return dtype, requires_grad, shape

In [None]:
def recv(src_rank, dst_rank, parallel_context, parallel_mode: ParallelMode.PIPELINE):
    rank = parallel_context.get_local_rank(ParallelMode.PIPELINE)
    if rank == dst_rank:
        return _P2P().recv(dst_rank, parallel_context, parallel_mode)

In [None]:
maximize storage
minimize communication
minimize flops

In [None]:
step 1: mask targets
step 2: local_predicted_logits
step 3: global_predicted_logits
step 4: log(...)
step 5: loss = log(...) - global_predicted_logits

In [None]:
- monitor node changes
- 

In [None]:
rank*partition_size

In [None]:
start_idx+partition_sze

In [14]:
class Broadcast(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input
    
    @staticmethod
    def backward(ctx, grad_input):
        dist.all_reduce(grad_input)
        return grad_input

In [13]:
class Gather(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        world_size = dist.get_world_size()
        inputs = [torch.zeros_like(input) for _ in range(world_size)]
        dist.all_gather(inputs, input)
        inputs = torch.cat(inputs, dim=-1)
        return inputs
    
    @staticmethod
    def backward(ctx, grad_input):
        world_size = dist.get_world_size()
        rank = dist.get_rank()
        
        per_partition = grad_input.shape[-1] // world_size
        chunks = torch.split(
            grad_input,
            split_size_or_sections=per_partition,
            dim=-1
        )
        
        return chunks[rank]

In [10]:
class ColumnParallelLinear(nn.Module):
    def __init__(self, input_size, output_size, world_size):
        super().__init__()
        
        per_partition = output_size // world_size
        
        self.weight = nn.Parameter(torch.randn(
            per_partition, input_size
        ))
        self.bias = nn.Parameter(torch.randn(per_partition))
    
    def forward(self, input):
        parallel_input = Broadcast.apply(input)
        parallel_output = F.linear(parallel_input, self.weight, self.bias)
        output = Gather.apply(parallel_output)
        return output

In [17]:
class Reduce(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, parallel_context, parallel_mode):
        group = parallel_context.get_group(parallel_mode)
        dist.all_reduce(input, group=group)
        return input
    
    @staticmethod
    def backward(ctx, grad_input):
        return grad_input

In [18]:
class ParallelEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, parallel_context):
        super().__init__()
        world_size = parallel_context.get_world_size(ParallelMode.TENSOR)
        num_embeddings_per_partition = num_embeddings // world_size
        
        self.weight = nn.Parameter(torch.randn(
            num_embeddings_per_partition,
            embedding_dim
        ))
        self.vocab_start_idx, self.vocab_end_idx = self._get_vocab_range(
            num_embeddings_per_partition,
            parallel_context.get_local_rank(ParallelMode.TENSOR)
        )
        self.parallel_context = parallel_context
    
    def _get_vocab_range(self, partition_size, rank):
        start_idx = partition_size*rank
        end_idx = start_idx+partition_size
        return start_idx, end_idx
    
    def forward(self, inputs):
        mask_input = (input < self.vocab_start_idx) | (input >= self.vocab_end_idx)
        masked_input = input.clone() - self.vocab_start_idx
        masked_input[mask_input] = 0
        
        parallel_embeddings = F.embedding(masked_input, self.weight)
        parallel_embeddings[masked_input] = 0.
        
        embeddings = Reduce.apply(
            parallel_embeddings,
            self.parallel_context,
            parallel_mode=ParallelMode.TENSOR
        )
        
        return embeddings

In [None]:
step 1: sharding
step 2: mask targets
step 3: local_embeddings
step 4: global_embedding

In [None]:
W_E = model.W_E
open_embeddings = W_E[:, open_idx]
close_embeddings = W_E[:, close_idx]

In [19]:
layer_idx, head_idx = 0, 0

In [None]:
W_OV = model.W_V[layer_idx, head_idx] @ model.W_O[layer_idx, head_idx]

In [None]:
W_E = model.W_E
open_embedding = W_E[:, open_idx]
close_embedding = W_E[:, close_idx]

In [20]:
layer_idx, head_idx = 0, 0

In [None]:
W_OV = model.W_V[layer_idx, head_idx] @ model.W_O[layer_idx, head_idx]

In [None]:
open_embedding = open_embedding @ layer0_ln_coefs.T @ W_OV
close_embedding = close_embedding @ layer0_ln_coefs.T @ W_OV

In [None]:
similarity = torch.cosine_similarity(open_embedding, close_embedding)

In [21]:
def patch_sender_head_output(
    acts, hook,
    clean_cache, corrupted_cache,
    target_head
):
    trg_layer_idx, trg_head_idx = target_head
    if hook.layer() == trg_layer_idx:
        corrupted_acts = corrupted_cache[hook.name]
        acts[:, :, trg_head_idx] = corrupted_acts[:, :, trg_head_idx]
    else:
        acts = clean_cache[hook.name]
    
    return acts

In [23]:
from functools import partial
from itertools import product
from transformer_lens.utils import get_act_name

In [None]:
_, clean_cache = model.run_with_cache(clean_tokens)
_, corrupted_cache = model.run_with_cache(corrupted_tokens)

In [None]:
receiver_heads = [(7, 3), (7, 9), (8, 6), (8, 10)]
receiver_layer_idxs = [7, 8]

In [None]:
sender_heads = list(product(
    range(max(receiver_layer_idxs)),
    range(model.cfg.n_heads)
))

In [29]:
receiver_names = [get_act_name("v", layer_idx) for layer_idx in [7, 8]]

In [30]:
receiver_names

['blocks.7.attn.hook_v', 'blocks.8.attn.hook_v']

In [None]:
def patch_receiver_head_input(acts, hook, )

In [None]:
for layer_idx, head_idx in sender_heads:
    model.reset_hooks()
    filter_head_output = lambda name: name.endswith("z")
    hook_func = partial(
        patch_sender_head_output,
        clean_cache=clean_cache,
        corrupted_cache=corrupted_cache,
        target_head=(layer_idx, head_idx)
    )
    
    model.add_hook(filter_head_output, hook_func)
    _, patched_cache = model.run_with_cache(clean_tokens)
    
    
    hook_func

In [31]:
import torch.distributed.rpc as rpc

In [None]:
class ParallelContext:
    def init_rpc_worker(self, host, port):
        if self.pipeline_parallel_size > 1:
            init_method = f"rpc://{host}:{port}"
            rank = self.get_global_rank()
            world_size = self.get_world_size(ParallelMode.GLOBAL)
            
            options = rpc.RpcBackendOptions(
                init_method=init_method
            )
            
            if torch.cuda.is_available():
                ranks = self.get_ranks_in_group(ParallelMode.PIPELINE)
                
                rpc_worker_map = {
                    rank: WORKER_NAME.format(rank)
                    for rank in ranks
                }
                
                for other in ranks:
                    if other == rank:
                        continue
                    
                    options.set_device_map()
            
            rpc.init_rpc(
                name=WORKER_NAME.format(name),
                rank=rank,
                world_size
            )

In [None]:
step 1: qkv
step 2: split
step 3: self
atep 4: 

In [None]:
step 1: size
step 2: split
step 3: x
step 4: gather

In [33]:
from contextlib import contextmanager
from queue import Queue
import threading

In [34]:
def run_worker(in_queue, out_queue):
    while True:
        task = in_queue.get()
        output = task()
        out_queue.put(output)

In [None]:
@contextmanager
def spawn_workers(devices):
    in_queues = []
    out_queues = []
    
    for device in devices:
        in_queue = Queue()
        out_queue = Queue()
        
        thread = threading.Thread(target=run_worker, daemon=True)
        thread.start()
        
        in_queues.append(in_queue)
        out_queues.append(out_queue)
    
    yield in_queues, out_queues

In [None]:
start, while

In [None]:
weights = attn_weights.diagonal(dim1=-2, dim2=-1, offset=-1)

In [35]:
def print_shape(input, _):
    print(input.shape)

In [None]:
model.blocks[1].register_forward_pre_hook(print_shape)

In [None]:
tokens = model.to_tokens(text)

In [None]:
hook_name = f"blocks.{layer_idx}.mlp.hook_post"

In [None]:
_, cache = model.run_with_cache(tokens)

In [None]:
top_tokens = cache[hook_name][:, neuron_idx].argmax(dim=-1)

In [None]:
W_OV = model.W_V[0, 1] @ model.W_O[0, 1]
W_QK = model.W_Q[1, 2] @ model.W_K[1, 2].T

virtual_weight = W_OV @ W_QK

In [None]:
softmax(x@W_Q@W_K.T@x.T) @ x @ W_V @ W_O

In [None]:
tokens = model.to_tokens(text)

In [None]:
logits = model(text)

In [None]:
log_probs = F.log_softmax(logits[:, -1, :], dim=-1)
predicted_log_probs = -log_probs[:, tokens[1:]]

In [None]:
induction_stripe = attn_weights.diagonal(
    dim1=-2,
    dim2=-1,
    offset=4-1
)

In [None]:
_, clean_cache = model.run_with_cache(clean_tokens)
_, corrupted_cache = model.run_with_cache(corrupted_tokens)

In [41]:
receiver_heads = [(7, 3), (7, 9), (8, 6), (8, 10)]
receiver_layer_idxs = [7, 8]

In [47]:
from itertools import product
from functools import partial

In [43]:
n_heads = 12

In [45]:
sender_heads = list(product(range(max(receiver_layer_idxs)), range(n_heads)))

In [54]:
def patch_sender_head_output(acts, hook, clean_cache, corrupted_cache, target_head):
    trg_layer_idx, trg_head_idx = target_head
    
    if hook.layer() == trg_layer_idx:
        acts[:, :, head_idx] = corrupted_cache[hook.name][:, :, head_idx]
    else:
        acts = clean_cache[hook.name]
    return acts

In [36]:
def pach_receiver_head_input(acts, hook, new_value):
    acts[:, :, ]

In [50]:
filter_sender_name = lambda x: x.endswith("z")

In [51]:
receiver_input_names = [get_act_name("v", layer_idx) for layer_idx in [7, 8]]

In [52]:
receiver_input_names

['blocks.7.attn.hook_v', 'blocks.8.attn.hook_v']

In [53]:
filter_receiver_name = lambda x: x in receiver_input_names

In [None]:
for layer_idx, head_idx in sender_heads:
    model.reset_hooks()
    
    hook_func = partial(
        patch_sender_head_output,
        clean_cache=clean_cache,
        corrupted_cache=corrupted_cache,
        target_head=(layer_idx, head_idx)
    )
    
    model.add_hook(filter_sender_name, hook_func)
    _, patched_cache = model.run_with_cache(clean_tokens)
    
    model.reset_hooks()
    hook_func = 
    patched_logits = model.run_with_hooks(
        clean_tokens,
        fwd_hooks=[(filter_receiver_name, hook_func)]
    )

In [None]:
def wait_stream(source_stream, target_stream):
    if isinstance(target_stream, torch.cuda.Stream):
        if isinstance(source_stream, torch.cuda.Stream):
            source_stream.wait_stream(target_stream)
        else:
            target_stream.syncronous()

In [55]:
class Wait(torch.autograd.Function):
    @staticmethod
    def forward(ctx, prev_stream, next_stream, output):
        ctx.prev_stream = prev_stream
        ctx.next_stream = next_stream
        
        wait_stream(
            souce_stream=next_stream,
            target_stream=prev_stream
        )
        
        return output
    
    @staticmethod
    def backward(ctx, grad_input):
        prev_stream = ctx.prev_stream
        next_stream = ctx.next_stream
        
        wait_stream(
            source_stream=prev_stream,
            target_stream=next_stream
        )
        
        return tuple([None, None, grad_input])

In [None]:
min(clock_idx+1, n_partitions)

In [None]:
stream1 = torch.cuda.Stream()
stream2 = torch.cuda.Stream()

In [None]:
with torch.cuda.stream(stream1):
    x_mean = x.mean(dim=-1)
    
with torch.cuda.stream(stream2):
    y_mean = y.mean(dim=-1)

In [None]:
rank = dist.get_rank()

if rank == 0:
    dist.send(x, dst=1)
elif rank == 1:
    dist.recv(tensor_will_be_received_data, src=0)

In [56]:
class CachedDataset:
    def __init__(self, filename):
        self.filename = filename
        self.data = None
        self.cache = None
    
    def prefetch(self, idxs):
        self.data = torch.load(self.filename)
        total_elements = sum([self.data[i] for i in idxs])
        self.cache = torch.zeros(
            total_elements,
            dtype=self.data.dtype
        )
        
        offset = 0
        for i in idxs:
            num_elements = self.data[i].numel()
            self.cache[offset:offset+num_elements] = self.data[i].view(dim=-1)
            offset += num_elements

In [57]:
world_size = 16

In [58]:
tensor_model_parallel_size = 2

In [59]:
pipeline_model_parallel_size = 4

In [60]:
data_parallel_groups = []

In [61]:
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size

In [66]:
ranks = []

for i in range(pipeline_model_parallel_size):
    start_idx = i*num_pipeline_model_parallel_groups
    end_idx = (i+1)*num_pipeline_model_parallel_groups
    
    for j in range(tensor_model_parallel_size):
        ranks = list(range(
            start_idx+j,
            end_idx,
            tensor_model_parallel_size
        ))
        
        print(ranks)

[0, 2]
[1, 3]
[4, 6]
[5, 7]
[8, 10]
[9, 11]
[12, 14]
[13, 15]


In [None]:
tokens = model.to_tokens(text)

In [None]:
embed = model.embed
mlp = model.blocks[0].mlp
ln2 = model.blocks[0].ln2

In [None]:
text_embeddings = embed(tokens)
resid_after_mlp0 = text_embeddings + mlp(ln2(text_embeddings))

In [67]:
from transformer_lens.utils import get_act_name

In [68]:
def get_k(input_tokens, layer_idx, head_idx):
    hook_name = get_act_name("k", layer_idx)
    _, cache = model.run_with_cache(input_tokens)
    return cache[hook_name][:, :, head_idx]

In [None]:
k_open = get_k(all_open_tokens, layer_idx=0, head_idx=0)
k_close = get_k(all_close_tokens, layer_idx=0, head_idx=0)

In [None]:
k_avg = (k_open+k_close) / 2

In [69]:
def patch_k(acts, hook, new_k, head_idx):
    pass

In [None]:
hook_name = get_act_name("k", layer_idx)
hook_func = partial(
    patch_k,
    new_k=k_avg,
    layer_idx=0,
    head_idx=0
)

In [None]:
recompute, forward, backward

In [71]:
for i in range(pipeline_model_parallel_size):
    start_idx = i*num_pipeline_model_parallel_groups
    end_idx = (i+1)*num_pipeline_model_parallel_groups
    
    for j in range(tensor_model_parallel_size):
        ranks = list(range(
            start_idx+j,
            end_idx,
            tensor_model_parallel_size
        ))
        
        print(ranks)

[0, 2]
[1, 3]
[4, 6]
[5, 7]
[8, 10]
[9, 11]
[12, 14]
[13, 15]


In [72]:
from typing import Callable

In [73]:
def foo(func: Callable[[int, int], str]) -> str:
    pass

In [75]:
import torch.distributed as dist
import torch.distributed.rpc as rpc

In [None]:
ionic, covalent

In [None]:
two uncertainty principles
quantization of angular momentum
quantization of action

In [None]:
gravity, 

In [None]:
- two uncertain principles
- quantization of angular momentum
- quantization of action

In [None]:
strong, weak nuclear force, gravity, electromagitc force

In [None]:
biocompatible, reliable recordings, neural plasticity

In [None]:
recording, send, memory, process

In [None]:
state, reward, done, 