In [3]:
# Set tokenizer and build vocabulary
from torchtext.datasets import Multi30k
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

en_tokenizer = get_tokenizer('spacy', language='en_core_web_md')
de_tokenizer = get_tokenizer('spacy', language='de_core_news_md')

def yield_tokens(data_iter, language: str):
    for text in data_iter:
        if language == 'en':
            yield en_tokenizer(text[0])
        elif language == 'de':
            yield de_tokenizer(text[1])

UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
special_tokens = ['<unk>', '<pad>', '<bos>', '<eos>']
en_vocab = build_vocab_from_iterator(yield_tokens(Multi30k(split='train', language_pair=('en', 'de')), 'en'), specials=special_tokens, special_first=True, min_freq=3)
de_vocab = build_vocab_from_iterator(yield_tokens(Multi30k(split='train', language_pair=('en', 'de')), 'de'), specials=special_tokens, special_first=True, min_freq=3)
en_vocab.set_default_index(UNK_IDX) # oov 일때 반환하는 토큰
de_vocab.set_default_index(UNK_IDX)

In [4]:
# Set preprocess pipeline
en_pipeline = lambda x: en_vocab(en_tokenizer(x))
de_pipeline = lambda x: de_vocab(de_tokenizer(x))
trasnform_pipeline = {'en': en_pipeline, 'de': de_pipeline}
print(en_pipeline('Several men in hard hats are operating a giant pulley system.'))
print(de_pipeline('Mehrere Männer mit Schutzhelmen bedienen ein Antriebsradsystem.'))

[166, 37, 8, 336, 288, 18, 1225, 4, 759, 4497, 2958, 6]
[85, 32, 11, 848, 2209, 16, 0, 5]


In [5]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence


def collate_func(batch, src_ln: str = 'en', tgt_ln: str = 'de', batch_first: bool = True):
    src_batch, tgt_batch = [], []
    for src, tgt in batch:
        src_ids = trasnform_pipeline[src_ln](src.rstrip('\n'))
        tgt_ids = trasnform_pipeline[tgt_ln](tgt.rstrip('\n'))
        src_ids = torch.cat((torch.tensor([BOS_IDX]), torch.tensor(src_ids), torch.tensor([EOS_IDX])))
        tgt_ids = torch.cat((torch.tensor([BOS_IDX]), torch.tensor(tgt_ids), torch.tensor([EOS_IDX])))
        src_batch.append(src_ids)
        tgt_batch.append(tgt_ids)

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX, batch_first=batch_first)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX, batch_first=batch_first)
    return src_batch, tgt_batch



train_iter, valid_iter, test_iter = Multi30k(split=('train', 'valid', 'test'), language_pair=('en', 'de'))
train_loader = DataLoader(train_iter, batch_size=2, collate_fn=collate_func)
valid_loader = DataLoader(valid_iter, batch_size=1, collate_fn=collate_func)
test_loader = DataLoader(test_iter, batch_size=1, collate_fn=collate_func)

In [67]:
# Define Transformer
# I follow style of official Pytorch Transformer source code and adjust it simply
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import ModuleList
from torch.nn.init import xavier_uniform_
from torch.nn.modules import LayerNorm
from typing import Union, Callable, Optional, Any, Tuple


class Transformer(nn.Module):
    def __init__(self, d_model: int = 512, n_head: int = 8, num_encoder_layers: int = 6, num_decoder_layers: int = 6, \
        dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, \
        layer_norm_eps: float = 1e-5, batch_first: bool = True, norm_first: bool = False, device=None, dtype=None) -> None:
        
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(Transformer, self).__init__()
   
        encoder_layer = TransformerEncoderLayer(d_model, n_head, dim_feedforward, dropout, activation, layer_norm_eps, batch_first, norm_first, **factory_kwargs)
        encoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        decoder_layer = TransformerDecoderLayer(d_model, n_head, dim_feedforward, dropout, activation, layer_norm_eps, batch_first, norm_first, **factory_kwargs)
        decoder_norm = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)

        self._reset_params()

        self.d_model = d_model
        self.n_head = n_head
        self.batch_first = batch_first
    
    def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None, \
                memory_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, \
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:

        memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        output  = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
        return output
    
    def _reset_params(self):
        for p in self.parameters():
            if p.dim() > 1:
                xavier_uniform_(p)


class TransformerEncoder(nn.Module):
    def __init__(self, encoder_layer, num_layers, norm=None):
        super(TransformerEncoder, self).__init__()
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
    
    def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        output = src
        for mod in self.layers:
            output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
        
        if self.norm is not None:
            output = self.norm(output)
        
        return output

class TransformerDecoder(nn.Module):
    def __init__(self, decoder_layer, num_layers, norm=None):
        super(TransformerDecoder, self).__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm
    
    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, \
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        output = tgt
        for mod in self.layers:
            output = mod(output, memory, tgt_mask=tgt_mask, memory_mask=memory_mask, tgt_key_padding_mask=tgt_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask)
        
        if self.norm is not None:
            output = self.norm(output)
        
        return output

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model: int, n_head: int, dim_feedforward: int = 2048, dropout: float = 0.1, \
                activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,  
                device=None, dtype=None) -> None:

        factory_kwargs = {'device': device, 'dtype': dtype}
        super(TransformerEncoderLayer, self).__init__()
        
        # Define params
        self.self_attn = MultiheadAttention(d_model, n_head, dropout, batch_first, **factory_kwargs)
        self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
        self.droput = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
        
        self.norm_first = norm_first
        self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = activation
    
    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        x = src
        if self.norm_first:
            x = x + self._self_attn_block(self.norm1(src), src_mask, src_key_padding_mask)
            x = x + self._feedforward_block(self.norm2(x))
        else:
            x = self.norm1(x + self._self_attn_block(x, src_mask, src_key_padding_mask))
            x = self.norm2(x + self._feedforward_block(x))
        return x

    def _self_attn_block(self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]) -> Tensor:
        x = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0]
        return self.dropout1(x)
    
    def _feedforward_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.droput(self.activation(self.linear1(x))))
        return self.dropout2(x)

class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model: int, n_head: int, dim_feedforward: int = 2048, droput: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, \
                layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(TransformerDecoderLayer, self).__init__()

        # Define params
        self.self_attn = MultiheadAttention(d_model, n_head, droput, batch_first, **factory_kwargs)
        self.multihead_attn = MultiheadAttention(d_model, n_head, droput, batch_first, **factory_kwargs)
        self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
        self.droput = nn.Dropout(droput)
        self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
        
        self.norm_first = norm_first
        self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.dropout1 = nn.Dropout(droput)
        self.dropout2 = nn.Dropout(droput)
        self.dropout3 = nn.Dropout(droput)

        self.activation = activation
    
    def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, \
                memory_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        x = tgt
        if self.norm_first:
            x = x + self._self_attn_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
            x = x + self._multihead_attn_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask)
            x = x + self._feedforward_block(self.norm3(x))
        else:
            x = self.norm1(x + self._self_attn_block(x, tgt_mask ,tgt_key_padding_mask))
            x = self.norm2(x + self._multihead_attn_block(x, memory, memory_mask, memory_key_padding_mask))
            x = self.norm3(x + self._feedforward_block(x))
        return x
    
    def _self_attn_block(self, x: Tensor, attn_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None) -> Tensor:
        x = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0]
        return self.dropout1(x)

    def _multihead_attn_block(self, x: Tensor, mem: Tensor, attn_mask: Optional[Tensor] = None, key_padding_mask: Optional[Tensor] = None) -> Tensor:
        x = self.multihead_attn(x, mem, mem, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False)[0]
        return self.dropout2(x)
    
    def _feedforward_block(self, x: Tensor) -> Tensor:
        x = self.linear2(self.droput(self.activation(self.linear1(x))))
        return self.dropout3(x)


class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0, batch_first=False, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(MultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.batch_first = batch_first
        self.head_dim = int(embed_dim // num_heads)
        assert self.embed_dim == self.head_dim * num_heads

        self.wq = nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs)
        self.wk = nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs)
        self.wv = nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False, **factory_kwargs)

    def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, \
                attn_mask: Optional[Tensor] = None, average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:

        query, key, value = self.wq(query), self.wk(key), self.wv(value)
        query, key ,value = self._split_heads(query), self._split_heads(key), self._split_heads(value)
        attn_out = self._attention(query, key, value, key_padding_mask, attn_mask)
        attn_out = self.out_proj(attn_out)
        return attn_out
    
    def _split_heads(self, proj: Tensor) -> Tensor:
        if self.batch_first:    # (N, L, E)
            bs = proj.size(0)
            proj = proj.view(bs, -1, self.num_heads, self.head_dim)    # (N, L, H, E_Hi)
            proj = proj.transpose(1, 2) # (N, H, L, E_Hi)
        else:   # (L, N, E)
            bs = proj.size(1)
            proj = proj.view(-1, bs, self.num_heads, self.head_dim)    # (L, N, H, E_Hi)
        return proj
    
    def _attention(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None) -> Tensor:
        if self.batch_first:
            score = torch.matmul(query, key.transpose(-1, -2))  # (N, H, L, L), score = Q * K^T
            score = score / math.sqrt(query.size(-1)) # score = Q * K^T divided by sqrt(E_Hi)
            print('score_shape', score.shape)
            if key_padding_mask is not None:
                score = score.masked_fill(key_padding_mask == 0, float('-inf'))
                print(score)
            if attn_mask is not None:
                score = score.masked_fill(attn_mask == 0, float('-inf'))
                print(score)

            softmax = F.softmax(score, dim=-1)
            attn_out = torch.matmul(softmax, value)   # (N, H, L, E_Hi)
            attn_out = attn_out.transpose(1, 2) # (N, L, H, E_Hi)
            attn_out = attn_out.contiguous().view(attn_out.size(0), -1, self.embed_dim)  # (N, L, E)
            return attn_out
        else:
            raise NotImplementedError()


def _get_clones(module, N):
    return ModuleList([copy.deepcopy(module) for i in range(N)])


In [68]:
import math

class PositionalEncoding(nn.Module):
    def __init__(self, emb_size: int = 512, dropout: float = 0.1, max_len: int = 5000) -> None:
        super(PositionalEncoding, self).__init__()
        pos = torch.arange(0, max_len).reshape(max_len, 1)
        val = torch.exp(-torch.arange(0, emb_size, 2) / emb_size * math.log(10000))
        pos_encoding = torch.zeros((max_len, emb_size))
        pos_encoding[:, 0::2] = torch.sin(pos * val)
        pos_encoding[:, 1::2] = torch.cos(pos * val)
        pos_encoding = pos_encoding.unsqueeze(0)    # batch first
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_encoding', pos_encoding)
    
    def forward(self, token_embedding: Tensor) -> Tensor:
        return self.dropout(token_embedding + self.pos_encoding[:, :token_embedding.size(1), :])

class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size: int = 512) -> None:
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size
    
    def forward(self, tokens: Tensor) -> Tensor:
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

In [69]:
class TransformerWrapper(nn.Module):
    def __init__(self, emb_size: int, src_vocab_size: int, tgt_vocab_size: int, device=None, dtype=None):
        super().__init__()
        self.transformer = Transformer(d_model=emb_size, batch_first=True)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.pos_encoding = PositionalEncoding()
    
    def forward(self, src: Tensor, tgt: Tensor, src_mask: Optional[Tensor] = None, tgt_mask: Optional[Tensor] = None, src_padding_mask: Optional[Tensor] = None, tgt_padding_mask: Optional[Tensor] = None, memory_padding_mask: Optional[Tensor] = None):
        src_emb = self.pos_encoding(self.src_tok_emb(src))
        tgt_emb = self.pos_encoding(self.tgt_tok_emb(tgt))
        output = self.transformer(src_emb, tgt_emb, src_mask=src_mask, tgt_mask=tgt_mask, src_key_padding_mask=src_padding_mask, tgt_key_padding_mask=tgt_padding_mask, memory_key_padding_mask=memory_padding_mask)
        return self.generator(output)

    def encode(self, src: Tensor, src_mask: Tensor) -> Tensor:
        src_emb = self.pos_encoding(self.src_tok_emb(src))
        return self.transformer.encoder(src_emb, src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor) -> Tensor:
        tgt_emb = self.pos_encoding(self.tgt_tok_emb(tgt))
        return self.transformer.decoder(tgt_emb, memory, tgt_mask)
        

In [70]:
import pytorch_lightning as pl
from torch import optim

class TransformerTrainer(pl.LightningModule):
    def __init__(self, model) -> None:
        super().__init__()
        self.transformer = model
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
    
    def training_step(self, batch, batch_idx):
        src, tgt = batch    # (N, L)
        logits = self.transformer(src, tgt[:, :-1])    # (N, L, tgt_vacab_size)
        loss = self.loss_fn(logits.reshape(-1, logits.size(-1)), tgt[:, 1:].reshape(-1))
        self.log('train loss', loss)
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [71]:
def generate_mask(src: Tensor, tgt: Tensor) -> Tuple[Tensor]:
    """
    Args:
        src: [N, L]
        tgt: [N, L]
    """
    src_seq_len = src.shape[-1] # L
    tgt_seq_len = tgt.shape[-1]

    """
    tgt mask
    1 0 0 0 0 
    1 1 0 0 0
    1 1 1 0 0
    1 1 1 1 0
    1 1 1 1 1
    """
    tgt_mask = (torch.triu(torch.ones(tgt_seq_len, tgt_seq_len)) == 1).transpose(0, 1).float()
    src_mask = torch.ones((src_seq_len, src_seq_len)).type(torch.bool)
    src_padding_mask = (src != PAD_IDX).float().unsqueeze(1).unsqueeze(2)
    tgt_padding_mask = (tgt != PAD_IDX).float().unsqueeze(1).unsqueeze(2)

    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask
    




In [78]:
model = TransformerWrapper(emb_size=512, src_vocab_size=len(en_vocab), tgt_vocab_size=len(de_vocab))
for src, tgt in train_loader:
    tgt_input = tgt[:, :-1]
    src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = generate_mask(src, tgt_input)
    print(src_mask.shape)
    print(src_padding_mask)
    print(src_padding_mask.shape)
    print(tgt_padding_mask.shape)
    logits = model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask, src_padding_mask=src_padding_mask, tgt_padding_mask=tgt_padding_mask, memory_padding_mask=src_padding_mask)
    print(logits)
    break

torch.Size([20, 20])
tensor([[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
           0., 0., 0.]]],


        [[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1.]]]])
torch.Size([19, 19])
torch.Size([2, 1, 1, 20])
torch.Size([2, 1, 1, 19])
score_shape torch.Size([2, 8, 20, 20])
tensor([[[[ 3.4218e+02, -8.6976e+02,  3.6668e+02,  ...,        -inf,
                  -inf,        -inf],
          [-1.2197e+03,  9.5816e+01,  3.3452e+02,  ...,        -inf,
                  -inf,        -inf],
          [-3.9936e+02,  3.9866e+02,  1.0080e+02,  ...,        -inf,
                  -inf,        -inf],
          ...,
          [-8.2364e+02, -2.2045e+02,  5.7169e+02,  ...,        -inf,
                  -inf,        -inf],
          [-5.8334e+02, -4.1499e+02,  2.2163e+02,  ...,        -inf,
                  -inf,        -inf],
          [-4.4546e+02, -1.4030e+02,  6.2033e+02,  ...,        -inf,
                  -inf,        -inf

RuntimeError: The size of tensor a (19) must match the size of tensor b (20) at non-singleton dimension 3

In [55]:
a = torch.randn((5, 5))
print(a)
mask = generate_mask(a, a)
tgt_mask = mask[1]
a = a.masked_fill(tgt_mask == 0, float('-inf'))
print(a)

tensor([[-0.0789,  0.6511,  0.3624,  1.0269, -1.7361],
        [ 0.6671,  1.7341,  0.1811,  0.4002,  0.8695],
        [-0.6923, -0.9453,  1.4208,  0.9455,  0.0064],
        [-0.1542, -1.2689, -1.1903,  0.4219, -0.6572],
        [-1.2416, -1.8163, -2.2015, -1.5558, -0.2969]])
tensor([[-0.0789,    -inf,    -inf,    -inf,    -inf],
        [ 0.6671,  1.7341,    -inf,    -inf,    -inf],
        [-0.6923, -0.9453,  1.4208,    -inf,    -inf],
        [-0.1542, -1.2689, -1.1903,  0.4219,    -inf],
        [-1.2416, -1.8163, -2.2015, -1.5558, -0.2969]])


In [None]:
trainer = pl.Trainer(max_epochs=1)
model = TransformerWrapper(emb_size=512, src_vocab_size=len(en_vocab), tgt_vocab_size=len(de_vocab))
pl_model = TransformerTrainer(model=model)
trainer.fit(model=pl_model, train_dataloaders=train_loader)