In [None]:
import itertools
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re
import time
from tqdm.notebook import tqdm
import pandas as pd
import nltk
from nltk.tokenize import word_tokenize
import unidecode
import codecs
import pickle
import string
import random
from tqdm.notebook import tqdm
from transformers import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR, CosineAnnealingLR
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Dataset
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import math
from collections import Counter
import gc
# nltk.download('punkt')
# sentence_tokenizer  =  nltk.data.load('tokenizers/punkt/english.pickle')

In [None]:
torch.cuda.empty_cache()
gc.collect()

In [None]:
output = torch.randn(1, 3, 4)
print(torch.argmax(output, dim =-1))

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

set_seed(seed = 1)

In [None]:
#borrow from https://github.com/hisiter97/Spelling_Correction_Vietnamese/blob/master/dataset/add_noise.py
class SynthesizeData(object):
    """
    Uitils class to create artificial miss-spelled words
    Args:
        vocab_path: path to vocab file. Vocab file is expected to be a set of words, separate by ' ', no newline charactor.
    """

    def __init__(self, vocab_path=""):

        # self.vocab = open(vocab_path, 'r', encoding = 'utf-8').read().split()
        self.tokenizer = word_tokenize
        self.word_couples = [['sương', 'xương'], ['sĩ', 'sỹ'], ['sẽ', 'sẻ'], ['sã', 'sả'], ['sả', 'xả'], ['sẽ', 'sẻ'],
                             ['mùi', 'muồi'],
                             ['chỉnh', 'chỉn'], ['sữa', 'sửa'], ['chuẩn', 'chẩn'], ['lẻ', 'lẽ'], ['chẳng', 'chẵng'],
                             ['cổ', 'cỗ'],
                             ['sát', 'xát'], ['cập', 'cặp'], ['truyện', 'chuyện'], ['xá', 'sá'], ['giả', 'dả'],
                             ['đỡ', 'đở'],
                             ['giữ', 'dữ'], ['giã', 'dã'], ['xảo', 'sảo'], ['kiểm', 'kiễm'], ['cuộc', 'cục'],
                             ['dạng', 'dạn'],
                             ['tản', 'tảng'], ['ngành', 'nghành'], ['nghề', 'ngề'], ['nổ', 'nỗ'], ['rảnh', 'rãnh'],
                             ['sẵn', 'sẳn'],
                             ['sáng', 'xán'], ['xuất', 'suất'], ['suôn', 'suông'], ['sử', 'xử'], ['sắc', 'xắc'],
                             ['chữa', 'chửa'],
                             ['thắn', 'thắng'], ['dỡ', 'dở'], ['trải', 'trãi'], ['trao', 'trau'], ['trung', 'chung'],
                             ['thăm', 'tham'],
                             ['sét', 'xét'], ['dục', 'giục'], ['tả', 'tã'], ['sông', 'xông'], ['sáo', 'xáo'],
                             ['sang', 'xang'],
                             ['ngã', 'ngả'], ['xuống', 'suống'], ['xuồng', 'suồng']]

        self.vn_alphabet = ['a', 'ă', 'â', 'b', 'c', 'd', 'đ', 'e', 'ê', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'ô',
                            'ơ', 'p', 'q', 'r', 's', 't', 'u', 'ư', 'v', 'x', 'y']
        self.alphabet_len = len(self.vn_alphabet)
        self.char_couples = [['i', 'y'], ['s', 'x'], ['gi', 'd'],
                             ['ă', 'â'], ['ch', 'tr'], ['ng', 'n'],
                             ['nh', 'n'], ['ngh', 'ng'], ['ục', 'uộc'], ['o', 'u'],
                             ['ă', 'a'], ['o', 'ô'], ['ả', 'ã'], ['ổ', 'ỗ'], ['ủ', 'ũ'], ['ễ', 'ể'],
                             ['e', 'ê'], ['à', 'ờ'], ['ằ', 'à'], ['ẩn', 'uẩn'], ['ẽ', 'ẻ'], ['ùi', 'uồi'], ['ă', 'â'],
                             ['ở', 'ỡ'], ['ỹ', 'ỷ'], ['ỉ', 'ĩ'], ['ị', 'ỵ'],
                             ['ấ', 'á'], ['n', 'l'], ['qu', 'w'], ['ph', 'f'], ['d', 'z'], ['c', 'k'], ['qu', 'q'],
                             ['i', 'j'], ['gi', 'j'],
                             ]

        self.teencode_dict = {'mình': ['mk', 'mik', 'mjk'], 'vô': ['zô', 'zo', 'vo'], 'vậy': ['zậy', 'z', 'zay', 'za'],
                              'phải': ['fải', 'fai', ], 'biết': ['bit', 'biet'],
                              'rồi': ['rùi', 'ròi', 'r'], 'bây': ['bi', 'bay'], 'giờ': ['h', ],
                              'không': ['k', 'ko', 'khong', 'hk', 'hong', 'hông', '0', 'kg', 'kh', ],
                              'đi': ['di', 'dj', ], 'gì': ['j', ], 'em': ['e', ], 'được': ['dc', 'đc', ], 'tao': ['t'],
                              'tôi': ['t'], 'chồng': ['ck'], 'vợ': ['vk']

                              }

        # self.typo={"ă":"aw","â":"aa","á":"as","à":"af","ả":"ar","ã":"ax","ạ":"aj","ắ":"aws","ổ":"oor","ỗ":"oox","ộ":"ooj","ơ":"ow",
        #           "ằ":"awf","ẳ":"awr","ẵ":"awx","ặ":"awj","ó":"os","ò":"of","ỏ":"or","õ":"ox","ọ":"oj","ô":"oo","ố":"oos","ồ":"oof",
        #           "ớ":"ows","ờ":"owf","ở":"owr","ỡ":"owx","ợ":"owj","é":"es","è":"ef","ẻ":"er","ẽ":"ex","ẹ":"ej","ê":"ee","ế":"ees","ề":"eef",
        #           "ể":"eer","ễ":"eex","ệ":"eej","ú":"us","ù":"uf","ủ":"ur","ũ":"ux","ụ":"uj","ư":"uw","ứ":"uws","ừ":"uwf","ử":"uwr","ữ":"uwx",
        #           "ự":"uwj","í":"is","ì":"if","ỉ":"ir","ị":"ij","ĩ":"ix","ý":"ys","ỳ":"yf","ỷ":"yr","ỵ":"yj","đ":"dd",
        #           "Ă":"Aw","Â":"Aa","Á":"As","À":"Af","Ả":"Ar","Ã":"Ax","Ạ":"Aj","Ắ":"Aws","Ổ":"Oor","Ỗ":"Oox","Ộ":"Ooj","Ơ":"Ow",
        #           "Ằ":"AWF","Ẳ":"Awr","Ẵ":"Awx","Ặ":"Awj","Ó":"Os","Ò":"Of","Ỏ":"Or","Õ":"Ox","Ọ":"Oj","Ô":"Oo","Ố":"Oos","Ồ":"Oof",
        #           "Ớ":"Ows","Ờ":"Owf","Ở":"Owr","Ỡ":"Owx","Ợ":"Owj","É":"Es","È":"Ef","Ẻ":"Er","Ẽ":"Ex","Ẹ":"Ej","Ê":"Ee","Ế":"Ees","Ề":"Eef",
        #           "Ể":"Eer","Ễ":"Eex","Ệ":"Eej","Ú":"Us","Ù":"Uf","Ủ":"Ur","Ũ":"Ux","Ụ":"Uj","Ư":"Uw","Ứ":"Uws","Ừ":"Uwf","Ử":"Uwr","Ữ":"Uwx",
        #           "Ự":"Uwj","Í":"Is","Ì":"If","Ỉ":"Ir","Ị":"Ij","Ĩ":"Ix","Ý":"Ys","Ỳ":"Yf","Ỷ":"Yr","Ỵ":"Yj","Đ":"Dd"}
        self.typo = {"ă": ["aw", "a8"], "â": ["aa", "a6"], "á": ["as", "a1"], "à": ["af", "a2"], "ả": ["ar", "a3"],
                     "ã": ["ax", "a4"], "ạ": ["aj", "a5"], "ắ": ["aws", "ă1"], "ổ": ["oor", "ô3"], "ỗ": ["oox", "ô4"],
                     "ộ": ["ooj", "ô5"], "ơ": ["ow", "o7"],
                     "ằ": ["awf", "ă2"], "ẳ": ["awr", "ă3"], "ẵ": ["awx", "ă4"], "ặ": ["awj", "ă5"], "ó": ["os", "o1"],
                     "ò": ["of", "o2"], "ỏ": ["or", "o3"], "õ": ["ox", "o4"], "ọ": ["oj", "o5"], "ô": ["oo", "o6"],
                     "ố": ["oos", "ô1"], "ồ": ["oof", "ô2"],
                     "ớ": ["ows", "ơ1"], "ờ": ["owf", "ơ2"], "ở": ["owr", "ơ2"], "ỡ": ["owx", "ơ4"], "ợ": ["owj", "ơ5"],
                     "é": ["es", "e1"], "è": ["ef", "e2"], "ẻ": ["er", "e3"], "ẽ": ["ex", "e4"], "ẹ": ["ej", "e5"],
                     "ê": ["ee", "e6"], "ế": ["ees", "ê1"], "ề": ["eef", "ê2"],
                     "ể": ["eer", "ê3"], "ễ": ["eex", "ê3"], "ệ": ["eej", "ê5"], "ú": ["us", "u1"], "ù": ["uf", "u2"],
                     "ủ": ["ur", "u3"], "ũ": ["ux", "u4"], "ụ": ["uj", "u5"], "ư": ["uw", "u7"], "ứ": ["uws", "ư1"],
                     "ừ": ["uwf", "ư2"], "ử": ["uwr", "ư3"], "ữ": ["uwx", "ư4"],
                     "ự": ["uwj", "ư5"], "í": ["is", "i1"], "ì": ["if", "i2"], "ỉ": ["ir", "i3"], "ị": ["ij", "i5"],
                     "ĩ": ["ix", "i4"], "ý": ["ys", "y1"], "ỳ": ["yf", "y2"], "ỷ": ["yr", "y3"], "ỵ": ["yj", "y5"],
                     "đ": ["dd", "d9"],
                     "Ă": ["Aw", "A8"], "Â": ["Aa", "A6"], "Á": ["As", "A1"], "À": ["Af", "A2"], "Ả": ["Ar", "A3"],
                     "Ã": ["Ax", "A4"], "Ạ": ["Aj", "A5"], "Ắ": ["Aws", "Ă1"], "Ổ": ["Oor", "Ô3"], "Ỗ": ["Oox", "Ô4"],
                     "Ộ": ["Ooj", "Ô5"], "Ơ": ["Ow", "O7"],
                     "Ằ": ["AWF", "Ă2"], "Ẳ": ["Awr", "Ă3"], "Ẵ": ["Awx", "Ă4"], "Ặ": ["Awj", "Ă5"], "Ó": ["Os", "O1"],
                     "Ò": ["Of", "O2"], "Ỏ": ["Or", "O3"], "Õ": ["Ox", "O4"], "Ọ": ["Oj", "O5"], "Ô": ["Oo", "O6"],
                     "Ố": ["Oos", "Ô1"], "Ồ": ["Oof", "Ô2"],
                     "Ớ": ["Ows", "Ơ1"], "Ờ": ["Owf", "Ơ2"], "Ở": ["Owr", "Ơ3"], "Ỡ": ["Owx", "Ơ4"], "Ợ": ["Owj", "Ơ5"],
                     "É": ["Es", "E1"], "È": ["Ef", "E2"], "Ẻ": ["Er", "E3"], "Ẽ": ["Ex", "E4"], "Ẹ": ["Ej", "E5"],
                     "Ê": ["Ee", "E6"], "Ế": ["Ees", "Ê1"], "Ề": ["Eef", "Ê2"],
                     "Ể": ["Eer", "Ê3"], "Ễ": ["Eex", "Ê4"], "Ệ": ["Eej", "Ê5"], "Ú": ["Us", "U1"], "Ù": ["Uf", "U2"],
                     "Ủ": ["Ur", "U3"], "Ũ": ["Ux", "U4"], "Ụ": ["Uj", "U5"], "Ư": ["Uw", "U7"], "Ứ": ["Uws", "Ư1"],
                     "Ừ": ["Uwf", "Ư2"], "Ử": ["Uwr", "Ư3"], "Ữ": ["Uwx", "Ư4"],
                     "Ự": ["Uwj", "Ư5"], "Í": ["Is", "I1"], "Ì": ["If", "I2"], "Ỉ": ["Ir", "I3"], "Ị": ["Ij", "I5"],
                     "Ĩ": ["Ix", "I4"], "Ý": ["Ys", "Y1"], "Ỳ": ["Yf", "Y2"], "Ỷ": ["Yr", "Y3"], "Ỵ": ["Yj", "Y5"],
                     "Đ": ["Dd", "D9"]}
        self.all_word_candidates = self.get_all_word_candidates(self.word_couples)
        self.string_all_word_candidates = ' '.join(self.all_word_candidates)
        self.all_char_candidates = self.get_all_char_candidates()
        self.keyboardNeighbors = self.getKeyboardNeighbors()

    def replace_teencode(self, word):
        candidates = self.teencode_dict.get(word, None)
        if candidates is not None:
            chosen_one = 0
            if len(candidates) > 1:
                chosen_one = np.random.randint(0, len(candidates))
            return candidates[chosen_one]

    def getKeyboardNeighbors(self):
        keyboardNeighbors = {}
        keyboardNeighbors['a'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ă'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['â'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['á'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['à'] = "aáàảãăắằẳẵâấầẩẫ"
        keyboardNeighbors['ả'] = "aảã"
        keyboardNeighbors['ã'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ạ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ắ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ằ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ẳ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ặ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ẵ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ấ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ầ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ẩ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ẫ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ậ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['b'] = "bh"
        keyboardNeighbors['c'] = "cgn"
        keyboardNeighbors['d'] = "đctơở"
        keyboardNeighbors['đ'] = "d"
        keyboardNeighbors['e'] = "eéèẻẽẹêếềểễệbpg"
        keyboardNeighbors['é'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['è'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['ẻ'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['ẽ'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['ẹ'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['ê'] = "eéèẻẽẹêếềểễệá"
        keyboardNeighbors['ế'] = "eéèẻẽẹêếềểễệố"
        keyboardNeighbors['ề'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['ể'] = "eéèẻẽẹêếềểễệôốồổỗộ"
        keyboardNeighbors['ễ'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['ệ'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['g'] = "qgộ"
        keyboardNeighbors['h'] = "h"
        keyboardNeighbors['i'] = "iíìỉĩịat"
        keyboardNeighbors['í'] = "iíìỉĩị"
        keyboardNeighbors['ì'] = "iíìỉĩị"
        keyboardNeighbors['ỉ'] = "iíìỉĩị"
        keyboardNeighbors['ĩ'] = "iíìỉĩị"
        keyboardNeighbors['ị'] = "iíìỉĩịhự"
        keyboardNeighbors['k'] = "klh"
        keyboardNeighbors['l'] = "ljidđ"
        keyboardNeighbors['m'] = "mn"
        keyboardNeighbors['n'] = "mnedư"
        keyboardNeighbors['o'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ó'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ò'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ỏ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['õ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ọ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ô'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ố'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ồ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ổ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ộ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ỗ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ơ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ớ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ờ'] = "oóòỏọõôốồổỗộơớờởợỡà"
        keyboardNeighbors['ở'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ợ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ỡ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        # keyboardNeighbors['p'] = "op"
        # keyboardNeighbors['q'] = "qọ"
        # keyboardNeighbors['r'] = "rht"
        # keyboardNeighbors['s'] = "s"
        # keyboardNeighbors['t'] = "tp"
        keyboardNeighbors['u'] = "uúùủũụưứừữửựhiaạt"
        keyboardNeighbors['ú'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ù'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ủ'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ũ'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ụ'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ư'] = "uúùủũụưứừữửựg"
        keyboardNeighbors['ứ'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ừ'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ử'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ữ'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ự'] = "uúùủũụưứừữửựg"
        keyboardNeighbors['v'] = "v"
        keyboardNeighbors['x'] = "x"
        keyboardNeighbors['y'] = "yýỳỷỵỹụ"
        keyboardNeighbors['ý'] = "yýỳỷỵỹ"
        keyboardNeighbors['ỳ'] = "yýỳỷỵỹ"
        keyboardNeighbors['ỷ'] = "yýỳỷỵỹ"
        keyboardNeighbors['ỵ'] = "yýỳỷỵỹ"
        keyboardNeighbors['ỹ'] = "yýỳỷỵỹ"
        # keyboardNeighbors['w'] = "wv"
        # keyboardNeighbors['j'] = "jli"
        # keyboardNeighbors['z'] = "zs"
        # keyboardNeighbors['f'] = "ft"

        return keyboardNeighbors

    def replace_char_noaccent(self, text, onehot_label):

        # find index noise
        idx = np.random.randint(0, len(onehot_label))
        prevent_loop = 0
        while onehot_label[idx] == 1 or text[idx].isnumeric() or text[idx] in string.punctuation:
            idx = np.random.randint(0, len(onehot_label))
            prevent_loop += 1
            if prevent_loop > 10:
                return False, text, onehot_label

        index_noise = idx
        onehot_label[index_noise] = 1
        word_noise = text[index_noise]
        for id in range(0, len(word_noise)):
            char = word_noise[id]

            if char in self.keyboardNeighbors:
                neighbors = self.keyboardNeighbors[char]
                idx_neigh = np.random.randint(0, len(neighbors))
                replaced = neighbors[idx_neigh]
                word_noise = word_noise[: id] + replaced + word_noise[id + 1:]
                text[index_noise] = word_noise
                return True, text, onehot_label

        return False, text, onehot_label

    def replace_word_candidate(self, word):
        """
        Return a homophone word of the input word.
        """
        capital_flag = word[0].isupper()
        word = word.lower()
        if capital_flag and word in self.teencode_dict:
            return self.replace_teencode(word).capitalize()
        elif word in self.teencode_dict:
            return self.replace_teencode(word)

        for couple in self.word_couples:
            for i in range(2):
                if couple[i] == word:
                    if i == 0:
                        if capital_flag:
                            return couple[1].capitalize()
                        else:
                            return couple[1]
                    else:
                        if capital_flag:
                            return couple[0].capitalize()
                        else:
                            return couple[0]

    def replace_char_candidate(self, char):
        """
        return a homophone char/subword of the input char.
        """
        for couple in self.char_couples:
            for i in range(2):
                if couple[i] == char:
                    if i == 0:
                        return couple[1]
                    else:
                        return couple[0]

    def replace_char_candidate_typo(self, char):
        """
        return a homophone char/subword of the input char.
        """
        i = np.random.randint(0, 2)

        return self.typo[char][i]

    def get_all_char_candidates(self, ):

        all_char_candidates = []
        for couple in self.char_couples:
            all_char_candidates.extend(couple)
        return all_char_candidates

    def get_all_word_candidates(self, word_couples):

        all_word_candidates = []
        for couple in self.word_couples:
            all_word_candidates.extend(couple)
        return all_word_candidates

    def remove_diacritics(self, text, onehot_label):
        """
        Replace word which has diacritics with the same word without diacritics
        Args:
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this
            function only choose the word that do not has onehot label == 1.
        return: a list of word tokens has one word that its diacritics was removed,
                a list of onehot label indicate the position of words that has been modified.
        """
        idx = np.random.randint(0, len(onehot_label))
        prevent_loop = 0
        while onehot_label[idx] == 1 or text[idx] == unidecode.unidecode(text[idx]) or text[idx] in string.punctuation:
            idx = np.random.randint(0, len(onehot_label))
            prevent_loop += 1
            if prevent_loop > 10:
                return False, text, onehot_label

        onehot_label[idx] = 1
        text[idx] = unidecode.unidecode(text[idx])
        return True, text, onehot_label

    def replace_with_random_letter(self, text, onehot_label):
        """
        Replace, add (or remove) a random letter in a random chosen word with a random letter
        Args:
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this
            function only choose the word that do not has onehot label == 1.
        return: a list of word tokens has one word that has been modified,
                a list of onehot label indicate the position of words that has been modified.
        """
        idx = np.random.randint(0, len(onehot_label))
        prevent_loop = 0
        while onehot_label[idx] == 1 or text[idx].isnumeric() or text[idx] in string.punctuation:
            idx = np.random.randint(0, len(onehot_label))
            prevent_loop += 1
            if prevent_loop > 10:
                return False, text, onehot_label

        # replace, add or remove? 0 is replace, 1 is add, 2 is remove
        coin = np.random.choice([0, 1, 2])
        if coin == 0:
            chosen_letter = text[idx][np.random.randint(0, len(text[idx]))]
            replaced = self.vn_alphabet[np.random.randint(0, self.alphabet_len)]
            try:
                text[idx] = re.sub(chosen_letter, replaced, text[idx])
            except:
                return False, text, onehot_label
        elif coin == 1:
            chosen_letter = text[idx][np.random.randint(0, len(text[idx]))]
            replaced = chosen_letter + self.vn_alphabet[np.random.randint(0, self.alphabet_len)]
            try:
                text[idx] = re.sub(chosen_letter, replaced, text[idx])
            except:
                return False, text, onehot_label
        else:
            chosen_letter = text[idx][np.random.randint(0, len(text[idx]))]
            try:
                text[idx] = re.sub(chosen_letter, '', text[idx])
            except:
                return False, text, onehot_label

        onehot_label[idx] = 1
        return True, text, onehot_label

    def replace_with_homophone_word(self, text, onehot_label):
        """
        Replace a candidate word (if exist in the word_couple) with its homophone. if successful, return True, else False
        Args:
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this
            function only choose the word that do not has onehot label == 1.
        return: True, text, onehot_label if successful replace, else False, text, onehot_label
        """
        # account for the case that the word in the text is upper case but its lowercase match the candidates list
        candidates = []
        for i in range(len(text)):
            if text[i].lower() in self.all_word_candidates or text[i].lower() in self.teencode_dict.keys():
                candidates.append((i, text[i]))

        if len(candidates) == 0:
            return False, text, onehot_label

        idx = np.random.randint(0, len(candidates))
        prevent_loop = 0
        while onehot_label[candidates[idx][0]] == 1:
            idx = np.random.choice(np.arange(0, len(candidates)))
            prevent_loop += 1
            if prevent_loop > 5:
                return False, text, onehot_label

        text[candidates[idx][0]] = self.replace_word_candidate(candidates[idx][1])
        onehot_label[candidates[idx][0]] = 1
        return True, text, onehot_label

    def replace_with_homophone_letter(self, text, onehot_label):
        """
        Replace a subword/letter with its homophones
        Args:
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this
            function only choose the word that do not has onehot label == 1.
        return: True, text, onehot_label if successful replace, else False, None, None
        """
        candidates = []
        for i in range(len(text)):
            for char in self.all_char_candidates:
                if re.search(char, text[i]) is not None:
                    candidates.append((i, char))
                    break

        if len(candidates) == 0:

            return False, text, onehot_label
        else:
            idx = np.random.randint(0, len(candidates))
            prevent_loop = 0
            while onehot_label[candidates[idx][0]] == 1:
                idx = np.random.randint(0, len(candidates))
                prevent_loop += 1
                if prevent_loop > 5:
                    return False, text, onehot_label

            replaced = self.replace_char_candidate(candidates[idx][1])
            text[candidates[idx][0]] = re.sub(candidates[idx][1], replaced, text[candidates[idx][0]])

            onehot_label[candidates[idx][0]] = 1
            return True, text, onehot_label

    def replace_with_typo_letter(self, text, onehot_label):
        """
        Replace a subword/letter with its homophones
        Args:
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this
            function only choose the word that do not has onehot label == 1.
        return: True, text, onehot_label if successful replace, else False, None, None
        """
        # find index noise
        idx = np.random.randint(0, len(onehot_label))
        prevent_loop = 0
        while onehot_label[idx] == 1 or text[idx].isnumeric() or text[idx] in string.punctuation:
            idx = np.random.randint(0, len(onehot_label))
            prevent_loop += 1
            if prevent_loop > 10:
                return False, text, onehot_label

        index_noise = idx
        onehot_label[index_noise] = 1

        word_noise = text[index_noise]
        for j in range(0, len(word_noise)):
            char = word_noise[j]

            if char in self.typo:
                replaced = self.replace_char_candidate_typo(char)
                word_noise = word_noise[: j] + replaced + word_noise[j + 1:]
                text[index_noise] = word_noise
                return True, text, onehot_label
        return True, text, onehot_label

    def add_noise(self, sentence, percent_err=0.15, num_type_err=5):
        tokens = self.tokenizer(sentence)
        onehot_label = [0] * len(tokens)

        num_wrong = int(np.ceil(percent_err * len(tokens)))
        num_wrong = np.random.randint(1, num_wrong + 1)
        if np.random.rand() < 0.05:
            num_wrong = 0

        for i in range(0, num_wrong):
            err = np.random.randint(0, num_type_err + 1)

            if err == 0:
                _, tokens, onehot_label = self.replace_with_homophone_letter(tokens, onehot_label)
            elif err == 1:
                _, tokens, onehot_label = self.replace_with_typo_letter(tokens, onehot_label)
            elif err == 2:
                _, tokens, onehot_label = self.replace_with_homophone_word(tokens, onehot_label)
            elif err == 3:
                _, tokens, onehot_label = self.replace_with_random_letter(tokens, onehot_label)
            elif err == 4:
                _, tokens, onehot_label = self.remove_diacritics(tokens, onehot_label)
            elif err == 5:
                _, tokens, onehot_label = self.replace_char_noaccent(tokens, onehot_label)
            else:
                continue
        return ' '.join(tokens), onehot_label

In [None]:
synthesizer = SynthesizeData()


In [None]:
print("here")
MAXLEN = 40
NGRAM = 5
alphabets = 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬ0bBcCdDđĐeEè1ÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈ2ĩĨíÍịỊjJkKlLmMnNoO3òÒỏỎõÕóÓọỌôÔồ4ỒổỔỗỖốỐộỘơƠờỜ5ởỞỡỠớỚợỢpP6qQrRsStTuUùÙủỦ7ũŨúÚụỤưƯừỪửỬữỮứỨựỰvVw8WxXyYỳỲỷỶ9ỹỸýÝỵỴzZ!"#$%&\'()*+,-./:;<=>?@[\]^_`{|}~ '
# alphabets = 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&\'()*+,-./:;<=>?@[\]^_`{|}~ ]'
print(len(alphabets))
print(alphabets[76])
print(alphabets)

In [None]:
class VocabChar():
    def __init__(self, chars):
        self.pad = 0
        self.go = 1
        self.eos = 2

        self.chars = chars
       
        
        self.i2c = {i + 3: c for i, c in enumerate(chars)}
        
        self.c2i = {c: i + 3 for i, c in enumerate(chars)}
        
        self.i2c[0] = '<pad>'
        self.i2c[1] = '<sos>'
        self.i2c[2] = '<eos>'

        
        

    def encode(self, chars):
        return [self.go] + [self.c2i[c] for c in chars] + [self.eos]

    def decode(self, ids):
        first = 1 if self.go in ids else 0
        last = ids.index(self.eos) if self.eos in ids else None
        sent = ''.join([self.i2c[i] for i in ids[first:last]])
        return sent

    def __len__(self):
        return len(self.c2i) + 3

    def batch_decode(self, arr):
        texts = [self.decode(ids) for ids in arr]
        return texts

    def __str__(self):
        return self.chars
    
vocab_char = VocabChar(alphabets)

# Preprocessing

# Params CONFIG


In [None]:
class VocabWord(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word ={}
        self.idx = 0
        self.go = 1
        self.eos = 2
        self.pad = 0

    def add_word(self, word):
        if not word in self.word2idx:
            self.word2idx[word] = self.idx
            self.idx2word[self.idx] = word
            self.idx +=1

    def __len__(self):
        return len(self.word2idx)
    
    def encode(self, chars):
        encode_sent = []
        encode_sent.append(self.go)
        for word in word_tokenize(chars):
            if not word in self.word2idx:
                encode_sent.append(self.word2idx['<unk>'])
            else:
                encode_sent.append(self.word2idx[word])
        encode_sent.append(self.eos)        
        
        return encode_sent
    
    def decode(self, ids):
        first = 1 if self.go in ids else 0
        last = ids.index(self.eos) if self.eos in ids else None
        sent = ' '.join([self.idx2word[i] for i in ids[first:last]])
        return sent


    def __call__(self, word):
        if not word in self.word2idx:
            return self.word2idx['<unk>']
        return self.word2idx[word]

In [None]:
train_normal_captions = np.load('../input/5gram-nlp-project/train_normal_captions.npy')
valid_normal_captions = np.load('../input/5gram-nlp-project/valid_normal_captions.npy')
train_normal_captions = train_normal_captions[0:200000]
valid_normal_captions = valid_normal_captions[0:2000]
captions = []
onehot_labels = []
def build_vocab(all_clean_captions):
    counter = Counter()
    for caption in tqdm(all_clean_captions):
        captions.append(caption)
        tokens = word_tokenize(caption)
        counter.update(tokens)

    words = [word for word, cnt in counter.items() if cnt>=2]
    vocab = VocabWord()
    vocab.add_word('<pad>') # 0
    vocab.add_word('<start>') # 1
    vocab.add_word('<end>') # 2
    vocab.add_word('<unk>') # 3
    for i, word in enumerate(words):
        vocab.add_word(word)
    return vocab

vocab_word = build_vocab(train_normal_captions)
print(f"Length of vocab {len(vocab_word)}")

In [None]:
# print(train_captions[0:20])

In [None]:
# print(valid_captions[0:20])

In [None]:
list_ngram_target = np.load('../input/5gram-nlp-project/list_5gram_nonum_train.npy')
list_ngram_target = list_ngram_target[0:200000]
list_ngram_valid_target = np.load('../input/5gram-nlp-project/list_5gram_nonum_valid.npy')
list_ngram_valid_target = list_ngram_valid_target[5000:7000]

In [None]:
class SpellingDataset(Dataset):
    def __init__(self, train_normal, list_ngram_target, vocab_char, vocab_word, max_len):
        self.list_ngram_target = list_ngram_target
        self.train_normal = train_normal
        self.vocab_char = vocab_char
        self.vocab_word = vocab_word
        self.max_len = max_len
    
    def __getitem__(self, index):
        train_text = self.train_normal[index]
        train_target = self.list_ngram_target[index]
        
        train_text_char = self.vocab_char.encode(train_text)
        train_text_word = self.vocab_word.encode(train_text)
        train_target_word = self.vocab_word.encode(train_target)
#         print(train_text_char,train_text_word,train_target_word)
    
        train_text_char_length = len(train_text_char)
        train_text_word_length = len(train_text_word)
        train_target_word_length = len(train_target_word)
        
        if(train_text_char_length < self.max_len):
            pad_length = self.max_len-train_text_char_length
            train_text_char_encode = np.array(train_text_char)
            train_text_char_encode = np.concatenate((train_text_char_encode, np.zeros(pad_length)), axis = 0)
            
        elif(train_text_char_length>= self.max_len):
            train_text_char_encode = train_text_char[0:self.max_len]
            train_text_char_encode = np.array(train_text_char_encode)
            
        if(train_text_word_length < self.max_len):
            pad_length = self.max_len-train_text_word_length
            train_text_word_encode = np.array(train_text_word)
            train_text_word_encode = np.concatenate((train_text_word_encode, np.zeros(pad_length)), axis = 0)
            
        elif(train_text_word_length>= self.max_len):
            train_text_word_encode = train_text_word[0:self.max_len]
            train_text_word_encode= np.array(train_text_word_encode)    
            
        if(train_target_word_length < self.max_len):
            pad_length = self.max_len-train_target_word_length
            train_target_word_encode = np.array(train_target_word)
            train_target_word_encode = np.concatenate((train_target_word_encode, np.zeros(pad_length)), axis = 0)
            
        elif(train_target_word_length>= self.max_len):
            train_target_word_encode = train_target_word[0:self.max_len]
            train_target_word_encode= np.array(train_target_word_encode)  
               
        tensor_text_char = torch.from_numpy(train_text_char_encode)
        tensor_text_word = torch.from_numpy(train_text_word_encode)
        tensor_target_word = torch.from_numpy(train_target_word_encode)
        return tensor_text_char, tensor_text_word, tensor_target_word
        
        
    
    def __len__(self):
        return len(self.list_ngram_target)

ds_train = SpellingDataset(train_normal_captions, list_ngram_target, vocab_char, vocab_word, MAXLEN)
ds_valid = SpellingDataset(valid_normal_captions, list_ngram_valid_target, vocab_char, vocab_word, MAXLEN)
# ds_test = SpellingDataset(list_ngram_test, synthesizer, vocab, MAXLEN)
train_loader = DataLoader(ds_train, batch_size = 512, shuffle=True)
val_loader = DataLoader(ds_valid, batch_size = 1)
# test_loader = DataLoader(ds_test, batch_size = 200)
# print(len(train_loader))
text_char, text_word, target_word= next(iter(train_loader))
print("Text: ", vocab_char.decode(np.squeeze(text_char[0].detach().numpy()).tolist()))
print("Text: ", vocab_word.decode(np.squeeze(text_word[0].detach().numpy()).tolist()))
print("Text: ", vocab_word.decode(np.squeeze(target_word[0].detach().numpy()).tolist()))
# valid_text, valid_oh_label= next(iter(val_loader))
print(len(train_loader))
print(text_char[0])
print(text_word[0])
print(target_word[0])
print(text_char.size())
print(text_word.size())
print(target_word.size())
# print(train_oh_label.size())
# print(train_oh_label[0])
# print(train_oh_label[1])
# print(train_oh_label[2])
# print(valid_text.size(), valid_oh_label.size())

In [None]:
class PositionEncoder(nn.Module):
    def __init__(self, max_len, emb_size, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_len, emb_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, emb_size, 2).float() * (-math.log(10000.0) / emb_size))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerEncoder(nn.Module):
    def __init__(self, input_size, emb_size, hidden_size, num_layer, max_len=7):
        super().__init__()
        self.emb_size = emb_size
        self.hidden_size = hidden_size
        self.num_layer = num_layer
        self.scale = math.sqrt(emb_size)

        self.embedding = nn.Embedding(input_size, emb_size)
        # additional length for sos and eos
        self.pos_encoder = PositionEncoder(max_len + 1, emb_size)
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_size, nhead=8,
                                                   dim_feedforward=hidden_size,
                                                   dropout=0.1, activation='gelu')
        encoder_norm = nn.LayerNorm(emb_size)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layer, norm=encoder_norm)

    def forward(self, src, src_mask):
        src = self.embedding(src) * self.scale
        src = self.pos_encoder(src)
        output = self.encoder(src, src_key_padding_mask=src_mask)
        return output

class TransformerEncoderHybrid(nn.Module):
    def __init__(self, emb_size, hidden_size, num_layer = 8, max_len=40):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_size, nhead=8,
                                                   dim_feedforward=hidden_size,
                                                   dropout=0.1, activation='gelu')
        encoder_norm = nn.LayerNorm(emb_size)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layer, norm=encoder_norm)

    def forward(self, src):
        output = self.encoder(src)
        return output

# class TransformerDecoder(nn.Module):
#     def __init__(self, output_size, emb_size, hidden_size, num_layer, max_len=64):
#         super().__init__()
#         self.emb_size = emb_size
#         self.hidden_size = hidden_size
#         self.num_layer = num_layer
#         self.scale = math.sqrt(emb_size)

#         self.embedding = nn.Embedding(output_size, emb_size)
#         self.pos_encoder = PositionEncoder(max_len + 10, emb_size)
#         decoder_layer = nn.TransformerDecoderLayer(d_model=emb_size, nhead=4,
#                                                    dim_feedforward=hidden_size,
#                                                    dropout=0.1, activation='gelu')
#         decoder_norm = nn.LayerNorm(emb_size)
#         self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layer, norm=decoder_norm)
#         self.fc = nn.Linear(emb_size, output_size)

#     def forward(self, trg, enc_output, sub_mask, mask):
# #         print(trg.size())
#         trg = self.embedding(trg) * self.scale
# #         print(trg.size())
#         trg = self.pos_encoder(trg)
# #         print(trg.size())
# #         print("Target", trg.size())
# #         print("Target sub mask", sub_mask.size())
# #         print("Target mask", mask.size())
#         output = self.decoder(trg, enc_output, tgt_mask=sub_mask, tgt_key_padding_mask=mask)
# #         print(output.size())
#         output = self.fc(output)
#         return output


In [None]:
class TransformerModel(nn.Module):
    def __init__(self, input_size, emb_size, hidden_size, output_size,
                 num_layer, max_len, pad_token, sos_token, eos_token):
        super().__init__()
        self.encoder = TransformerEncoder(input_size, emb_size, hidden_size, num_layer, max_len)
        self.encoder2 = TransformerEncoderHybrid(emb_size, hidden_size)
#         self.decoder = TransformerDecoder(output_size, emb_size, hidden_size, num_layer, max_len)
        self.embedding_word = nn.Embedding(output_size, emb_size)
        self.pad_token = pad_token
        self.sos_token = sos_token
        self.eos_token = eos_token
        self.max_len = max_len
        self.fc = nn.Linear(emb_size, output_size)
        self.softmax = nn.Softmax(dim=-1)
    
    @staticmethod
    def generate_mask(src, pad_token):
        '''
        Generate mask for tensor src
        :param src: tensor with shape (max_src, b)
        :param pad_token: padding token
        :return: mask with shape (b, max_src) where pad_token is masked with 1
        '''
        mask = (src.t() == pad_token)
        return mask.to(src.device)
    
    @staticmethod
    def generate_submask(src):
        sz = src.size(0)
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask.to(src.device)
    
    def forward(self, train_char, train_word, target_word):
        src_mask = self.generate_mask(train_char, self.pad_token)
#         trg_submask = self.generate_submask(target_word)
        enc_output = self.encoder(train_char, src_mask)
#         print(enc_output.size())
        
        train_word_embedding = self.embedding_word(train_word)
#         print(train_word_embedding.size())
        embed_hier = enc_output + train_word_embedding
#         print(embed_hier.size())
        output = self.encoder2(embed_hier)
        output = self.fc(output)
        eoutput = self.softmax(output)
#         print(output.size())
#         print(output.size())
        return output
    
    def inference(self, char, word, device):
        char_mask = self.generate_mask(char, self.pad_token)
        enc_output = self.encoder(char, char_mask)
        word_embedding = self.embedding_word(word)
        enc_output = enc_output + word_embedding
        enc_output = self.encoder2(enc_output)
        enc_output = enc_output.transpose(0, 1)
        enc_output = self.fc(enc_output)
        enc_output = self.softmax(enc_output)
#         device = src.device

        output = torch.argmax(enc_output, dim =-1)
#         print(output.size())
        output = output.detach().cpu().numpy().tolist()[0]
#         print(output)
#         print(vocab_word.decode(output))
        return vocab_word.decode(output)

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

In [None]:
class Trainer(object):
    def __init__(self, model, device, config):
        self.config = config
        self.start_epoch = 1
        self.model = model
        self.device = device
        self.model.to(self.device)
        self.optimizer = AdamW(self.model.parameters(), lr = self.config.lr)
        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr=0.001, steps_per_epoch=len(train_loader), epochs=15)
        self.criterion = nn.CrossEntropyLoss(ignore_index = self.config.PAD_IDX).to(device)
        self.scaler = GradScaler()
        self.base_dir = f'{config.folder}'
        if not os.path.exists(self.base_dir):
            os.makedirs(self.base_dir)
        
    def train_loop(self, train_resume=False):
        for e in range(self.start_epoch, self.start_epoch+self.config.num_epochs):
            t = time.time()
            calc_loss = self.train_one_epoch(self.device, e)
            self.config.epoch.append(e)
            self.config.train_loss.append(calc_loss.avg)
            print(f'Train. Epoch: {e}, Loss: {calc_loss.avg:.5f}, time: {(time.time() - t):.5f}')

            t = time.time()
            predictions, targets = self.valid_one_epoch(self.device, val_loader)
            acc_valid = self.accuracy_valid(predictions, targets)
            self.config.valid_acc.append(acc_valid)
            print(f'Val. Epoch: {e}, Acc valid: {acc_valid:.5f}, time: {(time.time() - t):.5f}')
            state = {'epoch': e, 'state_dict': self.model.state_dict(),'optimizer': self.optimizer.state_dict(), 
                         'train_loss': self.config.train_loss, 'valid_acc': self.config.valid_acc}
            self.save_model(state, f'model_transformer_{e}.pth')
            
        
        
    def train_one_epoch(self, device, e):
        self.model.train()
        calc_loss = AverageMeter()
        start = end = time.time()
        for batch_idx, (text_char, text_word,  target_word) in tqdm(enumerate(train_loader)):
            batch_size = text_char.size(0)
            text_char = torch.transpose(text_char, 0, 1)
            text_word = torch.transpose(text_word, 0, 1)
            target_word = torch.transpose(target_word, 0, 1)
            text_char = text_char.to(device, dtype=torch.int64)
            text_word = text_word.to(device, dtype=torch.int64)
            target_word = target_word.to(device, dtype=torch.int64)
            self.optimizer.zero_grad()
            with autocast():
                output = self.model(text_char, text_word, target_word[:-1])
                output = output.contiguous().reshape(-1, output.shape[2])
                target_word = target_word.contiguous().reshape(-1)
                loss = self.criterion(output, target_word)


                loss_value = loss.item()

            # Back prop
#             loss.backward()
            self.scaler.scale(loss).backward()

            # Clip to avoid exploding gradient issues, makes sure grads are
            # within a healthy range
#             torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1)

            # Gradient descent step
#             self.optimizer.step()
#             
            self.scaler.step(self.optimizer)
            self.scaler.update()
            calc_loss.update(loss_value, batch_size)
            
            if batch_idx %400==100:
                print('Epoch: [{0}][{1}/{2}] '
                  'Elapsed {remain:s} '
                  .format(
                   e, batch_idx, len(train_loader),
                   remain=timeSince(start, float(batch_idx+1)/len(train_loader)),
                   #lr=scheduler.get_lr()[0],
                   ))

        torch.cuda.empty_cache()
        gc.collect()
            
        return calc_loss
    
    def valid_one_epoch(self, device, val_loader):
        self.model.eval()
        predictions = []
        targets  =[]
        calc_loss = AverageMeter()
        with torch.no_grad():
            for batch_idx, (text_char,  text_word, target_word) in tqdm(enumerate(val_loader)): 
                batch_size = text_char.size(0)
                targets.append(vocab_word.decode(np.squeeze(target_word.detach().numpy()).tolist()))
                text_char = torch.transpose(text_char, 0, 1)
                text_word = torch.transpose(text_word, 0, 1)
#                 target_word = torch.transpose(target_word, 0, 1)
                text_char = text_char.to(device, dtype=torch.int64)
                text_word = text_word.to(device, dtype=torch.int64)
#                 target_word = target_word.to(device, dtype=torch.int64)
                prediction = model.inference(text_char, text_word, device)
#                 prediction = prediction.detach().cpu().numpy().tolist()
#                 prediction = vocab_word.decode(prediction)
                predictions.append(prediction)
#                 output = model(text_char,text_word, target_word[:-1])
                
                



#                 output = output.contiguous().reshape(-1, output.shape[2])
#                 target = target[1:].reshape(-1)
                
                
#                 loss = self.criterion(output, target)
#                 loss_value = loss.item()

#                 calc_loss.update(loss_value, batch_size)
                
        return predictions, targets
    
    def accuracy_valid(self, predictions, targets):
        n = len(predictions)
        acc = 0
    #     print(len(predictions), len(targets))
        for i in range(len(predictions)):

    #         print(f"Predictions {predictions}")
    #         print(f"Targets {targets}")
            if predictions[i]==targets[i]:
                acc+=1
        return acc/n
    
    
    def save_model(self, state, path):
        torch.save(state, path)
    
    def load_model(self, path):
        checkpoint = torch.load(path)
        self.start_epoch = checkpoint['epoch']+1
        self.model.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.config.train_loss = checkpoint['train_loss']
        self.config.valid_acc = checkpoint['acc_valid']

In [None]:
class GlobalParametersTrain:
    lr =0.0007
    vocab_char_size = len(vocab_char)
    vocab_word_size = len(vocab_word)
    num_epochs = 15
    model_dim = 256
    feed_forward_dim = 1024
    num_layers = 4
    max_len = 40
    PAD_IDX = 0
    SOS_IDX = 1
    EOS_IDX = 2 
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    train_loss = []
    valid_loss = []
    valid_acc = []
    epoch = []
    
    folder = './model'

In [None]:
torch.cuda.empty_cache()

In [None]:
config = GlobalParametersTrain()
model = TransformerModel(config.vocab_char_size, config.model_dim, config.feed_forward_dim, config.vocab_word_size, config.num_layers,
                            config.max_len, config.PAD_IDX, config.SOS_IDX, config.EOS_IDX).to(config.device)
training = Trainer(model, device=config.device, config=config)
training.train_loop(False)
