In [1]:
import itertools
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import re
import time
from tqdm.notebook import tqdm
import pandas as pd
import nltk
from nltk.tokenize import word_tokenize
import unidecode
import codecs
import pickle
import string
import random
from tqdm.notebook import tqdm
from transformers import AdamW
from nltk.tokenize.treebank import TreebankWordDetokenizer
from torchtext.data.metrics import bleu_score
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, OneCycleLR, CosineAnnealingLR
import torch
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import numpy as np
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import math
# nltk.download('punkt')
# sentence_tokenizer  =  nltk.data.load('tokenizers/punkt/english.pickle')

In [2]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

set_seed(seed = 38)

In [3]:
#borrow from https://github.com/hisiter97/Spelling_Correction_Vietnamese/blob/master/dataset/add_noise.py
class SynthesizeData(object):
    """
    Uitils class to create artificial miss-spelled words
    Args:
        vocab_path: path to vocab file. Vocab file is expected to be a set of words, separate by ' ', no newline charactor.
    """

    def __init__(self, vocab_path=""):

        # self.vocab = open(vocab_path, 'r', encoding = 'utf-8').read().split()
        self.tokenizer = word_tokenize
        self.word_couples = [['sương', 'xương'], ['sĩ', 'sỹ'], ['sẽ', 'sẻ'], ['sã', 'sả'], ['sả', 'xả'], ['sẽ', 'sẻ'],
                             ['mùi', 'muồi'],
                             ['chỉnh', 'chỉn'], ['sữa', 'sửa'], ['chuẩn', 'chẩn'], ['lẻ', 'lẽ'], ['chẳng', 'chẵng'],
                             ['cổ', 'cỗ'],
                             ['sát', 'xát'], ['cập', 'cặp'], ['truyện', 'chuyện'], ['xá', 'sá'], ['giả', 'dả'],
                             ['đỡ', 'đở'],
                             ['giữ', 'dữ'], ['giã', 'dã'], ['xảo', 'sảo'], ['kiểm', 'kiễm'], ['cuộc', 'cục'],
                             ['dạng', 'dạn'],
                             ['tản', 'tảng'], ['ngành', 'nghành'], ['nghề', 'ngề'], ['nổ', 'nỗ'], ['rảnh', 'rãnh'],
                             ['sẵn', 'sẳn'],
                             ['sáng', 'xán'], ['xuất', 'suất'], ['suôn', 'suông'], ['sử', 'xử'], ['sắc', 'xắc'],
                             ['chữa', 'chửa'],
                             ['thắn', 'thắng'], ['dỡ', 'dở'], ['trải', 'trãi'], ['trao', 'trau'], ['trung', 'chung'],
                             ['thăm', 'tham'],
                             ['sét', 'xét'], ['dục', 'giục'], ['tả', 'tã'], ['sông', 'xông'], ['sáo', 'xáo'],
                             ['sang', 'xang'],
                             ['ngã', 'ngả'], ['xuống', 'suống'], ['xuồng', 'suồng']]

        self.vn_alphabet = ['a', 'ă', 'â', 'b', 'c', 'd', 'đ', 'e', 'ê', 'g', 'h', 'i', 'k', 'l', 'm', 'n', 'o', 'ô',
                            'ơ', 'p', 'q', 'r', 's', 't', 'u', 'ư', 'v', 'x', 'y']
        self.alphabet_len = len(self.vn_alphabet)
        self.char_couples = [['i', 'y'], ['s', 'x'], ['gi', 'd'],
                             ['ă', 'â'], ['ch', 'tr'], ['ng', 'n'],
                             ['nh', 'n'], ['ngh', 'ng'], ['ục', 'uộc'], ['o', 'u'],
                             ['ă', 'a'], ['o', 'ô'], ['ả', 'ã'], ['ổ', 'ỗ'], ['ủ', 'ũ'], ['ễ', 'ể'],
                             ['e', 'ê'], ['à', 'ờ'], ['ằ', 'à'], ['ẩn', 'uẩn'], ['ẽ', 'ẻ'], ['ùi', 'uồi'], ['ă', 'â'],
                             ['ở', 'ỡ'], ['ỹ', 'ỷ'], ['ỉ', 'ĩ'], ['ị', 'ỵ'],
                             ['ấ', 'á'], ['n', 'l'], ['qu', 'w'], ['ph', 'f'], ['d', 'z'], ['c', 'k'], ['qu', 'q'],
                             ['i', 'j'], ['gi', 'j'],
                             ]

        self.teencode_dict = {'mình': ['mk', 'mik', 'mjk'], 'vô': ['zô', 'zo', 'vo'], 'vậy': ['zậy', 'z', 'zay', 'za'],
                              'phải': ['fải', 'fai', ], 'biết': ['bit', 'biet'],
                              'rồi': ['rùi', 'ròi', 'r'], 'bây': ['bi', 'bay'], 'giờ': ['h', ],
                              'không': ['k', 'ko', 'khong', 'hk', 'hong', 'hông', '0', 'kg', 'kh', ],
                              'đi': ['di', 'dj', ], 'gì': ['j', ], 'em': ['e', ], 'được': ['dc', 'đc', ], 'tao': ['t'],
                              'tôi': ['t'], 'chồng': ['ck'], 'vợ': ['vk']

                              }

        # self.typo={"ă":"aw","â":"aa","á":"as","à":"af","ả":"ar","ã":"ax","ạ":"aj","ắ":"aws","ổ":"oor","ỗ":"oox","ộ":"ooj","ơ":"ow",
        #           "ằ":"awf","ẳ":"awr","ẵ":"awx","ặ":"awj","ó":"os","ò":"of","ỏ":"or","õ":"ox","ọ":"oj","ô":"oo","ố":"oos","ồ":"oof",
        #           "ớ":"ows","ờ":"owf","ở":"owr","ỡ":"owx","ợ":"owj","é":"es","è":"ef","ẻ":"er","ẽ":"ex","ẹ":"ej","ê":"ee","ế":"ees","ề":"eef",
        #           "ể":"eer","ễ":"eex","ệ":"eej","ú":"us","ù":"uf","ủ":"ur","ũ":"ux","ụ":"uj","ư":"uw","ứ":"uws","ừ":"uwf","ử":"uwr","ữ":"uwx",
        #           "ự":"uwj","í":"is","ì":"if","ỉ":"ir","ị":"ij","ĩ":"ix","ý":"ys","ỳ":"yf","ỷ":"yr","ỵ":"yj","đ":"dd",
        #           "Ă":"Aw","Â":"Aa","Á":"As","À":"Af","Ả":"Ar","Ã":"Ax","Ạ":"Aj","Ắ":"Aws","Ổ":"Oor","Ỗ":"Oox","Ộ":"Ooj","Ơ":"Ow",
        #           "Ằ":"AWF","Ẳ":"Awr","Ẵ":"Awx","Ặ":"Awj","Ó":"Os","Ò":"Of","Ỏ":"Or","Õ":"Ox","Ọ":"Oj","Ô":"Oo","Ố":"Oos","Ồ":"Oof",
        #           "Ớ":"Ows","Ờ":"Owf","Ở":"Owr","Ỡ":"Owx","Ợ":"Owj","É":"Es","È":"Ef","Ẻ":"Er","Ẽ":"Ex","Ẹ":"Ej","Ê":"Ee","Ế":"Ees","Ề":"Eef",
        #           "Ể":"Eer","Ễ":"Eex","Ệ":"Eej","Ú":"Us","Ù":"Uf","Ủ":"Ur","Ũ":"Ux","Ụ":"Uj","Ư":"Uw","Ứ":"Uws","Ừ":"Uwf","Ử":"Uwr","Ữ":"Uwx",
        #           "Ự":"Uwj","Í":"Is","Ì":"If","Ỉ":"Ir","Ị":"Ij","Ĩ":"Ix","Ý":"Ys","Ỳ":"Yf","Ỷ":"Yr","Ỵ":"Yj","Đ":"Dd"}
        self.typo = {"ă": ["aw", "a8"], "â": ["aa", "a6"], "á": ["as", "a1"], "à": ["af", "a2"], "ả": ["ar", "a3"],
                     "ã": ["ax", "a4"], "ạ": ["aj", "a5"], "ắ": ["aws", "ă1"], "ổ": ["oor", "ô3"], "ỗ": ["oox", "ô4"],
                     "ộ": ["ooj", "ô5"], "ơ": ["ow", "o7"],
                     "ằ": ["awf", "ă2"], "ẳ": ["awr", "ă3"], "ẵ": ["awx", "ă4"], "ặ": ["awj", "ă5"], "ó": ["os", "o1"],
                     "ò": ["of", "o2"], "ỏ": ["or", "o3"], "õ": ["ox", "o4"], "ọ": ["oj", "o5"], "ô": ["oo", "o6"],
                     "ố": ["oos", "ô1"], "ồ": ["oof", "ô2"],
                     "ớ": ["ows", "ơ1"], "ờ": ["owf", "ơ2"], "ở": ["owr", "ơ2"], "ỡ": ["owx", "ơ4"], "ợ": ["owj", "ơ5"],
                     "é": ["es", "e1"], "è": ["ef", "e2"], "ẻ": ["er", "e3"], "ẽ": ["ex", "e4"], "ẹ": ["ej", "e5"],
                     "ê": ["ee", "e6"], "ế": ["ees", "ê1"], "ề": ["eef", "ê2"],
                     "ể": ["eer", "ê3"], "ễ": ["eex", "ê3"], "ệ": ["eej", "ê5"], "ú": ["us", "u1"], "ù": ["uf", "u2"],
                     "ủ": ["ur", "u3"], "ũ": ["ux", "u4"], "ụ": ["uj", "u5"], "ư": ["uw", "u7"], "ứ": ["uws", "ư1"],
                     "ừ": ["uwf", "ư2"], "ử": ["uwr", "ư3"], "ữ": ["uwx", "ư4"],
                     "ự": ["uwj", "ư5"], "í": ["is", "i1"], "ì": ["if", "i2"], "ỉ": ["ir", "i3"], "ị": ["ij", "i5"],
                     "ĩ": ["ix", "i4"], "ý": ["ys", "y1"], "ỳ": ["yf", "y2"], "ỷ": ["yr", "y3"], "ỵ": ["yj", "y5"],
                     "đ": ["dd", "d9"],
                     "Ă": ["Aw", "A8"], "Â": ["Aa", "A6"], "Á": ["As", "A1"], "À": ["Af", "A2"], "Ả": ["Ar", "A3"],
                     "Ã": ["Ax", "A4"], "Ạ": ["Aj", "A5"], "Ắ": ["Aws", "Ă1"], "Ổ": ["Oor", "Ô3"], "Ỗ": ["Oox", "Ô4"],
                     "Ộ": ["Ooj", "Ô5"], "Ơ": ["Ow", "O7"],
                     "Ằ": ["AWF", "Ă2"], "Ẳ": ["Awr", "Ă3"], "Ẵ": ["Awx", "Ă4"], "Ặ": ["Awj", "Ă5"], "Ó": ["Os", "O1"],
                     "Ò": ["Of", "O2"], "Ỏ": ["Or", "O3"], "Õ": ["Ox", "O4"], "Ọ": ["Oj", "O5"], "Ô": ["Oo", "O6"],
                     "Ố": ["Oos", "Ô1"], "Ồ": ["Oof", "Ô2"],
                     "Ớ": ["Ows", "Ơ1"], "Ờ": ["Owf", "Ơ2"], "Ở": ["Owr", "Ơ3"], "Ỡ": ["Owx", "Ơ4"], "Ợ": ["Owj", "Ơ5"],
                     "É": ["Es", "E1"], "È": ["Ef", "E2"], "Ẻ": ["Er", "E3"], "Ẽ": ["Ex", "E4"], "Ẹ": ["Ej", "E5"],
                     "Ê": ["Ee", "E6"], "Ế": ["Ees", "Ê1"], "Ề": ["Eef", "Ê2"],
                     "Ể": ["Eer", "Ê3"], "Ễ": ["Eex", "Ê4"], "Ệ": ["Eej", "Ê5"], "Ú": ["Us", "U1"], "Ù": ["Uf", "U2"],
                     "Ủ": ["Ur", "U3"], "Ũ": ["Ux", "U4"], "Ụ": ["Uj", "U5"], "Ư": ["Uw", "U7"], "Ứ": ["Uws", "Ư1"],
                     "Ừ": ["Uwf", "Ư2"], "Ử": ["Uwr", "Ư3"], "Ữ": ["Uwx", "Ư4"],
                     "Ự": ["Uwj", "Ư5"], "Í": ["Is", "I1"], "Ì": ["If", "I2"], "Ỉ": ["Ir", "I3"], "Ị": ["Ij", "I5"],
                     "Ĩ": ["Ix", "I4"], "Ý": ["Ys", "Y1"], "Ỳ": ["Yf", "Y2"], "Ỷ": ["Yr", "Y3"], "Ỵ": ["Yj", "Y5"],
                     "Đ": ["Dd", "D9"]}
        self.all_word_candidates = self.get_all_word_candidates(self.word_couples)
        self.string_all_word_candidates = ' '.join(self.all_word_candidates)
        self.all_char_candidates = self.get_all_char_candidates()
        self.keyboardNeighbors = self.getKeyboardNeighbors()

    def replace_teencode(self, word):
        candidates = self.teencode_dict.get(word, None)
        if candidates is not None:
            chosen_one = 0
            if len(candidates) > 1:
                chosen_one = np.random.randint(0, len(candidates))
            return candidates[chosen_one]

    def getKeyboardNeighbors(self):
        keyboardNeighbors = {}
        keyboardNeighbors['a'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ă'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['â'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['á'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['à'] = "aáàảãăắằẳẵâấầẩẫ"
        keyboardNeighbors['ả'] = "aảã"
        keyboardNeighbors['ã'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ạ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ắ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ằ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ẳ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ặ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ẵ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ấ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ầ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ẩ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ẫ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['ậ'] = "aáàảãạăắằẳẵặâấầẩẫậ"
        keyboardNeighbors['b'] = "bh"
        keyboardNeighbors['c'] = "cgn"
        keyboardNeighbors['d'] = "đctơở"
        keyboardNeighbors['đ'] = "d"
        keyboardNeighbors['e'] = "eéèẻẽẹêếềểễệbpg"
        keyboardNeighbors['é'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['è'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['ẻ'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['ẽ'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['ẹ'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['ê'] = "eéèẻẽẹêếềểễệá"
        keyboardNeighbors['ế'] = "eéèẻẽẹêếềểễệố"
        keyboardNeighbors['ề'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['ể'] = "eéèẻẽẹêếềểễệôốồổỗộ"
        keyboardNeighbors['ễ'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['ệ'] = "eéèẻẽẹêếềểễệ"
        keyboardNeighbors['g'] = "qgộ"
        keyboardNeighbors['h'] = "h"
        keyboardNeighbors['i'] = "iíìỉĩịat"
        keyboardNeighbors['í'] = "iíìỉĩị"
        keyboardNeighbors['ì'] = "iíìỉĩị"
        keyboardNeighbors['ỉ'] = "iíìỉĩị"
        keyboardNeighbors['ĩ'] = "iíìỉĩị"
        keyboardNeighbors['ị'] = "iíìỉĩịhự"
        keyboardNeighbors['k'] = "klh"
        keyboardNeighbors['l'] = "ljidđ"
        keyboardNeighbors['m'] = "mn"
        keyboardNeighbors['n'] = "mnedư"
        keyboardNeighbors['o'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ó'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ò'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ỏ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['õ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ọ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ô'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ố'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ồ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ổ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ộ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ỗ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ơ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ớ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ờ'] = "oóòỏọõôốồổỗộơớờởợỡà"
        keyboardNeighbors['ở'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ợ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        keyboardNeighbors['ỡ'] = "oóòỏọõôốồổỗộơớờởợỡ"
        # keyboardNeighbors['p'] = "op"
        # keyboardNeighbors['q'] = "qọ"
        # keyboardNeighbors['r'] = "rht"
        # keyboardNeighbors['s'] = "s"
        # keyboardNeighbors['t'] = "tp"
        keyboardNeighbors['u'] = "uúùủũụưứừữửựhiaạt"
        keyboardNeighbors['ú'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ù'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ủ'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ũ'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ụ'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ư'] = "uúùủũụưứừữửựg"
        keyboardNeighbors['ứ'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ừ'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ử'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ữ'] = "uúùủũụưứừữửự"
        keyboardNeighbors['ự'] = "uúùủũụưứừữửựg"
        keyboardNeighbors['v'] = "v"
        keyboardNeighbors['x'] = "x"
        keyboardNeighbors['y'] = "yýỳỷỵỹụ"
        keyboardNeighbors['ý'] = "yýỳỷỵỹ"
        keyboardNeighbors['ỳ'] = "yýỳỷỵỹ"
        keyboardNeighbors['ỷ'] = "yýỳỷỵỹ"
        keyboardNeighbors['ỵ'] = "yýỳỷỵỹ"
        keyboardNeighbors['ỹ'] = "yýỳỷỵỹ"
        # keyboardNeighbors['w'] = "wv"
        # keyboardNeighbors['j'] = "jli"
        # keyboardNeighbors['z'] = "zs"
        # keyboardNeighbors['f'] = "ft"

        return keyboardNeighbors

    def replace_char_noaccent(self, text, onehot_label):

        # find index noise
        idx = np.random.randint(0, len(onehot_label))
        prevent_loop = 0
        while onehot_label[idx] == 1 or text[idx].isnumeric() or text[idx] in string.punctuation:
            idx = np.random.randint(0, len(onehot_label))
            prevent_loop += 1
            if prevent_loop > 10:
                return False, text, onehot_label

        index_noise = idx
        onehot_label[index_noise] = 1
        word_noise = text[index_noise]
        for id in range(0, len(word_noise)):
            char = word_noise[id]

            if char in self.keyboardNeighbors:
                neighbors = self.keyboardNeighbors[char]
                idx_neigh = np.random.randint(0, len(neighbors))
                replaced = neighbors[idx_neigh]
                word_noise = word_noise[: id] + replaced + word_noise[id + 1:]
                text[index_noise] = word_noise
                return True, text, onehot_label

        return False, text, onehot_label

    def replace_word_candidate(self, word):
        """
        Return a homophone word of the input word.
        """
        capital_flag = word[0].isupper()
        word = word.lower()
        if capital_flag and word in self.teencode_dict:
            return self.replace_teencode(word).capitalize()
        elif word in self.teencode_dict:
            return self.replace_teencode(word)

        for couple in self.word_couples:
            for i in range(2):
                if couple[i] == word:
                    if i == 0:
                        if capital_flag:
                            return couple[1].capitalize()
                        else:
                            return couple[1]
                    else:
                        if capital_flag:
                            return couple[0].capitalize()
                        else:
                            return couple[0]

    def replace_char_candidate(self, char):
        """
        return a homophone char/subword of the input char.
        """
        for couple in self.char_couples:
            for i in range(2):
                if couple[i] == char:
                    if i == 0:
                        return couple[1]
                    else:
                        return couple[0]

    def replace_char_candidate_typo(self, char):
        """
        return a homophone char/subword of the input char.
        """
        i = np.random.randint(0, 2)

        return self.typo[char][i]

    def get_all_char_candidates(self, ):

        all_char_candidates = []
        for couple in self.char_couples:
            all_char_candidates.extend(couple)
        return all_char_candidates

    def get_all_word_candidates(self, word_couples):

        all_word_candidates = []
        for couple in self.word_couples:
            all_word_candidates.extend(couple)
        return all_word_candidates

    def remove_diacritics(self, text, onehot_label):
        """
        Replace word which has diacritics with the same word without diacritics
        Args:
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this
            function only choose the word that do not has onehot label == 1.
        return: a list of word tokens has one word that its diacritics was removed,
                a list of onehot label indicate the position of words that has been modified.
        """
        idx = np.random.randint(0, len(onehot_label))
        prevent_loop = 0
        while onehot_label[idx] == 1 or text[idx] == unidecode.unidecode(text[idx]) or text[idx] in string.punctuation:
            idx = np.random.randint(0, len(onehot_label))
            prevent_loop += 1
            if prevent_loop > 10:
                return False, text, onehot_label

        onehot_label[idx] = 1
        text[idx] = unidecode.unidecode(text[idx])
        return True, text, onehot_label

    def replace_with_random_letter(self, text, onehot_label):
        """
        Replace, add (or remove) a random letter in a random chosen word with a random letter
        Args:
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this
            function only choose the word that do not has onehot label == 1.
        return: a list of word tokens has one word that has been modified,
                a list of onehot label indicate the position of words that has been modified.
        """
        idx = np.random.randint(0, len(onehot_label))
        prevent_loop = 0
        while onehot_label[idx] == 1 or text[idx].isnumeric() or text[idx] in string.punctuation:
            idx = np.random.randint(0, len(onehot_label))
            prevent_loop += 1
            if prevent_loop > 10:
                return False, text, onehot_label

        # replace, add or remove? 0 is replace, 1 is add, 2 is remove
        coin = np.random.choice([0, 1, 2])
        if coin == 0:
            chosen_letter = text[idx][np.random.randint(0, len(text[idx]))]
            replaced = self.vn_alphabet[np.random.randint(0, self.alphabet_len)]
            try:
                text[idx] = re.sub(chosen_letter, replaced, text[idx])
            except:
                return False, text, onehot_label
        elif coin == 1:
            chosen_letter = text[idx][np.random.randint(0, len(text[idx]))]
            replaced = chosen_letter + self.vn_alphabet[np.random.randint(0, self.alphabet_len)]
            try:
                text[idx] = re.sub(chosen_letter, replaced, text[idx])
            except:
                return False, text, onehot_label
        else:
            chosen_letter = text[idx][np.random.randint(0, len(text[idx]))]
            try:
                text[idx] = re.sub(chosen_letter, '', text[idx])
            except:
                return False, text, onehot_label

        onehot_label[idx] = 1
        return True, text, onehot_label

    def replace_with_homophone_word(self, text, onehot_label):
        """
        Replace a candidate word (if exist in the word_couple) with its homophone. if successful, return True, else False
        Args:
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this
            function only choose the word that do not has onehot label == 1.
        return: True, text, onehot_label if successful replace, else False, text, onehot_label
        """
        # account for the case that the word in the text is upper case but its lowercase match the candidates list
        candidates = []
        for i in range(len(text)):
            if text[i].lower() in self.all_word_candidates or text[i].lower() in self.teencode_dict.keys():
                candidates.append((i, text[i]))

        if len(candidates) == 0:
            return False, text, onehot_label

        idx = np.random.randint(0, len(candidates))
        prevent_loop = 0
        while onehot_label[candidates[idx][0]] == 1:
            idx = np.random.choice(np.arange(0, len(candidates)))
            prevent_loop += 1
            if prevent_loop > 5:
                return False, text, onehot_label

        text[candidates[idx][0]] = self.replace_word_candidate(candidates[idx][1])
        onehot_label[candidates[idx][0]] = 1
        return True, text, onehot_label

    def replace_with_homophone_letter(self, text, onehot_label):
        """
        Replace a subword/letter with its homophones
        Args:
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this
            function only choose the word that do not has onehot label == 1.
        return: True, text, onehot_label if successful replace, else False, None, None
        """
        candidates = []
        for i in range(len(text)):
            for char in self.all_char_candidates:
                if re.search(char, text[i]) is not None:
                    candidates.append((i, char))
                    break

        if len(candidates) == 0:

            return False, text, onehot_label
        else:
            idx = np.random.randint(0, len(candidates))
            prevent_loop = 0
            while onehot_label[candidates[idx][0]] == 1:
                idx = np.random.randint(0, len(candidates))
                prevent_loop += 1
                if prevent_loop > 5:
                    return False, text, onehot_label

            replaced = self.replace_char_candidate(candidates[idx][1])
            text[candidates[idx][0]] = re.sub(candidates[idx][1], replaced, text[candidates[idx][0]])

            onehot_label[candidates[idx][0]] = 1
            return True, text, onehot_label

    def replace_with_typo_letter(self, text, onehot_label):
        """
        Replace a subword/letter with its homophones
        Args:
            text: a list of word tokens
            onehot_label: onehot array indicate position of word that has already modify, so this
            function only choose the word that do not has onehot label == 1.
        return: True, text, onehot_label if successful replace, else False, None, None
        """
        # find index noise
        idx = np.random.randint(0, len(onehot_label))
        prevent_loop = 0
        while onehot_label[idx] == 1 or text[idx].isnumeric() or text[idx] in string.punctuation:
            idx = np.random.randint(0, len(onehot_label))
            prevent_loop += 1
            if prevent_loop > 10:
                return False, text, onehot_label

        index_noise = idx
        onehot_label[index_noise] = 1

        word_noise = text[index_noise]
        for j in range(0, len(word_noise)):
            char = word_noise[j]

            if char in self.typo:
                replaced = self.replace_char_candidate_typo(char)
                word_noise = word_noise[: j] + replaced + word_noise[j + 1:]
                text[index_noise] = word_noise
                return True, text, onehot_label
        return True, text, onehot_label

    def add_noise(self, sentence, percent_err=0.15, num_type_err=5):
        tokens = self.tokenizer(sentence)
        onehot_label = [0] * len(tokens)

        num_wrong = int(np.ceil(percent_err * len(tokens)))
        num_wrong = np.random.randint(1, num_wrong + 1)
        if np.random.rand() < 0.05:
            num_wrong = 0

        for i in range(0, num_wrong):
            err = np.random.randint(0, num_type_err + 1)

            if err == 0:
                _, tokens, onehot_label = self.replace_with_homophone_letter(tokens, onehot_label)
            elif err == 1:
                _, tokens, onehot_label = self.replace_with_typo_letter(tokens, onehot_label)
            elif err == 2:
                _, tokens, onehot_label = self.replace_with_homophone_word(tokens, onehot_label)
            elif err == 3:
                _, tokens, onehot_label = self.replace_with_random_letter(tokens, onehot_label)
            elif err == 4:
                _, tokens, onehot_label = self.remove_diacritics(tokens, onehot_label)
            elif err == 5:
                _, tokens, onehot_label = self.replace_char_noaccent(tokens, onehot_label)
            else:
                continue
            # print(tokens)
        return ' '.join(tokens)

In [4]:
synthesizer = SynthesizeData()
# data = []
# for line in content_all[0:3]:
#     tup = synthesizer.add_noise(line)
#     data.append(tup)
# print(data)

# Preprocessing

In [5]:
# PUNCT_TO_REMOVE = string.punctuation
# def remove_punctuation(text):
#     """custom function to remove the punctuation"""
#     return text.translate(str.maketrans('', '', PUNCT_TO_REMOVE))

# def remove_urls(text):
#     url_pattern = re.compile(r'https?://\S+|www\.\S+')
#     return url_pattern.sub(r'', text)

# def remove_emoji(string):
#     emoji_pattern = re.compile("["
#                            u"\U0001F600-\U0001F64F"  # emoticons
#                            u"\U0001F300-\U0001F5FF"  # symbols & pictographs
#                            u"\U0001F680-\U0001F6FF"  # transport & map symbols
#                            u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
#                            u"\U00002702-\U000027B0"
#                            u"\U000024C2-\U0001F251"
#                            "]+", flags=re.UNICODE)
#     return emoji_pattern.sub(r'', string)

# def clean_numbers(x):
#     if bool(re.search(r'\d', x)):
#         x = re.sub('[0-9]{5,}', '#####', x)
#         x = re.sub('[0-9]{4}', '####', x)
#         x = re.sub('[0-9]{3}', '###', x)
#         x = re.sub('[0-9]{2}', '##', x)
#     return x

# def preprocessing_data(row):
#     processed = re.sub(
#             r'[^aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&''()*+,-./:;<=>?@[\]^_`{|}~ ]',
#             "", row)
#     return processed

# df.head()

# Params CONFIG


In [6]:
print("here")
MAXLEN = 40
NGRAM = 5
alphabets = 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬ0bBcCdDđĐeEè1ÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈ2ĩĨíÍịỊjJkKlLmMnNoO3òÒỏỎõÕóÓọỌôÔồ4ỒổỔỗỖốỐộỘơƠờỜ5ởỞỡỠớỚợỢpP6qQrRsStTuUùÙủỦ7ũŨúÚụỤưƯừỪửỬữỮứỨựỰvVw8WxXyYỳỲỷỶ9ỹỸýÝỵỴzZ!"#$%&\'()*+,-./:;<=>?@[\]^_`{|}~ '
# alphabets = 'aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&\'()*+,-./:;<=>?@[\]^_`{|}~ ]'
print(len(alphabets))
print(alphabets[76])
print(alphabets)

In [7]:
# content_all = df['Content'].values.tolist()
# content_all[0:3]

In [8]:
# ##https://viblo.asia/p/ung-dung-machine-translation-vao-bai-toan-them-dau-cho-tieng-viet-khong-dau-aivivn-challenge-3-3P0lP4a8lox
# class CreateDataset():
#     def __init__(self, csv_path='../input/kpdl-data/train_remove_noise.csv', save_path="./list_ngrams.npy"):
#         self.csv_path = csv_path
#         self.alphabets_regex = '^[aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&''()*+,-./:;<=>?@[\]^_`{|}~ ]'
#         self.save_path = save_path


#     def processing(self):
#         # read csv
#         df = pd.read_csv(self.csv_path)

#         # remove characters that out of vocab
#         df['Content'] = df['Content'].apply(self.preprocessing_data)

#         # extract phrases
#         phrases = itertools.chain.from_iterable(self.extract_phrases(text) for text in df['Content'])
#         phrases = [p.strip() for p in phrases if len(p.split()) > 1]

#         # gen ngrams
#         list_ngrams = []
#         for p in tqdm(phrases):
#             if not re.match(self.alphabets_regex, p.lower()):
#                 continue
#             if len(phrases) == 0:
#                 continue

#             for ngr in self.gen_ngrams(p, NGRAM):
#                 if len(" ".join(ngr)) < MAXLEN:
#                     list_ngrams.append(" ".join(ngr))
#         print("DONE extract ngrams, total ngrams: ", len(list_ngrams))
#         print(list_ngrams[0:30])

#         # save ngrams
#         self.save_ngrams(list_ngrams, save_path=self.save_path)

#         print("Done create dataset - ngrams")

#     def preprocessing_data(self, row):
#         processed = re.sub(
#             r'[^aAàÀảẢãÃáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0123456789!"#$%&''()*+,-./:;<=>?@[\]^_`{|}~ ]',
#             "", row)
#         return processed

#     def extract_phrases(self, text):
#         return re.findall(r'\w[\w ]+', text)

#     def gen_ngrams(self, text, n=5):
#         tokens = text.split()

#         if len(tokens) < n:
#             return [tokens]

#         return nltk.ngrams(text.split(), n)

#     def save_ngrams(self, list_ngrams, save_path='ngrams_list.npy'):
#         with open(save_path, 'wb') as f:
#             np.save(f, list_ngrams)
#         print("Saved dataset - ngrams")

# creater = CreateDataset()
# creater.processing()

In [9]:
# encode_string1
# encode_string1 = np.array(encode_string1)
# print(encode_string1.shape[0])

In [10]:
class Vocab():
    def __init__(self, chars):
        self.pad = 0
        self.go = 1
        self.eos = 2

        self.chars = chars
       
        
        self.i2c = {i + 3: c for i, c in enumerate(chars)}
        
        self.c2i = {c: i + 3 for i, c in enumerate(chars)}
        
        self.i2c[0] = '<pad>'
        self.i2c[1] = '<sos>'
        self.i2c[2] = '<eos>'

        
        

    def encode(self, chars):
        return [self.go] + [self.c2i[c] for c in chars] + [self.eos]

    def decode(self, ids):
        first = 1 if self.go in ids else 0
        last = ids.index(self.eos) if self.eos in ids else None
        sent = ''.join([self.i2c[i] for i in ids[first:last]])
        return sent

    def __len__(self):
        return len(self.c2i) + 3

    def batch_decode(self, arr):
        texts = [self.decode(ids) for ids in arr]
        return texts

    def __str__(self):
        return self.chars
    
vocab = Vocab(alphabets)
print(vocab.i2c)
print(vocab.c2i)
string1 = 'Tôi thích đi dạo'
encode_string1 = vocab.encode(string1)
print(encode_string1)

In [24]:
list_ngram_target = np.load('../input/5gram-nlp-project/list_5gram_nonum_train.npy')
list_ngram_target  = list_ngram_target [:1000000]
list_ngram_valid_target = np.load('../input/5gram-nlp-project/list_5gram_nonum_valid.npy')
list_ngram_valid_target = list_ngram_valid_target[5000:15000]
list_ngram_train = np.load('../input/5gram-nlp-project/train_normal_captions.npy')
list_ngram_valid = np.load('../input/5gram-nlp-project/valid_normal_captions.npy')
class SpellingDataset(Dataset):
    def __init__(self, list_ngram, list_ngram_target, vocab, maxlen):
        self.list_ngram = list_ngram
        self.list_ngram_target = list_ngram_target
        self.vocab = vocab
        self.max_len = maxlen
    
    def __getitem__(self, index):
        train_text = self.list_ngram[index]
        train_target = self.list_ngram_target[index]
        
        
        train_text_encode = self.vocab.encode(train_text)
        train_target_encode = self.vocab.encode(train_target)
        
        train_text_length = len(train_text_encode)
        train_target_length = len(train_target_encode)
        
        if(train_text_length < self.max_len):
            pad_length = self.max_len-train_text_length
            train_text_encode = np.array(train_text_encode)
            train_text_encode = np.concatenate((train_text_encode, np.zeros(pad_length)), axis = 0)
            
        elif(train_text_length>= self.max_len):
            train_text_encode = train_text_encode[0:self.max_len]
            train_text_encode = np.array(train_text_encode)
            
        if(train_target_length < self.max_len):
            pad_length = self.max_len-train_target_length
            train_target_encode = np.array(train_target_encode)
            train_target_encode = np.concatenate((train_target_encode, np.zeros(pad_length)), axis = 0)
            
        elif(train_target_length>= self.max_len):
            train_target_encode = train_target_encode[0:self.max_len]
            train_target_encode = np.array(train_target_encode)      
               
        tensor_text = torch.from_numpy(train_text_encode)
        tensor_target = torch.from_numpy(train_target_encode)
        return tensor_text, tensor_target
        
        
    
    def __len__(self):
        return len(self.list_ngram)

ds_train = SpellingDataset(list_ngram_train, list_ngram_target, vocab, MAXLEN)
ds_valid = SpellingDataset(list_ngram_valid, list_ngram_valid_target,vocab, MAXLEN)
# ds_test = SpellingDataset(list_ngram_test, synthesizer, vocab, MAXLEN)
train_loader = DataLoader(ds_train, batch_size = 512 , shuffle=True)
val_loader = DataLoader(ds_valid, batch_size = 1)
# test_loader = DataLoader(ds_test, batch_size = 200)
print(len(train_loader), len(val_loader))
text, target = next(iter(train_loader))
valid_text, valid_target = next(iter(train_loader))
print(text[0])
print(target[0])
print("Text: ", vocab.decode(np.squeeze(text[0].detach().numpy()).tolist()))
print("Target: ", vocab.decode(np.squeeze(target[0].detach().numpy()).tolist()))
print("Text: ", vocab.decode(np.squeeze(valid_text[0].detach().numpy()).tolist()))
print("Target: ", vocab.decode(np.squeeze(valid_target[0].detach().numpy()).tolist()))
print(text, target)
print(text.size(), target.size())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [12]:
# class MultiHeadAttention(nn.Module):
#     def __init__(self, model_size, num_heads):
#         super(MultiHeadAttention, self).__init__()
#         self.key_size = model_size //num_heads
#         self.heads = num_heads
#         self.wq = nn.Linear(model_size, model_size)
#         self.wk = nn.Linear(model_size, model_size)
#         self.wv = nn.Linear(model_size, model_size)
#         self.wo = nn.Linear(model_size, model_size)
        
    
#     def forward(self, query, key, value, mask = None):
#         # query len is seq_length
#         # model_size is embed size, just different way to call
#         # query shape (batch, query_len, model_size)
#         # value shape (batch, value_len, model_size)
#         query = self.wq(query)
#         key = self.wk(key)
#         value = self.wv(value)
        
#         # Originally, query has shape (batch, query_len, model_size)
#         # We need to reshape to (batch, query_len, h, key_size)
        
#         batch_size = query.shape[0]
#         query = query.reshape(batch_size, -1, self.heads, self.key_size)
#         # In order to compute matmul, the dimensions must be transposed to (batch_size, heads, query_len, key_size)
#         query = query.transpose(1, 2)

#         # Do the same for key and value
#         key = key.reshape(batch_size, -1, self.heads, self.key_size)
#         key = key.transpose(1, 2)
#         value = value.reshape(batch_size, -1, self.heads, self.key_size)
#         value = value.transpose(1, 2)
        
#         # query shape(batch_size, heads, query_len, key_size)
#         # value and key shape(batch_size, heads, value_len, key_size)
#         score = torch.matmul(query, key.transpose(2, 3))
#          # score will have shape of (batch_size, heads, query_len, value_len)
#         score = score / (torch.sqrt(torch.FloatTensor([self.key_size])).to(device))
#         if mask is not None:
#             score = score.masked_fill(mask==0, -1e10)
        
#         attention = torch.softmax(score, dim=-1)
#         # attention shape (batch_size, heads, query_len, value_len)
#         # value shape (batch_size, heads, value, key_size)
        
#         context = torch.matmul(attention, value)
#         # context shape (batch_size, heads, query_len, key_size)
        
#         context = context.transpose(1, 2).reshape(batch_size, -1, self.key_size * self.heads)
#         # context shape (batch_size, query_len, heads, key_size ) -> (batch_size, query_len, model_size)
        
#         x = self.wo(context)
#         # x shape (batch_size, query_len, model_size)
#         return x, attention
        

In [13]:
# class TransformerBlock(nn.Module):
#     def __init__(self, model_size, num_heads, dropout, forward_expansion):
#         super(TransformerBlock, self).__init__()
#         self.attention = MultiHeadAttention(model_size, num_heads)
#         self.norm1 = nn.LayerNorm(model_size)
#         self.norm2 = nn.LayerNorm(model_size)
    
#         self.feed_forward = nn.Sequential(nn.Linear(model_size, model_size* forward_expansion), nn.ReLU(), nn.Linear(model_size * forward_expansion, model_size))
#         self.dropout = nn.Dropout(dropout)
        
#     def forward(self, query, value, key, mask):
#         x, attention = self.attention(query, value, key, mask)
#         x = self.dropout(self.norm1(x) + query)
        
#         # x shape (batch_size, query_len, model_size)
#         forward = self.feed_forward(x)
#         out = self.dropout(self.norm2(forward) + x)
#         # out shape (batch_size, query_len, model_size)
#         return out, attention
        

In [14]:
# class Encoder(nn.Module):
#     def __init__(self, src_vocab_size, model_size, num_layers, num_heads, device, forward_expansion, dropout, pes, max_length=40):
#         super(Encoder, self).__init__()
#         self.model_size = model_size
#         self.device = device
#         self.word_embedding = nn.Embedding(src_vocab_size, model_size)
#         self.position_embedding = nn.Embedding(max_length, model_size)
        
#         self.layers = nn.ModuleList([TransformerBlock(model_size, num_heads, dropout, forward_expansion) for _ in range(num_layers)])
#         self.dropout = nn.Dropout(dropout)
    
#     def forward(self, x, mask):
#         #x = [batch size, src len]
#         #x_mask = [batch size, 1, 1, src len]
#         batch_size, seq_len = x.shape
#         pos = torch.arange(0, seq_len).unsqueeze(0).repeat(batch_size, 1).to(self.device)
#         # pos shape (batch_size, seq_len)
# #         print("X shape", x.size())
#         embed_out = self.word_embedding(x)
#         embed_out *= math.sqrt(self.model_size)
#         embed_out += pes[:seq_len, :]
#         x = self.dropout(embed_out)
        
#         # out shape (batch_size, seq_len, model_size)
#         for layer in self.layers:
#             x, _ = layer(x, x, x, mask)
        
#         return x

In [15]:
# class DecoderBlock(nn.Module):
#     def __init__(self, model_size, num_heads, device, forward_expansion, dropout, pes, max_length=40):
#         super(DecoderBlock, self).__init__()
    
#         self.attention = MultiHeadAttention(model_size, num_heads)
#         self.norm = nn.LayerNorm(model_size)
#         self.transformerBlock = TransformerBlock(model_size, num_heads, dropout,forward_expansion)
#         self.dropout = nn.Dropout(dropout)
    
#     def forward(self, x, value, key, src_mask, trg_mask):
#         #x = (batch_size, seq_len, model_size]
#         #enc_src = (batch_size, src_len, model_size)
#         #trg_mask = (batch size, 1, trg_len, trg_len)
#         #src_mask = (batch size, 1, 1, src_len)
        
#         trg, _ = self.attention(x, x, x,  trg_mask)
#         query = self.dropout(self.norm(trg) + x )
#         out, attention = self.transformerBlock(query, value, key, src_mask)
#         return out, attention

In [16]:
# class Decoder(nn.Module):
#     def __init__(self, target_vocab_size, model_size, num_layers, num_heads, device, forward_expansion, dropout, pes, max_length=40):
#         super(Decoder, self).__init__()
#         self.model_size = model_size
#         self.device = device
#         self.word_embedding = nn.Embedding(target_vocab_size, model_size)
#         self.position_embedding = nn.Embedding(max_length, model_size)
        
#         self.layers = nn.ModuleList([DecoderBlock(model_size, num_heads, device, forward_expansion, dropout, pes,) for _ in range(num_layers)])
#         self.dropout = nn.Dropout(dropout)
#         self.fc_out = nn.Linear(model_size, target_vocab_size)
        
#     def forward(self, x, enc_output, src_mask, target_mask):
#         batch_size, target_length = x.shape
# #         pos = torch.arange(0, target_length).unsqueeze(0).repeat(batch_size, 1).to(self.device)
#         embed_out = self.word_embedding(x)
#         embed_out *= math.sqrt(self.model_size)
#         embed_out += pes[:target_length, :]
#         x = self.dropout(embed_out)
        
#         for layer in self.layers:
#             x, attention = layer(x, enc_output, enc_output, src_mask, target_mask)
        
#         output = self.fc_out(x)
        
#         return output, attention
        

In [17]:
# class Seq2SeqTransformer(nn.Module):
#     def __init__(self, src_vocab_size, target_vocab_size, src_pad_idx, target_pad_idx, pes, device, model_size = 512, num_layers = 6, forward_expansion = 4, num_heads = 8, dropout = 0.1, max_length = 40):
#         super(Seq2SeqTransformer, self).__init__()
    
#         self.encoder = Encoder(src_vocab_size, model_size, num_layers, num_heads, device, forward_expansion, dropout, pes, max_length)
#         self.decoder = Decoder(target_vocab_size, model_size, num_layers, num_heads, device, forward_expansion, dropout, pes, max_length)
#         self.src_pad_idx = src_pad_idx
#         self.target_pad_idx = target_pad_idx
#         self.device = device
    
#     def make_src_mask(self, src):
        
#         #src = (batch_size, src_len)
        
#         src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

#         #src_mask = (batch_size, 1, 1, src_len)

#         return src_mask.to(self.device)
    
#     def make_target_mask(self, target):
        
#         batch_size, target_len = target.shape
#         target_pad_mask = (target != self.target_pad_idx).unsqueeze(1).unsqueeze(2)
# #         target_mask = torch.tril(torch.ones((target_len, target_len))).expand(batch_size, 1, target_len, target_len)
#         target_mask = torch.tril(torch.ones((target_len, target_len), device = self.device)).bool()
#         target_mask = target_pad_mask & target_mask
#         return target_mask.to(device)
    
#     def forward(self, src, target):
#         src_mask = self.make_src_mask(src)
#         target_mask = self.make_target_mask(target)
#         enc_src = self.encoder(src, src_mask)
#         out, attention = self.decoder(target, enc_src, src_mask, target_mask)
#         return out, attention

In [18]:
class PositionEncoder(nn.Module):
    def __init__(self, max_len, emb_size, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_len, emb_size)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, emb_size, 2).float() * (-math.log(10000.0) / emb_size))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerEncoder(nn.Module):
    def __init__(self, input_size, emb_size, hidden_size, num_layer, max_len=64):
        super().__init__()
        self.emb_size = emb_size
        self.hidden_size = hidden_size
        self.num_layer = num_layer
        self.scale = math.sqrt(emb_size)

        self.embedding = nn.Embedding(input_size, emb_size)
        # additional length for sos and eos
        self.pos_encoder = PositionEncoder(max_len + 10, emb_size)
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_size, nhead=8,
                                                   dim_feedforward=hidden_size,
                                                   dropout=0.1, activation='gelu')
        encoder_norm = nn.LayerNorm(emb_size)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layer, norm=encoder_norm)

    def forward(self, src, src_mask):
        src = self.embedding(src) * self.scale
        src = self.pos_encoder(src)
        output = self.encoder(src, src_key_padding_mask=src_mask)
        return output

class TransformerDecoder(nn.Module):
    def __init__(self, output_size, emb_size, hidden_size, num_layer, max_len=64):
        super().__init__()
        self.emb_size = emb_size
        self.hidden_size = hidden_size
        self.num_layer = num_layer
        self.scale = math.sqrt(emb_size)

        self.embedding = nn.Embedding(output_size, emb_size)
        self.pos_encoder = PositionEncoder(max_len + 10, emb_size)
        decoder_layer = nn.TransformerDecoderLayer(d_model=emb_size, nhead=8,
                                                   dim_feedforward=hidden_size,
                                                   dropout=0.1, activation='gelu')
        decoder_norm = nn.LayerNorm(emb_size)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layer, norm=decoder_norm)
        self.fc = nn.Linear(emb_size, output_size)

    def forward(self, trg, enc_output, sub_mask, mask):
#         print(trg.size())
        trg = self.embedding(trg) * self.scale
#         print(trg.size())
        trg = self.pos_encoder(trg)
#         print(trg.size())
#         print("Target", trg.size())
#         print("Target sub mask", sub_mask.size())
#         print("Target mask", mask.size())
        output = self.decoder(trg, enc_output, tgt_mask=sub_mask, tgt_key_padding_mask=mask)
#         print(output.size())
        output = self.fc(output)
        return output

class TransformerModel(nn.Module):
    def __init__(self, input_size, emb_size, hidden_size, output_size,
                 num_layer, max_len, pad_token, sos_token, eos_token):
        super().__init__()
        self.encoder = TransformerEncoder(input_size, emb_size, hidden_size, num_layer, max_len)
        self.decoder = TransformerDecoder(output_size, emb_size, hidden_size, num_layer, max_len)
        self.pad_token = pad_token
        self.sos_token = sos_token
        self.eos_token = eos_token

        self.encoder.apply(self.initialize_weights)
        self.decoder.apply(self.initialize_weights)

    @staticmethod
    def initialize_weights(m):
        if hasattr(m, 'weight') and m.weight.dim() > 1:
            nn.init.xavier_uniform_(m.weight.data)

    @staticmethod
    def generate_mask(src, pad_token):
        '''
        Generate mask for tensor src
        :param src: tensor with shape (max_src, b)
        :param pad_token: padding token
        :return: mask with shape (b, max_src) where pad_token is masked with 1
        '''
        mask = (src.t() == pad_token)
        return mask.to(src.device)

    @staticmethod
    def generate_submask(src):
        sz = src.size(0)
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask.to(src.device)

    def forward(self, src, trg):
#         print("Src size", src.size())
#         print("Target size", trg.size())
        src_mask = self.generate_mask(src, self.pad_token)
        trg_mask = self.generate_mask(trg, self.pad_token)
#         print("Src mask size", src_mask.size())
#         print("Target mask size", trg_mask.size())
        trg_submask = self.generate_submask(trg)
#         print("Target submask size", trg_submask.size())
        enc_output = self.encoder(src, src_mask)
#         print("Encoding_output size", enc_output.size())
        output = self.decoder(trg, enc_output, trg_submask, trg_mask)
        return output

    def inference(self, src, max_len, device):
#         assert src.dim() == 1, 'Can only translate one sentence at a time!'
#         assert src.size(0) <= max_len + 2, f'Source sentence exceeds max length: {max_len}'

#         src.unsqueeze_(-1)
        
        src_mask = self.generate_mask(src, self.pad_token)
        enc_output = self.encoder(src, src_mask)
#         device = src.device

        trg_list = [self.sos_token]
        for idx in range(max_len):
            trg = torch.tensor(trg_list, dtype=torch.long, device=device).unsqueeze(-1)
            trg_mask = self.generate_mask(trg, self.pad_token)
            trg_submask = self.generate_submask(trg)
            output = self.decoder(trg, enc_output, trg_submask, trg_mask)
            pred = torch.argmax(output.squeeze(1), dim=-1)[-1].item()
            trg_list.append(pred)
            if pred == self.eos_token:
                break
        return torch.tensor(trg_list[1:], dtype=torch.long, device=device)

In [19]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))

In [20]:
class Trainer(object):
    def __init__(self, model, device, config):
        self.config = config
        self.start_epoch = 1
        self.model = model
        self.device = device
        self.optimizer = AdamW(self.model.parameters(), lr = self.config.lr)
        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(self.optimizer, max_lr=0.001, steps_per_epoch=len(train_loader), epochs=2)
#         self.scheduler = config.SchedulerClass(self.optimizer, **self.config.scheduler_params)
        self.criterion = nn.CrossEntropyLoss(ignore_index = self.config.PAD_IDX).to(device)
        self.scaler = GradScaler()
        self.base_dir = f'{config.folder}'
        if not os.path.exists(self.base_dir):
            os.makedirs(self.base_dir)
        
    def train_loop(self, train_loader, validation_loader, train_resume=False):
        if train_resume==True:
            self.load_model('../input/model-transformer/model_transformer_50.pth')
        for e in range(self.start_epoch, self.start_epoch+self.config.num_epochs):
            t = time.time()
            calc_loss = self.train_one_epoch(self.device, train_loader, e)
            self.config.epoch.append(e)
            self.config.train_loss.append(calc_loss.avg)
            print(f'Train. Epoch: {e}, Loss: {calc_loss.avg:.5f}, time: {(time.time() - t):.5f}')

            if (e==self.start_epoch+self.config.num_epochs-1):
                t = time.time()
                predictions, targets, calc_loss = self.valid_one_epoch(self.device, validation_loader)
                acc_valid = self.accuracy_valid(predictions, targets)
                self.config.valid_acc.append(acc_valid)
                print(f'Val. Epoch: {e}, Loss: {calc_loss.avg:.5f}, Acc valid: {acc_valid:.5f}, time: {(time.time() - t):.5f}')
                state = {'epoch': e, 'state_dict': self.model.state_dict(),'optimizer': self.optimizer.state_dict(), 
                         'train_loss': self.config.train_loss, 'valid_acc': self.config.valid_acc}
                self.save_model(state, f'model_transformer_{e}.pth')
            
        
        
    def train_one_epoch(self, device, train_loader, e):
        self.model.train()
        calc_loss = AverageMeter()
        start = end = time.time()
        for batch_idx, (text, target) in tqdm(enumerate(train_loader)):
            batch_size = text.size(0)
            text = torch.transpose(text, 0, 1)
            target =  torch.transpose(target, 0, 1)
            text = text.to(device, dtype=torch.int64)
            target = target.to(device, dtype=torch.int64)
            self.optimizer.zero_grad()
            
            with autocast():
                output = self.model(text, target[:-1])
                output = output.contiguous().reshape(-1, output.shape[2])
                target = target[1:].reshape(-1)
            
#             print("Output",output.size())
            
#             print("Target", target.size())

            
#                 print("Output",output.size())
            
#                 print("Target", target.size())
            
                # output shape (batch_size, seq_length, vocab_length)
                # target shape (batch_size, seq_length)           

                loss = self.criterion(output, target)
            
            loss_value = loss.item()

            # Back prop
#             loss.backward()
            self.scaler.scale(loss).backward()

            # Clip to avoid exploding gradient issues, makes sure grads are
            # within a healthy range
#             torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1)

            # Gradient descent step
#             self.optimizer.step()
            self.scaler.step(self.optimizer)
            self.scaler.update()
            calc_loss.update(loss_value, batch_size)
            if batch_idx %400==100:
                print('Epoch: [{0}][{1}/{2}] '
                  'Elapsed {remain:s} '
                  .format(
                   e, batch_idx, len(train_loader),
                   remain=timeSince(start, float(batch_idx+1)/len(train_loader)),
                   #lr=scheduler.get_lr()[0],
                   ))
                
            
        return calc_loss
    
    def valid_one_epoch(self, device, val_loader):
        self.model.eval()
        predictions = []
        targets  =[]
        calc_loss = AverageMeter()
        with torch.no_grad():
            for batch_idx, (text, target) in tqdm(enumerate(val_loader)): 
                batch_size = text.size(0)
                targets.append(vocab.decode(np.squeeze(target.detach().numpy()).tolist()))
                text = torch.transpose(text, 0, 1)
                target =  torch.transpose(target, 0, 1)
                text = text.to(device, dtype=torch.int64)
                target = target.to(device, dtype=torch.int64)
                prediction = model.inference(text, 40, device)
                prediction = prediction.detach().cpu().numpy().tolist()
                prediction = vocab.decode(prediction)
                predictions.append(prediction)
                output = model(text, target[:-1])
                
                



                output = output.contiguous().reshape(-1, output.shape[2])
                target = target[1:].reshape(-1)
                
                
                loss = self.criterion(output, target)
                loss_value = loss.item()

                calc_loss.update(loss_value, batch_size)
                
        return predictions, targets, calc_loss
    
    def accuracy_valid(self, predictions, targets):
        n = len(predictions)
        acc = 0
    #     print(len(predictions), len(targets))
        for i in range(len(predictions)):

    #         print(f"Predictions {predictions}")
    #         print(f"Targets {targets}")
            if predictions[i]==targets[i]:
                acc+=1
        return acc/n
    
    
    def save_model(self, state, path):
        torch.save(state, path)
    
    def load_model(self, path):
        checkpoint = torch.load(path)
        self.start_epoch = checkpoint['epoch']+1
        self.model.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.config.train_loss = checkpoint['train_loss']
        self.config.valid_acc = checkpoint['valid_acc']

In [21]:
class GlobalParametersTrain:
    lr =0.0001
    SRC_VOCAB_SIZE = 232
    TARGET_VOCAB_SIZE = 232
    num_epochs = 10
    model_dim = 256
    feed_forward_dim = 1024
    num_layers = 6
    max_len = 40
    PAD_IDX = 0
    SOS_IDX = 1
    EOS_IDX = 2 
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    train_loss = []
    valid_loss = []
    valid_acc = []
    epoch = []
    
    folder = './model'

In [22]:
config = GlobalParametersTrain()
# model = Seq2SeqTransformer(SRC_VOCAB_SIZE, TARGET_VOCAB_SIZE, PAD_IDX, PAD_IDX, pes, device).to(device)
model = TransformerModel(config.SRC_VOCAB_SIZE, config.model_dim, config.feed_forward_dim, config.TARGET_VOCAB_SIZE, config.num_layers,
                            config.max_len, config.PAD_IDX, config.SOS_IDX, config.EOS_IDX).to(config.device)
training = Trainer(model.to(config.device), device=config.device, config=config)
training.train_loop(train_loader, val_loader, False)
