# 7. 실습2: Masked Multi-Head Attention

1. Masked Multi-Head Attention 구현
2. Encoder-Decoder Attention 구현

<br>

## 7.1 필요 패키지 import

In [None]:
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

import torch
import math

<br>

## 7.2 데이터 전처리

- 데이터의 값과 형태를 좀 더 명확하게 보기 위해 sample을 줄여보자.

In [None]:
pad_id = 0
vocab_size = 100

data = [
  [62, 13, 47, 39, 78, 33, 56, 13],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
  [66, 88, 98, 47],
  [77, 65, 51, 77, 19, 15, 35, 19, 23]
]

In [None]:
def padding(data):
    max_len = len(max(data, key=len))

    for i, seq in enumerate(tqdm(data)):
        if max_len > len(seq):
            data[i] = seq + [pad_id] * (max_len - len(seq))

    return data, max_len

In [None]:
data, max_len = padding(data)

100%|██████████| 5/5 [00:00<00:00, 3515.17it/s]


In [None]:
data

[[62, 13, 47, 39, 78, 33, 56, 13, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0],
 [77, 65, 51, 77, 19, 15, 35, 19, 23, 0]]

In [None]:
max_len

10

<br>

## 7.3 Hyperparameter 세팅 및 embedding

In [None]:
d_model = 8 # model의 hidden size
num_heads = 2 # head의 개수
inf = 1e12

In [None]:
embedding = nn.Embedding(vocab_size, d_model)

# B: batch_size, L: maximum sequence length
batch = torch.LongTensor(data) # (B, L)
batch_emb = embedding(batch) # (B, L, d_model)

In [None]:
batch_emb.shape

torch.Size([5, 10, 8])

In [None]:
batch_emb

tensor([[[ 0.3405, -0.0893, -1.8315, -0.5934,  0.2929, -0.0484, -1.7125,
           1.0601],
         [ 1.3382,  1.2600,  0.1263, -0.7021,  0.5817,  0.0095,  0.6655,
           2.0635],
         [ 0.4460, -1.0886, -1.3026,  0.9290,  0.2561, -0.0077,  0.4292,
          -0.0335],
         [ 1.6345, -1.1136,  0.9257,  0.1949,  0.9167, -1.0130,  0.2374,
          -0.1264],
         [-0.1171,  0.9863, -0.5977,  1.5313, -1.8583,  0.0425,  1.4645,
           1.0170],
         [-1.4498,  1.7848, -1.5639,  1.7929,  1.3079,  0.4944,  2.4067,
           0.9492],
         [ 0.8104,  0.7886, -0.9598, -1.8665,  0.5429, -1.2761, -1.0773,
           0.7063],
         [ 1.3382,  1.2600,  0.1263, -0.7021,  0.5817,  0.0095,  0.6655,
           2.0635],
         [-0.0064,  0.9030,  1.5812, -1.0435, -1.0133,  0.6658,  0.8902,
           0.2709],
         [-0.0064,  0.9030,  1.5812, -1.0435, -1.0133,  0.6658,  0.8902,
           0.2709]],

        [[-0.8422, -0.6688,  0.2231,  0.5043,  0.8684,  0.3534, -0.5

<br>

## 7.4 Mask 구축

- `True`: Attention이 적용될 부분
- `False`: masking 될 자리

In [None]:
padding_mask = (batch != pad_id).unsqueeze(1) # (B, 1, L)

In [None]:
padding_mask.shape

torch.Size([5, 1, 10])

In [None]:
padding_mask

tensor([[[ True,  True,  True,  True,  True,  True,  True,  True, False, False]],

        [[ True,  True,  True,  True,  True, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True, False, False, False, False, False, False]],

        [[ True,  True,  True,  True,  True,  True,  True,  True,  True, False]]])

In [None]:
nopeak_mask = torch.ones([1, max_len, max_len], dtype=torch.bool) # (1, L, L)
print(nopeak_mask.shape)
print(nopeak_mask)

torch.Size([1, 10, 10])
tensor([[[True, True, True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True],
         [True, True, True, True, True, True, True, True, True, True]]])


<br>

- `torch.tril()`: 정방행렬에 대해 대각선 원소 오른쪽 위의 모든 원소를 0으로 만들어주는 함수

In [None]:
nopeak_mask = torch.tril(nopeak_mask) # (1, L, L)
print(nopeak_mask.shape)
print(nopeak_mask)

torch.Size([1, 10, 10])
tensor([[[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]]])


In [None]:
mask = padding_mask & nopeak_mask # (B, 1, L) & (1, L, L) -> (B, L, L)

print(mask.shape)
print(mask)

torch.Size([5, 10, 10])
tensor([[[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False]],

        [[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
  

<br>

## 7.5 Linear Transformation & 여러 head로 나누기

In [None]:
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)

w_O = nn.Linear(d_model, d_model)

In [None]:
q = w_q(batch_emb) # (B, L, d_model)
k = w_k(batch_emb) # (B, L, d_model)
v = w_v(batch_emb) # (B, L, d_model)

batch_size = q.shape[0]
d_k = d_model // num_heads # 4

q = q.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)

q = q.transpose(1, 2) # (B, num_heads, L, d_k)
k = k.transpose(1, 2) # (B, num_heads, L, d_k)
v = v.transpose(1, 2) # (B, num_heads, L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([5, 2, 10, 4])
torch.Size([5, 2, 10, 4])
torch.Size([5, 2, 10, 4])


<br>

## 7.6 Masking이 적용된 Self-Attention 구현

In [None]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # (B, num_heads, L, L)
attn_scores.shape

torch.Size([5, 2, 10, 10])

In [None]:
attn_scores[0][0]

tensor([[ 0.0177, -0.3843, -0.1331, -0.0128, -0.8994, -0.8384,  0.2056, -0.3843,
         -0.4645, -0.4645],
        [ 0.0519,  0.1119, -0.2325,  0.2204, -0.7840, -0.7788,  0.4834,  0.1119,
          0.2174,  0.2174],
        [ 0.0346, -0.2122,  0.5610, -0.1651,  0.8309,  1.3624, -0.1920, -0.2122,
         -0.3669, -0.3669],
        [ 0.0295,  0.0966,  0.3417,  0.0498,  0.5527,  0.9089,  0.0563,  0.0966,
          0.0853,  0.0853],
        [ 0.2815, -0.2353,  0.4878, -0.3294,  0.4709,  1.3029,  0.1740, -0.2353,
         -0.3597, -0.3597],
        [-0.1603,  0.1456,  0.2527,  0.0909,  0.8318,  0.7549, -0.3529,  0.1456,
          0.1282,  0.1282],
        [-0.0660,  0.1164, -0.4327,  0.1997, -0.7921, -1.2069,  0.1887,  0.1164,
          0.2257,  0.2257],
        [ 0.0519,  0.1119, -0.2325,  0.2204, -0.7840, -0.7788,  0.4834,  0.1119,
          0.2174,  0.2174],
        [ 0.3630,  0.4276, -0.3010, -0.1066, -0.3365, -0.1556,  0.6001,  0.4276,
          0.6415,  0.6415],
        [ 0.3630,  

In [None]:
mask.shape

torch.Size([5, 10, 10])

In [None]:
masks = mask.unsqueeze(1) # (B, 1, L, L)
masks.shape

torch.Size([5, 1, 10, 10])

In [None]:
masked_attn_scores = attn_scores.masked_fill_(masks == False, -1 * inf) # (B, num_heads, L, L)

In [None]:
masked_attn_scores.shape

torch.Size([5, 2, 10, 10])

In [None]:
masked_attn_scores[0][0]

tensor([[ 1.7655e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12,
         -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
        [ 5.1875e-02,  1.1191e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12,
         -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
        [ 3.4631e-02, -2.1224e-01,  5.6099e-01, -1.0000e+12, -1.0000e+12,
         -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
        [ 2.9473e-02,  9.6629e-02,  3.4175e-01,  4.9794e-02, -1.0000e+12,
         -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
        [ 2.8152e-01, -2.3528e-01,  4.8782e-01, -3.2945e-01,  4.7086e-01,
         -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
        [-1.6030e-01,  1.4563e-01,  2.5267e-01,  9.0920e-02,  8.3182e-01,
          7.5491e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12],
        [-6.6023e-02,  1.1636e-01, -4.3267e-01,  1.9967e-01, -7.9209e-01,
         -1.2069e+00,  1.8870e-0

<br>

- `-1 * inf` 로 masking 된 부분은 softmax 후 0이 된다.

In [None]:
attn_dists = F.softmax(masked_attn_scores, dim=-1) # (B, num_heads, L, L)

In [None]:
attn_dists.shape

torch.Size([5, 2, 10, 10])

In [None]:
attn_dists[0][0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.4850, 0.5150, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2879, 0.2249, 0.4873, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2244, 0.2400, 0.3066, 0.2290, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.2185, 0.1303, 0.2686, 0.1186, 0.2640, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0966, 0.1312, 0.1460, 0.1242, 0.2606, 0.2413, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.1590, 0.1908, 0.1102, 0.2073, 0.0769, 0.0508, 0.2051, 0.0000, 0.0000,
         0.0000],
        [0.1339, 0.1422, 0.1008, 0.1585, 0.0580, 0.0583, 0.2061, 0.1422, 0.0000,
         0.0000],
        [0.1508, 0.1608, 0.0776, 0.0943, 0.0749, 0.0898, 0.1911, 0.1608, 0.0000,
         0.0000],
        [0.1508, 0.1608, 0.0776, 0.0943, 0.0749, 0.0898, 0.1911, 0.1608, 0.0000,
         0.0000]], grad_fn=<

In [None]:
attn_values = torch.matmul(attn_dists, v) # (B, num_heads, L, d_k)
attn_values.shape

torch.Size([5, 2, 10, 4])

<br>

## 7.7 Masked Multi-Head Attention 전체 코드

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()

        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)

        self.w_O = nn.Linear(d_model, d_model)

    def forward(self, q, k, v, mask=None):
        batch_size = q.shape[0]

        q = self.w_q(q) # (B, L, d_model)
        k = self.w_k(k) # (B, L, d_model)
        v = self.w_v(v) # (B, L, d_model)

        q = q.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)
        k = k.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)
        v = v.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)

        q = q.transpose(1, 2) # (B, num_heads, L, d_k)
        k = k.transpose(1, 2) # (B, num_heads, L, d_k)
        v = v.transpose(1, 2) # (B, num_heads, L, d_k)

        attn_values = self.self_attention(q, k, v, mask=mask) # (B, num_heads, L, d_k)
        attn_values = attn_values.transpose(1, 2) # (B, L, num_heads, d_k)
        attn_values = attn_values.contiguous().view(batch_size, -1, d_model) # (B, L, d_model)

        return self.w_O(attn_values)

    def self_attention(self, q, k, v, mask=None):
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # (B, num_heads, L, L)

        if mask is not None:
            mask = mask.unsqueeze(1) # (B, L, L) -> (B, 1, L, L) or (B, 1, 1, L) (?)<- broadcasting!
            attn_scores = attn_scores.masked_fill_(mask == False, -1 * inf) # (B, num_heads, L, L)

        attn_dists = F.softmax(attn_scores, dim=-1) # (B, num_heads, L, L)

        attn_values = torch.matmul(attn_dists, v) # (B, num_heads, L, d_k)

        return attn_values

In [None]:
multihead_attn = MultiHeadAttention()

outputs = multihead_attn(batch_emb, batch_emb, batch_emb, mask=mask) # (B, L, d_model)

In [None]:
outputs.shape

torch.Size([5, 10, 8])

In [None]:
outputs

tensor([[[ 0.8002, -0.8208,  0.0515, -0.2678,  0.0028,  0.5807,  0.5486,
           0.8211],
         [ 0.6026, -0.5331,  0.2481, -0.0188,  0.2464,  0.1090,  0.2184,
           0.4320],
         [ 0.5052, -0.4665,  0.1415, -0.2606,  0.2717,  0.1844,  0.3881,
           0.4864],
         [ 0.5385, -0.4053,  0.2601, -0.1694,  0.2779,  0.1148,  0.2781,
           0.4050],
         [ 0.4241, -0.3477,  0.2277, -0.2216,  0.2602,  0.0804,  0.2856,
           0.3582],
         [ 0.1695, -0.3260,  0.0330, -0.2908,  0.4310, -0.1071,  0.3268,
           0.2726],
         [ 0.1543, -0.3592,  0.0935, -0.1306,  0.5251, -0.2931,  0.2703,
           0.1455],
         [ 0.3072, -0.3524,  0.1917, -0.1002,  0.3646, -0.1523,  0.2485,
           0.1974],
         [ 0.3338, -0.3715,  0.2186, -0.0607,  0.3782, -0.1465,  0.1556,
           0.2322],
         [ 0.3338, -0.3715,  0.2186, -0.0607,  0.3782, -0.1465,  0.1556,
           0.2322]],

        [[ 0.3121, -0.3505,  0.1889, -0.1198,  0.3754, -0.1381,  0.0

<br>

## 7.8 Encoder-Decoder Attention

- Query, Key, Value만 달라질 뿐 구현은 동일하다.
- Decoder에 들어갈 batch만 별도로 구현한다.

In [None]:
trg_data = [
  [33, 11, 49, 10],
  [88, 34, 5, 29, 99, 45, 11, 25],
  [67, 25, 15, 90, 54, 4, 92, 10, 46, 20, 88 ,19],
  [16, 58, 91, 47, 12, 5, 8],
  [71, 63, 62, 7, 9, 11, 55, 91, 32, 48]
]

trg_data, trg_max_len = padding(trg_data)

100%|██████████| 5/5 [00:00<00:00, 32564.47it/s]


In [None]:
trg_data

[[33, 11, 49, 10, 0, 0, 0, 0, 0, 0, 0, 0],
 [88, 34, 5, 29, 99, 45, 11, 25, 0, 0, 0, 0],
 [67, 25, 15, 90, 54, 4, 92, 10, 46, 20, 88, 19],
 [16, 58, 91, 47, 12, 5, 8, 0, 0, 0, 0, 0],
 [71, 63, 62, 7, 9, 11, 55, 91, 32, 48, 0, 0]]

In [None]:
trg_max_len

12

In [None]:
# S_L: source maximum sequence length
src_batch = batch # (B, S_L)
src_batch.shape

torch.Size([5, 10])

In [None]:
# T_L: target maximum sequence length
trg_batch = torch.LongTensor(trg_data) # (B, T_L)
trg_batch.shape

torch.Size([5, 12])

In [None]:
src_emb = embedding(src_batch) # (B, S_L, d_model)
trg_emb = embedding(trg_batch) # (B, T_L, d_model)

print(src_emb.shape)
print(trg_emb.shape)

torch.Size([5, 10, 8])
torch.Size([5, 12, 8])


<br>

- `src_emb`: encoderr에서 나온 결과
- `trg_emb`: masked multi-head attention의 결과

In [None]:
q = w_q(trg_emb) # (B, T_L, d_model)
k = w_k(src_emb) # (B, S_L, d_model)
v = w_v(src_emb) # (B, S_L, d_model)

batch_size = q.shape[0]
d_k = d_model // num_heads # 4

q = q.view(batch_size, -1, num_heads, d_k) # (B, T_L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k) # (B, S_L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k) # (B, S_L, num_heads, d_k)

q = q.transpose(1, 2) # (B, num_heads, T_L, d_k)
k = k.transpose(1, 2) # (B, num_heads, S_L, d_k)
v = v.transpose(1, 2) # (B, num_heads, S_L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([5, 2, 12, 4])
torch.Size([5, 2, 10, 4])
torch.Size([5, 2, 10, 4])


In [None]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # (B, num_heads, T_L, S_L) == (5, 2, 12, 10)
attn_scores.shape

torch.Size([5, 2, 12, 10])

In [None]:
attn_dists = F.softmax(attn_scores, dim=-1) # (B, num_heads, T_L, S_L)
attn_dists.shape

torch.Size([5, 2, 12, 10])

In [None]:
attn_values = torch.matmul(attn_dists, v) # (B, num_heads, T_L, d_k)
attn_values.shape

torch.Size([5, 2, 12, 4])

<br>

- Encoder-Decoder Attention의 출력은 Masked Multi-Head Attention 후 나온 결과와 동일한 shape을 갖는다.
- 이후 layer에서 전체 연산도 동일하게 진행된다.