### Engineering

In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
import torch.distributed as dist
import torch.distributed.rpc as rpc

In [3]:
class ParallelContext:
    def init_rpc_worker(self, host, port):
        if self.pipeline_parallel_size > 1:
            rank = self.get_local_rank(ParallelMode.GLOBAL)
            world_size = self.get_world_size(ParallelMode.GLOBAL)
            
            init_method = f"rpc://{host}:{port}"
            options = rpc.RpcBackendOptions(
                init_method=init_method
            )
            
            if torch.cuda.is_available():
                ranks = self.get_ranks_in_group(ParallelMode.GLOBAL)
                
                worker_mapping = {
                    rank: WORKER_NAME.format(rank)
                    for rank in ranks
                }
                
                for other in ranks:
                    if rank == other:
                        continue
                    options.set_device_map(worker_mapping[other], {rank: other})
            
            rpc.init_rpc_worker(
                name=WORKER_NAME.format(rank),
                rank=rank,
                world_size=world_size,
                options=options
            )

In [None]:
for group in groups:
    dist.barrier()
    dist.destroy_process_group(group=group)

In [None]:
mine_dir = linear_probe[.., 2]
theirs_dir = linear_probe[..., 1]

In [None]:
mine_vs_theirs = mine_dir - theirs_dir

In [None]:
_, cache = model.run_with_cache(board_history)

In [4]:
from transformer_lens.utils import get_act_name

In [6]:
hook_name = get_act_name("attn_out", 6)

In [None]:
pattern = cache[hook_name][2, 20, :]

In [7]:
from einops import einsum

In [None]:
einsum(pattern, mine_vs_theirs)

In [None]:
step 1: prediction position
step 2: head
step 3: attn pattern
step 4: important component
step 5: path patching
step 6: repeat

In [None]:
monitor
reassign worker


In [8]:
from torch.distributed.rpc import RRef

In [None]:
class Agent:
    def __init__(self):
        self.id = RRef(self)

In [9]:
from einops import rearrange

In [None]:
def compute_loss(model, xb, yb):
    logits = model(xb)
    logits = rearrange(
        logits,
        "batch_size seq_len vocab_size -> (batch_size seq_len) vocab_size"
    )
    yb = rearrange(
        yb,
        "batch_size seq_len -> (batch_size seq_len)"
    )
    loss = F.cross_entropy(logits, yb)

In [10]:
def calculate_discounted_return_each_timestep(rewards, discount_factor):
    discounted_returns = []
    
    for i in range(len(rewards)):
        discounted_return = 0
        
        for k, reward in enumerate(rewards[i:]):
            discounted_return += (discount_factor**k)*reward
        
        discounted_returns.append(discounted_return)
    
    return discounted_returns

In [11]:
rewards = torch.tensor([1, 2, 3, 4, 5])

In [12]:
calculate_discounted_return_each_timestep(rewards, discount_factor=0.99)

[tensor(14.6045), tensor(13.7419), tensor(11.8605), tensor(8.9500), tensor(5.)]

In [13]:
from torch.distributions import Categorical

In [14]:
class Agent(nn.Module):
    def __init__(self, n_observations, n_actions, n_hidden):
        self.actor = nn.Sequential([
            nn.Linear(n_observations, n_hidden),
            nn.Tanh(),
            nn.Linear(n_hidden, n_hidden),
            nn.Tanh(),
            nn.Linear(n_hidden, n_actions)
        ])
        self.critic = nn.Sequential([
            nn.Linear(n_observations, n_hidden),
            nn.Tanh(),
            nn.Linear(n_hidden, n_hidden),
            nn.Tanh(),
            nn.Linear(n_hidden, 1),
            nn.Sigmoid()
        ])
    
    def get_action_and_value(self, observations):
        logits = self.actor(observations)
        dist = Categorical(logits=logits)
        action = dist.sample()
        log_prob = dist.log_prob()
        entropy = dist.entropy()
        critic_value = self.critic(observations)
        return action, log_prob, entropy, critic_value

In [15]:
3 // 5

0

In [16]:
8 // 5

1

In [4]:
import torch
from torch import nn

In [5]:
shape = (2, 4,)

In [7]:
nn.Linear(*shape)

Linear(in_features=2, out_features=4, bias=True)

In [79]:
QUEUE = []

In [89]:
import time

In [90]:
class TriggerBackward2(torch.autograd.Function):
    @staticmethod
    def forward(ctx, key, input):
        ctx.key = key
        return input
    
    @staticmethod
    def backward(ctx, grad_input):
        time.sleep(2)
        QUEUE.append(ctx.key)
        return (None, grad_input)

In [91]:
class TriggerBackward(torch.autograd.Function):
    @staticmethod
    def forward(ctx, key, input):
        ctx.key = key
        return input
    
    @staticmethod
    def backward(ctx, grad_input):
        time.sleep(3)
        QUEUE.append(ctx.key)
        return (None, grad_input)

In [92]:
f1 = nn.Linear(2, 4)

In [93]:
x = torch.randn(4, 2, requires_grad=True)

In [94]:
output = f1(TriggerBackward.apply(1, x))

In [95]:
output = TriggerBackward2.apply(2, output)

In [96]:
output.sum().backward()

In [97]:
QUEUE

[2, 1, 2, 1]