##**8. Masked Multi-head Attention**
1. Masked Multi-head Attention 구현.
2. Encoder-Decoder Attention 구현.

### **필요 패키지 import**

In [1]:
from torch import nn
from torch.nn import functional as F

from tqdm.auto import tqdm

import torch
import math

### **데이터 전처리**

데이터의 값과 형태를 좀 더 명확하게 보기 위해 sample을 줄이겠습니다.

In [2]:
pad_id = 0
vocab_size = 100

data = [
  [62, 13, 47, 39, 78, 33, 56, 13],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
  [66, 88, 98, 47],
  [77, 65, 51, 77, 19, 15, 35, 19, 23]
]

In [3]:
def padding(data):

    max_len = len(max(data, key = len))
    print(f"Maximum sequence length: {max_len}")

    for i,seq in enumerate(tqdm(data)):

        if len(seq) < max_len:

            data[i] = seq + [pad_id] * (max_len - len(seq))

    return data, max_len

In [4]:
data, max_len = padding(data)

Maximum sequence length: 10


  0%|          | 0/5 [00:00<?, ?it/s]

In [5]:
data

[[62, 13, 47, 39, 78, 33, 56, 13, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0],
 [77, 65, 51, 77, 19, 15, 35, 19, 23, 0]]

### **Hyperparameter 세팅 및 embedding**

In [6]:
d_model = 8 # model의 hidden size
num_heads = 2 # head의 개수
inf = 1e12

In [7]:
embedding = nn.Embedding(vocab_size, d_model)

#B: batch size, L: maximum sequence length

batch = torch.LongTensor(data) #(B, L)
batch_emb = embedding(batch) #(B,L,d_model)

In [9]:
print(batch_emb)
print(batch_emb.shape)

tensor([[[-0.2273, -0.0983,  0.5328, -1.1753, -0.4987,  0.4173,  0.2607,
           0.0171],
         [-0.8935, -0.5912, -0.7333, -0.1794,  1.1164,  0.6642,  0.5094,
          -1.5917],
         [ 0.9191, -0.5637, -1.2143,  0.4283,  0.5112, -0.4976, -1.5387,
           0.8228],
         [-0.2822,  0.9250,  0.2539,  0.8478, -0.0862,  2.2553, -1.1267,
          -1.1796],
         [ 2.8048,  0.6999, -0.2310, -0.0963, -0.8707,  1.3508, -0.3853,
          -1.0851],
         [-1.3985,  0.6530,  1.0392, -0.2071,  0.0160, -0.6280,  1.1516,
           1.6433],
         [-0.4956,  1.4092, -0.7539,  1.9166,  1.3538, -0.0419,  1.0191,
          -0.5718],
         [-0.8935, -0.5912, -0.7333, -0.1794,  1.1164,  0.6642,  0.5094,
          -1.5917],
         [ 0.0491, -1.1867,  1.3131, -1.4396, -0.3594, -1.6283, -0.2207,
          -1.4034],
         [ 0.0491, -1.1867,  1.3131, -1.4396, -0.3594, -1.6283, -0.2207,
          -1.4034]],

        [[ 0.6030, -0.7226,  1.7333, -0.3555,  2.1559,  0.0866, -0.2

### **Mask 구축**

`True`는 attention이 적용될 부분, `False`는 masking될 자리입니다.

In [10]:
padding_mask = (batch != pad_id).unsqueeze(1) #(B,1,L)

print(padding_mask)
print(padding_mask.shape)

tensor([[[ True,  True,  True,  True,  True,  True,  True,  True, False, False]],

        [[ True,  True,  True,  True,  True, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True, False, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True, False]]])
torch.Size([5, 1, 10])


In [12]:
nopeak_mask = torch.ones([1, max_len, max_len], dtype = torch.bool) #(1, L, L)

#torch.tril = lower triangular matrix

nopeak_mask = torch.tril(nopeak_mask) #(1, L, L)

print(nopeak_mask)
print(nopeak_mask.shape)

tensor([[[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]]])
torch.Size([1, 10, 10])


In [13]:
mask = padding_mask & nopeak_mask #(B, L, L)

print(mask)
print(mask.shape)

tensor([[[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False]],

        [[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  T

### **Linear transformation & 여러 head로 나누기**

In [14]:
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)

w_o = nn.Linear(d_model, d_model)

In [15]:
q = w_q(batch_emb) # (B, L, d_model)
k = w_k(batch_emb) # (B, L, d_model)
v = w_v(batch_emb) # (B, L, d_model)

batch_size = q.shape[0]
d_k = d_model//num_heads

q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)   # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)   # (B, L, num_heads, d_k)

q = q.transpose(1, 2) # (B, num_heads, L, d_k)
k = k.transpose(1, 2) # (B, num_heads, L, d_k)
v = v.transpose(1, 2) # (B, num_heads, L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([5, 2, 10, 4])
torch.Size([5, 2, 10, 4])
torch.Size([5, 2, 10, 4])


### **Masking이 적용된 self-attention 구현**

In [16]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # (B, num_heads, L, L)

In [17]:
masks = mask.unsqueeze(1) #(B, 1, L, L)
masked_attn_scores = attn_scores.masked_fill_(masks == False, -1 * inf) #(B, num_heads, L, L)

print(masked_attn_scores)
print(masked_attn_scores.shape)

tensor([[[[ 3.2544e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-6.3768e-02, -2.0633e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-2.7262e-01, -2.7068e-01,  3.1844e-01, -1.0000e+12, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-2.2799e-01, -2.8245e-01,  6.4515e-01, -4.7439e-02, -1.0000e+12,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 1.4887e-01,  8.8036e-02,  5.1274e-01, -2.9637e-01,  2.7026e-02,
           -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [ 4.0239e-01,  5.6376e-01, -4.9722e-01,  1.2941e-01, -9.9683e-02,
            1.2100e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
          [-2.2040e-01, -7.4122e-02,  2.8399e-01,  5.4994e-02, -3.6106e-01,
      

`-1* inf`로 masking된 부분은 softmax 후 0이 됩니다.

In [18]:
attn_dists = F.softmax(masked_attn_scores, dim = -1) #(B, num_heads, L, L)

print(attn_dists)
print(attn_dists.shape)

tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.5356, 0.4644, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.2626, 0.2631, 0.4743, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.1805, 0.1710, 0.4323, 0.2163, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.2038, 0.1918, 0.2933, 0.1306, 0.1805, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.2126, 0.2499, 0.0865, 0.1618, 0.1287, 0.1605, 0.0000, 0.0000,
           0.0000, 0.0000],
          [0.0955, 0.1106, 0.1582, 0.1258, 0.0830, 0.1915, 0.2352, 0.0000,
           0.0000, 0.0000],
          [0.1165, 0.1010, 0.1629, 0.1461, 0.1262, 0.1270, 0.1193, 0.1010,
           0.0000, 0.0000],
          [0.1753, 0.1614, 0.0547, 0.1472, 0.1663, 0.0766, 0.0573, 0.1614,
           0.0000, 0.0000],
          [0.1753, 0.1614, 0.0547, 0.1472, 0.1663, 0.0766, 0.0573, 0.1614

In [19]:
attn_values = torch.matmul(attn_dists, v) #(B, num_heads, L, d_k)

print(attn_values.shape)

torch.Size([5, 2, 10, 4])


### **전체 코드**

In [22]:
class MultiheadAttention(nn.Module):

    def __init__(self, d_model):
        super(MultiheadAttention, self).__init__()

        #Q, K, V learnable matrices
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)

        #Linear transformation for concatenated outputs
        self.w_o = nn.Linear(d_model, d_model)

    def self_attention(self, q, k, v, mask = None):

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) #(B, num_heads, L, L)

        if mask is not None:

            mask = mask.unsqueeze(1) #(B, 1, L, L) or (B, 1, L, L)
            attn_scores = attn_scores.masked_fill_(mask == False, -1*inf)

        attn_dists = F.softmax(attn_scores, dim = -1) #(B, num_heads, L, L)

        attn_values = torch.matmul(attn_dists, v) #(B, num_heads, L, d_k)

        return attn_values

    def forward(self, q, k, v, mask = None):

        batch_size = q.shape[0]

        q = self.w_q(q) #(B, L, d_model)
        k = self.w_k(k) #(B, L, d_model)
        v = self.w_v(v) #(B, L, d_model)

        q = q.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)
        k = k.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)
        v = v.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)

        q = q.transpose(1, 2) # (B, num_heads, L, d_k)
        k = k.transpose(1, 2) # (B, num_heads, L, d_k)
        v = v.transpose(1, 2) # (B, num_heads, L, d_k)

        attn_values = self.self_attention(q,k,v,mask = mask) # (B, num_heads, L, d_k)
        attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, -1, d_model) # (B, num_heads, L, d_k) => (B, L, num_heads, d_k) => (B, L, d_model)

        return self.w_o(attn_values)


In [23]:
multihead_attn = MultiheadAttention(d_model)

outputs = multihead_attn(batch_emb, batch_emb, batch_emb, mask = mask) #(B, L, d_model)

In [24]:
print(outputs)
print(outputs.shape)

tensor([[[-2.3269e-02,  3.1206e-01,  3.1238e-01, -1.7241e-02, -1.6370e-01,
          -4.7599e-02, -1.4976e-01, -2.0723e-01],
         [-1.0507e-01,  3.9004e-01,  3.4458e-01,  1.7233e-01, -2.7297e-01,
          -1.6310e-01, -2.3939e-01, -4.6597e-01],
         [-2.3351e-01,  5.1881e-01,  2.6438e-01,  1.4968e-01, -4.4715e-01,
          -3.4810e-02, -4.0974e-01, -5.1542e-01],
         [-3.2912e-01,  5.6972e-01,  2.3370e-01,  2.8644e-01, -2.8120e-01,
          -3.0914e-02, -2.9097e-01, -5.8074e-01],
         [-3.5269e-01,  6.9648e-01,  2.0319e-01,  1.7705e-01, -3.5971e-01,
           5.5281e-02, -3.5427e-01, -5.4995e-01],
         [-3.1193e-01,  5.6163e-01,  2.3772e-01,  2.7372e-01, -1.5466e-01,
          -3.2652e-03, -1.8894e-01, -4.9975e-01],
         [-2.5167e-01,  4.1133e-01,  2.4469e-01,  2.9654e-01, -3.2916e-01,
          -1.5856e-02, -3.2085e-01, -4.7524e-01],
         [-2.8118e-01,  5.0221e-01,  2.3823e-01,  3.0857e-01, -2.2174e-01,
          -1.9425e-02, -2.3244e-01, -5.0338e-01],


### **Encoder-Decoder attention**

Query, key, value만 달라질 뿐 구현은 동일합니다.  
Decoder에 들어갈 batch만 별도 구현하겠습니다.

In [29]:
trg_data = [
  [33, 11, 49, 10],
  [88, 34, 5, 29, 99, 45, 11, 25],
  [67, 25, 15, 90, 54, 4, 92, 10, 46, 20, 88 ,19],
  [16, 58, 91, 47, 12, 5, 8],
  [71, 63, 62, 7, 9, 11, 55, 91, 32, 48]
]


In [30]:
trg_data, trg_max_len = padding(trg_data)

Maximum sequence length: 12


  0%|          | 0/5 [00:00<?, ?it/s]

In [31]:
# S_L: source maximum sequence length, T_L: target maximum sequence length

src_batch = batch #(B, S_L)
trg_batch = torch.LongTensor(trg_data) #(B, T_L)

print(src_batch.shape)
print(trg_batch.shape)

torch.Size([5, 10])
torch.Size([5, 12])


In [32]:
src_emb = embedding(src_batch) #(B, S_L, d_w)
trg_emb = embedding(trg_batch) #(B, T_L, d_w)

print(src_emb.shape)
print(trg_emb.shape)

torch.Size([5, 10, 8])
torch.Size([5, 12, 8])


`src_emb`를 encoder에서 나온 결과, 그리고 `trg_emb`를 masked multi-head attention 후 결과로 가정합니다.

In [33]:
batch_size = q.shape[0]
d_k = d_model // num_heads

In [34]:
#decoder masked multihead attention
q = w_q(trg_emb) #(B, T_L, d_model)

#encoder
k = w_k(src_emb) #(B, S_L, d_model)
v = w_v(src_emb) #(B, S_L, d_model)

In [35]:
q = q.view(batch_size, -1, num_heads, d_k) # (B, T_L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k) # (B, S_L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k) # (B, S_L, num_heads, d_k)

In [36]:
q = q.transpose(1, 2) # (B, num_heads, T_L, d_k)
k = k.transpose(1, 2) # (B, num_heads, S_L, d_k)
v = v.transpose(1, 2) # (B, num_heads, S_L, d_k)

In [37]:
print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([5, 2, 12, 4])
torch.Size([5, 2, 10, 4])
torch.Size([5, 2, 10, 4])


In [38]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) #(B, num_heads, T_L, S_L)
attn_dists = F.softmax(attn_scores, dim = -1) #(B, num_heads, T_L, S_L)

print(attn_dists.shape)

torch.Size([5, 2, 12, 10])


In [39]:
attn_values = torch.matmul(attn_dists, v) #(B, num_heads, T_L, d_k)

print(attn_values.shape)

torch.Size([5, 2, 12, 4])


Masked multi-head attention 후 나온 결과와 동일한 shape를 가지며 이후 layer에서 전체 연산도 동일하게 진행됩니다.