In [1]:
import os
import torch
import numpy as np

from torch import nn
from torch.nn import functional as F

from data import Multi30kDataset, download_multi30k, make_cache

# 1.Parameters

In [2]:
DATA_DIR = "/home/pervinco/Datasets"

SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

EPOCHS = 100
BATCH_SIZE = 1
LEARNING_RATE = 1e-5
WEIGHT_DECAY = 5e-9
ADAM_EPS = 5e-9
SCHEDULER_FACTOR = 0.9
SCHEDULER_PATIENCE = 10
WARM_UP_STEP = 100

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_WORKERS = min([os.cpu_count(), BATCH_SIZE if BATCH_SIZE > 1 else 0, 8])

D_MODEL = 512
NUM_HEADS = 8
NUM_LAYERS = 6
FFN_DIM = 2048
MAX_SEQ_LEN = 256
DROP_PROB = 0.1

# 2.Dataset

In [3]:
download_multi30k(DATA_DIR)
make_cache(f"{DATA_DIR}/Multi30k")

/home/pervinco/Datasets/Multi30k is already exist.
/home/pervinco/Datasets/Multi30k/cache is already exist.


In [4]:
DATASET = Multi30kDataset(data_dir=f"{DATA_DIR}/Multi30k", source_language=SRC_LANGUAGE,  target_language=TGT_LANGUAGE,  max_seq_len=MAX_SEQ_LEN, vocab_min_freq=2)

train_iter, valid_iter, test_iter = DATASET.get_iter(batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
src_vocab_size, trg_vocab_size = len(DATASET.src_vocab), len(DATASET.trg_vocab)
print(src_vocab_size, trg_vocab_size)

6274 8041


In [5]:
src_batch, tgt_batch = next(iter(train_iter))
src_sample, tgt_sample = src_batch[0], tgt_batch[0]

print(src_sample.shape, tgt_sample.shape)
print(src_sample)
print(tgt_sample)

torch.Size([15]) torch.Size([14])
tensor([  2,   6,  12,   7,   4, 387,  24,  10, 268,  11,  38,   8, 123,   5,
          3])
tensor([   2,    5,   12,    7,    6,  422,   43,  216,    9,   42, 1534,  136,
           4,    3])


# 3.Embedding, Positional Encoding

In [6]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=256):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x

class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model)

    def forward(self, x):
        x = self.embedding(x) * np.sqrt(x.size(-1))
        x = self.positional_encoding(x)
        return x


In [7]:
## 11개의 토큰이 512 차원의 벡터로 임베딩.

src_embedder = EmbeddingLayer(src_vocab_size, D_MODEL)
src_embed = src_embedder(src_sample.unsqueeze(0))

print(src_embed.shape)

torch.Size([1, 15, 512])


In [8]:
trg_embedder = EmbeddingLayer(trg_vocab_size, D_MODEL)
trg_embed = trg_embedder(tgt_sample.unsqueeze(0))

print(trg_embed.shape)

torch.Size([1, 14, 512])


# 4.Masking

In [9]:
def make_pad_mask(query, key, pad_idx=1):
    """
    Padding Mask
        query: (n_batch, query_seq_len)
        key: (n_batch, key_seq_len)
    """
    query_seq_len, key_seq_len = query.size(1), key.size(1) ## 소스 문장과 타겟 문장의 길이

    ## key.ne(pad_idx)는 key 시퀀스에서 패딩 토큰(pad_idx)이 아닌 위치를 True로 표시하는 마스크를 생성.
    key_mask = key.ne(pad_idx).unsqueeze(1).unsqueeze(2)  # (n_batch, 1, 1, key_seq_len)

    ## key_mask.repeat(1, 1, query_seq_len, 1)은 query 시퀀스의 길이만큼 key_mask를 복제하여 크기를 (n_batch, 1, query_seq_len, key_seq_len)으로 만든다.
    key_mask = key_mask.repeat(1, 1, query_seq_len, 1)    # (n_batch, 1, query_seq_len, key_seq_len)

    ## query.ne(pad_idx)는 query 시퀀스에서 패딩 토큰이 아닌 위치를 True로 표시하는 마스크를 생성
    query_mask = query.ne(pad_idx).unsqueeze(1).unsqueeze(3)  # (n_batch, 1, query_seq_len, 1)

    ## query_mask.repeat(1, 1, 1, key_seq_len)은 key 시퀀스의 길이만큼 query_mask를 복제하여 크기를 (n_batch, 1, query_seq_len, key_seq_len)으로 만든다.
    query_mask = query_mask.repeat(1, 1, 1, key_seq_len)  # (n_batch, 1, query_seq_len, key_seq_len)

    mask = key_mask & query_mask  # 두 행렬에서 True인 원소만 True로.
    mask.requires_grad = False

    return mask


def make_subsequent_mask(query, key):
    """
    Look-Ahead Mask
        query : (batch_size, query_seq_len)
        key : (batch_size, key_seq_len)
    """
    query_seq_len, key_seq_len = query.size(1), key.size(1)

    tril = np.tril(np.ones((query_seq_len, key_seq_len)), k=0).astype('uint8')  # lower triangle without diagonal
    mask = torch.tensor(tril, dtype=torch.bool, requires_grad=False, device=query.device)  # boolean type의 텐서로 변환.

    return mask


def make_src_mask(src):
    pad_mask = make_pad_mask(src, src)
    return pad_mask

def make_tgt_mask(tgt):
    pad_mask = make_pad_mask(tgt, tgt)
    seq_mask = make_subsequent_mask(tgt, tgt)
    mask = pad_mask & seq_mask
    return mask

def make_src_tgt_mask(src, tgt):
    pad_mask = make_pad_mask(tgt, src)
    return pad_mask

In [12]:
src_mask = make_src_mask(src_sample.unsqueeze(0))
print(f"Source Padding Mask: {src_mask.shape} \n{src_mask}")

Source Padding Mask: torch.Size([1, 1, 15, 15]) 
tensor([[[[True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True],
          [True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True],
          [True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True],
          [True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True],
          [True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True],
          [True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True],
          [True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True],
          [True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True],
          [True, True, True, True, True, True, 

In [13]:
trg_x = tgt_sample.unsqueeze(0)[:, :-1]

In [14]:
tgt_mask = make_tgt_mask(trg_x)
print(f"Target Mask: {tgt_mask.shape} \n{tgt_mask}")

Target Mask: torch.Size([1, 1, 13, 13]) 
tensor([[[[ True, False, False, False, False, False, False, False, False, False,
           False, False, False],
          [ True,  True, False, False, False, False, False, False, False, False,
           False, False, False],
          [ True,  True,  True, False, False, False, False, False, False, False,
           False, False, False],
          [ True,  True,  True,  True, False, False, False, False, False, False,
           False, False, False],
          [ True,  True,  True,  True,  True, False, False, False, False, False,
           False, False, False],
          [ True,  True,  True,  True,  True,  True, False, False, False, False,
           False, False, False],
          [ True,  True,  True,  True,  True,  True,  True, False, False, False,
           False, False, False],
          [ True,  True,  True,  True,  True,  True,  True,  True, False, False,
           False, False, False],
          [ True,  True,  True,  True,  True,  

In [16]:
src_tgt_mask = make_src_tgt_mask(src_sample.unsqueeze(0), trg_x)
print(f"Source-Target Mask: {src_tgt_mask.shape} \n{src_tgt_mask}")

Source-Target Mask: torch.Size([1, 1, 13, 15]) 
tensor([[[[True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True],
          [True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True],
          [True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True],
          [True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True],
          [True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True],
          [True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True],
          [True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True],
          [True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True],
          [True, True, True, True, True, True, T