In [20]:
import torch
import unidecode
import itertools
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re
import time
from tqdm.notebook import tqdm
import nltk
from nltk.tokenize import word_tokenize
import pandas as pd
import pickle
import string
import random
from tqdm.notebook import tqdm
from nltk.tokenize.treebank import TreebankWordDetokenizer
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import Dataset, DataLoader
import numpy as np
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import matplotlib.ticker as ticker
import math
import gc
# nltk.download('punkt')
# sentence_tokenizer  =  nltk.data.load('tokenizers/punkt/english.pickle')

In [21]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

set_seed(seed = 1)

In [22]:
#borrow from https://github.com/hisiter97/Spelling_Correction_Vietnamese/blob/master/dataset/add_noise.py
class SynthesizeData(object):
    """
    Uitils class to create artificial miss-spelled words
    Args:
        vocab_path: path to vocab file. Vocab file is expected to be a set of words, separate by ' ', no newline charactor.
    """

    def __init__(self, vocab_path=""):

        # self.vocab = open(vocab_path, 'r', encoding = 'utf-8').read().split()
        self.tokenizer = word_tokenize
        self.word_couples = [['sương', 'xương'], ['sĩ', 'sỹ'], ['sẽ', 'sẻ'], ['sã', 'sả'], ['sả', 'xả'], ['sẽ', 'sẻ'],
                             ['mùi', 'muồi'],
                             ['chỉnh', 'chỉn'], ['sữa', 'sửa'], ['chuẩn', 'chẩn'], ['lẻ', 'lẽ'], ['chẳng', 'chẵng'],
                             ['cổ', 'cỗ'],
                             ['sát', 'xát'], ['cập', 'cặp'], ['truyện', 'chuyện'], ['xá', 'sá'], ['giả', 'dả'],
                             ['đỡ', 'đở'],
                             ['giữ', 'dữ'], ['giã', 'dã'], ['xảo', 'sảo'], ['kiểm', 'kiễm'], ['cuộc', 'cục'],
                             ['dạng', 'dạn'],
                             ['tản', 'tảng'], ['ngành', 'nghành'], ['nghề', 'ngề'], ['nổ', 'nỗ'], ['rảnh', 'rãnh'],
                             ['sẵn', 'sẳn'],
                             ['sáng', 'xán'], ['xuất', 'suất'], ['suôn', 'suông'], ['sử', 'xử'], ['sắc', 'xắc'],
                             ['chữa', 'chửa'],
                             ['thắn', 'thắng'], ['dỡ', 'dở'], ['trải', 'trãi'], ['trao', 'trau'], ['trung', 'chung'],
                             ['thăm', 'tham'],
                             ['sét', 'xét'], ['dục', 'giục'], ['tả', 'tã'], ['sông', 'xông'], ['sáo', 'xáo'],
                             ['sang', 'xang'],
                             ['ngã', 'ngả'], ['xuống', 'suống'], ['xuồng', 'suồng']]

        self.vn_alphabet = ['a', 'ă', 'â', 'b', 'c', 'd', 'đ', 'e', 'ê', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'ô',
                            'ơ', 'p', 'q', 'r', 's', 't', 'u', 'ư', 'v', 'x', 'y']
        self.alphabet_len = len(self.vn_alphabet)
        self.char_couples = [['i', 'y'], ['s', 'x'], ['gi', 'd'],
                             ['ă', 'â'], ['ch', 'tr'], ['ng', 'n'],
                             ['nh', 'n'], ['ngh', 'ng'], ['ục', 'uộc'], ['o', 'u'],
                             ['ă', 'a'], ['o', 'ô'], ['ả', 'ã'], ['ổ', 'ỗ'], ['ủ', 'ũ'], ['ễ', 'ể'],
                             ['e', 'ê'], ['à', 'ờ'], ['ằ', 'à'], ['ẩn', 'uẩn'], ['ẽ', 'ẻ'], ['ùi', 'uồi'], ['ă', 'â'],
                             ['ở', 'ỡ'], ['ỹ', 'ỷ'], ['ỉ', 'ĩ'], ['ị', 'ỵ'],
                             ['ấ', 'á'], ['n', 'l'], ['qu', 'w'], ['ph', 'f'], ['d', 'z'], ['c', 'k'], ['qu', 'q'],
                             ['i', 'j'], ['gi', 'j'],
                             ]

        self.teencode_dict = {'mình': ['mk', 'mik', 'mjk'], 'vô': ['zô', 'zo', 'vo'], 'vậy': ['zậy', 'z', 'zay', 'za'],
                              'phải': ['fải', 'fai', ], 'biết': ['bit', 'biet'],
                              'rồi': ['rùi', 'ròi', 'r'], 'bây': ['bi', 'bay'], 'giờ': ['h', ],
                              'không': ['k', 'ko', 'khong', 'hk', 'hong', 'hông', '0', 'kg', 'kh', ],
                              'đi': ['di', 'dj', ], 'gì': ['j', ], 'em': ['e', ], 'được': ['dc', 'đc', ], 'tao': ['t'],
                              'tôi': ['t'], 'chồng': ['ck'], 'vợ': ['vk']

                              }

        # self.typo={"ă":"aw","â":"aa","á":"as","à":"af","ả":"ar","ã":"ax","ạ":"aj","ắ":"aws","ổ":"oor","ỗ":"oox","ộ":"ooj","ơ":"ow",
        #           "ằ":"awf","ẳ":"awr","ẵ":"awx","ặ":"awj","ó":"os","ò":"of","ỏ":"or","õ":"ox","ọ":"oj","ô":"oo","ố":"oos","ồ":"oof",
        #           "ớ":"ows","ờ":"owf","ở":"owr","ỡ":"owx","ợ":"owj","é":"es","è":"ef","ẻ":"er","ẽ":"ex","ẹ":"ej","ê":"ee","ế":"ees","ề":"eef",
        #           "ể":"eer","ễ":"eex","ệ":"eej","ú":"us","ù":"uf","ủ":"ur","ũ":"ux","ụ":"uj","ư":"uw","ứ":"uws","ừ":"uwf","ử":"uwr","ữ":"uwx",
        #           "ự":"uwj","í":"is","ì":"if","ỉ":"ir","ị":"ij","ĩ":"ix","ý":"ys","ỳ":"yf","ỷ":"yr","ỵ":"yj","đ":"dd",
        #           "Ă":"Aw","Â":"Aa","Á":"As","À":"Af","Ả":"Ar","Ã":"Ax","Ạ":"Aj","Ắ":"Aws","Ổ":"Oor","Ỗ":"Oox","Ộ":"Ooj","Ơ":"Ow",
        #           "Ằ":"AWF","Ẳ":"Awr","Ẵ":"Awx","Ặ":"Awj","Ó":"Os","Ò":"Of","Ỏ":"Or","Õ":"Ox","Ọ":"Oj","Ô":"Oo","Ố":"Oos","Ồ":"Oof",
        #           "Ớ":"Ows","Ờ":"Owf","Ở":"Owr","Ỡ":"Owx","Ợ":"Owj","É":"Es","È":"Ef","Ẻ":"Er","Ẽ":"Ex","Ẹ":"Ej","Ê":"Ee","Ế":"Ees","Ề":"Eef",
        #           "Ể":"Eer","Ễ":"Eex","Ệ":"Eej","Ú":"Us","Ù":"Uf","Ủ":"Ur","Ũ":"Ux","Ụ":"Uj","Ư":"Uw","Ứ":"Uws","Ừ":"Uwf","Ử":"Uwr","Ữ":"Uwx",
        #           "Ự":"Uwj","Í":"Is","Ì":"If","Ỉ":"Ir","Ị":"Ij","Ĩ":"Ix","Ý":"Ys","Ỳ":"Yf","Ỷ":"Yr","Ỵ":"Yj","Đ":"Dd"}
        self.typo = {"ă": ["aw", "a8"], "â": ["aa", "a6"], "á": ["as", "a1"], "à": ["af", "a2"], "ả": ["ar", "a3"],
                     "ã": ["ax", "a4"], "ạ": ["aj", "a5"], "ắ": ["aws", "ă1"], "ổ": ["oor", "ô3"], "ỗ": ["oox", "ô4"],
                     "ộ": ["ooj", "ô5"], "ơ": ["ow", "o7"],
                     "ằ": ["awf", "ă2"], "ẳ": ["awr", "ă3"], "ẵ": ["awx", "ă4"], "ặ": ["awj", "ă5"], "ó": ["os", "o1"],
                     "ò": ["of", "o2"], "ỏ": ["or", "o3"], "õ": ["ox", "o4"], "ọ": ["oj", "o5"], "ô": ["oo", "o6"],
                     "ố": ["oos", "ô1"], "ồ": ["oof", "ô2"],
                     "ớ": ["ows", "ơ1"], "ờ": ["owf", "ơ2"], "ở": ["owr", "ơ2"], "ỡ": ["owx", "ơ4"], "ợ": ["owj", "ơ5"],
                     "é": ["es", "e1"], "è": ["ef", "e2"], "ẻ": ["er", "e3"], "ẽ": ["ex", "e4"], "ẹ": ["ej", "e5"],
                     "ê": ["ee", "e6"], "ế": ["ees", "ê1"], "ề": ["eef", "ê2"],
                     "ể": ["eer", "ê3"], "ễ": ["eex", "ê3"], "ệ": ["eej", "ê5"], "ú": ["us", "u1"], "ù": ["uf", "u2"],
                     "ủ": ["ur", "u3"], "ũ": ["ux", "u4"], "ụ": ["uj", "u5"], "ư": ["uw", "u7"], "ứ": ["uws", "ư1"],
                     "ừ": ["uwf", "ư2"], "ử": ["uwr", "ư3"], "ữ": ["uwx", "ư4"],
                     "ự": ["uwj", "ư5"], "í": ["is", "i1"], "ì": ["if", "i2"], "ỉ": ["ir", "i3"], "ị": ["ij", "i5"],
                     "ĩ": ["ix", "i4"], "ý": ["ys", "y1"], "ỳ": ["yf", "y2"], "ỷ": ["yr", "y3"], "ỵ": ["yj", "y5"],
                     "đ": ["dd", "d9"],
                     "Ă": ["Aw", "A8"], "Â": ["Aa", "A6"], "Á": ["As", "A1"], "À": ["Af", "A2"], "Ả": ["Ar", "A3"],
                     "Ã": ["Ax", "A4"], "Ạ": ["Aj", "A5"], "Ắ": ["Aws", "Ă1"], "Ổ": ["Oor", "Ô3"], "Ỗ": ["Oox", "Ô4"],
                     "Ộ": ["Ooj", "Ô5"], "Ơ": ["Ow", "O7"],
                     "Ằ": ["AWF", "Ă2"], "Ẳ": ["Awr", "Ă3"], "Ẵ": ["Awx", "Ă4"], "Ặ": ["Awj", "Ă5"], "Ó": ["Os", "O1"],
                     "Ò": ["Of", "O2"], "Ỏ": ["Or", "O3"], "Õ": ["Ox", "O4"], "Ọ": ["Oj", "O5"], "Ô": ["Oo", "O6"],
                     "Ố": ["Oos", "Ô1"], "Ồ": ["Oof", "Ô2"],
                     "Ớ": ["Ows", "Ơ1"], "Ờ": ["Owf", "Ơ2"], "Ở": ["Owr", "Ơ3"], "Ỡ": ["Owx", "Ơ4"], "Ợ": ["Owj", "Ơ5"],
                     "É": ["Es", "E1"], "È": ["Ef", "E2"], "Ẻ": ["Er", "E3"], "Ẽ": ["Ex", "E4"], "Ẹ": ["Ej", "E5"],
                     "Ê": ["Ee", "E6"], "Ế": ["Ees", "Ê1"], "Ề": ["Eef", "Ê2"],
                     "Ể": ["Eer", "Ê3"], "Ễ": ["Eex", "Ê4"], "Ệ": ["Eej", "Ê5"], "Ú": ["Us", "U1"], "Ù": ["Uf", "U2"],
                     "Ủ": ["Ur", "U3"], "Ũ": ["Ux", "U4"], "Ụ": ["Uj", "U5"], "Ư": ["Uw", "U7"], "Ứ": ["Uws", "Ư1"],
                     "Ừ": ["Uwf", "Ư2"], "Ử": ["Uwr", "Ư3"], "Ữ": ["Uwx", "Ư4"],
                     "Ự": ["Uwj", "Ư5"], "Í": ["Is", "I1"], "Ì": ["If", "I2"], "Ỉ": ["Ir", "I3"], "Ị": ["Ij", "I5"],
                     "Ĩ": ["Ix", "I4"], "Ý": ["Ys", "Y1"], "Ỳ": ["Yf", "Y2"], "Ỷ": ["Yr", "Y3"], "Ỵ": ["Yj", "Y5"],
                     "Đ": ["Dd", "D9"]}
        self.all_word_candidates = self.get_all_word_candidates(self.word_couples)
        self.string_all_word_candidates = ' '.join(self.all_word_candidates)
        self.all_char_candidates = self.get_all_char_candidates()
        self.keyboardNeighbors = self.getKeyboardNeighbors()

    def replace_teencode(self, word):
        candidates = self.teencode_dict.get(word, None)
        if candidates is not None:
            chosen_one = 0
            if len(candidates) > 1:
                chosen_one = np.random.randint(0, len(candidates))
            return candidates[chosen_one]

    def getKeyboardNeighbors(self):
        keyboardNeighbors = {}
        keyboardNeighbors['a'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ă'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['â'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['á'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['à'] = "aáàảãăắằẳẵâấầẩẫ"
        keyboardNeighbors['ả'] = "aảã"
        keyboardNeighbors['ã'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ạ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ắ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ằ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ẳ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ặ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ẵ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ấ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ầ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ẩ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ẫ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ậ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['b'] = "bh"
        keyboardNeighbors['c'] = "cgn"
        keyboardNeighbors['d'] = "đctơở"
        keyboardNeighbors['đ'] = "d"
        keyboardNeighbors['e'] = "eéèẻẽẹêếềểễệbpg"
        keyboardNeighbors['é'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['è'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['ẻ'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['ẽ'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['ẹ'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['ê'] = "eéèẻẽẹêếềểễệá"
        keyboardNeighbors['ế'] = "eéèẻẽẹêếềểễệố"
        keyboardNeighbors['ề'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['ể'] = "eéèẻẽẹêếềểễệôốồổỗộ"
        keyboardNeighbors['ễ'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['ệ'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['g'] = "qgộ"
        keyboardNeighbors['h'] = "h"
        keyboardNeighbors['i'] = "iíìỉĩịat"
        keyboardNeighbors['í'] = "iíìỉĩị"
        keyboardNeighbors['ì'] = "iíìỉĩị"
        keyboardNeighbors['ỉ'] = "iíìỉĩị"
        keyboardNeighbors['ĩ'] = "iíìỉĩị"
        keyboardNeighbors['ị'] = "iíìỉĩịhự"
        keyboardNeighbors['k'] = "klh"
        keyboardNeighbors['l'] = "ljidđ"
        keyboardNeighbors['m'] = "mn"
        keyboardNeighbors['n'] = "mnedư"
        keyboardNeighbors['o'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ó'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ò'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ỏ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['õ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ọ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ô'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ố'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ồ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ổ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ộ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ỗ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ơ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ớ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ờ'] = "oóòỏọõôốồổỗộơớờởợỡà"
        keyboardNeighbors['ở'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ợ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ỡ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        # keyboardNeighbors['p'] = "op"
        # keyboardNeighbors['q'] = "qọ"
        # keyboardNeighbors['r'] = "rht"
        # keyboardNeighbors['s'] = "s"
        # keyboardNeighbors['t'] = "tp"
        keyboardNeighbors['u'] = "uúùủũụưứừữửựhiaạt"
        keyboardNeighbors['ú'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ù'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ủ'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ũ'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ụ'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ư'] = "uúùủũụưứừữửựg"
        keyboardNeighbors['ứ'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ừ'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ử'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ữ'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ự'] = "uúùủũụưứừữửựg"
        keyboardNeighbors['v'] = "v"
        keyboardNeighbors['x'] = "x"
        keyboardNeighbors['y'] = "yýỳỷỵỹụ"
        keyboardNeighbors['ý'] = "yýỳỷỵỹ"
        keyboardNeighbors['ỳ'] = "yýỳỷỵỹ"
        keyboardNeighbors['ỷ'] = "yýỳỷỵỹ"
        keyboardNeighbors['ỵ'] = "yýỳỷỵỹ"
        keyboardNeighbors['ỹ'] = "yýỳỷỵỹ"
        # keyboardNeighbors['w'] = "wv"
        # keyboardNeighbors['j'] = "jli"
        # keyboardNeighbors['z'] = "zs"
        # keyboardNeighbors['f'] = "ft"

        return keyboardNeighbors

    def replace_char_noaccent(self, text, onehot_label):

        # find index noise
        idx = np.random.randint(0, len(onehot_label))
        prevent_loop = 0
        while onehot_label[idx] == 1 or text[idx].isnumeric() or text[idx] in string.punctuation:
            idx = np.random.randint(0, len(onehot_label))
            prevent_loop += 1
            if prevent_loop > 10:
                return False, text, onehot_label

        index_noise = idx
        onehot_label[index_noise] = 1
        word_noise = text[index_noise]
        for id in range(0, len(word_noise)):
            char = word_noise[id]

            if char in self.keyboardNeighbors:
                neighbors = self.keyboardNeighbors[char]
                idx_neigh = np.random.randint(0, len(neighbors))
                replaced = neighbors[idx_neigh]
                word_noise = word_noise[: id] + replaced + word_noise[id + 1:]
                text[index_noise] = word_noise
                return True, text, onehot_label

        return False, text, onehot_label

    def replace_word_candidate(self, word):
        """
        Return a homophone word of the input word.
        """
        capital_flag = word[0].isupper()
        word = word.lower()
        if capital_flag and word in self.teencode_dict:
            return self.replace_teencode(word).capitalize()
        elif word in self.teencode_dict:
            return self.replace_teencode(word)

        for couple in self.word_couples:
            for i in range(2):
                if couple[i] == word:
                    if i == 0:
                        if capital_flag:
                            return couple[1].capitalize()
                        else:
                            return couple[1]
                    else:
                        if capital_flag:
                            return couple[0].capitalize()
                        else:
                            return couple[0]

    def replace_char_candidate(self, char):
        """
        return a homophone char/subword of the input char.
        """
        for couple in self.char_couples:
            for i in range(2):
                if couple[i] == char:
                    if i == 0:
                        return couple[1]
                    else:
                        return couple[0]

    def replace_char_candidate_typo(self, char):
        """
        return a homophone char/subword of the input char.
        """
        i = np.random.randint(0, 2)

        return self.typo[char][i]

    def get_all_char_candidates(self, ):

        all_char_candidates = []
        for couple in self.char_couples:
            all_char_candidates.extend(couple)
        return all_char_candidates

    def get_all_word_candidates(self, word_couples):

        all_word_candidates = []
        for couple in self.word_couples:
            all_word_candidates.extend(couple)
        return all_word_candidates

    def remove_diacritics(self, text, onehot_label):
        """
        Replace word which has diacritics with the same word without diacritics
        Args:
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this
            function only choose the word that do not has onehot label == 1.
        return: a list of word tokens has one word that its diacritics was removed,
                a list of onehot label indicate the position of words that has been modified.
        """
        idx = np.random.randint(0, len(onehot_label))
        prevent_loop = 0
        while onehot_label[idx] == 1 or text[idx] == unidecode.unidecode(text[idx]) or text[idx] in string.punctuation:
            idx = np.random.randint(0, len(onehot_label))
            prevent_loop += 1
            if prevent_loop > 10:
                return False, text, onehot_label

        onehot_label[idx] = 1
        text[idx] = unidecode.unidecode(text[idx])
        return True, text, onehot_label

    def replace_with_random_letter(self, text, onehot_label):
        """
        Replace, add (or remove) a random letter in a random chosen word with a random letter
        Args:
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this
            function only choose the word that do not has onehot label == 1.
        return: a list of word tokens has one word that has been modified,
                a list of onehot label indicate the position of words that has been modified.
        """
        idx = np.random.randint(0, len(onehot_label))
        prevent_loop = 0
        while onehot_label[idx] == 1 or text[idx].isnumeric() or text[idx] in string.punctuation:
            idx = np.random.randint(0, len(onehot_label))
            prevent_loop += 1
            if prevent_loop > 10:
                return False, text, onehot_label

        # replace, add or remove? 0 is replace, 1 is add, 2 is remove
        coin = np.random.choice([0, 1, 2])
        if coin == 0:
            chosen_letter = text[idx][np.random.randint(0, len(text[idx]))]
            replaced = self.vn_alphabet[np.random.randint(0, self.alphabet_len)]
            try:
                text[idx] = re.sub(chosen_letter, replaced, text[idx])
            except:
                return False, text, onehot_label
        elif coin == 1:
            chosen_letter = text[idx][np.random.randint(0, len(text[idx]))]
            replaced = chosen_letter + self.vn_alphabet[np.random.randint(0, self.alphabet_len)]
            try:
                text[idx] = re.sub(chosen_letter, replaced, text[idx])
            except:
                return False, text, onehot_label
        else:
            chosen_letter = text[idx][np.random.randint(0, len(text[idx]))]
            try:
                text[idx] = re.sub(chosen_letter, '', text[idx])
            except:
                return False, text, onehot_label

        onehot_label[idx] = 1
        return True, text, onehot_label

    def replace_with_homophone_word(self, text, onehot_label):
        """
        Replace a candidate word (if exist in the word_couple) with its homophone. if successful, return True, else False
        Args:
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this
            function only choose the word that do not has onehot label == 1.
        return: True, text, onehot_label if successful replace, else False, text, onehot_label
        """
        # account for the case that the word in the text is upper case but its lowercase match the candidates list
        candidates = []
        for i in range(len(text)):
            if text[i].lower() in self.all_word_candidates or text[i].lower() in self.teencode_dict.keys():
                candidates.append((i, text[i]))

        if len(candidates) == 0:
            return False, text, onehot_label

        idx = np.random.randint(0, len(candidates))
        prevent_loop = 0
        while onehot_label[candidates[idx][0]] == 1:
            idx = np.random.choice(np.arange(0, len(candidates)))
            prevent_loop += 1
            if prevent_loop > 5:
                return False, text, onehot_label

        text[candidates[idx][0]] = self.replace_word_candidate(candidates[idx][1])
        onehot_label[candidates[idx][0]] = 1
        return True, text, onehot_label

    def replace_with_homophone_letter(self, text, onehot_label):
        """
        Replace a subword/letter with its homophones
        Args:
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this
            function only choose the word that do not has onehot label == 1.
        return: True, text, onehot_label if successful replace, else False, None, None
        """
        candidates = []
        for i in range(len(text)):
            for char in self.all_char_candidates:
                if re.search(char, text[i]) is not None:
                    candidates.append((i, char))
                    break

        if len(candidates) == 0:

            return False, text, onehot_label
        else:
            idx = np.random.randint(0, len(candidates))
            prevent_loop = 0
            while onehot_label[candidates[idx][0]] == 1:
                idx = np.random.randint(0, len(candidates))
                prevent_loop += 1
                if prevent_loop > 5:
                    return False, text, onehot_label

            replaced = self.replace_char_candidate(candidates[idx][1])
            text[candidates[idx][0]] = re.sub(candidates[idx][1], replaced, text[candidates[idx][0]])

            onehot_label[candidates[idx][0]] = 1
            return True, text, onehot_label

    def replace_with_typo_letter(self, text, onehot_label):
        """
        Replace a subword/letter with its homophones
        Args:
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this
            function only choose the word that do not has onehot label == 1.
        return: True, text, onehot_label if successful replace, else False, None, None
        """
        # find index noise
        idx = np.random.randint(0, len(onehot_label))
        prevent_loop = 0
        while onehot_label[idx] == 1 or text[idx].isnumeric() or text[idx] in string.punctuation:
            idx = np.random.randint(0, len(onehot_label))
            prevent_loop += 1
            if prevent_loop > 10:
                return False, text, onehot_label

        index_noise = idx
        onehot_label[index_noise] = 1

        word_noise = text[index_noise]
        for j in range(0, len(word_noise)):
            char = word_noise[j]

            if char in self.typo:
                replaced = self.replace_char_candidate_typo(char)
                word_noise = word_noise[: j] + replaced + word_noise[j + 1:]
                text[index_noise] = word_noise
                return True, text, onehot_label
        return True, text, onehot_label

    def add_noise(self, sentence, percent_err=0.3, num_type_err=5):
        tokens = self.tokenizer(sentence)
        onehot_label = [0] * len(tokens)

        num_wrong = int(np.ceil(percent_err * len(tokens)))
        num_wrong = np.random.randint(1, num_wrong + 1)
        if np.random.rand() < 0.05:
            num_wrong = 0

        for i in range(0, num_wrong):
            err = np.random.randint(0, num_type_err + 1)

            if err == 0:
                _, tokens, onehot_label = self.replace_with_homophone_letter(tokens, onehot_label)
            elif err == 1:
                _, tokens, onehot_label = self.replace_with_typo_letter(tokens, onehot_label)
            elif err == 2:
                _, tokens, onehot_label = self.replace_with_homophone_word(tokens, onehot_label)
            elif err == 3:
                _, tokens, onehot_label = self.replace_with_random_letter(tokens, onehot_label)
            elif err == 4:
                _, tokens, onehot_label = self.remove_diacritics(tokens, onehot_label)
            elif err == 5:
                _, tokens, onehot_label = self.replace_char_noaccent(tokens, onehot_label)
            else:
                continue
            # print(tokens)
        return ' '.join(tokens)

# Preprocessing

# Params CONFIG


In [23]:
print("here")
MAXLEN = 40
NGRAM = 5
alphabets = 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬ0bBcCdDđĐeEè1ÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈ2ĩĨíÍịỊjJkKlLmMnNoO3òÒỏỎõÕóÓọỌôÔồ4ỒổỔỗỖốỐộỘơƠờỜ5ởỞỡỠớỚợỢpP6qQrRsStTuUùÙủỦ7ũŨúÚụỤưƯừỪửỬữỮứỨựỰvVw8WxXyYỳỲỷỶ9ỹỸýÝỵỴzZ!"#$%&\'()*+,-./:;<=>?@[\]^_`{|}~ '
# alphabets = 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&\'()*+,-./:;<=>?@[\]^_`{|}~ ]'
print(len(alphabets))
print(alphabets[76])
print(alphabets)

In [24]:
class Vocab():
    def __init__(self, chars):
        self.pad = 0
        self.go = 1
        self.eos = 2

        self.chars = chars
       
        
        self.i2c = {i + 3: c for i, c in enumerate(chars)}
        
        self.c2i = {c: i + 3 for i, c in enumerate(chars)}
        
        self.i2c[0] = '<pad>'
        self.i2c[1] = '<sos>'
        self.i2c[2] = '<eos>'

        
        

    def encode(self, chars):
        return [self.go] + [self.c2i[c] for c in chars] + [self.eos]

    def decode(self, ids):
        first = 1 if self.go in ids else 0
        last = ids.index(self.eos) if self.eos in ids else None
        sent = ''.join([self.i2c[i] for i in ids[first:last]])
        return sent

    def __len__(self):
        return len(self.c2i) + 3

    def batch_decode(self, arr):
        texts = [self.decode(ids) for ids in arr]
        return texts

    def __str__(self):
        return self.chars
    
vocab = Vocab(alphabets)

In [25]:
list_ngram_target = np.load('../input/5gram-nlp-project/list_5gram_nonum_train.npy')
list_ngram_target  = list_ngram_target [:1000000]
list_ngram_valid_target = np.load('../input/5gram-nlp-project/list_5gram_nonum_valid.npy')
list_ngram_valid_target = list_ngram_valid_target[5000:15000]
list_ngram_train = np.load('../input/5gram-nlp-project/train_normal_captions.npy')
list_ngram_valid = np.load('../input/5gram-nlp-project/valid_normal_captions.npy')
class SpellingDataset(Dataset):
    def __init__(self, list_ngram, list_ngram_target, vocab, maxlen):
        self.list_ngram = list_ngram
        self.list_ngram_target = list_ngram_target
        self.vocab = vocab
        self.max_len = maxlen
    
    def __getitem__(self, index):
        train_text = self.list_ngram[index]
        train_target = self.list_ngram_target[index]
        
        
        train_text_encode = self.vocab.encode(train_text)
        train_target_encode = self.vocab.encode(train_target)
        
        train_text_length = len(train_text_encode)
        train_target_length = len(train_target_encode)
        
        if(train_text_length < self.max_len):
            pad_length = self.max_len-train_text_length
            train_text_encode = np.array(train_text_encode)
            train_text_encode = np.concatenate((train_text_encode, np.zeros(pad_length)), axis = 0)
            
        elif(train_text_length>= self.max_len):
            train_text_encode = train_text_encode[0:self.max_len]
            train_text_encode = np.array(train_text_encode)
            
        if(train_target_length < self.max_len):
            pad_length = self.max_len-train_target_length
            train_target_encode = np.array(train_target_encode)
            train_target_encode = np.concatenate((train_target_encode, np.zeros(pad_length)), axis = 0)
            
        elif(train_target_length>= self.max_len):
            train_target_encode = train_target_encode[0:self.max_len]
            train_target_encode = np.array(train_target_encode)      
               
        tensor_text = torch.from_numpy(train_text_encode)
        tensor_target = torch.from_numpy(train_target_encode)
        return tensor_text, tensor_target
        
        
    
    def __len__(self):
        return len(self.list_ngram)

ds_train = SpellingDataset(list_ngram_train, list_ngram_target, vocab, MAXLEN)
ds_valid = SpellingDataset(list_ngram_valid, list_ngram_valid_target,vocab, MAXLEN)
# ds_test = SpellingDataset(list_ngram_test, synthesizer, vocab, MAXLEN)
train_loader = DataLoader(ds_train, batch_size = 512 , shuffle=True)
val_loader = DataLoader(ds_valid, batch_size = 1)
# test_loader = DataLoader(ds_test, batch_size = 200)
print(len(train_loader), len(val_loader))
text, target = next(iter(train_loader))
valid_text, valid_target = next(iter(train_loader))
print(text[0])
print(target[0])
print("Text: ", vocab.decode(np.squeeze(text[0].detach().numpy()).tolist()))
print("Target: ", vocab.decode(np.squeeze(target[0].detach().numpy()).tolist()))
print("Text: ", vocab.decode(np.squeeze(valid_text[0].detach().numpy()).tolist()))
print("Target: ", vocab.decode(np.squeeze(valid_target[0].detach().numpy()).tolist()))
print(text, target)
print(text.size(), target.size())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [26]:
class Encoder(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, num_layers):
        super(Encoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embedding = nn.Embedding(input_size, embedding_size)
        
        # input shape: seq_length, batchsize, embedding_dim
        self.lstm = nn.LSTM(embedding_size, hidden_size, num_layers, bidirectional=True)
        self.fc = nn.Linear(hidden_size * 2, hidden_size)
#     def init_hidden(self):
#         # This is what we'll initialise our hidden state as
#         return (torch.zeros(self.num_layers, batch_size, self.hidden_size),
#                 torch.zeros(self.num_layers, batch_size, self.hidden_size))
    
    def forward(self, x):
        # x shape: (N, seq_length) where N is batch size
        embedding = self.embedding(x)
        # embedding shape: (N, seq_length, embedding_size)
        
#         embedding = torch.transpose(embedding, 0, 1)
#         print("Embedding shape", embedding.size())
        outputs, (hidden_state, cell_state) = self.lstm(embedding)
#         Embedding shape torch.Size([seq_len, batch_size, embedding_size])
#         Outputs shape torch.Size([seq_len, batch_size, hidden_dim*2])
#         Hidden state shape torch.Size([num_layer*2, batch_size, hidden_dim])
#         Cell state shape torch.Size([num_layer*2, batch_size, hidden_dim])

#         print("Outputs shape", outputs.size())
#         print("Hidden state shape", hidden_state.size())
#         print("Cell state shape", cell_state.size())
#         print("Outputs shape", outputs.size())
        
        hidden_state = self.fc(torch.cat((hidden_state[0,:,:], hidden_state[1,:,:]), dim=1))
        cell_state = self.fc(torch.cat((cell_state[1,:,:], cell_state[0,:,:]), dim=1))
        outputs = outputs[:,:,:self.hidden_size] + outputs[:,:,:self.hidden_size]
        
        
        
#         print("Hidden state shape", hidden_state.size())
#         print("Cell state shape", cell_state.size())
#         print("Outputs shape", outputs.size())
#         Hidden state shape torch.Size([batch_size, hidden_dim])
#         Cell state shape torch.Size([batch_size, hidden_dim])
#         Outputs shape torch.Size([seq_length, N, hidden_dim*2])
        # outputs shape: (seq_length, N, hidden_size)

        return outputs, hidden_state, cell_state

In [27]:
# class Attention(nn.Module):
#     def __init__(self, enc_hidden_dim, dec_hidden_dim):
#         super(Attention, self).__init__()
#         self.attn = nn.Linear((enc_hidden_dim*2) + dec_hidden_dim, dec_hidden_dim)
#         self.V = nn.Linear(dec_hidden_dim, 1, bias = False)
    
#     def forward(self, hidden, encoder_outputs):
#         batch_size = encoder_outputs.shape[1]
#         seq_len = encoder_outputs.shape[0]
        
#         hidden = hidden.unsqueeze(1).repeat(1, seq_len, 1)
        
#         encoder_outputs = encoder_outputs.permute(1, 0, 2)
# #         print("Hidden after repeat", hidden.size())
# #         print("Encoder outputs now", encoder_outputs.size())
#         # [batch_size, seq_len, 2 * enc_hidden_dim] output
#         # [batch_size, seq_len, dec_hidden_dim] hidden
#         energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
#         # energy = [batch, seq_len, dec_hidden_dim]
# #         print("Energy shape", energy.size())
        
#         attention = self.V(energy).squeeze(dim = 2)
#         # attention = [batch_size, seq_len]
# #         print("Attention shape", attention.size())
        
#         attention_weights = F.softmax(attention, dim = 1)
        
#         return attention_weights
        

In [28]:
class Attention(nn.Module):
    def __init__(self, enc_hidden_dim, dec_hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear((enc_hidden_dim *2) + dec_hidden_dim, dec_hidden_dim)
        self.V = nn.Linear(dec_hidden_dim, 1, bias = False)
    
    def forward(self, hidden, encoder_outputs):
        batch_size = encoder_outputs.shape[1]
        seq_len = encoder_outputs.shape[0]
        hidden = hidden.unsqueeze(0)
#         hidden = hidden.unsqueeze(1).repeat(1, seq_len, 1)
        
#         encoder_outputs = encoder_outputs.permute(1, 0, 2)
#         print("Hidden after repeat", hidden.size())
#         print("Encoder outputs now", encoder_outputs.size())
        # [batch_size, seq_len, 2 * enc_hidden_dim] output
        # [batch_size, seq_len, dec_hidden_dim] hidden
        energy  = torch.sum(hidden*encoder_outputs, dim=2)
#         energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        # energy = [batch, seq_len, dec_hidden_dim]
        
#         attention = self.V(energy).squeeze(dim = 2)
        attention = energy.t()
        # attention = [batch_size, seq_len]

        attention_weights = F.softmax(attention, dim = 1)
        
        return attention_weights
        

In [29]:
# class Decoder(nn.Module):
#     def __init__(self, input_size, embedding_size, hidden_size, output_size, num_layers, dropout, attention):
#         super(Decoder, self).__init__()
#         self.dropout = nn.Dropout(dropout)
#         self.hidden_size = hidden_size
#         self.num_layers = num_layers
#         self.attention = attention
        

#         self.embedding = nn.Embedding(input_size, embedding_size)
#         self.lstm = nn.LSTM(input_size = hidden_size*2  + embedding_size, hidden_size = hidden_size, num_layers = num_layers, dropout = dropout)
        
#         self.fc_out = nn.Linear((hidden_size*2) + hidden_size + embedding_size, output_size)
        
#         self.fc = nn.Linear(hidden_size, output_size)
    
#     def forward(self, x, hidden_state, cell_state, encoder_outputs):        
#         x = x.unsqueeze(0)
# #         print("X shape", x.size())

#         embedding = self.dropout(self.embedding(x))
# #         print("Decoder hidden state shape", hidden_state.size())
# #         print("Decoder cell state shape", cell_state.size())
        
# #         print("Embedding decoder shape", embedding.size())
#         # embedding shape (batch_size, 1, embedding_size)
        
#         a = self.attention(hidden_state, encoder_outputs)
#         a = a.unsqueeze(1)
        
#         # a shape (batch_size, 1, seq_length)
#         encoder_outputs = encoder_outputs.permute(1, 0, 2)
#         weighted = torch.bmm(a, encoder_outputs)
# #         print("Weighted shape", weighted.size())
#         # weighted shape (batch_size, 1, hidden_dim*2)
#         weighted = weighted.permute(1, 0, 2)
#         lstm_input = torch.cat((embedding, weighted), dim=2)
# #         print("LSTM input shape", lstm_input.size())
#         # lstm_input shape = [batch_size, 1, enc_hidden_dim * 2 + embed_dim]
        

#         output, (hidden_state, cell_state) = self.lstm(lstm_input, (hidden_state.unsqueeze(0), cell_state.unsqueeze(0)))
        
#         embedding = embedding.squeeze(0)
#         output = output.squeeze(0)
#         weighted = weighted.squeeze(0)
        
# #         print("Embedding shape", embedding.size())
# #         print("Output shape", output.size())
# #         print("Weighted", weighted.size())
        
#         prediction = self.fc_out(torch.cat((output, weighted, embedding), dim = 1))
        
# #         print("Predictions shape ", prediction.size())
#         # (batch_size, output_dim(223))

#         return prediction, hidden_state.squeeze(0), cell_state.squeeze(0)

In [30]:
class Decoder(nn.Module):
    def __init__(self, input_size, embedding_size, hidden_size, output_size, num_layers, attention):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.attention = attention
        

        self.embedding = nn.Embedding(input_size, embedding_size)
        self.lstm = nn.LSTM(input_size = hidden_size  + embedding_size, hidden_size = hidden_size, num_layers = num_layers)
        
        self.fc_out = nn.Linear((hidden_size) + hidden_size + embedding_size, output_size)
        
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x, hidden_state, cell_state, encoder_outputs):        
        x = x.unsqueeze(0)
#         print("X shape", x.size())
#         print("Decoder hidden state shape", hidden_state.size())
#         print("Decoder cell state shape", cell_state.size())
        embedding = self.embedding(x)
#         print("Embedding decoder shape", embedding.size())
        # embedding shape (batch_size, 1, embedding_size)
        
        a = self.attention(hidden_state, encoder_outputs)
        a = a.unsqueeze(1)
        
#         print("Shape a", a.size())
        
        # a shape (batch_size, 1, seq_length)
        encoder_outputs = encoder_outputs.permute(1, 0, 2)
        weighted = torch.bmm(a, encoder_outputs)
#         print("Weighted shape", weighted.size())
        # weighted shape (batch_size, 1, hidden_dim*2)
        weighted = weighted.permute(1, 0, 2)
        lstm_input = torch.cat((embedding, weighted), dim=2)
#         print("LSTM input shape", lstm_input.size())
        # lstm_input shape = [batch_size, 1, enc_hidden_dim * 2 + embed_dim]
        

        output, (hidden_state, cell_state) = self.lstm(lstm_input, (hidden_state.unsqueeze(0), cell_state.unsqueeze(0)))
        
        embedding = embedding.squeeze(0)
        output = output.squeeze(0)
        weighted = weighted.squeeze(0)
        
#         print("Embedding shape", embedding.size())
#         print("Output shape", output.size())
#         print("Weighted", weighted.size())
        
        prediction = self.fc_out(torch.cat((output, weighted, embedding), dim = 1))
        
#         print("Predictions shape ", prediction.size())
        # (batch_size, output_dim(223))

        return prediction, hidden_state.squeeze(0), cell_state.squeeze(0), a.squeeze(1)

In [31]:
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, target_vocab_size):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.target_vocab_size = target_vocab_size

    def forward(self, source, target, teacher_force_ratio=0.5):
        batch_size = target.shape[1]
        target_len = target.shape[0]
        
#         print("Shape of target", target.size())
        outputs = torch.zeros(target_len, batch_size,  self.target_vocab_size).to(device)

        encoder_outputs, hidden_state, cell_state = self.encoder(source)

        # Grab the first input to the Decoder which will be <SOS> token
#         print("Target", target.size)
#         print("Target shape", target.size())
        x = target[0,:]
#         print("Shape of x", x.size())

        for t in range(1, target_len):
            # Use previous hidden, cell as context from encoder at start
            output, hidden_state, cell_state, _ = self.decoder(x, hidden_state, cell_state, encoder_outputs)
            # Store next output prediction
#             output = output.unsqueeze(1)
#             outputs[:, t, :] = output[:, 0, :]
            outputs[t] = output
            

            # Get the best word the Decoder predicted (index in the vocabulary)
#             best_guess = output.argmax(2).squeeze(1)
            best_guess = output.argmax(1)
            
            teacher_force = torch.rand(1).item() < teacher_force_ratio

            # With probability of teacher_force_ratio we take the actual next word
            # otherwise we take the word that the Decoder predicted it to be.
            # Teacher Forcing is used so that the model gets used to seeing
            # similar inputs at training and testing time, if teacher forcing is 1
            # then inputs at test time might be completely different than what the
            # network is used to. This was a long comment.
#             x = target[:,t] if teacher_force else best_guess
            x = target[t, :] if teacher_force else best_guess
#             print("X target shape", x.size())
    
#         outputs = outputs.transpose(0, 1)
        return outputs

In [32]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))


In [33]:
class Trainer(object):
    def __init__(self, device, config):
        self.config = config
        self.start_epoch = 1
        self.device = device
        self.attention = Attention(self.config.hidden_size, self.config.hidden_size)
        self.encoder = Encoder(self.config.input_size_encoder, self.config.encoder_embedding_size, self.config.hidden_size, self.config.encoder_num_layers).to(self.device)

        self.decoder = Decoder(self.config.input_size_decoder,self.config.decoder_embedding_size,self.config.hidden_size,self.config.output_size,
                               self.config.decoder_num_layers,self.attention).to(self.device)

        self.model = Seq2Seq(self.encoder, self.decoder, self.config.output_size).to(self.device)
        self.encoder_optimizer = optim.Adam(self.encoder.parameters(), lr=self.config.lr)
        self.decoder_optimizer = optim.Adam(self.decoder.parameters(), lr=self.config.lr * self.config.decoder_learning_ratio)
        self.criterion = nn.CrossEntropyLoss().to(device)
        
        self.vocab = Vocab(alphabets)
        self.base_dir = f'{config.folder}'
        if not os.path.exists(self.base_dir):
            os.makedirs(self.base_dir)
        
    def train_loop(self, train_loader, validation_loader, train_resume=False):
        if train_resume==True:
            self.load_model('../input/nlp-model/model_attention_20.pth')
        for e in range(self.start_epoch, self.start_epoch + self.config.num_epochs):
            t = time.time()
            calc_loss = self.train_one_epoch(self.device, train_loader, e)
#             self.scheduler.step()
            self.config.epoch.append(e)
            self.config.train_loss.append(calc_loss.avg)
            print(f'Train. Epoch: {e}, Loss: {calc_loss.avg:.5f}, time: {(time.time() - t):.5f}')

            if (e==self.start_epoch + self.config.num_epochs -1):
                t = time.time()
                predictions, targets, calc_loss = self.valid_one_epoch(self.device, validation_loader)
                acc_valid = self.accuracy(predictions, targets)
                self.config.valid_acc.append(acc_valid)
                print(f'Val. Epoch: {e}, Loss: {calc_loss.avg:.5f}, Acc valid: {acc_valid:.5f}, time: {(time.time() - t):.5f}')
                state = {'epoch': e, 'state_dict': self.model.state_dict(),'encoder_optimizer': self.encoder_optimizer.state_dict(),
                         'decoder_optimizer': self.decoder_optimizer.state_dict(), 'train_loss': self.config.train_loss, 'valid_acc':self.config.valid_acc}
                self.save_model(state, f'model_attention_{e}.pth')
            
            
        
        
    def train_one_epoch(self, device, train_loader, epoch):
        self.model.train()
        calc_loss = AverageMeter()
        start = end = time.time()
        for batch_idx, (text, target) in tqdm(enumerate(train_loader)):
            batch_size = text.size(0)
            # make input and target shape (seq_length, batch_size)
            
            text = torch.transpose(text, 0, 1)
            text = text.to(device, dtype=torch.int64)
            target = torch.transpose(target, 0, 1)
            target = target.to(device, dtype=torch.int64)
            self.encoder_optimizer.zero_grad()
            self.decoder_optimizer.zero_grad()

            # Forward prop
            
            #trg = [trg len, batch size]
            #output = [trg len, batch size, output dim]
#             with autocast():
            output = self.model(text, target)


    #             print("Output", output)

    #             print("target ", target)

    #             print("Output",output.size())

    #             print("Target", target.size())
            output = output.transpose(0, 1).contiguous()

    #             output_dim = output.shape[-1]
    #             print("Output shape", output.size())
    #             print("Target shape", target.size())
    #             print("Output after reshape", output.size())
    #             print("Target after reshape", target.size())
        #         output = output[1:].reshape(-1, output.shape[2])
        #         target = target[1:].reshape(-1)
            output = output.reshape(-1, output.shape[2])
            target = target.transpose(0, 1).reshape(-1)

    #             print("Output", output)

    #             print("target ", target)

    #             print("Output",output.size())

    #             print("Target", target.size())

                # output shape (batch_size, seq_length, vocab_length)
                # target shape (batch_size, seq_length)           

            loss = self.criterion(output, target)
            loss_value = loss.item()
            # Back prop
            loss.backward()
#             self.scaler.scale(loss).backward()

            # Clip to avoid exploding gradient issues, makes sure grads are
            # within a healthy range
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1)
#             clip_gradient(self.encoder_optimizer, 5)
#             clip_gradient(self.decoder_optimizer, 5)
            # Adjust model weights
            self.encoder_optimizer.step()
            self.decoder_optimizer.step()
            calc_loss.update(loss_value, batch_size)
            if batch_idx %400==100:
                print('Epoch: [{0}][{1}/{2}] '
                  'Elapsed {remain:s} '
                  .format(
                   epoch, batch_idx, len(train_loader),
                   remain=timeSince(start, float(batch_idx+1)/len(train_loader)),
                   #lr=scheduler.get_lr()[0],
                   ))

        torch.cuda.empty_cache()
        gc.collect()
            
        return calc_loss
    
    def valid_one_epoch(self, device, val_loader):
        self.model.eval()
        calc_loss = AverageMeter()
        predictions = []
        targets  =[]
        with torch.no_grad():
            for batch_idx, (text, targett) in tqdm(enumerate(val_loader)):
                batch_size = text.size(0)
                text = torch.transpose(text, 0, 1)
                text = text.to(device, dtype=torch.int64)
                prediction = self.correct_sentence(self.model, text, device)
                target = self.vocab.decode(np.squeeze(targett.detach().numpy()).tolist())
                predictions.extend(prediction)
                targets.append(target)
                targett = torch.transpose(targett, 0, 1)

    #             text = text.to(device, dtype=torch.int64)
                targett = targett.to(device, dtype=torch.int64)
                output = self.model(text, targett, 0)

    #             print("Output", output.size())
                output = output.transpose(0, 1).contiguous()
                output = output.flatten(0, 1)
                targett = targett.flatten()
    #             print("Output", output.size())
    #             print("Target", target.size())


    #             output = output[1:].reshape(-1, output.shape[2])
    #             target = target[1:].reshape(-1)

                loss = self.criterion(output, targett)
                loss_value = loss.item()

                calc_loss.update(loss_value, batch_size)

        
        return predictions, targets, calc_loss

    def correct_sentence(self, model, text, device):
        self.model.eval()
        with torch.no_grad():
            encoder_outputs, hidden_state, cell_state = self.model.encoder(text)
            translated_sentence = [[1]]
            for _ in range(40):
                target_input = torch.LongTensor(translated_sentence).to(device)
                output, hidden_state, cell_state,_ = self.model.decoder(target_input[-1], hidden_state, cell_state, encoder_outputs)
                output = output.unsqueeze(1)
    #             print("Output", output.size())
                output = torch.softmax(output, dim=-1)
                output = output.to('cpu')
    #             print("Output", output.size())
                values, indices = torch.topk(output, 1)
    #             print("Indices", indices)
                indices = indices[:, -1, 0]
    #             print("Indices", indices)
                indices = indices.tolist()
    #             output = output.argmax(1).tolist()
                translated_sentence.append(indices)
    #             print("Translated sentence", translated_sentence)
                if(indices[0]==2):
                    break
            translated_sentence = np.asarray(translated_sentence).T
            translated_sentence = translated_sentence.tolist()
    #         translated_sentence = np.squeeze(np.array(translated_sentence)).tolist()
        return [self.vocab.decode(i) for i in translated_sentence]
    
    def accuracy(self, predictions, targets):
        n = len(predictions)
        acc = 0
    #     print(len(predictions), len(targets))
        for i in range(len(predictions)):

    #         print(f"Predictions {predictions}")
    #         print(f"Targets {targets}")
            if predictions[i]==targets[i]:
                acc+=1
        return acc/n
    
    
    def save_model(self, state, path):
        torch.save(state, path)
    
    def load_model(self, path):
        checkpoint = torch.load(path)
        self.start_epoch = checkpoint['epoch']+1
        self.model.load_state_dict(checkpoint['state_dict'])
        self.encoder_optimizer.load_state_dict(checkpoint['encoder_optimizer'])
        self.decoder_optimizer.load_state_dict(checkpoint['decoder_optimizer'])
        self.config.train_loss = checkpoint['train_loss']
        self.config.valid_acc = checkpoint['valid_acc']

In [34]:
class GlobalParametersTrain:
    lr = 0.0005
    # 0.0005 is the best
    SRC_VOCAB_SIZE = 232
    TARGET_VOCAB_SIZE = 232
    num_epochs = 10
    max_len = 40
    PAD_IDX = 0
    SOS_IDX = 1
    EOS_IDX = 2 
    encoder_embedding_size = 128
    decoder_embedding_size = 128
    input_size_encoder = len(alphabets)+3
    input_size_decoder = len(alphabets)+3
    output_size = len(alphabets)+3
    hidden_size = 1024  # Needs to be the same for both RNN's
    encoder_num_layers = 1
    decoder_num_layers = 1
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    decoder_learning_ratio = 5.0

    train_loss = []
    valid_loss = []
    valid_acc = []
    epoch = []
    
    folder = './model'

In [None]:
# Model hyperparameters
config = GlobalParametersTrain()
# model = Seq2SeqTransformer(SRC_VOCAB_SIZE, TARGET_VOCAB_SIZE, PAD_IDX, PAD_IDX, pes, device).to(device)
# model = TransformerModel(config.SRC_VOCAB_SIZE, config.model_dim, config.feed_forward_dim, config.TARGET_VOCAB_SIZE, config.num_layers,
#                             config.max_len, config.PAD_IDX, config.SOS_IDX, config.EOS_IDX).to(config.device)
training = Trainer(device=config.device, config=config)
training.train_loop(train_loader, val_loader, False)
