In [1]:
# ! python -m spacy download en_core_web_sm --quiet
# ! python -m spacy download de_core_news_sm --quiet

In [2]:
# references: 
# https://hussainwali.medium.com/transforming-your-text-data-with-pytorch-12ec1b1c9ae6
# https://github.com/bentrevett/pytorch-seq2seq/tree/main
# https://adeveloperdiary.com/data-science/deep-learning/nlp/machine-translation-recurrent-neural-network-pytorch/

In [3]:
import torch
import torch.nn as nn
import torchtext
import spacy
from torch.utils.data import DataLoader
import torch.optim as optim
import tqdm
import numpy as np
import datasets

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
dataset = datasets.load_dataset("bentrevett/multi30k")
train_data, valid_data, test_data = (
    dataset["train"],
    dataset["validation"],
    dataset["test"],
)

In [5]:
en_nlp = spacy.load("en_core_web_sm")
de_nlp = spacy.load("de_core_news_sm")

In [6]:
def tokenize_example(example, en_nlp, de_nlp, max_length, lower, sos_token, eos_token):
    en_tokens = [token.text for token in en_nlp.tokenizer(example["en"])][:max_length]
    de_tokens = [token.text for token in de_nlp.tokenizer(example["de"])][:max_length]
    en_tokens = [token.lower() for token in en_tokens]
    de_tokens = [token.lower() for token in de_tokens]
    en_tokens = [sos_token] + en_tokens + [eos_token]
    de_tokens = [sos_token] + de_tokens + [eos_token]
    return {"en_tokens": en_tokens, "de_tokens": de_tokens}

In [7]:
max_length = 1_000
lower = True
sos_token = "<sos>"
eos_token = "<eos>"

fn_kwargs = {
    "en_nlp": en_nlp,
    "de_nlp": de_nlp,
    "max_length": max_length,
    "lower": lower,
    "sos_token": sos_token,
    "eos_token": eos_token,
}

train_data = train_data.map(tokenize_example, fn_kwargs=fn_kwargs)
valid_data = valid_data.map(tokenize_example, fn_kwargs=fn_kwargs)
test_data = test_data.map(tokenize_example, fn_kwargs=fn_kwargs)

Map: 100%|███████████████████████████████| 1000/1000 [00:00<00:00, 2139.40 examples/s]


In [8]:
min_freq = 2
unk_token = "<unk>"
pad_token = "<pad>"

special_tokens = [
    unk_token,
    pad_token,
    sos_token,
    eos_token,
]

en_vocab = torchtext.vocab.build_vocab_from_iterator(
    train_data["en_tokens"],
    min_freq=min_freq,
    specials=special_tokens,
    max_tokens=1_000
)

de_vocab = torchtext.vocab.build_vocab_from_iterator(
    train_data["de_tokens"],
    min_freq=min_freq,
    specials=special_tokens,
    max_tokens=1_000
)

In [9]:
assert en_vocab[unk_token] == de_vocab[unk_token]
assert en_vocab[pad_token] == de_vocab[pad_token]

unk_index = en_vocab[unk_token]
pad_index = en_vocab[pad_token]

en_vocab.set_default_index(unk_index)
de_vocab.set_default_index(unk_index)

In [10]:
def numericalize_example(example, en_vocab, de_vocab):
    en_ids = en_vocab.lookup_indices(example["en_tokens"])
    de_ids = de_vocab.lookup_indices(example["de_tokens"])
    return {"en_ids": en_ids, "de_ids": de_ids}

In [11]:
fn_kwargs = {"en_vocab": en_vocab, "de_vocab": de_vocab}

train_data = train_data.map(numericalize_example, fn_kwargs=fn_kwargs)
valid_data = valid_data.map(numericalize_example, fn_kwargs=fn_kwargs)
test_data = test_data.map(numericalize_example, fn_kwargs=fn_kwargs)

Map: 100%|███████████████████████████████| 1000/1000 [00:00<00:00, 6448.20 examples/s]


In [21]:
train_data[0]

{'en_ids': tensor([  2,  16,  24,  15,  25, 778,  17,  57,  80, 202,   0,   5,   3]),
 'de_ids': tensor([  2,  18,  26, 253,  30,  84,  20,  88,   7,  15, 110,   0,   0,   4,
           3]),
 'en': 'Two young, White males are outside near many bushes.',
 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.',
 'en_tokens': ['<sos>',
  'two',
  'young',
  ',',
  'white',
  'males',
  'are',
  'outside',
  'near',
  'many',
  'bushes',
  '.',
  '<eos>'],
 'de_tokens': ['<sos>',
  'zwei',
  'junge',
  'weiße',
  'männer',
  'sind',
  'im',
  'freien',
  'in',
  'der',
  'nähe',
  'vieler',
  'büsche',
  '.',
  '<eos>']}

In [13]:
data_type = "torch"
format_columns = ["en_ids", "de_ids"]

train_data = train_data.with_format(
    type=data_type, columns=format_columns, output_all_columns=True
)

valid_data = valid_data.with_format(
    type=data_type,
    columns=format_columns,
    output_all_columns=True,
)

test_data = test_data.with_format(
    type=data_type,
    columns=format_columns,
    output_all_columns=True,
)

In [14]:
train_data[0]

{'en_ids': tensor([  2,  16,  24,  15,  25, 778,  17,  57,  80, 202,   0,   5,   3]),
 'de_ids': tensor([  2,  18,  26, 253,  30,  84,  20,  88,   7,  15, 110,   0,   0,   4,
           3]),
 'en': 'Two young, White males are outside near many bushes.',
 'de': 'Zwei junge weiße Männer sind im Freien in der Nähe vieler Büsche.',
 'en_tokens': ['<sos>',
  'two',
  'young',
  ',',
  'white',
  'males',
  'are',
  'outside',
  'near',
  'many',
  'bushes',
  '.',
  '<eos>'],
 'de_tokens': ['<sos>',
  'zwei',
  'junge',
  'weiße',
  'männer',
  'sind',
  'im',
  'freien',
  'in',
  'der',
  'nähe',
  'vieler',
  'büsche',
  '.',
  '<eos>']}

In [15]:
def get_collate_fn(pad_index):
    def collate_fn(batch):
        batch_en_ids = [example["en_ids"] for example in batch]
        batch_de_ids = [example["de_ids"] for example in batch]
        batch_en_ids = nn.utils.rnn.pad_sequence(batch_en_ids, padding_value=pad_index)
        batch_de_ids = nn.utils.rnn.pad_sequence(batch_de_ids, padding_value=pad_index)
        batch = {
            "en_ids": batch_en_ids,
            "de_ids": batch_de_ids,
        }
        return batch

    return collate_fn

In [16]:
def get_data_loader(dataset, batch_size, pad_index, shuffle=False):
    collate_fn = get_collate_fn(pad_index)
    data_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=shuffle,
    )

    return data_loader

In [17]:
batch_size = 128

train_dataloader = get_data_loader(train_data, batch_size, pad_index, shuffle=True)
valid_dataloader = get_data_loader(valid_data, batch_size, pad_index)
test_dataloader = get_data_loader(test_data, batch_size, pad_index)

In [18]:
list(train_dataloader)[7]["en_ids"].shape
# batch is of shape seq_len, batch_sz

torch.Size([33, 128])

In [19]:
# RNN implementation: https://aclanthology.org/attachments/D14-1179.Attachment.pdf

In [48]:
class Encoder(nn.Module):
    def __init__(self, vocab_sz, embedding_dim, hidden_dim, n_layers, p_dropout=0.5):
        super().__init__()
        self.embedding = nn.Embedding(vocab_sz, embedding_dim)
        self.gru = nn.GRU(embedding_dim, hidden_dim)
        self.dropout = nn.Dropout(p_dropout)
    def forward(self, src):
        # input is of shape [seq_len, batch_sz]
        embeddings = self.dropout(self.embedding(src))
        # out is of shape [seq_len, batch_sz, n_directions * hidden_dim]
        # hidden is of shape [n_directions * n_layers, batch_sz, hidden_dim]
        out, hidden = self.gru(embeddings)
        return out, hidden

# wee test
voc_sz = 2
b_sz = 3
seq_l = 4
n_lrs = 1
x = torch.randint(0, 2, (seq_l, b_sz))
# input is [4, 3]
enc = Encoder(voc_sz, b_sz, seq_l, n_lrs)
o, h = enc(x)
assert list(o.shape) == [4, 3, 4]
assert list(h.shape) == [1, 3, 4]

# Example parameters
vocab_size = 50  # Size of vocabulary
embedding_dim = 16  # Size of embeddings
hidden_size = 32  # Size of hidden state
seq_len = 10
batch_size = 4

# Random input tokens (integer indices)
input_tokens = torch.randint(0, vocab_size, (seq_len, batch_size))

# Initialize encoder
encoder = Encoder(vocab_size, embedding_dim, hidden_size)

# Forward pass
hidden_states, final_hidden_state = encoder(input_tokens)

# print(f"Hidden states shape: {hidden_states.shape}")
print(f"Final hidden state shape: {final_hidden_state.shape}")

In [354]:
class DecoderOneStep(nn.Module):
    def __init__(self, input_output_dim, embedding_dim, hidden_dim, n_layers, p_dropout = 0.5):
        super().__init__()
        self.input_output_dim = input_output_dim
        self.embedding = nn.Embedding(input_output_dim, embedding_dim)
        self.gru = nn.GRU(embedding_dim, hidden_dim)
        self.dropout = nn.Dropout(p_dropout)
        self.fc = nn.Linear(hidden_dim, input_output_dim)
    def forward(self, trg, hidden):
        # since we get only one token at a time, 
        # it is of shape [batch_sz]
        # so we need to unsqueeze to make it [1, batch_sz]
        trg = trg.unsqueeze(0)
        embeddings = self.dropout(self.embedding(trg))
        out, hidden = self.gru(embeddings, hidden)
        # out is of shape [seq_len, batch_sz, n_directions * hidden_dim]
        # hidden is of shape [n_directions * n_layers, batch_sz, hidden_dim]
        # but input for fc has to be squeezed at 0-th dim to remove the dimension added earlier
        # so that shape of out is [batch_sz, hidden_dim]
        out = out.squeeze(0)
        out = self.fc(out)
        # out here is [batch_sz, input_output_dim]
        return out, hidden

# wee test
io_dim = 5
emb_dim = 10
h_dim = 4
n_lrs = 1
d1s = DecoderOneStep(io_dim, emb_dim, h_dim, n_lrs) 
single_step_input = torch.tensor([1,2,2,1])
h = torch.randn(1, 4, h_dim) # batch_sz is 4 because the input into single step is [seq_len, batch_sz]
print(h.shape)
o1s, h1s = d1s(single_step_input, h)
# assert list(o1s.shape) == [4, 5]
assert list(h1s.shape) == [1, 4, 4]

In [378]:
class Decoder(nn.Module):
    def __init__(self, one_step_decoder):
        super().__init__()
        self.one_step_decoder = one_step_decoder
    def forward(self, trg, hidden, teacher_forcing_ratio=0.5):
        seq_len, batch_sz = trg.shape
        vocab_sz = self.one_step_decoder.input_output_dim
        predictions = torch.zeros(seq_len, batch_sz, vocab_sz)
        input_ = trg[0, :]
        for t in range(seq_len):
            pred, hidden = self.one_step_decoder(input_, hidden)
            predictions[t] = pred
            is_teacher_force =  random.random() < teacher_forcing_ratio 
            input_ = trg[t] if is_teacher_force else pred.argmax(1)
        # predictions are of shape [seq_len, batch_sz, input_output_dim]
        return predictions

# wee test
dec = Decoder(d1s)
y = torch.randint(0, 2, (seq_l, b_sz))
odec = dec(y, h)
assert list(odec.shape) == [seq_l, b_sz, io_dim]

In [379]:
class EncDec(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    def forward(self, src, trg, teacher_forcing_ratio):
        _, context = self.encoder(src)
        predictions = self.decoder(trg, context, teacher_forcing_ratio)
        return predictions

In [380]:
def init_model(en_vocab, de_vocab, pad_index):
    # params
    input_dim = len(en_vocab)
    output_dim = len(de_vocab)
    encoder_embedding_dim = 256
    decoder_embedding_dim = 256
    hidden_dim = 512
    encoder_dropout = 0.5
    decoder_dropout = 0.5
    n_layers = 1
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # init models
    encoder = Encoder(input_dim, encoder_embedding_dim, hidden_dim, n_layers, encoder_dropout)
    decoder_one_step = DecoderOneStep(output_dim, decoder_embedding_dim, hidden_dim, n_layers, decoder_dropout)
    decoder = Decoder(decoder_one_step)
    seq2seq = EncDec(encoder, decoder)

    # optimizer
    optimizer = optim.Adam(seq2seq.parameters())
    
    # loss function
    criterion = nn.CrossEntropyLoss(ignore_index=pad_index)
    
    return seq2seq, optimizer, criterion

In [381]:
def train(train_dataloader, val_dataloader, en_vocab, de_vocab, pad_index, n_epochs=3):
    model, optimizer, criterion = init_model(en_vocab, de_vocab, pad_index)
    clip = 1
    # training
    for epoch in tqdm.tqdm(range(n_epochs)):
        training_losses = []
        model.train()
        
        for i, batch in enumerate(train_dataloader):
            optimizer.zero_grad()
            
            src = batch["en_ids"]
            trg = batch["de_ids"]
            # src shape [src_seq_len, batch_sz]
            # trg shape [trg_seq_len, batch_sz]

            # forward pass
            y_pred = model(src, trg, teacher_forcing_ratio=0.5)
            # y_pred shape [trg_seq_len, batch_sz, trg_vocab_sz]
            trg_vocab_sz = y_pred.shape[-1]
            # trg_vocab_sz shape is just len of trg vocab
            # discard first token from output
            y_pred = y_pred[1:].view(-1, trg_vocab_sz) # means do whatever you want with other dims and last dim has to be trg_vocab_sz
            # or y_pred after discarding first token 
            # shape is [trg_seq_len - 1, batch_sz, vocab_sz]
            # .view(-1, trg_vocab_sz returns shape
            # [(trg_seq_len - 1) * batch_sz, vocab_sz]
            # so now trg needs to be updated same as above
            trg = trg[1:].view(-1) # not sure why this can be [(trg_seq_len - 1) * batch_sz]

            # calc loss
            loss = criterion(y_pred, trg)

            # backprop
            loss.backward()
            # gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

            # update weights
            optimizer.step()

            training_losses.append(loss.item())
        print(f"epoch {epoch} average training loss: {sum(training_losses) / len(training_losses)}")
    
        # validation
        with torch.no_grad():
            model.eval()
            validation_losses = []
            for i, batch in enumerate(val_dataloader):
                src = batch["en_ids"]
                trg = batch["de_ids"]
    
                # forward pass
                y_pred = model(src, trg, teacher_forcing_ratio=0.0)
                trg_vocab_sz = y_pred.shape[-1]
                y_pred = y_pred[1:].view(-1, trg_vocab_sz)
                trg = trg[1:].view(-1)
    
                # calc loss
                loss = criterion(y_pred, trg)
    
                validation_losses.append(loss.item())
            print(f"epoch {epoch} average validation loss: {sum(validation_losses) / len(validation_losses)}")   
    return model

In [382]:
# train model
model = train(train_dataloader, valid_dataloader, en_vocab, de_vocab, pad_index)
model

  0%|                                                           | 0/3 [00:00<?, ?it/s]

epoch 0 average training loss: 3.8569061798146116


 33%|████████████████▋                                 | 1/3 [04:49<09:39, 289.78s/it]

epoch 0 average validation loss: 3.817979872226715
epoch 1 average training loss: 3.149116719871891


 67%|█████████████████████████████████▎                | 2/3 [10:06<05:05, 305.83s/it]

epoch 1 average validation loss: 3.5690895318984985
epoch 2 average training loss: 2.8717597371155996


100%|██████████████████████████████████████████████████| 3/3 [13:46<00:00, 275.49s/it]

epoch 2 average validation loss: 3.358761876821518





EncDec(
  (encoder): Encoder(
    (embedding): Embedding(998, 256)
    (gru): GRU(256, 512)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (one_step_decoder): DecoderOneStep(
      (embedding): Embedding(998, 256)
      (gru): GRU(256, 512)
      (dropout): Dropout(p=0.5, inplace=False)
      (fc): Linear(in_features=512, out_features=998, bias=True)
    )
  )
)

In [383]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f"The model has {count_parameters(model):,} trainable parameters")

The model has 3,388,390 trainable parameters


In [420]:
def translate_sentence(
    sentence,
    model,
    en_nlp,
    de_nlp,
    en_vocab,
    de_vocab,
    sos_token,
    eos_token,
    device,
    max_output_length=25,
):
    with torch.no_grad():
        model.eval()
        tokens = [token.text for token in en_nlp.tokenizer(sentence)]
        print(tokens)
        tokens = [token.lower() for token in tokens]
        tokens = [sos_token] + tokens + [eos_token]
        ids = en_vocab.lookup_indices(tokens)
        print(f"in tokens: {ids}")
        tensor = torch.LongTensor(ids).unsqueeze(-1).to(device)
        _, context = model.encoder(tensor)
        hidden = context
        next_token = de_vocab.lookup_indices([sos_token])
        next_token = next_token[0]
        outputs = []
        for _ in range(max_output_length):
            next_token = torch.LongTensor([next_token])
            output, hidden = model.decoder.one_step_decoder(next_token, hidden)
            next_token = output.argmax(1).item()
            if next_token == de_vocab[eos_token]:
                break
            else:
                outputs.append(next_token)
        print(f"out tokens: {outputs}")
        tokens = de_vocab.lookup_tokens(outputs)
    return tokens

In [422]:
test_sentences = test_data
rando_idx = np.random.randint(low=0, high=len(test_sentences))
sentence = test_sentences[rando_idx]["en"]
expected_translation = test_sentences[rando_idx]["de"]
translation = translate_sentence(
    sentence=sentence,
    model=model,
    en_nlp=en_nlp,
    de_nlp=de_nlp,
    en_vocab=en_vocab,
    de_vocab=de_vocab,
    sos_token="<sos>",
    eos_token="<eos>",
    device="cpu",
)
print(f"\nsentence: {sentence}\n")
print(f"expected_translation: {expected_translation}\n")
print(f"actual translation: {' '.join(i for i in translation[1:])}")

['A', 'black', 'dog', 'and', 'a', 'brown', 'dog', 'with', 'a', 'ball', '.']
in tokens: [2, 4, 26, 35, 11, 4, 61, 35, 13, 4, 68, 5, 3]
out tokens: [5, 5, 114, 32, 11, 6, 293, 11, 6, 293, 4]

sentence: A black dog and a brown dog with a ball.

expected_translation: Ein schwarzer und ein brauner Hund mit einem Ball.

actual translation: ein schwarzer hund mit einem stock mit einem stock .
