In [1]:
from collections import Counter
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.utils.data
import math
import torch.nn.functional as F
import warnings
warnings.filterwarnings("ignore")

  return torch._C._cuda_getDeviceCount() > 0


Download the data from here: http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip

In [2]:
corpus_movie_conv = 'cornell movie-dialogs corpus/movie_conversations.txt'
corpus_movie_lines = 'cornell movie-dialogs corpus/movie_lines.txt'
max_len = 25

In [3]:
with open(corpus_movie_conv, 'r') as c:
    conv = c.readlines()

In [4]:
with open(corpus_movie_lines, 'r') as l:
    lines = l.readlines()

In [5]:
conv

["u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L194', 'L195', 'L196', 'L197']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L198', 'L199']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L200', 'L201', 'L202', 'L203']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L204', 'L205', 'L206']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L207', 'L208']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L271', 'L272', 'L273', 'L274', 'L275']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L276', 'L277']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L280', 'L281']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L363', 'L364']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L365', 'L366']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L367', 'L368']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L401', 'L402', 'L403']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L404', 'L405', 'L406', 'L407']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L575', 'L576']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L577', 'L578']\n",
 "u0 +++$+++ u2 +++$+++ m0 +++$+++ ['L662', 'L663']\n",
 "u0 +++$+++ u2 

In [6]:
lines_dic = {}
for line in lines:
    objects = line.split(" +++$+++ ")
    lines_dic[objects[0]] = objects[-1]

In [7]:
lines_dic

{'L1045': 'They do not!\n',
 'L1044': 'They do to!\n',
 'L985': 'I hope so.\n',
 'L984': 'She okay?\n',
 'L925': "Let's go.\n",
 'L924': 'Wow\n',
 'L872': "Okay -- you're gonna need to learn how to lie.\n",
 'L871': 'No\n',
 'L870': 'I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n',
 'L869': 'Like my fear of wearing pastels?\n',
 'L868': 'The "real you".\n',
 'L867': 'What good stuff?\n',
 'L866': "I figured you'd get to the good stuff eventually.\n",
 'L865': 'Thank God!  If I had to hear one more story about your coiffure...\n',
 'L864': "Me.  This endless ...blonde babble. I'm like, boring myself.\n",
 'L863': 'What crap?\n',
 'L862': 'do you listen to this crap?\n',
 'L861': 'No...\n',
 'L860': 'Then Guillermo says, "If you go any lighter, you\'re gonna look like an extra on 90210."\n',
 'L699': 'You always been this selfish?\n',
 'L698': 'But\n',
 'L697': "Then that's all you had to say.\n",
 'L696': 'Well, no...\n',
 'L695

In [8]:
def remove_punc(string):
    punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
    no_punct = ""
    for char in string:
        if char not in punctuations:
            no_punct = no_punct + char  # space is also a character
    return no_punct.lower()

In [9]:
pairs = []
for con in conv:
    ids = eval(con.split(" +++$+++ ")[-1]) # ['L276', 'L277']
    for i in range(len(ids)):
        qa_pairs = []
        
        if i==len(ids)-1:
            break
        
        first = remove_punc(lines_dic[ids[i]].strip()) # sentence 1      
        second = remove_punc(lines_dic[ids[i+1]].strip()) # sentence 2
        qa_pairs.append(first.split()[:max_len]) # split with space and take only max_len
        qa_pairs.append(second.split()[:max_len])# split with space and take only max_len
        pairs.append(qa_pairs)

In [10]:
pairs[:1]

[[['can',
   'we',
   'make',
   'this',
   'quick',
   'roxanne',
   'korrine',
   'and',
   'andrew',
   'barrett',
   'are',
   'having',
   'an',
   'incredibly',
   'horrendous',
   'public',
   'break',
   'up',
   'on',
   'the',
   'quad',
   'again'],
  ['well',
   'i',
   'thought',
   'wed',
   'start',
   'with',
   'pronunciation',
   'if',
   'thats',
   'okay',
   'with',
   'you']]]

In [11]:
word_freq = Counter()
for pair in pairs:
    word_freq.update(pair[0])
    word_freq.update(pair[1])

In [12]:
word_freq

Counter({'can': 14103,
         'we': 25912,
         'make': 5821,
         'this': 30502,
         'quick': 310,
         'roxanne': 1,
         'korrine': 1,
         'and': 52128,
         'andrew': 49,
         'barrett': 20,
         'are': 21713,
         'having': 1081,
         'an': 8827,
         'incredibly': 49,
         'horrendous': 4,
         'public': 306,
         'break': 799,
         'up': 14316,
         'on': 23908,
         'the': 120903,
         'quad': 2,
         'again': 2807,
         'well': 16263,
         'i': 137633,
         'thought': 4202,
         'wed': 541,
         'start': 1459,
         'with': 21394,
         'pronunciation': 2,
         'if': 16727,
         'thats': 14742,
         'okay': 5946,
         'you': 169693,
         'not': 26494,
         'hacking': 18,
         'gagging': 9,
         'spitting': 15,
         'part': 1260,
         'please': 3258,
         'then': 7532,
         'how': 14001,
         'bout': 393,
         'try

In [13]:
min_word_freq = 5
words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
word_map = {k: v + 1 for v, k in enumerate(words)}
word_map['<unk>'] = len(word_map) + 1
word_map['<start>'] = len(word_map) + 1
word_map['<end>'] = len(word_map) + 1
word_map['<pad>'] = 0

In [14]:
print("Total words are: {}".format(len(word_map)))


Total words are: 18243


In [15]:
word_map['<unk>'], word_map['<start>'], word_map['<end>'], word_map['<pad>']

(18240, 18241, 18242, 0)

In [16]:
with open('WORDMAP_corpus.json', 'w') as j:
    json.dump(word_map, j)

In [17]:
def encode_question(words, word_map):
    enc_c = [word_map.get(word, word_map['<unk>']) for word in words] + [word_map['<pad>']] * (max_len - len(words))
    return enc_c

In [18]:
def encode_reply(words, word_map):
    enc_c = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in words] + \
    [word_map['<end>']] + [word_map['<pad>']] * (max_len - len(words))
    return enc_c

In [19]:
pairs_encoded = []
for pair in pairs:
    qus = encode_question(pair[0], word_map)
    ans = encode_reply(pair[1], word_map)
    pairs_encoded.append([qus, ans])

In [20]:
pairs_encoded[0]

[[1,
  2,
  3,
  4,
  5,
  18240,
  18240,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  18240,
  13,
  14,
  15,
  16,
  17,
  18240,
  18,
  0,
  0,
  0],
 [18241,
  19,
  20,
  21,
  22,
  23,
  24,
  18240,
  25,
  26,
  27,
  24,
  28,
  18242,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0]]

In [21]:
with open('pairs_encoded.json', 'w') as p:
    json.dump(pairs_encoded, p)

In [22]:
class Dataset(Dataset):

    def __init__(self):

        self.pairs = json.load(open('pairs_encoded.json'))
        self.dataset_size = len(self.pairs)

    def __getitem__(self, i):
        
        question = torch.LongTensor(self.pairs[i][0])
        reply = torch.LongTensor(self.pairs[i][1])
            
        return question, reply

    def __len__(self):
        return self.dataset_size

In [23]:
train_loader = torch.utils.data.DataLoader(Dataset(),
                                           batch_size = 100, 
                                           shuffle=True, 
                                           pin_memory=True)

In [24]:
question, reply = next(iter(train_loader))

In [25]:
question.shape

torch.Size([100, 25])

In [26]:
reply.shape

torch.Size([100, 27])

In [27]:
# triu does opposite
# we have to transpose the result
size = 5
print(torch.triu(torch.ones(size, size))) # opposite of what we need masking
print(torch.triu(torch.ones(size, size)).transpose(0, 1)) # this is what we need masking later words
print(torch.triu(torch.ones(size, size)).transpose(0, 1).unsqueeze(0)) # batch dims

tensor([[1., 1., 1., 1., 1.],
        [0., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1.],
        [0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 1.]])
tensor([[1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1.]])
tensor([[[1., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0.],
         [1., 1., 1., 0., 0.],
         [1., 1., 1., 1., 0.],
         [1., 1., 1., 1., 1.]]])


In [28]:
question[0]

tensor([ 266,   99, 2993,  267,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0])

In [29]:
question[0]!=0 # where the word only present => True else other index is false

tensor([ True,  True,  True,  True, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False])

In [44]:
def create_masks(question, reply_input, reply_target):
    
    def subsequent_mask(size):
        mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        return mask.unsqueeze(0)
    
    question_mask = question!=0 # 
    question_mask = question_mask.to(device)
    question_mask = question_mask.unsqueeze(1).unsqueeze(1)         # (batch_size, 1, 1, max_words) : 4 dims
     
    reply_input_mask = reply_input!=0
    reply_input_mask = reply_input_mask.unsqueeze(1)  # (batch_size, 1, max_words)
    # combine: decoder mask : Mask padded words and future words
    reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data) 
    reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words, max_words)
    reply_target_mask = reply_target!=0              # (batch_size, max_words)
    
    return question_mask, reply_input_mask, reply_target_mask

In [45]:
class Embeddings(nn.Module):
    """
    Implements embeddings of the words and adds their positional encodings. 
    """
    def __init__(self, vocab_size, d_model, max_len = 50,num_layers=6):
        super(Embeddings, self).__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(0.1)
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pe = self.create_positinal_encoding(max_len, self.d_model) # (1, max_len, d_model)
        self.te = self.create_positinal_encoding(num_layers, self.d_model) # (1, num_layers, d_model)
        self.dropout = nn.Dropout(0.1)
        
    def create_positinal_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model).to(device)
        for pos in range(max_len):   # for each position of the word
            for i in range(0, d_model, 2):   # for each dimension of the each position
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
        pe = pe.unsqueeze(0)   # include the batch size
        return pe
        
    def forward(self, embedding, layer_idx):
        if layer_idx == 0:
            embedding = self.embed(embedding) * math.sqrt(self.d_model)
        embedding += self.pe[:, :embedding.size(1)]   # pe will automatically be expanded with the same batch size as encoded_words
        # embedding: (batch_size, max_len, d_model)
        # layer_idx is timestep idx : # (batch_size, d_model)
        timestep_embedding = self.te[:, layer_idx, :] 
        # repeat for max_len
        # # (batch_size, d_model) -> # (batch_size, 1, d_model): repeat max_len times 
        embedding += timestep_embedding.unsqueeze(1).repeat(1, embedding.size(1), 1)
        embedding = self.dropout(embedding)
        return embedding

In [46]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model):
        super(MultiHeadAttention, self).__init__()
        assert d_model % heads == 0
        self.d_k = d_model // heads
        self.heads = heads
        self.dropout = nn.Dropout(0.1)
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.concat = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value, mask):
        """
        query, key and value: Shape is (batch_size, max_words, 512)
        mask: Shape is (batch_size, 1, 1, max_words)
        """
        query = self.query(query) # (batch_size, max_words, 512)
        key = self.key(key) # (batch_size, max_words, 512)
        value = self.value(value) # (batch_size, max_words, 512)
        
        # (batch_size, max_words, 512) => (batch_size, max_words , 8, 512//8) => (batch_size, 8, max_words, d_k)
        query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3) # Swamp dim 1 and 2
        # (batch_size, max_words, 512) => (batch_size, max_words , 8, 512//8) => (batch_size, 8, max_words, d_k)
        key = key.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3) # Swamp dim 1 and 2
        # (batch_size, max_words, 512) => (batch_size, max_words , 8, 512//8) => (batch_size, 8, max_words, d_k)
        value = value.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3) # Swamp dim 1 and 2
        
        # (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
        scores = torch.matmul(query, key.permute(0, 1, 3, 2))/ math.sqrt(query.size(-1))
        # mask_fill with small value, prevents on calculating attention weights of padding squence
        scores = scores.masked_fill(mask == 0, -1e9)
        weights = F.softmax(scores, dim = -1)
        weights = self.dropout(weights)
        
        # (batch_size, 8, max_len, max_len) dot ((batch_size, 8, max_len, d_k)) => (batch_size, 8, max_len, d_k)
        context = torch.matmul(weights, value)
        # (batch_size, 8, max_len, d_k) --> (batch_size, max_len, 8, d_k) => (batch_size, max_len, 8*d_k)
        context = context.permute(0, 2, 1, 3).contiguous().view(context.shape[0], -1, self.heads * self.d_k) # shape: (batch_size, max_len, 512)
        # (batch_size, max_len, h * d_k)
        interacted = self.concat(context)
        return interacted
        

In [47]:
class FeedForward(nn.Module):
    def __init__(self, d_model, hidden_dim = 2048):
        super(FeedForward, self).__init__()
        
        self.fc1 = nn.Linear(d_model, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, d_model)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = self.fc2(self.dropout(out))
        return out

In [48]:
class EncoderLayer(nn.Module):

    def __init__(self, d_model, heads):
        super(EncoderLayer, self).__init__()
        self.layernorm = nn.LayerNorm(d_model)
        self.self_multihead = MultiHeadAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, embeddings, mask):
        interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
        interacted = self.layernorm(interacted + embeddings)
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        encoded = self.layernorm(feed_forward_out + interacted)
        return encoded

In [49]:
class DecoderLayer(nn.Module):
    
    def __init__(self, d_model, heads):
        super(DecoderLayer, self).__init__()
        self.layernorm = nn.LayerNorm(d_model)
        self.self_multihead = MultiHeadAttention(heads, d_model)
        self.src_multihead = MultiHeadAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, embeddings, encoded, src_mask, target_mask):
        query = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, target_mask))
        query = self.layernorm(query + embeddings)
        interacted = self.dropout(self.src_multihead(query, encoded, encoded, src_mask))
        interacted = self.layernorm(interacted + query)
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        decoded = self.layernorm(feed_forward_out + interacted)
        return decoded

In [50]:
class Transformer(nn.Module):
    
    def __init__(self, d_model, heads, num_layers, word_map):
        super(Transformer, self).__init__()
        
        self.d_model = d_model
        self.num_layers = num_layers
        self.vocab_size = len(word_map)
        self.embed = Embeddings(self.vocab_size, d_model, num_layers=num_layers)
        self.encoder = EncoderLayer(d_model, heads) 
        self.decoder = DecoderLayer(d_model, heads) 
        self.logit = nn.Linear(d_model, self.vocab_size)
        
    def encode(self, src_embeddings, src_mask):
        for i in range(self.num_layers):
            src_embeddings = self.embed(src_embeddings, i)
            src_embeddings = self.encoder(src_embeddings, src_mask)
        return src_embeddings
    
    def decode(self, tgt_embeddings, target_mask, src_embeddings, src_mask):
        for i in range(self.num_layers):
            tgt_embeddings = self.embed(tgt_embeddings, i)
            tgt_embeddings = self.decoder(tgt_embeddings, src_embeddings, src_mask, target_mask)
        return tgt_embeddings
        
    def forward(self, src_words, src_mask, target_words, target_mask):
        encoded = self.encode(src_words, src_mask)
        decoded = self.decode(target_words, target_mask, encoded, src_mask)
        out = F.log_softmax(self.logit(decoded), dim = 2)
        return out

In [51]:
class AdamWarmup:
    
    def __init__(self, model_size, warmup_steps, optimizer):
        self.model_size = model_size
        self.warmup_steps = warmup_steps
        self.optimizer = optimizer
        self.current_step = 0
        self.lr = 0
        
    def get_lr(self):
        return self.model_size ** (-0.5) * min(self.current_step ** (-0.5), self.current_step * self.warmup_steps ** (-1.5))

    def step(self):
        self.current_step += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        self.lr = lr
        # Update weights
        self.optimizer.step()
        

In [52]:
class LossWithLS(nn.Module):

    def __init__(self, size, smooth):
        super(LossWithLS, self).__init__()
        self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
        self.confidence = 1.0 - smooth
        self.smooth = smooth
        self.size = size
        
    def forward(self, prediction, target, mask):
        """
        prediction of shape: (batch_size, max_words, vocab_size)
        target and mask of shape: (batch_size, max_words)
        """
        prediction = prediction.view(-1, prediction.size(-1))   # (batch_size * max_words, vocab_size)
        target = target.contiguous().view(-1)   # (batch_size * max_words)
        mask = mask.float()
        mask = mask.view(-1)       # (batch_size * max_words)
        labels = prediction.data.clone()
        labels.fill_(self.smooth / (self.size - 1))
        labels.scatter_(1, target.data.unsqueeze(1), self.confidence)
        loss = self.criterion(prediction, labels)    # (batch_size * max_words, vocab_size)
        loss = (loss.sum(1) * mask).sum() / mask.sum()
        return loss

In [53]:
d_model = 512
heads = 8
num_layers = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 10

with open('WORDMAP_corpus.json', 'r') as j:
    word_map = json.load(j)
    
transformer = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, word_map = word_map)
transformer = transformer.to(device)
adam_optimizer = torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
transformer_optimizer = AdamWarmup(model_size = d_model, warmup_steps = 4000, optimizer = adam_optimizer)
criterion = LossWithLS(len(word_map), 0.1)

In [54]:
def train(train_loader, transformer, criterion, epoch):
    transformer.train()
    sum_loss = 0
    count = 0
    
    for i, (question, reply) in enumerate(train_loader):
        samples = question.shape[0]
        question = question.to(device)
        reply = reply.to(device)
        
        """
        Sample:
        Sentence: <start> I went home <end>
        Input: <start> I went home
        Target: I went home <end>
        """
        reply_input = reply[:, :-1] # except last word i.e <end> token
        reply_target = reply[:, 1:] # 
        
    
        question_mask, reply_input_mask, reply_target_mask = create_masks(question, reply_input, reply_target)
        
        # Run through transformer
        out = transformer(src_words=question, src_mask=question_mask, target_words=reply_input, target_mask=reply_input_mask)
        loss = criterion(out, reply_target, reply_target_mask)
        
        # Backprop
        transformer_optimizer.optimizer.zero_grad()
        
        loss.backward()
        transformer_optimizer.step()
        
        sum_loss += loss.item() * samples
        count += samples  # At the last batch for an epoch, count = len(train_loader)*100 i.e. 100 is batchsize
        
        if i % 100 == 0:
            print("Epoch [{}][{}/{}]\tLoss: {:.3f}".format(epoch, i, len(train_loader), sum_loss/count))
            

In [55]:
def evaluate(transformer, question, question_mask, max_len, word_map):
    """
    Performs Greedy Search decodiing with a batch size of 1
    """
    rev_word_map = {v:k for k, v in word_map.items()}
    transformer.eval()
    start_token = word_map['<start>']
    encoded = transformer.encode(question, question_mask)
    words = torch.LongTensor([[start_token]]).to(device) # (1, 1)
    
    for step in range(max_len -1):
        size = words.shape[0]
        target_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        target_mask = target_mask.to(device).unsqueeze(0)
        decoded = transformer.decode(words, target_mask, encoded, question_mask)
        # decoded is of shape (1, 1, vocab_size)
        predictions = transformer.pred(decoded[:, -1]) # final decoded logit output: prediction of the word about to generate
        # predictions of shape (1, vocab_size)
        _, next_word_index = torch.max(predictions, dim = 1)
        next_word_index = next_word_index.item() # (1, 1)
        if next_word_index == word_map['<end>']:
            break
        words = torch.cat([words, torch.LongTensor([[next_word_index]]).to(device)], dim=1) # (1, step + 2)
        
    # Construct Sentence
    #(1, 5)
    if words.dim() == 2:
        words = words.squeeze(0)
        words = words.tolist()
        
    sen_idx = [w for w in words if w not in {word_map['<start>']}]
    sentence = ' '.join([rev_word_map[sen_idx[k]] for k in range(len(sen_idx))])
    return sentence
    

In [None]:
for epoch in range(epochs):
    
    train(train_loader, transformer, criterion, epoch)
    
    state = {'epoch': epoch, 'transformer': transformer, 'transformer_optimizer': transformer_optimizer}
    torch.save(state, 'checkpoint_' + str(epoch) + '.pth.tar')

In [None]:
checkpoint = torch.load('checkpoint.pth.tar')
transformer = checkpoint['transformer']

In [None]:
while(1):
    question = input("Question: ") 
    if question == 'quit':
        break
    max_len = input("Maximum Reply Length: ")
    enc_qus = [word_map.get(word, word_map['<unk>']) for word in question.split()]
    question = torch.LongTensor(enc_qus).to(device).unsqueeze(0)
    question_mask = (question!=0).to(device).unsqueeze(1).unsqueeze(1)  
    sentence = evaluate(transformer, question, question_mask, int(max_len), word_map)
    print(sentence)