# 15-Attention is all you need

[1.Define Vocab & Dataset](#1)
 - [1-1.Build Vocab(Tokenizer)](#1-1)
 - [1-2.Build Torch Dataset & DataLoader](#1-2)

[2.Word Embedding, Positional Encoding](#2)

[3.Transformer Embedding](#3)

[4.Multi Head Self-Attention](#4)

[5.Feed Forward Layer](#5)

[6.LayerNorm & Residual Connection](#6)

[7.Encoder](#7)

[8.Decoder](#8)

[9.Transformer](#9)

In [1]:
import os
import math
import copy
import torch
import numpy as np

from tqdm import tqdm

from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

from torchtext import transforms
from torchtext.datasets import Multi30k
from torchtext.data.metrics import bleu_score
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

from utils.data_utils import load_pickle, download_multi30k, make_cache

In [2]:
DATA_DIR = "/home/pervinco/Datasets"
SAVE_DIR = "./runs/Transformer"

EPOCHS = 100
BATCH_SIZE = 32
LEARNING_RATE = 1e-5
WEIGHT_DECAY = 5e-9
ADAM_EPS = 5e-9
SCHEDULER_FACTOR = 0.9
SCHEDULER_PATIENCE = 10
WARM_UP_STEP = 100

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_WORKERS = min([os.cpu_count(), BATCH_SIZE if BATCH_SIZE > 1 else 0, 8])

D_MODEL = 512
NUM_HEADS = 8
NUM_LAYERS = 6
FFN_DIM = 2048
MAX_SEQ_LEN = 256
DROP_PROB = 0.1

SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'
UNK_IDX, PAD_IDX, SOS_IDX, EOS_IDX = 0, 1, 2, 3
special_symbols = ['<unk>', '<pad>', '<sos>', '<eos>']

<a id="1"></a>
## 1.Define Vocab & Dataset

<a id="1-1"></a>
### 1-1.Build Vocab(Tokenizer)

In [3]:
token_transform = {}
vocab_transform = {}
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')

def yield_tokens(data_iter, language):
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    for data_sample in data_iter:
        yield token_transform[language](data_sample[language_index[language]])

for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
                                                    min_freq=1,
                                                    specials=special_symbols,
                                                    special_first=True)

for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    vocab_transform[ln].set_default_index(UNK_IDX)

<a id="1-2"></a>
### 1-2.Build Torch Dataset & DataLoader

In [4]:
class Multi30kDataset:
    UNK, UNK_IDX = "<unk>", 0
    PAD, PAD_IDX = "<pad>", 1
    SOS, SOS_IDX = "<sos>", 2
    EOS, EOS_IDX = "<eos>", 3
    SPECIALS = {UNK : UNK_IDX, PAD : PAD_IDX, SOS : SOS_IDX, EOS : EOS_IDX}

    URL = "https://github.com/multi30k/dataset/raw/master/data/task1/raw"
    FILES = ["test_2016_flickr.de.gz",
             "test_2016_flickr.en.gz",
             "train.de.gz",
             "train.en.gz",
             "val.de.gz",
             "val.en.gz"]
    

    def __init__(self, data_dir, source_language="en", target_language="de", max_seq_len=256, vocab_min_freq=2):
        self.data_dir = data_dir

        self.max_seq_len = max_seq_len
        self.vocab_min_freq = vocab_min_freq
        self.source_language = source_language
        self.target_language = target_language

        ## 데이터 파일 로드.
        self.train = load_pickle(f"{data_dir}/cache/train.pkl")
        self.valid = load_pickle(f"{data_dir}/cache/val.pkl")
        self.test = load_pickle(f"{data_dir}/cache/test.pkl")

        ## tokenizer 정의.
        if self.source_language == "en":
            self.source_tokenizer = get_tokenizer("spacy", "en_core_web_sm")
            self.target_tokenizer = get_tokenizer("spacy", "de_core_news_sm")
        else:
            self.source_tokenizer = get_tokenizer("spacy", "de_core_news_sm")
            self.target_tokenizer = get_tokenizer("spacy", "en_core_web_sm")

        self.src_vocab, self.trg_vocab = self.get_vocab(self.train)
        self.src_transform = self.get_transform(self.src_vocab)
        self.trg_transform = self.get_transform(self.trg_vocab)


    def yield_tokens(self, train_dataset, is_src):
        for text_pair in train_dataset:
            if is_src:
                yield [str(token) for token in self.source_tokenizer(text_pair[0])]
            else:
                yield [str(token) for token in self.target_tokenizer(text_pair[1])]


    def get_vocab(self, train_dataset):
        src_vocab_pickle = f"{self.data_dir}/cache/vocab_{self.source_language}.pkl"
        trg_vocab_pickle = f"{self.data_dir}/cache/vocab_{self.target_language}.pkl"

        if os.path.exists(src_vocab_pickle) and os.path.exists(trg_vocab_pickle):
            src_vocab = load_pickle(src_vocab_pickle)
            trg_vocab = load_pickle(trg_vocab_pickle)
        else:
            src_vocab = build_vocab_from_iterator(self.yield_tokens(train_dataset, True), min_freq=self.vocab_min_freq, specials=self.SPECIALS.keys())
            src_vocab.set_default_index(self.UNK_IDX)

            trg_vocab = build_vocab_from_iterator(self.yield_tokens(train_dataset, False), min_freq=self.vocab_min_freq, specials=self.SPECIALS.keys())
            trg_vocab.set_default_index(self.UNK_IDX)
            
        return src_vocab, trg_vocab
    

    def get_transform(self, vocab):
        return transforms.Sequential(transforms.VocabTransform(vocab),
                                     transforms.Truncate(self.max_seq_len-2),
                                     transforms.AddToken(token=self.SOS_IDX, begin=True),
                                     transforms.AddToken(token=self.EOS_IDX, begin=False),
                                     transforms.ToTensor(padding_value=self.PAD_IDX))


    def collate_fn(self, pairs):
        src = [self.source_tokenizer(pair[0]) for pair in pairs]
        trg = [self.target_tokenizer(pair[1]) for pair in pairs]
        batch_src = self.src_transform(src)
        batch_trg = self.trg_transform(trg)

        return (batch_src, batch_trg)
    

    def get_iter(self, batch_size, num_workers):
        train_iter = DataLoader(self.train, collate_fn=self.collate_fn, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        valid_iter = DataLoader(self.valid, collate_fn=self.collate_fn, batch_size=batch_size, num_workers=num_workers)
        test_iter = DataLoader(self.test, collate_fn=self.collate_fn, batch_size=batch_size, num_workers=num_workers)

        return train_iter, valid_iter, test_iter
    
    
    def translate(self, model, src_sentence: str, decode_func):
        model.eval()
        src = self.src_transform([self.source_tokenizer(src_sentence)]).view(1, -1)
        num_tokens = src.shape[1]
        trg_tokens = decode_func(model, src, max_len=num_tokens + 5, start_symbol=self.SOS_IDX, end_symbol=self.EOS_IDX).flatten().cpu().numpy()
        trg_sentence = " ".join(self.trg_vocab.lookup_tokens(trg_tokens))

        return trg_sentence


In [5]:
download_multi30k(DATA_DIR)
make_cache(f"{DATA_DIR}/Multi30k")

DATASET = Multi30kDataset(data_dir=f"{DATA_DIR}/Multi30k", source_language=SRC_LANGUAGE,  target_language=TGT_LANGUAGE,  max_seq_len=MAX_SEQ_LEN, vocab_min_freq=2)
train_iter, valid_iter, test_iter = DATASET.get_iter(batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

/home/pervinco/Datasets/Multi30k is already exist.
/home/pervinco/Datasets/Multi30k/cache is already exist.


In [6]:
sample_src, sample_tgt = None, None
for src, trg in train_iter:
    print(src.shape)
    print(trg.shape)
    
    for s in src.numpy():
        print(s)

    sample_src = src
    sample_tgt = trg
    break

torch.Size([32, 24])
torch.Size([32, 23])
[   2    6   33    7    4   26   11   23 1567   11   96 5741   45    4
  425   14  515    9 1132  161   13   44    5    3]
[   2   48  390  128   10 1010  119   70    8   74 1372    5    3    1
    1    1    1    1    1    1    1    1    1    1]
[   2   19   53   17   55   70    4  425 1887    5    3    1    1    1
    1    1    1    1    1    1    1    1    1    1]
[   2    6   12   38 1427   15   27  861 2625   18    8  485    5    3
    1    1    1    1    1    1    1    1    1    1]
[  2   6  23  34 206 121   8  90   3   1   1   1   1   1   1   1   1   1
   1   1   1   1   1   1]
[   2    6   60   35   10  171    4 5780 1180  102   56    4   60   26
 1917    5    3    1    1    1    1    1    1    1]
[   2    6   16   21    4   96 4101  238    4  157   41    4 3243  362
    5    3    1    1    1    1    1    1    1    1]
[   2   19   37   17  294   28 5033    9    4  940    5    3    1    1
    1    1    1    1    1    1    1    1    1    1

<a id="2"></a>
## 2.Word Embedding, Positional Encoding

In [7]:
## Word Embedding
class WordEmbedding(nn.Module):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model) ## 임베딩을 학습 과정에서 조정하는 임베딩 층.
        self.d_model = d_model

    def forward(self, x):
        ## 임베딩 벡터에서 값이 너무 큰 원소를 억제하기 위해 math.sqrt(self.d_model)를 곱한다.
        out = self.embedding(x) * math.sqrt(self.d_model)

        return out

In [8]:
word_embedding_layer = WordEmbedding(D_MODEL, len(DATASET.src_vocab))

embedded_sample_src = word_embedding_layer(sample_src)
print(embedded_sample_src.shape)

torch.Size([32, 24, 512])


In [9]:
## Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len=256, device=torch.device("cpu")):
        super().__init__()

        ## positional encoding 결과를 담을 텐서. -> [max_seq_len, d_model]
        encoding = torch.zeros(max_seq_len, d_model, requires_grad=False)

        ## 0부터 max_seq_len-1까지의 연속된 정수값을 포함하는 1차원 텐서. -> [max_seq_len, 1]
        position = torch.arange(0, max_seq_len).float().unsqueeze(1)
        
        ## 10000**(2i/d_model) == e**(log(10000) * 2i / d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) 

        ## i가 홀수일 때는 cos함수, j가 짝수일 때는 sin함수.
        encoding[:, 0::2] = torch.sin(position * div_term)
        encoding[:, 1::2] = torch.cos(position * div_term)

        ## [1, max_seq_len, d_model]
        self.encoding = encoding.unsqueeze(0).to(device) 

    def forward(self, x):
        ## x : Word Embedded Vector
        _, seq_len, _ = x.size() ## [batch_size, max_seq_len, d_model]
        pos_embed = self.encoding[:, :seq_len, :] ## slicing [1, seq_len, d_model]

        out = x + pos_embed ## word embedded vector에 positional encoding값을 더함.

        return out

In [10]:
positional_encoding_layer = PositionalEncoding(D_MODEL, MAX_SEQ_LEN)

pe_embedded = positional_encoding_layer(embedded_sample_src)
print(pe_embedded.shape)

torch.Size([32, 24, 512])


<a id="3"></a>
## 3.Transformer Embedding

In [11]:
class TransformerEmbedding(nn.Module):
    def __init__(self, word_embedding_layer, positional_encoding_layer, drop_prob=0):
        super().__init__()
        self.embedding = nn.Sequential(word_embedding_layer, positional_encoding_layer)
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, x):
        out = x
        out = self.embedding(out)
        out = self.dropout(out)

        return out

In [12]:
transformer_embedding_layer = TransformerEmbedding(word_embedding_layer, positional_encoding_layer)
embedded_input_matrix = transformer_embedding_layer(sample_src)

print(embedded_input_matrix.shape)

torch.Size([32, 24, 512])


<a id="4"></a>

## 4.Multi Head Self-Attention

In [13]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, d_model, num_heads, drop_prob=0):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads

        ## Embedded vector를 Q, K, V 행렬로 변환하기 위한 Layers
        self.q_fc = nn.Linear(d_model, d_model) # (d_model, d_model)
        self.k_fc = nn.Linear(d_model, d_model) # (d_model, d_model)
        self.v_fc = nn.Linear(d_model, d_model) # (d_model, d_model)

        ## output layer
        self.out_fc = nn.Linear(d_model, d_model) # (d_model, d_model)

        self.dropout = nn.Dropout(p=drop_prob)


    def calculate_attention(self, query, key, value, mask):
        ## Q, K, V: (batch_size, num_heads, seq_len, d_k)
        ## Padding Mask : (batch_size, seq_len, seq_len)
        d_k = key.shape[-1]
        attention_score = torch.matmul(query, key.transpose(-2, -1)) # Attention Score. Q x K^T, (n_batch, num_heads, seq_len, seq_len)
        attention_score = attention_score / math.sqrt(d_k) ## Scaling

        if mask is not None:
            attention_score = attention_score.masked_fill(mask==0, -1e9)
            
        attention_prob = F.softmax(attention_score, dim=-1) ## Attention Distribution. (n_batch, num_heads, seq_len, seq_len)
        attention_prob = self.dropout(attention_prob)
        out = torch.matmul(attention_prob, value) ## Attention Value. (n_batch, num_heads, seq_len, d_k)

        return out


    def forward(self, query, key, value, mask=None):
        ## (query, key, value)는 동일한 input tensor : (n_batch, seq_len, d_model)
        ## transform(x, fc)가 Q, K, V로 변환.
        ## mask: (n_batch, seq_len, seq_len)
        # return value: (n_batch, num_heads, seq_len, d_k)
        n_batch = query.size(0)

        def transform(x, fc): # (n_batch, seq_len, d_model)
            out = fc(x)       # (n_batch, seq_len, d_model)

            # (n_batch, seq_len, num_heads, d_k)
            out = out.view(n_batch, -1, self.num_heads, self.d_model // self.num_heads)
            out = out.transpose(1, 2) # (n_batch, num_heads, seq_len, d_k)

            return out

        query = transform(query, self.q_fc) # (n_batch, num_heads, seq_len, d_k)
        key = transform(key, self.k_fc)       # (n_batch, num_heads, seq_len, d_k)
        value = transform(value, self.v_fc) # (n_batch, num_heads, seq_len, d_k)

        out = self.calculate_attention(query, key, value, mask) # (n_batch, num_heads, seq_len, d_k)
        out = out.transpose(1, 2) # (n_batch, seq_len, num_heads, d_k)
        out = out.contiguous().view(n_batch, -1, self.d_model) # (n_batch, seq_len, d_model)
        out = self.out_fc(out) # (n_batch, seq_len, d_model)

        return out


<a id="5"></a>
## 5.Feed Forward Layer

In [14]:
class FeedForwardLayer(nn.Module):
    def __init__(self, d_embed, d_ff, drop_prob=0):
        super().__init__()
        self.fc1 = nn.Linear(d_embed, d_ff)   # (d_embed, d_ff)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=drop_prob)
        self.fc2 = nn.Linear(d_ff, d_embed) # (d_ff, d_embed)

    def forward(self, x):
        out = x
        out = self.fc1(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.fc2(out)

        return out

<a id="6"></a>
## 6.LayerNorm & Residual Connection

In [15]:
encoder_norm = nn.LayerNorm(D_MODEL, eps=1e-5)
decoder_norm = nn.LayerNorm(D_MODEL, eps=1e-5)

In [16]:
class ResidualConnectionLayer(nn.Module):
    def __init__(self, norm, drop_prob=0):
        super().__init__()
        self.norm = norm
        self.dropout = nn.Dropout(p=drop_prob)

    def forward(self, x, sub_layer): ## 입력 텐서와 sublayer를 입력 받는 것에 주목.
        out = x
        out = self.norm(out)
        out = sub_layer(out)
        out = self.dropout(out)
        out = out + x

        # out = x
        # out = sub_layer(out)
        # out = self.dropout(out)
        # out = self.norm(x + out)
        
        return out

<a id="7"></a>
## 7.Encoder

In [17]:
class EncoderBlock(nn.Module):
    def __init__(self, self_attention, ffn, norm, drop_prob=0):
        """
        self_attention : Self-Attention Layer 객체
        ffn : FeedForward Layer 객체
        norm : LayerNorm 겍체
        """
        super().__init__()
        self.self_attention = self_attention
        ## 여기서 copy.deepcopy()는 하나의 layer 객체를 서로 다른 객체로 복사하기 위함.
        self.residual1 = ResidualConnectionLayer(copy.deepcopy(norm), drop_prob)

        self.ffn = ffn
        self.residual2 = ResidualConnectionLayer(copy.deepcopy(norm), drop_prob)

    def forward(self, src, src_mask):
        ## src : [batch_size, max_seq_len, d_model]
        out = src

        ## Self-Attention이므로 query, key, value가 모두 out으로 동일.
        ## 이 때, lambda가 적용되는 이유는 순환의 목적이 아니라 reisdual 객체 내부에서 필요한 시점에 호출하기 위함.
        out = self.residual1(out, lambda out : self.self_attention(query=out, key=out, value=out, mask=src_mask))
        out = self.residual2(out, self.ffn)

        return out

In [18]:
class Encoder(nn.Module):
    def __init__(self, encoder_block, num_layer, norm):
        super().__init__()
        self.num_layer = num_layer
        self.layers = nn.ModuleList([copy.deepcopy(encoder_block) for _ in range(self.num_layer)])
        self.norm = norm

    def forward(self, src, src_mask):
        out = src
        for layer in self.layers:
            out = layer(out, src_mask)
            
        out = self.norm(out)

        return out


<a id="8"></a>
## 8.Decoder

In [19]:
class DecoderBlock(nn.Module):
    def __init__(self, self_attention, cross_attention, ffn, norm, drop_prob=0):
        super().__init__()
        """
        self_attention : Masked Multi Head Self Attention 객체
        cross_attention : Encoder - Decoder Multi Head Self Attention 객체
        ffn : FeedForward Layer 객체
        norm : LayerNorm 객체
        """
        self.self_attention = self_attention
        self.residual1 = ResidualConnectionLayer(copy.deepcopy(norm), drop_prob)
        self.cross_attention = cross_attention
        self.residual2 = ResidualConnectionLayer(copy.deepcopy(norm), drop_prob)
        self.ffn = ffn
        self.residual3 = ResidualConnectionLayer(copy.deepcopy(norm), drop_prob)

    def forward(self, tgt, encoder_out, tgt_mask, src_tgt_mask):
        out = tgt
        out = self.residual1(out, lambda out: self.self_attention(query=out, key=out, value=out, mask=tgt_mask))
        out = self.residual2(out, lambda out: self.cross_attention(query=out, key=encoder_out, value=encoder_out, mask=src_tgt_mask))
        out = self.residual3(out, self.ffn)

        return out

In [20]:
class Decoder(nn.Module):
    def __init__(self, decoder_block, num_layer, norm):
        super().__init__()
        self.num_layer = num_layer
        self.layers = nn.ModuleList([copy.deepcopy(decoder_block) for _ in range(self.num_layer)])
        self.norm = norm

    def forward(self, trg, encoder_out, trg_mask, src_trg_mask):
        out = trg
        for layer in self.layers:
            out = layer(out, encoder_out, trg_mask, src_trg_mask)
        out = self.norm(out)

        return out

<a id="9"></a>
## 9.Transformer

In [21]:
class Transformer(nn.Module):
    def __init__(self, src_embed, trg_embed, encoder, decoder, generator):
        super().__init__()
        self.src_embed = src_embed
        self.trg_embed = trg_embed
        self.encoder = encoder
        self.decoder = decoder
        self.generator = generator


    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)


    def decode(self, trg, encoder_out, trg_mask, src_trg_mask):
        return self.decoder(self.trg_embed(trg), encoder_out, trg_mask, src_trg_mask)


    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        src_trg_mask = self.make_src_trg_mask(src, trg)
        encoder_out = self.encode(src, src_mask)
        decoder_out = self.decode(trg, encoder_out, trg_mask, src_trg_mask)
        out = self.generator(decoder_out)
        out = F.log_softmax(out, dim=-1)

        return out, decoder_out


    def make_src_mask(self, src):
        pad_mask = self.make_pad_mask(src, src)
        
        return pad_mask


    def make_trg_mask(self, trg):
        pad_mask = self.make_pad_mask(trg, trg)
        seq_mask = self.make_subsequent_mask(trg, trg)

        return pad_mask & seq_mask


    def make_src_trg_mask(self, src, trg):
        pad_mask = self.make_pad_mask(trg, src)

        return pad_mask


    def make_pad_mask(self, query, key, pad_idx=1):
        """
        Padding Mask
            query: (n_batch, query_seq_len)
            key: (n_batch, key_seq_len)
        """

        query_seq_len, key_seq_len = query.size(1), key.size(1) ## query_seq_len, key_seq_len

        ## ne : pad_idx가 아닌 원소들을 찾아 True/False인 텐서를 만든다.
        key_mask = key.ne(pad_idx).unsqueeze(1).unsqueeze(2)  # (n_batch, 1, 1, key_seq_len)

        ## 세번째 차원을 쿼리 시퀀스 길이에 맞춰 반복해서 쌓는다.
        key_mask = key_mask.repeat(1, 1, query_seq_len, 1)    # (n_batch, 1, query_seq_len, key_seq_len)

        ## ne : pad_idx가 아닌 원소들을 찾아 True/False인 텐서를 만든다.
        query_mask = query.ne(pad_idx).unsqueeze(1).unsqueeze(3)  # (n_batch, 1, query_seq_len, 1)
        query_mask = query_mask.repeat(1, 1, 1, key_seq_len)  # (n_batch, 1, query_seq_len, key_seq_len)

        mask = key_mask & query_mask ## 두 행렬에서 True인 원소만 True로.
        mask.requires_grad = False

        return mask


    def make_subsequent_mask(self, query, key):
        """
        Look-Ahead Mask
            query : (batch_size, query_seq_len)
            key : (batch_size, key_seq_len)
        """
        query_seq_len, key_seq_len = query.size(1), key.size(1)

        ## shape이 query_seq_len, key_seq_len인 lower-triangular matrix를 만든다. k는 주대각원소의 위쪽에 있는 원소값.
        tril = np.tril(np.ones((query_seq_len, key_seq_len)), k=0).astype('uint8') # lower triangle without diagonal
        mask = torch.tensor(tril, dtype=torch.bool, requires_grad=False, device=query.device) ## boolean type의 텐서로 변환.

        return mask


In [22]:
def build_model(src_vocab_size,
                trg_vocab_size, 
                max_seq_len=256,
                d_model=512, 
                num_layer=6, 
                num_heads=8, 
                d_ff=2048, 
                norm_eps=1e-5,
                drop_prob=0.1, 
                device=torch.device("cpu")):

    ## Word Embedding.
    src_token_embed = WordEmbedding(d_model=d_model, vocab_size=src_vocab_size)
    trg_token_embed = WordEmbedding(d_model=d_model, vocab_size=trg_vocab_size)

    ## Positional Encoding.
    src_pos_embedd = PositionalEncoding(d_model=d_model, max_seq_len=max_seq_len, device=device)
    trg_pos_embedd = PositionalEncoding(d_model=d_model, max_seq_len=max_seq_len, device=device)

    ## Word Embedding + Positional Encoding.
    trg_embed = TransformerEmbedding(word_embedding_layer=trg_token_embed, positional_encoding_layer=trg_pos_embedd, drop_prob=drop_prob)
    src_embed = TransformerEmbedding(word_embedding_layer=src_token_embed, positional_encoding_layer=src_pos_embedd, drop_prob=drop_prob)

    ## Multi-Head Self Attention.
    encoder_attention = MultiHeadAttentionLayer(d_model=d_model, num_heads=num_heads, drop_prob=drop_prob)
    decoder_attention = MultiHeadAttentionLayer(d_model=d_model, num_heads=num_heads, drop_prob=drop_prob)
    
    ## Position-Wise FeedForward.
    encoder_position_ff = FeedForwardLayer(d_model, d_ff, drop_prob=drop_prob)
    decoder_position_ff = FeedForwardLayer(d_model, d_ff, drop_prob=drop_prob)

    ## Add & Norm.
    encoder_norm = nn.LayerNorm(d_model, eps=norm_eps)
    decoder_norm = nn.LayerNorm(d_model, eps=norm_eps)

    ## Encoder Block.
    encoder_block = EncoderBlock(self_attention=encoder_attention,
                                 ffn=encoder_position_ff,
                                 norm=encoder_norm,
                                 drop_prob=drop_prob)
    
    ## Decoder Block
    decoder_block = DecoderBlock(self_attention=decoder_attention,
                                 cross_attention=decoder_attention,
                                 ffn=decoder_position_ff,
                                 norm=decoder_norm,
                                 drop_prob=drop_prob)

    ## Encoder(Encoder Block * Num_layers)
    encoder = Encoder(encoder_block=encoder_block, num_layer=num_layer, norm=encoder_norm)

    ## Decoder(Decoder Block * Num_layers)
    decoder = Decoder(decoder_block=decoder_block, num_layer=num_layer, norm=decoder_norm)

    ## Output Layer.
    generator = nn.Linear(d_model, trg_vocab_size)

    model = Transformer(src_embed=src_embed,
                        trg_embed=trg_embed,
                        encoder=encoder,
                        decoder=decoder,
                        generator=generator).to(device)
    
    model.device = device

    return model

In [23]:
def initialize_weights(model):
    if hasattr(model, 'weight') and model.weight.dim() > 1:
        nn.init.kaiming_uniform_(model.weight.data)

In [24]:
model = build_model(len(DATASET.src_vocab), len(DATASET.trg_vocab), device=DEVICE, drop_prob=DROP_PROB)
model.apply(initialize_weights)

optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, eps=ADAM_EPS)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, verbose=True, factor=SCHEDULER_FACTOR, patience=SCHEDULER_PATIENCE)
criterion = nn.CrossEntropyLoss(ignore_index=DATASET.PAD_IDX)

In [25]:
def train(model, data_loader, optimizer, criterion):
    model.train()
    epoch_loss = 0
    for (src, trg) in tqdm(data_loader, desc="train", leave=False):
        src = src.to(DEVICE)
        trg = trg.to(DEVICE)
        trg_x = trg[:, :-1]
        trg_y = trg[:, 1:]

        optimizer.zero_grad()

        output, _ = model(src, trg_x)

        y_hat = output.contiguous().view(-1, output.shape[-1])
        y_gt = trg_y.contiguous().view(-1)
        loss = criterion(y_hat, y_gt)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(list(data_loader))

In [26]:
def get_bleu_score(output, gt, vocab, specials, max_n=4):
    def itos(x):
        x = list(x.cpu().numpy())
        tokens = vocab.lookup_tokens(x)
        tokens = list(filter(lambda x: x not in {"", " ", "."} and x not in list(specials.keys()), tokens))
        return tokens

    pred = [out.max(dim=1)[1] for out in output]
    pred_str = list(map(itos, pred))
    gt_str = list(map(lambda x: [itos(x)], gt))

    return  bleu_score(pred_str, gt_str, max_n=max_n) * 100.0


def greedy_decode(model, src, max_len, start_symbol, end_symbol):
    src = src.to(model.device)
    src_mask = model.make_src_mask(src).to(model.device)
    memory = model.encode(src, src_mask)

    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(model.device)
    for i in range(max_len-1):
        memory = memory.to(model.device)
        trg_mask = model.make_trg_mask(ys).to(model.device)
        src_trg_mask = model.make_src_trg_mask(src, ys).to(model.device)
        out = model.decode(ys, memory, trg_mask, src_trg_mask)
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
        if next_word == end_symbol:
            break
        
    return ys


def evaluate(model, data_loader, criterion):
    model.eval()
    epoch_loss = 0

    total_bleu = []
    with torch.no_grad():
        for (src, trg) in tqdm(data_loader, desc="eval", leave=False):
            src = src.to(DEVICE)
            trg = trg.to(DEVICE)
            trg_x = trg[:, :-1]
            trg_y = trg[:, 1:]

            output, _ = model(src, trg_x)

            y_hat = output.contiguous().view(-1, output.shape[-1])
            y_gt = trg_y.contiguous().view(-1)
            loss = criterion(y_hat, y_gt)

            epoch_loss += loss.item()
            score = get_bleu_score(output, trg_y, DATASET.trg_vocab, DATASET.SPECIALS)
            total_bleu.append(score)

    loss_avr = epoch_loss / len(list(data_loader))
    bleu_score = sum(total_bleu) / len(total_bleu)

    return loss_avr, bleu_score

In [27]:
print("\nTrain Start")
if not os.path.isdir(SAVE_DIR):
    os.makedirs(SAVE_DIR, exist_ok=True)

min_val_loss = 0
for epoch in range(EPOCHS):
    train_loss = train(model, train_iter, optimizer, criterion)
    valid_loss, bleu_scores  = evaluate(model, valid_iter, criterion)

    if epoch == 0:
        min_val_loss = valid_loss

    if epoch > 1:
        if valid_loss < min_val_loss:
            min_val_loss = valid_loss
            ckpt = f"{SAVE_DIR}/{epoch:04}.pt"
            torch.save({'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'train_loss': train_loss,
                        'val_loss' : valid_loss}, ckpt)

    if epoch > WARM_UP_STEP:
        scheduler.step(valid_loss)

    print(f"Epoch : {epoch + 1} | train_loss: {train_loss:.5f} valid_loss: {valid_loss:.5f}, bleu_scores: {bleu_scores:.5f}")
    print("Predict : ", DATASET.translate(model, "A little girl climbing into a wooden playhouse .", greedy_decode))
    print(f"Answer : Ein kleines Mädchen klettert in ein Spielhaus aus Holz . \n")

test_loss, bleu_scores = evaluate(model, test_iter, criterion)
print(f"test_loss: {test_loss:.5f}")
print(f"bleu_scores: {bleu_scores:.5f}")


Train Start


                                                        

Epoch : 1 | train_loss: 5.42834 valid_loss: 4.53360, bleu_scores: 4.31376
Predict :  <sos> Ein Mädchen in einem <unk> . <eos>
Answer : Ein kleines Mädchen klettert in ein Spielhaus aus Holz . 



                                                        

Epoch : 2 | train_loss: 4.46157 valid_loss: 4.10693, bleu_scores: 6.67109
Predict :  <sos> Ein Mädchen mit einem Mädchen . <eos>
Answer : Ein kleines Mädchen klettert in ein Spielhaus aus Holz . 



train:  61%|██████    | 549/907 [00:13<00:08, 44.60it/s]