## Book Keeping


In [5]:
!wget http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip

--2021-02-25 11:35:57--  http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
Resolving www.cs.cornell.edu (www.cs.cornell.edu)... 132.236.207.36
Connecting to www.cs.cornell.edu (www.cs.cornell.edu)|132.236.207.36|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9916637 (9.5M) [application/zip]
Saving to: ‘cornell_movie_dialogs_corpus.zip.1’


2021-02-25 11:35:59 (11.8 MB/s) - ‘cornell_movie_dialogs_corpus.zip.1’ saved [9916637/9916637]



In [6]:
!unzip cornell_movie_dialogs_corpus.zip

Archive:  cornell_movie_dialogs_corpus.zip
replace cornell movie-dialogs corpus/.DS_Store? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: cornell movie-dialogs corpus/.DS_Store  
replace __MACOSX/cornell movie-dialogs corpus/._.DS_Store? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: __MACOSX/cornell movie-dialogs corpus/._.DS_Store  
replace cornell movie-dialogs corpus/chameleons.pdf? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: cornell movie-dialogs corpus/chameleons.pdf  
replace __MACOSX/cornell movie-dialogs corpus/._chameleons.pdf? [y]es, [n]o, [A]ll, [N]one, [r]ename: y
  inflating: __MACOSX/cornell movie-dialogs corpus/._chameleons.pdf  
replace cornell movie-dialogs corpus/movie_characters_metadata.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: A
  inflating: cornell movie-dialogs corpus/movie_characters_metadata.txt  
  inflating: cornell movie-dialogs corpus/movie_conversations.txt  
  inflating: cornell movie-dialogs corpus/movie_lines.txt  
  inflating: co

## Importing Libs

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import csv, re, random, os, unicodedata, codecs
from io import open
import math, itertools
import matplotlib.pyplot as plt
import numpy as np
from fastprogress import progress_bar

In [8]:
USE_CUDA = torch.cuda.is_available()
device = torch.device('cuda' if USE_CUDA else 'cpu')

## Reading dataset

In [9]:
corpus_name = "cornell movie-dialogs corpus"
corpus = corpus_name

In [10]:
def printLines(datapath, n=10):
    with open(datapath, 'rb') as f:
        lines = f.readlines()
    for line in lines[:n]:
        print(line)
printLines(os.path.join(corpus, "movie_lines.txt"))


b'L1045 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ They do not!\n'
b'L1044 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ They do to!\n'
b'L985 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I hope so.\n'
b'L984 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ She okay?\n'
b"L925 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Let's go.\n"
b'L924 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ Wow\n'
b"L872 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Okay -- you're gonna need to learn how to lie.\n"
b'L871 +++$+++ u2 +++$+++ m0 +++$+++ CAMERON +++$+++ No\n'
b'L870 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n'
b'L869 +++$+++ u0 +++$+++ m0 +++$+++ BIANCA +++$+++ Like my fear of wearing pastels?\n'


In [11]:
# Splits each line of the file into a dictionary of fields
def loadLines(fileName, fields):
    lines = {}
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            lineObj = {}
            for i, field in enumerate(fields):
                lineObj[field] = values[i]
            lines[lineObj['lineID']] = lineObj
    return lines


# Groups fields of lines from `loadLines` into conversations based on *movie_conversations.txt*
def loadConversations(fileName, lines, fields):
    conversations = []
    with open(fileName, 'r', encoding='iso-8859-1') as f:
        for line in f:
            values = line.split(" +++$+++ ")
            # Extract fields
            convObj = {}
            for i, field in enumerate(fields):
                convObj[field] = values[i]
            # Convert string to list (convObj["utteranceIDs"] == "['L598485', 'L598486', ...]")
            utterance_id_pattern = re.compile('L[0-9]+')
            lineIds = utterance_id_pattern.findall(convObj["utteranceIDs"])
            # Reassemble lines
            convObj["lines"] = []
            for lineId in lineIds:
                convObj["lines"].append(lines[lineId])
            conversations.append(convObj)
    return conversations


# Extracts pairs of sentences from conversations
def extractSentencePairs(conversations):
    qa_pairs = []
    for conversation in conversations:
        # Iterate over all the lines of the conversation
        for i in range(len(conversation["lines"]) - 1):  # We ignore the last line (no answer for it)
            inputLine = conversation["lines"][i]["text"].strip()
            targetLine = conversation["lines"][i+1]["text"].strip()
            # Filter wrong samples (if one of the lists is empty)
            if inputLine and targetLine:
                qa_pairs.append([inputLine, targetLine])
    return qa_pairs

In [12]:
# Define path to new file
datafile = os.path.join(corpus, "formatted_movie_lines.txt")

delimiter = '\t'
# Unescape the delimiter
delimiter = str(codecs.decode(delimiter, "unicode_escape"))

# Initialize lines dict, conversations list, and field ids
lines = {}
conversations = []
MOVIE_LINES_FIELDS = ["lineID", "characterID", "movieID", "character", "text"]
MOVIE_CONVERSATIONS_FIELDS = ["character1ID", "character2ID", "movieID", "utteranceIDs"]

# Load lines and process conversations
print("\nProcessing corpus...")
lines = loadLines(os.path.join(corpus, "movie_lines.txt"), MOVIE_LINES_FIELDS)
print("\nLoading conversations...")
conversations = loadConversations(os.path.join(corpus, "movie_conversations.txt"),
                                  lines, MOVIE_CONVERSATIONS_FIELDS)

# Write new csv file
print("\nWriting newly formatted file...")
with open(datafile, 'w', encoding='utf-8') as outputfile:
    writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')
    for pair in extractSentencePairs(conversations):
        writer.writerow(pair)

# Print a sample of lines
print("\nSample lines from file:")
printLines(datafile)


Processing corpus...

Loading conversations...

Writing newly formatted file...

Sample lines from file:
b"Can we make this quick?  Roxanne Korrine and Andrew Barrett are having an incredibly horrendous public break- up on the quad.  Again.\tWell, I thought we'd start with pronunciation, if that's okay with you.\n"
b"Well, I thought we'd start with pronunciation, if that's okay with you.\tNot the hacking and gagging and spitting part.  Please.\n"
b"Not the hacking and gagging and spitting part.  Please.\tOkay... then how 'bout we try out some French cuisine.  Saturday?  Night?\n"
b"You're asking me out.  That's so cute. What's your name again?\tForget it.\n"
b"No, no, it's my fault -- we didn't have a proper introduction ---\tCameron.\n"
b"Cameron.\tThe thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't date until she does.\n"
b"The thing is, Cameron -- I'm at the mercy of a particularly hideous breed of loser.  My sister.  I can't dat

In [13]:
PAD_token = 0  # Used for padding short sentences
SOS_token = 1  # Start-of-sentence token
EOS_token = 2  # End-of-sentence token

class Voc:
    def __init__(self, name):
        self.name = name
        self.trimmed = False
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3  # Count SOS, EOS, PAD

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.num_words
            self.word2count[word] = 1
            self.index2word[self.num_words] = word
            self.num_words += 1
        else:
            self.word2count[word] += 1

    # Remove words below a certain count threshold
    def trim(self, min_count):
        if self.trimmed:
            return
        self.trimmed = True

        keep_words = []

        for k, v in self.word2count.items():
            if v >= min_count:
                keep_words.append(k)

        print('keep_words {} / {} = {:.4f}'.format(
            len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)
        ))

        # Reinitialize dictionaries
        self.word2index = {}
        self.word2count = {}
        self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}
        self.num_words = 3 # Count default tokens

        for word in keep_words:
            self.addWord(word)

In [14]:
MAX_LENGTH = 10  # Maximum sentence length to consider

# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

# Read query/response pairs and return a voc object
def readVocs(datafile, corpus_name):
    print("Reading lines...")
    # Read the file and split into lines
    lines = open(datafile, encoding='utf-8').\
        read().strip().split('\n')
    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
    voc = Voc(corpus_name)
    return voc, pairs

# Returns True iff both sentences in a pair 'p' are under the MAX_LENGTH threshold
def filterPair(p):
    # Input sequences need to preserve the last word for EOS token
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH

# Filter pairs using filterPair condition
def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

# Using the functions defined above, return a populated voc object and pairs list
def loadPrepareData(corpus, corpus_name, datafile, save_dir):
    print("Start preparing training data ...")
    voc, pairs = readVocs(datafile, corpus_name)
    print("Read {!s} sentence pairs".format(len(pairs)))
    pairs = filterPairs(pairs)
    print("Trimmed to {!s} sentence pairs".format(len(pairs)))
    print("Counting words...")
    for pair in pairs:
        voc.addSentence(pair[0])
        voc.addSentence(pair[1])
    print("Counted words:", voc.num_words)
    return voc, pairs


# Load/Assemble voc and pairs
save_dir = os.path.join("data", "save")
voc, pairs = loadPrepareData(corpus, corpus_name, datafile, save_dir)
# Print some pairs to validate
print("\npairs:")
for pair in pairs[:10]:
    print(pair)

Start preparing training data ...
Reading lines...
Read 221282 sentence pairs
Trimmed to 64271 sentence pairs
Counting words...
Counted words: 18008

pairs:
['there .', 'where ?']
['you have my word . as a gentleman', 'you re sweet .']
['hi .', 'looks like things worked out tonight huh ?']
['you know chastity ?', 'i believe we share an art instructor']
['have fun tonight ?', 'tons']
['well no . . .', 'then that s all you had to say .']
['then that s all you had to say .', 'but']
['but', 'you always been this selfish ?']
['do you listen to this crap ?', 'what crap ?']
['what good stuff ?', 'the real you .']


In [15]:
MIN_COUNT = 3    # Minimum word count threshold for trimming

def trimRareWords(voc, pairs, MIN_COUNT):
    # Trim words used under the MIN_COUNT from the voc
    voc.trim(MIN_COUNT)
    # Filter out pairs with trimmed words
    keep_pairs = []
    for pair in pairs:
        input_sentence = pair[0]
        output_sentence = pair[1]
        keep_input = True
        keep_output = True
        # Check input sentence
        for word in input_sentence.split(' '):
            if word not in voc.word2index:
                keep_input = False
                break
        # Check output sentence
        for word in output_sentence.split(' '):
            if word not in voc.word2index:
                keep_output = False
                break

        # Only keep pairs that do not contain trimmed word(s) in their input or output sentence
        if keep_input and keep_output:
            keep_pairs.append(pair)

    print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))
    return keep_pairs


# Trim voc and pairs
pairs = trimRareWords(voc, pairs, MIN_COUNT)

keep_words 7823 / 18005 = 0.4345
Trimmed from 64271 pairs to 53165, 0.8272 of total


In [16]:
def indexesFromSentence(voc, sentence):
    return [voc.word2index[word] for word in sentence.split(' ')] + [EOS_token]


def zeroPadding(l, fillvalue=PAD_token):
    return list(itertools.zip_longest(*l, fillvalue=fillvalue))

def binaryMatrix(l, value=PAD_token):
    m = []
    for i, seq in enumerate(l):
        m.append([])
        for token in seq:
            if token == PAD_token:
                m[i].append(0)
            else:
                m[i].append(1)
    return m

# Returns padded input sequence tensor and lengths
def inputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    padVar = torch.LongTensor(padList)
    return padVar, lengths

# Returns padded target sequence tensor, padding mask, and max target length
def outputVar(l, voc):
    indexes_batch = [indexesFromSentence(voc, sentence) for sentence in l]
    max_target_len = max([len(indexes) for indexes in indexes_batch])
    padList = zeroPadding(indexes_batch)
    mask = binaryMatrix(padList)
    mask = torch.BoolTensor(mask)
    padVar = torch.LongTensor(padList)
    return padVar, mask, max_target_len

# Returns all items for a given batch of pairs
def batch2TrainData(voc, pair_batch):
    pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)
    input_batch, output_batch = [], []
    for pair in pair_batch:
        input_batch.append(pair[0])
        output_batch.append(pair[1])
    inp, lengths = inputVar(input_batch, voc)
    output, mask, max_target_len = outputVar(output_batch, voc)
    return inp, lengths, output, mask, max_target_len


# Example for validation
small_batch_size = 5
batches = batch2TrainData(voc, [random.choice(pairs) for _ in range(small_batch_size)])
input_variable, lengths, target_variable, mask, max_target_len = batches

print("input_variable:", input_variable)
print("lengths:", lengths)
print("target_variable:", target_variable)
print("mask:", mask)
print("max_target_len:", max_target_len)

input_variable: tensor([[  68,   38,   25,   25,   23],
        [   7,   76,  348,  718,    6],
        [ 218,  115,   64,  587,    2],
        [ 371,  115,    4,    4,    0],
        [  21, 6326,    2,    2,    0],
        [  56,  122,    0,    0,    0],
        [ 960, 7173,    0,    0,    0],
        [2484,    4,    0,    0,    0],
        [   6,    2,    0,    0,    0],
        [   2,    0,    0,    0,    0]])
lengths: tensor([10,  9,  5,  5,  3])
target_variable: tensor([[ 124,   25,    5,   50,  101],
        [  34,  200,   37,  326,  604],
        [ 572, 3818, 1550,    6,    7],
        [  66,   40,    6,    2,   53],
        [   2, 6326,    2,    0,  301],
        [   0,    4,    0,    0,  227],
        [   0,    2,    0,    0,    4],
        [   0,    0,    0,    0,    2]])
mask: tensor([[ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True],
        [ True,  True

## Model

In [17]:
class Encoder(nn.Module):
    def __init__(self, vocab_size,
                 embedding,
                 embed_dims,
                 pf_dims,
                 n_layers,
                 n_heads,
                 dropout = 0.1,
                 max_length = 100,
                 device = 'cuda'
                 ):
        super().__init__()
        # save the parameters
        self.embed_dims = embed_dims
        self.device = device
        # define token embedding 
        self.tok_embedder = embedding
        #self.tok_embedder = embedding_layer
        # define scaling parameter for embedding 
        self.scale = torch.sqrt(torch.FloatTensor([embed_dims])).to(device)
        # adding the position embedding, which embeds the token position into embed dims
        self.pos_embedder = nn.Embedding(max_length, embed_dims)
        # adding the dropout for the tok embedding
        self.dropout = nn.Dropout(dropout)
        # add the Multiheaded attention layer
        self.attn_layers = nn.ModuleList()
        hid_dims = embed_dims
        for _ in range(n_layers):
            self.attn_layers.append(EncoderLayer(hid_dims, n_heads, dropout, pf_dims, device=device))
    
    def forward(self, src, src_mask):
        # src: [batch size, srclen]
        batch_size = src.shape[0]
        srclen = src.shape[1]
        # embed the source language tokens into embedding dimensions
        tok_embedded = self.tok_embedder(src)*self.scale
        # create a new tensor containing the positions 
        positions = torch.arange(0,srclen).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        # embed the source positions
        pos_embedded = self.pos_embedder(positions)
        # add the position embedding to token embedding and pass through the dropout layer
        src_embedded = self.dropout((tok_embedded + pos_embedded))
        x = src_embedded
        # pass the embedded vaues through the encoder layers 
        for lyr in self.attn_layers:
            x = lyr(x, src_mask)
        return x


In [19]:
class EncoderLayer(nn.Module):
    def __init__(self,
                 hid_dims,
                 n_heads,
                 dropout,
                 pf_dims,
                 device = 'cuda'):
        super().__init__()
        # define attention layer 
        self.attnlayer = MultiHeadAttentionLayer(hid_dims, n_heads, dropout, device = device)
        # define a normalization layer 
        self.postattnnormlayer = nn.LayerNorm(hid_dims)
        # creating a PointWiseProjection layer
        self.pointwiselayer = PointWiseProjectionLayer(hid_dims, pf_dims, dropout)
        # define a normalization layer to be applied after pointwise layer
        self.postpointwisenormlayer = nn.LayerNorm(hid_dims)
        # define the dropout layer
        self.dropout = nn.Dropout(dropout)
    def forward(self, src, src_mask=None):
        # src: [batch_size, srclen, embed_dims]
        # pass through the attention layer
        _src, _ = self.attnlayer(src, src, src, mask = src_mask)
        # add the original src to attention based src values and pass through the normalization layer
        src = self.postattnnormlayer(src + self.dropout(_src))
        # pass this through the PointWisePorjection layer
        _src = self.pointwiselayer(src)
        # add the src to projected src and apply norm layer
        src = self.postpointwisenormlayer(src + self.dropout(_src))
        return src

In [21]:
class PointWiseProjectionLayer(nn.Module):
    def __init__(self, hid_dims, out_dims, dropout):
        super().__init__()
        self.init_prj = nn.Linear(hid_dims, out_dims)
        self.out_prj = nn.Linear(out_dims, hid_dims)
        self.dropout = nn.Dropout(dropout)
    def forward(self, lin_inp):
        # lin_inp: [batch_size, srclen, hid_dims]
        x = self.dropout(torch.relu(self.init_prj(lin_inp)))
        # x: [batch_size, srclen, out_dims]
        x = self.out_prj(x)
        return x


In [22]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self,
                 hid_dims,
                 n_heads,
                 dropout,
                 device = 'cuda'):
        super().__init__()
        self.n_heads = n_heads
        self.head_dims = hid_dims//n_heads
        self.hid_dims = hid_dims
        self.device = device
        # create projection layer for each of the query, key and value -- hid_dims here refer to the embedding dims used in encoder/decoder
        self.q_prj = nn.Linear(hid_dims, hid_dims)
        self.k_prj = nn.Linear(hid_dims, hid_dims)
        self.v_prj = nn.Linear(hid_dims, hid_dims)
        # define a scale parameter to be applied to raw attention scores for stabilization
        self.scale = torch.sqrt(torch.FloatTensor([self.head_dims])).to(self.device)
        # define a dropout for attention
        self.dropout = nn.Dropout(dropout)
        # define an output projection layer
        self.out_lyr = nn.Linear(hid_dims, hid_dims)

    def forward(self, query, key, value, mask=None):
        # query: [batch size, srclen, hid_dims]
        # key: [batch size, srclen, hid_dims]
        # value: [batch size, srclen, hid_dims]
        batch_size = query.shape[0]
        srclen = query.shape[1]
        # pass query, key and value through the projection layers

        Q = self.q_prj(query).view(batch_size, -1, self.n_heads, self.head_dims)
        K = self.k_prj(key).view(batch_size, -1, self.n_heads, self.head_dims)
        V = self.v_prj(value).view(batch_size, -1, self.n_heads, self.head_dims)
        # Q,K,V: [batch_size, srclen, n_heads, head_dims]
        Q = Q.permute(0, 2, 1, 3)
        K = K.permute(0, 2, 1, 3)
        V = V.permute(0, 2, 1, 3)
        # Q,K,V: [batch_size, n_heads, srclen, head_dims]

        # computing the raw energy of the attention 
        energy = torch.matmul(Q, K.permute(0,1,3,2))/self.scale
        # energy : [batch_size, n_heads, srclen, srclen]
        # defining extremely small energy at masked locations if provided
        if mask is not None:
            energy = energy.masked_fill(mask == 0, -1e10)
        # apply softmax to raw score 
        attention = torch.softmax(energy, dim = -1)
        # computed the final value based on attention
        value = self.dropout(torch.matmul(attention, V))
        # value : [batch_size, n_heads, srclen, head_dims]
        value = value.permute(0,2,1,3).contiguous()
        # reshape to get back the original hidden dims
        value = value.view(batch_size, -1, self.hid_dims)
        # value: [batch_size, srclen, hid_dims]
        # apply the output projection layer
        out = self.out_lyr(value)
        # out: [batch_size, srclen, hid_dims]
        return out, attention

In [23]:
class Decoder(nn.Module):
    def __init__(self,
                 vocab_size,
                 embedding,
                 embed_dims,
                 pf_dims,
                 n_layers,
                 n_heads,
                 dropout=0.1,
                 device = 'cuda',
                 max_length = 100):
        super().__init__()
        self.hid_dims = embed_dims
        self.device = device
        self.scale = torch.sqrt(torch.FloatTensor([embed_dims])).to(device)
        # create the token embedding and position embedding layers
        self.tok_embedder = embedding
        #self.tok_embedder = embedding_layer
        self.pos_embedder = nn.Embedding(max_length, embed_dims)
        hid_dims = embed_dims
        # defining out dropout layer
        self.dropout = nn.Dropout(dropout)
        # add the decoder layers
        self.decoder_layers = nn.ModuleList()
        for _ in range(n_layers):
            self.decoder_layers.append(DecoderLayer(hid_dims,pf_dims,n_heads,dropout, device = device))
        # define the final output layer
        self.out_lyr = nn.Linear(hid_dims, vocab_size)
    
    def forward(self, trg, encoder_outputs, trg_mask, src_mask):
        # trg: [batch_size, trglen]
        batch_size = trg.shape[0]
        trglen = trg.shape[1]
        # pass the target tokens through the embedding layer
        tok_embedded = self.tok_embedder(trg)*self.scale
        # tok_embedded: [batch_size, trglen, embed_dims]
        # create the positions of each token
        positions = torch.arange(0, trglen).unsqueeze(0).repeat(batch_size, 1).to(self.device)
        # embedded the positions
        pos_embedded = self.pos_embedder(positions)
        # combine both the embeddings
        trg_embedded = self.dropout(tok_embedded + pos_embedded)
        # pass through the decoder blocks one by one
        trg = trg_embedded
        for lyr in self.decoder_layers:
            trg, attention = lyr(trg, encoder_outputs, trg_mask, src_mask)
        # trg: [batch_size, trglen, hid_dims]
        # pass through the final output layer
        out = self.out_lyr(trg)
        return F.softmax(out, dim=-1), attention


In [24]:
class DecoderLayer(nn.Module):
    def __init__(self,
                 hid_dims,
                 pf_dims,
                 n_heads,
                 dropout,
                 device = 'cuda'):
        super().__init__()
        self.hid_dims = hid_dims
        self.device = device
        self.dropout = nn.Dropout(dropout)
        # define the self attention layer
        self.self_attn_lyr = MultiHeadAttentionLayer(hid_dims, n_heads, dropout, device = device)
        # define the decoder - encoder attention layer
        self.enc_attn_lyr = MultiHeadAttentionLayer(hid_dims, n_heads, dropout, device = device)
        # pointwise projecttion layer
        self.pf_lyr = PointWiseProjectionLayer(hid_dims, pf_dims, dropout)
        # add the selfattn norm layer
        self.postselfnormlyr = nn.LayerNorm(hid_dims)
        # add the enc norm layer
        self.postencnormlyr = nn.LayerNorm(hid_dims)
        # add the norm layer ffor pointwise layer
        self.postpfnormlyr = nn.LayerNorm(hid_dims)
    
    def forward(self, trg, encoder_outputs, trg_mask, src_mask):
        # trg: [batch_size, trglen, hid_dims]
        # pass through the self attention layer
        _trg, _ = self.self_attn_lyr(trg, trg, trg, trg_mask)
        # add the residual connection and pass through the normalization layer
        trg = self.postselfnormlyr(trg + self.dropout(_trg))
        # pass through the encoder attention layer
        _trg, attention = self.enc_attn_lyr(trg, encoder_outputs, encoder_outputs, src_mask)
        # add the residual connection and pass through the normalization layer
        trg = self.postencnormlyr(trg + self.dropout(_trg))
        # pass the output through the pointwise layer
        out = self.postpfnormlyr(trg + self.dropout(self.pf_lyr(trg)))
        return out, attention



In [25]:
def create_src_mask(src, src_pad_idx=PAD_token, device='cuda'):
    src_mask = (src != src_pad_idx).unsqueeze(1).unsqueeze(2)
    return src_mask

def create_trg_mask(trg, trg_pad_idx=PAD_token,device='cuda'):
    trg_mask = (trg != trg_pad_idx).unsqueeze(1).unsqueeze(2)
    trglen = trg.shape[1]
    # trg_mask: [batch_size, 1, 1, trglen]
    tril_mask = torch.tril(torch.ones(trglen, trglen)).bool().to(device)
    # tril_mask: [trglen, trglen]
    trg_mask  = (trg_mask & tril_mask)
    # trg_mask: [batch_size, 1, trglen, trglen]
    return trg_mask


## Training Utils

In [26]:
def maskNLLLoss(inp, target, mask):
    nTotal = mask.sum()
    crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))
    loss = crossEntropy.masked_select(mask).mean()
    loss = loss.to(device)
    return loss, nTotal.item()

In [27]:
def train(input_variable, lengths, target_variable, mask, max_target_len, encoder, decoder, embedding,
          encoder_optimizer, decoder_optimizer, batch_size, clip, max_length=MAX_LENGTH):
    # Zero gradients
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    # Set device options
    # input_variable :[srclen,bs]
    input_variable = input_variable.transpose(0,1).to(device)
    target_variable = target_variable.transpose(0,1).to(device)
    mask = mask.transpose(0,1).to(device)
    #print(input_variable.shape, target_variable.shape, mask.shape)


    # Initialize variables
    loss = 0
    print_losses = []
    n_totals = 0
    # get src_mask
    src_mask = create_src_mask(input_variable)
    # Forward pass through encoder
    encoder_outputs = encoder(input_variable, src_mask)

    # Create initial decoder input (start with SOS tokens for each sentence)
    sos_input = torch.LongTensor([SOS_token]*batch_size).unsqueeze(1).to(device)
    decoder_input = sos_input
    # Determine if we are using teacher forcing this iteration
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False

    # Forward batch of sequences through decoder one time step at a time
    if use_teacher_forcing:
        for t in range(max_target_len):
            trg_mask = create_trg_mask(decoder_input)
            #print(trg_mask.dtype, trg_mask.shape)
            decoder_output, attention = decoder(
                decoder_input, encoder_outputs, trg_mask, src_mask
            )
            # Teacher forcing: next input is current target
            decoder_input = torch.cat([sos_input, target_variable[:, :(t+1)]], dim=1)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output[:,-1,:], target_variable[:,t], mask[:, t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal
    else:
        for t in range(max_target_len):
            trg_mask = create_trg_mask(decoder_input)
            decoder_output, attention = decoder(
                decoder_input, encoder_outputs, trg_mask, src_mask
            )
            # No teacher forcing: next input is decoder's own current output
            _, topi = decoder_output.topk(1)
            topi_input = torch.LongTensor([topi[i][0] for i in range(batch_size)]).unsqueeze(1)
            topi_inp = topi_input.to(device)
            decoder_input = torch.cat([decoder_input, target_variable[:, t]], dim=1)
            # Calculate and accumulate loss
            mask_loss, nTotal = maskNLLLoss(decoder_output[:,-1,:], target_variable[:,t], mask[:, t])
            loss += mask_loss
            print_losses.append(mask_loss.item() * nTotal)
            n_totals += nTotal

    # Perform backpropatation
    loss.backward()

    # Clip gradients: gradients are modified in place
    _ = nn.utils.clip_grad_norm_(encoder.parameters(), clip)
    _ = nn.utils.clip_grad_norm_(decoder.parameters(), clip)

    # Adjust model weights
    encoder_optimizer.step()
    decoder_optimizer.step()

    return sum(print_losses) / n_totals

In [28]:
def trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, embedding, 
               encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size, print_every, save_every, clip, corpus_name, loadFilename):

    # Load batches for each iteration
    training_batches = [batch2TrainData(voc, [random.choice(pairs) for _ in range(batch_size)])
                      for _ in range(n_iteration)]

    # Initializations
    print('Initializing ...')
    start_iteration = 1
    print_loss = 0
    if loadFilename:
        start_iteration = checkpoint['iteration'] + 1

    # Training loop
    print("Training...")
    for iteration in progress_bar(range(start_iteration, n_iteration + 1), total = n_iteration):
        training_batch = training_batches[iteration - 1]
        # Extract fields from batch
        input_variable, lengths, target_variable, mask, max_target_len = training_batch

        # Run a training iteration with batch
        loss = train(input_variable, lengths, target_variable, mask, max_target_len, encoder,
                     decoder, embedding, encoder_optimizer, decoder_optimizer, batch_size, clip)
        print_loss += loss

        # Print progress
        if iteration % print_every == 0:
            print_loss_avg = print_loss / print_every
            print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / n_iteration * 100, print_loss_avg))
            print_loss = 0

        # Save checkpoint
        if (iteration % save_every == 0):
            directory = os.path.join(save_dir, model_name, corpus_name, '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, HID_DIM))
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'iteration': iteration,
                'en': encoder.state_dict(),
                'de': decoder.state_dict(),
                'en_opt': encoder_optimizer.state_dict(),
                'de_opt': decoder_optimizer.state_dict(),
                'loss': loss,
                'voc_dict': voc.__dict__,
                'embedding': embedding.state_dict()
            }, os.path.join(directory, '{}_{}.tar'.format(iteration, 'checkpoint')))

## Building Model

In [29]:
# Configure models
model_name = 'attn_model'
dropout = 0.0
batch_size = 128
ENC_VOCAB_SIZE = voc.num_words
DEC_VOCAB_SIZE = voc.num_words
HID_DIM = 256
encoder_n_layers = 3
decoder_n_layers = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = dropout
DEC_DROPOUT = dropout
SRC_PAD_IDX = PAD_token
TRG_PAD_IDX = PAD_token

# Set checkpoint to load from; set to None if starting from scratch
loadFilename = None
checkpoint_iter = 4000
#loadFilename = os.path.join(save_dir, model_name, corpus_name,
#                            '{}-{}_{}'.format(encoder_n_layers, decoder_n_layers, hidden_size),
#                            '{}_checkpoint.tar'.format(checkpoint_iter))


# # Load model if a loadFilename is provided
# if loadFilename:
#     # If loading on same machine the model was trained on
#     checkpoint = torch.load(loadFilename)
#     # If loading a model trained on GPU to CPU
#     #checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))
#     model_sd = checkpoint['model']
#     optimizer_sd = checkpoint['opt']
#     voc.__dict__ = checkpoint['voc_dict']


print('Building encoder and decoder ...')
embedding = nn.Embedding(voc.num_words, HID_DIM).to(device)

# Initialize word embeddings
encoder = Encoder(ENC_VOCAB_SIZE, 
              embedding,
              HID_DIM,
              ENC_PF_DIM,
              encoder_n_layers, 
              ENC_HEADS,
              dropout = ENC_DROPOUT, 
              device = device).to(device)



decoder = Decoder(DEC_VOCAB_SIZE, 
              embedding,
              HID_DIM, 
              DEC_PF_DIM,
              decoder_n_layers,
              DEC_HEADS,
              dropout = DEC_DROPOUT, 
              device = device).to(device)
if loadFilename:
    #embedding.load_state_dict(embedding_sd)
    encoder.load_state_dict(encoder_sd)
    decoder.load_state_dict(decoder_sd)
# Use appropriate device


print('Models built and ready to go!')

Building encoder and decoder ...
Models built and ready to go!


## Training Model

In [30]:
clip = 1.0
teacher_forcing_ratio = 1.0
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_iteration = 10000
print_every = 100
save_every = 500

# Ensure dropout layers are in train mode
encoder.train()
decoder.train()

# Initialize optimizers
print('Building optimizers ...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)

# Run training iterations
print("Starting Training!")
trainIters(model_name, voc, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
           embedding, encoder_n_layers, decoder_n_layers, save_dir, n_iteration, batch_size,
           print_every, save_every, clip, corpus_name, loadFilename)

Building optimizers ...
Starting Training!
Initializing ...
Training...


Iteration: 100; Percent complete: 1.0%; Average loss: 4.7390
Iteration: 200; Percent complete: 2.0%; Average loss: 3.9048
Iteration: 300; Percent complete: 3.0%; Average loss: 3.7177
Iteration: 400; Percent complete: 4.0%; Average loss: 3.6230
Iteration: 500; Percent complete: 5.0%; Average loss: 3.5028
Iteration: 600; Percent complete: 6.0%; Average loss: 3.4353
Iteration: 700; Percent complete: 7.0%; Average loss: 3.3620
Iteration: 800; Percent complete: 8.0%; Average loss: 3.3011
Iteration: 900; Percent complete: 9.0%; Average loss: 3.2455
Iteration: 1000; Percent complete: 10.0%; Average loss: 3.2157
Iteration: 1100; Percent complete: 11.0%; Average loss: 3.1525
Iteration: 1200; Percent complete: 12.0%; Average loss: 3.0729
Iteration: 1300; Percent complete: 13.0%; Average loss: 3.0360
Iteration: 1400; Percent complete: 14.0%; Average loss: 2.9803
Iteration: 1500; Percent complete: 15.0%; Average loss: 2.9339
Iteration: 1600; Percent complete: 16.0%; Average loss: 2.8932
Iteration:

## Evaluate model

In [31]:
def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
    ### Format input sentence as a batch
    # words -> indexes
    indexes_batch = [indexesFromSentence(voc, sentence)]
    # Create lengths tensor
    lengths = torch.tensor([len(indexes) for indexes in indexes_batch])
    # Transpose dimensions of batch to match models' expectations
    input_batch = torch.LongTensor(indexes_batch)#.transpose(0, 1)
    # Use appropriate device
    input_batch = input_batch.to(device)
    lengths = lengths.to(device)
    # Decode sentence with searcher
    tokens, scores = searcher(input_batch, lengths, max_length)
    # indexes -> words
    decoded_words = [voc.index2word[token] for token in tokens]
    return decoded_words


def evaluateInput(encoder, decoder, searcher, voc):
    input_sentence = ''
    while(1):
        try:
            # Get input sentence
            input_sentence = input('> ')
            # Check if it is quit case
            if input_sentence == 'q' or input_sentence == 'quit': break
            # Normalize sentence
            input_sentence = normalizeString(input_sentence)
            # Evaluate sentence
            output_words = evaluate(encoder, decoder, searcher, voc, input_sentence)
            # Format and print response sentence
            output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]
            print('Bot:', ' '.join(output_words))

        except KeyError:
            print("Error: Encountered unknown word.")


In [32]:
class GreedySearchDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(GreedySearchDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder


    def forward(self, input_seq, input_length, max_length):
        self.encoder.eval()
        self.decoder.eval()
        # get source mask
        src_mask = create_src_mask(input_seq)
        # get encoder_outputs
        encoder_outputs = self.encoder(input_seq, src_mask)
        # Initialize decoder input with SOS_token
        trg_indexes = [SOS_token]
        # 
        for i in range(max_length):

            trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)

            trg_mask = create_trg_mask(trg_tensor)
        
            with torch.no_grad():
                output, attention = self.decoder(trg_tensor, encoder_outputs, trg_mask, src_mask)
        
            pred_token = output.argmax(2)[:,-1].item()
        
            trg_indexes.append(pred_token)

            if pred_token == EOS_token:
                break
        
        return trg_indexes[1:], attention
        
        

In [33]:
# Initialize search module
searcher = GreedySearchDecoder(encoder, decoder)

# Begin chatting (uncomment and run the following line to begin)
evaluateInput(encoder, decoder, searcher, voc)

> Hi
Bot: hi . . . . . . . ?
> When will the world end?
Bot: soon as an even know me .
> what do we do when we feel sad?
Bot: very soon day .
> What is the square root of pi?
Error: Encountered unknown word.
> q
