<a href="https://colab.research.google.com/github/puneeshkhanna/Tensor-Parallelism/blob/master/tensor_parallelism_ffn_layers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Link - https://medium.com/@puneesh.khanna83/understanding-tensor-parallelism-to-fit-larger-models-on-multiple-devices-d4da1821d41b


In [None]:
import numpy
import torch
import torch.nn as nn

In [None]:
# Input of dimensions (batch size, no of words or seq len, embedding dimension or hidden size of each word)
input = torch.randn(size=(1, 5, 10), dtype=torch.float32)

In [None]:
embedding_dim = input.size(dim=2)
embedding_dim

10

## Back to back linear layers output

Tranformer architectures have back to back linear layers where the embedding dim
first goes from h to 4h and then back from 4h to h

```
(mlp): MLP(
          (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=False)          
          (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=False)
        )
```

In this notebook, below dimensions are used.
```
(mlp): MLP(
          (dense_h_to_4h): Linear(in_features=10, out_features=40, bias=False)          
          (dense_4h_to_h): Linear(in_features=40, out_features=19, bias=False)
        )
```

In [None]:
linear_h_to_4h = nn.Linear(in_features=embedding_dim, out_features=embedding_dim*4, bias=False)

# Input of shape [1,5,10] * W.T of shape [10, 40]
output = linear_h_to_4h(input)

# input of dimension (1,5,10) gets transformed to dimension (1,5,40)
output.shape

torch.Size([1, 5, 40])

In [None]:
linear_4h_to_h = nn.Linear(in_features=embedding_dim*4, out_features=embedding_dim, bias=False)

# h_to_4h output of shape [1,5,40] * W.T of shape [40,10]
final_output = linear_4h_to_h(output)

# h_to_4h output of dimension (1,5,40) gets transformed back to dimension (1,5,10)
final_output.shape

torch.Size([1, 5, 10])

## Back to back linear layer outputs with Tensor parallelism

Assuming 2 devices, below code depicts how weights will be divided between the 2 devices.

Note that all the below code is executed on single device only with comments that which blocks will be executed on first device or second device.

In [None]:
n_devices = 2
weight_parallel = int((embedding_dim*4)/n_devices)
weight_parallel

20

### First device - h_to_4h logic

In [None]:
# First device weights of h_to_4h linear layer will be [out_features, in_feature] = [20, 10]
linear_h_to_4h_parallel = nn.Linear(in_features=embedding_dim, out_features=weight_parallel, bias=False)

# Set weights of this layer to the first 20 rows of the original weights
linear_h_to_4h_parallel.weight.data = linear_h_to_4h.weight[:weight_parallel,:]

# Input of shape [1,5,10] * W.T of shape [10, 20] ; note that in W.T - it is actually the columns which are divided; hence this is also known as column parallel linear
output_parallel_1 = linear_h_to_4h_parallel(input)

output_parallel_1.shape

torch.Size([1, 5, 20])

### Second device - h_to_4h logic

In [None]:
# Second device weights of h_to_4h linear layer will be [out_features, in_feature] = [20, 10]
linear_h_to_4h_parallel = nn.Linear(in_features=embedding_dim, out_features=weight_parallel, bias=False)

# Set weights of this layer to the last 20 rows of the original weights
linear_h_to_4h_parallel.weight.data = linear_h_to_4h.weight[weight_parallel:,:]

# Input[1,5,10] * W.T[10, 20] ; note that in W.T - it is actually the columns which are divided; hence this is also known as column parallel linear
output_parallel_2 = linear_h_to_4h_parallel(input)

output_parallel_2.shape

torch.Size([1, 5, 20])

In [None]:
# If we perform an all gather of output_parallel_1 and output_parallel_2, we will get the actual output of h_to_4h linear layer
# but since we have one more linear layer of 4h_to_h, we can continue with the tensor parallel outputs.

### First device - 4h_to_h logic

In [None]:
# First device weights of 4h_to_h linear layer will be [out_features, in_features] = [10, 20]
linear_4h_to_h_parallel = nn.Linear(in_features=weight_parallel, out_features=embedding_dim, bias=False)

# Set weights of this layer to the first 20 columns of the original weights
linear_4h_to_h_parallel.weight.data = linear_4h_to_h.weight[:,:weight_parallel]

# output_parallel_1 of shape [1,5,20] * W.T of shape [20,10] ; note that in W.T - it is actually the rows which are divided; hence this is also known as row parallel linear
final_output_parallel_1 = linear_4h_to_h_parallel(output_parallel_1)

final_output_parallel_1.shape

torch.Size([1, 5, 10])

### Second device - 4h_to_h logic

In [None]:
# Second device weights of 4h_to_h linear layer will be [out_features, in_features] = [10, 20]
linear_4h_to_h_parallel = nn.Linear(in_features=weight_parallel, out_features=embedding_dim, bias=False)

# Set weights of this layer to the last 20 columns of the original weights
linear_4h_to_h_parallel.weight.data = linear_4h_to_h.weight[:,weight_parallel:]

# output_parallel_2 of shape [1,5,20] * W.T of shape [20,10] ; note that in W.T - it is actually the rows which are divided; hence this is also known as row parallel linear
final_output_parallel_2 = linear_4h_to_h_parallel(output_parallel_2)

final_output_parallel_2.shape

torch.Size([1, 5, 10])

### Add tensor parallel outputs of 4h_to_h using torch.add; it will be an all reduce operation when using actual 2 devices

In [None]:
# In actual tensor parallelism between 2 devices, this will be all reduce operation
final_output_tp = torch.add(final_output_parallel_1, final_output_parallel_2)

final_output_tp.shape

torch.Size([1, 5, 10])

In [None]:
final_output_tp.shape

torch.Size([1, 5, 10])

## Compare final_output and final_output_tp using allclose

In [None]:
print(torch.allclose(final_output, final_output_tp, rtol=1e-05, atol=1e-05))

True


In [None]:
final_output

tensor([[[ 0.0230, -0.1114, -0.1713,  0.5962, -0.1553, -0.1313, -0.2674,
          -0.2797,  0.0750, -0.1309],
         [-0.2180, -0.1899, -0.3012,  0.3890, -0.1467, -0.1274, -0.1041,
           0.0927, -0.2869, -0.1906],
         [ 0.0226,  0.1911,  0.1427, -0.6180, -0.0171,  0.3005,  0.3173,
           0.2533,  0.0814,  0.2457],
         [ 0.5145,  0.0557,  0.3507,  0.1115,  0.2716, -0.3608,  0.1050,
          -0.2071, -0.0866,  0.1415],
         [-0.0225,  0.0401, -0.4115,  0.1852, -0.2485,  0.1331, -0.4189,
          -0.1517, -0.0531,  0.2920]]], grad_fn=<UnsafeViewBackward0>)

In [None]:
final_output_tp

tensor([[[ 0.0230, -0.1114, -0.1713,  0.5962, -0.1553, -0.1313, -0.2674,
          -0.2797,  0.0750, -0.1309],
         [-0.2180, -0.1899, -0.3012,  0.3890, -0.1467, -0.1274, -0.1041,
           0.0927, -0.2869, -0.1906],
         [ 0.0226,  0.1911,  0.1427, -0.6180, -0.0171,  0.3005,  0.3173,
           0.2533,  0.0814,  0.2457],
         [ 0.5145,  0.0557,  0.3507,  0.1115,  0.2716, -0.3608,  0.1050,
          -0.2071, -0.0866,  0.1415],
         [-0.0225,  0.0401, -0.4115,  0.1852, -0.2485,  0.1331, -0.4189,
          -0.1517, -0.0531,  0.2920]]], grad_fn=<AddBackward0>)