In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [2]:
import torch.distributed as dist

In [4]:
class Reduce(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, parallel_context):
        group = parallel_context.get_group(ParallelMode.TENSOR)
        input = dist.all_reduce(
            input,
            group=group
        )
        return input
    
    @staticmethod
    def backward(ctx, grad_input):
        return (grad_input, None)

In [5]:
class ParallelEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, parallel_context):
        super().__init__()
        world_size = parallel_context.get_world_size(ParallelMode.TENSOR)
        
        num_embedding_per_partition = num_embeddings // world_size
        self.weight = nn.Parameter(torch.randn(
            num_embedding_per_partition,
            embedding_dim
        ))
        self.vocab_start_idx, self.vocab_end_idx = self._get_vocab_range(
            num_embedding_per_partition,
            parallel_context
        )
    
    def _get_vocab_range(self, num_embedding_per_partition, parallel_context):
        rank = parallel_context.get_local_rank(ParallelMode.TENSOR)
        start_idx = rank*num_embedding_per_partition
        end_idx = start_idx+num_embedding_per_partition
        return start_idx, end_idx
    
    def forward(self, input):
        input_mask = (input < self.vocab_start_idx) | (input >= self.vocab_end_idx)
        masked_input = input.clone() - self.vocab_start_idx
        masked_input[input_mask] = 0
        
        parallel_embeddings = F.embedding(masked_input, self.weight)
        parallel_embeddings[input_mask, :] = 0.
        
        embeddings = Reduce.apply(parallel_embeddings, parallel_context)
        return embeddings

In [None]:
step 1: send metadata
step 2: send data
step 3: construct
step 4: fill

In [None]:
step 1: map
step 2: convert
step 3: send

In [7]:
class Broadcast(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input
    
    @staticmethod
    def backward(ctx, grad_input):
        pass

In [8]:
class Gather(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        world_size = dist.get_world_size()
        inputs = [torch.zeros_like(input) for _ in range(world_size)]
        dist.all_gather(inputs, input)
        inputs = torch.cat(inputs)
        return inputs
    
    @staticmethod
    def backward(ctx, grad_input):
        world_size = dist.get_world_size()
        rank = dist.get_rank()
        chunks = torch.split(
            grad_input,
            split_size_or_sections=grad_input.shape[-1]//world_size
        )
        return chunks[rank]

In [6]:
class ColumnParallelLinear(nn.Module):
    def __init__(self, input_size, output_size, world_size):
        super().__init__()
        per_partition = output_size // world_size
        
        self.weight = nn.Parameter(torch.randn(
            per_partition,
            input_size
        ))
        self.bias = nn.Parameter(torch.randn(
            per_partition
        ))
    
    def forward(self, input):
        input_parallel = Broadcast.apply(input)
        output_parallel = F.linear(input, self.weight, self.bias)
        outputs = Gather.apply(output_parallel)
        return outputs

In [None]:
shape, requires_grad, dtype

In [None]:
step 1: mask targets
step 2: predicted logits
step 3: all-reduce predicted logits
step 4: log
step 5: 

In [None]:
step 1: partition weight
step 2: mask input
step 3: parallel_embedding
step 4: embeddings

In [None]:
_, clean_cache = model.run_with_cache(clean_tokens)
_, corrupted_cache = model.run_with_cache(corrupted_tokens)

In [10]:
def patch_head(acts, hook, clean_cache, corrupted_cache, target_head):
    target_layer_idx, target_head_idx = target_head
    
    if hook.layer() == target_layer_idx:
        acts[:, target_head_idx] = corrupted_cache[hook.name][:, target_head_idx]
    else:
        acts = clean_cache[hook.name]
    return acts

In [13]:
from itertools import product
from functools import partial

In [12]:
from transformer_lens.utils import get_act_name

In [None]:
results = torch.zeros(n_layers, n_heads)
combinations = product(range(n_layers), range(n_heads))

for layer_idx, head_idx in combinations:
    model.reset_hooks()
    hook_name = get_act_name("z", layer_idx)
    hook_func = partial(
        patch_head,
        clean_cache=clean_cache,
        corrupted_cache=corrupted_cache,
        target_head=(layer_idx, head_idx)
    )
    model.add_hook(hook_name, hook_func)
    patched_logits, _ = model.run_with_cache(
        clean_tokens
    )
    results[layer_idx, head_idx] = compute_ioi_metric(patched_logits)

In [None]:
input embedding > weight > feature > activations

In [None]:
_, cache = model.run_with_cache(board_history)

In [14]:
layer_idx, head_idx = 5, 1393

In [15]:
hook_name = f"blocks.{layer_idx}.mlp.hook_post"

In [None]:
neuron_activations = cache[hook_name][:, neuron_idx]

In [16]:
layer_idx, head_idx = 9, 9

In [None]:
_, cache = model.run_with_cache(clean_tokens)

In [17]:
hook_name = get_act_name("z", layer_idx)

In [18]:
hook_name

'blocks.9.attn.hook_z'

In [None]:
W_O = model.W_O[layer_idx, head_idx]
output = cache[hook_name][:, head_idx] @  W_O

In [None]:
W_U = model.W_U
io_dir = W_U[:, io_tokens]
s_dir = W_U[:, s_tokens]

In [19]:
from einops import einsum

In [None]:
projection_in_io_dir = einsum(
    output,
    io_dir
)
projection_in_s_dir = einsum(
    output,
    s_dir
)

In [None]:
attn_prob = cache["pattern", layer_idx][:, head_idx]
attn_from_end_to_io = attn_prob[:, end_idxs, io_idxs]
attn_from_end_to_s = attn_prob[:, end_idxs, s_idxs]

In [None]:
mine = linear_probe[..., 2]
theirs = linear_probe[..., 1]

mine_vs_theirs = mine - theirs
extracted_direction = mine_vs_theirs[:, 5, 4]

In [None]:
_, cache = model.run_with_cache(board_history)

In [None]:
mlp_neurons = cache[hook_name][:, 1393]
thr = mlp_neurons.quantile(0.99)
top_neurons = mlp_neurons > thr

(board_states == 1)[top_neurons].float().mean(dim=0)

In [None]:
W_O = model.W_O[0, 1]
W_V = model.W_V[0, 1]
W_Q = model.W_Q[1, 2]
W_K = model.W_K[1, 2]

In [20]:
from einops import rearrange

In [None]:
W_OV = W_V @ W_O
W_QK = W_Q @ rearrange(
    W_K, "... d_model d_head -> ... d_head d_model"
)

In [None]:
virtual_weight = W_OV @ W_QK

In [None]:
api scheduler, apiserver, ectd, contrl mangager

In [None]:
step 1: normalize the loss / n_epoch)
step 2: calculate gradients with respect to the normalized loss
step 3: sum
step 4: if current_epoch == n_epoch, update, otherwise, step 1

In [None]:
class ResidualLayerNorm(nn.Module):
    def __init__(self, d_model, dropout):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, residual):
        return self.norm(self.dropout(x) + residual)

In [None]:
class ShowerEnv

In [None]:
def discount_reward(rewards, discount_factors)