In [1]:
import torch
import numpy as np
import pandas as pd
import sys
from src.model.Transformer import Transformer 

In [6]:
source = '../artifacts/source.txt'  
translation = '../artifacts/translation.txt'

START_TOKEN = ''
PADDING_TOKEN = ''
END_TOKEN = ''

source_vocabulary= [START_TOKEN,' ', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't', 'u', 'v', 'w', 'y',
                         PADDING_TOKEN, END_TOKEN]

translation_vocabulary = [START_TOKEN, ' ', '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', PADDING_TOKEN, END_TOKEN]

In [7]:
index_to_translation = {k:v for k,v in enumerate(translation_vocabulary)}
translation_to_index = {v:k for k,v in enumerate(translation_vocabulary)}
index_to_source = {k:v for k,v in enumerate(source_vocabulary)}
source_to_index = {v:k for k,v in enumerate(source_vocabulary)}



In [8]:
print(index_to_translation)

{0: '', 1: ' ', 2: '-', 3: '0', 4: '1', 5: '2', 6: '3', 7: '4', 8: '5', 9: '6', 10: '7', 11: '8', 12: '9', 13: '', 14: ''}


In [9]:
with open(source, 'r') as file:

    source_sentences = file.readlines()

with open(translation, 'r') as file:
    translation_sentences = file.readlines()

# Limit Number of sentences
TOTAL_SENTENCES = 200000
source_sentences = source_sentences[:TOTAL_SENTENCES]
translation_sentences = translation_sentences[:TOTAL_SENTENCES]
source_sentences = [sentence.rstrip('\n').lower() for sentence in source_sentences]
translation_sentences = [sentence.rstrip('\n').lower() for sentence in translation_sentences]




In [10]:
print(translation_sentences)

['1992-07-07', '1970-07-20', '2014-08-06', '1986-02-20', '1989-11-29', '1980-06-19', '2000-08-03', '1978-09-27', '1976-09-18', '1993-05-25', '2000-05-17', '1979-06-04', '1999-11-24', '1989-03-07', '2021-10-11', '1974-12-08', '2022-11-21', '1992-04-09', '1997-03-27', '1978-01-03', '2008-11-08', '1980-04-07', '1988-06-17', '1971-04-10', '1988-05-07', '2022-04-14', '2022-11-24', '2010-03-28', '1970-03-10', '2020-10-24', '2017-01-21', '2011-08-29', '1979-09-01', '1975-05-19', '1992-05-27', '2017-11-09', '2001-03-29', '2009-10-22', '1982-07-28', '1998-04-19', '2008-05-03', '2014-07-30', '2013-08-14', '1982-07-22', '2017-03-10', '1981-09-11', '2013-05-03', '1999-12-31', '1980-01-17', '2001-10-23', '1998-01-03', '2021-10-22', '1972-03-30', '1978-11-14', '2023-02-21', '2014-12-22', '1978-02-14', '1982-05-20', '1999-02-19', '1978-06-26', '1987-07-01', '1972-08-31', '2008-06-17', '1974-03-31', '2023-10-02', '2019-10-01', '2012-08-18', '2011-09-29', '1989-11-16', '2006-09-05', '2011-02-21', '1995

In [11]:
max_sequence_length = 200

def is_valid_tokens(sentence, vocab):
    return all(token in vocab for token in set(sentence))

def is_valid_length(sentence, max_length):
    return len(sentence) < max_length - 1

valid_sentence_indices = [
    index
    for index, (trans_sentence, source_sentence) in enumerate(zip(translation_sentences, source_sentences))
    if is_valid_length(trans_sentence, max_sequence_length)
    and is_valid_length(source_sentence, max_sequence_length)
    and is_valid_tokens(trans_sentence, translation_vocabulary)
]

print(f"Number of sentences: {len(translation_sentences)}")
print(f"Number of valid sentences: {len(valid_sentence_indices)}")




Number of sentences: 12000
Number of valid sentences: 12000


In [12]:
translation_sentences = [translation_sentences[i] for i in valid_sentence_indices]
source_sentences = [source_sentences[i] for i in valid_sentence_indices]

In [13]:
print(len(translation_vocabulary))

15


In [14]:
import torch

d_model = 512
batch_size = 30
ffn_hidden = 2048
num_heads = 8
drop_prob = 0.1
num_layers = 1
max_sequence_length = 50
translation_vocab_size = len(translation_vocabulary)

transformer = Transformer(d_model,
                          ffn_hidden,
                          num_heads,
                          drop_prob,
                          num_layers,
                          max_sequence_length,
                          translation_vocab_size,
                          source_to_index,
                          translation_to_index,
                          START_TOKEN,
                          END_TOKEN,
                          PADDING_TOKEN)

In [15]:
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):

    def __init__(self, source_sentences, translation_sentences):
        self.source_sentences = source_sentences
        self.translation_sentences = translation_sentences

    def __len__(self):
        return len(self.source_sentences)

    def __getitem__(self, idx):
        return self.source_sentences[idx], self.translation_sentences[idx]

In [16]:
dataset = TextDataset(source_sentences, translation_sentences)


In [17]:
train_loader = DataLoader(dataset, batch_size)
iterator = iter(train_loader)
print(iterator)

<torch.utils.data.dataloader._SingleProcessDataLoaderIter object at 0x1639b5710>


In [18]:
for batch_num, batch in enumerate(iterator):
    source_batch, translation_batch = batch

    print(batch)
    if batch_num > 3:
        break


[('7 jul 1992', '20.07.70', '8/6/14', 'thursday february 20 1986', 'wednesday november 29 1989', 'thursday june 19 1980', 'thursday august 3 2000', '27 sep 1978', '18 sep 1976', 'tuesday may 25 1993', 'wednesday may 17 2000', 'monday june 4 1979', '24 nov 1999', 'tuesday march 7 1989', '11 october 2021', '8 december 1974', 'monday november 21 2022', '9 04 92', '27 march 1997', 'tuesday january 3 1978', 'saturday november 8 2008', '07.04.80', '17.06.88', 'april 10 1971', '7 may 1988', '14.04.22', '24 nov 2022', '28 mar 2010', '10 mar 1970', 'saturday october 24 2020'), ('1992-07-07', '1970-07-20', '2014-08-06', '1986-02-20', '1989-11-29', '1980-06-19', '2000-08-03', '1978-09-27', '1976-09-18', '1993-05-25', '2000-05-17', '1979-06-04', '1999-11-24', '1989-03-07', '2021-10-11', '1974-12-08', '2022-11-21', '1992-04-09', '1997-03-27', '1978-01-03', '2008-11-08', '1980-04-07', '1988-06-17', '1971-04-10', '1988-05-07', '2022-04-14', '2022-11-24', '2010-03-28', '1970-03-10', '2020-10-24')]
[('

In [19]:
from torch import nn

criterian = nn.CrossEntropyLoss(ignore_index=translation_to_index[PADDING_TOKEN],
                                reduction='none')

# When computing the loss, we are ignoring cases when the label is the padding token
for params in transformer.parameters():
    if params.dim() > 1:
        nn.init.xavier_uniform_(params)

optim = torch.optim.Adam(transformer.parameters(), lr=1e-4)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [20]:
NEG_INFTY = -1e9

def create_masks(source_batch,translation_batch):
    num_sentences = len(source_batch)
    look_ahead_mask = torch.full([max_sequence_length, max_sequence_length] , True)
    look_ahead_mask = torch.triu(look_ahead_mask, diagonal=1)
    encoder_padding_mask = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
    decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)

    for idx in range(num_sentences):
      source_sentence_length, translation_sentence_length = len(source_batch[idx]), len(translation_batch[idx])
      source_chars_to_padding_mask = np.arange(source_sentence_length + 1, max_sequence_length)
      translation_chars_to_padding_mask = np.arange(translation_sentence_length + 1, max_sequence_length)
      encoder_padding_mask[idx, :, source_chars_to_padding_mask] = True
      encoder_padding_mask[idx, source_chars_to_padding_mask, :] = True
      decoder_padding_mask_self_attention[idx, :, translation_chars_to_padding_mask] = True
      decoder_padding_mask_self_attention[idx, translation_chars_to_padding_mask, :] = True
      decoder_padding_mask_cross_attention[idx, :, source_chars_to_padding_mask] = True
      decoder_padding_mask_cross_attention[idx, translation_chars_to_padding_mask, :] = True

    encoder_self_attention_mask = torch.where(encoder_padding_mask, NEG_INFTY, 0)
    decoder_self_attention_mask =  torch.where(look_ahead_mask + decoder_padding_mask_self_attention, NEG_INFTY, 0)
    decoder_cross_attention_mask = torch.where(decoder_padding_mask_cross_attention, NEG_INFTY, 0)
    return encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask

In [21]:
transformer.train()
transformer.to(device)
total_loss = 0
num_epochs = 4

for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    iterator = iter(train_loader) # nedi
    for batch_num, batch in enumerate(iterator):
        transformer.train()
        source_batch, translation_batch = batch
        encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask = create_masks(source_batch, translation_batch)
        optim.zero_grad()
        translation_predictions = transformer(source_batch,
                                     translation_batch,
                                     encoder_self_attention_mask.to(device),
                                     decoder_self_attention_mask.to(device),
                                     decoder_cross_attention_mask.to(device),
                                     enc_start_token=False,
                                     enc_end_token=False,
                                     dec_start_token=True,
                                     dec_end_token=True)

        labels = transformer.decoder.sentence_embedding.batch_tokenize(translation_batch, start_token=False, end_token=True)
        loss = criterian(
            translation_predictions.view(-1, translation_vocab_size).to(device),
            labels.view(-1).to(device)
        ).to(device)
        valid_indicies = torch.where(labels.view(-1) == translation_to_index[PADDING_TOKEN], False, True)
        loss = loss.sum() / valid_indicies.sum()
        loss.backward()
        optim.step()
        #train_losses.append(loss.item())
        if batch_num % 100 == 0:
            print(f"Iteration {batch_num} : {loss.item()}")
            print(f"source: {source_batch[0]}")
            print(f"Translation: {translation_batch[0]}")
            translation_sentence_predicted = torch.argmax(translation_predictions[0], axis=1)
            predicted_sentence = ""
            for idx in translation_sentence_predicted:
              if idx == translation_to_index[END_TOKEN]:
                break
              predicted_sentence += index_to_translation[idx.item()]
            print(f"Prediction: {predicted_sentence}")


            transformer.eval()
            translation_sentence = ("",)
            source_sentence = ("saturday february 18 1995",)
            for word_counter in range(max_sequence_length):
                encoder_self_attention_mask, decoder_self_attention_mask, decoder_cross_attention_mask= create_masks(source_sentence, translation_sentence)
                predictions = transformer(source_sentence,
                                          translation_sentence,
                                          encoder_self_attention_mask.to(device),
                                          decoder_self_attention_mask.to(device),
                                          decoder_cross_attention_mask.to(device),
                                          enc_start_token=False,
                                          enc_end_token=False,
                                          dec_start_token=True,
                                          dec_end_token=False)
                next_token_prob_distribution = predictions[0][word_counter] # not actual probs
                next_token_index = torch.argmax(next_token_prob_distribution).item()
                next_token = index_to_translation[next_token_index]
                translation_sentence = (translation_sentence[0] + next_token, )
                if next_token == END_TOKEN:
                  break

            print(f"saturday february 18 1995 : {translation_sentence}")
            print("-------------------------------------------")

Epoch 0
Iteration 0 : 3.332223415374756
source: 7 jul 1992
Translation: 1992-07-07
Prediction:    11    6 1 - -3 11      16666616-  666666   6616
saturday february 18 1995 : ('000000----------000000000000---------------------1',)
-------------------------------------------
Iteration 100 : 1.350603461265564
source: 13 nov 1984
Translation: 1984-11-13
Prediction: 191--00-191111111101111111111111111211111111111111
saturday february 18 1995 : ('198-0-0-018301977100000111101111111111111111111111',)
-------------------------------------------
Iteration 200 : 1.0709675550460815
source: wednesday may 24 1972
Translation: 1972-05-24
Prediction: 1972-01-201112231211112111121211111111111111111111
saturday february 18 1995 : ('1971-11-119999999919999111111111111111111111111111',)
-------------------------------------------
Iteration 300 : 0.9650502800941467
source: saturday november 22 2014
Translation: 2014-11-22
Prediction: 2012-12-224528222122122222222222222221111212111111
saturday february 18 