In [7]:
import torch
import torch.nn as nn
import transformers_preprocess as data_preprocess
import pickle
from metrics import *
from torch import Tensor
import math
import numpy as np
from torch.nn import Transformer

In [8]:
device = torch.device("cpu")

In [9]:
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)


class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)

        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

In [10]:
class Trainer():
    
    def __init__(self,config_dict,device) -> None:

        self.config_dict = config_dict
        self.hidden_size = config_dict['hidden_size']
        self.batch_size = config_dict['batch_size']
        self.from_lang_str = config_dict['from_lang']
        self.max_length = config_dict['max_sentence_length']
        self.convert_eng_to_lang = config_dict['convert_eng_to_lang']

        self.use_pkl_data = config_dict['use_pkl_data']

        self.indx_sent_func = data_preprocess.indexesFromSentence
        self.optimizer = None

        self.device = device

        self.epoch_count = config_dict['epoch_count']
        self.learning_rate = config_dict['learning_rate']

        import pprint
        pprint.pprint(config_dict)

        in_lang_name = "eng" if self.convert_eng_to_lang else self.from_lang_str
        out_lang_name = self.from_lang_str if self.convert_eng_to_lang else "eng"

        save_str = f"{in_lang_name}_to_{out_lang_name}.pkl"

        with open(save_str, 'rb') as f:
            loads = pickle.load(f)


        self.in_lang = loads['input_lang']
        self.out_lang = loads['output_lang']
        self.in_ids = loads['test_input_ids']
        self.tgt_ids = loads['test_target_ids']
        self.pairs = loads['test_pairs']
        self.SOS_token = loads['SOS_token']
        self.EOS_token = loads['EOS_token']
        self.attention_flag = config_dict['use_attention']

        self.metrics = [CosineSimilarity(1),BLEUScore(1),METEORScore(1),ROUGEScore_custom(1)]

        SRC_VOCAB_SIZE = self.in_lang.n_words
        TGT_VOCAB_SIZE = self.out_lang.n_words
        EMB_SIZE = 512
        NHEAD = 8
        FFN_HID_DIM = 512
        NUM_ENCODER_LAYERS = 3
        NUM_DECODER_LAYERS = 3

        self.transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM).to(device)
        
        self.transformer.load_state_dict(torch.load('tf_models/fra_to_eng_tf.pth'))
        

    def generate_square_subsequent_mask(self,sz):
        mask = (torch.triu(torch.ones((sz, sz), device=device)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    
    def greedy_decode(self, model, fsrc, fsrc_masks,fends, start_symbol):
        outputs = []
        for i in range(len(fsrc)):
            src = fsrc[i].unsqueeze(1).to(self.device)
            src_mask = fsrc_masks[i].to(self.device)
            memory = model.encode(src, src_mask)
            ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(self.device)
            for _ in range(fends[i][0]-1):
                memory = memory.to(self.device)
                tgt_mask = (self.generate_square_subsequent_mask(ys.size(0))
                            .type(torch.bool)).to(self.device)
                out = model.decode(ys, memory, tgt_mask)
                out = out.transpose(0, 1)
                prob = model.generator(out[:, -1])
                _, next_word = torch.max(prob, dim=1)
                next_word = next_word.item()
                ys = torch.cat([ys,
                                torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
                if next_word == data_preprocess.EOS_token:
                    break

            outputs.append(ys)

        return outputs

    def string_from_X(self,decoder_outputs,type=0,is_input=False):
        strings = []
        if is_input:
            for i in decoder_outputs:
                decoded_ids = i
                decoded_words = []
                for idx in decoded_ids:
                    if idx == self.EOS_token:
                        break
                    decoded_words.append(self.in_lang.index2word[idx.item()])
                strings.append(' '.join(decoded_words))
        else:
            if type == 0:
                for i in decoder_outputs:
                    _, topi = i.topk(1)
                    decoded_ids = topi.squeeze()

                    decoded_words = []
                    for idx in decoded_ids:
                        if idx.item() == self.EOS_token:
                            break
                        decoded_words.append(self.out_lang.index2word[idx.item()])
                    strings.append(' '.join(decoded_words))
            else:
                for i in decoder_outputs:
                    decoded_ids = i
                    decoded_words = []
                    for idx in decoded_ids:
                        if idx == self.EOS_token:
                            break
                        decoded_words.append(self.out_lang.index2word[idx.item()])
                    decoded_words = decoded_words[1:]
                    strings.append(' '.join(decoded_words))

        return strings

    # actual function to translate input sentence into target language
    def translate(self, model: torch.nn.Module, sentence_batch):
        model.eval()
        masks = []
        ends = []
        sentences = []
        long_flag = False
        for i in range(sentence_batch.shape[0]):
            current_sentence = sentence_batch[i]
            end_val = (current_sentence == data_preprocess.EOS_token).nonzero(as_tuple=True)[0]
            ends.append(end_val+10)
            if end_val > 20:
                long_flag = True
            current_sentence = current_sentence[:end_val.item()]
            sentences.append(current_sentence)
            src_mask = (torch.zeros(current_sentence.shape[0],current_sentence.shape[0])).type(torch.bool)
            masks.append(src_mask)

        tgt_tokens = self.greedy_decode(
            model,  sentences, masks,ends, start_symbol=data_preprocess.SOS_token)


        strings = self.string_from_X(tgt_tokens,type=1)
        return strings,long_flag

        # for i in tgt_tokens:
        #     tens = i.flatten().tolist()
        #     string_repr = self.string_from_X(tens,type=1)
        #     print(string_repr)


    def eval(self,in_i,tgt_i):
        self.transformer.eval()
        long = False
        counter = 0
        while not long:
            choices = np.random.choice(len(in_i), 1, replace=False)
            counter += 1
            input_sentence_batch = in_i[choices]
            target_sentence_batch = tgt_i[choices]
            valid_str = input_sentence_batch[:,1:]

            input_strings = self.string_from_X(valid_str,is_input=True)
            translated_strings,long = self.translate(self.transformer,valid_str)
            target_strings = self.string_from_X(target_sentence_batch,type=1)

            for i in range(len(input_strings)):
                print("Counter =",counter)
                print("Input >",input_strings[i])
                print("Target >",target_strings[i])
                print("Translated =",translated_strings[i])
                print('\n')


In [11]:
config_dict = {
    "convert_eng_to_lang": False,
    "epoch_count": 20000,
    "learning_rate": 0.001,
    "max_sentence_length": 25,
    "hidden_size": 1024,
    "batch_size": 5,
    "from_lang": "fra",
    "use_pkl_data": True,
    "rnn_type": "gru",
    "loss_type": "nll",
    "optimizer": "adam",
    "use_attention": True
}
trainer = Trainer(config_dict,device)


{'batch_size': 5,
 'convert_eng_to_lang': False,
 'epoch_count': 20000,
 'from_lang': 'fra',
 'hidden_size': 1024,
 'learning_rate': 0.001,
 'loss_type': 'nll',
 'max_sentence_length': 25,
 'optimizer': 'adam',
 'rnn_type': 'gru',
 'use_attention': True,
 'use_pkl_data': True}


In [12]:
trainer.eval(trainer.in_ids,trainer.tgt_ids)

Counter = 1
Input > es-tu certain que c'est sur  ?
Target > are you sure that's safe ?
Translated = are you sure that's safe ?


Counter = 2
Input > quant a moi, je veux rester en vie .
Target > i want to stay alive .
Translated = i want to stay alive .


Counter = 3
Input > elle le soigna jusqu'a ce qu'il recouvre la sante .
Target > she nursed him back to health .
Translated = she treated her best in health .


Counter = 4
Input > ce livre est tres petit .
Target > this book is very small .
Translated = this book is very small .


Counter = 5
Input > tom chuchota quelque chose a l’oreille de marie .
Target > tom whispered something into mary's ear .
Translated = tom whispered something to mary .


Counter = 6
Input > il a rappele a sa femme de le reveiller a 7 heures du matin .
Target > he reminded his wife to wake him up at 7:00 a .m .
Translated = he reminded his wife to wake him up at seven in the morning .


Counter = 7
Input > tom chuchota quelque chose a l’oreille de marie .
Ta

KeyboardInterrupt: 