##**7. Multi-head Attention**
1. Multi-head attention 및 self-attention 구현.
2. 각 과정에서 일어나는 연산과 input/output 형태 이해.

### **필요 패키지 import**

In [1]:
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

import torch
import math

### **데이터 전처리**

In [2]:
pad_id = 0
vocab_size = 100

data = [
  [62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54],
  [75, 51],
  [66, 88, 98, 47],
  [21, 39, 10, 64, 21],
  [98],
  [77, 65, 51, 77, 19, 15, 35, 19, 23, 97, 50, 46, 53, 42, 45, 91, 66, 3, 43, 10],
  [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34],
  [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43]
]

In [3]:
def padding(data):
  max_len = len(max(data, key=len))
  print(f"Maximum sequence length: {max_len}")

  for i, seq in enumerate(tqdm(data)):
    if len(seq) < max_len:
      data[i] = seq + [pad_id] * (max_len - len(seq))

  return data, max_len

In [4]:
data, max_len = padding(data)

100%|██████████| 10/10 [00:00<00:00, 34865.37it/s]

Maximum sequence length: 20





In [5]:
data

[[62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75, 0, 0, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [75, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [21, 39, 10, 64, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [98, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [77,
  65,
  51,
  77,
  19,
  15,
  35,
  19,
  23,
  97,
  50,
  46,
  53,
  42,
  45,
  91,
  66,
  3,
  43,
  10],
 [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34, 0, 0, 0, 0],
 [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43, 0, 0, 0, 0, 0, 0]]

### **Hyperparameter 세팅 및 embedding**

In [6]:
d_model = 512  # model의 hidden size
num_heads = 8  # head의 개수

In [7]:
embedding = nn.Embedding(vocab_size, d_model)

# B: batch size, L: maximum sequence length
batch = torch.LongTensor(data)  # (B, L)
batch_emb = embedding(batch)  # (B, L, d_model)

In [8]:
print(batch_emb)
print(batch_emb.shape)

tensor([[[ 0.2826,  0.1404, -1.8806,  ...,  1.0316,  1.5623, -0.4378],
         [-0.6358, -1.4330,  0.0812,  ..., -0.2815, -0.8433, -0.6888],
         [-0.5335, -1.2960,  0.2290,  ..., -1.9809, -0.8674, -0.5929],
         ...,
         [-1.2422, -0.7566,  0.4933,  ...,  1.1519,  2.0426,  1.5392],
         [-1.2422, -0.7566,  0.4933,  ...,  1.1519,  2.0426,  1.5392],
         [-1.2422, -0.7566,  0.4933,  ...,  1.1519,  2.0426,  1.5392]],

        [[ 1.1400,  0.2889,  0.5718,  ...,  2.4864, -0.2878, -0.5118],
         [ 1.0549, -0.4372,  0.7949,  ...,  1.8875, -0.6376, -0.5406],
         [-0.3805, -0.2027, -0.8908,  ..., -0.0192, -2.5096, -1.4047],
         ...,
         [-1.2422, -0.7566,  0.4933,  ...,  1.1519,  2.0426,  1.5392],
         [-1.2422, -0.7566,  0.4933,  ...,  1.1519,  2.0426,  1.5392],
         [-1.2422, -0.7566,  0.4933,  ...,  1.1519,  2.0426,  1.5392]],

        [[ 1.4028, -1.4137, -0.7341,  ...,  0.1304, -1.5731, -0.9929],
         [-1.0511,  1.2357, -0.4694,  ..., -2

### **Linear transformation & 여러 head로 나누기**

Multi-head attention 내에서 쓰이는 linear transformation matrix들을 정의합니다.

In [9]:
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)

In [10]:
w_0 = nn.Linear(d_model, d_model)

In [11]:
q = w_q(batch_emb)  # (B, L, d_model)
k = w_k(batch_emb)  # (B, L, d_model)
v = w_v(batch_emb)  # (B, L, d_model)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 512])
torch.Size([10, 20, 512])
torch.Size([10, 20, 512])


Q, k, v를 `num_head`개의 차원 분할된 여러 vector로 만듭니다.

In [12]:
batch_size = q.shape[0]
d_k = d_model // num_heads

q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])


In [13]:
q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 8, 20, 64])
torch.Size([10, 8, 20, 64])
torch.Size([10, 8, 20, 64])


### **Scaled dot-product self-attention 구현**

각 head에서 실행되는 self-attetion 과정입니다.

In [19]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)
attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

print(attn_scores.shape)
# print(attn_dists)
print(attn_dists.shape)

torch.Size([10, 8, 20, 20])
torch.Size([10, 8, 20, 20])


In [20]:

v.shape

torch.Size([10, 8, 20, 64])

In [18]:
attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

print(attn_values.shape)

torch.Size([10, 8, 20, 64])


### **각 head의 결과물 병합**

각 head의 결과물을 concat하고 동일 차원으로 linear transformation합니다.

In [25]:
attn_values = attn_values.transpose(1, 2)  # (B, L, num_heads, d_k)
print(attn_values.shape)
attn_values = attn_values.contiguous().view(batch_size, -1, d_model)  # (B, L, d_model)

print(attn_values.shape)

torch.Size([10, 512, 20])
torch.Size([10, 20, 512])


In [26]:
outputs = w_0(attn_values)

print(outputs)
print(outputs.shape)

tensor([[[ 0.0567, -0.0209,  0.0043,  ..., -0.0531,  0.1738,  0.1426],
         [ 0.1409,  0.0459,  0.0512,  ...,  0.0461, -0.0205, -0.0057],
         [ 0.0487, -0.0434,  0.1266,  ...,  0.0438,  0.0561, -0.1456],
         ...,
         [-0.1780,  0.0567, -0.1630,  ..., -0.0976, -0.0459, -0.0163],
         [-0.0830, -0.1623, -0.1158,  ..., -0.1331,  0.0338, -0.0714],
         [ 0.0985, -0.0279, -0.2172,  ...,  0.0260, -0.0861, -0.0418]],

        [[-0.0440, -0.1483, -0.1870,  ..., -0.0452,  0.2920,  0.1776],
         [ 0.2851,  0.0721, -0.0379,  ...,  0.1460, -0.0991, -0.0980],
         [-0.0392, -0.2991,  0.2465,  ...,  0.0970,  0.1270, -0.2169],
         ...,
         [-0.5207, -0.1474, -0.2334,  ...,  0.0681, -0.3703,  0.4393],
         [-0.3173, -0.5996,  0.0582,  ..., -0.4438,  0.1314, -0.0597],
         [ 0.0777,  0.2193, -0.1201,  ...,  0.1036, -0.0339,  0.1314]],

        [[ 0.0776, -0.1898, -0.1236,  ...,  0.0252,  0.1498,  0.1471],
         [ 0.3411,  0.0330, -0.0211,  ...,  0

### **전체 코드**

위의 과정을 모두 합쳐 하나의 Multi-head attention 모듈을 구현하겠습니다.

In [27]:
class MultiheadAttention(nn.Module):
  def __init__(self):
    super(MultiheadAttention, self).__init__()

    # Q, K, V learnable matrices
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)

    # Linear transformation for concatenated outputs
    self.w_0 = nn.Linear(d_model, d_model)

  def forward(self, q, k, v):
    batch_size = q.shape[0]

    q = self.w_q(q)  # (B, L, d_model)
    k = self.w_k(k)  # (B, L, d_model)
    v = self.w_v(v)  # (B, L, d_model)

    q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

    q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
    k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
    v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

    attn_values = self.self_attention(q, k, v)  # (B, num_heads, L, d_k)
    attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, -1, d_model)  # (B, L, num_heads, d_k) => (B, L, d_model)

    return self.w_0(attn_values)

  def self_attention(self, q, k, v):
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)
    attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

    attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

    return attn_values

In [28]:
multihead_attn = MultiheadAttention()

outputs = multihead_attn(batch_emb, batch_emb, batch_emb)  # (B, L, d_model)

In [29]:
print(outputs)
print(outputs.shape)

tensor([[[ 0.0448,  0.0268, -0.0891,  ..., -0.0003,  0.0919,  0.0454],
         [-0.0024, -0.0347, -0.0528,  ...,  0.0291,  0.0624,  0.0686],
         [ 0.0525,  0.0335, -0.0281,  ...,  0.0178,  0.0203,  0.0042],
         ...,
         [ 0.0149, -0.0504, -0.0462,  ...,  0.0506,  0.0314, -0.0355],
         [ 0.0149, -0.0504, -0.0462,  ...,  0.0506,  0.0314, -0.0355],
         [ 0.0149, -0.0504, -0.0462,  ...,  0.0506,  0.0314, -0.0355]],

        [[-0.1386,  0.0251, -0.1141,  ...,  0.0034, -0.2172, -0.0186],
         [-0.1430,  0.0432, -0.1258,  ...,  0.0148, -0.1913,  0.0207],
         [-0.1519,  0.0702, -0.0603,  ...,  0.0395, -0.1865,  0.0488],
         ...,
         [-0.1897,  0.0055, -0.0758,  ...,  0.0074, -0.1392, -0.0149],
         [-0.1897,  0.0055, -0.0758,  ...,  0.0074, -0.1392, -0.0149],
         [-0.1897,  0.0055, -0.0758,  ...,  0.0074, -0.1392, -0.0149]],

        [[-0.0919,  0.0349, -0.0090,  ...,  0.0178, -0.1924,  0.0756],
         [-0.0721,  0.0197,  0.0370,  ...,  0