##**7. Multi-head Attention**
1. Multi-head attention 및 self-attention 구현.
2. 각 과정에서 일어나는 연산과 input/output 형태 이해.
- 여기는 Encoder 부분이다

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### **필요 패키지 import**

In [12]:
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

import torch
import math

### **데이터 전처리**

In [None]:
pad_id = 0
vocab_size = 100

# RNN이 아니기 때문에 packed object 필요 없다

data = [
  [62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54],
  [75, 51],
  [66, 88, 98, 47],
  [21, 39, 10, 64, 21],
  [98],
  [77, 65, 51, 77, 19, 15, 35, 19, 23, 97, 50, 46, 53, 42, 45, 91, 66, 3, 43, 10],
  [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34],
  [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43]
]

In [None]:
# 길이 맞춰주기 위해
# valid_len 계산할 필요 없다
def padding(data):
  max_len = len(max(data, key=len))
  print(f"Maximum sequence length: {max_len}")

  for i, seq in enumerate(tqdm(data)):
    if len(seq) < max_len:
      data[i] = seq + [pad_id] * (max_len - len(seq))

  return data, max_len

In [None]:
data, max_len = padding(data)

100%|██████████| 10/10 [00:00<00:00, 2848.62it/s]

Maximum sequence length: 20





In [None]:
data

[[62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75, 0, 0, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [75, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [21, 39, 10, 64, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [98, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [77,
  65,
  51,
  77,
  19,
  15,
  35,
  19,
  23,
  97,
  50,
  46,
  53,
  42,
  45,
  91,
  66,
  3,
  43,
  10],
 [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34, 0, 0, 0, 0],
 [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43, 0, 0, 0, 0, 0, 0]]

### **Hyperparameter 세팅 및 embedding**

In [None]:
d_model = 512  # model의 hidden size
num_heads = 8  # head의 개수

# d_model이 num_heads로 나누어 떨어져야 한다

In [None]:
embedding = nn.Embedding(vocab_size, d_model)

# B: batch size, L: maximum sequence length
batch = torch.LongTensor(data)  # (B, L)
batch_emb = embedding(batch)  # (B, L, d_model)

In [None]:
print(batch_emb)
print(batch_emb.shape)

tensor([[[-0.5375,  1.4314, -0.0752,  ..., -0.8263,  0.8717,  0.7618],
         [ 0.5835, -1.8638,  0.4799,  ...,  0.0206,  0.0927, -0.8794],
         [ 1.0635,  0.3856, -0.0417,  ..., -0.0522, -0.4558,  0.8001],
         ...,
         [-0.0265,  2.4331,  1.2265,  ..., -1.1198,  0.2593, -0.2225],
         [-0.0265,  2.4331,  1.2265,  ..., -1.1198,  0.2593, -0.2225],
         [-0.0265,  2.4331,  1.2265,  ..., -1.1198,  0.2593, -0.2225]],

        [[ 1.7256,  1.4002,  2.0348,  ...,  0.9457,  0.1526,  0.3471],
         [-2.3295,  0.6252,  0.7374,  ..., -0.6947, -0.6032, -1.8170],
         [-0.3749,  0.3878, -1.6112,  ...,  0.8582, -0.5425, -0.8985],
         ...,
         [-0.0265,  2.4331,  1.2265,  ..., -1.1198,  0.2593, -0.2225],
         [-0.0265,  2.4331,  1.2265,  ..., -1.1198,  0.2593, -0.2225],
         [-0.0265,  2.4331,  1.2265,  ..., -1.1198,  0.2593, -0.2225]],

        [[-1.0352, -1.7133,  1.2455,  ..., -0.3931, -0.0878, -1.2596],
         [ 1.0222,  0.0187, -0.3459,  ..., -0

### **Linear transformation & 여러 head로 나누기**

Multi-head attention 내에서 쓰이는 linear transformation matrix들을 정의합니다.

In [None]:
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)

In [None]:
w_0 = nn.Linear(d_model, d_model)

In [13]:
q = w_q(batch_emb)  # (B, L, d_model)
k = w_k(batch_emb)  # (B, L, d_model)
v = w_v(batch_emb)  # (B, L, d_model)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 512])
torch.Size([10, 20, 512])
torch.Size([10, 20, 512])


Q, k, v를 `num_head`개의 차원 분할된 여러 vector로 만듭니다.

- 이론적으로는 multi-head attention을 수행하면 input을 각각 다른 head 개수만큼의 Wq, Wk, Wv로 linear transformation 해서 각각 여러번의 attention 수행한 후 concat 한 후 linear transformation 수행해준다
- 구현에서는 Wq, Wk, Wv 한 개씩
- 실제 `attention is all you need` 논문의 구현 예시는 Query vector 한개를 dim으로 쪼개서 진행한다

In [14]:
batch_size = q.shape[0]
d_k = d_model // num_heads

# num_heads * d_k로 쪼갠다
q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])


In [15]:
# num_heads를 밖으로 뺌으로써
# 각 head가 (L, d_k) 만큼의 matrix를 가지고 self-attention 수행

q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 8, 20, 64])
torch.Size([10, 8, 20, 64])
torch.Size([10, 8, 20, 64])


### **Scaled dot-product self-attention 구현**

각 head에서 실행되는 self-attetion 과정입니다.

In [16]:
# shape - (L, L)
# 같은 sequence 내에 서로 다른 token들에게 얼마나 가중치를 두고 attention을 해야하는가
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)
# softmax - row-wise이기 때문에 dim은 -1
attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

print(attn_dists)
print(attn_dists.shape)

tensor([[[[0.0499, 0.0405, 0.0438,  ..., 0.0613, 0.0613, 0.0613],
          [0.0447, 0.0532, 0.0307,  ..., 0.0315, 0.0315, 0.0315],
          [0.0713, 0.0617, 0.0347,  ..., 0.0448, 0.0448, 0.0448],
          ...,
          [0.0584, 0.0449, 0.0351,  ..., 0.0216, 0.0216, 0.0216],
          [0.0584, 0.0449, 0.0351,  ..., 0.0216, 0.0216, 0.0216],
          [0.0584, 0.0449, 0.0351,  ..., 0.0216, 0.0216, 0.0216]],

         [[0.0438, 0.0642, 0.0674,  ..., 0.0452, 0.0452, 0.0452],
          [0.0508, 0.0594, 0.0724,  ..., 0.0437, 0.0437, 0.0437],
          [0.0830, 0.0343, 0.0587,  ..., 0.0384, 0.0384, 0.0384],
          ...,
          [0.0281, 0.0449, 0.0280,  ..., 0.0525, 0.0525, 0.0525],
          [0.0281, 0.0449, 0.0280,  ..., 0.0525, 0.0525, 0.0525],
          [0.0281, 0.0449, 0.0280,  ..., 0.0525, 0.0525, 0.0525]],

         [[0.0351, 0.0461, 0.0615,  ..., 0.0352, 0.0352, 0.0352],
          [0.0337, 0.0531, 0.0490,  ..., 0.0529, 0.0529, 0.0529],
          [0.0294, 0.0423, 0.0742,  ..., 0

In [17]:
attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

print(attn_values.shape)

torch.Size([10, 8, 20, 64])


### **각 head의 결과물 병합**

각 head의 결과물을 concat하고 동일 차원으로 linear transformation합니다.

In [18]:
attn_values = attn_values.transpose(1, 2)  # (B, L, num_heads, d_k)
attn_values = attn_values.contiguous().view(batch_size, -1, d_model)  # (B, L, d_model)

print(attn_values.shape)

torch.Size([10, 20, 512])


In [19]:
# w_0 : (d_model, d_model)
# 아무 의미 없지 않다
# 서로 다른 의미로 foucsing 된 self-attention 정보들을 합쳐주는 역할 수행
outputs = w_0(attn_values)

print(outputs)
print(outputs.shape)

tensor([[[ 0.0390, -0.1240,  0.0406,  ...,  0.0479,  0.0073,  0.1110],
         [ 0.0392, -0.1011,  0.0391,  ...,  0.0395,  0.0235,  0.0970],
         [ 0.0521, -0.1119,  0.0252,  ...,  0.0495,  0.0509,  0.1050],
         ...,
         [ 0.0634, -0.1406,  0.0485,  ...,  0.0662,  0.0739,  0.0924],
         [ 0.0634, -0.1406,  0.0485,  ...,  0.0662,  0.0739,  0.0924],
         [ 0.0634, -0.1406,  0.0485,  ...,  0.0662,  0.0739,  0.0924]],

        [[ 0.3256, -0.1864,  0.1352,  ..., -0.0625, -0.0321, -0.1640],
         [ 0.2821, -0.1365,  0.0807,  ..., -0.0504,  0.0488, -0.1291],
         [ 0.2804, -0.1647,  0.1189,  ..., -0.0888, -0.0025, -0.0983],
         ...,
         [ 0.3385, -0.1528,  0.0872,  ..., -0.0841,  0.0037, -0.1372],
         [ 0.3385, -0.1528,  0.0872,  ..., -0.0841,  0.0037, -0.1372],
         [ 0.3385, -0.1528,  0.0872,  ..., -0.0841,  0.0037, -0.1372]],

        [[ 0.1550, -0.1998,  0.1456,  ..., -0.1398,  0.0941, -0.0777],
         [ 0.2585, -0.1684,  0.1452,  ..., -0

### **전체 코드**

위의 과정을 모두 합쳐 하나의 Multi-head attention 모듈을 구현하겠습니다.

In [20]:
class MultiheadAttention(nn.Module):
  def __init__(self):
    super(MultiheadAttention, self).__init__()

    # Q, K, V learnable matrices
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)

    # Linear transformation for concatenated outputs
    self.w_0 = nn.Linear(d_model, d_model)

  def forward(self, q, k, v):
    batch_size = q.shape[0]

    # linear transformation
    q = self.w_q(q)  # (B, L, d_model)
    k = self.w_k(k)  # (B, L, d_model)
    v = self.w_v(v)  # (B, L, d_model)

    # head만큼 쪼개준다
    q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

    # 각 head가 (L, d_k)의 matrix를 담당하도록 만든다
    q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
    k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
    v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

    attn_values = self.self_attention(q, k, v)  # (B, num_heads, L, d_k)
    attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, -1, d_model)  # (B, L, num_heads, d_k) => (B, L, d_model)

    return self.w_0(attn_values)

  # scaled-dot product attention
  def self_attention(self, q, k, v):
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)
    attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

    attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

    return attn_values

In [21]:
multihead_attn = MultiheadAttention()

outputs = multihead_attn(batch_emb, batch_emb, batch_emb)  # (B, L, d_model)

In [22]:
print(outputs)
print(outputs.shape)  # (batch_size, length, d_model)

tensor([[[ 0.0714, -0.0246, -0.1465,  ...,  0.1340,  0.2671, -0.0991],
         [-0.0478, -0.0892, -0.1207,  ...,  0.1031,  0.1928, -0.0121],
         [ 0.0668, -0.0034, -0.1259,  ...,  0.0757,  0.1786, -0.1021],
         ...,
         [ 0.0262,  0.0057, -0.1877,  ...,  0.1759,  0.2441, -0.0821],
         [ 0.0262,  0.0057, -0.1877,  ...,  0.1759,  0.2441, -0.0821],
         [ 0.0262,  0.0057, -0.1877,  ...,  0.1759,  0.2441, -0.0821]],

        [[ 0.3932, -0.1357, -0.3062,  ...,  0.1884,  0.1783, -0.1680],
         [ 0.4036, -0.1211, -0.2892,  ...,  0.2104,  0.1763, -0.1609],
         [ 0.4562, -0.1351, -0.2547,  ...,  0.1734,  0.2246, -0.1549],
         ...,
         [ 0.3332, -0.1106, -0.2841,  ...,  0.1888,  0.2551, -0.1363],
         [ 0.3332, -0.1106, -0.2841,  ...,  0.1888,  0.2551, -0.1363],
         [ 0.3332, -0.1106, -0.2841,  ...,  0.1888,  0.2551, -0.1363]],

        [[ 0.1764, -0.1598, -0.1772,  ...,  0.1703,  0.1059, -0.1156],
         [ 0.1458, -0.1457, -0.0733,  ...,  0