In [2]:
# Set tokenizer and build vocabulary
from torchtext.datasets import Multi30k
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

en_tokenizer = get_tokenizer('spacy', language='en_core_web_md')
de_tokenizer = get_tokenizer('spacy', language='de_core_news_md')

def yield_tokens(data_iter, language: str):
    for text in data_iter:
        if language == 'en':
            yield en_tokenizer(text[0])
        elif language == 'de':
            yield de_tokenizer(text[1])

UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
special_tokens = ['<unk>', '<pad>', '<bos>', '<eos>']
en_vocab = build_vocab_from_iterator(yield_tokens(Multi30k(split='train', language_pair=('en', 'de')), 'en'), specials=special_tokens, special_first=True, min_freq=3)
de_vocab = build_vocab_from_iterator(yield_tokens(Multi30k(split='train', language_pair=('en', 'de')), 'de'), specials=special_tokens, special_first=True, min_freq=3)
en_vocab.set_default_index(UNK_IDX) # oov 일때 반환하는 토큰
de_vocab.set_default_index(UNK_IDX)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Set preprocess pipeline
en_pipeline = lambda x: en_vocab(en_tokenizer(x))
de_pipeline = lambda x: de_vocab(de_tokenizer(x))
print(en_pipeline('Several men in hard hats are operating a giant pulley system.'))
print(de_pipeline('Mehrere Männer mit Schutzhelmen bedienen ein Antriebsradsystem.'))

[165, 36, 7, 335, 287, 17, 1224, 4, 758, 4496, 2957, 5]
[84, 31, 10, 847, 2208, 15, 0, 4]


In [4]:
from torch.utils.data import DataLoader

train_iter = Multi30k(split='train', language_pair=('en', 'de'))
train_loader = DataLoader(train_iter, batch_size=2)
for src, data in train_loader:
    pass

In [5]:
# Define Transformer
# I follow style of official Pytorch Transformer source code and adjust it simply
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import ModuleList
from torch.nn.init import xavier_uniform_
from torch.nn.modules import LayerNorm
from torch.nn.parameter import Parameter
import pytorch_lightning as pl
from typing import Union, Callable, Optional, Any, Tuple


class TransformerWrapper(pl.LightningModule):
    def __init__(self, transformer):
        self.model = transformer
    def training_step(self, *args: Any, **kwargs: Any):
        return super().training_step(*args, **kwargs)
    def validation_step(self, *args: Any, **kwargs: Any):
        return super().validation_step(*args, **kwargs)
    def configure_optimizers(self) -> Any:
        return super().configure_optimizers()


class Transformer(nn.Module):
    def __init__(self, d_model: int = 512, n_head: int = 8, num_encoder_layers: int = 6, num_decoder_layers: int = 6, \
        dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, \
        layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None) -> None:
        
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(Transformer, self).__init__()
   
        encoder_layer = TransformerEncoderLayer(d_model, n_head, dim_feedforward, dropout, activation, layer_norm_eps, batch_first, norm_first, **factory_kwargs)
        encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        decoder_layer = TransformerDecoderLayer(d_model, n_head, dim_feedforward, dropout, activation, layer_norm_eps, batch_first, norm_first, **factory_kwargs)
        decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)

        self._reset_params()

        self.d_model = d_model
        self.n_head = n_head
        self.batch_first = batch_first
    
    def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None, \
                memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, \
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:

        memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        output  = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
        return output
    
    def _reset_params(self):
        for p in self.parameters():
            if p.dim() > 1:
                xavier_uniform_(p)


class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
    
    def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        output = src
        for mod in self.layers:
            output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
        
        if self.norm is not None:
            output = self.norm(output)
        
        return output

class TransformerDecoder(nn.Module):
    def __init__(self, decoder_layer, num_layers, norm=None):
        super(TransformerDecoder, self).__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
    
    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, \
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        output = tgt
        for mod in self.layers:
            output = mod(output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
        
        if self.norm is not None:
            output = self.norm(output)
        
        return output

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model: int, n_head: int, dim_feedforward: int = 2048, dropout: float = 0.1, \
                activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,  
                device=None, dtype=None) -> None:

        factory_kwargs = {'device': device, 'dtype': dtype}
        super(TransformerEncoderLayer, self).__init__()
        
        # Define params
        self.self_attn = MultiheadAttention(d_model, n_head, dropout, batch_first, **factory_kwargs)
        self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
        self.droput = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
        
        self.norm_first = norm_first
        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = activation
    
    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        x = src
        if self.norm_first:
            x = x + self._self_attn_block(self.norm1(src), src_mask, src_key_padding_mask)
            x = x + self._feedforward_block(self.norm2(x))
        else:
            x = self.norm1(x + self._self_attn_block(x, src_mask, src_key_padding_mask))
            x = self.norm2(x + self._feedforward_block(x))
        return x

    def _self_attn_block(self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
        x = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0]
        return self.dropout1(x)
    
    def _feedforward_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.droput(self.activation(self.linear1(x))))
        return self.dropout2(x)

class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model: int, n_head: int, dim_feedforward: int = 2048, droput: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, \
                layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(TransformerDecoderLayer, self).__init__()

        # Define params
        self.self_attn = MultiheadAttention(d_model, n_head, droput, batch_first, **factory_kwargs)
        self.multihead_attn = MultiheadAttention(d_model, n_head, droput, batch_first, **factory_kwargs)
        self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
        self.droput = nn.Dropout(droput)
        self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
        
        self.norm_first = norm_first
        self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.dropout1 = nn.Dropout(droput)
        self.dropout2 = nn.Dropout(droput)
        self.dropout3 = nn.Dropout(droput)

        self.activation = activation
    
    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, \
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        x = tgt
        if self.norm_first:
            x = x + self._self_attn_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
            x = x + self._multihead_attn_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask)
            x = x + self._feedforward_block(self.norm3(x))
        else:
            x = self.norm1(x + self._self_attn_block(x, tgt_mask ,tgt_key_padding_mask))
            x = self.norm2(x + self._multihead_attn_block(x, memory, memory_mask, memory_key_padding_mask))
            x = self.norm3(x + self._feedforward_block(x))
        return x
    
    def _self_attn_block(self, x: Tensor, attn_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None) -> Tensor:
        x = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0]
        return self.dropout1(x)

    def _multihead_attn_block(self, x: Tensor, mem: Tensor, attn_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None) -> Tensor:
        x = self.multihead_attn(x, mem, mem, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0]
        return self.dropout2(x)
    
    def _feedforward_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.droput(self.activation(self.linear1(x))))
        return self.dropout3(x)


class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0, batch_first=False, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(MultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_first = batch_first
        self.head_dim = int(embed_dim // num_heads)
        assert self.embed_dim == self.head_dim * num_heads

        self.wq = nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs)
        self.wk = nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs)
        self.wv = nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs)

    def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, \
                attn_mask: Optional[Tensor] = None, average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:

        query, key, value = self.wq(query), self.wk(key), self.wv(value)
        query, key ,value = self._split_heads(query), self._split_heads(key), self._split_heads(value)
        attn_out = self._attention(query, key, value)
        attn_out = self.out_proj(attn_out)
        return attn_out
    
    def _split_heads(self, proj: Tensor) -> Tensor:
        if self.batch_first:    # (N, L, E)
            bs = proj.size(0)
            proj = proj.view(bs, -1, self.num_heads, self.head_dim)    # (N, L, H, E_Hi)
            proj = proj.transpose(1, 2) # (N, H, L, E_Hi)
        else:   # (L, N, E)
            bs = proj.size(1)
            proj = proj.view(-1, bs, self.num_heads, self.head_dim)    # (L, N, H, E_Hi)
        return proj
    
    def _attention(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
        if self.batch_first:
            score = torch.matmul(query, key.transpose(-1, -2))  # (N, H, L, L), score = Q * K^T
            score = score / math.sqrt(query.size(-1)) # score = Q * K^T divided by sqrt(E_Hi)
            softmax = F.softmax(score, dim=-1)
            attn_out = torch.matmul(softmax, value)   # (N, H, L, E_Hi)
            attn_out = attn_out.transpose(1, 2) # (N, L, H, E_Hi)
            attn_out = attn_out.contiguous().view(attn_out.size(0), -1, self.embed_dim)  # (N, L, E)
            return attn_out
        else:
            raise NotImplementedError()


def _get_clones(module, N):
    return ModuleList([copy.deepcopy(module) for i in range(N)])


In [17]:
import math

class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int = 512, dropout: float = 0.1, max_len: int = 5000) -> None:
        super(PositionalEncoding, self).__init__()
        pos = torch.arange(0, max_len).reshape(max_len, 1)
        val = torch.exp(-torch.arange(0, emb_size, 2) / emb_size * math.log(10000))
        pos_encoding = torch.zeros((max_len, emb_size))
        pos_encoding[:, 0::2] = torch.sin(pos * val)
        pos_encoding[:, 1::2] = torch.cos(pos * val)
        pos_encoding = pos_encoding.unsqueeze(0)    # batch first
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_encoding', pos_encoding)
    
    def forward(self, token_embedding: Tensor) -> Tensor:
        print(self.pos_encoding.shape)
        return self.dropout(token_embedding + self.pos_encoding[:, :token_embedding.size(1), :])

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size: int = 512) -> None:
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size
    
    def forward(self, tokens: Tensor) -> Tensor:
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

In [19]:
class Translator(nn.Module):
    def __init__(self, emb_size, src_vocab_size: int, tgt_vocab_size: int, device=None, dtype=None):
        super().__init__()
        self.transformer = Transformer()
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.pos_encoding = PositionalEncoding()
    
    def forward(self, src: Tensor, tgt: Tensor, src_mask: Tensor, tgt_mask: Tensor, src_padding_mask: Tensor, tgt_padding_mask: Tensor, memory_padding_mask: Tensor):
        src_emb = self.pos_encoding(self.src_tok_emb(src))
        tgt_emb = self.pos_encoding(self.tgt_tok_emb(tgt))
        output = self.transformer(src_emb, tgt_emb, src_mask=src_mask, tgt_mask=tgt_mask, src_key_padding_mask=src_padding_mask, tgt_key_padding_mask=tgt_padding_mask, memory_key_padding_mask=memory_padding_mask)
        return self.generator(output)

    def encode(self, src: Tensor, src_mask: Tensor) -> Tensor:
        src_emb = self.pos_encoding(self.src_tok_emb(src))
        return self.transformer.encoder(src_emb, src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor) -> Tensor:
        tgt_emb = self.pos_encoding(self.tgt_tok_emb(tgt))
        return self.transformer.decoder(tgt_emb, memory, tgt_mask)
        