# 7. 실습 1: Multi-Head Attention

1. Multi-head attention 및 self-attention 구현
2. 각 과정에서 일어나는 연산과 input/output 형태 이해

<br>

## 7.1 필요 패키지 import

In [1]:
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

import torch
import math

<br>

## 7.2 데이터 전처리

In [2]:
pad_id = 0
vocab_size = 100

data = [
  [62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54],
  [75, 51],
  [66, 88, 98, 47],
  [21, 39, 10, 64, 21],
  [98],
  [77, 65, 51, 77, 19, 15, 35, 19, 23, 97, 50, 46, 53, 42, 45, 91, 66, 3, 43, 10],
  [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34],
  [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43]
]

In [3]:
def padding(data):
    max_len = len(max(data, key=len))
    print(f"Maximum sequence length: {max_len}")

    for i, seq in enumerate(data):
        if len(seq) < max_len:
            data[i] = seq + [pad_id] * (max_len - len(seq))

    return data, max_len

In [4]:
data, max_len = padding(data)

Maximum sequence length: 20


In [6]:
for d in data:
    print(d)

[62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75, 0, 0, 0, 0]
[60, 96, 51, 32, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[75, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[66, 88, 98, 47, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[21, 39, 10, 64, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[98, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[77, 65, 51, 77, 19, 15, 35, 19, 23, 97, 50, 46, 53, 42, 45, 91, 66, 3, 43, 10]
[70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34, 0, 0, 0, 0]
[20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43, 0, 0, 0, 0, 0, 0]


<br>

## 7.3 Hyperparameter 세팅 및 embedding

In [7]:
d_model = 512 # model의 hidden size
num_heads = 8 # head의 개수

In [8]:
embedding = nn.Embedding(vocab_size, d_model)

# B: batch size, L: maximum sequence length
batch = torch.LongTensor(data) # (B, L)
batch_emb = embedding(batch) # (B, L, d_model)

In [10]:
print(batch_emb.shape)

torch.Size([10, 20, 512])


In [9]:
print(batch_emb)

tensor([[[ 1.6177,  0.0194,  0.2977,  ..., -1.6516, -1.2086,  0.5774],
         [-0.6715, -1.9382,  1.3285,  ..., -1.0405, -1.3163, -1.6637],
         [ 0.2595, -0.1710, -1.3617,  ...,  1.7346, -0.7475, -1.8132],
         ...,
         [-0.3729,  0.5898,  0.1876,  ..., -0.7958,  1.0197, -0.1817],
         [-0.3729,  0.5898,  0.1876,  ..., -0.7958,  1.0197, -0.1817],
         [-0.3729,  0.5898,  0.1876,  ..., -0.7958,  1.0197, -0.1817]],

        [[ 1.0453,  1.0690,  1.5396,  ..., -0.1327, -0.7455, -1.0294],
         [-2.0269, -0.3514, -1.1836,  ..., -0.5916, -0.1701,  1.1248],
         [ 2.6546,  0.8070, -0.8941,  ..., -0.0147, -1.4569, -0.0532],
         ...,
         [-0.3729,  0.5898,  0.1876,  ..., -0.7958,  1.0197, -0.1817],
         [-0.3729,  0.5898,  0.1876,  ..., -0.7958,  1.0197, -0.1817],
         [-0.3729,  0.5898,  0.1876,  ..., -0.7958,  1.0197, -0.1817]],

        [[ 1.7817,  0.3581, -2.8669,  ..., -0.0494, -0.2738,  1.5816],
         [-2.1258,  0.0092, -0.2425,  ...,  0

<br>

## 7.4 Linear transformation & 여러 head로 나누기

- Multi-Head Attention 내에서 쓰이는 linear transformation matrix들을 정의한다.

In [11]:
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)

In [12]:
w_O = nn.Linear(d_model, d_model)

In [13]:
q = w_q(batch_emb) # (B, L, d_model)
k = w_q(batch_emb) # (B, L, d_model)
v = w_q(batch_emb) # (B, L, d_model)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 512])
torch.Size([10, 20, 512])
torch.Size([10, 20, 512])


<br>

- q, k, v를 `num_heads`개의 차원 분할된 여러 vector로 만든다.

In [15]:
batch_size = q.shape[0]
d_k = d_model // num_heads # 64

q = q.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])


In [16]:
q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 8, 20, 64])
torch.Size([10, 8, 20, 64])
torch.Size([10, 8, 20, 64])


<br>


## 7.5 Scaled dot-product self-attention 구현

- 각 head에서 실행되는 self-attention 과정이다.

In [17]:
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # (B, num_heads, L, L)
attn_dists = F.softmax(attn_scores, dim=-1) # (B, num_heads, L, L)

In [18]:
print(attn_dists.shape)

torch.Size([10, 8, 20, 20])


In [19]:
print(attn_dists)

tensor([[[[0.4010, 0.0219, 0.0204,  ..., 0.0263, 0.0263, 0.0263],
          [0.0104, 0.3667, 0.0094,  ..., 0.0125, 0.0125, 0.0125],
          [0.0230, 0.0222, 0.4601,  ..., 0.0285, 0.0285, 0.0285],
          ...,
          [0.0118, 0.0118, 0.0113,  ..., 0.1939, 0.1939, 0.1939],
          [0.0118, 0.0118, 0.0113,  ..., 0.1939, 0.1939, 0.1939],
          [0.0118, 0.0118, 0.0113,  ..., 0.1939, 0.1939, 0.1939]],

         [[0.4642, 0.0305, 0.0217,  ..., 0.0229, 0.0229, 0.0229],
          [0.0299, 0.2538, 0.0260,  ..., 0.0211, 0.0211, 0.0211],
          [0.0198, 0.0241, 0.4565,  ..., 0.0314, 0.0314, 0.0314],
          ...,
          [0.0205, 0.0192, 0.0307,  ..., 0.1501, 0.1501, 0.1501],
          [0.0205, 0.0192, 0.0307,  ..., 0.1501, 0.1501, 0.1501],
          [0.0205, 0.0192, 0.0307,  ..., 0.1501, 0.1501, 0.1501]],

         [[0.3145, 0.0253, 0.0408,  ..., 0.0384, 0.0384, 0.0384],
          [0.0079, 0.3931, 0.0119,  ..., 0.0071, 0.0071, 0.0071],
          [0.0358, 0.0337, 0.3850,  ..., 0

In [20]:
attn_values = torch.matmul(attn_dists, v) # (B, num_heads, L, d_k)

In [21]:
print(attn_values.shape)

torch.Size([10, 8, 20, 64])


In [22]:
print(attn_values)

tensor([[[[-3.9832e-01, -2.5390e-01, -2.1970e-01,  ...,  1.0867e-01,
           -1.6720e-01, -3.6598e-01],
          [-3.1375e-01,  8.8513e-01,  4.4421e-01,  ...,  5.0939e-01,
           -3.9207e-02,  1.7287e-01],
          [ 4.8908e-01, -5.6930e-01,  2.0121e-01,  ..., -3.4793e-01,
            4.9373e-01,  2.2995e-01],
          ...,
          [ 8.2576e-01,  1.6975e-01, -3.3165e-01,  ...,  4.3440e-01,
            5.1049e-02, -1.8943e-01],
          [ 8.2576e-01,  1.6975e-01, -3.3165e-01,  ...,  4.3440e-01,
            5.1049e-02, -1.8943e-01],
          [ 8.2576e-01,  1.6975e-01, -3.3165e-01,  ...,  4.3440e-01,
            5.1049e-02, -1.8943e-01]],

         [[-2.1033e-02, -2.1804e-03,  2.8225e-01,  ..., -2.1492e-01,
            1.4894e-01, -2.4694e-01],
          [ 2.9235e-01, -8.8229e-02,  2.6196e-01,  ..., -1.9656e-01,
            9.4702e-02,  4.5453e-01],
          [-5.8237e-02,  3.8165e-02, -7.1660e-02,  ...,  2.7075e-02,
           -5.0420e-01,  4.1629e-01],
          ...,
     

<br>

## 7.6 각 head의 결과물 병합

- 각 head의 결과물을 concat하고 동일 차원으로 linear transformation한다.

In [23]:
print(attn_values.shape)

torch.Size([10, 8, 20, 64])


In [24]:
attn_values = attn_values.transpose(1, 2) # (B, L, num_heads, d_k)
print(attn_values.shape)

torch.Size([10, 20, 8, 64])


In [25]:
attn_values = attn_values.contiguous().view(batch_size, -1, d_model) # (B, L, d_model)
print(attn_values.shape)

torch.Size([10, 20, 512])


In [26]:
outputs = w_O(attn_values)

In [27]:
print(outputs.shape)

torch.Size([10, 20, 512])


In [28]:
print(outputs)

tensor([[[ 3.6645e-01,  1.3034e-02, -1.0524e-01,  ...,  1.0181e-01,
          -5.1604e-02, -1.3736e-01],
         [-1.7816e-01,  2.8100e-02, -4.7757e-02,  ...,  2.8880e-02,
           4.3996e-01, -2.5548e-01],
         [-1.0033e-01,  2.1689e-01,  2.8576e-02,  ...,  4.0960e-02,
           5.7728e-02, -4.4218e-01],
         ...,
         [ 2.2200e-01,  6.9033e-01, -2.9222e-01,  ...,  2.6137e-01,
          -9.4578e-02, -2.2674e-01],
         [ 2.2200e-01,  6.9033e-01, -2.9222e-01,  ...,  2.6137e-01,
          -9.4578e-02, -2.2674e-01],
         [ 2.2200e-01,  6.9033e-01, -2.9222e-01,  ...,  2.6137e-01,
          -9.4578e-02, -2.2674e-01]],

        [[ 4.8561e-02,  7.4417e-01, -1.2848e-01,  ...,  2.0169e-01,
          -2.2639e-01, -7.7707e-02],
         [ 2.2239e-01,  9.3600e-02,  1.5875e-01,  ...,  1.6421e-02,
           2.8445e-02,  1.7800e-01],
         [ 2.5356e-01,  3.0430e-01, -2.4435e-01,  ...,  2.2100e-01,
          -9.0367e-03, -6.6186e-02],
         ...,
         [ 2.5680e-01,  9

<br>

## 7.7 전체 코드

- 위의 과정을 모두 합쳐 하나의 Multi-Head Attention 모듈을 구현해본다.

In [29]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()

        # Q, K, V learnable matrices
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)

        # Linear transformation for concatenated outputs
        self.w_O = nn.Linear(d_model, d_model)

    def forward(self, q, k, v):
        batch_size = q.shape[0]

        q = self.w_q(q) # (B, L, d_model)
        k = self.w_k(k) # (B, L, d_model)
        v = self.w_v(v) # (B, L, d_model)

        # split heads
        q = q.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)
        k = k.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)
        v = v.view(batch_size, -1, num_heads, d_k) # (B, L, num_heads, d_k)

        q = q.transpose(1, 2) # (B, num_heads, L, d_k)
        k = k.transpose(1, 2) # (B, num_heads, L, d_k)
        v = v.transpose(1, 2) # (B, num_heads, L, d_k)

        attn_values = self.self_attention(q, k, v) # (B, num_heads, L, d_k)
        attn_values = attn_values.transpose(1, 2) # (B, L, num_heads, d_k)
        attn_values = attn_values.contiguous().view(batch_size, -1, d_model) # (B, L, d_model)

        return self.w_O(attn_values)

        
    def self_attention(self, q, k, v):
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) # (B, num_heads, L, L)
        attn_dists = F.softmax(attn_scores, dim=-1) # (B, num_heads, L, L)
        attn_values = torch.matmul(attn_dists, v) # (B, num_heads, L, d_k)
        return attn_values

In [30]:
multihead_attn = MultiHeadAttention()

outputs = multihead_attn(batch_emb, batch_emb, batch_emb) # (B, L, d_model)

In [31]:
print(outputs.shape)

torch.Size([10, 20, 512])


In [32]:
print(outputs)

tensor([[[-0.1779,  0.0911, -0.1022,  ..., -0.0150,  0.0550,  0.0552],
         [-0.0885,  0.0596, -0.1044,  ..., -0.0130,  0.0985,  0.1009],
         [-0.2141,  0.0480, -0.0466,  ..., -0.0477,  0.1206,  0.1355],
         ...,
         [-0.1886,  0.1106, -0.0647,  ..., -0.0424,  0.1168,  0.0838],
         [-0.1886,  0.1106, -0.0647,  ..., -0.0424,  0.1168,  0.0838],
         [-0.1886,  0.1106, -0.0647,  ..., -0.0424,  0.1168,  0.0838]],

        [[-0.3866, -0.0119,  0.1348,  ...,  0.0080,  0.4120, -0.1987],
         [-0.4221, -0.0697,  0.1829,  ..., -0.0048,  0.4192, -0.2076],
         [-0.3741, -0.0280,  0.1774,  ...,  0.0280,  0.4078, -0.1923],
         ...,
         [-0.5010,  0.0122,  0.1367,  ..., -0.0083,  0.4184, -0.1940],
         [-0.5010,  0.0122,  0.1367,  ..., -0.0083,  0.4184, -0.1940],
         [-0.5010,  0.0122,  0.1367,  ..., -0.0083,  0.4184, -0.1940]],

        [[-0.1113, -0.0803,  0.0343,  ...,  0.0360,  0.2206, -0.1678],
         [-0.1363, -0.0922,  0.0627,  ..., -0