### Engineering

In [1]:
import socketserver

In [None]:
with socketserver.ThreadingTCPServer(
    (MASTER_HOST, MASTER_PORT),
    EchoRequestHandler
) as server:
    server.server_forever()

In [2]:
import threading

In [3]:
event = threading.Event()

In [None]:
def run_worker():
    print("waiting")
    event.wait()
    print("received")

In [None]:
worker_thread = threading.Thread(
    target=run_worker
)

In [None]:
detach input
enable
save

In [5]:
from socketserver import StreamRequestHandler

In [None]:
class EchoRequestHandler(StreamRequestHandler):
    def handle(self):
        print(self.rfile().readline())

In [6]:
import torch
from torch import nn
import torch.nn.functional as F

In [7]:
class Checkpoint(torch.autograd.Function):
    @staticmethod
    def forward(ctx, phony, recomputed, function, input):
        ctx.recomputed = recomputed
        ctx.function = function
        ctx.input = input
        
        with torch.no_grad():
            output = function(input)
        
        return output
    
    @staticmethod
    def backward(ctx, grad_input):
        output, input = ctx.recomputed.pop()
        input_leaf = input.detach().requires_grad_(
            input.requires_grad
        )
        
        with torch.grad_enabled():
            torch.autograd.backward(output, grad_input)
        
        grad = [None, None, None]
        
        if input_leaf.requires_grad:
            grad.extend([input_leaf.grad])
        else:
            grad.extend([None])
        
        return grad

In [None]:
save
sync
restore
reset

In [8]:
from torch.multiprocessing import Process

In [None]:
for _ in range(world_size):
    process = Process(target=say_hello)
    process.start()

In [None]:
class Recompute(torch.autograd.Function):
    @staticmethod
    def forward(ctx, phony, recomputed, function, input):
        ctx.recomputed = recomputed
        ctx.function = function
        ctx.input = input
        
        return phony
    
    @staticmethod
    def backward(grad_output):
        function = ctx.function
        input = ctx.input # detach
        
        with torch.grad_enabled():
            output = function(input)
        
        ctx.recomputed.append((output, input_leaf))
        
        grad = [None, None, None, None]
        return  

In [9]:
class VocabParallelEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super().__init__()
        world_size = torch.distributed.get_world_size()
        
        self.num_embeddings_per_partrition = num_embeddings // world_size
        self.embedding_dim = embedding_dim
        
        self.weight = nn.Parameter(torch.randn(
            self.num_embeddings_per_partrition,
            self.embedding_dim
        ))
        
        self.vocab_start_idx, self.vocab_end_idx = self.get_vocab_range(
            self.num_embeddings_per_partrition
        )
    
    def get_vocab_range(self, num_embeddings_per_partrition):
        rank = torch.distributed.get_rank()
        start_idx = rank*num_embeddings_per_partrition
        end_idx = start_idx + num_embeddings_per_partrition
        return start_idx
    
    def forward(self, tokens):
        masked = (tokens < self.vocab_start_idx) | (tokens > self.vocab_end_idx)
        tokens = tokens - self.vocab_start_idx
        tokens[masked] = 0.
        
        embeddings = F.embedding(tokens, self.weight)
        mask_idxs = torch.where(masked == False)[1]
        embeddings[mask_idxs] = 0.
        
        torch.distributed.all_reduce(embeddings)
        
        return embeddings

In [10]:
from socketserver import StreamRequestHandler

In [11]:
class EchoRequestHandler(StreamRequestHandler):
    def handle(self):
        pass

In [None]:
monitor
if a node failed, reassign
if a new node entered, form a new communication ring

In [12]:
import threading

In [13]:
data = threading.local()

In [None]:
thread = threading.Thread(
    target=print_and_modify,
    args=(data,)
)

In [17]:
class StartBatch(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        phony = torch.empty(0, requires_grad=False, device=input.device)
        return input, phony
    
    @staticmethod
    def backward(ctx, grad_input, grad_phony):
        return grad_input

In [18]:
class EndDependency(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, phony):
        return input
    
    @staticmethod
    def backward(ctx, grad_input):
        return grad_input, None

In [14]:
def create_dependency(start_batch, end_batch):
    start_batch, phony = StartDependency.apply(start_batch)
    end_batch = EndDependency.apply(end_batch, phony)
    return start_batch, end_batch

In [19]:
n_microbatches = 4

In [20]:
n_partritions = 3

In [21]:
n_clock_cycles = n_microbatches+n_partritions-1

In [22]:
n_clock_cycles

6

In [23]:
for clock_idx in range(n_clock_cycles):
    start_partrition = max(clock_idx+1-n_microbatches, 0)
    end_partrition = min(clock_idx+1, n_partritions)
    
    tasks = []
    for partrition_idx in range(start_partrition, end_partrition):
        microbatch_idx = clock_idx - partrition_idx
        tasks.append((microbatch_idx, partrition_idx))
    
    print(tasks)

[(0, 0)]
[(1, 0), (0, 1)]
[(2, 0), (1, 1), (0, 2)]
[(3, 0), (2, 1), (1, 2)]
[(3, 1), (2, 2)]
[(3, 2)]


In [None]:
load, save, move

In [None]:
main thread > worker thread > task > cuda stream

In [None]:
service: receive
manager: handle
client: send

In [None]:
service, manager, client

In [None]:
load, save, move

In [None]:
tokens = model.to_tokens(text)

In [None]:
seq_len = tokens.shape[-1]

In [None]:
target_pattern = torch.zeros(seq_len, seq_len)

In [None]:
target_pattern[torch.arange(seq_len), torch.arange(seq_len)-1] = 1

In [None]:
target_pattern[0] = torch.zeros_like(target_pattern[0])

In [None]:
step 1: create a socket server
step 2: socket client requests a connection
step 3: accept
step 3: 

In [None]:
checkpoint.forward() > recomputed.forward() > recomputed.backward() > checkpoint().backward()

In [None]:
step 1: bind
step 2: connect
step 3: accept
step 4: xx
step 5: close

### MLE

In [24]:
import great_expectations as ge

In [None]:
ge.dataset.PandasDataset

In [25]:
from evidently.metric_preset import DataDriftPreset

In [26]:
from evidently.report import Report

In [None]:
report = Report(metrics=[
    DataDriftPreset()
])

In [28]:
from pydantic import BaseModel
from pydantic import root_validator

In [None]:
class User(BaseModel):
    password1: int
    password2: int
    
    @root_validator()
    def validate_password(values):
        

In [None]:
tokens = tokenizer(text, return_tensors="pt")

In [29]:
def print_shape(module, input):
    print(input.shape)

In [None]:
model.blocks[1].register_forward_pre_hook(print_shape)

In [None]:
step 1: two prompts
step 2: record the activations
step 3: iteratively replace the activations from the clean prompt to the corrupted prompt
step 4

In [None]:
_, cache = model.run_with_cache(tokens)

In [30]:
from transformer_lens.utils import get_act_name

In [31]:
hook_name = get_act_name("pattern", 0, "attn")

In [None]:
attention_patterns = cache[hook_name]

In [None]:
str_tokens = model.to_str_tokens(tokens)

In [32]:
import circuitsvis as cv

In [None]:
cv.attention.attention_pattern(
    tokens=str_tokens,
    attention=attention_patterns
)

In [None]:
tokens = model.to_tokens(text)

In [None]:
target_tokens = tokens[1:]

In [None]:
_, cache = model.run_with_cache(tokens)

In [None]:
W_U = cache["W_U"]

In [None]:
W_U_correct_tokens = W_U[:, target_tokens]

In [None]:
logits = unembed @ final_residual_stream


In [None]:
query, key, value, atten pattern, output

In [None]:
torch.roll(x, shifts=1, dims=0)

In [33]:
from einops import einsum

In [None]:
output = einsum(a, "h w -> w hb")

In [None]:
text encoder, image encoder, projection head, constrative loss

In [34]:
futs = []

In [35]:
import torch.distributed.rpc as rpc

In [None]:
for ob_rref in obj_rrefs:
    futs.append(rpc.rpc_async(
        to=ob_rref.owner(),
        func=ob_rref.rpc_sync().run
    ))

In [None]:
for fut in futs:
    fut.wait()

In [None]:
name, id, WorkerInfo

In [None]:
nn.Flatten(start_dim=2, end_dim=-1)

In [None]:
ob_rref.owner()

In [None]:
torch.roll(x, shifts=1, dim=0)

In [None]:
func = nn.Flatten(start_dim=2, end_dim=3)

In [None]:
func(batch)

In [None]:
population activity, motor imagery, stimulus, conditioned

In [None]:
stimulate, record, memory, process

In [None]:
import jax

In [None]:
vmapped_func = jax.vmap(multiply)

In [None]:
def summary_batch(article_batches, max_context_length):
    x = []
    
    for batch in article_batches:
        tokens = tokenizer(article_batch)

In [None]:
class Agent(nn.Module):
    def __init__(self, n_observations, n_actions, n_hidden):
        super().__init__()
        self.actor_network = nn.Sequential(
            nn.Linear(n_observations, n_hidden),
            nn.Tanh(),
            nn.Linear(n_hidden, n_hidden),
            nn.Tanh(),
            nn.Linear(n_hidden, n_actions)
        )
        
        self.critic_network = nn.Sequential(
            nn.Linear(n_observations, n_hidden),
            nn.Tanh(),
            nn.Linear(n_hidden, n_hidden),
            nn.Tanh(),
            nn.Linear(n_hidden, 1),
            nn.Sigmoid(dim=-1)
        )
    
    def get_action_and_value(self, observations):
        logits = self.actor_network(observations)
        probs = F.softmax(logits, dim=-1)
        action = torch.argmax(probs, dim=-1)
        critic_value = self.critic_network(observations)
        
        return action, probs.log(), probs.entropy(), critic_value