In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from nltk.tokenize import word_tokenize
from collections import defaultdict
import torch.nn.init as init
import re
import os

In [2]:
def read_data(fname, word2idx, max_words, max_sentences):
    # stories[story_ind] = [[sentence1], [sentence2], ..., [sentenceN]]
    # questions[question_ind] = {'question': [question], 'answer': [answer], 'story_index': #, 'sentence_index': #}
    stories = dict()
    questions = dict()
    
    if len(word2idx) == 0:
        word2idx['<null>'] = 0

    
    if os.path.isfile(fname):
        with open(fname) as f:
            lines = f.readlines()
    else:
        raise Exception("[!] Data {file} not found".format(file=fname))

    for line in lines:
        words = line.split()
        max_words = max(max_words, len(words))
        
        # Determine whether the line indicates the start of a new story
        if words[0] == '1':
            story_ind = len(stories)
            sentence_ind = 0
            stories[story_ind] = []
        
        # Determine whether the line is a question or not
        if '?' in line:
            is_question = True
            question_ind = len(questions)
            questions[question_ind] = {'question': [], 'answer': [], 'story_index': story_ind, 'sentence_index': sentence_ind}
        else:
            is_question = False
            sentence_ind = len(stories[story_ind])
        
        # Parse and append the words to appropriate dictionary / Expand word2idx dictionary
        sentence_list = []
        for k in range(1, len(words)):
            w = words[k].lower()
            
            # Remove punctuation
            if ('.' in w) or ('?' in w):
                w = w[:-1]
            
            # Add new word to dictionary
            if w not in word2idx:
                word2idx[w] = len(word2idx)
            
            # Append sentence to story dict if not question
            if not is_question:
                sentence_list.append(w)
                
                if '.' in words[k]:
                    stories[story_ind].append(sentence_list)
                    break
            
            # Append sentence and answer to question dict if question
            else:
                sentence_list.append(w)
                
                if '?' in words[k]:
                    answer = words[k + 1].lower()
                    
                    if answer not in word2idx:
                        word2idx[answer] = len(word2idx)
                    
                    questions[question_ind]['question'].extend(sentence_list)
                    questions[question_ind]['answer'].append(answer)
                    break
        
        # Update max_sentences
        max_sentences = max(max_sentences, sentence_ind)
    
    
    
    # Convert the words into indices
    for idx, context in stories.items():
        for i in range(len(context)):
            temp = list(map(word2idx.get, context[i]))
            context[i] = temp
    
    for idx, value in questions.items():
        temp1 = list(map(word2idx.get, value['question']))
        temp2 = list(map(word2idx.get, value['answer']))
        
        value['question'] = temp1
        value['answer'] = temp2
    
    return stories, questions, max_words, max_sentences


def pad_data(stories, questions, max_words, max_sentences):

    # Pad the context into same size with '<null>'
    for idx, context in stories.items():
        for sentence in context:           
            while len(sentence) < max_words:
                sentence.append(0)
        while len(context) < max_sentences:
            context.append([0] * max_words)
    
    # Pad the question into same size with '<null>'
    for idx, value in questions.items():
        while len(value['question']) < max_words:
            value['question'].append(0)


def depad_data(stories, questions):

    for idx, context in stories.items():
        for i in range(len(context)):
            if 0 in context[i]:
                if context[i][0] == 0:
                    temp = context[:i]
                    context = temp
                    break
                else:
                    index = context[i].index(0)
                    context[i] = context[i][:index]

    for idx, value in questions.items():
        if 0 in value['question']:
            index = value['question'].index(0)
            value['question'] = value['question'][:index]

In [3]:
# stories[story_ind] = [[sentence1], [sentence2], ..., [sentenceN]]
# questions[question_ind] = {'question': [question], 'answer': [answer], 'story_index': #, 'sentence_index': #}
data_path = "./data/Memory_End_to-End_Network_Project/bAbi-tasks/QA_bAbI_tasks/en/"
word2idx = {}
max_words = 0
max_sentences = 0
train_fname = data_path + "qa1_single-supporting-fact_train.txt"
test_fname = data_path + "qa1_single-supporting-fact_test.txt"

train_stories, train_questions, max_words, max_sentences = \
    read_data(train_fname, word2idx, max_words, max_sentences)
test_stories, test_questions, max_words, max_sentences = \
    read_data(test_fname, word2idx, max_words, max_sentences)
    
pad_data(train_stories, train_questions, max_words, max_sentences)
pad_data(test_stories, test_questions, max_words, max_sentences)

idx2word = dict(zip(word2idx.values(), word2idx.keys()))

nwords = len(word2idx)

config = defaultdict()
config['vocab_size'] = nwords
config['max_words'] = max_words
config['embedding_size'] = 30
config['n_hops'] = 3
# config['mem_size'] = 50
config['batch_size'] = 32
config['nepoch'] = 100
config['anneal_epoch'] = 25  # anneal the learning rate every <anneal_epoch> epochs
config['babi_task'] = 1  # index of bAbI task for the network to learn
config['init_lr'] = 0.01
config['anneal_rate'] = 0.5
config['init_mean'] = 0.0
config['init_std'] = 0.1
config['max_grad_norm'] = 40  # clip gradients to this norm
config['lin_start'] = False  # True for linear start training, False for otherwise
config['is_test'] = False


In [4]:
# def position_encoding(sentence_size, embedding_dim):
#     encoding = np.ones((embedding_dim, sentence_size), dtype=np.float32)
#     ls = sentence_size + 1
#     le = embedding_dim + 1
#     for i in range(1, le):
#         for j in range(1, ls):
#             encoding[i-1, j-1] = (i - (embedding_dim+1)/2) * (j - (sentence_size+1)/2)
#     encoding = 1 + 4 * encoding / embedding_dim / sentence_size
#     # Make position encoding of time words identity to avoid modifying them
#     encoding[:, -1] = 1.0
#     return np.transpose(encoding)

In [5]:
class AttrProxy(object):
    """
    Translates index lookups into attribute lookups.
    To implement some trick which able to use list of nn.Module in a nn.Module
    see https://discuss.pytorch.org/t/list-of-nn-module-in-a-nn-module/219/2
    """
    def __init__(self, module, prefix):
        self.module = module
        self.prefix = prefix

    def __getitem__(self, i):
        return getattr(self.module, self.prefix + str(i))

In [17]:
class E2EMN(nn.Module):
    def __init__(self, config):
        super(E2EMN, self).__init__()
        
        self.vocab_size = config['vocab_size']
        self.embedding_size = config['embedding_size']
        self.n_hops = config['n_hops']
        self.batch_size = config['batch_size']
        self.nepoch = config['nepoch']
        self.anneal_epoch = config['anneal_epoch']
        self.babi_task = config['babi_task']
        self.init_lr = config['init_lr']
        self.anneal_rate = config['anneal_rate']
        self.init_mean = config['init_mean']
        self.init_std = config['init_std']
        self.max_grad_norm = config['max_grad_norm']
        self.lin_start = config['lin_start']
        self.is_test = config['is_test']
        self.sentence_size = config['max_words']
        
        for hop in range(self.n_hops+1):
            C = nn.Embedding(self.vocab_size, self.embedding_size, padding_idx=0)
            C.weight.data.normal_(0, 0.1)
            self.add_module('C_{}'.format(hop), C)
        self.C = AttrProxy(self, "C_")
        
        self.encoder = nn.Embedding(self.vocab_size, self.embedding_size)
        self.softmax = nn.Softmax(dim=0)            
        self.encoding = Variable(torch.FloatTensor(
            (self.sentence_size, self.embedding_size)), requires_grad=True)
                                                  
            
    def forward(self, story, query):
        
        u = []
        query_embed = self.C[0](query)
        
        u.append(query_embed.unsqueeze(0).sum(dim=1))  # u (1, embedding_size)

        for hop in range(self.n_hops):
            
            embed_A = self.C[hop](story)  # story_size, sentence_size, embedding_size
            embed_A = embed_A.view()
            print(embed_A)
            embed_A = self.encoding(embed_A)
            print(embed_A.shape)
            m_A = torch.sum(embed_A, 1)  # story_size, embedding_size
            prob = self.softmax(m_A @ u[-1].t())
            
            embed_C = self.C[hop](story)  # story_size, sentence_size, embedding_size
            m_C = torch.sum(embed_C, 1)  # story_size, embedding_size
            
            o_k = torch.sum(m_C*prob, 0).unsqueeze(0)  # 1, embedding_size
            
            u_k = u[-1] + o_k
            
        a_hat = self.C[self.n_hops].weight @ u[-1].t()  # vocab_size, 1
        return a_hat.t(), self.softmax(a_hat).t()

In [18]:
mem_network = E2EMN(config)

---

In [19]:
for (story, query, answer) in making_data_set(train_stories, train_questions):
    break

In [24]:
embed_A = mem_network.encoder(story)

In [30]:
embed_A.data.numpy()

array([[[ -7.54053831e-01,  -4.90354300e-01,   3.67807090e-01,
           1.88888717e+00,   1.31979418e+00,  -4.67547089e-01,
           1.57792962e+00,  -1.18782997e+00,   5.67408860e-01,
          -3.37423861e-01,   1.14038646e+00,   1.50130916e+00,
          -9.51272488e-01,   9.33103681e-01,   9.21080470e-01,
           1.26116204e+00,  -2.43288890e-01,  -2.59721041e-01,
           2.84962088e-01,   9.28212523e-01,   1.76599705e+00,
           5.34330234e-02,  -9.47482586e-02,   1.00768852e+00,
           5.00619531e-01,  -8.10726643e-01,  -5.69559276e-01,
          -1.37073264e-01,  -2.05137148e-01,  -4.20069471e-02],
        [ -1.45136309e+00,   8.78992200e-01,   8.56548786e-01,
           2.21644759e-01,   6.13842547e-01,  -3.68373394e-01,
          -1.07878304e+00,  -1.16971470e-01,  -7.10955501e-01,
          -1.57969505e-01,   4.52605546e-01,  -6.66734397e-01,
           5.13251126e-01,  -3.20417657e-02,  -5.75110614e-01,
           1.62349179e-01,  -5.08150220e-01,   4.98368

In [31]:
mem_network.encoding(embed_A.data.numpy())

TypeError: 'Variable' object is not callable

---

In [12]:
len(train_stories), len(train_questions), len(test_stories), len(test_questions)

(200, 1000, 200, 1000)

In [13]:
def data_tensor(data):
    return Variable(torch.LongTensor(data))

In [14]:
def making_data_set(stories, questions, mode='single_task'):
    total_set_number = len(questions.keys())
#     data = []
    for k in range(total_set_number):
        q = questions[k]
        story = stories[q['story_index']][q['sentence_index']-1 : q['sentence_index']+1]
        query = q['question']
        answer = q['answer']
        yield (data_tensor(story), data_tensor(query), data_tensor(answer))
#         data.append((data_tensor(story), data_tensor(query), data_tensor(answer)))
#     return data

In [43]:
optimizer = torch.optim.Adam(mem_network.parameters(), lr=0.01)
loss_F = nn.CrossEntropyLoss()

In [None]:
from torch.utils.data import DataLoader
train_loader = DataLoader(train_data,
                           batch_size=config['batch_size'],
                           num_workers=1,
                           shuffle=True)

In [96]:
def train_single_epoch(train_stories, train_questions):
    for (story, query, answer) in making_data_set(train_stories, train_questions):
        optimizer.zero_grad()
        a_hat, probs = mem_network.forward(story, query)
        
        loss = loss_F(a_hat, answer)
        loss.backward()
        
        optimizer.step()
        
    return loss.data[0]

In [97]:
def evaluate(stories, questions, model):
    acc = 0
    len_data_set = len(questions)
    for (story, query, answer) in making_data_set(stories, questions):
        a_hat, pred_probs = model.forward(story, query)
        pred = pred_probs.data.max(1)[1]
        acc += pred.eq(answer.data).sum()
    
    return acc / len_data_set

In [98]:
for epoch in range(config['nepoch']):
    loss = train_single_epoch(train_stories, train_questions)
    
    if (epoch+1) % 10 == 0:
        train_acc = evaluate(train_stories, train_questions, mem_network)
        test_acc = evaluate(test_stories, test_questions, mem_network)
        print('#{0}: loss: {1:.6f} | train_acc: {2} | test_acc: {3}'.format(
                                                            epoch+1, loss, train_acc, test_acc))


#10: loss: 15.244619 | train_acc: 0.164 | test_acc: 0.162
#20: loss: 24.808598 | train_acc: 0.155 | test_acc: 0.171
#30: loss: 3.690599 | train_acc: 0.151 | test_acc: 0.165
#40: loss: 6.681912 | train_acc: 0.173 | test_acc: 0.169
#50: loss: 2.834771 | train_acc: 0.167 | test_acc: 0.172
#60: loss: 15.150473 | train_acc: 0.166 | test_acc: 0.182
#70: loss: 9.215902 | train_acc: 0.163 | test_acc: 0.177
#80: loss: 1.771870 | train_acc: 0.177 | test_acc: 0.175
#90: loss: 26.715284 | train_acc: 0.163 | test_acc: 0.176
#100: loss: 22.503906 | train_acc: 0.166 | test_acc: 0.182
