- Remove non alphanumeric characters for simple training

In [1]:
from transformer import Transformer # this is the transformer.py file
import torch
import numpy as np

In [2]:
english_file = 'Lang/output_eng.eng' # replace this path with appropriate one
kannada_file = 'Lang/output_ru.ru' # replace this path with appropriate one

# Generated this by filtering Appendix code

START_TOKEN = '<START>'
PADDING_TOKEN = '<PADDING>'
END_TOKEN = '<END>'

kannada_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', 
                      '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', 
                      'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
                      'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
                      'А', 'Б', 'В', 'Г', 'Д', 'Е', 'Ё', 'Ж', 'З', 'И', 'Й', 'К', 'Л', 'М', 'Н', 'О', 'П', 'Р', 'С', 'Т', 'У', 'Ф', 
                      'Х', 'Ц', 'Ч', 'Ш', 'Щ', 'Ъ', 'Ы', 'Ь', 'Э', 'Ю', 'Я', 'а', 'б', 'в', 'г', 'д', 'е', 'ё', 'ж', 'з', 'и', 'й', 
                      'к', 'л', 'м', 'н', 'о', 'п', 'р', 'с', 'т', 'у', 'ф', 'х', 'ц', 'ч', 'ш', 'щ', 'ъ', 'ы', 'ь', 'э', 'ю', 'я',
                      '{', '|', '}', '~', PADDING_TOKEN, END_TOKEN]

english_vocabulary = [START_TOKEN, ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', 
                        '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
                        ':', '<', '=', '>', '?', '@',
                        '[', '\\', ']', '^', '_', '`', 
                        'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l',
                        'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 
                        'y', 'z', 
                        '{', '|', '}', '~', PADDING_TOKEN, END_TOKEN]

In [3]:
index_to_russian = {k:v for k,v in enumerate(kannada_vocabulary)}
russian_to_index = {v:k for k,v in enumerate(kannada_vocabulary)}
index_to_english = {k:v for k,v in enumerate(english_vocabulary)}
english_to_index = {v:k for k,v in enumerate(english_vocabulary)}

In [4]:
with open(english_file, 'r') as file:
    english_sentences = file.readlines()
with open(kannada_file, 'r') as file:
    russian_sentences = file.readlines()

# Limit Number of sentences
TOTAL_SENTENCES = 200000
english_sentences = english_sentences[:TOTAL_SENTENCES]
russian_sentences = russian_sentences[:TOTAL_SENTENCES]
english_sentences = [sentence.rstrip('\n').lower() for sentence in english_sentences]
russian_sentences = [sentence.rstrip('\n') for sentence in russian_sentences]

In [5]:
english_sentences[:10]

['e-mail: info@e-e-e.cz',
 'the company has been registered with the municipal court in prague, in section b, file 14857.',
 'news:',
 '22.02.09',
 'since january 2009 new office premises in the street v kolkovn 3, prague 1, have been opened. these premises are situated in a close vicinity of all our partner banks.',
 'about us news references contacts',
 'news',
 '02 / 2009',
 'new projects (22. 02.) new office premises have been opened (19. 02.)',
 'east export engineering a.s. (eee) has reassumed the previous success of czech companies and banks in the field of export financing and commenced its activity with the aim to further develop the support of financing of home and foreign schemes involving supplier participation of czech companies.']

In [6]:
russian_sentences[:10]

['email: info@e-e-e.cz',
 'фирма зарегистрирована в городском суде в г. праге, раздел б, вкладыш 14857.',
 'новости:',
 '22.02.09',
 'с января месяца 2009 г. открыты новые офисные помещения на улице v kolkovn 3 (в колковне, д. 3), прага-1. данные помещения находятся поблизости всех банков-партнеров.',
 'o нас новости референции контакт',
 'новости',
 '02 / 2009',
 'открыты новые офисные помещения (22. 02.)',
 'компания ао east export engineering (eee) исходит из успехов чешских фирм и банков в области экспортного финансирования; она поставила перед собой цель далее развивать поддержку финансирования чешских и зарубежных проектов с участием чешских фирм в качестве подрядчиков/поставщиков.']

In [7]:
import numpy as np
PERCENTILE = 97
print( f"{PERCENTILE}th percentile length Kannada: {np.percentile([len(x) for x in russian_sentences], PERCENTILE)}" )
print( f"{PERCENTILE}th percentile length English: {np.percentile([len(x) for x in english_sentences], PERCENTILE)}" )


97th percentile length Kannada: 276.0
97th percentile length English: 295.0


In [8]:
max_sequence_length = 200

def is_valid_tokens(sentence, vocab):
    for token in list(set(sentence)):
        if token not in vocab:
            return False
    return True

def is_valid_length(sentence, max_sequence_length):
    return len(list(sentence)) < (max_sequence_length - 1) # need to re-add the end token so leaving 1 space

valid_sentence_indicies = []
for index in range(len(russian_sentences)):
    kannada_sentence, english_sentence = russian_sentences[index], english_sentences[index]
    if is_valid_length(kannada_sentence, max_sequence_length) \
      and is_valid_length(english_sentence, max_sequence_length) \
      and is_valid_tokens(kannada_sentence, kannada_vocabulary):
        valid_sentence_indicies.append(index)

print(f"Number of sentences: {len(russian_sentences)}")
print(f"Number of valid sentences: {len(valid_sentence_indicies)}")

Number of sentences: 200000
Number of valid sentences: 180503


In [9]:
russian_sentences = [russian_sentences[i] for i in valid_sentence_indicies]
english_sentences = [english_sentences[i] for i in valid_sentence_indicies]

In [10]:
russian_sentences[:3]

['email: info@e-e-e.cz',
 'фирма зарегистрирована в городском суде в г. праге, раздел б, вкладыш 14857.',
 'новости:']

In [11]:
import torch

d_model = 512
batch_size = 30
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.1
num_layers = 1
max_sequence_length = 200
kn_vocab_size = len(kannada_vocabulary)

transformer = Transformer(d_model, 
                          ffn_hidden,
                          num_heads, 
                          drop_prob, 
                          num_layers, 
                          max_sequence_length,
                          kn_vocab_size,
                          english_to_index,
                          russian_to_index,
                          START_TOKEN, 
                          END_TOKEN, 
                          PADDING_TOKEN)

In [12]:
transformer

Transformer(
  (encoder): Encoder(
    (sentence_embedding): SentenceEmbedding(
      (embedding): Embedding(71, 512)
      (position_encoder): PositionalEncoding()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (layers): SequentialEncoder(
      (0): EncoderLayer(
        (attention): MultiHeadAttention(
          (qkv_layer): Linear(in_features=512, out_features=1536, bias=True)
          (linear_layer): Linear(in_features=512, out_features=512, bias=True)
        )
        (norm1): LayerNormalization()
        (dropout1): Dropout(p=0.1, inplace=False)
        (ffn): PositionwiseFeedForward(
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (relu): ReLU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (norm2): LayerNormalization()
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (decoder): Decoder(
    (sentence_embedding):

In [13]:
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):

    def __init__(self, english_sentences, russian_sentences):
        self.english_sentences = english_sentences
        self.russian_sentences = russian_sentences

    def __len__(self):
        return len(self.english_sentences)

    def __getitem__(self, idx):
        return self.english_sentences[idx], self.russian_sentences[idx]

In [14]:
dataset = TextDataset(english_sentences, russian_sentences)

In [15]:
len(dataset)

180503

In [16]:
dataset[1]

('the company has been registered with the municipal court in prague, in section b, file 14857.',
 'фирма зарегистрирована в городском суде в г. праге, раздел б, вкладыш 14857.')

In [17]:
train_loader = DataLoader(dataset, batch_size)
iterator = iter(train_loader)

In [18]:
for batch_num, batch in enumerate(iterator):
    print(batch)
    if batch_num > 3:
        break

[('e-mail: info@e-e-e.cz', 'the company has been registered with the municipal court in prague, in section b, file 14857.', 'news:', '22.02.09', 'since january 2009 new office premises in the street v kolkovn 3, prague 1, have been opened. these premises are situated in a close vicinity of all our partner banks.', 'about us news references contacts', 'news', '02 / 2009', 'new projects (22. 02.) new office premises have been opened (19. 02.)', 'our team consists of highly experienced professionals who have already successfully implemented several schemes.', 'we are convinced that the co-operation with eee will be beneficial for you and you will thus appraise our long-term experience that we are ready to value for you.', 'references', 'corporate expert workers have taken part in the provision of financing of a construction scheme of gofra, a factory for the production of packaging paper and boxes in the moscow region.', 'for this scheme amounting to 17 million euros czech commercial bank

In [19]:
from torch import nn

criterian = nn.CrossEntropyLoss(ignore_index=russian_to_index[PADDING_TOKEN],
                                reduction='none')

# When computing the loss, we are ignoring cases when the label is the padding token
for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)

optim = torch.optim.Adam(transformer.parameters(), lr=1e-4)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [20]:
NEG_INFTY = -1e9

def create_masks(eng_batch, kn_batch):
    num_sentences = len(eng_batch)
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)

    for idx in range(num_sentences):
      eng_sentence_length, kn_sentence_length = len(eng_batch[idx]), len(kn_batch[idx])
      eng_chars_to_padding_mask = np.arange(eng_sentence_length + 1, max_sequence_length)
      kn_chars_to_padding_mask = np.arange(kn_sentence_length + 1, max_sequence_length)
      encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = True
      encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = True
      decoder_padding_mask_self_attention[idx, :, kn_chars_to_padding_mask] = True
      decoder_padding_mask_self_attention[idx, kn_chars_to_padding_mask, :] = True
      decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = True
      decoder_padding_mask_cross_attention[idx, kn_chars_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

Modify mask such that the padding tokens cannot look ahead.
In Encoder, tokens before it should be -1e9 while tokens after it should be -inf.
 

Note the target mask starts with 2 rows of non masked items: https://github.com/SamLynnEvans/Transformer/blob/master/Beam.py#L55


In [21]:
transformer.train()
transformer.to(device)
total_loss = 0
num_epochs = 10

for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    iterator = iter(train_loader)
    for batch_num, batch in enumerate(iterator):
        transformer.train()
        eng_batch, kn_batch = batch
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(eng_batch, kn_batch)
        optim.zero_grad()
        kn_predictions = transformer(eng_batch,
                                     kn_batch,
                                     encoder_self_attention_mask.to(device), 
                                     decoder_self_attention_mask.to(device), 
                                     decoder_cross_attention_mask.to(device),
                                     enc_start_token=False,
                                     enc_end_token=False,
                                     dec_start_token=True,
                                     dec_end_token=True)
        labels = transformer.decoder.sentence_embedding.batch_tokenize(kn_batch, start_token=False, end_token=True)
        loss = criterian(
            kn_predictions.view(-1, kn_vocab_size).to(device),
            labels.view(-1).to(device)
        ).to(device)
        valid_indicies = torch.where(labels.view(-1) == russian_to_index[PADDING_TOKEN], False, True)
        loss = loss.sum() / valid_indicies.sum()
        loss.backward()
        optim.step()
        #train_losses.append(loss.item())
        if batch_num % 100 == 0:
            print(f"Iteration {batch_num} : {loss.item()}")
            print(f"English: {eng_batch[0]}")
            print(f"Kannada Translation: {kn_batch[0]}")
            kn_sentence_predicted = torch.argmax(kn_predictions[0], axis=1)
            predicted_sentence = ""
            for idx in kn_sentence_predicted:
              if idx == russian_to_index[END_TOKEN]:
                break
              predicted_sentence += index_to_russian[idx.item()]
            print(f"Kannada Prediction: {predicted_sentence}")


            transformer.eval()
            kn_sentence = ("",)
            eng_sentence = ("should we go to the mall?",)
            for word_counter in range(max_sequence_length):
                encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, kn_sentence)
                predictions = transformer(eng_sentence,
                                          kn_sentence,
                                          encoder_self_attention_mask.to(device), 
                                          decoder_self_attention_mask.to(device), 
                                          decoder_cross_attention_mask.to(device),
                                          enc_start_token=False,
                                          enc_end_token=False,
                                          dec_start_token=True,
                                          dec_end_token=False)
                next_token_prob_distribution = predictions[0][word_counter] # not actual probs
                next_token_index = torch.argmax(next_token_prob_distribution).item()
                next_token = index_to_russian[next_token_index]
                kn_sentence = (kn_sentence[0] + next_token, )
                if next_token == END_TOKEN:
                  break
            
            print(f"Evaluation translation (should we go to the mall?) : {kn_sentence}")
            print("-------------------------------------------")

Epoch 0
Iteration 0 : 5.82766580581665
English: e-mail: info@e-e-e.cz
Kannada Translation: email: info@e-e-e.cz
Kannada Prediction: qqqqььжqqУУУаа~q??ьЩьл
Evaluation translation (should we go to the mall?) : ('оооиииииииаааааиртаааааааааиаоааааттааа<END>',)
-------------------------------------------
Iteration 100 : 4.874773979187012
English: 2013-09-28
Kannada Translation: 2013-09-28
Kannada Prediction: 
Evaluation translation (should we go to the mall?) : ('   <END>',)
-------------------------------------------
Iteration 200 : 2.8694446086883545
English: keep a close watch on
Kannada Translation: keep a thing close
Kannada Prediction: aa  a         a
Evaluation translation (should we go to the mall?) : ('                                                                                                                                                                                                        ',)
-------------------------------------------
Iteration 300 : 2.96559476852417
En

KeyboardInterrupt: 

## Inference

In [None]:
transformer.eval()
def translate(eng_sentence):
  eng_sentence = (eng_sentence,)
  kn_sentence = ("",)
  for word_counter in range(max_sequence_length):
    encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(eng_sentence, kn_sentence)
    predictions = transformer(eng_sentence,
                              kn_sentence,
                              encoder_self_attention_mask.to(device), 
                              decoder_self_attention_mask.to(device), 
                              decoder_cross_attention_mask.to(device),
                              enc_start_token=False,
                              enc_end_token=False,
                              dec_start_token=True,
                              dec_end_token=False)
    next_token_prob_distribution = predictions[0][word_counter]
    next_token_index = torch.argmax(next_token_prob_distribution).item()
    next_token = index_to_russian[next_token_index]
    kn_sentence = (kn_sentence[0] + next_token, )
    if next_token == END_TOKEN:
      break
  return kn_sentence[0]

In [None]:
translation = translate("well, prince, so genoa is now no longer republics.")
print(translation)


In [None]:
translation = translate("you have liberated them, haven't you?")
print(translation)


In [None]:
translation = translate("the world is a large place with different people")
print(translation)

In [None]:
translation = translate("my name is ajay")
print(translation)

In [None]:
translation = translate("i cannot stand this smell")
print(translation)

In [None]:
translation = translate("noodles are the best")
print(translation)

In [None]:
translation = translate("why care about this?")
print(translation)

This translated pretty well : "What is the reason. Why" without punctuation.

In [None]:
translation = translate("this is the best thing ever")
print(translation)

The translation : "This is very unusual"

In [None]:
translation = translate("i am here")
print(translation)

Translation: "I have heard". 
This is why word based translator may perform better than character translator. This is actually very good at optimizing the objective of the current transformer even though the translation is off.

In [None]:
translation = translate("click this")
print(translation)

In [None]:
translation = translate("where is the mall?")
print(translation)

In [None]:
translation = translate("what should we do?")
print(translation)

This is correct; but it absolutely fumbles on the next one

In [None]:
translation = translate("today, what should we do")
print(translation)

In [None]:
translation = translate("why did they activate?")
print(translation)

In [None]:
translation = translate("why did they do this?")
print(translation)

That turned out well!

In [None]:
translation = translate("i am well.")
print(translation)

Translation: "I will give you something"

In [None]:
translation = translate("whats the word on the street?")
print(translation)

Overall, this model definately learned something. And you can use other languages instead of this russian language.