## Import Libraries

In [78]:
import pandas as pd
from tqdm import tqdm
import numpy as np
import warnings
import re
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import re
import random
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import numpy as np
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from nltk.translate.bleu_score import sentence_bleu

warnings.filterwarnings("ignore")
tqdm.pandas()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Convert m2 data to dataframe

In [79]:
#https://www.cl.cam.ac.uk/research/nl/bea2019st/data/corr_from_m2.py
def m2_to_df(m2_file_path,id=0):
    '''This function takes m2 file path as input and converts it to pandas dataframe'''

    m2 = open(m2_file_path).read().strip().split("\n\n")
    # Do not apply edits with these error types
    skip = {"noop", "UNK", "Um"}

    correct_sent_array = []
    incorrect_sent_array = []

    for sent in tqdm(m2):
        sent = sent.split("\n")
        incor_sent = sent[0].split()[1:] # Ignore "S "
        incorrect_sent_array.append(str(' '.join(incor_sent))) 
        cor_sent = incor_sent.copy()

        edits = sent[1:]
        offset = 0
        for edit in edits:
            edit = edit.split("|||")
            if edit[1] in skip: continue # Ignore certain edits
            coder = int(edit[-1])
            if coder != id: continue # Ignore other coders
            span = edit[0].split()[1:] # Ignore "A "
            start = int(span[0])
            end = int(span[1])
            cor = edit[2].split()
            cor_sent[start+offset:end+offset] = cor
            offset = offset-(end-start)+len(cor)
        correct_sent_array.append(str(' '.join(cor_sent)))

    df = pd.DataFrame()
    df["correct"] = correct_sent_array
    df["incorrect"] = incorrect_sent_array
    return df

In [80]:
m2_file = '/Users/saisrivishwanath/Documents/GrammarErrorCorrection/lang8.bea19/lang8.train.auto.bea19.m2'
df = m2_to_df(m2_file)

100%|█████████████████████████████| 1037561/1037561 [00:02<00:00, 353136.87it/s]


In [81]:
df.sample(5)

Unnamed: 0,correct,incorrect
299143,"In order to solve this problem ,","For solving this problem ,"
775957,There is a large lake called Loch Ness in the ...,There is a large lake called the Loch Ness in ...
399672,"And if I want to work abroad , the ability to ...","And if I want to work abroad , an ability to s..."
746981,It has been raining in my area since yesterday...,"Yesterday and Today , it is raining in my area..."
35404,Wow .,Wow .


In [82]:
df.to_csv('/Users/saisrivishwanath/Documents/GrammarErrorCorrection/git files/GrammarErrorCorrection/df.csv',index=False)

In [83]:
df.shape

(1037561, 2)

Character and word counts

In [84]:
df['correct_char_count'] = df['correct'].astype('str').apply(lambda x:len(x))
df['incorrect_char_count'] = df['incorrect'].astype('str').apply(lambda x:len(x))

In [85]:
df['correct_word_count'] = df['correct'].astype('str').apply(lambda x:len(x.split()))
df['incorrect_word_count'] = df['incorrect'].astype('str').apply(lambda x:len(x.split()))

In [86]:
df.sample(5)

Unnamed: 0,correct,incorrect,correct_char_count,incorrect_char_count,correct_word_count,incorrect_word_count
114440,My best regards .,My best regards .,17,17,4,4
687158,But this time I really regret messing up my ch...,But this time I really regret messing up my ch...,69,69,14,14
869104,I feel tried to treat it . .,I feel tried to treat it . .,28,28,8,8
239255,I was ordered by the company to improve my Eng...,I was ordered by company to improve my English .,52,48,11,10
198006,10 .,10 .,4,4,2,2


# Preprocessing

In [87]:
pd.DataFrame(df.isna().sum(),columns=['missing_count'])

Unnamed: 0,missing_count
correct,0
incorrect,0
correct_char_count,0
incorrect_char_count,0
correct_word_count,0
incorrect_word_count,0


In [88]:
df[df.isna().any(axis=1)]

Unnamed: 0,correct,incorrect,correct_char_count,incorrect_char_count,correct_word_count,incorrect_word_count


In [89]:
df = df.dropna().reset_index(drop=True)

In [90]:
df.shape

(1037561, 6)

In [91]:
df.sample(5)

Unnamed: 0,correct,incorrect,correct_char_count,incorrect_char_count,correct_word_count,incorrect_word_count
457861,The movie 's title is `` Green Zone `` .,Movie 's title is `` Green Zone `` .,40,36,10,9
952728,# twinglish,# twinglish,11,11,2,2
186064,I 'm looking for a language exchange friend .,I 'm finding language exchange friend . .,45,41,9,8
979921,I 'm planning to travel !,I 'm planning to travel !,25,25,6,6
412865,"So , she is a great mother , and I admire her .","so , she is a great mother , I admire she .",47,43,13,12


In [92]:
print(f"total number of duplicate pairs: {len(df[df['correct']==df['incorrect']])}")

total number of duplicate pairs: 539202


In [93]:
print(f"total number of duplicate pairs: {len(df[df['correct']==df['incorrect']])}")

total number of duplicate pairs: 539202


In [94]:
df[df['correct']==df['incorrect']].sample(10)

Unnamed: 0,correct,incorrect,correct_char_count,incorrect_char_count,correct_word_count,incorrect_word_count
333593,"Then , do you like to eat cup - ramyeons , or ...","Then , do you like to eat cup - ramyeons , or ...",62,62,16,16
1016241,I strolled to the kitchen and went up to the k...,I strolled to the kitchen and went up to the k...,53,53,12,12
747764,I have to work until 9 : 30 am .,I have to work until 9 : 30 am .,32,32,10,10
195733,He is talented .,He is talented .,16,16,4,4
502768,Handsome guy,Handsome guy,12,12,2,2
606427,"For beer garden , it was so much fun too ! !","For beer garden , it was so much fun too ! !",44,44,12,12
923940,Worship had been by screen without a real past...,Worship had been by screen without a real past...,50,50,10,10
387184,cherry blossoms,cherry blossoms,15,15,2,2
251119,"If you do it you will see the results , if you...","If you do it you will see the results , if you...",69,69,18,18
328793,"However , it 's too difficult for me to do the...","However , it 's too difficult for me to do the...",50,50,12,12


In [95]:
df = df[df['correct']!=df['incorrect']]

In [96]:
df.shape

(498359, 6)

In [97]:
df.sample(5)

Unnamed: 0,correct,incorrect,correct_char_count,incorrect_char_count,correct_word_count,incorrect_word_count
44297,But he only stayed there 5 weeks and said that...,But he said a 5 - weeks was too short .,65,39,15,11
864999,"Properly speaking , they refuse to be so .","Exactly speaking , they refuse to be so .",42,41,9,9
546778,"I have been bedridden for the past two days , ...","I lie down all the days long in these 2 days ,...",125,123,27,28
751003,I was able to see my friends and was pursued b...,I could meet my friends and was courted by Meg...,58,53,14,12
911014,We Japanese tend to not look into the eyes of ...,We Japanese tend not to look into eyes of othe...,83,85,18,18


In [98]:
print(f'total number of duplicates: {df.duplicated().sum()}')

total number of duplicates: 2021


In [99]:
df[df.duplicated(keep=False)].sort_values('correct')

Unnamed: 0,correct,incorrect,correct_char_count,incorrect_char_count,correct_word_count,incorrect_word_count
717380,"( I seriously want to escape , all the way , t...","( I seriously want to escape , all the way , t...",88,85,19,19
1027463,"( I seriously want to escape , all the way , t...","( I seriously want to escape , all the way , t...",88,85,19,19
802136,: - ),: - (,5,5,3,3
800389,: - ),: - (,5,5,3,3
161743,A : How much did it cost ?,A : How much does is cost ?,26,27,8,8
...,...,...,...,...,...,...
350828,to be continued . . .,to be continue . . .,21,20,6,6
17343,to be continued . . .,to be continue . . .,21,20,6,6
633236,to be continued . . .,to be continue . . .,21,20,6,6
767285,today was a bad day .,today is a bad day .,21,20,6,6


In [100]:
df = df.drop_duplicates().reset_index(drop=True)

In [101]:
df.shape

(496338, 6)

In [102]:
df.sample(5)

Unnamed: 0,correct,incorrect,correct_char_count,incorrect_char_count,correct_word_count,incorrect_word_count
430830,I ca n't believe time goes by so fast .,I ca n't believe time is so fast .,39,34,10,9
483161,"Now , I am free - lance .","Now , I am free job .",25,21,8,7
423637,It is not easy but I enjoy it .,It is not easy but enjoying .,31,29,9,7
283036,"A lot of drinking Beer , syotyu and kakutel .","A lot of drink a Beer , syotyu and kakutel .",45,44,10,11
380087,I 'm not good at Physics .,I 'm not good at physics .,26,26,7,7


Removing small texts(sentences whose length is less than 2)

In [103]:
df[df['correct_char_count']<2].shape

(27, 6)

In [104]:
df[df['correct_char_count']<2].sample(10)

Unnamed: 0,correct,incorrect,correct_char_count,incorrect_char_count,correct_word_count,incorrect_word_count
465170,.,life .,1,6,1,2
251783,.,out .,1,5,1,2
153480,.,For Tech support .,1,18,1,4
113363,.,M .,1,3,1,2
461060,.,had them .,1,10,1,3
10489,.,took some medicine .,1,20,1,4
398158,.,them .,1,6,1,2
439421,.,on face .,1,9,1,3
480724,.,or philosopher movie .,1,22,1,4
164090,?,' ? ?,1,5,1,3


In [105]:
df = df[df['correct_char_count']>2].reset_index(drop=True)

In [106]:
df.shape

(496299, 6)

Clean text

In [107]:
#https://www.analyticsvidhya.com/blog/2020/04/beginners-guide-exploratory-data-analysis-text-data/
contractions_dict = { "ain't": "are not","'s":" is","aren't": "are not",
                     "can't": "cannot","can't've": "cannot have",
                     "'cause": "because","could've": "could have","couldn't": "could not",
                     "couldn't've": "could not have", "didn't": "did not","doesn't": "does not",
                     "don't": "do not","hadn't": "had not","hadn't've": "had not have",
                     "hasn't": "has not","haven't": "have not","he'd": "he would",
                     "he'd've": "he would have","he'll": "he will", "he'll've": "he will have",
                     "how'd": "how did","how'd'y": "how do you","how'll": "how will",
                     "I'd": "I would", "I'd've": "I would have","I'll": "I will",
                     "I'll've": "I will have","I'm": "I am","I've": "I have", "isn't": "is not",
                     "it'd": "it would","it'd've": "it would have","it'll": "it will",
                     "it'll've": "it will have", "let's": "let us","ma'am": "madam",
                     "mayn't": "may not","might've": "might have","mightn't": "might not", 
                     "mightn't've": "might not have","must've": "must have","mustn't": "must not",
                     "mustn't've": "must not have", "needn't": "need not",
                     "needn't've": "need not have","o'clock": "of the clock","oughtn't": "ought not",
                     "oughtn't've": "ought not have","shan't": "shall not","sha'n't": "shall not",
                     "shan't've": "shall not have","she'd": "she would","she'd've": "she would have",
                     "she'll": "she will", "she'll've": "she will have","should've": "should have",
                     "shouldn't": "should not", "shouldn't've": "should not have","so've": "so have",
                     "that'd": "that would","that'd've": "that would have", "there'd": "there would",
                     "there'd've": "there would have", "they'd": "they would",
                     "they'd've": "they would have","they'll": "they will",
                     "they'll've": "they will have", "they're": "they are","they've": "they have",
                     "to've": "to have","wasn't": "was not","we'd": "we would",
                     "we'd've": "we would have","we'll": "we will","we'll've": "we will have",
                     "we're": "we are","we've": "we have", "weren't": "were not","what'll": "what will",
                     "what'll've": "what will have","what're": "what are", "what've": "what have",
                     "when've": "when have","where'd": "where did", "where've": "where have",
                     "who'll": "who will","who'll've": "who will have","who've": "who have",
                     "why've": "why have","will've": "will have","won't": "will not",
                     "won't've": "will not have", "would've": "would have","wouldn't": "would not",
                     "wouldn't've": "would not have","y'all": "you all", "y'all'd": "you all would",
                     "y'all'd've": "you all would have","y'all're": "you all are",
                     "y'all've": "you all have", "you'd": "you would","you'd've": "you would have",
                     "you'll": "you will","you'll've": "you will have", "you're": "you are",
                     "you've": "you have","n\'t":" not","\'re":" are","\'s": " is","\'d":" would",
                     "\'ll": " will","\'t":" not","\'ve": " have","\'m":" am"}


# Regular expression for finding contractions
contractions_re=re.compile('(%s)' % '|'.join(contractions_dict.keys()))

# Function for expanding contractions
def expand_contractions(text,contractions_dict=contractions_dict):
    def replace(match):
        return contractions_dict[match.group(0)]
    return contractions_re.sub(replace, text)

In [110]:
# https://stackoverflow.com/a/47091490/4084039
def clean(text):
    text = re.sub('\s*\<.*?\>\s', '', text)
    text = re.sub('\s*\(.*?\)\s', '', text)
    text = re.sub('\s*\[.*?\]\s', '', text)
    text = re.sub('\s*\{.*?\}\s', '', text)
    text = re.sub("[-+@#^/|*(){}$~<>=_%:;]","",text)
    text = text.replace("\\","")
    text = re.sub("\[","",text)
    text = re.sub("\]","",text)
    text = re.sub("\<","",text)
    text = re.sub("\>","",text)
    text = re.sub("\(","",text)
    text = re.sub("\)","",text)
    text = re.sub("[0-9]","",text)
    text = ' '.join(text.split())
    return text

In [34]:
df['correct'] = df['correct'].progress_apply(clean)
df['correct'] = df['correct'].progress_apply(expand_contractions)

100%|████████████████████████████████| 496299/496299 [00:05<00:00, 97551.77it/s]
100%|████████████████████████████████| 496299/496299 [00:06<00:00, 72454.82it/s]


In [35]:
df['incorrect'] = df['incorrect'].progress_apply(clean)
df['incorrect'] = df['incorrect'].progress_apply(expand_contractions)

100%|████████████████████████████████| 496299/496299 [00:04<00:00, 99390.99it/s]
100%|████████████████████████████████| 496299/496299 [00:06<00:00, 74261.03it/s]


In [36]:
df.sample(5)

Unnamed: 0,correct,incorrect,correct_char_count,incorrect_char_count,correct_word_count,incorrect_word_count
71816,I write about two or three mails a day .,I make about two or three mails in a day .,40,42,10,11
482897,"In the past when I was a kid , I wanted to go ...","in past when kid , I want go good school , but...",99,69,25,17
173839,The J league will cover only one more game bef...,The J league will over only one game before a ...,104,83,20,16
273983,One of my American friends gave me a new phras...,One of my American friends gave me a new word ...,73,71,16,16
187603,I feel bad . or I feel ill .,I feel bad .,28,12,9,4


In [38]:
df.to_csv('/Users/saisrivishwanath/Documents/GrammarErrorCorrection/git files/GrammarErrorCorrection/preprocessed_df.csv',index=False)

In [46]:
df.shape

(496299, 6)

In [39]:
final_df = df[df['incorrect_word_count'] <= 12]

In [40]:
final_df.to_csv('/Users/saisrivishwanath/Documents/GrammarErrorCorrection/git files/GrammarErrorCorrection/final_df.csv',index=False)

# Modelling

Language Vocabulary Builder for Seq2Seq Models

In [47]:
SOS_token = 0
EOS_token = 1

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2 

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

Converting Unicode Strings to ASCII and Normalizing Text

In [48]:
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def normalizeString(s):
    s = unicodeToAscii(s.strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z!?]+", r" ", s)
    return s.strip()

Read and Process Language Pairs from the data(CSV File)

In [52]:
def readLangs(lang1, lang2, filename, reverse=False):
    print("Reading lines...")

    df = pd.read_csv('preprocessed_df.csv')
    df = df.dropna(subset=['incorrect', 'correct'])
  
    pairs = list(zip(df['incorrect'], df['correct']))

    pairs = [[normalizeString(s) for s in pair] for pair in pairs]

    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2) 
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)

    return input_lang, output_lang, pairs


Filtering Sentence Pairs for Maximum Length Constraint

In [53]:
MAX_LENGTH = 12

def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

Data Preparation

In [54]:
def prepareData(lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
    print("Read %s sentence pairs" % len(pairs))
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    pairs = pairs[0:50000]
    print("length of first 50000 pairs", len(pairs))
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs

input_lang, output_lang, pairs = prepareData('eng', 'eng', True)
print(random.choice(pairs))

Reading lines...
Read 496255 sentence pairs
Trimmed to 266109 sentence pairs
Counting words...
length of first 50000 pairs 50000
Counted words:
eng 60326
eng 49505
['I felt I need to study !', 'I felt like I needed to study more !']


In [55]:
print(len(pairs))

50000


In [56]:
input_lang.word2index

{'And': 2,
 'he': 3,
 'took': 4,
 'in': 5,
 'my': 6,
 'favorite': 7,
 'subject': 8,
 'like': 9,
 'soccer': 10,
 'Actually': 11,
 'who': 12,
 'let': 13,
 'me': 14,
 'know': 15,
 'about': 16,
 'Lang': 17,
 'was': 18,
 'him': 19,
 'His': 20,
 'Kanji': 21,
 'is': 22,
 'ability': 23,
 'much': 24,
 'better': 25,
 'than': 26,
 'I': 27,
 'heard': 28,
 'a': 29,
 'sentence': 30,
 'last': 31,
 'night': 32,
 'when': 33,
 'watched': 34,
 'TV': 35,
 'When': 36,
 'you': 37,
 'go': 38,
 'uphill': 39,
 'hvae': 40,
 'to': 41,
 'bend': 42,
 'your': 43,
 'back': 44,
 'are': 45,
 'smoothly': 46,
 'have': 47,
 'be': 48,
 'more': 49,
 'modest': 50,
 'The': 51,
 'making': 52,
 'souvenir': 53,
 'hard': 54,
 'and': 55,
 'interesting': 56,
 'work': 57,
 'You': 58,
 'can': 59,
 'take': 60,
 'them': 61,
 'at': 62,
 'slot': 63,
 'machine': 64,
 'third': 65,
 'memory': 66,
 'the': 67,
 'house': 68,
 'we': 69,
 'lived': 70,
 'Do': 71,
 'not': 72,
 'why': 73,
 'liked': 74,
 'winter': 75,
 'Finland': 76,
 'hope': 77,
 

Seq2Seq Model

In [57]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, dropout_p=0.1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, input):
        embedded = self.dropout(self.embedding(input))
        output, hidden = self.gru(embedded)
        return output, hidden

In [58]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)
        decoder_hidden = encoder_hidden
        decoder_outputs = []

        for i in range(MAX_LENGTH):
            decoder_output, decoder_hidden  = self.forward_step(decoder_input, decoder_hidden)
            decoder_outputs.append(decoder_output)

            if target_tensor is not None:
                # Teacher forcing: Feed the target as the next input
                decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing
            else:
                # Without teacher forcing: use its own predictions as the next input
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()  # detach from history as input

        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        return decoder_outputs, decoder_hidden, None # We return `None` for consistency in the training loop

    def forward_step(self, input, hidden):
        output = self.embedding(input)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.out(output)
        return output, hidden

In [59]:
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super(BahdanauAttention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, query, keys):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
        scores = scores.squeeze(2).unsqueeze(1)

        weights = F.softmax(scores, dim=-1)
        context = torch.bmm(weights, keys)

        return context, weights

class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1):
        super(AttnDecoderRNN, self).__init__()
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.attention = BahdanauAttention(hidden_size)
        self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)
        decoder_hidden = encoder_hidden
        decoder_outputs = []
        attentions = []

        for i in range(MAX_LENGTH):
            decoder_output, decoder_hidden, attn_weights = self.forward_step(
                decoder_input, decoder_hidden, encoder_outputs
            )
            decoder_outputs.append(decoder_output)
            attentions.append(attn_weights)

            if target_tensor is not None:
                # Teacher forcing: Feed the target as the next input
                decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing
            else:
                # Without teacher forcing: use its own predictions as the next input
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()  # detach from history as input

        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        attentions = torch.cat(attentions, dim=1)

        return decoder_outputs, decoder_hidden, attentions


    def forward_step(self, input, hidden, encoder_outputs):
        embedded =  self.dropout(self.embedding(input))

        query = hidden.permute(1, 0, 2)
        context, attn_weights = self.attention(query, encoder_outputs)
        input_gru = torch.cat((embedded, context), dim=2)

        output, hidden = self.gru(input_gru, hidden)
        output = self.out(output)

        return output, hidden, attn_weights

Loading Data for Seq2Seq Language Processing

In [60]:
def indexesFromSentence(lang, sentence):
    return [lang.word2index[word] for word in sentence.split(' ')]

def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)

def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(input_lang, pair[0])
    target_tensor = tensorFromSentence(output_lang, pair[1])
    return (input_tensor, target_tensor)

def get_dataloader(batch_size):
    input_lang, output_lang, pairs = prepareData('eng', 'eng', True)

    n = len(pairs)
    input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
    target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)

    for idx, (inp, tgt) in enumerate(pairs):
        inp_ids = indexesFromSentence(input_lang, inp)
        tgt_ids = indexesFromSentence(output_lang, tgt)
        inp_ids.append(EOS_token)
        tgt_ids.append(EOS_token)
        input_ids[idx, :len(inp_ids)] = inp_ids
        target_ids[idx, :len(tgt_ids)] = tgt_ids

    train_data = TensorDataset(torch.LongTensor(input_ids).to(device),
                               torch.LongTensor(target_ids).to(device))

    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
    return input_lang, output_lang, train_dataloader

batch_size = 32
input_lang, output_lang, train_dataloader = get_dataloader(batch_size)

Reading lines...
Read 496255 sentence pairs
Trimmed to 266109 sentence pairs
Counting words...
length of first 50000 pairs 50000
Counted words:
eng 60326
eng 49505


In [61]:
len(train_dataloader)

1563

Training and evaluating the model

In [62]:
def train_epoch(dataloader, encoder, decoder, encoder_optimizer,
          decoder_optimizer, criterion):

    total_loss = 0
    for data in dataloader:
        input_tensor, target_tensor = data

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)

        loss = criterion(
            decoder_outputs.view(-1, decoder_outputs.size(-1)),
            target_tensor.view(-1)
        )
        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

In [63]:
import time
import math

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [64]:
def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,
               print_every=100, plot_every=100):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
    criterion = nn.NLLLoss()

    for epoch in range(1, n_epochs + 1):
        loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print('%s (%d %d%%) %.4f' % (timeSince(start, epoch / n_epochs),
                                        epoch, epoch / n_epochs * 100, print_loss_avg))

        if epoch % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

    showPlot(plot_losses)

In [65]:
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib.ticker as ticker
import numpy as np

def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    # this locator puts ticks at regular intervals
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)

In [66]:
def evaluate(encoder, decoder, sentence, input_lang, output_lang):
    with torch.no_grad():
        input_tensor = tensorFromSentence(input_lang, sentence)

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)

        _, topi = decoder_outputs.topk(1)
        decoded_ids = topi.squeeze()

        decoded_words = []
        for idx in decoded_ids:
            if idx.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            decoded_words.append(output_lang.index2word[idx.item()])
    return decoded_words, decoder_attn

In [75]:
def evaluateRandomly(encoder, decoder, n=10):
    bleu_scores = []
    
    for i in range(n):
        pair = random.choice(pairs)
        print('>', pair[0])
        print('=', pair[1])
        output_words, _ = evaluate(encoder, decoder, pair[0], input_lang, output_lang)
        output_sentence = ' '.join(output_words)
        print('<', output_sentence)
        reference = [pair[1].split(' ')]
        candidate = output_sentence.split(' ')
        score = sentence_bleu(reference, candidate)
        print("BLEU score", score)
        bleu_scores.append(score)
        print('')

    print("Average BLEU score:", sum(bleu_scores) / len(bleu_scores))

In [69]:
hidden_size = 128
encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
decoder = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)

train(train_dataloader, encoder, decoder, 80, print_every=5, plot_every=5)

32m 29s (- 487m 19s) (5 6%) 2.1439
56m 22s (- 394m 40s) (10 12%) 1.0295
93m 32s (- 405m 19s) (15 18%) 0.7254
117m 24s (- 352m 12s) (20 25%) 0.5863
140m 32s (- 309m 12s) (25 31%) 0.5014
163m 51s (- 273m 5s) (30 37%) 0.4419
187m 14s (- 240m 44s) (35 43%) 0.3978
211m 9s (- 211m 9s) (40 50%) 0.3629
233m 54s (- 181m 55s) (45 56%) 0.3344
258m 0s (- 154m 48s) (50 62%) 0.3114
282m 32s (- 128m 25s) (55 68%) 0.2910
307m 3s (- 102m 21s) (60 75%) 0.2742
330m 58s (- 76m 22s) (65 81%) 0.2593
354m 25s (- 50m 37s) (70 87%) 0.2463
377m 54s (- 25m 11s) (75 93%) 0.2348
401m 37s (- 0m 0s) (80 100%) 0.2244


In [76]:
encoder.eval()
decoder.eval()
evaluateRandomly(encoder, decoder)

> Do your best I am so next week
= I am going to do my best next week
< I will do the best this I am so next week <EOS>
BLEU score 8.190757052088229e-155

> It looks like more deep pink than Japanese one
= They are more deep pink than the Japanese ones
< They are more deep pink than the Japanese ones <EOS>
BLEU score 0.8801117367933934

> Anyway I jogged km it was not bad
= Anyway I jogged km It was not bad
< Anyway I jogged km It was not bad <EOS>
BLEU score 0.8633400213704505

> Please crrect my Poor English sentences
= Please correct my poor English sentences
< Please correct my poor English sentences <EOS>
BLEU score 0.8091067115702212

> because I had to sales analysis today
= because I had to do a sales analysis today
< because I had to do a sales analysis today <EOS>
BLEU score 0.8801117367933934

> SPI make me crazy
= SPI makes me crazy
< SPI makes me crazy <EOS>
BLEU score 0.668740304976422

> Californication Times like these
= Californication Times Like These
< Californicatio

In [71]:
def showAttention(input_sentence, output_words, attentions):
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(attentions.cpu().numpy(), cmap='bone')
    fig.colorbar(cax)

    # Set up axes
    ax.set_xticklabels([''] + input_sentence.split(' ') +
                       ['<EOS>'], rotation=90)
    ax.set_yticklabels([''] + output_words)

    # Show label at every tick
    ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.show()


def evaluateAndShowAttention(input_sentence):
    output_words, attentions = evaluate(encoder, decoder, input_sentence, input_lang, output_lang)
    print('input =', input_sentence)
    print('output =', ' '.join(output_words))
    showAttention(input_sentence, output_words, attentions[0, :len(output_words), :])

evaluateAndShowAttention('I heard it was the popular book in America')

input = I heard it was the popular book in America
output = I heard it was a popular book in America <EOS>
