### Engineering

In [None]:
local rank, local world size, processgroup, ranks in group

step 1: determine a list of global rank in that group
step 2: if the process's global rank in that list, init a distributed group
step 3: 

In [None]:
class GreedySharding:
    def __init__(self, module, parallel_context):
        self.module = module
        self.parallel_context = parallel_context
        
    def shard(self):
        module = self.module
        
        self._shard_parameters()
        
        for name, param in module.named_parameters():
            pass
    
    def _shard_parameters(self):
        module = self.module
        world_size = self.parallel_context.get_world_size()
        
        for p in module.parameters():
            assert not hasattr(p, "_is_sharded")
            
            if world_size > 1:
                orig_data = p.data
                p.data = self._get_shard(p.data)
    
    def _get_shard(self, data):
        world_size = self.parallel_context.get_world_size()
        rank = self.parallel_context.get_rank()
        
        chunks = list(data.flatten().chunk(world_size))
        while len(chunks) < world_size:
            chunks.append(torch.empty(0))
        
        shard = chunks[rank].clone()
        num_to_pad = chunks[0].numel() - shard.numel()
        if num_to_pad > 0:
            shard = F.parallel_context

In [None]:
class Pipeline:
    def __init__(
        self,
        batches,
        partitions,
        devices,
        scheduler = DetermisticScheduler()
    ):
        self.batches = batches
        self.partitions = partitions
        self.devices = devices
        self.scheduler = scheduler
    
    def fit(self):
        batches = self.batches
        partitions = self.partitions
        devices = self.devices
        scheduler = self.scheduler
        
        with spawn_worker(devices) as (in_queues, out_queues):
            for schedule in scheduler:
                self._compute(schedule, in_queues, out_queues)
    
    def _compute(self, schedule, in_queues, out_queues):
        batches = self.batches
        
        for microbatch_idx, partition_idx in schedule:
            batch = batches[microbatch_idx]
            
            def task_func(microbatch, partition):
                def wrapper():
                    return partition(microbatch)
                return wrapper
            
            task = Task(compute=task_func)
            in_queues[partition_idx].put(task)
        
        for microbatch_idx, partition_idx in schedule:
            output = out_queues[partition_idx].get()
            batches[microbatch_idx].put(output)

In [None]:
step 1: determine a list of global ranks in that group
step 2: check whether the process's global rank in that list
step 3: if yes, init a distributed group
step 4: determine the process's local rank
step 5: save

In [2]:
from torch.utils.data import Sampler

In [None]:
class EvenSampler(Sampler):
    def __init__(self, data):
        super().__init__()
        self.data = data
    
    def __iter__(self):
        return [x for x in range(0, len(self.data), 2)]

In [None]:
void addTwoVectors(int* a, int* b, int* c, int total_elements):
    int gid = (threadBlock.x * threadDim.x) + threadIdx.x
    
    if (gid < total_elements) {
        c[gid] = a[gid] + b[gid]
    }

In [None]:
step 1: normalize the loss, loss / n_steps
step 2: sum the grad
step 3: if epoch = n_steps, then update, otherwise, repeat step 1

In [None]:
step 1: gather the weights
step 2: do the backward pass
step 3: release the non-relevant weights
step 4: reduce-scattern

In [None]:
embed = cache["embed"]
pos_embed = cache["pos_embed"]

In [3]:
components = []

In [None]:
components.append([embed, pos_embed])

In [4]:
from transformer_lens.utils import get_act_name

In [None]:
for layer_idx in range(3):
    mlp_name = get_act_name("mlp_out", layer_idx)
    attn_name = get_act_name("attn_out", layer_idx)
    components.append(cache[attn_name])
    components.append(cache[mlp_name])

In [None]:
W_O = model.W_O[layer_idx, head_idx]

In [6]:
import plotly.graph_objects as go

In [7]:
fig = go.Figure()

In [None]:
fig.add_trace(go.Line())

In [None]:
tokens = model.to_tokens(text)

In [None]:
_, cache = model.run_with_cache(
    tokens
)

In [None]:
embed = cache["embed"]
pos_embed = cache["pos_embed"]

In [None]:
heads = cache["result", layer_idx-1]

In [None]:
input_components = torch.cat([
    embed,
    pos_embed,
    heads,
], dim=0)

In [8]:
from einops import einsum

In [None]:
W_Q = model.W_Q[layer_idx]
query_components = einsum(
    input_components,
    W_Q
)

In [None]:
W_K = model.W_K[layer_idx]
key_components = einsum(
    input_components,
    W_K
)

In [None]:
decomposed_scores = einsum(
    query_components,
    key_components,
    ""
)

In [None]:
_, cache = model.run_with_cache(tokens)

In [None]:
hook_name = get_act_name("pattern", layer_idx)

In [None]:
attn_pattern = cache[hook_name][:, head_idx]

In [None]:
attn_pattern[target_query_positions].mean(dim=0) 

In [None]:
step 1: prob(0) = sigmoid(logit0 - logit1)
step 2: logit0 = resid2 @ W_U[0], logit1 = resid2 @ W_U[1]
step 3:
logit0 - logit1 = resid2 @ W_U[0] - resid2 @ W_U[1]
logit0 - logit1 = resid2 @ (W_U[0] - W_U[1])
step 4:
resid2 = resid1 @ ln2 @ W_OV^{0, 2}
step 5:
ln2 @ W_OV^{0, 2} @ (W_U[0] - W_U[1])

In [None]:
W_E = model.W_E
W_U = model.W_U

In [None]:
full_OV_circuit = W_E @ OV_circuit @ W_U

In [None]:
model.

In [None]:
deployment, configmap, service

In [None]:
step 1: determine the number of workers
step 2: flatten the params
step 3: chunks
step 4: add empty chunk if len(chunks) < world_size
step 5: get shard
step 6: pad

In [None]:
linear, layernorm, embedding, attention

In [9]:
import torch
from torch import nn
import torch.nn.functional as F

In [None]:
class GreedySharding:
    def __init__(self, module, parallel_context):
        self.module = module
        self.parallel_context = parallel_context
    
    def shard(self):
        self._shard_params()
    
    def _shard_params(self):
        
        for param in model.parameters():
            assert hasattr(param, "_is_sharded")
            
            orig_data = param.data
            shard = self._get_shard(param.data)
            free_memory(orig_data)
            param.data = shard
            param._is_sharded = True
    
    def _get_shard(self, data):
        world_size = self.parallel_context.get_world_size()
        rank = self.parallel_context.get_rank()
        
        chunks = data.flatten().chunks(world_size)
        
        while len(chunks) < world_size:
            chunks.append(torch.empty(0))
        
        shard = chunks[rank]
        num_to_pad = chunks[0].numel() - shard.numel()
        
        if num_to_pad > 0:
            shard = F.pad(shard, pad=num_to_pad)
        
        return shard

In [10]:
tensor_model_paralell_size = 2

In [11]:
num_tensor_model_parallel_groups = 8

In [12]:
for i in range(num_tensor_model_parallel_groups):
    ranks = range(
        i*tensor_model_paralell_size,
        (i+1)*tensor_model_paralell_size
    )
    print(ranks)

range(0, 2)
range(2, 4)
range(4, 6)
range(6, 8)
range(8, 10)
range(10, 12)
range(12, 14)
range(14, 16)


In [13]:
from torch.utils.data import DataLoader, random_split

In [None]:
train_set, test_set = random_split(dataset, lengths=[6, 4])

In [None]:
train_loader = DataLoader(train_set, batch_size=2)
test_loader = DataLoader(test_set, batch_size=2)

In [None]:
3, host, torchstate, elastic

In [None]:
step 1: partitioning
step 2: gather, do forward pass, release
step 3: gather, do backward pass, release
step 4: reduce-scatter
step 5: update

In [14]:
def probability_scores(image_embedding, text_embedding):
    image_norm = image_embedding.norm(dim-1, keepdim=True)
    image_embedding = image_embedding / image_norm
    
    text_norm = text_embedding.norm(dim=-1, keepdim=True)
    text_embedding = text_embedding / text_norm
    
    similarities = image_embedding @ text_embedding.T
    probs = F.softmax(
        similarities,
        dim=-1
    )
    return probs

In [None]:
requires_grad, dtype, shape

In [None]:
step 1: gather
step 2: register pre backward hook
step 3: do forward
step 4: register post backward hook
step 5: 

In [None]:
optic nerve > thalamus > visual cortex

In [15]:
import torch

In [40]:
p = torch.randn(1)

In [41]:
p2 = p + 2

In [42]:
p_tmp = p2.expand_as(p2)

In [43]:
p_tmp

tensor([2.5315])

In [44]:
p_tmp.grad_fn

In [31]:
linear = nn.Linear(1, 1)

In [33]:
list(linear.parameters())[0]

Parameter containing:
tensor([[-0.0319]], requires_grad=True)