In [2]:
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

import torch
import math

In [5]:
pad_id = 0
vocab_size = 100

data = [
  [62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54],
  [75, 51],
  [66, 88, 98, 47],
  [21, 39, 10, 64, 21],
  [98],
  [77, 65, 51, 77, 19, 15, 35, 19, 23, 97, 50, 46, 53, 42, 45, 91, 66, 3, 43, 10],
  [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34],
  [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43]
]

def padding(data):
    max_len = len(max(data, key=len))
    print(f"Maximum sequence length: {max_len}")

    for i, seq in enumerate(tqdm(data)):
        if len(seq) < max_len:
            data[i] = seq + [pad_id] * (max_len - len(seq))
    return data, max_len

data, max_len = padding(data)

100%|██████████| 10/10 [00:00<00:00, 128266.18it/s]

Maximum sequence length: 20





### Hyper parameter setting embedding

In [6]:
d_model = 512
num_heads = 8

embedding = nn.Embedding(vocab_size, d_model)

batch = torch.LongTensor(data)
batch_emb = embedding(batch)

In [9]:
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)

w_0 = nn.Linear(d_model, d_model)

q,k,v = w_q(batch_emb), w_k(batch_emb), w_v(batch_emb)

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 512])
torch.Size([10, 20, 512])
torch.Size([10, 20, 512])


In [8]:
batch_size = q.shape[0]
d_k = d_model // num_heads

q = q.view( batch_size, -1, num_heads, d_k )
k = k.view( batch_size, -1, num_heads, d_k )
v = v.view( batch_size, -1, num_heads, d_k )

print(q.shape)
print(k.shape)
print(v.shape)

torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])
torch.Size([10, 20, 8, 64])


In [20]:
attn_scores = torch.matmul(q, k.transpose(-2,-1)) / math.sqrt(d_k)
attn_dists = F.softmax(attn_scores, dim = -1)

print(attn_dists)
print(attn_dists.shape)

attn_values = torch.matmul(attn_dists, v)

print(attn_values.shape)

tensor([[[0.0302, 0.0287, 0.2438,  ..., 0.0294, 0.0294, 0.0294],
         [0.0433, 0.0731, 0.0598,  ..., 0.0109, 0.0109, 0.0109],
         [0.0843, 0.1672, 0.0237,  ..., 0.0216, 0.0216, 0.0216],
         ...,
         [0.0255, 0.1751, 0.1248,  ..., 0.0100, 0.0100, 0.0100],
         [0.0255, 0.1751, 0.1248,  ..., 0.0100, 0.0100, 0.0100],
         [0.0255, 0.1751, 0.1248,  ..., 0.0100, 0.0100, 0.0100]],

        [[0.0418, 0.0156, 0.0117,  ..., 0.0587, 0.0587, 0.0587],
         [0.0196, 0.0961, 0.0136,  ..., 0.0517, 0.0517, 0.0517],
         [0.0103, 0.0255, 0.1457,  ..., 0.0477, 0.0477, 0.0477],
         ...,
         [0.1555, 0.0249, 0.1721,  ..., 0.0307, 0.0307, 0.0307],
         [0.1555, 0.0249, 0.1721,  ..., 0.0307, 0.0307, 0.0307],
         [0.1555, 0.0249, 0.1721,  ..., 0.0307, 0.0307, 0.0307]],

        [[0.0275, 0.0317, 0.0618,  ..., 0.0801, 0.0801, 0.0801],
         [0.0656, 0.0071, 0.0175,  ..., 0.0688, 0.0688, 0.0688],
         [0.0924, 0.1049, 0.0383,  ..., 0.0283, 0.0283, 0.

## 각 head와 결과물 병합

In [24]:
attn_values = attn_values.transpose(1,2)
attn_values = attn_values.contiguous().view(batch_size, -1, d_model)

print(attn_values.shape)

outputs = w_0(attn_values)

print(outputs)
print(outputs.shape)

torch.Size([10, 20, 512])
tensor([[[-0.0024, -0.0981,  0.2376,  ..., -0.0716, -0.0008, -0.2553],
         [ 0.0360,  0.0455,  0.1208,  ...,  0.1956, -0.1243, -0.0389],
         [-0.0028,  0.0939, -0.1629,  ..., -0.1600,  0.1667, -0.3306],
         ...,
         [-0.1684, -0.1064,  0.2197,  ..., -0.1322,  0.0044, -0.1180],
         [-0.0708,  0.2256, -0.1024,  ..., -0.0662, -0.0648,  0.3130],
         [-0.0147, -0.0364,  0.0042,  ..., -0.0326, -0.1142, -0.2119]],

        [[ 0.1454, -0.2093,  0.0136,  ...,  0.2104, -0.0654, -0.3726],
         [ 0.0288,  0.4116,  0.5411,  ...,  0.1261,  0.3474,  0.0799],
         [-0.1704,  0.0796, -0.1691,  ..., -0.2964,  0.0923, -0.1506],
         ...,
         [-0.4135, -0.0215,  0.4416,  ...,  0.1586,  0.1487,  0.0887],
         [-0.1519,  0.2000,  0.0781,  ..., -0.2038,  0.1383,  0.3918],
         [-0.1004, -0.0034,  0.4002,  ..., -0.1959, -0.0601, -0.0217]],

        [[ 0.0098, -0.2113,  0.0832,  ...,  0.1265,  0.0413, -0.3621],
         [-0.0368, 

In [31]:
class MultiheadAttention(nn.Module):
    def __init__(self):
        super(MultiheadAttention, self).__init__()

        # Q, K, V learnable matrices
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)

        # Linear transformation for concatenated outputs
        self.w_0 = nn.Linear(d_model, d_model)

    def forward(self, q, k, v):
        batch_size = q.shape[0]

        q = self.w_q(q)  # (B, L, d_model)
        k = self.w_k(k)  # (B, L, d_model)
        v = self.w_v(v)  # (B, L, d_model)

        q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
        k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
        v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)

        q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
        k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
        v = v.transpose(1, 2)  # (B, num_heads, L, d_k)

        attn_values = self.self_attention(q, k, v)  # (B, num_heads, L, d_k)
        attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, -1, d_model)  # (B, L, num_heads, d_k) => (B, L, d_model)

        return self.w_0(attn_values)

    def self_attention(self, q, k, v):
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, num_heads, L, L)
        attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)

        attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)

        return attn_values

tensor([[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3],
        [4, 4, 4],
        [5, 5, 5],
        [6, 6, 6],
        [7, 7, 7],
        [8, 8, 8],
        [9, 9, 9]])