<a href="https://colab.research.google.com/github/puneeshkhanna/Tensor-Parallelism/blob/master/tensor_parallelism_attention_layers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Link - https://medium.com/@puneesh.khanna83/understanding-tensor-parallelism-to-fit-larger-models-on-multiple-devices-part-2-ee8a2ab7f017

In [13]:
import numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [14]:
# Input of dimensions (batch size, no of words or q_len, embedding dimension or hidden size of each word)
input = torch.randn(size=(1, 5, 32), dtype=torch.float32)

In [15]:
bsz, q_len, hidden_size = input.size()
hidden_size

32

In [16]:
# Number of attention heads of multi head attention of each transformer block
num_heads = 4

# hidden size must be divisible by num heads; calculate per head embedding dim
head_dim = hidden_size // num_heads

print(f"num heads is {num_heads}, hidden size is {hidden_size}, head dim is {head_dim}")

num heads is 4, hidden size is 32, head dim is 8


## Attention Layer output

Tranformer architectures have Attention block with QKV layers followed by an o_proj (dense layer)

```
Attention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
```

In this notebook, below dimensions are used along with num_heads=4.
```
Attention(
          (q_proj): Linear(in_features=32, out_features=32, bias=False)
          (k_proj): Linear(in_features=32, out_features=32, bias=False)
          (v_proj): Linear(in_features=32, out_features=32, bias=False)
          (o_proj): Linear(in_features=32, out_features=32, bias=False)
        )
```

In [17]:
q_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
k_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)
v_proj = nn.Linear(hidden_size, num_heads * head_dim, bias=False)

o_proj = nn.Linear(num_heads * head_dim, hidden_size, bias=False)

print("q_proj, k_proj, v_proj weights.T shape:", q_proj.weight.T.shape)
print("o_proj weights.T shape:", o_proj.weight.T.shape)

q_proj, k_proj, v_proj weights.T shape: torch.Size([32, 32])
o_proj weights.T shape: torch.Size([32, 32])


In [18]:
query_states = q_proj(input)
key_states = k_proj(input)
value_states = v_proj(input)

print("query_states, key_states, value_states after projections -> [batch size, q_len, num heads*head_dim]: ", query_states.shape)

query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

print("\nquery_states, key_states, value_states after view and transpose -> [batch size, num heads, q_len, head_dim]:", query_states.shape)

query_states, key_states, value_states after projections -> [batch size, q_len, num heads*head_dim]:  torch.Size([1, 5, 32])

query_states, key_states, value_states after view and transpose -> [batch size, num heads, q_len, head_dim]: torch.Size([1, 4, 5, 8])


In [19]:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
print("attn scores after QK.T and softmax -> [batch size, num heads, q_len, q_len]:", attn_weights.shape)

attn_output = torch.matmul(attn_weights, value_states)
print("\nattn output -> matmul of attn scores [batch size, num heads, q_len, q_len] and value states [batch size, num heads, q_len, head_dim] -> [batch size, num heads, q_len, head_dim]:", attn_output.shape)

attn scores after QK.T and softmax -> [batch size, num heads, q_len, q_len]: torch.Size([1, 4, 5, 5])

attn output -> matmul of attn scores [batch size, num heads, q_len, q_len] and value states [batch size, num heads, q_len, head_dim] -> [batch size, num heads, q_len, head_dim]: torch.Size([1, 4, 5, 8])


In [20]:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, num_heads*head_dim)
print("attn output after view and reshape -> [batch size, q_len, num heads*head_dim]:", attn_output.shape)

attn_output_non_tp = o_proj(attn_output)
print("\nattn output after final o_proj -> [batch size, q_len, hidden_size]:", attn_output.shape)

attn output after view and reshape -> [batch size, q_len, num heads*head_dim]: torch.Size([1, 5, 32])

attn output after final o_proj -> [batch size, q_len, hidden_size]: torch.Size([1, 5, 32])


## Attention layer output with Tensor parallelism

Assuming 2 devices, below code depicts how attention heads will be divided between the 2 devices.

Note that all the below code is executed on single device only with comments that which blocks will be executed on first device or second device.

In [21]:
n_devices = 2
num_heads = num_heads // n_devices
print(f"num heads changes to {num_heads} for each of the 2 device")

num heads changes to 2 for each of the 2 device


In [22]:
print("Dividing q_proj, k_proj, v_proj, o_proj weights based upon num_heads")

query_slices = q_proj.weight.split(num_heads * head_dim, dim=0)
key_slices = k_proj.weight.split(num_heads * head_dim, dim=0)
value_slices = v_proj.weight.split(num_heads * head_dim, dim=0)

o_proj_slices = o_proj.weight.split(num_heads * head_dim, dim=1)

Dividing q_proj, k_proj, v_proj, o_proj weights based upon num_heads


In [23]:
print("Original qkv proj weights transpose without TP:",q_proj.weight.T.shape, "\tPer device qkv proj weights transpose with TP (column divided):", query_slices[0].T.shape)

print("\nOriginal o_proj weights transpose without TP:",o_proj.weight.T.shape, "\tPer device o_proj weights transpose with TP (row divided):", o_proj_slices[0].T.shape)

Original qkv proj weights transpose without TP: torch.Size([32, 32]) 	Per device qkv proj weights transpose with TP (column divided): torch.Size([32, 16])

Original o_proj weights transpose without TP: torch.Size([32, 32]) 	Per device o_proj weights transpose with TP (row divided): torch.Size([16, 32])


### First device - attention output logic

In [24]:
query_states = F.linear(input, query_slices[0])
key_states = F.linear(input, key_slices[0])
value_states = F.linear(input, value_slices[0])

print("query_states, key_states, value_states after projections -> [batch size, q_len, num heads*head_dim]: ", query_states.shape)

query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

print("\nquery_states, key_states, value_states after view and transpose -> [batch size, num heads, q_len, head_dim]:", query_states.shape)

query_states, key_states, value_states after projections -> [batch size, q_len, num heads*head_dim]:  torch.Size([1, 5, 16])

query_states, key_states, value_states after view and transpose -> [batch size, num heads, q_len, head_dim]: torch.Size([1, 2, 5, 8])


In [25]:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
print("attn scores after QK.T and softmax -> [batch size, num heads, q_len, q_len]:", attn_weights.shape)

attn_output = torch.matmul(attn_weights, value_states)
print("\nattn output -> matmul of attn scores [batch size, num heads, q_len, q_len] and value states [batch size, num heads, q_len, head_dim] -> [batch size, num heads, q_len, head_dim]:", attn_output.shape)

attn scores after QK.T and softmax -> [batch size, num heads, q_len, q_len]: torch.Size([1, 2, 5, 5])

attn output -> matmul of attn scores [batch size, num heads, q_len, q_len] and value states [batch size, num heads, q_len, head_dim] -> [batch size, num heads, q_len, head_dim]: torch.Size([1, 2, 5, 8])


In [26]:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, num_heads*head_dim)
print("attn output after view and reshape -> [batch size, q_len, num heads*head_dim]:", attn_output.shape)

attn_output_1 = F.linear(attn_output, o_proj_slices[0])
print("\nattn output on 1st Tensor Parallel device after final o_proj -> [batch size, q_len, hidden_size]:", attn_output_1.shape)

attn output after view and reshape -> [batch size, q_len, num heads*head_dim]: torch.Size([1, 5, 16])

attn output on 1st Tensor Parallel device after final o_proj -> [batch size, q_len, hidden_size]: torch.Size([1, 5, 32])


### Second device - attention output logic

In [27]:
query_states = F.linear(input, query_slices[1])
key_states = F.linear(input, key_slices[1])
value_states = F.linear(input, value_slices[1])
print("query_states, key_states, value_states after projections -> [batch size, q_len, num heads*head_dim]: ", query_states.shape)

query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
print("\nquery_states, key_states, value_states after view and transpose -> [batch size, num heads, q_len, head_dim]:", query_states.shape)

query_states, key_states, value_states after projections -> [batch size, q_len, num heads*head_dim]:  torch.Size([1, 5, 16])

query_states, key_states, value_states after view and transpose -> [batch size, num heads, q_len, head_dim]: torch.Size([1, 2, 5, 8])


In [28]:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
print("attn scores after QK.T and softmax -> [batch size, num heads, q_len, q_len]:", attn_weights.shape)

attn_output = torch.matmul(attn_weights, value_states)
print("\nattn output -> matmul of attn scores [batch size, num heads, q_len, q_len] and value states [batch size, num heads, q_len, head_dim] -> [batch size, num heads, q_len, head_dim]:", attn_output.shape)

attn scores after QK.T and softmax -> [batch size, num heads, q_len, q_len]: torch.Size([1, 2, 5, 5])

attn output -> matmul of attn scores [batch size, num heads, q_len, q_len] and value states [batch size, num heads, q_len, head_dim] -> [batch size, num heads, q_len, head_dim]: torch.Size([1, 2, 5, 8])


In [29]:
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, num_heads*head_dim)
print("attn output after view and reshape -> [batch size, q_len, num heads*head_dim]:", attn_output.shape)

attn_output_2 = F.linear(attn_output, o_proj_slices[1])
print("\nattn output on 2nd Tensor Parallel device after final o_proj -> [batch size, q_len, hidden_size]:", attn_output_2.shape)

attn output after view and reshape -> [batch size, q_len, num heads*head_dim]: torch.Size([1, 5, 16])

attn output on 2nd Tensor Parallel device after final o_proj -> [batch size, q_len, hidden_size]: torch.Size([1, 5, 32])


### Add tensor parallel outputs of attn_output using torch.add; it will be an all reduce operation when using actual 2 devices

In [30]:
# In actual tensor parallelism between 2 devices, this will be all reduce operation
attn_output_tp = torch.add(attn_output_1, attn_output_2)
attn_output_tp.shape

torch.Size([1, 5, 32])

## Compare final_output_non_tp and final_output_tp using allclose

In [31]:
print(torch.allclose(attn_output_non_tp, attn_output_tp, rtol=1e-05, atol=1e-05))

True


In [32]:
attn_output_non_tp

tensor([[[-0.0735,  0.1547, -0.0292,  0.0470, -0.0282,  0.0597,  0.1018,
           0.1307, -0.1112, -0.1699,  0.2083, -0.1083,  0.0874, -0.0449,
          -0.1183, -0.2595, -0.0655, -0.2183,  0.1065, -0.1037,  0.2390,
           0.0212,  0.0796, -0.0258,  0.1713,  0.0254,  0.1817, -0.1307,
          -0.1476, -0.0997, -0.0394, -0.0292],
         [-0.1030,  0.1681, -0.0376,  0.0200, -0.1233,  0.0678,  0.1338,
          -0.0023, -0.1671, -0.1551,  0.1568, -0.0524, -0.0250, -0.0711,
          -0.2904, -0.2652, -0.0289, -0.2211,  0.0744, -0.0724,  0.2301,
           0.0566,  0.0180, -0.0022,  0.0772,  0.1041,  0.1453, -0.1092,
          -0.1428,  0.0365,  0.0212, -0.0211],
         [-0.1740,  0.0882, -0.0542, -0.0338, -0.0964,  0.1371,  0.0987,
           0.0847, -0.2318, -0.1612,  0.1891, -0.0056, -0.0497, -0.1431,
          -0.2468, -0.3641,  0.0449, -0.1987,  0.0127, -0.1358,  0.3252,
           0.0855,  0.0307,  0.0189,  0.1379,  0.1158,  0.1468, -0.0483,
          -0.1450,  0.0827, -0

In [33]:
attn_output_tp

tensor([[[-0.0735,  0.1547, -0.0292,  0.0470, -0.0282,  0.0597,  0.1018,
           0.1307, -0.1112, -0.1699,  0.2083, -0.1083,  0.0874, -0.0449,
          -0.1183, -0.2595, -0.0655, -0.2183,  0.1065, -0.1037,  0.2390,
           0.0212,  0.0796, -0.0258,  0.1713,  0.0254,  0.1817, -0.1307,
          -0.1476, -0.0997, -0.0394, -0.0292],
         [-0.1030,  0.1681, -0.0376,  0.0200, -0.1233,  0.0678,  0.1338,
          -0.0023, -0.1671, -0.1551,  0.1568, -0.0524, -0.0250, -0.0711,
          -0.2904, -0.2652, -0.0289, -0.2211,  0.0744, -0.0724,  0.2301,
           0.0566,  0.0180, -0.0022,  0.0772,  0.1041,  0.1453, -0.1092,
          -0.1428,  0.0365,  0.0212, -0.0211],
         [-0.1740,  0.0882, -0.0542, -0.0338, -0.0964,  0.1371,  0.0987,
           0.0847, -0.2318, -0.1612,  0.1891, -0.0056, -0.0497, -0.1431,
          -0.2468, -0.3641,  0.0449, -0.1987,  0.0127, -0.1358,  0.3252,
           0.0855,  0.0307,  0.0189,  0.1379,  0.1158,  0.1468, -0.0483,
          -0.1450,  0.0827, -0