### Engineering

In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [3]:
import torch.distributed as dist
import torch.distributed.rpc as rpc

In [5]:
class DataParallelGroup:
    def __init__(self, module, parallel_context):
        self.module = module
        self.parallel_context = parallel_context
    
    def parallelize(self):
        module = self.module
        
        if self.data_parallel_size > 1:
            for p in module.paramaters():
                if p.requires_grad:
                    p.register_hook(_avg_grad_hook)
        
        return module
    
    def _avg_grad_hook(self, grad):
        data_parallel_size = self.parallel_context.data_parallel_size
        process_group = self.parallel_context.get_group(ParallelMode.DATA)
        grad /= data_parallel_size
        
        dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=process_group)
        
        return grad

In [None]:
class Checkpoint(torch.autograd.Function):
    @staticmethod
    def forward(ctx, phony, recomputed, function, input):
        ctx.recomputed = recomputed
        ctx.function = function
        
        with torch.no_grad():
            output = function(input)
        return output
    
    @staticmethod
    def backward(ctx, grad_input):
        output, input_leaf = ctx.recomputed.pop()
        
        with torch.enable_grad():
            torch.autograd.backward(output, grad_input)
        
        return 

In [None]:
step 1: register
step 2: user
step 3: pytorch trigger
step 4: run

In [6]:
world_size = 16

In [7]:
tensor_model_parallel_size = 2
pipeline_model_parallel_size = 4

In [8]:
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size

In [9]:
data_parallel_groups = []

In [10]:
for i in range(pipeline_model_parallel_size):
    start_rank = i*num_pipeline_model_parallel_groups
    end_rank = (i+1)*num_pipeline_model_parallel_groups
    
    for j in range(tensor_model_parallel_size):
        ranks = list(range(
            start_rank+j,
            end_rank,
            tensor_model_parallel_size
        ))
        print(ranks)

[0, 2]
[1, 3]
[4, 6]
[5, 7]
[8, 10]
[9, 11]
[12, 14]
[13, 15]


In [None]:
W_E = model.W_E
W_Q = model.W_Q[1, 4]
W_K = model.W_K[1, 4]
W_O = model.W_O[0, 7]
W_V = model.W_V[0, 7]

In [None]:
K = W_E @ W_V @ W_O @ W_K

In [None]:
Q = W_E @ W_Q

In [None]:
tokens = model.to_tokens(text)

In [None]:
_, cache = model.run_with_cache(
    tokens
)

In [None]:
W_Q = model.W_Q[layer_idx, head_idx]
W_K = model.W_K[layer_idx, head_idx]

In [None]:
components = torch.tensor([
    cache["embed"],
    cache["pos_embed"],
    cache["result", layer_idx-1]
])

In [None]:
from einops

In [None]:
decomposed_Q = 

In [11]:
import torch
from torch import nn
import torch.nn.functional as F

In [12]:
import torch.distributed as dist
import torch.distributed.rpc as rpc

In [None]:
class Number() {
    public:
        int value;
    
        isLargerThanZero() {
            
        }
}

In [None]:
mine_dir = linear_probe[..., 2]
their_dir = linear_probe[..., 1]

In [None]:
step 1: wrap
step 2: user
step 3: pytorch trigger
step 4: create job
step 5: execute

In [16]:
n_partitions = 5
n_microbatches = 3

In [17]:
n_clock_cycles = n_partitions + n_microbatches - 1

schedules = []
for clock_idx in range(n_clock_cycles):
    start_partrition = max(clock_idx + 1 - n_microbatches, 0)
    end_partition = min(clock_idx + 1, n_partitions)

    tasks = []
    for partition_idx in range(start_partrition, end_partition):
        microbatch_idx = clock_idx - partition_idx
        tasks.append((microbatch_idx, partition_idx))

    schedules.append(tasks)

In [18]:
schedules

[[(0, 0)],
 [(1, 0), (0, 1)],
 [(2, 0), (1, 1), (0, 2)],
 [(2, 1), (1, 2), (0, 3)],
 [(2, 2), (1, 3), (0, 4)],
 [(2, 3), (1, 4)],
 [(2, 4)]]

In [20]:
OUTPUT_METADATA = {
    (0, 0): (0, 1),
    (0, 1): (0, 2),
    
    (1, 0): (1, 1),
    (1, 1): (1, 2),
    
    (2, 0): (2, 1),
    (2, 1): (2, 2),
    
    
    (3, 0): (3, 1),
    (3, 1): (3, 2),
    
    (4, 0): (4, 1),
    (4, 1): (4, 2),
}

In [21]:
OUTPUT_METADATA[(2, 0)]

(2, 1)

In [23]:
world_size = 8

In [24]:
pipeline_parallel_size = 2

In [25]:
num_pipeline_parallel_groups = world_size // pipeline_parallel_size

In [26]:
num_pipeline_parallel_groups

4

In [31]:
groups = []

for i in range(num_pipeline_parallel_groups):
    ranks = list(range(i, world_size, num_pipeline_parallel_groups))
    groups.append(ranks)

In [33]:
groups

[[0, 4], [1, 5], [2, 6], [3, 7]]

In [36]:
def find_index(number, lst):
    return next((i for i, sublist in enumerate(lst) if number in sublist), None)

In [38]:
find_index(5, groups)

1