<a href="https://colab.research.google.com/github/puneeshkhanna/Tensor-Parallelism/blob/master/tensor_parallelism_attention_layers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [45]:
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [69]:
# Input of dimensions (batch size, no of words, embedding dimension or hidden size of each word)
input = torch.randn(size=(1, 5, 32), dtype=torch.float32)

In [70]:
bsz, q_len, hidden_size = input.size()
hidden_size

32

In [71]:
num_heads = 4
head_dim = hidden_size // num_heads
head_dim

8

## Attention Layer output

Tranformer architectures have Attention block with QKV layers followed by an o_proj (dense layer)

```
Attention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
```

In [72]:
q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
k_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
v_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)

o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)

In [73]:
query_states = q_proj(input)
key_states = k_proj(input)
value_states = v_proj(input)

query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

key_states = key_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

value_states = value_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

query_states.shape

torch.Size([1, 4, 5, 8])

In [74]:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)

attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

attn_output = torch.matmul(attn_weights, value_states)

attn_output.shape

torch.Size([1, 4, 5, 8])

In [75]:
attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.reshape(bsz, q_len, hidden_size)

attn_output_non_tp = o_proj(attn_output)

attn_output_non_tp.shape

torch.Size([1, 5, 32])

## Attention layer output with Tensor parallelism

Assuming 2 devices, below code depicts how weights will be divided between the 2 devices.

Note that all the below code is executed on single device only with comments that which blocks will be executed on first device or second device.

In [76]:
n_devices = 2
num_heads = num_heads // n_devices
hidden_size = hidden_size // n_devices
print(num_heads, hidden_size)

2 16


In [77]:
query_slices = q_proj.weight.split(num_heads * head_dim, dim=0)

key_slices = k_proj.weight.split(num_heads * head_dim, dim=0)

value_slices = v_proj.weight.split(num_heads * head_dim, dim=0)

o_proj_slices = o_proj.weight.split(num_heads * head_dim, dim=1)

In [78]:
print(q_proj.weight.shape, query_slices[0].shape, query_slices[1].shape)

print(o_proj.weight.shape, o_proj_slices[0].shape, o_proj_slices[1].shape)

torch.Size([32, 32]) torch.Size([16, 32]) torch.Size([16, 32])
torch.Size([32, 32]) torch.Size([32, 16]) torch.Size([32, 16])


### First device - attention output logic

In [79]:
query_states = F.linear(input, query_slices[0])

key_states = F.linear(input, key_slices[0])

value_states = F.linear(input, value_slices[0])

query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

key_states = key_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

value_states = value_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

query_states.shape

torch.Size([1, 2, 5, 8])

In [80]:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)

attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

attn_output = torch.matmul(attn_weights, value_states)

attn_output.shape

torch.Size([1, 2, 5, 8])

In [81]:
attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.reshape(bsz, q_len, hidden_size)

attn_output_1 = F.linear(attn_output, o_proj_slices[0])

### Second device - attention output logic

In [82]:
query_states = F.linear(input, query_slices[1])

key_states = F.linear(input, key_slices[1])

value_states = F.linear(input, value_slices[1])

query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

key_states = key_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

value_states = value_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

In [83]:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)

attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

attn_output = torch.matmul(attn_weights, value_states)

In [84]:
attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.reshape(bsz, q_len, hidden_size)

attn_output_2 = F.linear(attn_output, o_proj_slices[1])

### Add tensor parallel outputs of attn_output using torch.add; it will be an all reduce operation when using actual 2 devices

In [85]:
# In actual tensor parallelism between 2 devices, this will be all reduce operation
attn_output_tp = torch.add(attn_output_1, attn_output_2)
attn_output_tp.shape

torch.Size([1, 5, 32])

## Compare final_output and final_output_tp using allclose

In [89]:
print(torch.allclose(attn_output_non_tp, attn_output_tp, rtol=1e-05, atol=1e-05))

True


In [90]:
attn_output_non_tp

tensor([[[ 0.4250, -0.0481,  0.1116, -0.1460, -0.3481, -0.1828, -0.0176,
           0.1101, -0.0424,  0.0251,  0.0632,  0.0177, -0.1264,  0.0059,
           0.1619, -0.2412,  0.1473, -0.0929,  0.1656, -0.0315, -0.1305,
          -0.0577,  0.2283, -0.3891, -0.2213, -0.2125, -0.0287, -0.0260,
          -0.2675, -0.0463, -0.0470, -0.0351],
         [ 0.3329, -0.1350,  0.1774, -0.1266, -0.3457, -0.1331,  0.0591,
           0.1060, -0.1679,  0.0124,  0.1157,  0.0533, -0.1115, -0.0017,
           0.1772, -0.2066,  0.1702, -0.1749,  0.2329,  0.0469, -0.1127,
          -0.2003,  0.3471, -0.4370, -0.1605, -0.2053, -0.0667, -0.0485,
          -0.2776,  0.0339, -0.0089, -0.0604],
         [ 0.3094, -0.1319,  0.0992, -0.0496, -0.3021, -0.1464,  0.0567,
           0.0253,  0.0236, -0.0088,  0.0708, -0.0748, -0.2009, -0.0122,
           0.0542, -0.2259,  0.2555, -0.1405,  0.1925, -0.0574, -0.1862,
          -0.0678,  0.3356, -0.4476, -0.1952, -0.2065, -0.0017,  0.0199,
          -0.2879,  0.0167, -0

In [91]:
attn_output_tp

tensor([[[ 0.4250, -0.0481,  0.1116, -0.1460, -0.3481, -0.1828, -0.0176,
           0.1101, -0.0424,  0.0251,  0.0632,  0.0177, -0.1264,  0.0059,
           0.1619, -0.2412,  0.1473, -0.0929,  0.1656, -0.0315, -0.1305,
          -0.0577,  0.2283, -0.3891, -0.2213, -0.2125, -0.0287, -0.0260,
          -0.2675, -0.0463, -0.0470, -0.0351],
         [ 0.3329, -0.1350,  0.1774, -0.1266, -0.3457, -0.1331,  0.0591,
           0.1060, -0.1679,  0.0124,  0.1157,  0.0533, -0.1115, -0.0017,
           0.1772, -0.2066,  0.1702, -0.1749,  0.2329,  0.0469, -0.1127,
          -0.2003,  0.3471, -0.4370, -0.1605, -0.2053, -0.0667, -0.0485,
          -0.2776,  0.0339, -0.0089, -0.0604],
         [ 0.3094, -0.1319,  0.0992, -0.0496, -0.3021, -0.1464,  0.0567,
           0.0253,  0.0236, -0.0088,  0.0708, -0.0748, -0.2009, -0.0122,
           0.0542, -0.2259,  0.2555, -0.1405,  0.1925, -0.0574, -0.1862,
          -0.0678,  0.3356, -0.4476, -0.1952, -0.2065, -0.0017,  0.0199,
          -0.2879,  0.0167, -0