Implementing [Attention Is All You Need, Ashish Vaswani et al.](https://arxiv.org/pdf/1706.03762) from scratch, with PyTorch.

In [None]:
# update the version of the library 'datasets'

# !pip install --upgrade datasets
# !pip install tokenizers

In [None]:
import torch
from typing import Optional

## Implement Transformer

In [None]:
def positional_encoding(x: torch.Tensor):
    """
    x: batch, seq_len, d_model
    assume d_model is even
    """
    batch, seq_len, d_model = x.size()
    positinal_encoding_indices = (torch.arange(0, seq_len, dtype=torch.float32).unsqueeze(0).unsqueeze(-1) / 10000 ) ** (torch.arange(0, d_model, 2, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / d_model)
    print(positinal_encoding_indices)
    positional_encoding_values = torch.cat([torch.sin(positinal_encoding_indices), torch.cos(positinal_encoding_indices)], dim=-1)
    return positional_encoding_values


In [None]:
def scaled_dot_product_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,  mask: Optional[torch.Tensor] = None):
    """
    q: ..., seq_len, d_model
    k: ..., seq_len, d_model
    v: ..., seq_len, d_model
    mask: batch, seq_len, seq_len. if mask[b, i, j] == 0, then attention[b, i, j] = 0 (do not attend.)
    """
    dim = q.size(-1)
    w = torch.matmul(q, k.transpose(-2,-1)) / dim ** 0.5
    if mask is not None:
        w = w.masked_fill(mask == 0, -1e9)
    w = torch.softmax(w, dim=-1)
    return torch.matmul(w, v)

In [None]:
class MultiHeadAttenstion(torch.nn.Module):
    def __init__(self, d_model: int = 64, d_q: int = 16, d_v: int = 16,d_k: int = 16, num_heads: int = 4, dropout: float = 0.2):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_k
        self.d_v = d_v
        self.d_q = d_q
        self.dropout = dropout

        self.w_q = torch.nn.Linear(self.d_model, self.d_q * self.num_heads)
        self.w_k = torch.nn.Linear(self.d_model, self.d_k * self.num_heads)
        self.w_v = torch.nn.Linear(self.d_model, self.d_v * self.num_heads)
        self.w_o = torch.nn.Linear(self.d_v * self.num_heads, self.d_model)

        self.dropout_q = torch.nn.Dropout(self.dropout)
        self.dropout_k = torch.nn.Dropout(self.dropout)
        self.dropout_v = torch.nn.Dropout(self.dropout)
        self.dropout_o = torch.nn.Dropout(self.dropout)

    def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None):
        batch_size = q.size(0)

        qw = self.dropout_q(self.w_q(q)).view(batch_size, self.num_heads, -1, self.d_q)
        kw = self.dropout_k(self.w_k(k)).view(batch_size, self.num_heads, -1, self.d_k)
        vw = self.dropout_v(self.w_v(v)).view(batch_size, self.num_heads, -1, self.d_v)

        o = scaled_dot_product_attention(qw, kw, vw, mask = mask).view(batch_size, -1, self.d_v * self.num_heads)
        return self.dropout_o(self.w_o(o)).view(batch_size, -1, self.d_model)




In [None]:
class TransformerEncoderLayer(torch.nn.Module):
    def __init__(self, d_model: int = 64 , num_heads: int = 4, d_ff: int = 128, dropout: float = 0.2):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.dropout = dropout
        self.feedforward = torch.nn.Sequential(
            torch.nn.Linear(self.d_model, self.d_ff),
            torch.nn.Dropout(self.dropout),
            torch.nn.Linear(self.d_ff, self.d_model),
            torch.nn.LayerNorm(self.d_model)
        )
        self.ff1 = torch.nn.Linear(self.d_model, self.d_ff)
        self.ff2 = torch.nn.Linear(self.d_ff, self.d_model)
        self.dropout_ff1 = torch.nn.Dropout(self.dropout)
        self.dropout_ff2 = torch.nn.Dropout(self.dropout)

        self.layer_norm_1 = torch.nn.LayerNorm(self.d_model)
        self.layer_norm_2 = torch.nn.LayerNorm(self.d_model)

        if d_model % num_heads != 0:
            raise ValueError("d_model must be divisible by num_heads")
        self.self_attention = MultiHeadAttenstion(d_model, d_model // num_heads, d_model // num_heads , d_model // num_heads, num_heads, dropout)

    def forward(self,x: torch.Tensor, mask: Optional[torch.Tensor] = None):
        x1 = self.self_attention(x, x, x, mask)
        x2 = self.layer_norm_1(x + x1)

        x3 = self.feedforward(x2)
        x4 = self.layer_norm_2(x3 + x2)
        return x4

In [None]:
class TransformerDecoderLayer(torch.nn.Module):
    def __init__(self, d_model: int = 64, num_heads: int = 4, d_ff: int = 128, dropout: float=0.2):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.dropout = dropout

        self.feedforward = torch.nn.Sequential(
            torch.nn.Linear(self.d_model, self.d_ff),
            torch.nn.Dropout(self.dropout),
            torch.nn.Linear(self.d_ff, self.d_model),
            torch.nn.LayerNorm(self.d_model)
        )

        self.layer_norm_1 = torch.nn.LayerNorm(self.d_model)
        self.layer_norm_2 = torch.nn.LayerNorm(self.d_model)
        self.layer_norm_3 = torch.nn.LayerNorm(self.d_model)

        if d_model % num_heads != 0:
            raise ValueError("d_model must be divisible by num_heads")
        self.masked_self_attention = MultiHeadAttenstion(d_model, d_model // num_heads, d_model // num_heads , d_model // num_heads, num_heads, dropout)
        self.cross_attention = MultiHeadAttenstion(d_model, d_model // num_heads, d_model // num_heads , d_model // num_heads, num_heads, dropout)

    def forward(self, x: torch.Tensor, encoding: torch.Tensor, mask: torch.Tensor):
        x1 = self.layer_norm_1(self.masked_self_attention(x, x, x, mask)+x)
        # use encoding as key and value
        x2 = self.layer_norm_2(self.cross_attention(x1, encoding, encoding) + x1)

        x_ff = self.feedforward(x2)
        x3 = self.layer_norm_3(x_ff + x2)

        return x3



In [None]:
class TransformerEncoder(torch.nn.Module):
    def __init__(self, d_model: int = 64, num_heads: int = 4, d_ff: int = 128, dropout: float = 0.2, num_layers: int = 4):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.drop_out = dropout
        self.num_layers = num_layers

        self.encoder_layers = torch.nn.ModuleList([
            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor]=None):
        for encoder_layer in self.encoder_layers:
            x = encoder_layer(x, mask)

        return x


In [None]:
class TransformerDecoder(torch.nn.Module):
    def __init__(self, d_model: int = 64, num_heads: int = 4, d_ff: int = 128, dropout: float = 0.2, num_layers: int = 4):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.drop_out = dropout
        self.num_layers = num_layers

        self.decoder_layers = torch.nn.ModuleList([
            TransformerDecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)
        ])

    def forward(self, x:torch.Tensor, encoding: torch.Tensor, src_mask: Optional[torch.Tensor] = None, tgt_mask: Optional[torch.Tensor] = None, is_causal: bool = True): # mask should be src_mask + padding_mask
        # shape: batch, num_heads, seq_len, seq_len. element value 1 means passing the value, and 0 means ignoring the value.
        batch_size = x.size(0)
        seq_len = x.size(1)
        if tgt_mask is None:
            if is_causal:
                tgt_mask = torch.tril(torch.ones(batch_size, self.num_heads, seq_len, seq_len), diagonal = -1)
            else:
                tgt_mask = torch.ones(batch_size, self.num_heads, seq_len, seq_len)

        if src_mask is None:
            mask = tgt_mask
        else:
            if src_mask.dim() == 4:
                mask = src_mask + tgt_mask
            elif src_mask.dim() == 3:
                src_mask = src_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
                mask = src_mask + tgt_mask
            else:
                raise ValueError("The src mask should have shape of"
                " (Batchsize, num_sequence, num_sequences) or "
                "(batch_size, num_heads, num_sequence, num_sequecnes)")
            
        for decoder_layer in self.decoder_layers:
            x = decoder_layer(x, encoding, mask)
        return x

In [None]:
class Transfomer(torch.nn.Module):
    def __init__(self, num_encoder_layers: int = 6, num_decoder_layers: int = 6, model_dim: int = 128, num_heads: int = 4, encoder_ff_dim: int = 128, decoder_ff_dim: int = 128, output_dim: int = 128, dropout:float = 0.2):

        super().__init__()
        self.num_encoder_layers = num_encoder_layers
        self.num_decoder_layers = num_decoder_layers
        self.model_dim = model_dim
        self.num_heads = num_heads
        self.encoder_ff_dim = encoder_ff_dim
        self.decoder_ff_dim = decoder_ff_dim
        self.dropout = dropout
        self.output_dim = output_dim

        self.encoder = TransformerEncoder(model_dim, num_heads, encoder_ff_dim, dropout, num_encoder_layers)
        self.decoder = TransformerDecoder(model_dim, num_heads, encoder_ff_dim, dropout, num_decoder_layers)
        self.linear = torch.nn.Linear(model_dim, output_dim)
        self.dropout_linear = torch.nn.Dropout(dropout)

    def forward(self,
                src: torch.Tensor,
                tgt: torch.Tensor,
                src_mask: Optional[torch.Tensor] = None,
                tgt_mask: Optional[torch.Tensor] = None
                ):
        """
            src: input vectors, to pass through the encoder
            tgt: target vectors, it will pass through decoder, with encoded srcs.
            src_mask: source mask. Used for masking [PAD] tokens.
            tgt_mask: target_mask. Used for causal masking. If None is inputted,
            then automatically causal masking is applied.
        """

        encoder_memory = self.encoder(src, src_mask)
        output = self.decoder(tgt, encoder_memory, tgt_mask)
        logit = self.dropout_linear(self.linear(output))
        return torch.softmax(logit, dim = -1)

### Test codes

## Implement Byte-Pair Tokenizer

In [None]:
#class BytePairTokenizer:


## Entangle everything to build a Translator model

In [None]:
from tokenizers import Tokenizer
from tokenizers.models import BPE

class TransformerTranslator(torch.nn.Module):
    def __init__(
                self,
                src_tokenizer: Tokenizer,
                tgt_tokenizer: Tokenizer,
                num_encoder_layers: int = 6,
                 num_decoder_layers: int = 6,
                 model_dim: int = 128,
                 num_heads: int = 4,
                 encoder_ff_dim: int = 128,
                 decoder_ff_dim: int = 128,
                 dropout:float = 0.2,
                seq_len: int = 128

            ):
        """
        src_tokenizer(Tokenizer): tokenizer for source language
        tgt_tokenizer(Tokenizer): tokenizer for target language
        num_encoder_layers(int): number of encoder layers
        num_decoder_layers(int): number of decoder layers
        model_dim(int): dimension of model
        num_heads(int): number of heads
        encoder_ff_dim(int): dimension of feedforward network in encoder
        decoder_ff_dim(int): dimension of feedforward network in decoder
        dropout(float): dropout rate

        [UNK]: ID 0. Unknown tokens
        [CLS]: ID 1, class tokens. Used for denoting the start of sequence.
        [EOS]: ID 2, End of sequences token.
        [SEP]: ID 3, Separator token. Used for separating different
         sentences of  texts merged in one sequence.
        [PAD]: ID 4, PADDING Token.
        [MASK]: ID 5, Masking Token. Used for masked generation.

        The tokenization result, i.e. encoding will be torch.Tensor of integer sequences, of the shape (batch_size, seq_len)
        """
        super().__init__()

        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.num_encoder_layers = num_encoder_layers
        self.num_decoder_layers = num_decoder_layers
        self.model_dim = model_dim
        self.num_heads = num_heads
        self.encoder_ff_dim = encoder_ff_dim
        self.decoder_ff_dim = decoder_ff_dim
        self.seq_len = seq_len

        self.src_vocab_size = src_tokenizer.get_vocab_size()
        self.tgt_vocab_size = tgt_tokenizer.get_vocab_size()
        self.src_embedding = torch.nn.Embedding(num_embeddings=self.src_vocab_size,
                                            embedding_dim = model_dim)
        self.tgt_embedding = torch.nn.Embedding(num_embeddings=self.tgt_vocab_size,
                                                embedding_dim = model_dim)

        self.transformer = Transfomer(
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers = num_decoder_layers,
            model_dim = model_dim,
            num_heads = num_heads,
            encoder_ff_dim = encoder_ff_dim,
            decoder_ff_dim = decoder_ff_dim,
            output_dim = self.tgt_vocab_size,
            dropout = dropout
        )

    def forward(self, src: torch.Tensor, tgt:torch.Tensor, src_mask: Optional[torch.Tensor] = None, tgt_mask: Optional[torch.Tensor]=None, src_enc = None):
        """
        src(torch.Tensor): batch of source token sequences, Datatype is integer. shape: (batch_size, seq_len)
        tgt(torch.Tensor): batch of target token sequences,  Datatype is integer. shape: (batch_size, seq_len)
        src_mask(torch.Tensor): mask for source sequences. Usually used for masking [PAD] tokens. shape: (batch_size, seq_len, seq_len)
        tgt_mask(torch.Tensor): mask for target sequences. Usually used for masking [PAD] tokens for target sequences and causal masking. shape: (batch_size, seq_len, seq_len)

        return
            probs(torch.Tensor): batch of probabilities, shape: (batch_size, seq_len, tgt_vocab_size)
        """
        if src_mask is None:
            # Automatically build padding mask for source sequence
            src_mask = self.get_padding_mask(src)
        if tgt_mask is None:
            # Automatically build padding mask and causal mask for the target sequence, and join them.
            tgt_pad_mask = self.get_padding_mask(tgt)
            tgt_causal_mask = self.get_causal_mask(tgt)

            tgt_mask = tgt_pad_mask + tgt_causal_mask

        if src_enc is None:
            src_emb = self.compute_src_embedding(src)
            tgt_emb = self.compute_tgt_embedding(tgt)
            probs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask)

        else:
            tgt_emb = self.compute_tgt_embedding(tgt)
            probs = self.transformer.decoder(tgt_emb, src_enc, tgt_mask)
            
        return probs

    def compute_src_embedding(self, src: torch.Tensor):
        return positional_encoding(self.src_embedding(self.src_tokenizer.encode_batch(src)))

    def compute_tgt_embedding(self, tgt: torch.Tensor):
        return positional_encoding(self.tgt_embedding(self.tgt_tokenizer.encode_batch(tgt)))

    def get_padding_mask(self, seqs :torch.Tensor):
        """
        Compute padding mask for masking [PAD] tokens in the input sequence.

        seqs(torch.Tensor): tokenized source sequences. shape of (batch, sequence_length)
        Then outputs padding mask of (batch, num_heads, sequence_length, sequence_legnth), where output[b, k, i, j] = 0 if 
        seqs[b,i] = <pad> or seqs[b, j] = <pad>.

        """
        padding_tokens = torch.where(seqs == 4, 0, 1).unsqueeze(-1)
        padding_mask = padding_tokens @ padding_tokens.transpose(1,2)
        return padding_mask.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
    
    def get_causal_mask(self, seqs: torch.Tensor):
        """
        Construct causal mask. expect seqs to have shape of (batch, sequence_length).
        Then outputs mask of shape (batch, sequence_length, sequence_length)
        """
        batch_size = seqs.size(0)
        return torch.tril(torch.ones(batch_size, self.num_heads, self.seq_len, self.seq_len), -1)
        
        

    def generate(self, src: torch.Tensor):
        """
        src(torch.Tensor): batch of source sentences, Datatype is string. shape: (batch_size, seq_len)

        return
            tgt(torch.Tensor): batch of target sentences,  Datatype is string. shape: (batch_size, seq_len)
        """
        tgt = 4*torch.ones_like(src)
        tgt[:, 0] = 1
        tgt_emb = self.compute_tgt_embedding(tgt)

        src_tokens = self.src_tokenizer.encode_batch(src)
        src_emb = self.compute_src_embedding(src)
        src_mask = self.get_padding_mask(src_tokens)

        src_encoding = self.transformer.encoder(src_emb, src_mask)
        incompletely_generated_samples = list(range(src.size(0)))
        for i in range(1, self.seq_len):
            logits = self.transformer.decoder(tgt_emb, src_encoding, src_mask = src_mask)
            output_tokens = logits.argmax(dim=-1)
            tgt[:, i] = output_tokens
            for i in range(output_tokens.size(0)):
                if output_tokens[i].item() == 2: # EOS
                    incompletely_generated_samples.pop(i)
            if len(incompletely_generated_samples) == 0:
                break
            tgt_emb = self.compute_tgt_embedding(tgt)

        return tgt



Model Usage

1. dataloader outputs raw text data.
2. All method but generate() takes tokenized sequence as input.
3. generate() method takes raw text as input. and outputs text data also.

## Load Dataset

In [None]:
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

dataset_hf = load_dataset("wmt/wmt14", "de-en", split="train")

In [None]:
dataset = dataset_hf.with_format(type="torch")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)

In [None]:
class WMTDataset(Dataset):
    def __init__(self, src_tokenizer=None, tgt_tokenizer= None, dataset_from_hub = None, split = "train"):
        if dataset_from_hub is None:
            print("dataset is not fed. Download from the huggingface hub.")
            dataset_from_hub = load_dataset("wmt/wmt14", "de-en", split=split)
        if src_tokenizer is None:
            print("Source tokenizer is not fed. Build from scratch.")
            src_tokenizer = get_eng_tokenizer(dataset_from_hub)
        if tgt_tokenizer is None:
            print("Target tokenizer is not fed. Build it from the scratch.")
            tgt_tokenizer = get_de_tokenizer(dataset_from_hub)
    
        
        texts = dataset_from_hub.select_columns(["translation"])
        
        eng_texts, de_texts = [], []

        for batch in texts.iter(1000):
            for sample in batch['translation']:
                eng_texts.append(
                    sample['en']
                )
                de_texts.append(
                    sample['de']
                )
        
        self.eng_texts = eng_texts
        self.de_texts = de_texts
    
    def __len__(self):
        return len(self.eng_texts)
    def __getitem__(self, idx):
        pass


        

def get_eng_tokenizer(dataset_from_hub):
    def eng_dataloader(batch_size = 1000):
        tok_dataset = dataset_from_hub.select_columns(["translation"])

        for batch in tok_dataset.iter(batch_size):
            res = []
            for sample in batch["translation"]:
                res.append(sample['en'])
            yield res
    src_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
    src_tokenizer.pre_tokenizer = Whitespace()
    trainer = BpeTrainer(vocab_size=1000, special_tokens=[
        "[UNK]",
        "[CLS]",
        "[EOS]",
        "[SEP]",
        "[PAD]",
        "[MASK]"
    ])

    src_tokenizer.train_from_iterator(eng_dataloader(), trainer)
    # src_tokenizer.save("en-tokenizer.json")
    return src_tokenizer

def get_de_tokenizer(dataset_from_hub):
    def de_dataloader(batch_size = 1000):
        tok_dataset = dataset_from_hub.select_columns(["translation"])

        for batch in tok_dataset.iter(batch_size):
            res = []
            for sample in batch["translation"]:
                res.append(sample['de'])
            yield res
    tgt_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
    tgt_tokenizer.pre_tokenizer = Whitespace()
    trainer = BpeTrainer(vocab_size=1000, special_tokens=[
        "[UNK]",
        "[CLS]",
        "[EOS]",
        "[SEP]",
        "[PAD]",
        "[MASK]"
    ])

    tgt_tokenizer.train_from_iterator(de_dataloader(), trainer)
    # tgt_tokenizer.save("de-tokenizer.json")
    return tgt_tokenizer



def eng_dataloader(batch_size = 1000):
    tok_dataset = dataset_hf.select_columns(["translation"])

    for batch in tok_dataset.iter(batch_size):
        res = []
        for sample in batch["translation"]:
            res.append(sample['en'])
        yield res

def de_dataloader(batch_size = 1000):
    tok_dataset = dataset_hf.select_columns(["translation"])

    for batch in tok_dataset.iter(batch_size):
        res = []
        for sample in batch["translation"]:
            res.append(sample['de'])
        yield res


## Tokenizer

In [None]:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace

en_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
en_tokenizer.pre_tokenizer = Whitespace()
trainer = BpeTrainer(vocab_size=1000, special_tokens=[
    "[UNK]",
    "[CLS]",
    "[EOS]",
    "[SEP]",
    "[PAD]",
    "[MASK]"
])

en_tokenizer.train_from_iterator(eng_dataloader(), trainer)
en_tokenizer.save("en-tokenizer.json")

In [None]:
de_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
de_tokenizer.pre_tokenizer = Whitespace()
trainer = BpeTrainer(vocab_size=1000, special_tokens=[
    "[UNK]",
    "[CLS]",
    "[EOS]",
    "[SEP]",
    "[PAD]",
    "[MASK]"
])

de_tokenizer.train_from_iterator(de_dataloader(), trainer)
de_tokenizer.save("de-tokenizer.json")

## Train model

In [None]:
class TransformerTranslatorTrainer:
    def __self__(
        self,
        model,
        optimizer,
        src_tokenizer,
        tgt_tokenizer,
        train_dataloader,
        val_dataloader,
        test_dataloader,
        criterion,
    ):
        self.model = model
        self.optimizer = optimizer
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer
        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.test_dataloader = test_dataloader
        self.criterion = criterion

    def train_one_step(
        self, src: torch.Tensor, tgt: torch.Tensor, teacher_forcing: bool = True
    ):
        """
        compute loss, step optimizers, return training statistics.
        Use teacher forcing if teacher_forcing is True

        src: (batch_size, ) shape torch tensor.
        tgt: (batch_size, ) shape torch tensor.
        """
        self.model.train()
        src_tokens = self.src_tokenizer.encode_batch(src)
        tgt_tokens = self.tgt_tokenizer.encode_batch(tgt)
        if teacher_forcing:
            

    def test(
        self,
    ):
        """
        Do the model test and return test statistics.
        """
        pass

    def train(self, num_epochs):
        for n in range(num_epochs):
            self.train_one_step()

## Test Model

## Inference