##**8. Masked Multi-head Attention**
1. Masked Multi-head Attention 구현.
2. Encoder-Decoder Attention 구현.
- Decoder 부분

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### **필요 패키지 import**

In [10]:
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

import torch
import math

### **데이터 전처리**

데이터의 값과 형태를 좀 더 명확하게 보기 위해 sample을 줄이겠습니다.

In [3]:
pad_id = 0
vocab_size = 100

# 실제 값 확인해 볼 필요
# 보기 쉽기 위해 데이터를 줄였다
data = [
  [62, 13, 47, 39, 78, 33, 56, 13],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
  [66, 88, 98, 47],
  [77, 65, 51, 77, 19, 15, 35, 19, 23]
]

In [4]:
def padding(data):
  max_len = len(max(data, key=len))
  print(f"Maximum sequence length: {max_len}")

  for i, seq in enumerate(tqdm(data)):
    if len(seq) < max_len:
      data[i] = seq + [pad_id] * (max_len - len(seq))

  return data, max_len

In [5]:
data, max_len = padding(data)

100%|██████████| 5/5 [00:00<00:00, 12679.27it/s]

Maximum sequence length: 10





In [6]:
data

[[62, 13, 47, 39, 78, 33, 56, 13, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0],
 [77, 65, 51, 77, 19, 15, 35, 19, 23, 0]]

### **Hyperparameter 세팅 및 embedding**

In [7]:
# 보기 쉽게 하기 위해 차원도 줄였다
d_model = 8  # model의 hidden size
num_heads = 2  # head의 개수
inf = 1e12

In [8]:
embedding = nn.Embedding(vocab_size, d_model)

# B: batch size, L: maximum sequence length
batch = torch.LongTensor(data)  # (B, L)
batch_emb = embedding(batch)  # (B, L, d_model)

In [11]:
print(batch_emb)
print(batch_emb.shape)

tensor([[[-1.9610e+00,  1.6843e-01,  2.4753e-02,  1.9252e+00, -1.1563e+00,
          -9.6380e-01, -7.5435e-01, -1.5953e-01],
         [-1.6101e-01, -3.4729e-01,  1.3322e-01,  2.2014e+00,  1.4778e-01,
          -1.7889e+00, -1.1637e+00, -5.3178e-01],
         [ 3.4615e-02,  3.4697e-01, -8.8172e-01,  7.1375e-01,  6.5656e-01,
           2.7853e-01,  1.5836e-01, -5.0664e-01],
         [-1.9087e-01,  7.1909e-01,  1.8414e+00,  1.5429e+00,  6.4760e-01,
          -4.7824e-01,  9.9462e-02,  2.3132e-01],
         [ 5.5992e-01,  2.6660e+00, -7.6602e-02,  2.3276e-01,  9.4208e-01,
          -1.3384e+00, -1.2717e-01, -1.4878e+00],
         [ 1.2408e+00,  5.1148e-01, -4.4437e-01,  2.1889e-02, -5.8662e-01,
           3.8692e-01, -6.3511e-01, -1.9990e+00],
         [-5.9062e-01,  3.6413e-01, -3.6650e-01, -5.3939e-02, -1.4821e+00,
          -1.0151e+00, -1.8241e+00, -7.1189e-01],
         [-1.6101e-01, -3.4729e-01,  1.3322e-01,  2.2014e+00,  1.4778e-01,
          -1.7889e+00, -1.1637e+00, -5.3178e-01],


### **Mask 구축**

`True`는 attention이 적용될 부분, `False`는 masking될 자리입니다.

In [12]:
# pad token을 masking 해준다
# True -> attention 허가, False -> masking
padding_mask = (batch != pad_id).unsqueeze(1)  # (B, 1, L)

print(padding_mask)
print(padding_mask.shape)

tensor([[[ True,  True,  True,  True,  True,  True,  True,  True, False, False]],

        [[ True,  True,  True,  True,  True, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True, False, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True, False]]])
torch.Size([5, 1, 10])


In [13]:
# nopeak_mask - 보지 못하게 막는 마스크
# 전부 1로 초기화
# 맨 앞의 1은 계산의 용이성을 위해
nopeak_mask = torch.ones([1, max_len, max_len], dtype=torch.bool)  # (1, L, L)
# tril - triangle low
# matrix 반쪽을 아래쪽만 True로 채워진 삼각형 모양을 만들어준다
nopeak_mask = torch.tril(nopeak_mask)  # (1, L, L)

print(nopeak_mask)
print(nopeak_mask.shape)

tensor([[[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]]])
torch.Size([1, 10, 10])


In [14]:
# padding_mask - (B, 1, L)
# nopeak_mask가 batch_size만큼 복사되고, 
# padding mask가 또 L만큼 반복되어
# 둘이 같은 사이즈로 맞춰지면서
# 동시에 둘이 함께 True였던 부분만 True가 되고 나머지는 False가 된다
mask = padding_mask & nopeak_mask  # (B, L, L)

print(mask)
print(mask.shape)

tensor([[[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False]],

        [[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  T

### **Linear transformation & 여러 head로 나누기**

In [15]:
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)

w_0 = nn.Linear(d_model, d_model)

In [16]:
q = w_q(batch_emb)  # (B, L, d_model)
k = w_k(batch_emb)  # (B, L, d_model)
v = w_v(batch_emb)  # (B, L, d_model)

batch_size = q.shape[0]
d_k = d_model // num_heads

q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([5, 2, 10, 4])
torch.Size([5, 2, 10, 4])
torch.Size([5, 2, 10, 4])


### **Masking이 적용된 self-attention 구현**

In [18]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)

In [19]:
masks = mask.unsqueeze(1)  # (B, 1, L, L)
# masked_fill_ 함수 -> mask중 False는 -inf으로
# 0이 아니고 매우 작은 값을 넣어줬다
masked_attn_scores = attn_scores.masked_fill_(masks == False, -1 * inf)  # (B, num_heads, L, L)

print(masked_attn_scores)
print(masked_attn_scores.shape)

tensor([[[[ 5.9993e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 4.2600e-01,  4.3161e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 4.9480e-01,  5.1267e-01,  1.6712e-01, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 9.6691e-01,  1.1695e+00,  2.2937e-01,  8.9673e-01, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 4.5012e-01,  4.4531e-01,  2.0075e-01,  1.4426e-01,  1.1765e-01,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 7.1670e-02,  7.0278e-02,  8.7590e-04, -3.7849e-02,  3.2966e-02,
           -8.0640e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-3.0063e-01, -3.8789e-01, -1.6093e-01, -5.8082e-01, -7.1848e-02,
      

`-1* inf`로 masking된 부분은 softmax 후 0이 됩니다.

In [20]:
attn_dists = F.softmax(masked_attn_scores, dim=-1)  # (B, num_heads, L, L)

print(attn_dists)
print(attn_dists.shape)

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.4986, 0.5014, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.3651, 0.3717, 0.2631, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.2751, 0.3369, 0.1316, 0.2564, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.2365, 0.2354, 0.1843, 0.1742, 0.1696, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1771, 0.1768, 0.1650, 0.1587, 0.1704, 0.1521, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1343, 0.1231, 0.1545, 0.1015, 0.1689, 0.1449, 0.1728, 0.0000,
           0.0000, 0.0000],
          [0.1558, 0.1567, 0.1061, 0.0974, 0.1153, 0.0739, 0.1381, 0.1567,
           0.0000, 0.0000],
          [0.1186, 0.1103, 0.1353, 0.1070, 0.1267, 0.1470, 0.1448, 0.1103,
           0.0000, 0.0000],
          [0.1186, 0.1103, 0.1353, 0.1070, 0.1267, 0.1470, 0.1448, 0.1103

In [21]:
attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

print(attn_values.shape)

torch.Size([5, 2, 10, 4])


### **전체 코드**

In [22]:
class MultiheadAttention(nn.Module):
  def __init__(self):
    super(MultiheadAttention, self).__init__()

    # Q, K, V learnable matrices
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)

    # Linear transformation for concatenated outputs
    self.w_0 = nn.Linear(d_model, d_model)

  def forward(self, q, k, v, mask=None):
    batch_size = q.shape[0]

    q = self.w_q(q)  # (B, L, d_model)
    k = self.w_k(k)  # (B, L, d_model)
    v = self.w_v(v)  # (B, L, d_model)

    q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

    q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
    k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
    v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

    attn_values = self.self_attention(q, k, v, mask=mask)  # (B, num_heads, L, d_k)
    attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, -1, d_model)  # (B, L, num_heads, d_k) => (B, L, d_model)

    return self.w_0(attn_values)

  def self_attention(self, q, k, v, mask=None):
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)

    if mask is not None:
      mask = mask.unsqueeze(1)  # (B, 1, L, L) or  (B, 1, 1, L)
      attn_scores = attn_scores.masked_fill_(mask == False, -1*inf)

    attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

    attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

    return attn_values

In [23]:
multihead_attn = MultiheadAttention()

outputs = multihead_attn(batch_emb, batch_emb, batch_emb, mask=mask)  # (B, L, d_model)

In [24]:
print(outputs)
print(outputs.shape)

tensor([[[-7.0922e-03, -1.8101e-01, -8.1648e-01, -1.8661e-01, -4.4591e-01,
          -5.5058e-01,  2.6041e-01,  6.1075e-01],
         [-2.3712e-01, -2.4034e-01, -6.8836e-01, -4.0657e-01, -4.8102e-01,
          -5.0028e-01,  2.3226e-01,  4.4302e-01],
         [-3.8950e-01, -1.6009e-01, -6.1487e-01, -3.0829e-01, -1.9821e-01,
          -4.1462e-01,  1.4663e-01,  2.9856e-01],
         [-4.3241e-01, -9.4246e-02, -5.8908e-01, -2.7046e-01, -2.5450e-01,
          -4.1177e-01,  2.5301e-01,  3.0859e-01],
         [-5.3860e-01, -7.5092e-02, -5.6061e-01, -3.2771e-01, -2.1212e-01,
          -3.7784e-01,  2.9600e-01,  3.3393e-01],
         [-5.0542e-01, -6.8979e-03, -4.8650e-01, -2.3255e-01, -1.9498e-01,
          -3.9364e-01,  2.9099e-01,  4.8673e-01],
         [-4.4824e-01, -9.3124e-03, -4.8479e-01, -2.1237e-01, -2.4692e-01,
          -4.2991e-01,  2.8655e-01,  5.7619e-01],
         [-4.4588e-01, -6.3881e-02, -4.7409e-01, -2.6464e-01, -2.3907e-01,
          -4.2752e-01,  2.4946e-01,  5.2912e-01],


### **Encoder-Decoder attention**

Query, key, value만 달라질 뿐 구현은 동일합니다.  
Decoder에 들어갈 batch만 별도 구현하겠습니다.

- seq2seq with attention에서 decoder hidden state와 전체 endocer hidden states들을 곱해주는
- Encoder 결과와 연관성을 찾는 기준을 Transformer 기준으로 구현한 것

In [25]:
# Decoder에 들어갈 target data
trg_data = [
  [33, 11, 49, 10],
  [88, 34, 5, 29, 99, 45, 11, 25],
  [67, 25, 15, 90, 54, 4, 92, 10, 46, 20, 88 ,19],
  [16, 58, 91, 47, 12, 5, 8],
  [71, 63, 62, 7, 9, 11, 55, 91, 32, 48]
]

trg_data, trg_max_len = padding(trg_data)

100%|██████████| 5/5 [00:00<00:00, 24328.91it/s]

Maximum sequence length: 12





In [26]:
# S_L: source maximum sequence length, T_L: target maximum sequence length
src_batch = batch  # (B, S_L)
trg_batch = torch.LongTensor(trg_data)  # (B, T_L)

print(src_batch.shape)
print(trg_batch.shape)

torch.Size([5, 10])
torch.Size([5, 12])


In [28]:
src_emb = embedding(src_batch)  # (B, S_L, d_w)
trg_emb = embedding(trg_batch)  # (B, T_L, d_w)

print(src_emb.shape)  # encoder에서 나온 결과
print(trg_emb.shape)  # masked multi-attention 하고 나온 결과라고 가정

torch.Size([5, 10, 8])
torch.Size([5, 12, 8])


`src_emb`를 encoder에서 나온 결과, 그리고 `trg_emb`를 masked multi-head attention 후 결과로 가정합니다.

In [29]:
# Query는 target, K와 V는 source embedding
q = w_q(trg_emb)  # (B, T_L, d_model)
k = w_k(src_emb)  # (B, S_L, d_model)
v = w_v(src_emb)  # (B, S_L, d_model)

batch_size = q.shape[0]
d_k = d_model // num_heads

q = q.view(batch_size, -1, num_heads, d_k)  # (B, T_L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, S_L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, S_L, num_heads, d_k)

q = q.transpose(1, 2)  # (B, num_heads, T_L, d_k)
k = k.transpose(1, 2)  # (B, num_heads, S_L, d_k)
v = v.transpose(1, 2)  # (B, num_heads, S_L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([5, 2, 12, 4])
torch.Size([5, 2, 10, 4])
torch.Size([5, 2, 10, 4])


In [30]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, T_L, S_L) 
attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, T_L, S_L)

print(attn_dists.shape)

torch.Size([5, 2, 12, 10])


In [31]:
attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, T_L, d_k)

print(attn_values.shape)

torch.Size([5, 2, 12, 4])


Masked multi-head attention 후 나온 결과와 동일한 shape를 가지며 이후 layer에서 전체 연산도 동일하게 진행됩니다.