### Prepare data and preprocess pipeline

In [None]:
# Set tokenizer and build vocabulary
from torchtext.datasets import Multi30k
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

en_tokenizer = get_tokenizer('spacy', language='en_core_web_md')
de_tokenizer = get_tokenizer('spacy', language='de_core_news_md')

def yield_tokens(data_iter, language: str):
    for text in data_iter:
        if language == 'en':
            yield en_tokenizer(text[0])
        elif language == 'de':
            yield de_tokenizer(text[1])

UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
special_tokens = ['<unk>', '<pad>', '<bos>', '<eos>']
en_vocab = build_vocab_from_iterator(yield_tokens(Multi30k(split='train', language_pair=('en', 'de')), 'en'), specials=special_tokens, special_first=True, min_freq=3)
de_vocab = build_vocab_from_iterator(yield_tokens(Multi30k(split='train', language_pair=('en', 'de')), 'de'), specials=special_tokens, special_first=True, min_freq=3)
en_vocab.set_default_index(UNK_IDX) # oov 일때 반환하는 토큰
de_vocab.set_default_index(UNK_IDX)

In [2]:
# Set preprocess pipeline
en_pipeline = lambda x: en_vocab(en_tokenizer(x))
de_pipeline = lambda x: de_vocab(de_tokenizer(x))
trasnform_pipeline = {'en': en_pipeline, 'de': de_pipeline}
print(en_pipeline('Several men in hard hats are operating a giant pulley system.'))
print(de_pipeline('Mehrere Männer mit Schutzhelmen bedienen ein Antriebsradsystem.'))

[166, 37, 8, 336, 288, 18, 1225, 4, 759, 4497, 2958, 6]
[85, 32, 11, 848, 2209, 16, 0, 5]


In [3]:
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.backward_compatibility import worker_init_fn


def collate_func(batch, src_ln: str = 'de', tgt_ln: str = 'en', batch_first: bool = True):
    src_batch, tgt_batch = [], []
    for src, tgt in batch:
        src_ids = trasnform_pipeline[src_ln](src.rstrip('\n'))
        tgt_ids = trasnform_pipeline[tgt_ln](tgt.rstrip('\n'))
        src_ids = torch.cat((torch.tensor([BOS_IDX]), torch.tensor(src_ids), torch.tensor([EOS_IDX])))
        tgt_ids = torch.cat((torch.tensor([BOS_IDX]), torch.tensor(tgt_ids), torch.tensor([EOS_IDX])))
        src_batch.append(src_ids)
        tgt_batch.append(tgt_ids)

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first=batch_first)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX, batch_first=batch_first)
    
    return src_batch, tgt_batch



train_iter, valid_iter, test_iter = Multi30k(root='./data', split=('train', 'valid', 'test'), language_pair=('de', 'en'))
train_loader = DataLoader(list(train_iter), batch_size=128, collate_fn=collate_func, num_workers=8, shuffle=True)
valid_loader = DataLoader(list(valid_iter), batch_size=1, collate_fn=collate_func)
test_loader = DataLoader(list(test_iter), batch_size=1, collate_fn=collate_func)

In [4]:
import math
import torch.nn as nn
from torch import Tensor


class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int = 512, dropout: float = 0.1, max_len: int = 5000, device=None) -> None:
        super(PositionalEncoding, self).__init__()
        pos = torch.arange(0, max_len).reshape(max_len, 1)
        val = torch.exp(-torch.arange(0, emb_size, 2) / emb_size * math.log(10000))
        pos_encoding = torch.zeros((max_len, emb_size))
        pos_encoding[:, 0::2] = torch.sin(pos * val)
        pos_encoding[:, 1::2] = torch.cos(pos * val)
        pos_encoding = pos_encoding.unsqueeze(0).to(device)    # batch first
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_encoding', pos_encoding)
    
    def forward(self, token_embedding: Tensor) -> Tensor:
        return self.dropout(token_embedding + self.pos_encoding[:, :token_embedding.size(1), :])

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size: int = 512, device=None) -> None:
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size, device=device)
        self.emb_size = emb_size
    
    def forward(self, tokens: Tensor) -> Tensor:
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

In [5]:
from typing import Optional
from torch.nn.init import xavier_uniform_

from transformer import Transformer


class TransformerWrapper(nn.Module):
    def __init__(self, emb_size: int, src_vocab_size: int, tgt_vocab_size: int, device=None, dtype=None):
        super(TransformerWrapper, self).__init__()
        self.transformer = Transformer(d_model=emb_size, num_encoder_layers=3, num_decoder_layers=3, n_head=8, dim_feedforward=1024, batch_first=True, device=device)
        self.generator = nn.Linear(emb_size, tgt_vocab_size, device=device)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size, device=device)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size, device=device)
        self.pos_encoding = PositionalEncoding(emb_size=emb_size, device=device)

        self._reset_params()
        
    
    def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None):
        src_emb = self.pos_encoding(self.src_tok_emb(src))
        tgt_emb = self.pos_encoding(self.tgt_tok_emb(tgt))
        output = self.transformer(src_emb, tgt_emb, src_mask=src_mask, tgt_mask=tgt_mask, memory_mask=memory_mask)
        return self.generator(output)

    def encode(self, src: Tensor, src_mask: Tensor) -> Tensor:
        src_emb = self.pos_encoding(self.src_tok_emb(src))
        return self.transformer.encoder(src_emb, src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor) -> Tensor:
        tgt_emb = self.pos_encoding(self.tgt_tok_emb(tgt))
        return self.transformer.decoder(tgt_emb, memory, tgt_mask)
    
    def _reset_params(self):
        for p in self.parameters():
            if p.dim() > 1:
                xavier_uniform_(p)

In [6]:
from typing import Tuple

def generate_mask(src: Tensor, tgt: Tensor, device=None) -> Tuple[Tensor]:
        """
        Args:
            src: [N, L]
            tgt: [N, L]
        """
        tgt_seq_len = tgt.shape[-1]

        """
        tgt mask
        1 0 0 0 0 
        1 1 0 0 0
        1 1 1 0 0
        1 1 1 1 0
        1 1 1 1 1
        """
        src_mask = (src != PAD_IDX).unsqueeze(1).unsqueeze(2)
        tgt_seq_mask = (torch.triu(torch.ones((tgt_seq_len, tgt_seq_len), device=device)) == 1).transpose(0, 1)
        tgt_padding_mask = (tgt != PAD_IDX).unsqueeze(1).unsqueeze(2)

        tgt_mask = tgt_seq_mask & tgt_padding_mask


        return src_mask, tgt_mask

In [7]:
from torch.optim import Adam

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

model = TransformerWrapper(emb_size=512, src_vocab_size=len(de_vocab), tgt_vocab_size=len(en_vocab), device=device)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
optimizer = Adam(model.parameters(), lr=0.0005)

EPOCH = 10
GCLIP = 1
for e in range(EPOCH):
    epoch_loss = 0
    model.train()
    for src, tgt in train_loader:
        src, tgt = src.to(device), tgt.to(device)

        tgt_in = tgt[:, :-1]
        tgt_out = tgt[:, 1:]
        
        src_mask, tgt_mask = generate_mask(src, tgt_in, device)

        logits = model(src, tgt_in, src_mask=src_mask, tgt_mask=tgt_mask, memory_mask=src_mask)
        preds = torch.argmax(logits, dim=-1)

        optimizer.zero_grad()
        loss = loss_fn(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
        epoch_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), GCLIP)
        optimizer.step()


    model.eval()
    with torch.no_grad():
        valid_loss = 0
        for src, tgt in valid_loader:
            src, tgt = src.to(device), tgt.to(device)
            tgt_in = tgt[:, :-1]
            tgt_out = tgt[:, 1:]
            
            src_mask, tgt_mask = generate_mask(src, tgt_in, device)
            logits = model(src, tgt_in, src_mask=src_mask, tgt_mask=tgt_mask, memory_mask=src_mask)
            loss = loss_fn(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
            valid_loss += loss.item()
    
    print(f"Epoch {e} => Train Loss {epoch_loss / len(train_loader)} Train PPL {math.exp(epoch_loss / len(train_loader))}")
    print(f"Epoch {e} => Vaid Loss {valid_loss / len(valid_loader)} Valid PPL {math.exp(valid_loss / len(valid_loader))}")

torch.save(model.state_dict(), 'model.pt')


Epoch 0 => Train Loss 4.15803920855081 Train PPL 63.9460147938143
Epoch 0 => Vaid Loss 3.2433015697101166 Valid PPL 25.61816242635035
Epoch 1 => Train Loss 3.0796020839707965 Train PPL 21.749746101467203
Epoch 1 => Vaid Loss 2.6508969977880135 Valid PPL 14.166740482640463
Epoch 2 => Train Loss 2.438406428576566 Train PPL 11.454772192912776
Epoch 2 => Vaid Loss 2.082813606827099 Valid PPL 8.027022055638872
Epoch 3 => Train Loss 1.9594577886984736 Train PPL 7.095478773176655
Epoch 3 => Vaid Loss 1.8156690316998512 Valid PPL 6.14518612676874
Epoch 4 => Train Loss 1.6804517712362013 Train PPL 5.367980522606029
Epoch 4 => Vaid Loss 1.6757644856646217 Valid PPL 5.342878141926892
Epoch 5 => Train Loss 1.4915613697488928 Train PPL 4.444028877079208
Epoch 5 => Vaid Loss 1.6182349434074683 Valid PPL 5.044179192820308
Epoch 6 => Train Loss 1.352497707904698 Train PPL 3.867072295321762
Epoch 6 => Vaid Loss 1.600491096113149 Valid PPL 4.955465436739591
Epoch 7 => Train Loss 1.2380533780295417 Train

### Load trained Transformer and test

In [7]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'


model = TransformerWrapper(emb_size=512, src_vocab_size=len(de_vocab), tgt_vocab_size=len(en_vocab), device=device)
model.load_state_dict(torch.load('model.pt'))
model.to(device)

TransformerWrapper(
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (wq): Linear(in_features=512, out_features=512, bias=True)
            (wk): Linear(in_features=512, out_features=512, bias=True)
            (wv): Linear(in_features=512, out_features=512, bias=True)
            (out_proj): Linear(in_features=512, out_features=512, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (linear1): Linear(in_features=512, out_features=1024, bias=True)
          (droput): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=1024, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
 

In [20]:
def translate(model, src: str, src_ln: str = 'de', max_len: int = 50, device=None):
    model.eval()
    if src_ln == 'de':
        tgt_vocab = en_vocab.get_itos()
    else:
        tgt_vocab = de_vocab.get_itos()
    
    src_ids = trasnform_pipeline[src_ln](src.rstrip('\n'))
    src_ids = torch.cat((torch.tensor([BOS_IDX]), torch.tensor(src_ids), torch.tensor([EOS_IDX])))
    src_ids = src_ids.unsqueeze(0).to(device)

    src_mask = (src_ids != PAD_IDX).unsqueeze(1).unsqueeze(2).to(device)

    with torch.no_grad():
        memory = model.encode(src_ids, src_mask)

    pred_ids = [BOS_IDX]
    for i in range(max_len):
        pred_tensor = torch.LongTensor(pred_ids).unsqueeze(0).to(device)
        tgt_seq_mask = (torch.triu(torch.ones((pred_tensor.shape[-1], pred_tensor.shape[-1]), device=device)) == 1).transpose(0, 1)
        tgt_padding_mask = (pred_tensor != PAD_IDX).unsqueeze(1).unsqueeze(2)
        tgt_mask = tgt_seq_mask & tgt_padding_mask

        with torch.no_grad():
            preds = model.decode(pred_tensor, memory, tgt_mask)
            preds = model.generator(preds)
        
        pred_id = preds.argmax(-1)[:, -1].item()
        pred_ids.append(pred_id)

        if pred_id == EOS_IDX:
            break

    pred = []
    for id in pred_ids[1:-1]:
        pred.append(tgt_vocab[id])
        
    return ' '.join(pred)

print(translate(model, 'Zwei Männer spielen Fußball', device=device))
print(translate(model, 'Auf dem Rasen spielen zwei Männer Fußball.', device=device))
print(translate(model, 'Zwei Männer spielen an einem regnerischen Tag auf dem Rasen Fußball.', device=device))

Two men are playing soccer .
The two men are playing soccer on the grass .
Two men are playing soccer on a busy day .
