### Engineering

In [1]:
import torch
from torch import nn
import torch.nn.functional as F

In [None]:
step 1: loss / (current_epoch / n_epoch)
step 2: calculate the gradients with respect to the normalized loss
step 3: accumulate the gradients
step 4: if current_epoch == n_epoch, update, otherwise, repeat

In [None]:
message passing, sharded memory, file system

In [None]:
(local_rank + 1)%local_world_Sizec

In [5]:
import torch.distributed as dist

In [7]:
class Broadcast(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input
    
    @staticmethod
    def backward(ctx, grad_input):
        dist.all_reduce(grad_input)
        return grad_input

In [6]:
class Gather(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        world_size = dist.get_world_size()
        inputs = [torch.zeros_like(input) for _ in range(world_size)]
        dist.all_gather(inputs, input)
        inputs = torch.cat(inputs)
        return inputs
    
    @staticmethod
    def backward(ctx, grad_input):
        world_size = dist.get_world_size()
        rank = dist.get_rank()
        chunks = torch.split(
            grad_input,
            split_size_or_sections=grad_input.shape[-1]//world_size
        )
        return chunks[rank]

In [4]:
class ColumnParallelLinear(nn.Module):
    def __init__(self, input_size, output_size, world_size):
        super().__init__()
        out_per_partition = world_size // output_size
        
        self.weight = nn.Parameter(torch.randn(
            out_per_partition,
            input_size
        ))
        self.bias = nn.Parameter(torch.randn(
            out_per_partition
        ))
    
    def forward(self, input):
        input_parallel = Broadcast.apply(input)
        output_parallel = F.linear(input_parallel, self.weight, self.bias)
        outputs = Gather.apply(output_parallel)
        return outputs

In [None]:
class RowParallelLinear(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        world_size = dist.get_world_size()
        input_per_partrition = input_size // world_size
        self.weight = nn.Parameter(torch.randn(
            output_size,
            input_per_partrition
        ))
        
        self.bias = nn.Parameter(torch.randn(
            output_size
        ))
    
    def forward(self, input):
        input_parallel = Scatter.apply(input)
        output_parallel = F.linear()

In [None]:
fro

In [None]:
class _Scatter(Function):
    @staticmethod
    def forward(ctx: Any, input: torch.Tensor, dim: int, parallel_context: ParallelContext) -> torch.Tensor:
        ctx.dim = dim
        ctx.parallel_context = parallel_context
        
        world_size = parallel_context.get_world_size(parallel_mode)
        rank = parallel_context.get_local_rank(parallel_mode)

        if world_size == 1:
            return tensor

        assert tensor.size(dim) % world_size == 0

        tensor_list = torch.chunk(tensor, world_size, dim=dim)
        return tensor_list[rank]

        
        return scatter(input, dim=dim, parallel_context=parallel_context, parallel_mode=ParallelMode.TENSOR)

    @staticmethod
    def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]:
        dim = ctx.dim
        parallel_context = ctx.parallel_context

        return (
            all_gather(
                grad_output, dim=dim, async_op=False, parallel_context=parallel_context, parallel_mode=ParallelMode.TENSOR
            ),
            None,
            None,
        )

In [10]:
class RowParallelLinear(nn.Module):
    def __init__(
        self,
        in_features,
        out_features,
        parallel_context,
    ) -> None:
        super().__init__()
        in_per_partition = self._get_input_per_partition(in_features, parallel_context)

        self.in_features = in_features
        self.out_features = out_features
        self.parallel_context = parallel_context

        self.weight = nn.Parameter(torch.randn(out_features, in_per_partition))

        self.bias = nn.Parameter(torch.randn(out_features))

    def _get_input_per_partition(self, in_features, parallel_context):
        local_world_size = parallel_context.get_world_size(ParallelMode.TENSOR)
        return in_features // local_world_size

    def forward(self, input):
        input_parallel = scatter_tensor_1d(input, dim=-1, parallel_context=self.parallel_context)
        output_parallel = F.linear(input_parallel, self.weight)
        outputs = reduce_tensor_1d(output_parallel, parallel_context=self.parallel_context)

        return outputs + self.bias

In [None]:
class TestFruit:
    def setup_method(self):
        self.fruit = Fruit("x")
    
    def test_fruit(self):
        assert self.fruit == "x"
        
    def teardown_method(self):
        del self.fruit

In [11]:
world_size = 16

In [12]:
tensor_model_parallel_size = 2

In [13]:
pipeline_model_parallel_size = 4

In [14]:
data_parallel_groups = []

In [15]:
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size

In [16]:
for i in range(pipeline_model_parallel_size):
    start_rank = i*num_pipeline_model_parallel_groups
    end_rank = (i+1)*num_pipeline_model_parallel_groups
    
    for j in range(tensor_model_parallel_size):
        ranks = list(range(
            start_rank+j,
            end_rank,
            tensor_model_parallel_size
        ))
        
        print(ranks)

[0, 2]
[1, 3]
[4, 6]
[5, 7]
[8, 10]
[9, 11]
[12, 14]
[13, 15]


In [None]:
hostdiscovery, torchstate, 3 notif, 

In [None]:
step 1: var
step 2: global
step 3: parallel groups
step 4: set device

In [None]:
Scatter > All reduce > Identity > All-gather

In [None]:
_, cache = model.run_with_cache(tokens)

In [None]:
cache["out", 2]

In [18]:
seq_len, d_model = 4, 16

In [23]:
W_in = torch.zeros(d_model, seq_len)
W_in[0:4, :] = torch.diagonal()

In [24]:
W_in

tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

In [None]:
step 1: prob[0] = sigmoid(logit0 - logit1)
step 2: logit0 = resid @ W_U[0], logit1 = resid @ W_U[1]
step 3: resid @ (W_U[0] - W_U[1])
step 4

In [25]:
d_model = 16

In [26]:
seq_len = 4

In [27]:
W_V = torch.zeros(seq_len, d_model)

In [28]:
W_V[torch.arange(4), torch.arange(4)] = 1.

In [29]:
W_V

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [30]:
W_O = torch.zeros(d_model, seq_len)

In [33]:
W_O[8:11, :] = 1.

In [34]:
W_O

tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

In [66]:
d_head = 4

In [90]:
import torch

In [91]:
d_model, d_head

(16, 4)

Write down $W_V^1$ and $W_{\text {out }}^1$ for head 1 , such that the head copies dimensions 0-3 of its input to 8-11 in its output + there are four input tokens

In [92]:
W_V = torch.zeros(d_head, d_model)
W_V[torch.arange(4), torch.arange(4)] = 1.

In [93]:
W_O = torch.zeros(d_model, d_head)
W_O[7:11, :] = torch.eye(4)

In [94]:
W_V

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [95]:
W_O

tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])

In [96]:
W_V.shape, W_O.shape

(torch.Size([4, 16]), torch.Size([16, 4]))

In [None]:
cache["attn", 1]

In [None]:
A@ 

In [None]:
_, clean_activations = model.run_with_cache(clean_tokens)
_, corrupted_activations = model.run_with_cache(corrupted_activations)

In [97]:
from itertools import product
from functools import partial

In [98]:
n_layers = 12
n_heads = 12

In [99]:
combinations = product(range(n_layers), range(n_heads))

In [103]:
def patch_head(
    activations,
    hook,
    clean_activations,
    corrupted_activations,
    target_head
):
    target_layer_idx, target_head_idx = target_head
    if hook.layer() == target_layer_idx:
        activations[:, head_idx] = corrupted_activations[hook.name][:, head_idx]
    else:
        activations = clean_activations[hook.name][:, head_idx]
    
    return activations

In [104]:
from transformer_lens.utils import get_act_name

In [None]:
results = torch.zeros(n_layers, n_heads)

for layer_idx, head_idx in combinations:
    model.reset_hooks()
    
    hook_name = get_act_name("z", layer_idx)
    hook_func = partial(
        patch_head,
        clean_activations=clean_activations,
        corrupted_activations=corrupted_activations,
        target_head=(layer_idx, head_idx)
    )
    
    patched_logits = model.run_with_hooks(clean_tokens)
    logit_diff = compute_ioi_metric(patched_logits)
    results[layer_idx, head_idx] = logit_diff

In [None]:
_, cache = model.run_with_cache(tokens)

In [None]:
W_in_acts = cache["post", layer_idx]
W_out = model.W_out[layer_idx]

In [105]:
from einops import einsum

In [None]:
einsum(W_in_acts, W_out)

In [None]:
A@W_OV

In [None]:
_, cache = model.run_with_cache(board_history)

In [None]:
cache[]

In [None]:
_, cache = model.run_with_cache(clean_tokens)

In [106]:
layer_idx, head_idx = 9, 9

In [None]:
attn_out = cache["z", layer_idx][:, :, head_idx]

W_U = model.W_U
io_dir = W_U[:, io_tokens]
s_dir = W_U[:, s_tokens]

In [None]:
projection_in_io_dir = (attn_out * io_dir).sum()
projection_in_s_dir = (attn_out * s_dir).sum()

In [None]:
attn_pattern = cache["pattern", layer_idx][:, head_idx]
attn_from_end_to_io = attn_pattern[:, end_idxs, io_idxs]
attn_from_end_to_s = attn_pattern[:, end_idxs, s_idxs]

In [None]:
_, cache = model.run_with_cache(past_moves)

In [None]:
top_neurons = cache["post", 2].std(dim=[0, 1]).argsort(descending=True)

In [None]:
W_out = model.W_out[2]

In [None]:
W_V = torxh.

In [None]:
A@x@W_OV@W_OV

In [112]:
class Scatter(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        world_size = dist.get_world_size()
        rank = dist.get_rank()
        chunks = torch.split(
            inputs,
            split_size_or_sections=input.shape[0] // world_size
        )
        return chunks[rank]

    @staticmethod
    def backward(ctx, grad_input):
        world_size = dist.get_world_size()
        grads = [torch.zeros_like(grad_input) for _ in range(world_size)]
        dist.all_gather(grads, grad_input)
        grads = torch.cat(grads)
        return grads

In [111]:
class Reduce(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        dist.all_reduce(input)
        return input
    
    @staticmethod
    def backward(ctx, grad_input):
        return grad_input

In [108]:
class RowParallelLinear(torch.autograd.Function):
    def __init__(self, input_size, output_size):
        super().__init__()
        world_size = dist.get_world_size()
        self.weight = nn.Parameter(torch.randn(
            output_size,
            input_size//world_size
        ))
        self.bias = nn.Parameter(torch.randn(
            output_size
        ))
    
    def forward(self, input):
        input_parallel = Scatter.apply(input)
        output_parallel = F.linear(input_parallel, self.weight)
        outputs = Reduce.apply(output_parallel)
        return outputs + self.bias

In [None]:
scatter > all-reduce > identity > all-gather

In [None]:
broadcast > gather > scatter > all-reduce

In [None]:
scatter > all-reduce > identty > all gather 

In [113]:
from torch import nn

In [114]:
embedding = nn.Embedding(4, 2)

In [116]:
embedding

Embedding(4, 2)

In [124]:
input = torch.tensor([[55]])

In [127]:
100 // 5

20

In [128]:
weight = torch.randn(4, 2)

In [129]:
chunks = torch.split(weight, 2)

In [130]:
chunks[0]

tensor([[-0.2077, -1.2928],
        [ 0.4192, -0.6931]])

In [131]:
chunks[1]

tensor([[ 0.5300,  0.5838],
        [-1.2980,  0.2849]])

In [133]:
weight

tensor([[-0.2077, -1.2928],
        [ 0.4192, -0.6931],
        [ 0.5300,  0.5838],
        [-1.2980,  0.2849]])