In [1]:
from collections import Counter
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import math
import torch.nn.functional as F

In [791]:
CUDA_LAUNCH_BLOCKING=1
TORCH_USE_CUDA_DSA=1

In [2]:
corpus_movie_conv = 'data/cornell movie-dialogs corpus/movie_conversations.txt'
corpus_movie_lines = 'data/cornell movie-dialogs corpus/movie_lines.txt'
max_len = 25

In [3]:
with open(corpus_movie_conv, 'r') as c:
    conv = c.readlines()

In [4]:
with open(corpus_movie_lines, 'r') as l:
    lines = l.readlines()

In [5]:
lines_dict = {}
for line in lines:
    objects = line.split(" +++$+++ ")
    lines_dict[objects[0]] = objects[-1]


In [832]:
lines_dict['L1045']

'They do not!\n'

In [6]:
# def remove_punc(string):
#     punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
#     no_punctuation_string = ""
#     for char in string:
#         if char not in punctuations:
#             no_punctuation_string = no_punctuation_string + char
#     return no_punctuation_string.lower()
def remove_punc(string):
    punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*_~'''
    no_punct = ""
    for char in string:
        if char not in punctuations:
            no_punct = no_punct + char  # space is also a character
    return no_punct.lower()

In [834]:
ele = remove_punc(lines_dict['L1045'])
ele

'they do not\n'

In [7]:
pairs = []
for con in conv:
    
    ids = eval(con.split(" +++$+++ ")[-1])
    
    for i in range(len(ids)):
        qa_pairs = []

        if i==len(ids)-1:
            break
        
        first = remove_punc(lines_dict[ids[i]].strip())      
        second = remove_punc(lines_dict[ids[i+1]].strip())
        qa_pairs.append(first.split()[:max_len])
        qa_pairs.append(second.split()[:max_len])
        pairs.append(qa_pairs)

        # qa_pairs.append(remove_punc(lines_dict[(ids[i])].strip()).split()[:max_len])
        # qa_pairs.append(remove_punc(lines_dict[(ids[i+1])].strip()).split()[:max_len])
        # pairs.append(qa_pairs)

# pairs
len(pairs)

221616

In [813]:
# len(pairs)

In [729]:
# word_count = Counter()
# for pair in pairs:
#     for word in pair[0]:
#         word_count[word] += 1
#     for word in pair[1]:
#         word_count[word] += 1

In [730]:
# word_count

In [8]:
word_freq = Counter()
for pair in pairs:
    word_freq.update(pair[0])
    word_freq.update(pair[1])

word_freq

Counter({'can': 14103,
         'we': 25912,
         'make': 5821,
         'this': 30502,
         'quick': 310,
         'roxanne': 1,
         'korrine': 1,
         'and': 52128,
         'andrew': 49,
         'barrett': 20,
         'are': 21713,
         'having': 1081,
         'an': 8827,
         'incredibly': 49,
         'horrendous': 4,
         'public': 306,
         'break': 799,
         'up': 14316,
         'on': 23908,
         'the': 120903,
         'quad': 2,
         'again': 2807,
         'well': 16263,
         'i': 137633,
         'thought': 4202,
         'wed': 541,
         'start': 1459,
         'with': 21394,
         'pronunciation': 2,
         'if': 16727,
         'thats': 14742,
         'okay': 5946,
         'you': 169693,
         'not': 26494,
         'hacking': 18,
         'gagging': 9,
         'spitting': 15,
         'part': 1260,
         'please': 3258,
         'then': 7532,
         'how': 14001,
         'bout': 393,
         'try

In [9]:
min_word_freq = 5
words = [word for word in word_freq.keys() if (word_freq[word] > min_word_freq)]
word_map = {k: v+1 for v, k in enumerate(words)}
word_map['<unk>'] = len(word_map)+1
word_map['<start>'] = len(word_map)+1
word_map['<end>'] = len(word_map)+1
word_map['<pad>'] = 0


In [10]:
print("Total words are: {}".format(len(word_map)))

Total words are: 18243


In [733]:
# len(word_map)

In [11]:
with open('data/WORDMAP_corpus.json', 'w') as j:
    json.dump(word_map, j)

In [12]:
def encode_question(words, word_map):
    enc_c = [word_map.get(word, word_map['<unk>']) for word in words] + [word_map['<pad>']]*(max_len-len(words))
    # enc_c = torch.LongTensor(enc_c)
    return enc_c

In [13]:
def encode_reply(words, word_map):
    enc_c = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in words] + [word_map['<end>']] + [word_map['<pad>']]*(max_len-len(words))
    # enc_c = torch.LongTensor(enc_c)
    return enc_c

In [737]:
# pairs[0][0]

In [738]:
# encode_question(pairs[0][0], word_map)

In [14]:
pairs_encoded = []
for pair in pairs:
    question = encode_question(pair[0], word_map)
    reply = encode_reply(pair[1], word_map)
    pairs_encoded.append([question, reply])

In [740]:
# pairs_encoded

In [15]:
with open('data/pairs_encoded.json', 'w') as w:
    json.dump(pairs_encoded, w)

In [16]:
class Dataset(Dataset):
    def __init__(self):
        self.pairs = json.load(open('data/pairs_encoded.json', 'r'))
        self.dataset_size = len(self.pairs)
        
    def __len__(self):
        return self.dataset_size
    
    def __getitem__(self, index):
        question = torch.LongTensor(self.pairs[index][0])
        reply = torch.LongTensor(self.pairs[index][1])
        return question, reply
        

In [19]:
train_loader = torch.utils.data.DataLoader(Dataset(),
                                            batch_size=100,
                                            shuffle=True,
                                            pin_memory=True)

In [18]:
train_loader.__sizeof__()

24

In [744]:
# question, reply = next(iter(train_loader))

In [745]:
# question.shape

In [846]:
# set device
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [20]:
def create_masks(question, reply_input, reply_target):

    def subsequent_mask(size):
        "Mask out subsequent positions."
        mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        return mask.unsqueeze(0)
    
    question_mask = question!=0
    question_mask = question_mask.to(device)
    question_mask = question_mask.unsqueeze(1).unsqueeze(1) # (batch_size, 1, 1, max_words)

    reply_input_mask = reply_input!=0
    reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words)
    reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data) # (batch_size, max_words, max_words)
    reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size, 1, max_words, max_words)
    reply_target_mask = reply_target!=0

    return question_mask, reply_input_mask, reply_target_mask



In [748]:
# How subsequent_mask works
# size = 5
# t = torch.ones(size, size)
# t_triu = torch.triu(t)
# t_triu.T # transpose

In [749]:
# question[0] !=0

In [21]:
class Embeddings(nn.Module):
    def __init__(self, vocab_size, d_model, max_len = 50):
        super(Embeddings, self).__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(0.1)
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pe = self.create_positional_encoding(max_len, self.d_model)
        self.dropout = nn.Dropout(0.1)

    def create_positional_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model).to(device)
        for pos in range(max_len): # for each position of the word in the sentence
            for i in range(0, d_model, 2): # for each dimension of each position in the word embedding
                pe[pos, i] = math.sin(pos/(10000**((2*i)/d_model)))
                pe[pos, i+1] = math.cos(pos/(10000**((2*(i+1))/d_model)))
        pe = pe.unsqueeze(0) # include batch dimension (1, max_len, d_model)
        return pe

    # include forward function when using nn.Module    
    def forward(self, encoded_words):
        embeddings = self.embed(encoded_words) * math.sqrt(self.d_model) # (batch_size, max_words, d_model)
        # max_words = embeddings.size(1)
        embeddings += self.pe[:, :embeddings.size(1)] # pe will automatically be expanded to match the batch size of embeddings matrix (1, max_words, d_model)
        embeddings = self.dropout(embeddings)
        return embeddings

In [22]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model):
        super(MultiHeadAttention, self).__init__()
        assert d_model % heads == 0
        self.d_k = d_model // heads
        self.heads = heads
        self.dropout = nn.Dropout(0.1)
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.concat = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        """ 
        query: (batch_size, max_words, d_model); d_model = 512
        key: (batch_size, max_words, d_model)
        value: (batch_size, max_words, d_model)
        """
        batch_size = query.size(0)
        
        # linear layer
        query = self.query(query) # (batch_size, max_words, d_model)
        key = self.key(key)       # (batch_size, max_words, d_model)
        value = self.value(value) # (batch_size, max_words, d_model)
        
        # split into heads
        # (batch_size, max_words, d_model) -> (batch_size, max_words, heads, d_k) -> (batch_size, heads, max_words, d_k)
        query = query.view(batch_size, -1, self.heads, self.d_k)
        key = key.view(batch_size, -1, self.heads, self.d_k)
        value = value.view(batch_size, -1, self.heads, self.d_k)
        
        # transpose to get dimensions (batch_size, heads, max_words, d_k)
        query = query.transpose(1,2)
        key = key.transpose(1,2)
        value = value.transpose(1,2)
        
        # (batch_size, heads, max_words, d_k) dot (batch_size, heads, d_k, max_words) -> (batch_size, heads, max_words, max_words)

        # calculate scores
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.d_k)
        scores = scores.masked_fill(mask==0, -1e9) # (batch_size, heads, max_words, max_words)
        weights = F.softmax(scores, dim=-1) # (batch_size, heads, max_words, max_words)
        weights = self.dropout(weights)
        
        # (batch_size, heads, max_words, max_words) dot (batch_size, heads, max_words, d_k) -> (batch_size, heads, max_words, d_k)
        context = torch.matmul(weights, value) # (batch_size, heads, max_words, d_k)

        # transpose to get dimensions (batch_size, max_words, heads, d_k)
        # combine last two dimensions to concatenate all heads together
        # (batch_size, 8, max_words, 64) -> (batch_size, max_words, 8, d_k) (batch_size, max_words, 8*64)
        context = context.transpose(1,2).contiguous().view(batch_size, -1, self.heads*self.d_k)
        output = self.concat(context) # (batch_size, max_words, d_model)
        return output

In [752]:
# # showing how mask works
# a = torch.randn(2, 2)
# mask = torch.tensor([[1, 0], [0, 1]])

In [753]:
# a

In [754]:
# mask

In [755]:
# a.masked_fill(mask==0, -1e9)

In [23]:
class FeedForward(nn.Module):
    def __init__(self, d_model, middle_dim = 2048):
        super(FeedForward, self).__init__()

        self.fc1 = nn.Linear(d_model, middle_dim)
        self.fc2 = nn.Linear(middle_dim, d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        out = self.fc1(x)
        out = F.relu(out)
        out = self.dropout(out)
        out = self.fc2(out)
        return out
    

In [24]:
class EncoderLayer(nn.Module):

    def __init__(self, d_model, heads):
        super(EncoderLayer, self).__init__()

        self.layer_norm = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(heads, d_model)
        self.ff = FeedForward(d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, embeddings, mask):
        self_attn = self.attn(embeddings, embeddings, embeddings, mask) # (query, key, value, mask)
        self_attn = self.dropout(self_attn)
        self_attn = self.layer_norm(self_attn + embeddings)
        ff = self.ff(self_attn)
        ff = self.dropout(ff)
        encoder_out = self.layer_norm(self_attn + ff)
        return encoder_out


In [25]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, heads):
        super(DecoderLayer, self).__init__()

        self.layer_norm = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(heads, d_model)
        self.ff = FeedForward(d_model)
        self.dropout = nn.Dropout(0.1)

    def forward(self, embeddings, encoder_out, src_mask, target_mask):
        query = self.attn(embeddings, embeddings, embeddings, target_mask) # (query, key, value, mask)
        query = self.dropout(query)
        query = self.layer_norm(query + embeddings)
        src_attn = self.attn(query, encoder_out, encoder_out, src_mask) # (query, key, value, mask)
        src_attn = self.dropout(src_attn)
        src_attn = self.layer_norm(src_attn + query)
        ff = self.ff(src_attn)
        ff = self.dropout(ff)
        decoder_out = self.layer_norm(src_attn + ff)
        return decoder_out


In [26]:
class Transformer(nn.Module):
    def __init__(self, d_model, heads, num_layers, word_map):
        super(Transformer, self).__init__()

        self.d_model = d_model
        self.vocab_size = len(word_map)
        self.embed = Embeddings(self.vocab_size, d_model)
        self.encoder = nn.ModuleList([EncoderLayer(d_model, heads) for _ in range(num_layers)])
        self.decoder = nn.ModuleList([DecoderLayer(d_model, heads) for _ in range(num_layers)])
        self.logit = nn.Linear(d_model, self.vocab_size)

    def encode(self, src_words, src_mask): # src_words aka question
        src_embeddings = self.embed(src_words)
        for layer in self.encoder:
            src_embeddings = layer(src_embeddings, src_mask)
        return src_embeddings

    def decode(self, tgt_words, tgt_mask, src_embedding, src_mask, ): # tgt_words aka reply
        tgt_embeddings = self.embed(tgt_words)
        for layer in self.decoder:
            tgt_embeddings = layer(tgt_embeddings, src_embedding, src_mask, tgt_mask)
        return tgt_embeddings
    
    def forward(self, src_words, src_mask, tgt_words, tgt_mask):
        encoded = self.encode(src_words, src_mask)
        decoded = self.decode(tgt_words, tgt_mask, encoded, src_mask)
        logits = self.logit(decoded)
        out = F.log_softmax(logits, dim=-1) # include manually because KLDivLoss requires log_softmax unlike cross_entropy
        return out

            

In [27]:
class AdamWarmup:

    def __init__(self, model_size, warmup_steps, optimizer):
        self.model_size = model_size
        self.warmup_steps = warmup_steps
        self.optimizer = optimizer
        self.current_step = 0
        self.lr = 0

    def get_lr(self):
        return self.model_size**(-0.5) * min(self.current_step**(-0.5), self.current_step*self.warmup_steps**(-1.5))
    
    def step(self):
        self.current_step += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        self.lr = lr
        # update weights
        self.optimizer.step()

In [28]:
class LossWithLS(nn.Module):
    def __init__(self, size, smoothing):
        super(LossWithLS, self).__init__()
        self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.size = size

    def forward(self, prediction, target, mask):
        """ 
        prediction: (batch_size, max_words, vocab_size)
        target: (batch_size, max_words)
        mask: (batch_size, max_words)
        """
        prediction = prediction.view(-1, prediction.size(-1)) # (batch_size*max_words, vocab_size)
        target = target.contiguous().view(-1) # (batch_size*max_words)
        mask = mask.float()
        mask = mask.view(-1) # (batch_size*max_words)
        labels = prediction.data.clone()
        labels.fill_(self.smoothing / (self.size - 1))
        labels.scatter_(1, target.data.unsqueeze(1), self.confidence)
        loss = self.criterion(prediction, labels)             # (batch_size*max_words, vocab_size)
        loss = (loss.sum(dim=1) * mask).sum() / mask.sum()
        return loss

        

In [762]:
# batch_size = 3
# max_words = 5
# vocab_size = 7
# loss = torch.randn(batch_size * max_words, vocab_size)

In [763]:
# loss.shape

In [764]:
# loss

In [765]:
# loss.sum(1)

In [766]:
# mask = torch.randn(15)
# mask.shape
# mask.float()

In [767]:
# (loss.sum(1) * mask.float()).sum() / mask.sum()

In [768]:
# batch_size = 3
# max_words = 5
# vocab_size = 7
# prediction = torch.randn(batch_size, max_words, vocab_size)

In [769]:
# prediction

In [770]:
# prediction = prediction.view(-1, prediction.shape[-1])

In [771]:
# prediction.shape

In [772]:
# target = torch.LongTensor(batch_size * max_words).random_(0, vocab_size)

In [773]:
# target

In [774]:
# mask = target!=0

In [775]:
# labels = prediction.data.clone()
# labels.shape

In [776]:
# labels.fill_(0.3 / 3 - 1)

In [777]:
# labels.scatter(1, target.data.unsqueeze(1), 0.7)

In [778]:
torch.cuda.is_available()

True

In [779]:
CUDA_LAUNCH_BLOCKING=1

In [29]:
# Define Model, Optimizer, Loss
d_model = 512
heads = 8
num_layers = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 10
CUDA_LAUNCH_BLOCKING=1
TORCH_USE_CUDA_DSA=1
with open('data/WORDMAP_corpus.json', 'r') as j:
    word_map = json.load(j)

transformer = Transformer(d_model=d_model, heads = heads, num_layers = num_layers, word_map = word_map).to(device)
adam_optimizer = torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
transformer_optimizer = AdamWarmup(d_model, 4000, adam_optimizer)
criterion = LossWithLS(len(word_map), 0.1)



In [30]:
# Train
def train(train_loader, transformer, criterion, epoch):
    transformer.train()
    total_loss = 0
    count = 0
    for i, (question, reply) in enumerate(train_loader):
        samples = question.shape[0]

        # Move to GPU, if available
        question = question.to(device)
        reply = reply.to(device)
        
        # Sentence: <start> I went home . <end>
        # reply_input: <start> I went home .
        # reply_target: I went home . <end>

        # Prepare Target Data
        reply_input = reply[:, :-1]
        reply_target = reply[:, 1:]

        # create masks and add dimensions
        question_mask, reply_input_mask, reply_target_mask = create_masks(question, reply_input, reply_target)

        # Run the Transformer model and get outputs
        out = transformer(question, question_mask, reply_input, reply_input_mask)

        # Calculate loss
        loss = criterion(out, reply_target, reply_target_mask)

        # Backprop
        transformer_optimizer.optimizer.zero_grad()
        loss.backward()
        transformer_optimizer.step()

        total_loss += loss.item() * samples
        count += samples
        
        if i % 100 == 0:
            print("Epoch {} | Batch {}/{} | Loss {}".format(epoch+1, i, len(train_loader), total_loss/count))
    return total_loss / len(train_loader)




In [31]:
# Evaluate
def evaluate(transformer, question, question_mask, max_len, word_map):
    """
    Performs Greedy Decoding with a batch size of 1
    """

    rev_word_map = {v: k for k, v in word_map.items()}
    transformer.eval()
    start_token = word_map['<start>']

    # Encode
    encoded = transformer.encode(question, question_mask)
    words = torch.LongTensor([[start_token]]).to(device) # (batch_size, 1)

    # Decoding starts here
    for step in range(max_len-1):
        size = words.shape[0]

        # create target mask
        tgt_mask = torch.triu(torch.ones(size, size)).transpose(0, 1).type(dtype=torch.uint8)
        tgt_mask = tgt_mask.to(device).unsqueeze(0).unsqueeze(0)

        # decode
        # decoded shape: (batch_size, max_words, d_model) = (1, 1, vocab_size)
        decoded = transformer.decode(words, tgt_mask, encoded, question_mask)

        # predict next word using logits and softmax
        # predictions shape: (max_words, vocab_size) = (1, vocab_size)
        predictions = transformer.logit(decoded[:, -1])
        _, next_word = torch.max(predictions, dim=1) #(1, 1)
        next_word = next_word.item()

        if next_word == word_map['<end>']:
            break
    words = torch.cat([words, torch.LongTensor([[next_word]].to(device))], dim=1) #(1, step+2)

    #(1,5) -> (5)
    words = words.squeeze(0)  # (1, step+1) -> (step+1)
    words = words.tolist()

    # concat words to form sentence
    sentence_idx = [w for w in words if w not in {word_map['<start>']}]
    sentence = ' '.join([rev_word_map[sentence_idx[k]] for k in range(len(sentence_idx))])

    return sentence

In [33]:
for epoch in range(epochs):
    train(train_loader, transformer, criterion, epoch)
    state = {'epoch': epoch, 'transformer': transformer, 'transformer_optimizer': transformer_optimizer}
    torch.save(state, 'checkpoint/checkpoint_' + str(epoch) + '.tar')

Epoch 1 | Batch 0/2217 | Loss 4.394916534423828
Epoch 1 | Batch 100/2217 | Loss 4.433915964447626
Epoch 1 | Batch 200/2217 | Loss 4.4352311305145715
Epoch 1 | Batch 300/2217 | Loss 4.42975094310469
Epoch 1 | Batch 400/2217 | Loss 4.428952117215963
Epoch 1 | Batch 500/2217 | Loss 4.4260578326836315
Epoch 1 | Batch 600/2217 | Loss 4.421902313010268
Epoch 1 | Batch 700/2217 | Loss 4.4195305080114515
Epoch 1 | Batch 800/2217 | Loss 4.418576418534944
Epoch 1 | Batch 900/2217 | Loss 4.418793742849877
Epoch 1 | Batch 1000/2217 | Loss 4.418258470731539
Epoch 1 | Batch 1100/2217 | Loss 4.417297190909598
Epoch 1 | Batch 1200/2217 | Loss 4.415683094408987
Epoch 1 | Batch 1300/2217 | Loss 4.414350040136714
Epoch 1 | Batch 1400/2217 | Loss 4.412398444508587
Epoch 1 | Batch 1500/2217 | Loss 4.409784390082922
Epoch 1 | Batch 1600/2217 | Loss 4.408918334870395
Epoch 1 | Batch 1700/2217 | Loss 4.40789369384658
Epoch 1 | Batch 1800/2217 | Loss 4.4075559116217375
Epoch 1 | Batch 1900/2217 | Loss 4.406056

In [35]:
checkpoint = torch.load('./checkpoint/checkpoint_9.tar')
model = checkpoint['transformer']

In [36]:
while(1):
    question = input("Question: ")
    if question == 'quit':
        break
    max_len = input("Enter Max Words to be generated: ")
    enc_question = [word_map.get(word, word_map['<unk>']) for word in question.split()]
    question = torch.LongTensor(enc_question).to(device).unsqueeze(0)
    question_mask = (question!=0).to(device).unsqueeze(1).unsqueeze(1)
    sentence = evaluate(model, question, question_mask, int(max_len), word_map)
    print(sentence)

AttributeError: 'list' object has no attribute 'to'