### Engineering

In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [5]:
def is_grad_enabled(input):
    return torch.is_grad_enabled and input.requires_grad

In [7]:
def _broadcast(input):
    return input

In [8]:
def _reduce(input):
    world_size = torch.distributed.get_world_size(group=parallel_group)
    
    if world_size == 1:
        return input
    
    torch.distributed.all_reduce(input, group=parallel_group)
    
    return input

In [4]:
class Broadcast(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return _broadcast(input)
    
    @staticmethod
    def backward(ctx, grad_input):
        return _reduce(grad_input)

In [2]:
def broadcast_with_forward_and_backward(input):
    if is_grad_enabled(input):
        output = Broadcast.apply(input)
    else:
        output = _broadcast(input)
    
    return output

In [13]:
class f(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input
    
    @staticmethod
    def backward(ctx, grad_input):
        torch.distributed.all_reduce(grad_input)
        return grad_input

In [12]:
class g(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        world_size = torch.distributed.get_world_size()        
        inputs = [torch.empty_like(input) for _ in range(world_size)]
        torch.distributed.all_gather(inputs, input)
        inputs = torch.cat(inputs, dim=-1)
        return inputs
    
    @staticmethod
    def backward(ctx, grad_input):
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()
        
        dim_size = grad_input.shape[-1]
        dim_size_per_partrition = dim_size // world_size
        
        grads = torch.split(grad_input, dim_size_per_partrition, dim=-1)
        return grads[rank]

In [9]:
class ColumnParallellLinear(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        world_size = torch.distributed.get_world_size()
        
        self.input_size = input_size
        self.output_size_per_partrition = output_size // world_size
        self.weight = nn.Parameter(torch.empty(
            self.output_size_per_partrition,
            self.input_size
        ))
        self.bias = nn.Parameter(torch.empty(
            self.output_size_per_partrition
        ))
    
    def forward(self, input):
        input_parallel = f.apply(input)
        output_paralell = F.linear(input, self.weight, self.bias)
        outputs = g.apply(output_paralell)
        return outputs

In [None]:
start = torch.cuda.Event(enable_timing=True)

In [None]:
end = torch.cuda.Event(enable_timing=True)

In [None]:
start.record()

In [None]:
hardshit()

In [None]:
end.record()

In [None]:
torch.cuda.synchronize()

In [None]:
elapsed_time = start.

In [None]:
file system, memory sharing, message passing, 

In [17]:
class StartDependency(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        phony = torch.zeros(1, requires_grad=False, device=input.device)
        return input, phony
    
    @staticmethod
    def backward(ctx, grad_input, grad_phony):
        return grad_input

In [16]:
class EndDependency(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, phony):
        return input
    
    @staticmethod
    def backward(ctx, grad_input):
        return grad_input

In [None]:
def create_dependency(start_batch, end_batch):
    start_batch, phony = StartDependency.apply(start_batch)
    end_batch = EndDependency.apply(end_batch, phony)
    return start_batch, end_batch

In [None]:
int main() {
    int i = 0;
    
    while (i <= 5) {
        std::cout << i;
        i++;
    }
}

In [18]:
import torch
from torch import nn
import torch.nn.functional as F

In [19]:
def compute_forward_pass_using_data_parallelism(model, input, device_ids, output_id):
    models = nn.parallel.replicate(model, device_ids)
    inputs = nn.parallel.scatter(input, device_ids)
    
    logit = nn.parallel.parallel_apply(models, inputs)
    logits = nn.parallel.gather(logit, output_id)
    
    return logits

In [20]:
from torch.utils.data import Dataset

In [None]:
class CachedDataset(Dataset):
    def __init__(self, filename):
        super().__init__()
        self.filename = filename
        self.data = None
    
    def prefetch(self, idxs):
        if all([i in self.cached_index for i in idxs]):
            return
        
        if not self.data:
            self.data = torch.load(filename)
        
        total_elements = sum([self.data[i].numel() for i in idxs])
        self.cache = torch.zeros(total_elements, dtype=self.data.dtype)
        
        offset = 0
        for i in ixs:
            n_elements = self.data[i]
            self.cache[offset:offset+n_elements] = self.data[i].view(-1)
            offset += n_elements

In [22]:
from queue import Queue
from contextlib import contextmanager
from threading import Thread

In [23]:
def wait_and_execute(device, in_queue, out_queue):
    while True:
        task = in_queue.get()
        
        try:
            output = task()
        except:
            out_queue.put((None, False))
            continue
        
        out_queue.put((output, True))

In [None]:
@contextmanager
def spawn_workers(devices):
    in_queues = []
    out_queues = []
    workers = {}
    
    for device in devices:
        try:
            in_queues, out_queues = workers[device]
        except Exception:
            in_queue = Queue()
            out_queue = Queue()
            workers[device] = (in_queue, out_queue)
            
            thread = Thread(
                target=wait_and_execute,
                args=(device, in_queue, out_queue),
                daemon=True
            )
            thread.start()
        
        in_queues.append(in_queue)
        out_queues.append(out_queue)
    
    yield (in_queues, out_queues)

In [None]:
broadcast, scatter, reduce, gather

In [None]:
clock cycle 1: backward(m, n)
clock cycle 2: backward(m, n-1), backward(m-1, n)
clock cycle 3: backward(m, n-2), backward(m-1, n-1), backward(m-2, n)

In [24]:
from transformers import Trainer

In [None]:
class DistillationTrainer(Trainer):
    def __init__(self, teacher_model):
        super().__init__()
        

In [None]:
p, n, i, d