In [9]:
import math
import torch
from torch import nn
D_EMBEDDING = 512
MAX_TOKENS = 1  # 5000

In [10]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model=D_EMBEDDING, dropout=0.1, max_len=MAX_TOKENS):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)  # <1>
        self.d_model = d_model  # <2>
        self.max_len = max_len  # <3>
        pe = torch.zeros(max_len, d_model)  # <4>
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                             (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)  # <5>
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]  # <6>
        return self.dropout(x)

In [11]:
encode_position = PositionalEncoding()


In [12]:
import numpy as np
X = torch.tensor(np.arange(D_EMBEDDING * MAX_TOKENS).reshape(MAX_TOKENS, D_EMBEDDING))
X_encoded = encode_position(X)

In [13]:
X_encoded.shape

torch.Size([1, 1, 512])

In [21]:
from datasets import load_dataset  # <1>
opus = load_dataset('opus_books', 'de-en')
opus
sents = opus['train'].train_test_split(test_size=.1)
sents
next(iter(sents['test']))  # <1>
DEVICE = torch.device(
    'cuda' if torch.cuda.is_available()
    else 'cpu')
SRC = 'en'  # <1>
TGT = 'de'  # <2>
SOS, EOS = '<s>', '</s>'
PAD, UNK, MASK = '<pad>', '<unk>', '<mask>'
SPECIAL_TOKS = [SOS, PAD, EOS, UNK, MASK]
VOCAB_SIZE = 10_000

Found cached dataset opus_books (/home/hobs/.cache/huggingface/datasets/opus_books/de-en/1.0.0/e8f950a4f32dc39b7f9088908216cd2d7e21ac35f893d04d39eb594746af2daf)


  0%|          | 0/1 [00:00<?, ?it/s]

In [22]:
from tokenizers import ByteLevelBPETokenizer  # <3>
tokenize_src = ByteLevelBPETokenizer()
tokenize_src.train_from_iterator(
    [x[SRC] for x in sents['train']['translation']],
    vocab_size=10000, min_frequency=2,
    special_tokens=SPECIAL_TOKS)
PAD_IDX = tokenize_src.token_to_id(PAD)
tokenize_tgt = ByteLevelBPETokenizer()
tokenize_tgt.train_from_iterator(
    [x[TGT] for x in sents['train']['translation']],
    vocab_size=10000, min_frequency=2,
    special_tokens=SPECIAL_TOKS)
assert PAD_IDX == tokenize_tgt.token_to_id(PAD)









In [24]:
from torch import Tensor
from typing import Optional
class CustomDecoderLayer(nn.TransformerDecoderLayer):
    def forward(self, tgt: Tensor, memory: Tensor,
            tgt_mask: Optional[Tensor] = None,
            memory_mask: Optional[Tensor] = None,
            tgt_key_padding_mask: Optional[Tensor] = None,
            mem_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        """Like decode but returns multi-head attention weights."""
        tgt2 = self.self_attn(
            tgt, tgt, tgt, attn_mask=tgt_mask,
            key_padding_mask=tgt_key_padding_mask)[0]
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        tgt2, attention_weights = self.multihead_attn(
            tgt, memory, memory,  # <1>
            attn_mask=memory_mask,
            key_padding_mask=mem_key_padding_mask,
            need_weights=True)
        tgt = tgt + self.dropout2(tgt2)
        tgt = self.norm2(tgt)
        tgt2 = self.linear2(
            self.dropout(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout3(tgt2)
        tgt = self.norm3(tgt)
        return tgt, attention_weights  # <2>

In [25]:
class CustomDecoder(nn.TransformerDecoder):
    def __init__(self, decoder_layer, num_layers, norm=None):
        super().__init__(
            decoder_layer, num_layers, norm)

    def forward(self,
            tgt: Tensor, memory: Tensor,
            tgt_mask: Optional[Tensor] = None,
            memory_mask: Optional[Tensor] = None,
            tgt_key_padding_mask: Optional[Tensor] = None
            ) -> Tensor:
        """Like TransformerDecoder but cache multi-head attention"""
        self.attention_weights = []  # <1>
        output = tgt
        for mod in self.layers:
            output, attention = mod(
                output, memory, tgt_mask=tgt_mask,
                memory_mask=memory_mask,
                tgt_key_padding_mask=tgt_key_padding_mask)
            self.attention_weights.append(attention) # <2>

        if self.norm is not None:
            output = self.norm(output)

        return output

In [28]:
from einops import rearrange  # <1>

In [29]:
class TranslationTransformer(nn.Transformer):  # <2>
    def __init__(self,
            device=DEVICE,
            src_vocab_size: int = VOCAB_SIZE,
            src_pad_idx: int = PAD_IDX,
            tgt_vocab_size: int = VOCAB_SIZE,
            tgt_pad_idx: int = PAD_IDX,
            max_sequence_length: int = 100,
            d_model: int = 512,
            nhead: int = 8,
            num_encoder_layers: int = 6,
            num_decoder_layers: int = 6,
            dim_feedforward: int = 2048,
            dropout: float = 0.1,
            activation: str = "relu"
        ):

        decoder_layer = CustomDecoderLayer(
            d_model, nhead, dim_feedforward,  # <3>
            dropout, activation)
        decoder_norm = nn.LayerNorm(d_model)
        decoder = CustomDecoder(
            decoder_layer, num_decoder_layers,
            decoder_norm)  # <4>

        super().__init__(
            d_model=d_model, nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout, custom_decoder=decoder)

        self.src_pad_idx = src_pad_idx
        self.tgt_pad_idx = tgt_pad_idx
        self.device = device

        self.src_emb = nn.Embedding(
            src_vocab_size, d_model)  # <5>
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)

        self.pos_enc = PositionalEncoding(
            d_model, dropout, max_sequence_length)  # <6>
        self.linear = nn.Linear(
            d_model, tgt_vocab_size)  # <7>
    def _make_key_padding_mask(self, t, pad_idx):
        mask = (t == pad_idx).to(self.device)
        return mask

    def prepare_src(self, src, src_pad_idx):
        src_key_padding_mask = self._make_key_padding_mask(
            src, src_pad_idx)
        src = rearrange(src, 'N S -> S N')
        src = self.pos_enc(self.src_emb(src)
            * math.sqrt(self.d_model))
        return src, src_key_padding_mask
    def prepare_tgt(self, tgt, tgt_pad_idx):
        tgt_key_padding_mask = self._make_key_padding_mask(
            tgt, tgt_pad_idx)
        tgt = rearrange(tgt, 'N T -> T N')
        tgt_mask = self.generate_square_subsequent_mask(
            tgt.shape[0]).to(self.device)
        tgt = self.pos_enc(self.tgt_emb(tgt)
            * math.sqrt(self.d_model))
        return tgt, tgt_key_padding_mask, tgt_mask
    def forward(self, src, tgt):
        src, src_key_padding_mask = self.prepare_src(
            src, self.src_pad_idx)
        tgt, tgt_key_padding_mask, tgt_mask = self.prepare_tgt(
            tgt, self.tgt_pad_idx)
        memory_key_padding_mask = src_key_padding_mask.clone()
        output = super().forward(
            src, tgt, tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask)
        output = rearrange(output, 'T N E -> N T E')
        return self.linear(output)
    def init_weights(self):
        def _init_weights(m):
            if hasattr(m, 'weight') and m.weight.dim() > 1:
                nn.init.xavier_uniform_(m.weight.data)
        self.apply(_init_weights);  # <1>

In [30]:
class TranslationTransformer(nn.Transformer):
#    global supertransformer
    def __init__(self,
            device=DEVICE,
            src_vocab_size: int = 10000,
            src_pad_idx: int = PAD_IDX,
            tgt_vocab_size: int  = 10000,
            tgt_pad_idx: int = PAD_IDX,
            max_sequence_length: int = 100,
            d_model: int = 512,
            nhead: int = 8,
            num_encoder_layers: int = 6,
            num_decoder_layers: int = 6,
            dim_feedforward: int = 2048,
            dropout: float = 0.1,
            activation: str = "relu"
            ):
        decoder_layer = CustomDecoderLayer(
            d_model, nhead, dim_feedforward,
            dropout, activation)
        decoder_norm = nn.LayerNorm(d_model)
        decoder = CustomDecoder(
            decoder_layer, num_decoder_layers, decoder_norm)

        super().__init__(
            d_model=d_model, nhead=nhead,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout, custom_decoder=decoder)

        self.src_pad_idx = src_pad_idx
        self.tgt_pad_idx = tgt_pad_idx
        self.device = device
        self.src_emb = nn.Embedding(src_vocab_size, d_model)
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_enc = PositionalEncoding(
            d_model, dropout, max_sequence_length)
        self.linear = nn.Linear(d_model, tgt_vocab_size)

    def init_weights(self):
        def _init_weights(m):
            if hasattr(m, 'weight') and m.weight.dim() > 1:
                nn.init.xavier_uniform_(m.weight.data)
        self.apply(_init_weights);

    def _make_key_padding_mask(self, t, pad_idx=PAD_IDX):
        mask = (t == pad_idx).to(self.device)
        return mask

    def prepare_src(self, src, src_pad_idx):
        src_key_padding_mask = self._make_key_padding_mask(
            src, src_pad_idx)
        src = rearrange(src, 'N S -> S N')
        src = self.pos_enc(self.src_emb(src)
            * math.sqrt(self.d_model))
        return src, src_key_padding_mask

    def prepare_tgt(self, tgt, tgt_pad_idx):
        tgt_key_padding_mask = self._make_key_padding_mask(
            tgt, tgt_pad_idx)
        tgt = rearrange(tgt, 'N T -> T N')
        tgt_mask = self.generate_square_subsequent_mask(
            tgt.shape[0]).to(self.device)      # <1>
        tgt = self.pos_enc(self.tgt_emb(tgt)
            * math.sqrt(self.d_model))
        return tgt, tgt_key_padding_mask, tgt_mask

    def forward(self, src, tgt):
        src, src_key_padding_mask = self.prepare_src(
            src, self.src_pad_idx)
        tgt, tgt_key_padding_mask, tgt_mask = self.prepare_tgt(
            tgt, self.tgt_pad_idx)
        memory_key_padding_mask = src_key_padding_mask.clone()
        # supertransformer = super()
        # print(help(supertransformer.forward))
        # forward(
        #     src: torch.Tensor,
        #     tgt: torch.Tensor,
        #     src_mask: Optional[torch.Tensor] = None,
        #     tgt_mask: Optional[torch.Tensor] = None,
        #     memory_mask: Optional[torch.Tensor] = None,
        #     src_key_padding_mask: Optional[torch.Tensor] = None,
        #     tgt_key_padding_mask: Optional[torch.Tensor] = None,
        #     memory_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor method of __main__.TranslationTransformer instance
        output = super().forward(
            src=src,
            tgt=tgt,
            src_mask=None,
            tgt_mask=tgt_mask,
            memory_mask=None,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask,
            )
        output = rearrange(output, 'T N E -> N T E')
        return self.linear(output)
