# Multi-head Attention

1. multi-head attention 및 self-attetion 구현
2. 각 과정에서 일어나는 연산과 input/output형태 이해


## 필요 패키지 import

In [None]:
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

import torch
import math

## 데이터 전처리

In [None]:
pad_id = 0
vocab_size = 100

data = [
  [62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54],
  [75, 51],
  [66, 88, 98, 47],
  [21, 39, 10, 64, 21],
  [98],
  [77, 65, 51, 77, 19, 15, 35, 19, 23, 97, 50, 46, 53, 42, 45, 91, 66, 3, 43, 10],
  [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34],
  [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43]
]

In [None]:
def padding(data):
  max_len = len(max(data, key = len))
  print(f"Maximum length: {max_len}")

  for i, seq in enumerate(tqdm(data)):
    if len(seq)<max_len:
      data[i] = seq + [pad_id]*(max_len-len(seq))
    
  return data, max_len

In [None]:
data, max_len = padding(data)

100%|██████████| 10/10 [00:00<00:00, 49991.70it/s]

Maximum length: 20





In [None]:
data

[[62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75, 0, 0, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [75, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [21, 39, 10, 64, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [98, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [77,
  65,
  51,
  77,
  19,
  15,
  35,
  19,
  23,
  97,
  50,
  46,
  53,
  42,
  45,
  91,
  66,
  3,
  43,
  10],
 [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34, 0, 0, 0, 0],
 [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43, 0, 0, 0, 0, 0, 0]]

## Hyperparameter 세팅 및 embedding

In [None]:
d_model = 512 # hidden size of model
num_heads = 8 # the number of heads

In [None]:
embedding = nn.Embedding(vocab_size, d_model)

# B: batch size, L: maximum sequence length
batch = torch.LongTensor(data)  # (B, L)
batch_emb = embedding(batch)    # (B, L, d_model)

In [None]:
print(batch_emb)
print(batch_emb.shape)

tensor([[[ 0.8649, -1.7624, -0.1467,  ...,  0.1065, -0.5493, -0.9687],
         [-0.5280,  0.5591,  0.3295,  ...,  0.9948, -2.4267, -1.2439],
         [-0.9900,  1.9719,  0.3138,  ...,  0.0512, -1.3343,  1.2161],
         ...,
         [ 0.0243,  0.3623,  0.3413,  ..., -0.2686, -0.3003,  1.2229],
         [ 0.0243,  0.3623,  0.3413,  ..., -0.2686, -0.3003,  1.2229],
         [ 0.0243,  0.3623,  0.3413,  ..., -0.2686, -0.3003,  1.2229]],

        [[-0.7190,  0.5153, -1.0311,  ..., -1.0616,  1.0147,  0.2931],
         [ 1.0068,  0.0157, -0.0789,  ..., -0.1688,  0.6181,  0.2007],
         [-0.1904, -0.7859, -0.5010,  ..., -1.7233,  0.1630,  0.3240],
         ...,
         [ 0.0243,  0.3623,  0.3413,  ..., -0.2686, -0.3003,  1.2229],
         [ 0.0243,  0.3623,  0.3413,  ..., -0.2686, -0.3003,  1.2229],
         [ 0.0243,  0.3623,  0.3413,  ..., -0.2686, -0.3003,  1.2229]],

        [[-0.3793, -0.6177,  1.3900,  ...,  0.2882,  0.2737,  1.1043],
         [ 1.5289, -2.1897,  0.4747,  ..., -0

## Linear transformation & split into several heads

define the matrices used in Multi-head attention

In [None]:
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)

In [None]:
w_0 = nn.Linear(d_model, d_model)

In [None]:
q = w_q(batch_emb)  # (B, L, d_model)
k = w_k(batch_emb)  # (B, L, d_model)
v = w_v(batch_emb)  # (B, L, d_model)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 512])
torch.Size([10, 20, 512])
torch.Size([10, 20, 512])


make vectors for `Q, k, v` divided into the number of the dimension of `num_head`

In [None]:
batch_size = q.shape[0]
d_k = d_model // num_heads

q = q.view(batch_size, -1, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])


## Scaled-dot product self-attention

it is the process of self-attnetion for each head

In [None]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
attn_dists = F.softmax(attn_scores, dim = -1)

print(attn_dists)
print(attn_dists.shape)

tensor([[[[0.1318, 0.1550, 0.0935,  ..., 0.0809, 0.1775, 0.1748],
          [0.2706, 0.0917, 0.1166,  ..., 0.1883, 0.1077, 0.0368],
          [0.1419, 0.1564, 0.1201,  ..., 0.1719, 0.0831, 0.1145],
          ...,
          [0.0702, 0.2032, 0.1900,  ..., 0.0995, 0.1312, 0.0840],
          [0.1302, 0.1279, 0.1387,  ..., 0.1376, 0.1013, 0.0941],
          [0.1226, 0.0985, 0.1069,  ..., 0.1747, 0.1171, 0.1145]],

         [[0.1224, 0.1282, 0.1021,  ..., 0.1567, 0.2502, 0.0892],
          [0.0843, 0.1336, 0.1431,  ..., 0.0916, 0.1909, 0.1151],
          [0.1078, 0.1030, 0.0970,  ..., 0.1682, 0.1435, 0.1183],
          ...,
          [0.1431, 0.1403, 0.1357,  ..., 0.0652, 0.0978, 0.1330],
          [0.0899, 0.0785, 0.2660,  ..., 0.0821, 0.0997, 0.1342],
          [0.1331, 0.0730, 0.0848,  ..., 0.1372, 0.0731, 0.2091]],

         [[0.1150, 0.1244, 0.1434,  ..., 0.1598, 0.1230, 0.1073],
          [0.1735, 0.1504, 0.1092,  ..., 0.0756, 0.1283, 0.0850],
          [0.1940, 0.1704, 0.0866,  ..., 0

In [None]:
attn_values = torch.matmul(attn_dists, v)

print(attn_values.shape)

torch.Size([10, 20, 8, 64])


## Merge each result

Concatenate each result.

And add fully-connected layers(linear transformation) with the proper dimension.

In [None]:
attn_values = attn_values.transpose(1, 2) # (B, L, num_heads, d_k)
attn_values = attn_values.contiguous().view(batch_size, -1, d_model)  #(B, L, d_model)

print(attn_values.shape)

torch.Size([10, 20, 512])


In [None]:
outputs = w_0(attn_values)

print(outputs)
print(outputs.shape)

tensor([[[-4.2419e-02, -5.0413e-02, -7.5183e-03,  ..., -3.6984e-02,
          -5.2476e-03, -2.6158e-02],
         [ 5.7163e-02,  4.9516e-02,  1.2621e-02,  ...,  6.9614e-03,
          -1.6063e-01, -1.2941e-01],
         [ 1.3234e-01,  2.2052e-01, -2.5090e-02,  ..., -1.6492e-01,
           7.6562e-02,  1.3833e-01],
         ...,
         [ 1.2553e-01,  1.9933e-01, -4.3863e-02,  ..., -1.7124e-01,
           7.2022e-02,  1.5761e-01],
         [-6.7493e-02,  1.6571e-01, -1.8425e-01,  ...,  2.4885e-01,
           1.5054e-01,  5.1648e-02],
         [-3.9699e-02,  8.8506e-02,  1.0955e-01,  ..., -8.0166e-02,
           1.7501e-01, -1.9001e-01]],

        [[-1.0307e-01,  4.8448e-02,  1.4954e-01,  ...,  2.3246e-01,
           2.9636e-03,  1.3784e-01],
         [ 8.2978e-02,  2.1908e-02,  6.0375e-02,  ..., -2.0790e-01,
           1.2990e-01,  9.4210e-02],
         [ 1.9119e-01, -2.3776e-02, -1.0227e-01,  ..., -7.8632e-02,
           1.2237e-01,  1.1746e-01],
         ...,
         [ 9.8120e-02,  8

## The Whole Code

Put the codes above together to make Multi-head attention module

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self):
    super(MultiHeadAttention, self).__init__()

    # Q, K, V learnerable matrices
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)

    # Linear transformation for concatenated outputs
    self.w_0 = nn.Linear(d_model, d_model)

  def forward(self, q, k, v):
    batch_size = q.shape[0]

    q = self.w_q(q) # (B, L, d_model)
    k = self.w_q(k) # (B, L, d_model)
    v = self.w_q(v) # (B, L, d_model)

    q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

    q = q.transpose(1, 2) # (B, num_heads, L, d_k)
    k = k.transpose(1, 2) # (B, num_heads, L, d_k)
    v = v.transpose(1, 2) # (B, num_heads, L, d_k)

    attn_values = self.self_attention(q, k, v)
    attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, -1, d_model)  # (B, L, num_heads, d_k) => (B, L, d_model)

    return self.w_0(attn_values)
  
  def self_attention(self, q, k, v):
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # (B, num_heads, L, L)
    attn_dists = F.softmax(attn_scores, dim = -1) # (B, num_heads, L, L)

    attn_values = torch.matmul(attn_dists, v) # (B, num_heads, L, d_k)

    return attn_values

In [None]:
multihead_attn = MultiHeadAttention()

outputs = multihead_attn(batch_emb, batch_emb, batch_emb) # (B, L, d_model)

In [None]:
print(outputs)
print(outputs.shape)

tensor([[[-0.0953, -0.2668, -0.0293,  ...,  0.1291, -0.0275, -0.2341],
         [-0.5509,  0.0514,  0.0039,  ...,  0.0297,  0.0797,  0.1726],
         [-0.3558,  0.0457,  0.0024,  ...,  0.1884,  0.3044,  0.2683],
         ...,
         [-0.0263, -0.2028, -0.0705,  ...,  0.1075,  0.2533, -0.2356],
         [-0.0263, -0.2028, -0.0705,  ...,  0.1075,  0.2533, -0.2356],
         [-0.0263, -0.2028, -0.0705,  ...,  0.1075,  0.2533, -0.2356]],

        [[-0.1758,  0.0447, -0.1223,  ...,  0.0823,  0.2819,  0.0240],
         [ 0.0085, -0.3409, -0.1536,  ..., -0.0007,  0.3156,  0.1716],
         [-0.0534, -0.0326, -0.0608,  ...,  0.2224,  0.1670, -0.1612],
         ...,
         [-0.0351, -0.2495, -0.0925,  ...,  0.1341,  0.3069, -0.3234],
         [-0.0351, -0.2495, -0.0925,  ...,  0.1341,  0.3069, -0.3234],
         [-0.0351, -0.2495, -0.0925,  ...,  0.1341,  0.3069, -0.3234]],

        [[ 0.2885, -0.0145, -0.0887,  ...,  0.1108,  0.2228, -0.2776],
         [ 0.2202,  0.0567, -0.2858,  ...,  0