In [77]:
from collections import Counter
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.utils.data 
import math
import torch.nn.functional as F
device = torch.device('mps')

In [52]:
corpus_movie_conv = 'archive/movie_conversations.txt'
corpus_movie_lines = 'archive/movie_lines.txt'
max_len = 25


In [53]:
with open(corpus_movie_conv,'r') as c:
    conv = c.readlines()
    
with open(corpus_movie_lines,'r', encoding='iso-8859-1') as l:
    lines = l.readlines()

In [54]:
lines_dict = {}

for line in lines:
    objects = line.split(" +++$+++ ")
    
    lines_dict[objects[0]] = objects[-1]

In [55]:
def remove_punc(string):
    punctuations = '''!()-[]{};:'"\,<>./?@#$%^&*'''
    no_punc = " "
    
    for char in string:
        if char not in punctuations:
            no_punc = no_punc + char
            
    return no_punc.lower()

In [56]:
pairs = []
for con in conv:
    ids = eval(con.split(" +++$+++ ")[-1])
    for i in range(len(ids)):
        qa_pairs = []
        
        if i==len(ids)-1:
            break
        
        first = remove_punc(lines_dict[ids[i]].strip())      
        second = remove_punc(lines_dict[ids[i+1]].strip())
        qa_pairs.append(first.split()[:max_len])
        qa_pairs.append(second.split()[:max_len])
        pairs.append(qa_pairs)

### Remove words that occurs less than 5 times


In [57]:
word_freq = Counter()

for pair in pairs:
    word_freq.update(pair[0])
    word_freq.update(pair[1])
    

In [58]:
min_word_freq = 5
words = []
for w in word_freq.keys():
    if word_freq[w] > min_word_freq:
        words.append(w)
    
word_map = {}
for v,k in enumerate(words):
    word_map[k] = v+1
word_map['<unk>'] = len(word_map) + 1
word_map['<start>'] = len(word_map) + 1
word_map['<end>'] = len(word_map) + 1
word_map['<pad>'] = 0

In [59]:
with open('WORDMAP_corpus.json','w') as j:
    json.dump(word_map,j)

In [60]:
def encode_question(words,word_map):
    
    enc_c = [word_map.get(word,word_map['<unk>']) for word in words] + [word_map['<pad>']] * (max_len - len(words))
    return enc_c

In [61]:
def encode_reply(words,word_map):
    enc_c = [word_map['<start>']] + [word_map.get(word,word_map['<unk>']) for word in words] + [word_map['<end>']] + [word_map['<pad>']] * (max_len - len(words))
    return enc_c

In [62]:
## Testing
# encode_question(pairs[0][0],word_map)

In [63]:
pairs_encoded = []

for pair in pairs:
    ques = encode_question(pair[0],word_map)
    ans = encode_reply(pair[1],word_map)
    
    pairs_encoded.append([ques,ans])
    


In [64]:
with open('pairs_encoded.json','w') as w:
    json.dump(pairs_encoded,w)
    

In [65]:
class Dataset(Dataset):
    
    def __init__(self):
        self.pairs = json.load(open('pairs_encoded.json'))
        self.dataset_size = len(self.pairs)
        
    def __getitem__(self,i):
        # Return 1 pair of question and reply
        question = torch.LongTensor(self.pairs[i][0])
        reply = torch.LongTensor(self.pairs[i][1])
        
        return question , reply
    
    def __len__(self):
        return self.dataset_size
        

In [66]:
train_loader = torch.utils.data.DataLoader(Dataset(),
                                          batch_size = 100,
                                          shuffle = True,
                                          pin_memory = True)

question,reply = next(iter(train_loader))
question.shape

torch.Size([100, 25])

In [2]:
def create_masks(question,reply_input,reply_target):
    def subsequent_mask(size):
        mask = torch.triu(torch.ones(size,size)).transpose(0,1).type(dtype=torch.uint8)
        return mask.unsequeeze(0)
    
    question_mask = (question!=0).to(device)
    
    question_mask = question_mask.unsqueeze(1).unsqueeze(1) # (batch_size ,1 ,1 , max_words)
    
    
    reply_input_mask = reply_input!=0
    reply_input_mask = reply_input_mask.unsqueeze(1) # (batch_size,1,max_words)
    reply_input_mask = reply_input_mask & subsequent_mask(reply_input.size(-1)).type_as(reply_input_mask.data)
    # (batch_size,max_words,max_words)
    reply_input_mask = reply_input_mask.unsqueeze(1)
    reply_target_mask = reply_target!=0
    
    return question_mask,reply_input_mask,reply_target_mask

In [80]:
size = 5
mask = torch.triu(torch.ones(size,size)).transpose(0,1)


In [None]:
class Embedding(nn.Module):
    def __init__(self,vocab_size,d_model, max_len = 50):
        super(Embeddings,self).__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(0.1)
        self.embed = nn.Embedding(vocab_size,d_model)
        self.pe = self.create_positinal_encoding(max_len, self.d_model) #positional encoding, call function after creating
        self.dropout = nn.Dropout(0.1)
        
    def create_positinal_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model).to(device)
        for pos in range(max_len):   # for each position of the word
            for i in range(0, d_model, 2):   # for each dimension of the each position
                pe[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
                pe[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
        pe = pe.unsqueeze(0)   # include the batch size
        return pe
        
    def forward(self, encoded_words):
        embeddings = self.embed(encoded_words) * math.sqrt(self.d_model) # (batch_size, max_words, d_model)
        # max_words = embeddings.size(1)
        embeddings += self.pe[:,embdeddings.size(1)] # pe will automatically be expanded to the same batch_size as embeddings
        embeddings = self.dropout(embeddings)
        return embeddings
        
        
    

In [1]:
class MultiHeadAttention(nn.Module):
    def __init__(self,head,d_model):
        super(MultiHeadAttention,self).__init__()
        assert d_model % heads == 0
        self.d_k = d_model // heads
        self.heads = heads
        self.dropout = nn.Dropout(0.1)
        self.query = 
        self.key = 
        self.value = 
        
        

In [None]:
! git add .
