<a href="https://colab.research.google.com/github/puneeshkhanna/Tensor-Parallelism/blob/master/tensor_parallelism_attention_layers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [25]:
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [26]:
# Input of dimensions (batch size, no of words or q_len, embedding dimension or hidden size of each word)
input = torch.randn(size=(1, 5, 32), dtype=torch.float32)

In [27]:
bsz, q_len, hidden_size = input.size()
hidden_size

32

In [28]:
# Number of attention heads of multi head attention of each transformer block
num_heads = 4

# hidden size is divisible by num heads; per head embedding dim
head_dim = hidden_size // num_heads

print(f"num heads is {num_heads}, hidden size is {hidden_size}, head dim is {head_dim}")

num heads is 4, hidden size is 32, head dim is 8


## Attention Layer output

Tranformer architectures have Attention block with QKV layers followed by an o_proj (dense layer)

```
Attention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
```

In [29]:
q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
k_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
v_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)

o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)

print("q_proj weights shape:", q_proj.weight.shape)
print("o_proj weights shape:", o_proj.weight.shape)

q_proj weights shape: torch.Size([32, 32])
o_proj weights shape: torch.Size([32, 32])


In [30]:
query_states = q_proj(input)
key_states = k_proj(input)
value_states = v_proj(input)

print("query_states after projections -> [batch size, q_len, hidden_size]: ", query_states.shape)

query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

print("\nquery_states after view and transpose -> [batch size, num heads, q_len, head_dim]:", query_states.shape)

query_states after projections -> [batch size, q_len, hidden_size]:  torch.Size([1, 5, 32])

query_states after view and transpose -> [batch size, num heads, q_len, head_dim]: torch.Size([1, 4, 5, 8])


In [31]:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
print("attn matrix after QK.T and softmax -> [batch size, num heads, q_len, q_len]:", attn_weights.shape)

attn_output = torch.matmul(attn_weights, value_states)
print("\nattn output -> matmul of attn matrix [batch size, num heads, q_len, q_len] and value states [batch size, num heads, q_len, head_dim] -> [batch size, num heads, q_len, head_dim]:", attn_output.shape)

attn matrix after QK.T and softmax -> [batch size, num heads, q_len, q_len]: torch.Size([1, 4, 5, 5])

attn output -> matmul of attn matrix [batch size, num heads, q_len, q_len] and value states [batch size, num heads, q_len, head_dim] -> [batch size, num heads, q_len, head_dim]: torch.Size([1, 4, 5, 8])


In [32]:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, hidden_size)
print("attn output after view and reshape -> [batch size, q_len, hidden_size]:", attn_output.shape)

attn_output_non_tp = o_proj(attn_output)
print("\nattn output after final o_proj -> [batch size, q_len, hidden_size]:", attn_output.shape)

attn output after view and reshape -> [batch size, q_len, hidden_size]: torch.Size([1, 5, 32])

attn output after final o_proj -> [batch size, q_len, hidden_size]: torch.Size([1, 5, 32])


## Attention layer output with Tensor parallelism

Assuming 2 devices, below code depicts how attention heads will be divided between the 2 devices.

Note that all the below code is executed on single device only with comments that which blocks will be executed on first device or second device.

In [33]:
n_devices = 2
num_heads = num_heads // n_devices
hidden_size = hidden_size // n_devices
print(f"num heads changes to {num_heads}, hidden size changes to {hidden_size}")

num heads changes to 2, hidden size changes to 16


In [34]:
print("Dividing q_proj, k_proj, v_proj, o_proj weights based upon num_heads")

query_slices = q_proj.weight.split(num_heads * head_dim, dim=0)

key_slices = k_proj.weight.split(num_heads * head_dim, dim=0)

value_slices = v_proj.weight.split(num_heads * head_dim, dim=0)

o_proj_slices = o_proj.weight.split(num_heads * head_dim, dim=1)

Dividing q_proj, k_proj, v_proj, o_proj weights based upon num_heads


In [35]:
print("Original qkv proj weights without TP:",q_proj.weight.shape, "\tPer device qkv proj weights with TP:", query_slices[0].shape, query_slices[1].shape)

print("\nOriginal o_proj weights without TP:",o_proj.weight.shape, "\tPer device o_proj weights with TP:", o_proj_slices[0].shape, o_proj_slices[1].shape)

Original qkv proj weights without TP: torch.Size([32, 32]) 	Per device qkv proj weights with TP: torch.Size([16, 32]) torch.Size([16, 32])

Original o_proj weights without TP: torch.Size([32, 32]) 	Per device o_proj weights with TP: torch.Size([32, 16]) torch.Size([32, 16])


### First device - attention output logic

In [36]:
query_states = F.linear(input, query_slices[0])

key_states = F.linear(input, key_slices[0])

value_states = F.linear(input, value_slices[0])

print("query_states after projections -> [batch size, q_len, hidden_size]: ", query_states.shape)

query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

key_states = key_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

value_states = value_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

print("\nquery_states after view and transpose -> [batch size, num heads, q_len, head_dim]:", query_states.shape)

query_states after projections -> [batch size, q_len, hidden_size]:  torch.Size([1, 5, 16])

query_states after view and transpose -> [batch size, num heads, q_len, head_dim]: torch.Size([1, 2, 5, 8])


In [37]:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)

attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

print("attn matrix after QK.T and softmax -> [batch size, num heads, q_len, q_len]:", attn_weights.shape)

attn_output = torch.matmul(attn_weights, value_states)

print("attn output -> matmul of attn matrix [batch size, num heads, q_len, q_len] and value states [batch size, num heads, q_len, head_dim] -> [batch size, num heads, q_len, head_dim]:", attn_output.shape)

attn matrix after QK.T and softmax -> [batch size, num heads, q_len, q_len]: torch.Size([1, 2, 5, 5])
attn output -> matmul of attn matrix [batch size, num heads, q_len, q_len] and value states [batch size, num heads, q_len, head_dim] -> [batch size, num heads, q_len, head_dim]: torch.Size([1, 2, 5, 8])


In [38]:
attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.reshape(bsz, q_len, hidden_size)

print("attn output after view and reshape -> [batch size, q_len, hidden_size]:", attn_output.shape)

attn_output_1 = F.linear(attn_output, o_proj_slices[0])

print("\nattn output on 1st Tensor Parallel device after final o_proj -> [batch size, q_len, hidden_size]:", attn_output_1.shape)

attn output after view and reshape -> [batch size, q_len, hidden_size]: torch.Size([1, 5, 16])

attn output on 1st Tensor Parallel device after final o_proj -> [batch size, q_len, hidden_size]: torch.Size([1, 5, 32])


### Second device - attention output logic

In [39]:
query_states = F.linear(input, query_slices[1])

key_states = F.linear(input, key_slices[1])

value_states = F.linear(input, value_slices[1])

print("query_states after projections -> [batch size, q_len, hidden_size]: ", query_states.shape)

query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

key_states = key_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

value_states = value_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

print("\nquery_states after view and transpose -> [batch size, num heads, q_len, head_dim]:", query_states.shape)

query_states after projections -> [batch size, q_len, hidden_size]:  torch.Size([1, 5, 16])

query_states after view and transpose -> [batch size, num heads, q_len, head_dim]: torch.Size([1, 2, 5, 8])


In [40]:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)

attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)

print("attn matrix after QK.T and softmax -> [batch size, num heads, q_len, q_len]:", attn_weights.shape)

attn_output = torch.matmul(attn_weights, value_states)

print("\nattn output -> matmul of attn matrix [batch size, num heads, q_len, q_len] and value states [batch size, num heads, q_len, head_dim] -> [batch size, num heads, q_len, head_dim]:", attn_output.shape)

attn matrix after QK.T and softmax -> [batch size, num heads, q_len, q_len]: torch.Size([1, 2, 5, 5])

attn output -> matmul of attn matrix [batch size, num heads, q_len, q_len] and value states [batch size, num heads, q_len, head_dim] -> [batch size, num heads, q_len, head_dim]: torch.Size([1, 2, 5, 8])


In [41]:
attn_output = attn_output.transpose(1, 2).contiguous()

attn_output = attn_output.reshape(bsz, q_len, hidden_size)

print("attn output after view and reshape -> [batch size, q_len, hidden_size]:", attn_output.shape)

attn_output_2 = F.linear(attn_output, o_proj_slices[1])

print("\nattn output on 2nd Tensor Parallel device after final o_proj -> [batch size, q_len, hidden_size]:", attn_output_2.shape)

attn output after view and reshape -> [batch size, q_len, hidden_size]: torch.Size([1, 5, 16])

attn output on 2nd Tensor Parallel device after final o_proj -> [batch size, q_len, hidden_size]: torch.Size([1, 5, 32])


### Add tensor parallel outputs of attn_output using torch.add; it will be an all reduce operation when using actual 2 devices

In [42]:
# In actual tensor parallelism between 2 devices, this will be all reduce operation
attn_output_tp = torch.add(attn_output_1, attn_output_2)
attn_output_tp.shape

torch.Size([1, 5, 32])

## Compare final_output_non_tp and final_output_tp using allclose

In [43]:
print(torch.allclose(attn_output_non_tp, attn_output_tp, rtol=1e-05, atol=1e-05))

True


In [44]:
attn_output_non_tp

tensor([[[ 5.3712e-02, -1.1881e-01,  1.5588e-01,  5.6367e-02,  3.2628e-02,
           7.5863e-02,  2.8828e-02,  7.3524e-02, -5.9104e-03,  5.6621e-02,
           9.0149e-02,  6.6041e-02,  6.0233e-02,  5.9664e-02,  7.0160e-02,
           1.8019e-01,  1.8662e-02,  3.4610e-02,  2.2844e-01, -1.4812e-01,
          -1.2527e-02, -1.0309e-01,  8.6435e-02, -2.1808e-01,  1.4488e-01,
          -1.4784e-01,  1.8542e-02, -7.8376e-02,  2.8780e-03, -2.3423e-01,
          -3.8625e-02,  1.1624e-01],
         [ 1.1160e-02, -1.0493e-01,  4.8735e-02,  1.5505e-01,  2.0439e-01,
           1.7308e-03, -1.0093e-01,  1.0066e-01, -1.3819e-01,  3.0193e-01,
          -1.6863e-02, -6.8167e-02,  1.3596e-01, -2.5812e-02,  1.3117e-01,
           2.2190e-01,  8.5406e-02,  1.7488e-01,  2.2445e-01, -2.7848e-02,
           1.9271e-02, -1.6363e-01,  2.3058e-02, -2.0312e-01,  1.2977e-01,
          -6.3400e-02,  1.4039e-01,  1.8939e-01,  7.7873e-02,  4.7048e-02,
          -1.1575e-01,  1.8809e-02],
         [ 2.2801e-01,  7.

In [45]:
attn_output_tp

tensor([[[ 5.3712e-02, -1.1881e-01,  1.5588e-01,  5.6367e-02,  3.2628e-02,
           7.5863e-02,  2.8828e-02,  7.3524e-02, -5.9104e-03,  5.6621e-02,
           9.0149e-02,  6.6041e-02,  6.0233e-02,  5.9664e-02,  7.0160e-02,
           1.8019e-01,  1.8662e-02,  3.4610e-02,  2.2844e-01, -1.4812e-01,
          -1.2527e-02, -1.0309e-01,  8.6435e-02, -2.1808e-01,  1.4488e-01,
          -1.4784e-01,  1.8542e-02, -7.8376e-02,  2.8780e-03, -2.3423e-01,
          -3.8625e-02,  1.1624e-01],
         [ 1.1160e-02, -1.0493e-01,  4.8735e-02,  1.5505e-01,  2.0439e-01,
           1.7308e-03, -1.0093e-01,  1.0066e-01, -1.3819e-01,  3.0193e-01,
          -1.6863e-02, -6.8167e-02,  1.3596e-01, -2.5812e-02,  1.3117e-01,
           2.2190e-01,  8.5406e-02,  1.7488e-01,  2.2445e-01, -2.7848e-02,
           1.9271e-02, -1.6363e-01,  2.3058e-02, -2.0312e-01,  1.2977e-01,
          -6.3400e-02,  1.4039e-01,  1.8939e-01,  7.7873e-02,  4.7048e-02,
          -1.1575e-01,  1.8809e-02],
         [ 2.2801e-01,  7.