### Engineering

In [None]:
import torch
from torch import nn
import torch.nn.functional as F

In [None]:
class f(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input
    
    @staticmethod
    def backward(ctx, grad_output):
        torch.distributed.all_reduce(grad_output)
        return grad_output

In [None]:
class g(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        world_size = torch.distributed.get_world_size()
        input_list = [torch.empty_like(input) for _ in range(world_size)]
        torch.all_gather(input_list, input)
        input_list = torch.cat(input_list, dim=-1)
        return input_list
    
    @staticmethod
    def backward(ctx, grad_output):
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()
        
        dim_size = grad_output.shape[-1]
        dim_size_per_partrition = dim_size // world_size
        grad_chunks = torch.split(
            grad_output,
            split_size_or_sections=dim_size_per_partrition,
            dim=-1
        )
        return grad_chunks[rank]

In [None]:
class ColumnLinearParallel(nn.Module):
    def __init__(self, input_size, output_size, world_size):
        super().__init__()
        world_size = torch.distributed.get_world_size()
        self.input_size = input_size
        self.output_size_per_partrition = output_size // world_size
        
        self.weight = nn.Parameter(torch.empty(
            self.output_size_per_partrition,
            self.input_size
        ))
        self.bias = nn.Parameter(torch.empty(
            self.output_size_per_partrition
        ))
    
    def forward(self, input):
        input_parallel = f.apply(input)
        output_parallel = F.linear(
            input_parallel,
            self.weight,
            self.bias
        )
        outputs = g.apply(input_parallel)
        return outputs

In [None]:
tensor_model_parallel_size = 2

In [None]:
num_tensor_model_parallel_groups = 8

In [None]:
for i in range(num_tensor_model_parallel_groups):
    ranks = list(range(
        i*tensor_model_parallel_size,
        (i+1)*tensor_model_parallel_size
    ))
    
    print(ranks)

[0, 1]
[2, 3]
[4, 5]
[6, 7]
[8, 9]
[10, 11]
[12, 13]
[14, 15]


In [None]:
import os

In [None]:
class MPU:
    def __init__(
        self,
        tensor_model_parallel_size,
        master_addr,
        master_port,
        backend
    ):
        if not torch.distributed.is_initialized():
            self._initialize_distributed(
                master_addr=master_addr,
                master_port=master_port,
                backend=backend
            )
        
        world_size = torch.distributed.get_world_size()
        tensor_model_parallel_groups = world_size // tensor_model_parallel_size
        
    
    def _initialize_distributed(self):
        if not torch.distributed.is_initialized():
            
            torch.distributed.new_process_group(
                backend=backend
            )

In [None]:
def is_grad_enabled(inputs):
    return torch.is_grad_enabled() and inputs.requires_grad

In [None]:
def _broadcast(inputs):
    return inputs.clone()

In [None]:
class Broadcast(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs):
        return _broadcast(inputs)
    
    @staticmethod
    def backward(ctx, grad_outputs):
        return _reduce(inputs)

In [None]:
def broadcast_with_forward_and_backward(inputs):
    if is_grad_enabled(inputs):
        outputs = Broadcast.apply(inputs)
    else:
        outputs = _broadcast(inputs)
    return outputs

In [None]:
def create_continuous_memory(memory_size):
    FP16_SIZE = 4
    n_numbers = memory_size // FP16_SIZE
    return torch.empty(FP16_SIZE, type=torch.float16)

In [None]:
micro-batch 1 > 2 > ...

In [None]:
import os

In [None]:
class MPU:
    def __init__(self, master_addr, master_port, backend):
        if not torch.distributed.is_initialized():
            self._initialize_distributed(
                master_addr,
                master_port,
                backend
            )
            
    def process_to_gpu(self, rank):
        n_devices = torch.cuda.device_count()
        
        if n_devices > 0:
            torch.cuda.set_device(rank % n_devices)
    
    def _initialize_distributed(self, master_addr, master_port, backend):
        if not torch.distributed.is_initialized():
            RANK = int(os.getenv("RANK"))
            WORLD_SIZE = os.getenv("WORLD_SIZE")
            os.environ["MASTER_ADDR"] = str(master_addr)
            os.environ["MASTER_PORT"] = str(master_port)

            self.process_to_gpu(rank)
            
            torch.distributed.new_process_group(
                rank=rank,
                world_size=world_size,
                backend=backend
            )

In [None]:
forward pass, backward pass, recomputation

In [None]:
import time

In [None]:
def profile_times(model, batch):
    records = [[] for _ in range(model)]
    
    for i, layer in enumerate(model):
        start_time = time.time()
        outputs = [layer(x) for x in batch]
        outputs_with_grad = [x for x in outputs if x.requires_grad]
        
        if outputs_with_grad:
            torch.autograd.backward(outputs_with_grad, outputs_with_grad)
        
        end_time = time.time()
        records[i].append(end_time - start_time)
    
    return records

In [None]:
class f(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input.clone()
    
    @staticmethod
    def backward(ctx, grad_output):
        torch.distributed.all_reduce(grad_output)
        return grad_output

In [None]:
class g(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        world_size = torch.distributed.get_world_size()
        inputs = [torch.empty_like(input) for _ in range(world_size)]
        torch.distributed.all_gather(inputs, input)
        inputs = torch.cat(inputs, dim=-1)
        return input
    
    @staticmethod
    def backward(ctx, grad_output):
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()
        
        dim_size = grad_output.shape[-1]
        dim_size_per_partrition = dim_size // world_size
        grad_chunks = torch.split(grad_output, dim_size_per_partrition, dim=-1)
        return grad_chunks[rank]

In [None]:
class ColumnParallelLinear(nn.Module):
    def __init__(self, input_size, output_size, world_size):
        super().__init__()
        self.input_size = input_size
        self.output_size_per_partrition = output_size_per_partrition // world_size
        
        self.weight = nn.Parameter(torch.empty(
            self.output_size_per_partrition,
            self.input_size
        ))
        self.bias = nn.Parameter(torch.empty(
            self.output_size_per_partrition
        ))
    
    def forward(self, input):
        input_parallel = f.apply(input)
        output_parallel = F.linear(input_parallel, self.weight, self.bias)
        output = g.apply(output_parallel)
        return output

In [None]:
partrition j on j device
F_{m, n} must be completed before F_{m+1, n}
B_{m, n} must be completed before B_{m-1, n}

In [None]:
def is_grad_enabled(inputs):
    return torch.is_grad_enabled() and inputs.requires_grad

In [None]:
def _broadcast(inputs):
    return inputs.clone()

In [None]:
def _reduce(inputs):
    torch.distributed.all_reduce(inputs)

In [None]:
class Broadcast(torch.autograd.Function):
    @staticmethod
    def forward(ctx, inputs):
        return _broadcast(inputs)
    
    @staticmethod
    def backward(ctx, grad_outputs):
        return _reduce(grad_outputs)

In [None]:
def broadcast_with_forward_and_backward(inputs):
    if is_grad_enabled(inputs):
        outputs = Broadcast.apply(inputs)
    else:
        outputs = _broadcast(inputs)
    
    return outputs

In [None]:
new > ready >running > blocked > terminated

In [None]:
rank = torch.distributed.get_rank()

In [None]:
if rank == 69:
    torch.distributed.isend(x, dst=42)
elif rank == 42:
    torch.distributed.irecv(tensor_will_be_received_data, src=69)

In [None]:
for i in range(num_tensor_model_parallel_groups):
    ranks = list(range(
        i*tensor_model_parallel_size,
        (i+1)*tensor_model_parallel_size
    ))

In [None]:
for in in

In [None]:
step 1: local variable
step 2: communication
step 3: set device
step 4: initialize parallel groups

In [None]:
- clock cycle 1: F_{1, 1}
- clock cycle 2: F_{1, 2}, F_{2, 1}
- clock cycle 3: F_{1, 3}, F_{2, 2}, F_{3, 1}

### ML Engineering

### MechInterp

In [None]:
data = {}

In [None]:
neuron_idx = 69

In [None]:
batch_idx = 0

In [None]:
tokens = model.to_tokens(text)

In [None]:
def extract_neuron_activation(activations, hook):
    data[neuron_idx] = activations[batch_idx, :, neuron_idx]
    return activations

In [None]:
layer_idx = 3

In [None]:
hook_name = f"blocks.{layer_idx}.mlp.hook_post"

In [None]:
model.run_with_hooks(
    tokens,
    fwd_hooks=[(hook_name, extract_neuron_activation)]
)

In [None]:
token_with_highest_activation = torch.argmax(data[neuron_idx])

In [None]:
tokens = model.to_tokens(text)

In [None]:
seq_len = tokens.shape[-1]

In [None]:
target_pattern = torch.zeros(seq_len, seq_len)

In [None]:
target_pattern[torch.arange(seq_len), torch.arange(seq_len)] = 1.

In [None]:
target_pattern[0] = torch.zeros_like(target_pattern[0])

In [None]:
def compute_score(attention_pattern):
    score = attention_pattern * target_pattern
    return score.sum() / attention_pattern.sum()

In [None]:
tokens = model.to_tokens(text)

In [None]:
_, cache = model.run_with_cache(tokens)

In [None]:
data = torch.zeros(n_layers, n_heads)

In [None]:
for layer_idx in range(n_layers):
    hook_name = f"blocks.{layer_idx}.attn.hook_pattern"
    layer_attention_pattern = cache[hook_name]
    for head_idx in range(n_heads):
        attention_pattern = layer_attention_pattern[0, head_idx, :]
        score = compute_score(attention_pattern)
        data[layer_idx][head_idx] = score

In [None]:
correct_token = model.to_single_token(" John")

In [None]:
incorrect_token = model.to_single_token(" Mary")

In [None]:
clean_tokens = model.to_tokens(clean_prompt)

In [None]:
corrupted_tokens = model.to_tokens(corrupted_prompt)

In [None]:
_, clean_activations = model.run_with_cache(clean_tokens)

In [None]:
def patch_residual(
    corrupted_activations,
    hook,
    position,
    clean_activations
):
    corrupted_activations[:, position, :] = clean_activations[hook.name][:, position, :]
    return corrupted_activations

In [None]:
n_tokens = tokens.shape[-1]

In [None]:
from functools import partial

In [None]:
from transformer_lens import utils

In [None]:
def compute_logit_diff(logits, correct_token, incorrect_token):
    correct_logits = logits[:, correct_logit, :]
    incorrect_logits = logits[:, incorrect_token, :]
    return correct_logits - incorrect_logits

In [None]:
for layer in range(n_layers):
    for position in range(n_tokens):
        hook_func = partial(patch_residual, position=position, clean_activations)
        hook_name = utils.get_act_name("resid_pre", layer)
        
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(hook_name, hook_func)]
        )
        

In [None]:
step 1: tokenize the pormpt
step 2: tokenize observation, append to the prompt
step 3: predict
step 4: exec

In [None]:
step 1: tokenize the pormpt
step 2: tokenize observation, append to the prompt
step 3: predict
step 4: exec

In [None]:
import math

In [None]:
import torch
from torch import nn
import torch.nn.functional as F

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, d_head):
        super().__init__()
        self.d_head = d_head
    
    def forward(self, q, k, v, mask=None):
        k = k.transpose(3, 2)
        qk = torch.matmul(q, k)
        scores = qk / math.sqrt(self.d_head)
        
        if mask != None:
            scores.fill_mask(mask == 0, 1e-9)
        
        attention_weights = F.softmax(scores, dim=-1)
        output = torch.matmul(attention_weights, v)
        return output, attention_weights

In [None]:
x.repeat((3, 2))

In [None]:
torch.roll(x, shifts=1, dim=1)

In [None]:
import torch.distributed.rpc as rpc

In [None]:
rpc.remote("worker_1", create_tensor)

In [None]:
nn.init.kaiming_normal_(layer1.weight)

In [None]:
torch.split(x, 3)

In [None]:
clean_tokens = model.to_tokens(tokens)
corrupted_tokens = model.to_tokens(tokens)

In [None]:
correct_token = model.to_single_token(" John")
incorrect_token = model.to_single_token(" Mary")

In [None]:
_, clean_activations = model.run_with_cache(
    clean_tokens
)

In [None]:
n_tokens = tokens.shape[-1]

In [None]:
data = torch.zeros(n_layers, n_tokens)

In [None]:
from transformer_lens.utils import get_act_name

In [None]:
def patch_component(
    corrupted_activations,
    hook,
    position,
    clean_activations
):
    corrupted_activations[:, position, :] = clean_activations[hook.name][:, position, :]
    return corrupted_activations

In [None]:
def logit_diff(logits, correct_token, incorrect_token):
    correct_logits = logits[:, correct_token, :]
    incorrect_logits = logits[:, incorrect_token, :]
    return correct_logits - incorrect_token

In [None]:
for layer_idx in range(layer_idx):
    for position in range(n_tokens):
        hook_func = partial(
            patch_component,
            position=position,
            clean_activations=clean_activations
        )
        hook_name = get_act_name("resid_pre", layer_idx)
        patched_logits = model.run_with_hook(
            corrupted_tokens,
            fwd_hooks=[(hook_name, hook_func)]
        )
        logit_diff = logit_diff(patched_logits, correct_token, incorrect_token)
        data[layer_idx][position] = logit_diff

In [None]:
target_preds = sm_pred[torch.arange(y_train.shape[0]), y_train]