In [1]:
import os
# from tqdm import tqdm

import json
import math
import numpy as np
from collections import Counter

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.utils.data
import torch.nn.functional as F

torch.cuda.get_device_name()

'GeForce 940MX'

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
max_len = 25

In [5]:
DATA_BASE_DIR = '../../Resources/data/chatbot/cornell_movie_dialogs_corpus/cornell movie-dialogs corpus/'

In [6]:
movie_conv = os.path.join(DATA_BASE_DIR, 'movie_conversations.txt')
movie_lines = os.path.join(DATA_BASE_DIR, 'movie_lines.txt')

In [7]:
with open(movie_conv, 'r') as f:
    conv = f.readlines()
f.close()

In [8]:
with open(movie_lines, 'r') as f:
    lines = f.readlines()
f.close()

In [9]:
lines_dict = {}
for line in lines:
    objects = line.split(' +++$+++ ')
    lines_dict[objects[0]] = objects[-1]

In [10]:
lines_dict

{'L1045': 'They do not!\n',
 'L1044': 'They do to!\n',
 'L985': 'I hope so.\n',
 'L984': 'She okay?\n',
 'L925': "Let's go.\n",
 'L924': 'Wow\n',
 'L872': "Okay -- you're gonna need to learn how to lie.\n",
 'L871': 'No\n',
 'L870': 'I\'m kidding.  You know how sometimes you just become this "persona"?  And you don\'t know how to quit?\n',
 'L869': 'Like my fear of wearing pastels?\n',
 'L868': 'The "real you".\n',
 'L867': 'What good stuff?\n',
 'L866': "I figured you'd get to the good stuff eventually.\n",
 'L865': 'Thank God!  If I had to hear one more story about your coiffure...\n',
 'L864': "Me.  This endless ...blonde babble. I'm like, boring myself.\n",
 'L863': 'What crap?\n',
 'L862': 'do you listen to this crap?\n',
 'L861': 'No...\n',
 'L860': 'Then Guillermo says, "If you go any lighter, you\'re gonna look like an extra on 90210."\n',
 'L699': 'You always been this selfish?\n',
 'L698': 'But\n',
 'L697': "Then that's all you had to say.\n",
 'L696': 'Well, no...\n',
 'L695

In [11]:
import string
def remove_punctuation(str_):
    return str_.translate(str.maketrans('', '', string.punctuation))

In [12]:
pairs = []
for con in conv:
    ids = eval(con.split('+++$+++ ')[-1])
    for i in range(len(ids)):
        qa_pairs = []
        
        if i == len(ids)-1:
            break
            
        first = remove_punctuation(lines_dict[ids[i]].strip())
        second = remove_punctuation(lines_dict[ids[i+1]].strip())
        
        qa_pairs.append(first.split()[:max_len])
        qa_pairs.append(second.split()[:max_len])
        pairs.append(qa_pairs)

In [13]:
word_freq = Counter()
for pair in pairs:
    word_freq.update(pair[0])
    word_freq.update(pair[1])

In [14]:
min_word_freq = 5
words = [w for w in word_freq.keys() if word_freq[w] > min_word_freq]
word_map = {k: v+1 for v, k in enumerate(words)}
word_map['<unk>'] = len(word_map) + 1
word_map['<start>'] = len(word_map) + 1
word_map['<end>'] = len(word_map) + 1
word_map['<pad>'] = 0

len(word_map)

20841

In [15]:
def encode_question(words, word_map):
    encoded = [word_map.get(word, word_map['<unk>']) for word in words] + \
              [word_map['<pad>']] * (max_len - len(words))
    return encoded

In [16]:
def encode_reply(words, word_map):
    encoded = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in words] + \
              [word_map['<end>']] + [word_map['<pad>']] * (max_len - len(words))
    
    return encoded

In [17]:
encoded_pairs = []
for pair in pairs:
    question = encode_question(pair[0], word_map)
    answer = encode_reply(pair[1], word_map)
    encoded_pairs.append([question, answer])

In [18]:
class ConversationDataset(Dataset):
    def __init__(self, qa_pairs):
        self.pairs = qa_pairs
        self.dataset_size = len(self.pairs)
        
    def __getitem__(self, i):
        question = torch.LongTensor(self.pairs[i][0])
        reply = torch.LongTensor(self.pairs[i][1])
        
        return question, reply
    
    def __len__(self):
        return self.dataset_size

In [19]:
train_loader = torch.utils.data.DataLoader(ConversationDataset(encoded_pairs),
                                           batch_size=100,
                                           shuffle=True,
                                           pin_memory=True)

In [20]:
def create_masks(question, reply_input, reply_target):
    
    def subsequent_mask(size):
        mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        return mask.unsqueeze(0) ## Add dimension to 0th idx
    
    question_mask = question != 0
    question_mask = question_mask.to(device)
    question_mask = question_mask.unsqueeze(1).unsqueeze(1) ## (batch_size, 1, 1, max_words)
    
    reply_input_mask = reply_input != 0
    reply_input_mask = reply_input_mask.unsqueeze(1) ## (batch_size, 1, max_words)
    reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data)
    reply_input_mask = reply_input_mask.unsqueeze(1) ## (batch_size, 1, max_words, max_words)
    reply_target_mask = reply_target!=0 ## (batch_size, max_words)
    
    return question_mask, reply_input_mask, reply_target_mask

In [21]:
class Embeddings(nn.Module):
    
    '''
        Calculate embeddings, then create positional encodings
    '''
    
    def __init__(self, vocab_size, d_model, max_len=50):
        super(Embeddings, self).__init__()
        
        self.d_model = d_model
        self.dropout = nn.Dropout(0.1)
        self.embed = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = self.create_positional_encoding(max_len, self.d_model)
        self.dropout = nn.Dropout(0.1)
        
        
    def create_positional_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model).to(device)
        for pos in range(max_len): ## for each position of a word
            for i in range(0, d_model, 2): ## for each dimension of each position
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * (i + 1)) / d_model )))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
                
        pe = pe.unsqueeze(0)
        return pe
                
        
    def forward(self, encoded_words):
        embedding = self.embed(encoded_words) * math.sqrt(self.d_model)
        embedding += self.positional_encoding[:, :embedding.size(1)]
        embedding = self.dropout(embedding)
        return embedding

In [22]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model):
        super(MultiHeadAttention, self).__init__()
        
        assert d_model % heads == 0
        
        self.d_k = d_model // heads
        self.heads = heads
        self.dropout = nn.Dropout(0.1)
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.concat = nn.Linear(d_model, d_model)
        
    
    def forward(self, query, key, value, mask):
        '''
            q, k, v of size (batch_size, max_len, 512)
            mask of size (batch_size, 1, 1, max_words)
        '''
        
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)
        
        ## (batch_size, max_len, 512) --> (batch_size, max_len, h, d_k) --> (batch_size, h, max_len, d_k)
        query = query.view(query.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
        key = key.view(key.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)  
        value = value.view(value.shape[0], -1, self.heads, self.d_k).permute(0, 2, 1, 3)
        
        ## (batch_size, h, max_len, d_k) matmul (batch_size, h, d_k, max_len) --> (batch_size, h, max_len, max_len)
        scores = torch.matmul(query, key.permute(0, 1, 3, 2)) / math.sqrt(query.size(-1))
        scores = scores.masked_fill(mask == 0, -1e9)
        weights = F.softmax(scores, dim=-1)
        weights = self.dropout(weights)
        
        ## (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
        context = torch.matmul(weights, value)
        
        ## (batch_size, h, max_len, max_len) matmul (batch_size, h, max_len, d_k) --> (batch_size, h, max_len, d_k)
        context = context.permute(0, 2, 1, 3).contiguous().view(context.shape[0], -1, self.heads * self.d_k)
        
        interacted = self.concat(context)
        
        return interacted

In [23]:
class FeedForward(nn.Module):
    def __init__(self, d_model, middle_dim=2048):
        super(FeedForward, self).__init__()
        
        self.fc1 = nn.Linear(d_model, middle_dim)
        self.fc2 = nn.Linear(middle_dim, d_model)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = self.fc2(self.dropout(out))
        
        return out

In [24]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, heads):
        super(EncoderLayer, self).__init__()
        
        self.layernorm = nn.LayerNorm(d_model)
        self.self_multihead = MultiHeadAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model)
        self.dropout = nn.Dropout(0.1)
        
        
    def forward(self, embeddings, mask):
        interacted = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, mask))
        interacted = self.layernorm(interacted + embeddings)
        feed_forward_opt = self.dropout(self.feed_forward(interacted))
        encoded = self.layernorm(feed_forward_opt + interacted)
        
        return encoded

In [25]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, heads):
        super(DecoderLayer, self).__init__()
        
        self.layernorm = nn.LayerNorm(d_model)
        self.self_multihead = MultiHeadAttention(heads, d_model)
        self.src_multihead = MultiHeadAttention(heads, d_model)
        self.feed_forward = FeedForward(d_model)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, embeddings, encoded, src_mask, target_mask):
        query = self.dropout(self.self_multihead(embeddings, embeddings, embeddings, target_mask))
        query = self.layernorm(query + embeddings)
        interacted = self.dropout(self.src_multihead(query, encoded, encoded, src_mask))
        interacted = self.layernorm(interacted + query)
        feed_forward_out = self.dropout(self.feed_forward(interacted))
        decoded = self.layernorm(feed_forward_out + interacted)
        return decoded

In [26]:
class Transformer(nn.Module):
    
    def __init__(self, d_model, heads, num_layers, word_map):
        super(Transformer, self).__init__()
        
        self.d_model = d_model
        self.vocab_size = len(word_map)
        self.embed = Embeddings(self.vocab_size, d_model)
        self.encoder = nn.ModuleList([EncoderLayer(d_model, heads) for _ in range(num_layers)])
        self.decoder = nn.ModuleList([DecoderLayer(d_model, heads) for _ in range(num_layers)])
        self.logit = nn.Linear(d_model, self.vocab_size)
        
        
    def encode(self, src_words, src_mask):
        src_embeddings = self.embed(src_words)
        for layer in self.encoder:
            src_embeddings = layer(src_embeddings, src_mask)
            
        return src_embeddings
    
    
    def decode(self, target_words, target_mask, src_embeddings, src_mask):
        tgt_embeddings = self.embed(target_words)
        for layer in self.decoder:
            tgt_embeddings = layer(tgt_embeddings, src_embeddings, src_mask, target_mask)
        return tgt_embeddings
    
    
    def forward(self, src_words, src_mask, target_words, target_mask):
        encoded = self.encode(src_words, src_mask)
        decoded = self.decode(target_words, target_mask, encoded, src_mask)
        out = F.log_softmax(self.logit(decoded), dim = 2)
        return out
    
    

In [27]:
class AdamWarmup:
    
    def __init__(self, model_size, warmup_steps, optimizer):
        self.model_size = model_size
        self.warmup_steps = warmup_steps
        self.optimizer = optimizer
        self.current_step = 0
        self.lr = 0
        
        
    def get_lr(self):
        return self.model_size ** (-0.5) * min(self.current_step ** (-0.5), self.current_step * self.warmup_steps ** (-1.5))
    
    
    def step(self):
        # Increment the number of steps each time we call the step function
        self.current_step += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        # update the learning rate
        self.lr = lr
        self.optimizer.step()

In [28]:
class LossWithLS(nn.Module):

    def __init__(self, size, smooth):
        super(LossWithLS, self).__init__()
        self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
        self.confidence = 1.0 - smooth
        self.smooth = smooth
        self.size = size
        
    def forward(self, prediction, target, mask):
        """
        prediction of shape: (batch_size, max_words, vocab_size)
        target and mask of shape: (batch_size, max_words)
        """
        prediction = prediction.view(-1, prediction.size(-1))   # (batch_size * max_words, vocab_size)
        target = target.contiguous().view(-1)   # (batch_size * max_words)
        mask = mask.float()
        mask = mask.view(-1)       # (batch_size * max_words)
        labels = prediction.data.clone()
        labels.fill_(self.smooth / (self.size - 1))
        labels.scatter_(1, target.data.unsqueeze(1), self.confidence)
        loss = self.criterion(prediction, labels)    # (batch_size * max_words, vocab_size)
        loss = (loss.sum(1) * mask).sum() / mask.sum()
        return loss

In [29]:
d_model = 512
heads = 8
num_layers = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 10

    
transformer = Transformer(d_model = d_model, heads = heads, num_layers = num_layers, word_map = word_map)
transformer = transformer.to(device)
adam_optimizer = torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
transformer_optimizer = AdamWarmup(model_size = d_model, warmup_steps = 4000, optimizer = adam_optimizer)
criterion = LossWithLS(len(word_map), 0.1)



In [30]:
def train(train_loader, transformer, criterion, epochs):
    
    transformer.train()
    sum_loss = 0
    count = 0
    
    for i, (question, reply) in enumerate(train_loader):
        samples = question.shape[0]
        
        question = question.to(device)
        reply = reply.to(device)
        
        reply_input = reply[:, :-1]
        reply_target = reply[:, 1:]
        
        question_mask, reply_input_mask, reply_target_mask = create_masks(question, reply_input, reply_target)
        
        ## Transformer outputs
        out = transformer(question, question_mask, reply_input, reply_input_mask)
        
        ## Compute loss
        loss = criterion(out, reply_target, reply_target_mask)
        
        ## Backprop
        transformer_optimizer.optimizer.zero_grad()
        loss.backward()
        transformer_optimizer.step()
        
        sum_loss += loss.item() * samples
        count += samples
        
        if i % 10 == 0:
            print("Epoch [{}][{}/{}]\tLoss: {:.3f}".format(epoch, i, len(train_loader), sum_loss/count))

In [None]:
for epoch in range(epochs):
    
    train(train_loader, transformer, criterion, epoch)
    
    state = {'epoch': epoch, 'transformer': transformer, 'transformer_optimizer': transformer_optimizer}
    torch.save(state, 'checkpoint_' + str(epoch) + '.pth.tar')

Epoch [0][0/2217]	Loss: 8.795
Epoch [0][10/2217]	Loss: 8.769
Epoch [0][20/2217]	Loss: 8.739
Epoch [0][30/2217]	Loss: 8.687
Epoch [0][40/2217]	Loss: 8.618
Epoch [0][50/2217]	Loss: 8.527
Epoch [0][60/2217]	Loss: 8.421
Epoch [0][70/2217]	Loss: 8.312
Epoch [0][80/2217]	Loss: 8.205
Epoch [0][90/2217]	Loss: 8.105
Epoch [0][100/2217]	Loss: 8.015
Epoch [0][110/2217]	Loss: 7.928
Epoch [0][120/2217]	Loss: 7.849
Epoch [0][130/2217]	Loss: 7.773
Epoch [0][140/2217]	Loss: 7.700
Epoch [0][150/2217]	Loss: 7.630
Epoch [0][160/2217]	Loss: 7.564
Epoch [0][170/2217]	Loss: 7.499
Epoch [0][180/2217]	Loss: 7.435
Epoch [0][190/2217]	Loss: 7.374
Epoch [0][200/2217]	Loss: 7.313
Epoch [0][210/2217]	Loss: 7.255
Epoch [0][220/2217]	Loss: 7.197
Epoch [0][230/2217]	Loss: 7.142
Epoch [0][240/2217]	Loss: 7.088
Epoch [0][250/2217]	Loss: 7.035
Epoch [0][260/2217]	Loss: 6.985
Epoch [0][270/2217]	Loss: 6.937
Epoch [0][280/2217]	Loss: 6.890
Epoch [0][290/2217]	Loss: 6.845
Epoch [0][300/2217]	Loss: 6.801
Epoch [0][310/2217]