### Tensor Parallelism

##### Example 1

In [None]:
import torch

In [None]:
inputs = torch.tensor([
    [0, 1, 2, 3],
    [4, 5, 6, 7]
])

In [None]:
weights = torch.tensor([
    [10, 14],
    [11, 15],
    [12, 16],
    [13, 17]
])

In [None]:
outputs = inputs @ weights

In [None]:
outputs.shape

torch.Size([2, 2])

In [None]:
import torch

In [None]:
inputs.shape, weights.shape

(torch.Size([2, 4]), torch.Size([4, 2]))

In [None]:
outputs = inputs @ weights

Compute **the matrix multiplication operation** using tensor parallelism with **a factor of 2**.

In [None]:
def by_column_parallelism(inputs, weights):
    n_cols = weights.shape[-1]    
    w1, w2 = weights[:, :n_cols//2], weights[:, n_cols//2:]
    out1 = inputs @ w1
    out2 = inputs @ w2
    return torch.cat([out1, out2], dim=-1)

In [None]:
column_output = by_column_parallelism(inputs, weights)

In [None]:
output == column_output

tensor([[True, True],
        [True, True]])

In [None]:
def by_row_parallelism(inputs, weights):
    n_cols = weights.shape[-1]
    x1, x2 = inputs[:, :n_cols//2], inputs[:, n_cols//2:]
    w1, w2 = weights[:n_cols//2, :], weights[n_cols//2:, :]
    out1 = x1 @ w1
    out2 = x2 @ w2    
    return out1 + out2

In [None]:
row_output = by_row_parallelism(inputs, weights)

In [None]:
output == row_output

tensor([[True, True],
        [True, True]])

### Transformer Block

##### Example 1

In [None]:
from torch import nn

In [None]:
class VocabParallelEmbedding(nn.Module):
    def forward(self, x):
        