In [1]:
from babi_loader import BabiDataset, pad_collate
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from torch.utils.data import DataLoader


def interpret_indexed_tensor(var):
    qa= dset.QA
    if len(var.size()) == 3:
        # var -> n x #sen x #token
        for n, sentences in enumerate(var):
            for i, sentence in enumerate(sentences):
                s = ' '.join([qa.IVOCAB[elem.data[0]] for elem in sentence if qa.IVOCAB[elem.data[0]] != '<EOS>' or qa.IVOCAB[elem.data[0]] !=  '<PAD>'] )
                print(f'{n}th of batch, {i}th sentence, {s}')
    elif len(var.size()) == 2:
        # var -> n x #token
        for n, sentence in enumerate(var):
            s = ' '.join([qa.IVOCAB[elem.data[0]] for elem in sentence if qa.IVOCAB[elem.data[0]] != '<EOS>' or qa.IVOCAB[elem.data[0]] !=  '<PAD>'] )
            print(f'{n}th of batch, {s}')
    elif len(var.size()) == 1:
        # var -> n (one token per batch)
        for n, token in enumerate(var):
            s = qa.IVOCAB[token.data[0]]
            print(f'{n}th of batch, {s}')



In [2]:

from babi_loader import BabiDataset, pad_collate
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from torch.utils.data import DataLoader

import numpy as np

def position_encoding(embedded_sentence):
    
    sentence_length = embedded_sentence.size()[2]
    embedding_length = embedded_sentence.size()[3]
    shape = (embedding_length, sentence_length)
    l = np.empty(shape)

    for word_index in range(sentence_length):
        for e_index in range(embedding_length):
            l[e_index][word_index]=(1 - word_index/(sentence_length-1)) - (e_index/(embedding_length-1)) * (1 - 2*word_index/(sentence_length-1))
    l=l.T
    l = torch.FloatTensor(l)
    l = l.unsqueeze(0) # for #batch
    l = l.unsqueeze(1) # for #sen
    print("embedded_sentence.size() = ",embedded_sentence.size())
    print("before ",(l.size()))
    l = l.expand_as(embedded_sentence)
    print("after ",(l.size()))
    weighted = embedded_sentence * Variable(l)
    var = torch.sum(weighted, dim=2).squeeze(2)
    print("return size", var.size())
    return torch.sum(weighted, dim=2).squeeze(2) # sum with tokens



class AttentionGRUCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(AttentionGRUCell, self).__init__()
        self.hidden_size = hidden_size
        self.Wr = nn.Linear(input_size, hidden_size)
        init.xavier_normal(self.Wr.state_dict()['weight'])
        self.Ur = nn.Linear(hidden_size, hidden_size)
        init.xavier_normal(self.Ur.state_dict()['weight'])
        self.W = nn.Linear(input_size, hidden_size)
        init.xavier_normal(self.W.state_dict()['weight'])
        self.U = nn.Linear(hidden_size, hidden_size)
        init.xavier_normal(self.U.state_dict()['weight'])

    def forward(self, fact, C, g):
        '''
        fact.size() -> (#batch, #hidden = #embedding)
        c.size() -> (#hidden, ) -> (#batch, #hidden = #embedding)
        r.size() -> (#batch, #hidden = #embedding)
        h_tilda.size() -> (#batch, #hidden = #embedding)
        g.size() -> (#batch, )
        '''

        r = F.sigmoid(self.Wr(fact) + self.Ur(C))
        h_tilda = F.tanh(self.W(fact) + r * self.U(C))
        g = g.unsqueeze(1).expand_as(h_tilda)
        h = g * h_tilda + (1 - g) * C
        return h

class AttentionGRU(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(AttentionGRU, self).__init__()
        self.hidden_size = hidden_size
        self.AGRUCell = AttentionGRUCell(input_size, hidden_size)

    def forward(self, facts, G):
        '''
        facts.size() -> (#batch, #sentence, #hidden = #embedding)
        fact.size() -> (#batch, #hidden = #embedding)
        G.size() -> (#batch, #sentence)
        g.size() -> (#batch, )
        C.size() -> (#batch, #hidden)
        '''
        batch_num, sen_num, embedding_size = facts.size()
        C = Variable(torch.zeros(self.hidden_size))
        for sid in range(sen_num):
            fact = facts[:, sid, :]
            g = G[:, sid]
            if sid == 0:
                C = C.unsqueeze(0).expand_as(fact)
            C = self.AGRUCell(fact, C, g)
        return C

class EpisodicMemory(nn.Module):
    def __init__(self, hidden_size):
        super(EpisodicMemory, self).__init__()
        self.AGRU = AttentionGRU(hidden_size, hidden_size)
        self.z1 = nn.Linear(4 * hidden_size, hidden_size)
        self.z2 = nn.Linear(hidden_size, 1)
        print("hi.....")
        self.next_mem = nn.Linear(3 * hidden_size, hidden_size)
        init.xavier_normal(self.z1.state_dict()['weight'])
        init.xavier_normal(self.z2.state_dict()['weight'])
        init.xavier_normal(self.next_mem.state_dict()['weight'])

    def make_interaction(self, facts, questions, prevM):
        '''
        facts.size() -> (#batch, #sentence, #hidden = #embedding)
        questions.size() -> (#batch, 1, #hidden)
        prevM.size() -> (#batch, #sentence = 1, #hidden = #embedding)
        z.size() -> (#batch, #sentence, 4 x #embedding)
        G.size() -> (#batch, #sentence)
        '''
        batch_num, sen_num, embedding_size = facts.size()
        questions = questions.expand_as(facts)
        prevM = prevM.expand_as(facts)

        z = torch.cat([
            facts * questions,
            facts * prevM,
            torch.abs(facts - questions),
            torch.abs(facts - prevM)
        ], dim=2)

        z = z.view(-1, 4 * embedding_size)

        G = F.tanh(self.z1(z))
        G = self.z2(G)
        G = G.view(batch_num, -1)
        G = F.softmax(G)

        return G

    def forward(self, facts, questions, prevM):
        '''
        facts.size() -> (#batch, #sentence, #hidden = #embedding)
        questions.size() -> (#batch, #sentence = 1, #hidden)
        prevM.size() -> (#batch, #sentence = 1, #hidden = #embedding)
        G.size() -> (#batch, #sentence)
        C.size() -> (#batch, #hidden)
        concat.size() -> (#batch, 3 x #embedding)
        '''
        G = self.make_interaction(facts, questions, prevM)
        value, index = torch.max(G, dim=1)
        print("attentions= ",G)
        
        print("focus = ")
        print((index[0]))
        #print("focus = ",interpret_indexed_tensor(contexts))
        #print("focus = ",(value[0]))

        C = self.AGRU(facts, G)
        concat = torch.cat([prevM.squeeze(1), C, questions.squeeze(1)], dim=1)
        next_mem = F.relu(self.next_mem(concat))
        next_mem = next_mem.unsqueeze(1)
        return next_mem


class QuestionModule(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(QuestionModule, self).__init__()
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)

    def forward(self, questions, word_embedding):
        '''
        questions.size() -> (#batch, #token)
        word_embedding() -> (#batch, #token, #embedding)
        gru() -> (1, #batch, #hidden)
        '''
        questions = word_embedding(questions)
        _, questions = self.gru(questions)
        questions = questions.transpose(0, 1)
        return questions

class InputModule(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(InputModule, self).__init__()
        self.hidden_size = hidden_size
        self.gru = nn.GRU(hidden_size, hidden_size, bidirectional=True, batch_first=True)
        for name, param in self.gru.state_dict().items():
            if 'weight' in name: init.xavier_normal(param)
        self.dropout = nn.Dropout(0.1)

    def forward(self, contexts, word_embedding):
        '''
        contexts.size() -> (#batch, #sentence, #token)
        word_embedding() -> (#batch, #sentence x #token, #embedding)
        position_encoding() -> (#batch, #sentence, #embedding)
        facts.size() -> (#batch, #sentence, #hidden = #embedding)
        '''
        batch_num, sen_num, token_num = contexts.size()

        contexts = contexts.view(batch_num, -1)
        contexts = word_embedding(contexts)

        contexts = contexts.view(batch_num, sen_num, token_num, -1)
        contexts = position_encoding(contexts)
        #print("contexts size = ",contexts.size())
        #contexts = self.dropout(contexts)

        h0 = Variable(torch.zeros(2, batch_num, self.hidden_size))
        facts, hdn = self.gru(contexts, h0)
        facts = facts[:, :, :hidden_size] + facts[:, :, hidden_size:]
        #print("facts size = ",facts.size())
        return facts

class AnswerModule(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(AnswerModule, self).__init__()
        self.z = nn.Linear(2 * hidden_size, vocab_size)
        init.xavier_normal(self.z.state_dict()['weight'])
        self.dropout = nn.Dropout(0.1)

    def forward(self, M, questions):
        M = self.dropout(M)
        concat = torch.cat([M, questions], dim=2).squeeze(1)
        z = self.z(concat)
        return z

class DMNPlus(nn.Module):
    def __init__(self, hidden_size, vocab_size, num_hop=3, qa=None):
        super(DMNPlus, self).__init__()
        self.num_hop = num_hop
        self.qa = qa
        self.word_embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0, sparse=True)
        init.uniform(self.word_embedding.state_dict()['weight'], a=-(3**0.5), b=3**0.5)
        self.criterion = nn.CrossEntropyLoss(size_average=False)

        self.input_module = InputModule(vocab_size, hidden_size)
        self.question_module = QuestionModule(vocab_size, hidden_size)
        self.memory = EpisodicMemory(hidden_size)
        self.answer_module = AnswerModule(vocab_size, hidden_size)

    def forward(self, contexts, questions):
        '''
        contexts.size() -> (#batch, #sentence, #token) -> (#batch, #sentence, #hidden = #embedding)
        questions.size() -> (#batch, #token) -> (#batch, 1, #hidden)
        '''
        facts = self.input_module(contexts, self.word_embedding)
        questions = self.question_module(questions, self.word_embedding)
        M = questions
        for hop in range(self.num_hop):
            M = self.memory(facts, questions, M)
        preds = self.answer_module(M, questions)
        return preds

   
    def get_loss(self, contexts, questions, targets):
        output = self.forward(contexts, questions)
        loss = self.criterion(output, targets)
        reg_loss = 0
        for param in self.parameters():
            reg_loss += 0.001 * torch.sum(param * param)
        preds = F.softmax(output)
        _, pred_ids = torch.max(preds, dim=1)
        s = self.qa.IVOCAB[pred_ids.data[0]]
        print("\npredicted answer - ", s)
        corrects = (pred_ids.data == answers.data)
        acc = torch.mean(corrects.float())
        return loss + reg_loss, acc



In [3]:

if __name__ == '__main__':
    task_id = 2
    dset = BabiDataset(task_id)
    vocab_size = len(dset.QA.VOCAB)
    hidden_size = 80
    model = DMNPlus(hidden_size, vocab_size, num_hop=3, qa=dset.QA)
    best_acc = 0
    optim = torch.optim.Adam(model.parameters())
    pretrained_dict_name = "task"+str(task_id)+".pth"
    pretrained_dict = torch.load(pretrained_dict_name, map_location='cpu')
    model_dict = model.state_dict()
    #print(pretrained_dict.keys())
    #print("\n\n\n")
    #print(model_dict.keys())
    dset.set_mode('test')
    test_loader = DataLoader(
        dset, batch_size=1, shuffle=False, collate_fn=pad_collate
    )
    test_acc = 0
    cnt = 0

    for batch_idx, data in enumerate(test_loader):
        contexts, questions, answers = data
        batch_size = contexts.size()[0]
        contexts = Variable(contexts.long())
        questions = Variable(questions.long())
        answers = Variable(answers)
        
        print("contexts -> ")
        interpret_indexed_tensor(contexts)
        print("\n\n questions -> ")

        interpret_indexed_tensor(questions)
        print("\n\n answers -> ")

        interpret_indexed_tensor(answers)


        
        model.load_state_dict(pretrained_dict)
        
        _, acc = model.get_loss(contexts, questions, answers)
        #print(acc)
        break


hi.....
contexts -> 
0th of batch, 0th sentence, mary got the milk there . <EOS> <PAD>
0th of batch, 1th sentence, john moved to the bedroom . <EOS> <PAD>
0th of batch, 2th sentence, sandra went back to the kitchen . <EOS>
0th of batch, 3th sentence, mary travelled to the hallway . <EOS> <PAD>


 questions -> 
0th of batch, where is the milk <EOS>


 answers -> 
0th of batch, hallway
embedded_sentence.size() =  torch.Size([1, 4, 8, 80])
before  torch.Size([1, 1, 8, 80])
after  torch.Size([1, 4, 8, 80])
return size torch.Size([1, 4, 80])
attentions=  Variable containing:
 6.6195e-06  9.4987e-04  1.4478e-04  9.9890e-01
[torch.FloatTensor of size 1x4]

focus = 
Variable containing:
 3
[torch.LongTensor of size 1]

attentions=  Variable containing:
 0.0000  0.0013  0.0003  0.9984
[torch.FloatTensor of size 1x4]

focus = 
Variable containing:
 3
[torch.LongTensor of size 1]

attentions=  Variable containing:
 0.0017  0.0102  0.0016  0.9865
[torch.FloatTensor of size 1x4]

focus = 
Variable 



In [4]:
print(pretrained_dict.keys())

odict_keys(['word_embedding.weight', 'input_module.gru.weight_ih_l0', 'input_module.gru.weight_hh_l0', 'input_module.gru.bias_ih_l0', 'input_module.gru.bias_hh_l0', 'input_module.gru.weight_ih_l0_reverse', 'input_module.gru.weight_hh_l0_reverse', 'input_module.gru.bias_ih_l0_reverse', 'input_module.gru.bias_hh_l0_reverse', 'question_module.gru.weight_ih_l0', 'question_module.gru.weight_hh_l0', 'question_module.gru.bias_ih_l0', 'question_module.gru.bias_hh_l0', 'memory.AGRU.AGRUCell.Wr.weight', 'memory.AGRU.AGRUCell.Wr.bias', 'memory.AGRU.AGRUCell.Ur.weight', 'memory.AGRU.AGRUCell.Ur.bias', 'memory.AGRU.AGRUCell.W.weight', 'memory.AGRU.AGRUCell.W.bias', 'memory.AGRU.AGRUCell.U.weight', 'memory.AGRU.AGRUCell.U.bias', 'memory.z1.weight', 'memory.z1.bias', 'memory.z2.weight', 'memory.z2.bias', 'memory.next_mem.weight', 'memory.next_mem.bias', 'answer_module.z.weight', 'answer_module.z.bias'])
