In [2]:
!pip install torch==2.2.0 torchvision==0.17.0 torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cpu

Looking in indexes: https://download.pytorch.org/whl/cpu
[33mDEPRECATION: omegaconf 2.0.6 has a non-standard dependency specifier PyYAML>=5.1.*. pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of omegaconf or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0m

In [248]:
# https://github.com/hyunwoongko/transformer

In [3]:
import math
import torch
import pytorch_lightning as PL
from torch import nn
from torch import optim
from torch import Tensor
from torch.nn import functional as F

### Prepare dataset

In [223]:
pairs = [
    ("<bos> hello <eos>","<bos> hi , how are you ? <eos>"),
    ("<bos> how are you ? <eos>", "<bos> i am fine <eos>"),
    ("<bos> how is it doing ? <eos>", "<bos> i am good <eos>"),
]

def add_word(vocab: dict, word: str):
    if word not in vocab:
        vocab[word] = len(vocab)

def add_sent(vocab: dict, sent: str):
    for word in sent.split(" "):
        add_word(vocab, word)
vocab = {
    "<unk>": 0,
    "<bos>": 1,
    "<eos>": 2
}

for p in pairs:
    add_sent(vocab, p[0])
    add_sent(vocab, p[1])
vocab

{'<unk>': 0,
 '<bos>': 1,
 '<eos>': 2,
 'hello': 3,
 'hi': 4,
 ',': 5,
 'how': 6,
 'are': 7,
 'you': 8,
 '?': 9,
 'i': 10,
 'am': 11,
 'fine': 12,
 'is': 13,
 'it': 14,
 'doing': 15,
 'good': 16}

In [224]:
idx2word = {value: key for key, value in vocab.items()}
idx2word

{0: '<unk>',
 1: '<bos>',
 2: '<eos>',
 3: 'hello',
 4: 'hi',
 5: ',',
 6: 'how',
 7: 'are',
 8: 'you',
 9: '?',
 10: 'i',
 11: 'am',
 12: 'fine',
 13: 'is',
 14: 'it',
 15: 'doing',
 16: 'good'}

In [225]:
def to_tensor(vocab: dict, sent: str):
    words = sent.split(" ")
    ts = torch.ones(len(words), dtype=torch.long)
    for i, word in enumerate(words):
        ts[i] = vocab[word]
    return ts

In [226]:
src, tgt = pairs[2]
src = to_tensor(vocab, src)
tgt = to_tensor(vocab, tgt)

### Square mask

In [227]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz))) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]
    
    src_mask = torch.zeros((src_seq_len, src_seq_len), dtype=torch.bool)
    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)

    src_padding_mask = torch.ones(src_seq_len, dtype=torch.bool)
    tgt_padding_mask = torch.ones(tgt_seq_len, dtype=torch.bool)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [228]:
src.shape, tgt.shape

(torch.Size([7]), torch.Size([5]))

In [229]:
src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt)
src_mask.shape, tgt_mask.shape, src_padding_mask.shape, tgt_padding_mask.shape

(torch.Size([7, 7]), torch.Size([5, 5]), torch.Size([7]), torch.Size([5]))

### Seq2Seq Transformer

In [238]:
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        # This is for batch only, we test on unbatch
        # pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])


class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)


class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 emb_size: int,
                 vocab_size: int,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        
        self.transformer = nn.Transformer(
            d_model=emb_size, nhead=4,
            num_encoder_layers=4, 
            num_decoder_layers=4,
            dim_feedforward=256,
            dropout=dropout
        )
        
        self.token_emb = TokenEmbedding(vocab_size, emb_size)
        self.pos_encoding = PositionalEncoding(emb_size, dropout=dropout)
        self.generator = nn.Linear(emb_size, vocab_size)

    def forward(self, src: Tensor, tgt: Tensor, src_mask: Tensor, tgt_mask: Tensor):
        src_emb = self.pos_encoding(self.token_emb(src))
        tgt_emb = self.pos_encoding(self.token_emb(tgt))
        # outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask)
        outs = self.transformer(src_emb, tgt_emb)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        # return self.transformer.encoder(self.pos_encoding(self.token_emb(src)), src_mask)
        return self.transformer.encoder(self.pos_encoding(self.token_emb(src)))

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        # return self.transformer.decoder(self.pos_encoding(self.token_emb(tgt)), memory, tgt_mask)
        return self.transformer.decoder(self.pos_encoding(self.token_emb(tgt)), memory)

In [239]:
vocab_size = len(vocab)
emb_size = 128
my_model = Seq2SeqTransformer(emb_size, len(vocab))
optimizer = optim.Adam(my_model.parameters(), lr=0.001, eps=1e-09, betas=[0.9, 0.98])
criterion = nn.CrossEntropyLoss()

### Train the model

In [240]:
def train_epoch(model: nn.Module, optimizer: optim.Optimizer, criterion):
    cost = 0
    for p in pairs:
        src = to_tensor(vocab, p[0])
        tgt = to_tensor(vocab, p[1])
        tgt_inp = tgt[:-1]
        tgt_out = tgt[1:]
        src_mask, tgt_mask, _, _ = create_mask(src, tgt_inp)
        logits = model(src, tgt_inp, src_mask, tgt_mask)
        loss = criterion(logits, tgt_out)
        cost += loss
        optimizer.zero_grad()        
        loss.backward()
        optimizer.step()
    return cost / len(pairs)
    
for epoch in range(1, 100 + 1):
    cost = train_epoch(my_model, optimizer, criterion)
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, cost={cost}")

Epoch 10, cost=1.2706842422485352
Epoch 20, cost=0.13766483962535858
Epoch 30, cost=0.038455624133348465
Epoch 40, cost=0.014489025808870792
Epoch 50, cost=0.01110044214874506
Epoch 60, cost=0.009455631487071514
Epoch 70, cost=0.005460354965180159
Epoch 80, cost=0.0043428209610283375
Epoch 90, cost=0.002672881120815873
Epoch 100, cost=0.0022111504804342985


### Greedy decoding

In [247]:
def greedy_decode(model: nn.Module, src: str, bos_token: int, eos_token: int, max_len: int = 20):
    src = to_tensor(vocab, src)
    tgt = torch.ones(1).fill_(bos_token)
    src_seq_len = src.shape[0]
    src_mask = torch.zeros((src_seq_len, src_seq_len), dtype=torch.bool)
    memory = model.encode(src, src_mask)

    for i in range(max_len - 1):
        tgt_seq_len = tgt.shape[0]
        tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
        decode_out = model.decode(tgt, memory, tgt_mask)
        logits = model.generator(decode_out)[-1, :]
        logits = logits.view(1, -1)
        _, next_word = torch.max(logits, dim=1)
        tgt = torch.cat([tgt, next_word], dim=0)
        if next_word.item() == eos_token:
            break
    sent = " ".join(idx2word[x.item()] for x in tgt)
    return sent

src = "<bos> how are you <eos>"
print(f"Input: {src}")
print(greedy_decode(my_model, src, vocab["<bos>"], vocab["<eos>"]))

Input: <bos> how are you <eos>
<bos> i am fine <eos>
