In [4]:
## reference: https://www.youtube.com/watch?v=CNuI8OWsppg

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import csv
import random
import re
import os
import unicodedata
import codecs
import itertools

In [5]:
CUDA = torch.cuda.is_available()
device = torch.device("cuda" if CUDA else "cpu")

In [6]:
device

device(type='cpu')

# 1. Data Preprocessing

In [7]:
lines_filepath = os.path.join("cornell movie-dialogs corpus","movie_lines.txt")
conv_filepath = os.path.join("cornell movie-dialogs corpus","movie_conversations.txt")

In [12]:
# visualize some lines

with open(lines_filepath, "r") as file:
    lines = file.readlines()
for line in lines[:8]:
    print(line.strip())

L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!
L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!
L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.
L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?
L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.
L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow
L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.
L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No


In [18]:
## splitting each values into different fields in dictionary
line_field = ['lineID','characterID','movieID','character','text']
lines ={}
with open(lines_filepath,'r', encoding ='iso-8859-1') as f:
    for line in f:
        values = line.split(" +++$+++ ")
        lineObj ={}
        for i, field in enumerate(line_field):
            lineObj[field]=values[i]
        lines[lineObj['lineID']] = lineObj

In [20]:
# Group fields of lines from loadlines into conversations based on movie_conversations.txt

conv_fields = ['characeter1ID','character2ID','movieID','utteranceIDs']
conversations = []
with open(conv_filepath,'r', encoding='iso-8859-1') as f:
    for line in f:
        values = line.split(' +++$+++ ')
        convObj = {}
        for i, field in enumerate(conv_fields):
            convObj[field]=values[i]
        
        #convert string into list 
        lineIds = eval(convObj['utteranceIDs'])
        #Reassemble lines
        convObj['lines']=[]
        for lineId in lineIds:
            convObj['lines'].append(lines[lineId])
        conversations.append(convObj)

In [22]:
# extract qa pairs
qa_pairs =[]
for conversation in conversations:
    for i in range(len(conversation['lines'])-1):
        inputLine = conversation['lines'][i]['text'].strip()
        targetLine = conversation['lines'][i+1]['text'].strip()
        if inputLine and targetLine:
            qa_pairs.append([inputLine, targetLine])

In [30]:
# Define path to new file
datafile = os.path.join("cornell movie-dialogs corpus","formatted_movie_lines.txt")
delimiter ='\t'

delimiter = str(codecs.decode(delimiter,"unicode_escape"))

#write new csv file
print("\n writing newly formatted file")
with open(datafile,'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter = delimiter)
    for pairs in qa_pairs:
        writer.writerow(pairs)
print("Done writing to file")


 writing newly formatted file
Done writing to file


In [31]:
# visualize some lines
datafile = os.path.join("cornell movie-dialogs corpus","formatted_movie_lines.txt")
with open(datafile ,'rb') as file:
    lines= file.readlines()
for line in lines[:8]:
    print(line)

b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\r\r\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\r\r\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\r\r\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\r\r\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\r\r\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\r\r\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\tSeems like she could get a date easy enough...\r\r\n"
b'Why?\tU

In [32]:
PAD_token =0
SOS_token=1
EOS_token=2

class Vocabulary:
    def __init__(self, name):
        self.name = name
        self.word2index ={}
        self.word2count ={}
        self.index2word = {PAD_token : "PAD", SOS_token : "SOS", EOS_token : "EOS"}
        self.num_words = 3
        
    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)
    
    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words +=1
        else:
            self.word2count[word] +=1
            
    #remove words below a certain threshold
    def trim(self, min_count):
        keep_words = []
        for k, v in self.word2count.items():
            if v>= min_count:
                keep_words.append(k)
            
        print('keep words {} / {} = {:.4f}'.format(len(keep_words), len(self.word2index), len(keep_words)/len(self.word2index)) )
        
        # reinitialize dictionaries
        self.word2index ={}
        self.word2count ={}
        self.index2word = {PAD_token : "PAD", SOS_token : "SOS", EOS_token : "EOS"}
        self.num_words = 3
        
        for word in keep_words:
            self.addWord(word)

In [38]:
# turn unicode to ascii
def unicodeToAscii(s):
    return ''.join(c for c in unicodedata.normalize('NFD',s) if unicodedata.category(c) != 'Mn')

In [39]:
# lowercase, trim white space, lines..etc and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+" , r" ", s)
    s = re.sub(r"\s+",r" ",s).strip()
    return s
        

In [40]:
normalizeString("aaa13aaa!s's    dd?")

'aaa aaa !s s dd ?'

In [41]:
datafile = os.path.join("cornell movie-dialogs corpus","formatted_movie_lines.txt")
print("Reading and procssing file..pls wait")
lines = open(datafile, encoding='utf-8').read().strip().split('\n')
pairs = [[normalizeString(s) for s in pairs.split('\t')] for pairs in lines]
print("Done processing")
voc = Vocabulary("cornell movie-dialogs corpus")

Reading and procssing file..pls wait
Done processing


In [43]:
len(pairs)

442563

In [44]:
# returns true if both the sentences have words < max length
MAX_LENGTH = 10
def filterPair(p):
    return len(p[0].split()) < MAX_LENGTH and len(p[1].split()) <MAX_LENGTH

def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

In [49]:
pairs = [pair for pair in pairs if len(pair)>1]
print("There are {} pairs/conversations in the dataset".format(len(pairs)))
pairs = filterPairs(pairs)
print("After filtering {} pairs/conversations in the dataset".format(len(pairs)))

There are 64271 pairs/conversations in the dataset
After filtering 64271 pairs/conversations in the dataset


In [50]:
# Loop through each pairs and add them as part of vocabulary
for pair in pairs:
    voc.addSentence(pair[0])
    voc.addSentence(pair[1])
print('Counted words: ', voc.num_words)
for pair in pairs[:10]:
    print(pair)


Counted words:  18008
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']


In [51]:
MIN_COUNT = 3 # minimum word count for trimming

def trimRareWords(voc, pairs, MIN_COUNT):
    voc.trim(MIN_COUNT)
    # filter out pairs with trimmed words
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        # check input sequence
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        #check output sequence
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break
        
        if keep_input and keep_output:
            keep_pairs.append(pair)
        
    print('Trimmed from {} pairs to {}, {:.4f} of total'.format(len(pairs), len(keep_pairs), len(keep_pairs)/len(pairs)))
    return keep_pairs

# Trim voc and pairs
pairs = trimRareWords(voc, pairs, MIN_COUNT)


keep words 7823 / 18005 = 0.4345
Trimmed from 64271 pairs to 53165, 0.8272 of total


In [52]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]

In [53]:
indexesFromSentence(voc, pairs[1][0])

[7, 8, 9, 10, 4, 11, 12, 13, 2]

In [56]:
# define some sample for testing
inp = []
out =[]
for pair in pairs[:10]:
    inp.append(pair[0])
    out.append(pair[1])
print(inp)
print(len(inp))
indexes = [indexesFromSentence(voc, sentence) for sentence in inp]
indexes

['there .', 'you have my word . as a gentleman', 'hi .', 'have fun tonight ?', 'well no . . .', 'then that s all you had to say .', 'but', 'do you listen to this crap ?', 'what good stuff ?', 'wow']
10


[[3, 4, 2],
 [7, 8, 9, 10, 4, 11, 12, 13, 2],
 [16, 4, 2],
 [8, 31, 22, 6, 2],
 [33, 34, 4, 4, 4, 2],
 [35, 36, 37, 38, 7, 39, 40, 41, 4, 2],
 [42, 2],
 [47, 7, 48, 40, 45, 49, 6, 2],
 [50, 51, 52, 6, 2],
 [58, 2]]

In [57]:
def zeroPadding(l, fillvalue =0):
    return list(itertools.zip_longest(*l, fillvalue = fillvalue))

In [58]:
leng = [len(ind) for ind in indexes]
max(leng)

10

In [59]:
# test the function
test_result = zeroPadding(indexes)
print(len(test_result))
test_result


10


[(3, 7, 16, 8, 33, 35, 42, 47, 50, 58),
 (4, 8, 4, 31, 34, 36, 2, 7, 51, 2),
 (2, 9, 2, 22, 4, 37, 0, 48, 52, 0),
 (0, 10, 0, 6, 4, 38, 0, 40, 6, 0),
 (0, 4, 0, 2, 4, 7, 0, 45, 2, 0),
 (0, 11, 0, 0, 2, 39, 0, 49, 0, 0),
 (0, 12, 0, 0, 0, 40, 0, 6, 0, 0),
 (0, 13, 0, 0, 0, 41, 0, 2, 0, 0),
 (0, 2, 0, 0, 0, 4, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 2, 0, 0, 0, 0)]

In [60]:
def binaryMatrix(l, value =0):
    m=[]
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

In [61]:
binary_result = binaryMatrix(test_result)
binary_result


[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [1, 1, 1, 1, 1, 1, 0, 1, 1, 0],
 [0, 1, 0, 1, 1, 1, 0, 1, 1, 0],
 [0, 1, 0, 1, 1, 1, 0, 1, 1, 0],
 [0, 1, 0, 0, 1, 1, 0, 1, 0, 0],
 [0, 1, 0, 0, 0, 1, 0, 1, 0, 0],
 [0, 1, 0, 0, 0, 1, 0, 1, 0, 0],
 [0, 1, 0, 0, 0, 1, 0, 0, 0, 0],
 [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]]

In [62]:
# returns padded input sequence tensor and as well as a tensor of lengths for each of the sequence in the batch
def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

In [64]:
# returns padded target sequence tensor, padding mask, and max target length
def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.ByteTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

In [65]:
## returns all items for a given batch of pairs
def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key = lambda x : len(x[0].split(" ")), reverse = True)
    input_batch, output_batch = [],[]
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len =outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len

In [67]:
# example for validation
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print("input variable:")
print(input_variable)
print("lengths:",lengths)
print("target variable:")
print(target_variable)
print("mask:",mask )
print("max_target_len:", max_target_len)

input variable:
tensor([[ 242,  289,   23,   60,  536],
        [ 188,   83,    4,   53,  709],
        [  36,  219,    4,  598, 3381],
        [ 199,   27,    4,  135,    2],
        [   4,  158, 1308, 4189,    0],
        [ 716,  480,    6,    4,    0],
        [ 228,    4,    2,    2,    0],
        [   4,    2,    0,    0,    0],
        [   2,    0,    0,    0,    0]])
lengths: tensor([9, 8, 7, 7, 4])
target variable:
tensor([[  25,  124,    4,  571,    7],
        [  24,    9,    4,  367,   18],
        [   4, 1014,    4,   59,   36],
        [   2, 1215,    4,   83,    6],
        [   0, 1014,    4,   96,    2],
        [   0,    4,    4,    4,    0],
        [   0,    4,    2,    2,    0],
        [   0,    4,    0,    0,    0],
        [   0,    2,    0,    0,    0]])
mask: tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1],
        [0, 1, 1, 1, 0],
        [0, 1, 1, 1, 0],
        [0, 1, 0, 0, 0],
      

# 2 Building model

In [68]:
class EncoderRNN(nn.Module):
    def __init__(self, hidden_size, embedding, n_layers = 1, dropout =0):
        super(EncoderRNN).__init__()
        self.n_layers =n_layers
        self.hidden_size = hidden_size
        self.embedding = embedding
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout(0 if n_layers ==1 else dropout), bidirectional = True)
        
    def forward(self, input_seq, input_lengths, hidden = None):
        #input_seq = shape(max_length, batch_size)
        #input_lengths = list of sentence lengths 
        # hidden = shape(n_layers * num_directions, batch_size, hidden_size)
        
        embedded = self.embedding(input_seq)
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
        outputs, hidden = self.gru(packed, hidden)
        outputs,_ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
        outputs = outputs[:,:,:self.hidden_size] + outputs[:,:,self.hidden_size:]
        return outputs, hidden

In [69]:
class Attn(torch.nn.Module):
    def __init__(self, method, hidden_size):
        super(Attn, self).__init__()
        self.method = method
        self.hidden_size = hidden_size
    
    def dot_score(self, hidden, encoder_output):
        return torch.sum(hidden * encoder_output, dim = 2)
    
    def forward(self, hidden, encoder_outputs):
        attn_energies = self.dot_score(hidden, encoder_outputs)
        attn_energies = attn_energies.t()
        return F.softmax(attn_energies, dim=1).unsqueeze(1)
    

In [77]:
class LuongAttnDecoderRNN(nn.Module):
    def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers = 1, dropout =0.1):
        super(LuongAttnDecoderRNN, self).__init__()
        self.attn_model = attn_model
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.n_layers = n_layers
        self.dropout = dropout
        
        #Define layers
        self.embedding = embedding
        self.embedding_dropout = nn.Dropout(dropout)
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers==1 else dropout))
        self.concat = nn.Linear(hidden_size*2 , hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        
        self.attn = Attn(attn_model, hidden_size)
    
    def forward(self, input_step, last_hidden, encoder_outputs):
        embedded = self.embedding(input_step)
        embedded = self.embedding_dropout(embedded)
        
        rnn_output, hidden = self.gru(embedded, last_hidden)
        attn_weights = self.attn(rnn_output, encoder_outputs)
        
        context = attn_weights.bmm(encoder_outputs.transpose(0,1))
        
        rnn_output = rnn_output.squeeze(0)
        context = context.squeeze(1)
        concat_input = torch.cat((rnn_output, context),1)
        concat_output = torch.tanh(self.concat(concat_input))
        
        output = self.out(concat_output)
        output = F.softmax(output, dim=1)
        
        return output, hidden
    
    

In [78]:
def maskNLLLoss(decoder_out, target, mask):
    nTotal = mask.sum() # how many elements should we consider
    target = target.view(-1,1)
    
    gathered_tensor = torch.gather(decoder_out, 1, target)
    crossEntropy = -torch.log(gathered_tensor)
    loss = crossEntropy.masked_select(mask)
    loss = loss.mean()
    loss = loss.to(device)
    return loss, nTotal.item()


In [79]:
#Visualing 1 iteration
small_batch_size =5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, length, target_variable, mask, max_target_len = batches

print("input variable:")
print(input_variable)
print("lengths:",lengths)
print("target variable:")
print(target_variable)
print("mask:",mask )
print("max_target_len:", max_target_len)

# define parameters
hidden_size = 500
encoder_n_layers =2
decoder_n_layers = 2
dropout = 0.1
attn_model = 'dot'
embedding = nn.Embedding(voc.num_words,hidden_size)

# Define the encoder and decoder
encoder = EncoderRNN(hidden_size, embedding, encoder_n_layers, dropout)
decoder = LuongAttnDecoderRNN(attn_model, embedding, hidden_size, voc.num_words, decoder_n_layers, dropout)
encoder = encoder.to(device)
decoder = decoder.to(device)

encoder.train()
decoder.train()

#Initialize optimizers
encoder_optimizer = optim.Adam(encoder.parameters(), lr=0.0001)
decoder_optimizer = optim.Adam(encoder.parameters(), lr =0.0001)
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()

input_variable = input_variable.to(device)
lengths = lengths.to(device)
target_variable = target_variable.to(device)
mask = mask.to(device)

loss =0
print_losses =[]
n_totals=0

encoder_outputs, encoder_hidden = encoder(input_variable, lengths)
print("Encoder output shape:", encoder_outputs.shape)
print("Last Encoder Hidden Shape: ", encoder_hidden.shape)

decoder_input = torch.LongTensor([[SOS_token for _ in range(small_batch_size)]])
decoder_input = decoder_input.to(device)
print("Initial Decoder Input shape:", decoder_input.shape)
print(decoder_input)

decoder_hidden = encoder_hidden[:decoder.n_layers]
print("Initial Decoder hidden state shape:", decoder_hidden.shape)
print("\n")
print("---------------------------------------------------------------------------------")
print("Now Let's look what is happening in every timestep of the GRU!")
print("----------------------------------------------------------------------------------")
print("\n")

#Assume we are using teacher forcing

for t in range(max_target_len):
    decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
    print("Decoder output shape:", decoder_output.shape)
    print("Decoder hidden shape:", decoder_hidden.shape)
    
    decoder_input = target_variable[t].view(-1,1)
    print("The target variable at the current timestamp before reshaping:", target_variable[t])
    print("The target variable at the current timestamp shape before reshaping:", target_variable[t].shape)
    print("The Decoder input shape(reshape the target variable):", decoder_input.shape)
    
    print("The mask at the current timestamp:", mask[t])
    print("The mast at the current timestamp shape:", mask[t].shape)
    mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
    print("Mask loss:", mask_loss)
    print("Total:",nTotal)
    
    loss +=mask_loss
    print_losses.append(mask_loss.item() *nTotal)
    print(print_losses)
    
    n_totals += nTotal
    print(n_totals)
    encoder.optimizer.step()
    decoder.optimizer.step()
    returned_loss= sum(print_losses)/n_totals
    print("REturned loss:", returned_loss)
    print("\n")
    print("---------------------------------Done one Timestep------------------------------")
    print("\n")
    

input variable:
tensor([[  76,    7,   25,   77,   53],
        [ 271,  488,   89,   80, 5813],
        [ 117,  158,  553,   47, 1750],
        [ 935,  123,   76,   36,    4],
        [ 115, 2498,    4,    6,    2],
        [  76,    4,    2,    2,    0],
        [7214,    2,    0,    0,    0],
        [   6,    0,    0,    0,    0],
        [   2,    0,    0,    0,    0]])
lengths: tensor([9, 8, 7, 7, 4])
target variable:
tensor([[  34,  124,  553, 1596, 5813],
        [   4, 1643,   50,    4,    4],
        [   2,   25,    6,    2,    2],
        [   0,  112,    2,    0,    0],
        [   0,  197,    0,    0,    0],
        [   0,  117,    0,    0,    0],
        [   0,   24,    0,    0,    0],
        [   0,   75,    0,    0,    0],
        [   0,   45,    0,    0,    0],
        [   0,    2,    0,    0,    0]])
mask: tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [0, 1, 1, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [0, 1, 0

AttributeError: cannot assign module before Module.__init__() call

In [80]:
## training module
def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding, encoder_optimizer, 
         decoder_optimizer, batch_size, clip, max_length = MAX_LENGTH):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    
    input_variable = input_variable.to(device)
    lengths = lengths.to(device)
    target_variable =target_variable.to(device)
    mask = mask.to(device)
    
    loss = 0
    print_losses = []
    n_totals = 0
    
    encoder_outputs, encoder_hidden = encoder(input_variable, lengths)
    
    decoder_input = torch.LongTensor([[SOS_token for _ in range(batch_size)]])
    decoder_input = decoder_input.to(device)
    
    decoder_hidden = encoder_hidden[:decoder.n_layers]
    
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    
    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
            decoder_input = target_variable[t].view(-1,1)
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss +=mask_loss
            print_losses.append(mask_loss.item()*nTotal)
            n_totals += nTotal
            
    else:
        for t in range(max_target_len):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
            _,topi = decoder_output.topk(1)
            decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            decoder_input = decoder_input.to(device)
            
            mask_loss, nTotal = maskNLLLoss(decoder_output, target_variable[t], mask[t])
            loss += mask_loss
            print_losses.append(mask_loss.item() *nTotal)
            n_totals += nTotal
            
    loss.backward()
    
    _=torch.nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _= torch.nn.utils.clip_grad_norm(decoder.parameters(), clip)
    
    encoder_optimizer.step()
    decoder_optimizer.step()
    
    return sum(print_losses) / n_totals
            