In [1]:
from babi_loader import BabiDataset, pad_collate
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.autograd import Variable
from torch.utils.data import DataLoader



task_id = 1
batch_size=2




dset_train = BabiDataset(task_id)
train_loader = DataLoader(
dset_train, batch_size=batch_size, shuffle=False, collate_fn=pad_collate
)


for batch_idx, data in enumerate(train_loader):
    contexts, questions, answers = data
    batch_size = contexts.size()[0]
    contexts = Variable(contexts.long())
    questions = Variable(questions.long())
    answers = Variable(answers)
    break



def interpret_indexed_tensor(var):
    qa= dset_train.QA
    if len(var.size()) == 3:
        # var -> n x #sen x #token
        for n, sentences in enumerate(var):
            for i, sentence in enumerate(sentences):
                s = ' '.join([qa.IVOCAB[elem.data[0]] for elem in sentence ] )
                print(f'{n}th of batch, {i}th sentence, {s}')
    elif len(var.size()) == 2:
        # var -> n x #token
        for n, sentence in enumerate(var):
            s = ' '.join([qa.IVOCAB[elem.data[0]] for elem in sentence ] )
            print(f'{n}th of batch, {s}')
    elif len(var.size()) == 1:
        # var -> n (one token per batch)
        for n, token in enumerate(var):
            s = qa.IVOCAB[token.data[0]]
            print(f'{n}th of batch, {s}')






In [2]:
#Demo
print("contexts -> ")
interpret_indexed_tensor(contexts)
print("\n\n questions -> ")

interpret_indexed_tensor(questions)
print("\n\n answers -> ")

interpret_indexed_tensor(answers)


contexts -> 
0th of batch, 0th sentence, mary moved to the bathroom . <EOS> <PAD>
0th of batch, 1th sentence, john went to the hallway . <EOS> <PAD>
0th of batch, 2th sentence, <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>
0th of batch, 3th sentence, <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>
1th of batch, 0th sentence, mary moved to the bathroom . <EOS> <PAD>
1th of batch, 1th sentence, john went to the hallway . <EOS> <PAD>
1th of batch, 2th sentence, daniel went back to the hallway . <EOS>
1th of batch, 3th sentence, sandra moved to the garden . <EOS> <PAD>


 questions -> 
0th of batch, where is mary <EOS>
1th of batch, where is daniel <EOS>


 answers -> 
0th of batch, bathroom
1th of batch, hallway


In [3]:
import sys, os

var_sysout = sys.stdout

# Disable
def blockPrint():
    sys.stdout = open(os.devnull, 'w')

# Restore
def enablePrint():
    sys.stdout = var_sysout


In [4]:
#Demo
# this will be used to enable or disable prints

print ('This will print')

blockPrint()

print ("This won't")

enablePrint()

print ("This will too")


This will print
This will too


In [5]:
qa= dset_train.QA
qa.VOCAB

{'.': 7,
 '<EOS>': 1,
 '<PAD>': 0,
 'back': 14,
 'bathroom': 6,
 'bedroom': 20,
 'daniel': 13,
 'garden': 16,
 'hallway': 10,
 'is': 12,
 'john': 8,
 'journeyed': 18,
 'kitchen': 21,
 'mary': 2,
 'moved': 3,
 'office': 17,
 'sandra': 15,
 'the': 5,
 'to': 4,
 'travelled': 19,
 'went': 9,
 'where': 11}

In [6]:
vocab_size = len(dset_train.QA.VOCAB)
hidden_size = 80

In [7]:
import numpy as np

def position_encoding(embedded_sentence):
    
    sentence_length = embedded_sentence.size()[2]
    embedding_length = embedded_sentence.size()[3]
    shape = (embedding_length, sentence_length)
    l = np.empty(shape)

    for word_index in range(sentence_length):
        for e_index in range(embedding_length):
            l[e_index][word_index]=(1 - word_index/(sentence_length-1)) - (e_index/(embedding_length-1)) * (1 - 2*word_index/(sentence_length-1))
    l=l.T
    l = torch.FloatTensor(l)
    l = l.unsqueeze(0) # for #batch
    l = l.unsqueeze(1) # for #sen
    print("embedded_sentence.size() = ",embedded_sentence.size())
    print("before ",(l.size()))
    l = l.expand_as(embedded_sentence)
    print("after ",(l.size()))
    weighted = embedded_sentence * Variable(l)
    var = torch.sum(weighted, dim=2).squeeze(2)
    print("return size", var.size())
    return torch.sum(weighted, dim=2).squeeze(2) # sum with tokens



In [8]:
#it convers f to ҃f
def input_module(contexts, word_embedding):
    gru = nn.GRU(hidden_size, hidden_size, bidirectional=True, batch_first=True)
    for name, param in gru.state_dict().items():
        if 'weight' in name: init.xavier_normal(param)

    dropout = nn.Dropout(0.1)
    
    
    print("context dimension before word embedding",contexts.size())
    
    batch_num, sen_num, token_num = contexts.size()
    #print("before view ", contexts.size())
    contexts = contexts.view(batch_num, -1)
    #print("after view ", contexts.size())

    contexts = word_embedding(contexts)
    print("context dimension after word embedding",contexts.size())

    contexts = contexts.view(batch_num, sen_num, token_num, -1)

    print("context dimension ",contexts.size())

    contexts = position_encoding(contexts)
    
    contexts = dropout(contexts)
    
    print("contexts dimension after position_encoding = ", (contexts.size()))


    h0 = Variable(torch.zeros(2, batch_num, hidden_size))
    print("h0 size = ", (h0.size()))

    facts, hdn = gru(contexts, h0)
    print("facts size = ", (facts.size()))
    print("hdn size = ",(hdn.size()))
    facts = facts[:, :, :hidden_size] + facts[:, :, hidden_size:]
    
    print("final fact (context) size returned from input module", facts.size())
    return facts

In [9]:
#Demo

word_embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0, sparse=True)

facts=input_module(contexts,word_embedding)

context dimension before word embedding torch.Size([2, 4, 8])
context dimension after word embedding torch.Size([2, 32, 80])
context dimension  torch.Size([2, 4, 8, 80])
embedded_sentence.size() =  torch.Size([2, 4, 8, 80])
before  torch.Size([1, 1, 8, 80])
after  torch.Size([2, 4, 8, 80])
return size torch.Size([2, 4, 80])
contexts dimension after position_encoding =  torch.Size([2, 4, 80])
h0 size =  torch.Size([2, 2, 80])
facts size =  torch.Size([2, 4, 160])
hdn size =  torch.Size([2, 2, 80])
final fact (context) size returned from input module torch.Size([2, 4, 80])


In [10]:
def question_module(questions, word_embedding):
    '''
    questions.size() -> (#batch, #token)
    word_embedding() -> (#batch, #token, #embedding)
    gru() -> (1, #batch, #hidden)
    '''
    
    gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
    print("before ",(questions.size()))
    questions = word_embedding(questions)
    print("after word embedding ",(questions.size()))
    _, questions = gru(questions)
    print("after gru ",(questions.size()))

    questions = questions.transpose(0, 1)
    print("after transpose ",(questions.size()))

    return questions



In [11]:
#Demo
ques = question_module(questions, word_embedding)

before  torch.Size([2, 4])
after word embedding  torch.Size([2, 4, 80])
after gru  torch.Size([1, 2, 80])
after transpose  torch.Size([2, 1, 80])


In [12]:
#returns attention vector g (containing scalar values) based on facts, questions and prevM

z1 = nn.Linear(4 * hidden_size, hidden_size)
z2 = nn.Linear(hidden_size, 1)
init.xavier_normal(z1.state_dict()['weight'])
init.xavier_normal(z2.state_dict()['weight'])

def make_interaction(facts, questions, prevM):
    '''
    facts.size() -> (#batch, #sentence, #hidden = #embedding)
    questions.size() -> (#batch, 1, #hidden)
    prevM.size() -> (#batch, #sentence = 1, #hidden = #embedding)
    z.size() -> (#batch, #sentence, 4 x #embedding)
    G.size() -> (#batch, #sentence)
    '''
   
    batch_num, sen_num, embedding_size = facts.size()
    questions = questions.expand_as(facts)
    prevM = prevM.expand_as(facts)

    z = torch.cat([
        facts * questions,
        facts * prevM,
        torch.abs(facts - questions),
        torch.abs(facts - prevM)
    ], dim=2)
    print("z.size after concatenation ",z.size())
    
    z = z.view(-1, 4 * embedding_size)
    print("z.size after view ",z.size())
    G = F.tanh(z1(z))
    G = z2(G)
    print("G.size = ",G.size())
    G = G.view(batch_num, -1)
    print("G.size = ",G.size())

    G = F.softmax(G)

    return G



In [13]:
#DEMO
#calculating attention

prevM=ques
G=make_interaction(facts, ques, prevM)
value, index = torch.max(G, dim=1)
print(G)
print(value)
print(index)
print(interpret_indexed_tensor(contexts[:, index[0], :]))

z.size after concatenation  torch.Size([2, 4, 320])
z.size after view  torch.Size([8, 320])
G.size =  torch.Size([8, 1])
G.size =  torch.Size([2, 4])
Variable containing:
 0.1932  0.2221  0.3125  0.2721
 0.1375  0.1383  0.3666  0.3575
[torch.FloatTensor of size 2x4]

Variable containing:
 0.3125
 0.3666
[torch.FloatTensor of size 2]

Variable containing:
 2
 2
[torch.LongTensor of size 2]

0th of batch, 0th sentence, <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD> <PAD>
1th of batch, 0th sentence, daniel went back to the hallway . <EOS>
None




In [14]:
def AGRUCell(fact, C, g):
    #C is h_(i-1) i.e previous hidden state
    input_size=hidden_size
    Wr = nn.Linear(input_size, hidden_size)
    init.xavier_normal(Wr.state_dict()['weight'])
    Ur = nn.Linear(hidden_size, hidden_size)
    init.xavier_normal(Ur.state_dict()['weight'])
    W = nn.Linear(input_size, hidden_size)
    init.xavier_normal(W.state_dict()['weight'])
    U = nn.Linear(hidden_size, hidden_size)
    init.xavier_normal(U.state_dict()['weight'])

    r = F.sigmoid(Wr(fact) + Ur(C))
    #print("r.size = ", r.size())
    h_tilda = F.tanh(W(fact) + r * U(C))
    #print("h_tilda.size = ", h_tilda.size())
    print("g.size = ", g.size())
    g = g.unsqueeze(1).expand_as(h_tilda)
    print("g.size = ", g.size())

    h = g * h_tilda + (1 - g) * C
    return h

#AttaintionGRU_forward
def AGRU(facts, G):
    batch_num, sen_num, embedding_size = facts.size()
    C = Variable(torch.zeros(hidden_size)) #previous hidden state of GRU, initally zero
    for sid in range(sen_num):
        fact = facts[:, sid, :] #taking all batch facts at a time
        print("fact.size() = ",fact.size())
        g = G[:, sid]
        print("g.size = ", g.size())
        if sid == 0:
            C = C.unsqueeze(0).expand_as(fact)
            print("C.size = ",C.size())
        C = AGRUCell(fact, C, g)
        break #running only one time for demonstration purpose
    return C #final hidden state of AGRU



In [15]:
#Demo
#calculating last hidden state of attention gru
C= AGRU(facts, G)  
print(C.size())

fact.size() =  torch.Size([2, 80])
g.size =  torch.Size([2])
C.size =  torch.Size([2, 80])
g.size =  torch.Size([2])
g.size =  torch.Size([2, 80])
torch.Size([2, 80])


In [16]:
def memory(facts, questions, prevM):
    next_mem = nn.Linear(3 * hidden_size, hidden_size)
    init.xavier_normal(next_mem.state_dict()['weight'])
    blockPrint()
    G = make_interaction(facts, questions, prevM)
    C = AGRU(facts, G)
    enablePrint()
    concat = torch.cat([prevM.squeeze(1), C, questions.squeeze(1)], dim=1)
    print("concat.size() = ",concat.size())
    next_mem = F.relu(next_mem(concat))
    next_mem = next_mem.unsqueeze(1)
    print("next_mem.size() = ",next_mem.size())
    return next_mem



In [17]:
#Demo
#calculating next memory based on prevM, facts and questions
next_mem=memory(facts, ques, prevM)

concat.size() =  torch.Size([2, 240])
next_mem.size() =  torch.Size([2, 1, 80])




In [18]:
def answer_module(M, questions):
    print(next_mem.size())
    print(questions.size())
    z = nn.Linear(2 * hidden_size, vocab_size)
    init.xavier_normal(z.state_dict()['weight'])
    dropout = nn.Dropout(0.1)
    M = dropout(M)
    concat = torch.cat([M, questions], dim=2).squeeze(1)
    z = z(concat)
    print(z.size())
    return z

In [19]:
#Demo
#calulating final answer based on question and last memory
blockPrint()
ans=answer_module(next_mem, ques)
enablePrint()
print(ans.size())
print("vocab_size = ",vocab_size)

torch.Size([2, 22])
vocab_size =  22


In [20]:
class DMNPlus(nn.Module):
    def __init__(self, hidden_size, vocab_size, num_hop=3, qa=None):
        super(DMNPlus, self).__init__()
        self.num_hop = num_hop
        self.qa = qa
        self.word_embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0, sparse=True)
        init.uniform(self.word_embedding.state_dict()['weight'], a=-(3**0.5), b=3**0.5)
        self.criterion = nn.CrossEntropyLoss(size_average=False)

        
    def forward(self, contexts, questions):
        '''
        contexts.size() -> (#batch, #sentence, #token) -> (#batch, #sentence, #hidden = #embedding)
        questions.size() -> (#batch, #token) -> (#batch, 1, #hidden)
        '''
        facts = input_module(contexts, self.word_embedding)
        questions = question_module(questions, self.word_embedding)
        M = questions
        for hop in range(self.num_hop):
            M = memory(facts, questions, M)
        preds = answer_module(M, questions)
        return preds

    
    def get_loss(self, contexts, questions, targets):
        
        blockPrint()
        output = self.forward(contexts, questions)
        enablePrint() 
        print("actual answers size = ", targets.size())
        print("predicted answers size = ", output.size())

        loss = self.criterion(output, targets)
        print("loss size = ", loss.size())
        print("loss  = ", loss)

        reg_loss = 0
        print("self.parameters() = ",self.parameters)
        for param in self.parameters():
            print("param.size()=",param.size())
            print("(param * param).size()=",(param * param).size())
            #regularisation
            #param is embedding weight matrix
            reg_loss += 0.001 * torch.sum(param * param)
        preds = F.softmax(output)
        _, pred_ids = torch.max(preds, dim=1)
        print(pred_ids.size())
        corrects = (pred_ids.data == answers.data)
        acc = torch.mean(corrects.float())
        enablePrint() 
        return loss + reg_loss, acc



In [21]:
model = DMNPlus(hidden_size, vocab_size, num_hop=3, qa=dset_train.QA)
loss=model.get_loss(contexts, questions, answers)

concat.size() =  torch.Size([2, 240])
next_mem.size() =  torch.Size([2, 1, 80])
concat.size() =  torch.Size([2, 240])
next_mem.size() =  torch.Size([2, 1, 80])
concat.size() =  torch.Size([2, 240])
next_mem.size() =  torch.Size([2, 1, 80])
torch.Size([2, 1, 80])
torch.Size([2, 1, 80])
torch.Size([2, 22])
actual answers size =  torch.Size([2])
predicted answers size =  torch.Size([2, 22])
loss size =  torch.Size([1])
loss  =  Variable containing:
 6.2372
[torch.FloatTensor of size 1]

self.parameters() =  <bound method Module.parameters of DMNPlus(
  (word_embedding): Embedding(22, 80, padding_idx=0, sparse=True)
  (criterion): CrossEntropyLoss(
  )
)>
param.size()= torch.Size([22, 80])
(param * param).size()= torch.Size([22, 80])
torch.Size([2])


