### MLE

### Engineering

In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
from torch.multiprocessing import Process

In [4]:
def init_communication(rank, world_size, config):
    torch.distributed.init_process_group(
        *config,
        rank=rank,
        world_size=world_size,
    )

In [None]:
for rank in range(4):
    process = Process(
        target=init_communication,
        args=(rank, world_size, config)
    )
    process.start()

In [None]:
#include <iostream>

In [None]:
int main() {
    int x = 1;
    std::cout << &x;
    return 0;
}

In [None]:
int main() {
    const float pi = 3.14;
    return 0;
}

In [None]:
int main() {
    int file_size = 100;
    return 0;
}

In [None]:
int main() {
    std::cout << "hello" << std::endl;
    std::cout << "world";
    return 0;
}

In [None]:
typedef int age_type;

In [5]:
n_microbatches = 4

In [6]:
n_partritions = 3

In [7]:
n_clock_cycles = n_microbatches + n_partritions - 1

In [8]:
for clock_idx in range(n_clock_cycles):
    start_partrition = max(clock_idx+1-n_microbatches, 0)
    end_partrition = min(clock_idx+1, n_partritions)
    
    tasks = []
    for partrition_idx in range(start_partrition, end_partrition):
        microbatch_idx = clock_idx - partrition_idx
        tasks.append((microbatch_idx, partrition_idx))
    
    print(tasks)

[(0, 0)]
[(1, 0), (0, 1)]
[(2, 0), (1, 1), (0, 2)]
[(3, 0), (2, 1), (1, 2)]
[(3, 1), (2, 2)]
[(3, 2)]


In [None]:
main thread > worker thread > task > cuda stream

In [11]:
def is_grad_enabled(input):
    return torch.is_grad_enabled() and input.requires_grad

In [12]:
def _broadcast(input):
    world_size = torch.distributed.get_world_size()
    if world_size == 1:
        return input

    torch.distributed.broadcast(input, group=parallel_group)

In [None]:
def _reduce(input):
    world_size = torch.distributed.get_world_size()
    if world_size == 1:
        return input
    
    torch.distributed.all_reduce(input, group=parallel_group)
    return input

In [10]:
class Broadcast(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return _broadcast(input)
    
    @staticmethod
    def backward(ctx, grad_input):
        return _reduce(grad_input)

In [9]:
def broadcast_with_forward_and_backward(input):
    if is_grad_enabled(input):
        output = Broadcast.apply(input)
    else:
        output = _broadcast(input)
    return input

In [13]:
class StartDependency(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        device = input.device
        phony = torch.randn(1, requires_grad=False, device=device)
        return input, phony
    
    @staticmethod
    def backward(ctx, grad_input, grad_phony):
        return grad_input, None

In [14]:
class EndDependency(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, phony):
        return input
    
    @staticmethod
    def backward(ctx, grad_input):
        return grad_input

In [15]:
def create_dependency(start_batch, end_batch):
    start_batch, phony = StartDependency(start_batch)
    end_batch = EndDependency(start_batch, phony)
    return start_batch, end_batch

In [None]:
clock cycle 1: B_{m, n}
clock cycle 2: B_{m-1, n}
clock cycle 3: B_{m-2, n}

In [None]:
typedef int age;

In [None]:
clock cycle 1: Backward(4, 3), Recompute(3, 3)
clock cycle 2: Backward(3, 3), Recompute(2, 3)
clock cycle 3: Backward(2, 3), Recompute(1, 3)

In [16]:
from torch.utils.data import DataLoader

In [None]:
class CachedDataset(DataLoader):
    def __init__(self, filename):
        super().__init__()
        self.filename = filename
        self.cached_idxs = {}
        self.data = None
    
    def prefetch(self, idxs):
        if all([idx in self.cached_idxs for idx in idxs]):
            return
        
        if not self.data:
            self.data = torch.load(self.filename)
        
        self.total_elements = sum([self.data[i].numel() for i in ixs])
        self.cache = torch.zeros(total_elements, dtype=self.data.dtype)
        self.cache_index.clear()
        
        offset = 0
        for i in idxs:
            n_elements = 

In [None]:
broadcast, scatter, reduce, gather

In [19]:
from contextlib import contextmanager
from queue import Queue
from threading import Thread

In [None]:
def wait_and_execute(device, in_queue, out_queue):
    while True:
        task = in_queue.get()
        
        try:
            output = task()
        except Exception:
            out_queue.put([None, False])
            continue
        
        out_queue.put([output, True])

In [None]:
@contextmanager
def spawn_workers(devices):
    in_queues = []
    out_queues = []
    workers = {}
    
    for device in devies:
        try:
            in_queue, out_queue = workers[device]
        except Exception:
            in_queue = Queue()
            out_queue = Queue()
            workers[device] = in_queue, out_queue
            
            thread = Thread(target=)
        
        in_queues.append(in_queue)
        out_queues.append(out_queue)
    
    yield (in_queues, out_queues)

In [None]:
broadcast, scatter, reduce, gather

In [20]:
from torch.utils.data import Dataset

In [None]:
class CachedDataset(Dataset):
    def __init__(self, filename):
        super().__init__()
        self.filename = filename
        self.cached_index = {}
        self.data = None
        self.cache = {}
    
    def prefetch(self, idxs):
        if all([i in self.cached_index for i in idxs]):
            return
        
        if not self.data:
            self.data = torch.load(filename)
        
        n_elements = sum([self.data[i].numel() for i in idxs])
        self.cache = torch.zeros(n_elements, dtype=self.data.dtype)
        
        offset = 0
        for i in idxs:
            length = self.data[i].numel()
            self.cache[offset:offset+length] = self.data[i]
            offset += length

In [None]:
clock cycle 1: backward(m, n)
clock cycle 2: backward(m, n-1), backward(m-1, n)
clock cycle 3: backward(m, n-2), backward(m-1, n-1), backward(m-2, n-2)

### AI

In [None]:
x.repeat()

In [None]:
torch.roll(x, shifts=1, dim=1)

In [None]:
torch.distributed.recv(x)

In [None]:
class Function(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.input = input
        return input
    
    @staticmethod
    def backward(ctx, grad_input):
        return ctx.input

In [None]:
tokens = model.to_tokens(prompt)

In [None]:
corrupted_prompt = "A told B: 'Persistence is all you need.' C replied back to "

In [None]:
corrupted_tokens = model.to_tokens(corrupted_prompt)

In [None]:
target_token = model.to_single_token("John")

In [None]:
_, clean_activations = model.run_with_cache(tokens)

In [None]:
_, corrupted_activations = model.run_with_cache(corrupted_tokens)

In [21]:
from transformer_lens import utils

In [22]:
head_idx, layer_idx = 6, 9

In [23]:
hook_name = utils.get_act_name("attn", layer_idx)

In [None]:
corrupted_head_activations = corrupted_activations[hook_name][:, head_idx, :, :]

In [25]:
def patch_corrupted_head_activation(activations, hook):
    activations[:, head_idx, :, :] = corrupted_head_activations
    return activations

In [None]:
, patched_activations = model.run_with_hooks(
    tokens,
    fwd_hooks=([hook_name, patch_corrupted_head_activation])
)

In [None]:
corrupted_receiver_activations = patched_activations[receiver_hook_name]

In [26]:
def patch_corrupted_receiver_activations(activations, hook):
    return corrupted_receiver_activations

In [None]:
model.add_hook(patch_corrupted_receiver_activations)

In [None]:
_, patched_logits = model(tokens)

In [None]:
model.reset_hooks()
_, clean_logits = model(tokens)

In [None]:
def compute_logit_difference(clean_logits, corrupted_logits, target_token):
    return corrupted_logits[:, -1, :][target_token] - clean_logits[:, -1, :][target_token]

In [None]:
class ShortcutProjection(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        

In [27]:
def calculate_discounted_return_an_episode(rewards, discount_factor):
    total = 0
    
    for k, reward in enumerate(rewards):
        total += (discount_factor**k)*reward
    
    return total