<a href="https://colab.research.google.com/github/puneeshkhanna/Tensor-Parallelism/blob/master/tensor_parallelism_attention_layers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [74]:
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [75]:
# Input of dimensions (batch size, no of words or q_len, embedding dimension or hidden size of each word)
input = torch.randn(size=(1, 5, 32), dtype=torch.float32)

In [76]:
bsz, q_len, hidden_size = input.size()
hidden_size

32

In [77]:
# Number of attention heads of multi head attention of each transformer block
num_heads = 4

# hidden size is divisible by num heads; per head embedding dim
head_dim = hidden_size // num_heads

print(f"num heads is {num_heads}, hidden size is {hidden_size}, head dim is {head_dim}")

num heads is 4, hidden size is 32, head dim is 8


## Attention Layer output

Tranformer architectures have Attention block with QKV layers followed by an o_proj (dense layer)

```
Attention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
```

In [78]:
q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
k_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
v_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)

o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)

print("q_proj weights shape:", q_proj.weight.shape)
print("o_proj weights shape:", o_proj.weight.shape)

q_proj weights shape: torch.Size([32, 32])
o_proj weights shape: torch.Size([32, 32])


In [79]:
query_states = q_proj(input)
key_states = k_proj(input)
value_states = v_proj(input)

print("query_states after projections -> [batch size, q_len, num heads*head_dim]: ", query_states.shape)

query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

print("\nquery_states after view and transpose -> [batch size, num heads, q_len, head_dim]:", query_states.shape)

query_states after projections -> [batch size, q_len, num heads*head_dim]:  torch.Size([1, 5, 32])

query_states after view and transpose -> [batch size, num heads, q_len, head_dim]: torch.Size([1, 4, 5, 8])


In [80]:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
print("attn matrix after QK.T and softmax -> [batch size, num heads, q_len, q_len]:", attn_weights.shape)

attn_output = torch.matmul(attn_weights, value_states)
print("\nattn output -> matmul of attn matrix [batch size, num heads, q_len, q_len] and value states [batch size, num heads, q_len, head_dim] -> [batch size, num heads, q_len, head_dim]:", attn_output.shape)

attn matrix after QK.T and softmax -> [batch size, num heads, q_len, q_len]: torch.Size([1, 4, 5, 5])

attn output -> matmul of attn matrix [batch size, num heads, q_len, q_len] and value states [batch size, num heads, q_len, head_dim] -> [batch size, num heads, q_len, head_dim]: torch.Size([1, 4, 5, 8])


In [81]:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, num_heads*head_dim)
print("attn output after view and reshape -> [batch size, q_len, num heads*head_dim]:", attn_output.shape)

attn_output_non_tp = o_proj(attn_output)
print("\nattn output after final o_proj -> [batch size, q_len, hidden_size]:", attn_output.shape)

attn output after view and reshape -> [batch size, q_len, num heads*head_dim]: torch.Size([1, 5, 32])

attn output after final o_proj -> [batch size, q_len, hidden_size]: torch.Size([1, 5, 32])


## Attention layer output with Tensor parallelism

Assuming 2 devices, below code depicts how attention heads will be divided between the 2 devices.

Note that all the below code is executed on single device only with comments that which blocks will be executed on first device or second device.

In [82]:
n_devices = 2
num_heads = num_heads // n_devices
print(f"num heads changes to {num_heads}")

num heads changes to 2


In [83]:
print("Dividing q_proj, k_proj, v_proj, o_proj weights based upon num_heads")

query_slices = q_proj.weight.split(num_heads * head_dim, dim=0)
key_slices = k_proj.weight.split(num_heads * head_dim, dim=0)
value_slices = v_proj.weight.split(num_heads * head_dim, dim=0)

o_proj_slices = o_proj.weight.split(num_heads * head_dim, dim=1)

Dividing q_proj, k_proj, v_proj, o_proj weights based upon num_heads


In [84]:
print("Original qkv proj weights transpose without TP:",q_proj.weight.T.shape, "\tPer device qkv proj weights transpose with TP (column divided):", query_slices[0].T.shape)

print("\nOriginal o_proj weights transpose without TP:",o_proj.weight.T.shape, "\tPer device o_proj weights transpose with TP (row divided):", o_proj_slices[0].T.shape)

Original qkv proj weights transpose without TP: torch.Size([32, 32]) 	Per device qkv proj weights transpose with TP (column divided): torch.Size([32, 16])

Original o_proj weights transpose without TP: torch.Size([32, 32]) 	Per device o_proj weights transpose with TP (row divided): torch.Size([16, 32])


### First device - attention output logic

In [85]:
query_states = F.linear(input, query_slices[0])
key_states = F.linear(input, key_slices[0])
value_states = F.linear(input, value_slices[0])

print("query_states after projections -> [batch size, q_len, num heads*head_dim]: ", query_states.shape)

query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

print("\nquery_states after view and transpose -> [batch size, num heads, q_len, head_dim]:", query_states.shape)

query_states after projections -> [batch size, q_len, num heads*head_dim]:  torch.Size([1, 5, 16])

query_states after view and transpose -> [batch size, num heads, q_len, head_dim]: torch.Size([1, 2, 5, 8])


In [86]:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
print("attn matrix after QK.T and softmax -> [batch size, num heads, q_len, q_len]:", attn_weights.shape)

attn_output = torch.matmul(attn_weights, value_states)
print("\nattn output -> matmul of attn matrix [batch size, num heads, q_len, q_len] and value states [batch size, num heads, q_len, head_dim] -> [batch size, num heads, q_len, head_dim]:", attn_output.shape)

attn matrix after QK.T and softmax -> [batch size, num heads, q_len, q_len]: torch.Size([1, 2, 5, 5])

attn output -> matmul of attn matrix [batch size, num heads, q_len, q_len] and value states [batch size, num heads, q_len, head_dim] -> [batch size, num heads, q_len, head_dim]: torch.Size([1, 2, 5, 8])


In [87]:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, num_heads*head_dim)
print("attn output after view and reshape -> [batch size, q_len, num heads*head_dim]:", attn_output.shape)

attn_output_1 = F.linear(attn_output, o_proj_slices[0])
print("\nattn output on 1st Tensor Parallel device after final o_proj -> [batch size, q_len, hidden_size]:", attn_output_1.shape)

attn output after view and reshape -> [batch size, q_len, num heads*head_dim]: torch.Size([1, 5, 16])

attn output on 1st Tensor Parallel device after final o_proj -> [batch size, q_len, hidden_size]: torch.Size([1, 5, 32])


### Second device - attention output logic

In [88]:
query_states = F.linear(input, query_slices[1])
key_states = F.linear(input, key_slices[1])
value_states = F.linear(input, value_slices[1])
print("query_states after projections -> [batch size, q_len, num heads*head_dim]: ", query_states.shape)

query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
print("\nquery_states after view and transpose -> [batch size, num heads, q_len, head_dim]:", query_states.shape)

query_states after projections -> [batch size, q_len, num heads*head_dim]:  torch.Size([1, 5, 16])

query_states after view and transpose -> [batch size, num heads, q_len, head_dim]: torch.Size([1, 2, 5, 8])


In [89]:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
print("attn matrix after QK.T and softmax -> [batch size, num heads, q_len, q_len]:", attn_weights.shape)

attn_output = torch.matmul(attn_weights, value_states)
print("\nattn output -> matmul of attn matrix [batch size, num heads, q_len, q_len] and value states [batch size, num heads, q_len, head_dim] -> [batch size, num heads, q_len, head_dim]:", attn_output.shape)

attn matrix after QK.T and softmax -> [batch size, num heads, q_len, q_len]: torch.Size([1, 2, 5, 5])

attn output -> matmul of attn matrix [batch size, num heads, q_len, q_len] and value states [batch size, num heads, q_len, head_dim] -> [batch size, num heads, q_len, head_dim]: torch.Size([1, 2, 5, 8])


In [90]:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, num_heads*head_dim)
print("attn output after view and reshape -> [batch size, q_len, num heads*head_dim]:", attn_output.shape)

attn_output_2 = F.linear(attn_output, o_proj_slices[1])
print("\nattn output on 2nd Tensor Parallel device after final o_proj -> [batch size, q_len, hidden_size]:", attn_output_2.shape)

attn output after view and reshape -> [batch size, q_len, num heads*head_dim]: torch.Size([1, 5, 16])

attn output on 2nd Tensor Parallel device after final o_proj -> [batch size, q_len, hidden_size]: torch.Size([1, 5, 32])


### Add tensor parallel outputs of attn_output using torch.add; it will be an all reduce operation when using actual 2 devices

In [91]:
# In actual tensor parallelism between 2 devices, this will be all reduce operation
attn_output_tp = torch.add(attn_output_1, attn_output_2)
attn_output_tp.shape

torch.Size([1, 5, 32])

## Compare final_output_non_tp and final_output_tp using allclose

In [92]:
print(torch.allclose(attn_output_non_tp, attn_output_tp, rtol=1e-05, atol=1e-05))

True


In [93]:
attn_output_non_tp

tensor([[[-3.5638e-01,  3.2600e-01, -3.2392e-01,  3.5811e-01,  8.3697e-02,
           3.7586e-01, -2.3411e-01, -3.1438e-01, -2.7816e-01,  1.7625e-01,
           6.6336e-02, -2.1609e-02, -7.9867e-02,  4.6139e-03,  3.9733e-02,
          -2.8644e-01, -3.2399e-02,  3.3861e-01, -1.2574e-01,  1.1643e-01,
          -1.6354e-01, -1.6879e-01, -7.3738e-03, -1.1563e-01, -3.7473e-01,
           1.0994e-01,  2.5477e-01,  2.8067e-01, -1.8519e-01, -2.6387e-01,
          -4.0888e-01,  4.1851e-02],
         [-3.3681e-01,  1.5425e-01, -3.6065e-01,  4.2199e-01,  2.5573e-02,
           4.9650e-01, -2.8206e-01, -2.9998e-01, -3.5022e-01,  1.3859e-01,
           8.8276e-03,  5.4824e-02, -1.7515e-01,  8.2978e-02,  8.6214e-02,
          -3.4301e-01, -5.9058e-02,  3.6043e-01, -1.6147e-01,  4.5518e-02,
          -1.6713e-01, -1.0388e-01, -3.5998e-02, -6.6658e-02, -3.3598e-01,
           1.5457e-01,  3.3041e-01,  1.4448e-01, -1.4736e-01, -1.7590e-01,
          -3.8984e-01,  1.5257e-02],
         [-3.5187e-01,  2.

In [94]:
attn_output_tp

tensor([[[-3.5638e-01,  3.2600e-01, -3.2392e-01,  3.5811e-01,  8.3697e-02,
           3.7586e-01, -2.3411e-01, -3.1438e-01, -2.7816e-01,  1.7625e-01,
           6.6336e-02, -2.1609e-02, -7.9867e-02,  4.6139e-03,  3.9733e-02,
          -2.8644e-01, -3.2399e-02,  3.3861e-01, -1.2574e-01,  1.1643e-01,
          -1.6354e-01, -1.6879e-01, -7.3738e-03, -1.1563e-01, -3.7473e-01,
           1.0994e-01,  2.5477e-01,  2.8067e-01, -1.8519e-01, -2.6387e-01,
          -4.0888e-01,  4.1851e-02],
         [-3.3681e-01,  1.5425e-01, -3.6065e-01,  4.2199e-01,  2.5573e-02,
           4.9650e-01, -2.8206e-01, -2.9998e-01, -3.5022e-01,  1.3859e-01,
           8.8276e-03,  5.4824e-02, -1.7515e-01,  8.2978e-02,  8.6214e-02,
          -3.4301e-01, -5.9058e-02,  3.6043e-01, -1.6147e-01,  4.5518e-02,
          -1.6713e-01, -1.0388e-01, -3.5998e-02, -6.6658e-02, -3.3598e-01,
           1.5457e-01,  3.3041e-01,  1.4448e-01, -1.4736e-01, -1.7590e-01,
          -3.8984e-01,  1.5257e-02],
         [-3.5187e-01,  2.