### Engineering

In [24]:
import torch
from torch import nn
import torch.nn.functional as F

In [25]:
import torch.distributed as dist
import torch.distributed.rpc as rpc

In [4]:
def compute_cross_entropy(logits, targets):
    log_probs = F.log_softmax(logits, dim=-1)
    return -log_probs[range(targets.shape[0]), targets].mean()

In [8]:
import numpy as np
from gymnasium.spaces import Discrete, Box

In [None]:
class ShowerEnv:
    def __init__(self):
        self.observation_space = Box(
            low=np.array([0]),
            high=np.array([100])
        )
        self.action_space = Discrete(3)
    
    def step(self):
        pass
    
    def reset(self):
        self.temperature = 20
        self.shower_length = 60
        return self.temperature

In [9]:
def discount_reward(rewards, discount_factor):
    return rewards*discount_factor

In [10]:
rewards = torch.tensor([1, 2, 3, 4])

In [11]:
discount_reward(rewards, 0.99)

tensor([0.9900, 1.9800, 2.9700, 3.9600])

In [None]:
step 1: group dispatching
step 2: local capacity constrain

In [None]:
step 1: determine global rank
step 2: 
step 3: parallelize embedding, head, mlp, layer norm
step 4: 

In [None]:
job selector > spawn initial workers > job monitor

In [None]:
syncronization, handshake, job queue

In [None]:
dist.broadcast(x, src=0, async_op=True)

In [None]:
work_handler = dist.broadcast(x, src=0, async_op=True)

In [None]:
work_handler.wait()

In [None]:
ready, running, failed, succeed, cooldown, blacklisted 

In [15]:
class Broadcast(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input
    
    @staticmethod
    def backward(ctx, grad_input):
        return dist.all_reduce(grad_input)

In [14]:
class Gather(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        world_size = dist.get_world_size()
        inputs = [torch.randn_like(input) for _ in range(world_size)]
        dist.all_gather(inputs, input)
        inputs = torch.cat(inputs, dim=0)
        return inputs
    
    @staticmethod
    def backward(ctx, grad_input):
        rank = dist.get_rank()
        world_size = dist.get_world_size()
        chunks = torch.chunk(grad_input, chunks=world_size)
        return chunks[rank]

In [12]:
class ColumnParallelLinear(nn.Module):
    def __init__(self, input_size, output_size, world_size):
        super().__init__()
        per_partition = output_size // world_size
        
        self.weight = nn.Parameter(torch.randn(
            self.per_partition,
            self.input_size
        ))
        self.bias = nn.Parameter(torch.randn(
            self.per_partition,
            self.input_size
        ))
        
    def forward(self, input):
        input_parallel = Broadcast.apply(input)
        output_parallel = F.linear(
            input_parallel,
            self.weight,
            self.bias
        )
        outputs = Gather.apply(output_parallel)
        return outputs

In [16]:
ranks = [0, 1, 3, 6]

In [None]:
rank = dist.get_rank()
process_group = None

In [None]:
if rank in ranks:
    process_group = dist.new_group(ranks=rank)

In [None]:
if process_group is not None:
    dist.broadcast(0, group=process_group)

In [None]:
message passing, shared memory, file system

In [17]:
world_size = 16

In [18]:
tensor_parallel_size = 2
pipeline_parallel_size = 4

In [20]:
num_pipeline_parallel_groups = world_size // pipeline_parallel_size

In [23]:
for i in range(pipeline_parallel_size):
    start_rank = i*num_pipeline_parallel_groups
    end_rank = (i+1)*num_pipeline_parallel_groups
    
    for j in range(tensor_parallel_size):
        ranks = list(range(
            start_rank+j,
            end_rank,
            tensor_parallel_size
        ))
        
        print(ranks)

[0, 2]
[1, 3]
[4, 6]
[5, 7]
[8, 10]
[9, 11]
[12, 14]
[13, 15]


In [None]:
elastic driver, torchstate, hostdiscovery,

In [None]:
inputs = inputs.view(-1)

In [None]:
probs = F.softmax(switch(inputs), dim=-1)

In [None]:
_, idxs = torch.max(probs, dim=-1)

step 1: tokens are split into G group that dispatch experts in parallel
step 2: calculate local capacity
step 3: 

In [None]:
q, k, v, output projection

In [29]:
from einops import rearrange, einsum

In [None]:
einsum(x, y, "batch dim, batch dim ->")

In [31]:
from typing import overload

In [32]:
@overload
def getitem(x: str) -> str:
    pass

In [33]:
from typing import List

In [34]:
@overload
def getitem(x: List[int]) -> int:
    pass

In [None]:
isinstance()

In [35]:
import torch.multiprocessing as mp

In [None]:
def run_worker(rank, world_size):
    rpc.init_rpc(
        name=AGENT_NAME.format(rank),
        rank=rank,
        world_size=world_size
    )

In [None]:
for rank in range(world_size):
    p = mp.Process(run_worker, args=(rank, world_size))
    p.start()

In [None]:
expert, router, loss

In [None]:
W_U = model.W_U
logit_diff_dir = W_U[:, 0] - W_U[:, 1]

In [None]:
_, cache = model.run_with_cache(tokens)

In [None]:
input_components = torch.cat([
    cache["embed"],
    cache["pos_embed"],
    cache[""]
])

In [None]:
softmax

In [37]:
layer_idx, head_idx = 2, 0

In [None]:
W_OV = model.W_V[layer_idx, head_idx] @ model.W_O[layer_idx, head_idx]

In [38]:
from transformer_lens.utils import get_act_name

In [None]:
pre_final_ln_name = get_act_name("resid_post", 2)
head20_pre_ln_name = get_act_name("resid_pre", 2)
head_20_post_ln_name = get_act_name("normalized", 2, "ln1")

In [None]:
_, cache = model.run_with_cache(tokens)

In [None]:
pre_final_ln_acts = cache[pre_final_ln_name]
post_final_ln_acts = cache[post_final_ln_acts]

In [40]:
head_idx, layer_idx = 6, 9

In [None]:
W_E = model.W_E
W_U = model.W_U

In [None]:
W_O = model.W_O[layer_idx, head_idx]
W_V = model.W_V[layer_idx, head_idx]

full_OV_circuit = W_E @ W_

In [None]:
from torch im