### Engineering

In [3]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.distributed.rpc as rpc

In [6]:
from einops import rearrange

In [7]:
class Reduce(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, parallel_context):
        group = parallel_context.get_group(ParallelMode.TENSOR)
        dist.all_reduce(input, group=group)
        return input
    
    @staticmethod
    def backward(ctx, grad_input):
        return (grad_input, None)

In [8]:
class _VocabParallelCrossEntropy(torch.autograd.Function):
    @staticmethod
    def forward(ctx, parallel_logits, parallel_context):
        def get_vocab_range():
            rank = parallel_context.get_local_rank(ParallelMode.TENSOR)
            partition_size = parallel_logits.shape[-1]
            start_idx = rank*partition_size
            end_idx = start_idx+partition_size
            return start_idx, end_idx
        
        vocab_start_idx, vocab_end_idx = get_vocab_range()
        target_mask = (targets < self.vocab_start_idx) | (targets >= self.vocab_end_idx)
        masked_targets = targets.clone() - self.vocab_start_idx
        masked_targets[target_mask] = 0.
            
        masked_targets_1d = rearrange(
            masked_targets,
            "batch_size seq_len -> (batch_size seq_len)"
        )
        parallel_logits = rearrange(
            parallel_logits,
            "batch_size seq_len vocab_size -> (batch_size seq_len) vocab_size"
        )
        predicted_logits = parallel_logits[torch.arange(targets.shape[0]), masked_targets_1d]
        predicted_logits = torch.where(masked_targets_1d == False, predicted_logits, 0.)
        
        predicted_logits = Reduce.apply(predicted_logits)
        
        exp_logits = torch.exp(parallel_logits).sum(dim=-1)
        exp_logits = Reduce.apply(exp_logits, parallel_context)
        loss = exp_logits.log() - predicted_logits
        return loss

In [5]:
class VocabParallelCrossEntropy(nn.Module):
    def __init__(self, parallel_context):
        super().__init__()
        self.parallel_context = parallel_context
    
    def forward(self, logits, targets):
        loss = _VocabParallelCrossEntropy.apply(logits, targets)

In [10]:
class Reduce(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, parallel_context):
        group = parallel_context.get_group(ParallelMode.TENSOR)
        dist.all_reduce(input, group=group)
        return input
    
    @staticmethod
    def backward(ctx, grad_input):
        return (grad_input, None)

In [9]:
class ParallelEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, parallel_context):
        super().__init__()
        
        world_size = parallel_context.get_world_size(ParallelMode.TENSOR)
        per_partition = num_embeddings // world_size
        
        self.weight = nn.Parameter(torch.randn(
            per_partition,
            embedding_dim
        ))
        self.vocab_start_idx, self.vocab_end_idx = self._get_vocab_range(
            per_partition,
            parallel_context
        )
    
    def _get_vocab_range(self, per_partition, parallel_context):
        rank = parallel_context.get_local_rank(ParallelMode.TENSOR)
        start_idx = rank*per_partition
        end_idx = start_idx+per_partition
        return start_idx, end_idx
    
    def forward(self, input):
        input_mask = (input < self.vocab_start_idx) | (input >= self.vocab_end_idx)
        masked_input = input.clone() - self.vocab_start_idx
        masked_input[input_mask] = 0.
        
        parallel_embeddings = F.embedding(masked_input, self.weight)
        parallel_embeddings[masked_input] = 0.
        
        embeddings = Reduce.apply(parallel_embeddings, parallel_context)
        
        return embeddings

In [14]:
class Scatter(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        world_size = dist.get_world_size()
        rank = dist.get_rank()
        
        chunks = torch.split(
            input,
            split_size_or_sections=input.shape[-1]//world_size
        )
        return chunks[rank]
    
    @staticmethod
    def backward(ctx, grad_input):
        world_size = dist.get_world_size()
        rank = dist.get_rank()
        
        grads = [torch.zeros_like(grad_input) for _ in range(world_size)]
        dist.all_gather(grads, grad_input)
        grads = torch.cat(grads)
        return grads

In [13]:
class Reduce(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        dist.all_reduce(input)
        return input
    
    @staticmethod
    def backward(ctx, grad_input):
        return (grad_input, None)

In [11]:
class RowParallelLinear(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        
        world_size = dist.get_world_size()
        inp_per_partition = input_size // world_size
        
        self.weight = nn.Parameter(torch.randn(
            output_size,
            inp_per_partition
        ))
        self.bias = nn.Parameter(torch.randn(
            output_size
        ))
    
    def forward(self, input):
        input_parallel = Scatter.apply(input)
        parallel_output = F.linear(
            input_parallel, self.weight
        )
        outputs = Reduce.apply(input_parallel)
        return outputs + self.bias

In [None]:
broadcast > gather > scatter > all-reduce

In [None]:
_, cache = model.run_with_cache(tokens)

In [15]:
layer_idx = 2

In [None]:
W_in_acts = cache["post", layer_idx]
W_out = model.W_out[layer_idx]

output = W_in_acts @ W_out

In [None]:
agent_rref.rpc_sync().your_mom(69)

In [None]:
minimize communication
maximize storage
minimize flops

In [None]:
int zero() {
    return 0;
}

In [None]:
int numberToColor(int x) {
    
}

In [None]:
#include <iostream>
using namespace std;

In [None]:
class Book() {
    public:
        string title;
    
        Book() {
            std::cout << "x";
        }
}

In [16]:
n_microbatches = 4
n_partitions = 3

In [17]:
n_clock_cycles = n_microbatches+n_partitions-1

In [23]:
for clock_idx in range(n_clock_cycles):
    start_partition_idx = max(clock_idx+1-n_microbatches, 0)
    end_partition_idx = min(clock_idx+1, n_partitions)
    
    xs = []
    for partition_idx in range(start_partition_idx, end_partition_idx):
        microbatch_idx = clock_idx - partition_idx
        xs.append((microbatch_idx, partition_idx))
    print(xs)

[(0, 0)]
[(1, 0), (0, 1)]
[(2, 0), (1, 1), (0, 2)]
[(3, 0), (2, 1), (1, 2)]
[(3, 1), (2, 2)]
[(3, 2)]


In [24]:
class _P2P:
    def send(self):
        pass

In [25]:
def send(data, src_rank, dst_rank, parallel_context, parallel_mode):
    rank = parallel_context.get_local_rank(paralllel_mode)
    if rank == src_rank:
        _P2P().send(data, dst_rank, parallel_context, parallel_model)

In [26]:
n_microbatches = 4
n_partitions = 3

In [27]:
n_clock_cycles = n_microbatches+n_partitions-1

In [28]:
for clock_idx in range(n_clock_cycles):
    start_partition_idx = max(clock_idx+1-n_microbatches, 0)
    end_partition_idx = min(clock_idx+1, n_partitions)
    
    xs = []
    for partition_idx in range(start_partition_idx, end_partition_idx):
        microbatch_idx = clock_idx-partition_idx
        xs.append((microbatch_idx, partition_idx))
    print(xs)

[(0, 0)]
[(1, 0), (0, 1)]
[(2, 0), (1, 1), (0, 2)]
[(3, 0), (2, 1), (1, 2)]
[(3, 1), (2, 2)]
[(3, 2)]


In [29]:
import threading

In [30]:
data = threading.local()

In [None]:
thread = threading.Thread(
    target=print_and_modify
)

In [None]:
class TensorParallelGroupInitializer(ProcessGroupInitializer):
    def init_dist_group(self):    
        num_tensor_parallel_groups = self.world_size // self.tensor_parallel_size
        process_group = None
        local_rank = None
        local_world_size = None
        ranks = None
        
        for i in range(num_tensor_parallel_groups):
            ranks = list(range(
                i*self.tensor_parallel_size,
                (i+1)*self.tensor_parallel_size
            ))
            process_group = dist.new_group(ranks)
            
            if self.rank in ranks:
                local_rank = ranks.index(self.rank)
                local_world_size = len(ranks)
        
                return {
                    "local"
                }

In [31]:
class _P2P:
    def _recv_metadata(self, src_rank, parallel_context, parallel_mode):
        group = parallel_context.get_group(parallel_mode)
        
        dtype = torch.tensor(0)
        dist.recv(dtype, src=src_rank, group=group)
        dtype = DTYPE_TO_ID[dtype]
        
        shape = torch.tensor(0)
        dist.recv(shape, src=src_rank, group=group)
        
        requires_grad = torch.tensor(0)
        dist.recv(requires_grad, src=src_rank, group=group)
        requires_grad = True if requires_grad == 1 else False
        
        return dtype, shape, requires_grad
    
    def recv(self, src_rank, parallel_context, parallel_mode):
        group = parallel_mode.get_group(parallel_mode)
        dtype, shape, requires_grad = self._recv_metadata(src_rank, parallel_context, parallel_mode)
        data = torch.zeros(shape, requires_grad=requires_grad, dtype=dtype)
        dist.recv(data, src=src_rank, group=group)
        return data

In [None]:
def recv(src_rank, dst_rank, parallel_context, parallel_mode):
    rank = parallel_context.get_local_rank(parallel_mode)
    if rank in dst_rank:
        _P2P().recv(src_rank, parallel_context, parallel_mode)

In [None]:
step 1: script
step 2: run
step 3: compare
step 4: notify

In [None]:
job selector > worker threads > pool watcher

In [None]:
step 1: determine global rank
step 2: resize embedding
step 3: parallelize embedding, linear, attn, layer norm
step 4: resize lm_head

In [None]:
received data
jobs
handshake

In [None]:
boundary, entity, interactor

In [32]:
from torch.multiprocessing import Process

In [None]:
processes = []

for rank in range(3):
    p = Process(target=say_hello, args=(rank,))
    p.start()
    processes.append(p)
    
for p in processes:
    

In [None]:
ready
begin forward
finsihed forward
finished backward
finished batch

In [35]:
class _P2P:
    def _send_metadata(self, dst_rank, parallel_context, parallel_mode):
        group = parallel_context.get_group(parallel_mode)
        dtype = torch.tensor(DTYPE_TO_ID[data.dtype])
        dist.send(dtype, dst=dst_rank, group=group)
        
        requires_grad = torch.tensor(
            1 if data.requires_grad == True else 0
        )
        dist.send(requires_grad, dst=dst_rank, group=group)
        
        shape = torch.tensor(data.shape.to_list())
        dist.send(requires_grad, dst=dst_rank, group=group)
    
    def send(self, data, dst_rank, parallel_context, parallel_mode):
        group = parallel_context.get_group(parallel_mode)
        self._send_metadata(data, dst_rank, parallel_context, parallel_mode)
        dist.send(data, dst=dst_rank, group=group)

In [36]:
def send(data, src_rank, dst_rank, parallel_context, parallel_mode):
    rank = parallel_context.get_local_rank(parallel_mode)
    
    if src_rank == rank:
        _P2P.send(data, dst_rank, parallel_context, parallel_mode)

In [None]:
step 1: initialize partitioned weight
step 2: mask targets
step 3: calculate local embedding
step 4: calculate global embedding

In [37]:
import threading

In [38]:
lock = threading.Lock()

In [39]:
def run():
    with lock:
        print_numbers()

In [None]:
thread = threading.Thread(target=run)
thread.start()

In [None]:
step 1: mark a point in the forward pass's computation graph
step 2: retrieve function, input
step 3: recompute activations, put them into a shared memory queue

In [None]:
pool watcher, worker threads, job selector

In [None]:
syncronization
received data queue
handhskae

In [None]:
scatter > reduce > identity > gather

In [None]:
net brackets
no negative

In [None]:
final_residal_stream @ (W_U[0]-W_U[1])