In [1]:
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import csv
import random
import re
import os
import unicodedata
import codecs
import itertools
import bz2

In [2]:
device=torch.device("cpu")

In [3]:
lines_filepath=os.path.join("cornell movie-dialogs corpus","movie_lines.txt")
conv_filepath=os.path.join("cornell movie-dialogs corpus","movie_conversations.txt")

In [4]:
with open(lines_filepath,"r") as file:
    lines=file.readlines()
for line in lines[:8]:
    print(line.strip())

L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!
L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!
L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.
L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?
L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.
L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow
L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.
L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No


In [5]:
line_fields=["lineID","characterID","movieID","character","text"]
lines={}
with open(lines_filepath,"r",encoding="iso-8859-1") as f:
    for line in f:
        values=line.split(" +++$+++ ")
        lineobj={}
        for i,field in enumerate(line_fields):
            lineobj[field]=values[i]
        lines[lineobj["lineID"]]=lineobj    

In [6]:
lines

{'L1045': {'lineID': 'L1045',
  'characterID': 'u0',
  'movieID': 'm0',
  'character': 'BIANCA',
  'text': 'They do not!\n'},
 'L1044': {'lineID': 'L1044',
  'characterID': 'u2',
  'movieID': 'm0',
  'character': 'CAMERON',
  'text': 'They do to!\n'},
 'L985': {'lineID': 'L985',
  'characterID': 'u0',
  'movieID': 'm0',
  'character': 'BIANCA',
  'text': 'I hope so.\n'},
 'L984': {'lineID': 'L984',
  'characterID': 'u2',
  'movieID': 'm0',
  'character': 'CAMERON',
  'text': 'She okay?\n'},
 'L925': {'lineID': 'L925',
  'characterID': 'u0',
  'movieID': 'm0',
  'character': 'BIANCA',
  'text': "Let's go.\n"},
 'L924': {'lineID': 'L924',
  'characterID': 'u2',
  'movieID': 'm0',
  'character': 'CAMERON',
  'text': 'Wow\n'},
 'L872': {'lineID': 'L872',
  'characterID': 'u0',
  'movieID': 'm0',
  'character': 'BIANCA',
  'text': "Okay -- you're gonna need to learn how to lie.\n"},
 'L871': {'lineID': 'L871',
  'characterID': 'u2',
  'movieID': 'm0',
  'character': 'CAMERON',
  'text': 'No

In [7]:
with open(conv_filepath,'r',encoding="iso-8859-1") as f:
    s_lines=f.readlines()
s_lines

["u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L194', 'L195', 'L196', 'L197']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L198', 'L199']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L200', 'L201', 'L202', 'L203']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L204', 'L205', 'L206']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L207', 'L208']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L271', 'L272', 'L273', 'L274', 'L275']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L276', 'L277']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L280', 'L281']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L363', 'L364']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L365', 'L366']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L367', 'L368']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L401', 'L402', 'L403']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L404', 'L405', 'L406', 'L407']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L575', 'L576']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L577', 'L578']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L662', 'L663']\n",
 "u0 +++$+++ u2 

In [8]:
conv_fields=["characterID","character2ID","movieID","utteranceIDs"]
conversations=[]
with open(conv_filepath,'r',encoding="iso-8859-1") as f:
    for line in f:
        values=line.split(" +++$+++ ")
        convobj={}
        for i,field in enumerate(conv_fields):
            convobj[field]=values[i]
        lineIds=eval(convobj["utteranceIDs"])
        convobj["lines"]=[]
        for lineId in lineIds:
            convobj["lines"].append(lines[lineId])
        conversations.append(convobj)    

In [9]:
conversations

[{'characterID': 'u0',
  'character2ID': 'u2',
  'movieID': 'm0',
  'utteranceIDs': "['L194', 'L195', 'L196', 'L197']\n",
  'lines': [{'lineID': 'L194',
    'characterID': 'u0',
    'movieID': 'm0',
    'character': 'BIANCA',
    'text': 'Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\n'},
   {'lineID': 'L195',
    'characterID': 'u2',
    'movieID': 'm0',
    'character': 'CAMERON',
    'text': "Well, I thought we'd start with pronunciation, if that's okay with you.\n"},
   {'lineID': 'L196',
    'characterID': 'u0',
    'movieID': 'm0',
    'character': 'BIANCA',
    'text': 'Not the hacking and gagging and spitting part.  Please.\n'},
   {'lineID': 'L197',
    'characterID': 'u2',
    'movieID': 'm0',
    'character': 'CAMERON',
    'text': "Okay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"}]},
 {'characterID': 'u0',
  'character2ID': 'u2',
  'movieID': 'm0',
  'utteranc

In [10]:
qa_pairs=[]
for conversation in conversations:
    for i in range(len(conversation["lines"])-1):
        inputLine=conversation["lines"][i]["text"].strip()
        targetLine=conversation["lines"][i+1]["text"].strip()
        if inputLine and targetLine:
            qa_pairs.append([inputLine,targetLine])

In [11]:
qa_pairs

[['Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.',
  "Well, I thought we'd start with pronunciation, if that's okay with you."],
 ["Well, I thought we'd start with pronunciation, if that's okay with you.",
  'Not the hacking and gagging and spitting part.  Please.'],
 ['Not the hacking and gagging and spitting part.  Please.',
  "Okay... then how 'bout we try out some French cuisine.  Saturday?  Night?"],
 ["You're asking me out.  That's so cute. What's your name again?",
  'Forget it.'],
 ["No, no, it's my fault -- we didn't have a proper introduction ---",
  'Cameron.'],
 ['Cameron.',
  "The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does."],
 ["The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.",
  'Seems like she could get a date easy enough...'],
 [

In [12]:
datafile=os.path.join("cornell movie-dialogs corpus","fomatted_movie_lines.txt")
delimiter='\t'
delimiter=str(codecs.decode(delimiter,"unicode_escape"))
print("\nwriting newly formatted file")
with open(datafile,'w',encoding="utf-8") as outputfile:
    writer=csv.writer(outputfile,delimiter=delimiter)
    for pair in qa_pairs:
        writer.writerow(pair)
print("done writing to file")


writing newly formatted file
done writing to file


In [87]:
datafile=os.path.join("cornell movie-dialogs corpus","fomatted_movie_lines.txt")
with open(datafile,'rb') as file:
    lines=file.readlines()
for line in lines[:8]:
    print(line)

b'can we make this quick ? roxanne korrine and andrew barrett are having an incredibly horrendous public break up on the quad . again .\twell i thought we d start with pronunciation if that s okay with you .\r\r\n'
b'well i thought we d start with pronunciation if that s okay with you .\tnot the hacking and gagging and spitting part . please .\r\r\n'
b'not the hacking and gagging and spitting part . please .\tokay . . . then how bout we try out some french cuisine . saturday ? night ?\r\r\n'
b'you re asking me out . that s so cute . what s your name again ?\tforget it .\r\r\n'
b'no no it s my fault we didn t have a proper introduction\tcameron .\r\r\n'
b'cameron .\tthe thing is cameron i m at the mercy of a particularly hideous breed of loser . my sister . i can t date until she does .\r\r\n'
b'the thing is cameron i m at the mercy of a particularly hideous breed of loser . my sister . i can t date until she does .\tseems like she could get a date easy enough . . .\r\r\n'
b'why ?\tunso

In [14]:
PAD_token=0
SOS_token=1
EOS_token=2

class vocabulary:
    
    def __init__(self,name):
        self.name=name
        self.word2index={}
        self.word2count={}
        self.index2word={PAD_token:"PAD",SOS_token:"SOS",EOS_token:"EOS"}
        self.num_words=3
    
    def addsentence(self,sentence):
        for word in sentence.split(' '):
            self.addword(word)
            
    def addword(self,word):
        if word not in self.word2index:
            self.word2index[word]=self.num_words
            self.word2count[word]=1
            self.index2word[self.num_words]=word
            self.num_words+=1
        else:
            self.word2count[word]+=1
            
    def trim(self,min_word):
        keep_words=[]
        for k,v in self.word2count.items():
            if v>=min_word:
                keep_words.append(k)
        print("keep_words {} / {}={:.4f}".format((len(keep_words)),self.num_words,len(keep_words)/self.num_words))
        self.word2index={}
        self.word2count={}
        self.index2word={PAD_token:"PAD",SOS_token:"SOS",EOS_token:"EOS"}
        self.num_words=3
        for word in keep_words:
            self.addword(word)
            
    
            
            

In [15]:
def unicodeToAscii(s):
    return ''.join(c for c in unicodedata.normalize('NFD',s) if unicodedata.category(c)!='Mn')

In [16]:
unicodeToAscii("Maître Renard, par l’odeur alléché")

'Maitre Renard, par l’odeur alleche'

In [17]:
def normalisestring(s):
    s=unicodeToAscii(s.lower().strip())
    s=re.sub(r"([.!?])",r" \1",s)
    s=re.sub(r"([^a-zA-Z.!?]+)",r" ",s)
    s=re.sub(r"\s+",r" ",s).strip()
    
    return s
    

In [61]:
datafile=os.path.join("cornell movie-dialogs corpus","fomatted_movie_lines.txt")
delimiter='\t'
delimiter=str(codecs.decode(delimiter,"unicode_escape"))
print("\nwriting newly formatted file")
with open(datafile,'w',encoding="utf-8") as outputfile:
    writer=csv.writer(outputfile,delimiter=delimiter)
    for pair in qa_pairs:
        pair[0]=normalisestring(pair[0])
        pair[1]=normalisestring(pair[1])
        writer.writerow(pair)
print("done writing to file")


writing newly formatted file
done writing to file


In [88]:
datafile=os.path.join("cornell movie-dialogs corpus","fomatted_movie_lines.txt")
print("Reading and processing file.....please wait")
lines=open(datafile,encoding="utf-8").read().strip().split('\n')
pairs=[[normalisestring(s) for s in pair.split('\t')] for pair in lines]
print("Done reading")
voc=vocabulary("cornell movie-dialogs corpus")

Reading and processing file.....please wait
Done reading


In [62]:
MAX_LENGTH=10
def filterpair(p):
    return (len(p[0].split())<MAX_LENGTH and len(p[1].split())<MAX_LENGTH)
def filterpairs(pairs):
    return [pair for pair in pairs if filterpair(pair)] 

In [63]:
pairs=[pair for pair in pairs if len(pair)>1]
print("There are {} pairs/conversations in the dataset".format(len(pairs)))
pairs=filterpairs(pairs)
print("After filtering there are {} pairs/conversations in the dataset".format(len(pairs)))

There are 53165 pairs/conversations in the dataset
After filtering there are 53165 pairs/conversations in the dataset


In [64]:
for pair in pairs:
    voc.addsentence(pair[0])
    voc.addsentence(pair[1])
print("counted words:",voc.num_words)
for pair in pairs[:10]:
    print(pair)

counted words: 7826
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']
['wow', 'let s go .']


In [65]:
MIN_COUNT=3

def trimrarewords(voc,pairs,MIN_COUNT):
    voc.trim(MIN_COUNT)
    keep_pairs=[]
    for pair in pairs:
        input_sentence=pair[0]
        output_sentence=pair[1]
        keep_input=True
        keep_output=True
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input=False
                break
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output=False
                break
        if keep_input and keep_output:
            keep_pairs.append(pair)
    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs),len(keep_pairs),len(keep_pairs)/len(pairs)))
    return keep_pairs

In [66]:
pairs=trimrarewords(voc,pairs,MIN_COUNT)

keep_words 7411 / 7826=0.9470
Trimmed from 53165 pairs to 52848, 0.9940 of total


In [67]:
pairs[0][1]

'where ?'

In [68]:
def indexesfromsentence(voc,sentence):
    return [voc.word2index[word] for word in sentence.split(' ')]+[EOS_token]

In [69]:
#for pair in pairs:
pairs[0][0]
print(indexesfromsentence(voc,pairs[3][0]))

[8, 31, 22, 6, 2]


In [70]:
pairs[4][0]

'well no . . .'

In [71]:
pairs[1][0]

'you have my word . as a gentleman'

In [72]:
a=[[1,2,3],[2,1]]
a.append([1])
print(a)

[[1, 2, 3], [2, 1], [1]]


In [73]:
inp=[]
out=[]
for pair in pairs[:10]:
    inp.append(pair[0])
    out.append(pair[1])
print(len(inp))
indexes=[indexesfromsentence(voc,sentence) for sentence in inp]
indexes

10


[[3, 4, 2],
 [7, 8, 9, 10, 4, 11, 12, 13, 2],
 [16, 4, 2],
 [8, 31, 22, 6, 2],
 [33, 34, 4, 4, 4, 2],
 [35, 36, 37, 38, 7, 39, 40, 41, 4, 2],
 [42, 2],
 [47, 7, 48, 40, 45, 49, 6, 2],
 [50, 51, 52, 6, 2],
 [58, 2]]

In [74]:
a=[1,2,3,4,5]
b=['a','s','d']
print(list(zip(a,b)))
list(itertools.zip_longest(a,b))

[(1, 'a'), (2, 's'), (3, 'd')]


[(1, 'a'), (2, 's'), (3, 'd'), (4, None), (5, None)]

In [75]:
list(itertools.zip_longest(*indexes,fillvalue=0))

[(3, 7, 16, 8, 33, 35, 42, 47, 50, 58),
 (4, 8, 4, 31, 34, 36, 2, 7, 51, 2),
 (2, 9, 2, 22, 4, 37, 0, 48, 52, 0),
 (0, 10, 0, 6, 4, 38, 0, 40, 6, 0),
 (0, 4, 0, 2, 4, 7, 0, 45, 2, 0),
 (0, 11, 0, 0, 2, 39, 0, 49, 0, 0),
 (0, 12, 0, 0, 0, 40, 0, 6, 0, 0),
 (0, 13, 0, 0, 0, 41, 0, 2, 0, 0),
 (0, 2, 0, 0, 0, 4, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 2, 0, 0, 0, 0)]

In [76]:
def zeropadding(l,fillvalue=0):
    return list(itertools.zip_longest(*l,fillvalue=fillvalue))

In [77]:
leng=[len(index) for index in indexes]
max(leng)

10

In [78]:
test_result=zeropadding(indexes)
print(len(test_result))
test_result

10


[(3, 7, 16, 8, 33, 35, 42, 47, 50, 58),
 (4, 8, 4, 31, 34, 36, 2, 7, 51, 2),
 (2, 9, 2, 22, 4, 37, 0, 48, 52, 0),
 (0, 10, 0, 6, 4, 38, 0, 40, 6, 0),
 (0, 4, 0, 2, 4, 7, 0, 45, 2, 0),
 (0, 11, 0, 0, 2, 39, 0, 49, 0, 0),
 (0, 12, 0, 0, 0, 40, 0, 6, 0, 0),
 (0, 13, 0, 0, 0, 41, 0, 2, 0, 0),
 (0, 2, 0, 0, 0, 4, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 2, 0, 0, 0, 0)]

In [79]:
def binaryMatrix(l,value=0):
    m=[]
    for i,seq in enumerate(l):
        m.append([])
        for token in seq:
            if token==PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

In [37]:
a=[[1,2,3],[2,3]]
a[1].append(2)
a

[[1, 2, 3], [2, 3, 2]]

In [38]:
def inputVar(l,voc):
    indexes_batch=[indexesfromsentence(voc,sentence) for sentence in l]
    lengths=[len(indexes) for indexes in indexes_batch]
    padList=zeropadding(indexes_batch)
    padVar=torch.LongTensor(padList)
    
    return padVar,lengths

In [39]:
def outputVar(l,voc):
    indexes_batch=[indexesfromsentence(voc,sentence) for sentence in l]
    max_target_length=max([len(indexes) for indexes in indexes_batch])
    padList=zeropadding(indexes_batch)
    mask=binaryMatrix(padList)
    mask=torch.ByteTensor(mask)
    padVar=torch.LongTensor(padList)
    
    return padVar,mask,max_target_length

In [40]:
def batchTraindata(voc,pair_batch):
    pair_batch.sort(key=lambda x:len(x[0].split(' ')),reverse=True)
    input_batch,output_batch=[],[]
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp,lengths=inputVar(input_batch,voc)
    output,mask,max_target_length=outputVar(output_batch,voc)
    
    return inp,lengths,output,mask,max_target_length

In [41]:
small_batch_size=5
batches=batchTraindata(voc,[random.choice(pairs) for _ in range(small_batch_size)])
input_variable,lengths,target_variable,mask,max_target_length=batches
print("input_variable:")
p=torch.nn.utils.rnn.pack_padded_sequence(input_variable,lengths)
print(p)
print("lengths:")
print(lengths)
print("target_variable:")
print(target_variable)
print("mask:")
print(mask)
print("max_target_length:")
print(max_target_length)
print(torch.nn.utils.rnn.pad_packed_sequence(p))
print(target_variable[max_target_length-1])

input_variable:
PackedSequence(data=tensor([ 387,  101,   42,  167,   50,   25,   37,    7,   56,    6,   67,   67,
          24,  827,    2,  107,   18,    9,    4,  350,  301,  185,    2,   40,
         516,  706,  380,    4,    6, 2840,    2,    2,    6,    2]), batch_sizes=tensor([5, 5, 5, 4, 4, 3, 3, 3, 1, 1]))
lengths:
[10, 8, 8, 5, 3]
target_variable:
tensor([[ 318,  101,   25,   50,   70],
        [   4,   37,  200,   37, 2647],
        [   2, 2359, 1828,  123,   37],
        [   0,    4,  111,  177,  159],
        [   0,    2,  266,    6,    4],
        [   0,    0,   95,    2,    2],
        [   0,    0,    4,    0,    0],
        [   0,    0,    2,    0,    0]])
mask:
tensor([[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [0, 1, 1, 1, 1],
        [0, 1, 1, 1, 1],
        [0, 0, 1, 1, 1],
        [0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0]], dtype=torch.uint8)
max_target_length:
8
(tensor([[ 387,  101,   42,  167,   50],
        [  25,   37,    7,   5

In [42]:
class EncoderRNN(nn.Module):
    def __init__(self,hidden_size,embedding,n_layers=1,dropout=0):
        super(EncoderRNN,self).__init__()
        self.n_layers=n_layers
        self.hidden_size=hidden_size
        self.embedding=embedding
        self.gru=nn.GRU(hidden_size,hidden_size,n_layers,dropout=(0 if n_layers==1 else dropout),bidirectional=True)
    def forward(self,input_seq,input_lengths,hidden=None):
        embedded=self.embedding(input_seq)
        packed=torch.nn.utils.rnn.pack_padded_sequence(embedded,input_lengths)
        outputs,hidden=self.gru(packed,hidden)
        outputs,_=torch.nn.utils.rnn.pad_packed_sequence(outputs)
        outputs=outputs[:,:,:self.hidden_size]+outputs[:,:,self.hidden_size:]
        
        return outputs,hidden
        

In [43]:
class Attn(torch.nn.Module):
    def __init__(self,method,hidden_size):
        super(Attn,self).__init__()
        self.method=method
        self.hidden_size=hidden_size
    def dot_score(self,hidden,encoder_outputs):
        return torch.sum(hidden*encoder_outputs,dim=2)
    def forward(self,hidden,encoder_outputs):
        attn_energies=self.dot_score(hidden,encoder_outputs)
        attn_energies=attn_energies.t()
        return F.softmax(attn_energies,dim=1).unsqueeze(1)

In [44]:
class LongAttnDecoderRNN(nn.Module):
    def __init__(self,attn_model,embedding,hidden_size,output_size,n_layers=1,dropout=0.1):
        super(LongAttnDecoderRNN,self).__init__()
        self.attn_model=attn_model
        self.hidden_size=hidden_size
        self.output_size=output_size
        self.n_layers=n_layers
        self.dropout=dropout
        
        self.embedding=embedding
        self.embedding_dropout=nn.Dropout(dropout)
        self.gru=nn.GRU(hidden_size,hidden_size,n_layers,dropout=(0 if n_layers==1 else dropout))
        self.concat=nn.Linear(2*hidden_size,hidden_size)
        self.out=nn.Linear(hidden_size,output_size)
        
        self.attn=Attn(attn_model,hidden_size)
    def forward(self,input_step,last_hidden,encoder_outputs):
        embedded=self.embedding(input_step)
        embedded=self.embedding_dropout(embedded)
        rnn_output,hidden=self.gru(embedded,last_hidden)
        attn_weights=self.attn(rnn_output,encoder_outputs)
        context=attn_weights.bmm(encoder_outputs.transpose(0,1))
        rnn_output=rnn_output.squeeze(0)
        context=context.squeeze(1)
        concat_input=torch.cat((rnn_output,context),1)
        concat_output=torch.tanh(self.concat(concat_input))
        output=self.out(concat_output)
        output=F.softmax(output,dim=1)
        return output,hidden
        

In [45]:
def maskNLLLoss(decoder_output,target,mask):
    nTotal=mask.sum()
    target=target.view(-1,1)
   
    gathered_tensor=torch.gather(decoder_output,1,target)
    crossEntropy=-torch.log(gathered_tensor)
    loss=crossEntropy.masked_select(mask)
    loss=torch.mean(loss)
    loss=loss.to(device)
    return loss,nTotal.item()

In [46]:
small_batch_size=5
batches=batchTraindata(voc,[random.choice(pairs) for _ in range(small_batch_size)])
input_variable,lengths,target_variable,mask,max_target_length=batches
print("input_variable shape:")
print(input_variable.shape)
lengths=torch.tensor(lengths)
print("lengths shape:")
print(lengths.shape)
print("target_variable shape:")
print(target_variable.shape)
print("mask shape:")
print(mask.shape)
print("max_target_length:")
print(max_target_length)
hidden_size=500
encoder_n_layers=2
decoder_n_layers=2
dropout=0.1
attn_model='dot'
embedding=nn.Embedding(voc.num_words,hidden_size)
encoder=EncoderRNN(hidden_size,embedding,encoder_n_layers,dropout)
decoder=LongAttnDecoderRNN(attn_model,embedding,hidden_size,voc.num_words,decoder_n_layers,dropout)
encoder=encoder.to(device)
decoder=decoder.to(device)
encoder.train()
decoder.train()
encoder_optimizer=optim.Adam(encoder.parameters(),lr=0.0001)
decoder_optimizer=optim.Adam(decoder.parameters(),lr=0.0001)
input_variable=input_variable.to(device)
lengths=lengths.to(device)
target_variable=target_variable.to(device)
mask=mask.to(device)
loss=0
print_losses=[]
n_totals=0

encoder_outputs,encoder_hidden=encoder(input_variable,lengths)
print("Encoder output's shape:",encoder_outputs.shape)
print("Last encoder hidden shape:",encoder_hidden.shape)
decoder_input=torch.LongTensor([[SOS_token for _ in range(small_batch_size)]])
decoder_input=decoder_input.to(device)
print("Initial decoder input shape",decoder_input.shape)
print(decoder_input)
decoder_hidden=encoder_hidden[:decoder.n_layers]
print("Initial decoder hidden state shape",decoder_hidden.shape,"\n")
print("----------------------------------------------------------------")
print("Now let's look what's happening in every timestep of the GRU")
print("----------------------------------------------------------------")
print("\n")

for t in range(max_target_length):
    decoder_output,decoder_hidden=decoder(decoder_input,decoder_hidden,encoder_outputs)
    print("Decoder output shape:",decoder_output.shape)
    print("Decoder hidden shape:",decoder_hidden.shape)
    
    decoder_input=target_variable[t].view(1,-1)
    print("The target variable at the current timestep before reshaping",target_variable[t])
    print("The target variable at the current timestep shape after reshaping",target_variable[t].shape)
    print("The decoder input shape (reshape the target variable)",decoder_input.shape)
    
    print("The mask at the current timestep",mask[t])
    print("The mask at the current timestep shape",mask[t].shape)
    mask_loss,nTotal=maskNLLLoss(decoder_output,target_variable[t],mask[t])
    print("Mask loss:",mask_loss)
    print("Total:",nTotal)
    loss+=mask_loss
    print_losses.append(mask_loss.item()*nTotal)
    n_totals+=nTotal
    print(n_totals)
    encoder_optimizer.step()
    encoder_optimizer.step()
    returned_loss=sum(print_losses)/n_totals
    print("Returned_loss:",returned_loss)
    print("\n")
    print('-------------------------------DONE ONE TIMESTEP--------------------------------')
    print("\n")

input_variable shape:
torch.Size([9, 5])
lengths shape:
torch.Size([5])
target_variable shape:
torch.Size([9, 5])
mask shape:
torch.Size([9, 5])
max_target_length:
9
Encoder output's shape: torch.Size([9, 5, 500])
Last encoder hidden shape: torch.Size([4, 5, 500])
Initial decoder input shape torch.Size([1, 5])
tensor([[1, 1, 1, 1, 1]])
Initial decoder hidden state shape torch.Size([2, 5, 500]) 

----------------------------------------------------------------
Now let's look what's happening in every timestep of the GRU
----------------------------------------------------------------


Decoder output shape: torch.Size([5, 7826])
Decoder hidden shape: torch.Size([2, 5, 500])
The target variable at the current timestep before reshaping tensor([158, 167,  25,  56,  39])
The target variable at the current timestep shape after reshaping torch.Size([5])
The decoder input shape (reshape the target variable) torch.Size([1, 5])
The mask at the current timestep tensor([1, 1, 1, 1, 1], dtype=torch

In [47]:
def train(input_variable,lengths,target_variable,mask,max_target_len,encoder,decoder,embedding,encoder_optimizer,
          decoder_optimizer,batch_size,clip,max_length=MAX_LENGTH):
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    
    input_variable=input_variable.to(device)
    lengths=torch.tensor(lengths).to(device)
    target_variable=target_variable.to(device)
    mask=mask.to(device)
      
    loss=0
    print_losses=[]
    n_totals=0
    
    encoder_outputs,encoder_hidden=encoder(input_variable,lengths)
    
    decoder_input=torch.tensor([[SOS_token for _ in range(batch_size)]])
    decoder_input=decoder_input.to(device)
    
    decoder_hidden=encoder_hidden[:decoder.n_layers]
    
    use_teacher_forcing=True if random.random() > teacher_forcing_ratio else False
    
    if use_teacher_forcing:
        for t in range(max_target_len):
            decoder_output,decoder_hidden=decoder(decoder_input,decoder_hidden,encoder_outputs)
            
            decoder_input=target_variable[t].view(1,-1)
            
            mask_loss,nTotal=maskNLLLoss(decoder_output,target_variable[t],mask[t])
            loss+=mask_loss
            print_losses.append(mask_loss.item()*nTotal)
            n_totals+=nTotal
           
            
    else:
        for t in range(max_target_len):
            decoder_output,decoder_hidden=decoder(decoder_input,decoder_hidden,encoder_outputs)
            _,topi=decoder_output.topk(1)
            decoder_input=torch.LongTensor([[topi[i][0] for i in range(batch_size)]])
            
            mask_loss,nTotal=maskNLLLoss(decoder_output,target_variable[t],mask[t])
            loss+=mask_loss
            print_losses.append(mask_loss.item()*nTotal)
            n_totals+=nTotal
    
    loss.backward()
    
    torch.nn.utils.clip_grad_norm_(encoder.parameters(),clip)
    torch.nn.utils.clip_grad_norm_(decoder.parameters(),clip)

    encoder_optimizer.step()
    decoder_optimizer.step()
    
    return sum(print_losses)/n_totals

In [48]:
def trainIters(model_name,voc,pairs,encoder,decoder,encoder_optimizer,decoder_optimizer,embedding,encoder_n_layers,
               decoder_n_layers,save_dir,n_iteration,branch_size,print_every,save_every,clip,corpus_name,loadFilename):
    training_batches=[batchTraindata(voc,[random.choice(pairs) for _ in range(batch_size)]) for _ in range(n_iteration)]
    print('Initialising.........')
    start_iteration=1
    print_loss=0
    if loadFilename:
        start_iteration=checkpoint['iteration']+1
    print('training....')
    for iteration in range(start_iteration,n_iteration+1):
        training_batch=training_batches[iteration-1]
        input_variable,lengths,target_variable,mask,max_target_len=training_batch
        
        loss=train(input_variable,lengths,target_variable,mask,max_target_len,encoder,decoder,embedding,encoder_optimizer,decoder_optimizer,
                  batch_size,clip)
        print_loss+=loss
        if iteration % print_every == 0:
            print_avg_loss=print_loss/print_every
            print("Iteration:{},percent complete:{},Average loss:{}".format(iteration,iteration*100/n_iteration,print_avg_loss))
            print_loss=0
        if iteration % save_every == 0:
            directory=os.path.join(save_dir,model_name,corpus_name,'{}-{}_{}'.format(encoder_n_layers,decoder_n_layers,hidden_size))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({'iteration':iteration,'en':encoder.state_dict(),'de':decoder.state_dict(),'en_opt':encoder_optimizer.state_dict(),
                        'de_opt':decoder_optimizer.state_dict(),'loss':loss,'voc_dict':voc.__dict__,'embedding':embedding.state_dict()},
                       os.path.join(directory,'{}_{}.tar'.format(iteration,'checkpoint')))

In [49]:
class GreedySearchDecoder(nn.Module):
    def __init__(self,encoder,decoder):
        super(GreedySearchDecoder,self).__init__()
        self.encoder=encoder
        self.decoder=decoder
    def forward(self,input_seq,input_length,max_length):
        encoder_outputs,encoder_hidden=self.encoder(input_seq,input_length)
        decoder_hidden=encoder_hidden[:decoder.n_layers]
        decoder_input=torch.ones(1,1,device=device,dtype=torch.long)*SOS_token
        all_tokens=torch.zeros([0],device=device,dtype=torch.long)
        all_scores=torch.zeros([0],device=device,dtype=torch.long)
        for _ in range(max_length):
            decoder_output,decoder_hidden=self.decoder(decoder_input,decoder_hidden,encoder_outputs)
            decoder_score,decoder_input=torch.max(decoder_output,dim=1)
            all_scores=torch.cat((all_scores,torch.LongTensor([decoder_score])),dim=0)
            all_tokens=torch.cat((all_tokens,torch.LongTensor([decoder_input])),dim=0)
            decoder_input=torch.unsqueeze(decoder_input,0)
            
        return all_tokens,all_scores

In [85]:
def evaluate(encoder,decoder,searcher,voc,sentence,max_length=MAX_LENGTH):
    indexes_batch=[indexesfromsentence(voc,sentence)]
    lengths=torch.tensor([len(indexes) for indexes in indexes_batch])
    input_batch=torch.LongTensor(indexes_batch).transpose(0,1)
    input_batch=input_batch.to(device)
    lengths=lengths.to(device)
    tokens,scores=searcher(input_batch,lengths,max_length)
    decoded_words=[voc.index2word[token.item()] for token in tokens]
    return decoded_words

def evaluateInput(encoder,decoder,searcher,voc):
    input_sentence=''
    p=[]
    n_iteration=68002
    i=0
    while(1):
        try:
            input_sentence=input('> ')
            if (input_sentence=='q' or input_sentence=='quit'):
                break
            input_sentence=normalisestring(input_sentence)
            output_words=evaluate(encoder,decoder,searcher,voc,input_sentence)
            output_words[:]=[x for x in output_words if not(x=='EOS' or x=='PAD')]
            print('BOT:',' '.join(output_words))
            s=input('> ')
            if s!='y' and s!='q':
                p.append([])
                s=normalisestring(s)
                p[i].append(input_sentence)
                p[i].append(s)
                n_iteration+=1
                i+=1
        
        except KeyError:
            print("Error:Encountered unknown word.")    
    
    if p!=[] and input_sentence=='q':
        print_every=1
        save_every=1
                
        datafile=os.path.join("cornell movie-dialogs corpus","fomatted_movie_lines.txt")
        delimiter='\t'
        delimiter=str(codecs.decode(delimiter,"unicode_escape"))
        print("\nwriting newly formatted file")
        with open(datafile,'a',encoding="utf-8") as outputfile:
            writer=csv.writer(outputfile,delimiter=delimiter)        
            for pair in p:
                if pair not in pairs:
                    print(1)
                    writer.writerow(pair)
        
        

In [51]:
model_name='cb_model'
attn_model='dot'
hidden_size=500
encoder_n_layers=2
decoder_n_layers=2
dropout=0.2
batch_size=100

clip=50

teacher_forcing_ratio=1.0
learning_rate=0.00005
decoder_learning_ratio=5.0

loadFilename=os.path.join("cornell movie-dialogs corpus","cb_model","trained_data","2-2_500","68002_checkpoint.tar")
checkpoint_iter=4000

if loadFilename:
    checkpoint=torch.load(loadFilename)
    encoder_sd=checkpoint['en']
    decoder_sd=checkpoint['de']
    encoder_optimizer_sd=checkpoint['en_opt']
    decoder_optimizer_sd=checkpoint['de_opt']
    embedding_sd=checkpoint['embedding']
    voc.__dicy__=checkpoint['voc_dict']
print("building and encoder.....")
embedding=nn.Embedding(voc.num_words,hidden_size)
if loadFilename:
    embedding.load_state_dict(embedding_sd)
    encoder=EncoderRNN(hidden_size,embedding,encoder_n_layers,dropout)
    decoder=LongAttnDecoderRNN(attn_model,embedding,hidden_size,voc.num_words,decoder_n_layers,dropout)
if loadFilename:
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)
encoder=encoder.to(device)
decoder=decoder.to(device)
print('models built and ready to go')    

building and encoder.....
models built and ready to go


In [52]:
clip=50
teacher_forcing_ratio=1.0
learning_rate=0.000001
decoder_learning_ratio=5.0
n_iteration=70000
print_every=1
save_every=500

encoder.train()
decoder.train()

print("building optimizers")
encoder_optimizer=optim.Adam(encoder.parameters(),lr=learning_rate)
decoder_optimizer=optim.Adam(decoder.parameters(),lr=learning_rate*decoder_learning_ratio)
if loadFilename:
    encoder_optimizer.load_state_dict(encoder_optimizer_sd)
    decoder_optimizer.load_state_dict(decoder_optimizer_sd)
print("start training!")
trainIters(model_name,voc,pairs,encoder,decoder,encoder_optimizer,decoder_optimizer,embedding,encoder_n_layers,decoder_n_layers,
           "cornell movie-dialogs corpus",n_iteration,batch_size,print_every,save_every,clip,"trained_data",loadFilename)

building optimizers
start training!
Initialising.........
training....
Iteration:68003,percent complete:97.14714285714285,Average loss:0.5253319806535913
Iteration:68004,percent complete:97.14857142857143,Average loss:0.6492800640119255
Iteration:68005,percent complete:97.15,Average loss:0.7667343276538259
Iteration:68006,percent complete:97.15142857142857,Average loss:0.7289563839363048
Iteration:68007,percent complete:97.15285714285714,Average loss:0.673308609688282
Iteration:68008,percent complete:97.15428571428572,Average loss:0.7066553838467864
Iteration:68009,percent complete:97.15571428571428,Average loss:0.53169547160858
Iteration:68010,percent complete:97.15714285714286,Average loss:0.5888424184559157
Iteration:68011,percent complete:97.15857142857143,Average loss:0.791895119381128
Iteration:68012,percent complete:97.16,Average loss:0.5770385979644714
Iteration:68013,percent complete:97.16142857142857,Average loss:0.5527709856726447
Iteration:68014,percent complete:97.16285714

Iteration:68104,percent complete:97.29142857142857,Average loss:0.6750761571816143
Iteration:68105,percent complete:97.29285714285714,Average loss:0.6857410745477028
Iteration:68106,percent complete:97.29428571428572,Average loss:0.6136808112336725
Iteration:68107,percent complete:97.29571428571428,Average loss:0.7025610860866336
Iteration:68108,percent complete:97.29714285714286,Average loss:0.7341735674861627
Iteration:68109,percent complete:97.29857142857144,Average loss:0.6625813633696107
Iteration:68110,percent complete:97.3,Average loss:0.6534375479944001
Iteration:68111,percent complete:97.30142857142857,Average loss:0.5838686695370154
Iteration:68112,percent complete:97.30285714285715,Average loss:0.8655714453336043
Iteration:68113,percent complete:97.30428571428571,Average loss:0.7639919498172782
Iteration:68114,percent complete:97.30571428571429,Average loss:0.6695708819627761
Iteration:68115,percent complete:97.30714285714286,Average loss:0.6339790760061871
Iteration:68116,p

Iteration:68205,percent complete:97.43571428571428,Average loss:0.4933242314324087
Iteration:68206,percent complete:97.43714285714286,Average loss:0.6903391256005771
Iteration:68207,percent complete:97.43857142857142,Average loss:0.6136312465938484
Iteration:68208,percent complete:97.44,Average loss:0.7859762368930711
Iteration:68209,percent complete:97.44142857142857,Average loss:0.853838411039693
Iteration:68210,percent complete:97.44285714285714,Average loss:0.6517127556046424
Iteration:68211,percent complete:97.44428571428571,Average loss:0.7525118680223024
Iteration:68212,percent complete:97.44571428571429,Average loss:0.47345100290992964
Iteration:68213,percent complete:97.44714285714285,Average loss:0.5860189446677012
Iteration:68214,percent complete:97.44857142857143,Average loss:0.5813079839035616
Iteration:68215,percent complete:97.45,Average loss:0.5451571707988402
Iteration:68216,percent complete:97.45142857142856,Average loss:0.710606629358969
Iteration:68217,percent compl

Iteration:68306,percent complete:97.58,Average loss:0.6849474967209002
Iteration:68307,percent complete:97.58142857142857,Average loss:0.6143617029540623
Iteration:68308,percent complete:97.58285714285714,Average loss:0.6122709116497176
Iteration:68309,percent complete:97.58428571428571,Average loss:0.5730776848970716
Iteration:68310,percent complete:97.58571428571429,Average loss:0.7846397983583484
Iteration:68311,percent complete:97.58714285714285,Average loss:0.7159961347713921
Iteration:68312,percent complete:97.58857142857143,Average loss:0.5962939040321779
Iteration:68313,percent complete:97.59,Average loss:0.6155071343893219
Iteration:68314,percent complete:97.59142857142857,Average loss:0.7746937788276641
Iteration:68315,percent complete:97.59285714285714,Average loss:0.5462268625501714
Iteration:68316,percent complete:97.59428571428572,Average loss:0.6994559684010334
Iteration:68317,percent complete:97.59571428571428,Average loss:0.7112564826021743
Iteration:68318,percent comp

Iteration:68408,percent complete:97.72571428571429,Average loss:0.5563898034110691
Iteration:68409,percent complete:97.72714285714285,Average loss:0.5202840171149945
Iteration:68410,percent complete:97.72857142857143,Average loss:0.6175995329380035
Iteration:68411,percent complete:97.73,Average loss:0.6450764719440509
Iteration:68412,percent complete:97.73142857142857,Average loss:0.6204952620583839
Iteration:68413,percent complete:97.73285714285714,Average loss:0.6599774501659583
Iteration:68414,percent complete:97.73428571428572,Average loss:0.6687594136595726
Iteration:68415,percent complete:97.73571428571428,Average loss:0.9770547939125205
Iteration:68416,percent complete:97.73714285714286,Average loss:0.7157494944305143
Iteration:68417,percent complete:97.73857142857143,Average loss:0.7460093011806725
Iteration:68418,percent complete:97.74,Average loss:0.5705531123322207
Iteration:68419,percent complete:97.74142857142857,Average loss:0.7125077315000893
Iteration:68420,percent comp

Iteration:68509,percent complete:97.87,Average loss:0.6029170880263502
Iteration:68510,percent complete:97.87142857142857,Average loss:0.4196837528729949
Iteration:68511,percent complete:97.87285714285714,Average loss:0.8480771856867456
Iteration:68512,percent complete:97.87428571428572,Average loss:0.5871305223765011
Iteration:68513,percent complete:97.87571428571428,Average loss:0.5657917228479361
Iteration:68514,percent complete:97.87714285714286,Average loss:0.8041743138946005
Iteration:68515,percent complete:97.87857142857143,Average loss:0.6150079030915485
Iteration:68516,percent complete:97.88,Average loss:0.6115471615595518
Iteration:68517,percent complete:97.88142857142857,Average loss:0.6345889958543033
Iteration:68518,percent complete:97.88285714285715,Average loss:0.9498162369539098
Iteration:68519,percent complete:97.88428571428571,Average loss:0.6789981439734132
Iteration:68520,percent complete:97.88571428571429,Average loss:0.617032577372586
Iteration:68521,percent compl

Iteration:68610,percent complete:98.01428571428572,Average loss:0.5887912775755998
Iteration:68611,percent complete:98.01571428571428,Average loss:0.44244314189745915
Iteration:68612,percent complete:98.01714285714286,Average loss:0.649852416062491
Iteration:68613,percent complete:98.01857142857143,Average loss:0.7058894857418777
Iteration:68614,percent complete:98.02,Average loss:0.7327672017384167
Iteration:68615,percent complete:98.02142857142857,Average loss:0.8552253834333695
Iteration:68616,percent complete:98.02285714285715,Average loss:0.26032554448055695
Iteration:68617,percent complete:98.02428571428571,Average loss:0.6688477279983949
Iteration:68618,percent complete:98.02571428571429,Average loss:0.6262401044252691
Iteration:68619,percent complete:98.02714285714286,Average loss:0.4731404256419772
Iteration:68620,percent complete:98.02857142857142,Average loss:0.5498025857561097
Iteration:68621,percent complete:98.03,Average loss:0.6932169246272399
Iteration:68622,percent com

Iteration:68711,percent complete:98.15857142857143,Average loss:0.5350697056629992
Iteration:68712,percent complete:98.16,Average loss:0.6176754551233146
Iteration:68713,percent complete:98.16142857142857,Average loss:0.9458225700272921
Iteration:68714,percent complete:98.16285714285715,Average loss:0.7620374695216487
Iteration:68715,percent complete:98.16428571428571,Average loss:0.6385347169196728
Iteration:68716,percent complete:98.16571428571429,Average loss:0.843738478817253
Iteration:68717,percent complete:98.16714285714286,Average loss:0.6739778545518917
Iteration:68718,percent complete:98.16857142857143,Average loss:0.735914486270055
Iteration:68719,percent complete:98.17,Average loss:0.4342994483425703
Iteration:68720,percent complete:98.17142857142858,Average loss:0.9022706030761893
Iteration:68721,percent complete:98.17285714285714,Average loss:0.6536182921718467
Iteration:68722,percent complete:98.17428571428572,Average loss:0.635807620476725
Iteration:68723,percent complet

Iteration:68812,percent complete:98.30285714285715,Average loss:0.5700913701258915
Iteration:68813,percent complete:98.30428571428571,Average loss:0.5228372948220024
Iteration:68814,percent complete:98.30571428571429,Average loss:0.5763659690987439
Iteration:68815,percent complete:98.30714285714286,Average loss:0.8881738180132924
Iteration:68816,percent complete:98.30857142857143,Average loss:0.5986700715892986
Iteration:68817,percent complete:98.31,Average loss:0.5658132157812361
Iteration:68818,percent complete:98.31142857142858,Average loss:0.7121120777058109
Iteration:68819,percent complete:98.31285714285714,Average loss:0.4509946334306087
Iteration:68820,percent complete:98.31428571428572,Average loss:0.8223580064753166
Iteration:68821,percent complete:98.31571428571428,Average loss:0.5609740508174536
Iteration:68822,percent complete:98.31714285714285,Average loss:0.6303486424303034
Iteration:68823,percent complete:98.31857142857143,Average loss:0.6656622342438764
Iteration:68824,

Iteration:68913,percent complete:98.44714285714285,Average loss:0.6142204934841062
Iteration:68914,percent complete:98.44857142857143,Average loss:0.5220025457282763
Iteration:68915,percent complete:98.45,Average loss:0.41334355503161124
Iteration:68916,percent complete:98.45142857142856,Average loss:0.45427260679712933
Iteration:68917,percent complete:98.45285714285714,Average loss:0.6413910194765009
Iteration:68918,percent complete:98.45428571428572,Average loss:0.7683616093169885
Iteration:68919,percent complete:98.45571428571428,Average loss:0.709874031656823
Iteration:68920,percent complete:98.45714285714286,Average loss:0.6107642626850655
Iteration:68921,percent complete:98.45857142857143,Average loss:0.6121769365688792
Iteration:68922,percent complete:98.46,Average loss:0.6433852807867866
Iteration:68923,percent complete:98.46142857142857,Average loss:0.6753869095990895
Iteration:68924,percent complete:98.46285714285715,Average loss:0.8285533134456848
Iteration:68925,percent com

Iteration:69014,percent complete:98.59142857142857,Average loss:0.5972696358449017
Iteration:69015,percent complete:98.59285714285714,Average loss:0.5898091521699692
Iteration:69016,percent complete:98.59428571428572,Average loss:0.6120232128833206
Iteration:69017,percent complete:98.59571428571428,Average loss:0.8205698703480239
Iteration:69018,percent complete:98.59714285714286,Average loss:0.6466406610944578
Iteration:69019,percent complete:98.59857142857143,Average loss:0.4979928131787682
Iteration:69020,percent complete:98.6,Average loss:0.5175068836277864
Iteration:69021,percent complete:98.60142857142857,Average loss:0.49628390708178477
Iteration:69022,percent complete:98.60285714285715,Average loss:0.665028016012248
Iteration:69023,percent complete:98.60428571428571,Average loss:0.4489534180390754
Iteration:69024,percent complete:98.60571428571428,Average loss:0.5096598927942172
Iteration:69025,percent complete:98.60714285714286,Average loss:0.7579299032463875
Iteration:69026,p

Iteration:69115,percent complete:98.73571428571428,Average loss:0.8541625620342616
Iteration:69116,percent complete:98.73714285714286,Average loss:0.47027139844772897
Iteration:69117,percent complete:98.73857142857143,Average loss:0.6191042739475889
Iteration:69118,percent complete:98.74,Average loss:0.6623444265814744
Iteration:69119,percent complete:98.74142857142857,Average loss:0.5269204949979898
Iteration:69120,percent complete:98.74285714285715,Average loss:0.623915298750466
Iteration:69121,percent complete:98.74428571428571,Average loss:0.5914155830613138
Iteration:69122,percent complete:98.74571428571429,Average loss:0.3321792156830508
Iteration:69123,percent complete:98.74714285714286,Average loss:0.5449007088426402
Iteration:69124,percent complete:98.74857142857142,Average loss:0.45455703620952886
Iteration:69125,percent complete:98.75,Average loss:0.5430883572978757
Iteration:69126,percent complete:98.75142857142858,Average loss:0.5670372510545751
Iteration:69127,percent com

Iteration:69216,percent complete:98.88,Average loss:0.6874473720372464
Iteration:69217,percent complete:98.88142857142857,Average loss:0.5701726525771518
Iteration:69218,percent complete:98.88285714285715,Average loss:0.6806962014909766
Iteration:69219,percent complete:98.88428571428571,Average loss:0.5899253707522383
Iteration:69220,percent complete:98.88571428571429,Average loss:0.6165875217037988
Iteration:69221,percent complete:98.88714285714286,Average loss:0.5800942235620193
Iteration:69222,percent complete:98.88857142857142,Average loss:0.777694219979001
Iteration:69223,percent complete:98.89,Average loss:0.44376655493200556
Iteration:69224,percent complete:98.89142857142858,Average loss:0.5502406549179106
Iteration:69225,percent complete:98.89285714285714,Average loss:0.7548662087624279
Iteration:69226,percent complete:98.89428571428572,Average loss:0.7662445274108216
Iteration:69227,percent complete:98.89571428571429,Average loss:0.6120543272707237
Iteration:69228,percent comp

Iteration:69318,percent complete:99.02571428571429,Average loss:0.677132872273389
Iteration:69319,percent complete:99.02714285714286,Average loss:0.730427229692709
Iteration:69320,percent complete:99.02857142857142,Average loss:0.5501692658366232
Iteration:69321,percent complete:99.03,Average loss:0.5346209533074323
Iteration:69322,percent complete:99.03142857142858,Average loss:0.6212754561165332
Iteration:69323,percent complete:99.03285714285714,Average loss:0.7787978718967205
Iteration:69324,percent complete:99.03428571428572,Average loss:0.5955282856981242
Iteration:69325,percent complete:99.03571428571429,Average loss:0.7727922988753464
Iteration:69326,percent complete:99.03714285714285,Average loss:0.7971089409003335
Iteration:69327,percent complete:99.03857142857143,Average loss:0.7258935016846866
Iteration:69328,percent complete:99.04,Average loss:0.6818591350058335
Iteration:69329,percent complete:99.04142857142857,Average loss:0.6219174904812231
Iteration:69330,percent comple

Iteration:69419,percent complete:99.17,Average loss:0.8061248352629394
Iteration:69420,percent complete:99.17142857142858,Average loss:0.8301606848292548
Iteration:69421,percent complete:99.17285714285714,Average loss:0.5496069389923165
Iteration:69422,percent complete:99.17428571428572,Average loss:0.5626426420210784
Iteration:69423,percent complete:99.17571428571429,Average loss:0.5744396817718542
Iteration:69424,percent complete:99.17714285714285,Average loss:0.6567924372578957
Iteration:69425,percent complete:99.17857142857143,Average loss:0.6858901611409305
Iteration:69426,percent complete:99.18,Average loss:0.5569570485153809
Iteration:69427,percent complete:99.18142857142857,Average loss:0.6950292135134521
Iteration:69428,percent complete:99.18285714285715,Average loss:0.5076011570990573
Iteration:69429,percent complete:99.18428571428572,Average loss:0.719737423367861
Iteration:69430,percent complete:99.18571428571428,Average loss:0.6903534320433551
Iteration:69431,percent compl

Iteration:69520,percent complete:99.31428571428572,Average loss:0.8180652730279799
Iteration:69521,percent complete:99.31571428571428,Average loss:0.5021302135863577
Iteration:69522,percent complete:99.31714285714285,Average loss:0.4638181573040905
Iteration:69523,percent complete:99.31857142857143,Average loss:0.4811746359554328
Iteration:69524,percent complete:99.32,Average loss:0.6680561976947138
Iteration:69525,percent complete:99.32142857142857,Average loss:0.4929583278803018
Iteration:69526,percent complete:99.32285714285715,Average loss:0.5106755803633285
Iteration:69527,percent complete:99.32428571428571,Average loss:0.5954514235017762
Iteration:69528,percent complete:99.32571428571428,Average loss:0.3977073654900666
Iteration:69529,percent complete:99.32714285714286,Average loss:0.5782283315080995
Iteration:69530,percent complete:99.32857142857142,Average loss:0.7929580565730129
Iteration:69531,percent complete:99.33,Average loss:0.6136643648256263
Iteration:69532,percent comp

Iteration:69621,percent complete:99.45857142857143,Average loss:0.7523975361670766
Iteration:69622,percent complete:99.46,Average loss:0.6245099282064358
Iteration:69623,percent complete:99.46142857142857,Average loss:0.5974859560816741
Iteration:69624,percent complete:99.46285714285715,Average loss:0.5276749584683799
Iteration:69625,percent complete:99.46428571428571,Average loss:0.7132908927526062
Iteration:69626,percent complete:99.46571428571428,Average loss:0.6719170062545764
Iteration:69627,percent complete:99.46714285714286,Average loss:0.7198206874772102
Iteration:69628,percent complete:99.46857142857142,Average loss:0.8030835100436545
Iteration:69629,percent complete:99.47,Average loss:0.6031013774742251
Iteration:69630,percent complete:99.47142857142858,Average loss:0.6218296977340197
Iteration:69631,percent complete:99.47285714285714,Average loss:0.5806551910429126
Iteration:69632,percent complete:99.47428571428571,Average loss:0.46520740010978645
Iteration:69633,percent com

Iteration:69722,percent complete:99.60285714285715,Average loss:0.667976825492036
Iteration:69723,percent complete:99.60428571428571,Average loss:0.5324942232849442
Iteration:69724,percent complete:99.60571428571428,Average loss:0.5588484656685562
Iteration:69725,percent complete:99.60714285714286,Average loss:0.8311221366022508
Iteration:69726,percent complete:99.60857142857142,Average loss:0.738508011778734
Iteration:69727,percent complete:99.61,Average loss:0.6072145057028359
Iteration:69728,percent complete:99.61142857142858,Average loss:0.4978230900497666
Iteration:69729,percent complete:99.61285714285714,Average loss:0.6497326066229967
Iteration:69730,percent complete:99.61428571428571,Average loss:0.6181410312483256
Iteration:69731,percent complete:99.61571428571429,Average loss:0.5149769487899619
Iteration:69732,percent complete:99.61714285714285,Average loss:0.6483798072590091
Iteration:69733,percent complete:99.61857142857143,Average loss:0.5001093610786856
Iteration:69734,pe

Iteration:69823,percent complete:99.74714285714286,Average loss:0.6995296819070912
Iteration:69824,percent complete:99.74857142857142,Average loss:0.6655766167248606
Iteration:69825,percent complete:99.75,Average loss:0.4746787089440558
Iteration:69826,percent complete:99.75142857142858,Average loss:0.4941132902915095
Iteration:69827,percent complete:99.75285714285714,Average loss:0.6703977239294312
Iteration:69828,percent complete:99.75428571428571,Average loss:0.6717340068695243
Iteration:69829,percent complete:99.75571428571429,Average loss:0.8282499164463245
Iteration:69830,percent complete:99.75714285714285,Average loss:0.7711152286222916
Iteration:69831,percent complete:99.75857142857143,Average loss:0.4309780694568119
Iteration:69832,percent complete:99.76,Average loss:0.5608944653976159
Iteration:69833,percent complete:99.76142857142857,Average loss:0.7357392139913301
Iteration:69834,percent complete:99.76285714285714,Average loss:0.5693178419377419
Iteration:69835,percent comp

Iteration:69925,percent complete:99.89285714285714,Average loss:0.7922850087055793
Iteration:69926,percent complete:99.89428571428572,Average loss:0.6425549583254613
Iteration:69927,percent complete:99.89571428571429,Average loss:0.7196284870119536
Iteration:69928,percent complete:99.89714285714285,Average loss:0.6166487774218466
Iteration:69929,percent complete:99.89857142857143,Average loss:0.47605529745669345
Iteration:69930,percent complete:99.9,Average loss:0.7378231651096974
Iteration:69931,percent complete:99.90142857142857,Average loss:0.47726296766833914
Iteration:69932,percent complete:99.90285714285714,Average loss:0.7476487657483164
Iteration:69933,percent complete:99.90428571428572,Average loss:0.6713513436830706
Iteration:69934,percent complete:99.90571428571428,Average loss:0.6703354289723261
Iteration:69935,percent complete:99.90714285714286,Average loss:0.559247943068423
Iteration:69936,percent complete:99.90857142857143,Average loss:0.4750333888039581
Iteration:69937,

In [86]:
encoder.eval()

decoder.eval()
searcher=GreedySearchDecoder(encoder,decoder)
evaluateInput(encoder,decoder,searcher,voc)

> hey
BOT: hey how are you ?
> hey,how are you?
> i'm fine,what about you?
BOT: my need lately .
> i'm fine too,thanks for your concern
> q

writing newly formatted file
1
1


In [90]:
pair=["hey",normalisestring("hey,how are you?")]
c=1 if pair in pairs else 0
c

1

In [54]:
p

PackedSequence(data=tensor([ 387,  101,   42,  167,   50,   25,   37,    7,   56,    6,   67,   67,
          24,  827,    2,  107,   18,    9,    4,  350,  301,  185,    2,   40,
         516,  706,  380,    4,    6, 2840,    2,    2,    6,    2]), batch_sizes=tensor([5, 5, 5, 4, 4, 3, 3, 3, 1, 1]))

In [55]:
d=os.path.join("ubuntu_dialogs.tgz")
with open(d,"r",encoding="iso-8859-1") as f:
    lines=f.readlines()
    

FileNotFoundError: [Errno 2] No such file or directory: 'ubuntu_dialogs.tgz'

In [None]:
p=[['aadfd'],['adadg']]
p[0].append('fa')
p

In [None]:
p=[]
#p.append(['gda'])
print(p)

In [None]:
a=1 if "word" in voc.index2word.values() else 0
a

In [None]:
voc.word2index['bot']

In [None]:
voc.num_words


In [None]:
e=encoder.embedding
e(torch.tensor([7826]))

In [103]:
for i in range(29500,68000,500):
    filename=str(i)+("_checkpoint.tar")
    loadFilename=os.path.join("cornell movie-dialogs corpus","cb_model","trained_data","2-2_500",filename)

    if loadFilename:
        checkpoint=torch.load(loadFilename)
        loss=checkpoint['loss']
        print(i,":",loss)

29500 : 0.7274426093897204
30000 : 0.845932053761524
30500 : 0.8566680630841175
31000 : 0.6788223092314254
31500 : 0.9030907842632685
32000 : 0.5473081622758129
32500 : 0.7442630892411916
33000 : 0.903321305878545
33500 : 0.5677482092585893
34000 : 0.6860555289588014
34500 : 0.7951816880647868
35000 : 0.7411859672012181
35500 : 0.663747705381919
36000 : 0.5944213349529897
36500 : 0.6040708467674752
37000 : 0.684938199131284
37500 : 0.8195819232606608
38000 : 0.9361440103126711
38500 : 0.6763073693840734
39000 : 0.5972005088264362
39500 : 0.7699805187032773
40000 : 0.947154623880161
40500 : 0.48398274379151485
41000 : 0.8053304493945168
41500 : 0.745600637230664
42000 : 0.8528446105505132
42500 : 0.7110764380866513
43000 : 0.6633855633462552
43500 : 0.7715817067399621
44000 : 0.7679918485735527
44500 : 0.5940668119848529
45000 : 0.7814096654054435
45500 : 0.7602656885936473
46000 : 0.7048637589315966
46500 : 0.5814745712908724
47000 : 0.7955107401883047
47500 : 0.725797190054043
48000 :

In [100]:
for i in range(0,10,2):
    print(str(i)+('weeee'))

0weeee
2weeee
4weeee
6weeee
8weeee
