In [5]:
import torch
from torch import nn
from torch.nn import functional as F

In [1]:
from lib.vectorize import vectorize
ddict = vectorize(sent_size=50)
vocab_size = len(list(ddict['word2index']))

100%|██████████| 549367/549367 [00:05<00:00, 97335.70it/s] 
100%|██████████| 9842/9842 [00:00<00:00, 131429.22it/s]
100%|██████████| 9824/9824 [00:00<00:00, 135284.14it/s]


In [None]:
len(ddict['train_data'])

In [None]:
class MatchLSTM(nn.Module):
    def __init__(self, embedding_size, vocab_size, class_size):
        super(MatchLSTM, self).__init__()
        
        self.embedding_size = embedding_size    # E
        self.vocab_size     = vocab_size        # V
               
        self.embed = nn.Embedding(vocab_size, embedding_size)        #VxE -> #E
        
        self.premise_lstm    = nn.LSTMCell(embedding_size, embedding_size)
        self.hypothesis_lstm = nn.LSTMCell(embedding_size, embedding_size)
        self.match_lstm      = nn.LSTMCell(embedding_size, embedding_size)
        
        self.attend_premise    = nn.Linear(embedding_size, embedding_size, bias=False)
        self.attend_hypothesis = nn.Linear(embedding_size, embedding_size, bias=False)
        self.attend_state      = nn.Linear(embedding_size, embedding_size, bias=False)
        self.attend_match      = nn.Linear(embedding_size, embedding_size, bias=False)
        
        self.scale = nn.Linear(embedding_size, 1)
        self.classify = nn.Linear(embedding_size, class_size)
        
        self.print_sizes = True
        
    def initial_hidden_state(self):
        return Variable(torch.zeros([1, self.embedding_size]))
        
    def printsize(tensor):
        if self.print_sizes:
            print(tensor.size())
            
    def forward(self, premise, hypothesis):
        print('premise:{}'.format(premise.size()))
        print('hypothesis:{}'.format(hypothesis.size()))
        premise_emb    = self.embed(premise)                           #PlxH
        hypothesis_emb = self.embed(hypothesis)                        #HlXH
        
        print('premise_emb:{}'.format(premise_emb.size()))
        print('hypothesis_emb:{}'.format(hypothesis_emb.size()))
        
        hidden_state = self.initial_hidden_state()
        cell_state   = self.initial_hidden_state()
        print('hidden_state:{}'.format(hidden_state.size()))
        hypothesis_states = []                                              #HlxH
        for h in hypothesis_emb:
            hidden_state, cell_state = self.hypothesis_lstm(h.view([1, -1]), 
                                                            (hidden_state, cell_state))
            hypothesis_states.append(hidden_state)
                
        hidden_state = self.initial_hidden_state()
        cell_state   = self.initial_hidden_state()
        premise_states = []                                                 #PlxH
        for p in premise_emb:
            hidden_state, cell_state = self.hypothesis_lstm(p.view([1,-1]),
                                                            (hidden_state, cell_state))
            premise_states.append(hidden_state)
            
        premise_states = torch.stack(premise_states).squeeze(1)                     #PlXH        
        print('premise_states:{}'.format(premise_states.size()))

        hidden_state = self.initial_hidden_state()
        cell_state   = self.initial_hidden_state()
        for h in hypothesis_states:            
            hattn = self.attend_hypothesis(h)                               #1xH
            print('hattn:{}'.format(hattn.size()))
            
            pattn = self.attend_premise(premise_states)                     #PlxH
            print('pattn:{}'.format(pattn.size()))
            
            mattn = self.attend_match(hidden_state)                         #1xH
            print('mattn:{}'.format(mattn.size()))
            
            attn = F.softmax(self.scale( hattn.expand_as(pattn)             #PlxH -> scale ->PlX1
                                        + pattn
                                        + mattn.expand_as(pattn))
                            )                                   
            print('attn:{}'.format(attn.size()))
            
            attn = torch.mm(attn.t(), premise_states)                      #1xPl * PlxH -> 1xH 
            print('attn:{}'.format(attn.size()))
            
            attn_hidden_mat = torch.cat([attn, h])                         # HXH
            print('attn_hidden_mat:{}'.format(attn_hidden_mat.size()))
            
            hidden_state, cell_state = self.match_lstm(attn_hidden_mat, 
                                                       (hidden_state, cell_state))
            print('hidden_state:{}'.format(hidden_state.size()))
            
        attended_match_state = hidden_state
        return F.log_softmax(self.classify(attended_match_state))
    
        
        

In [3]:
def create_one_hot(length, index):
    a = np.zeros([length])
    a[index] = 1
    return a

In [None]:
from torch import optim
from torch.autograd import Variable
def train(epochs, model,  train_batches, print_every = 100):
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.1)
    for epoch in range(epochs+1):
        for batch, sample in enumerate(train_batches):
            premise = Variable(torch.LongTensor(sample[0]))
            hypothesis = Variable(torch.LongTensor(sample[1]))
            judgements = Variable(torch.LongTensor(sample[3]))
            
            optimizer.zero_grad()

            predictions = model(premise, hypothesis)
            loss = F.nll_loss(predictions, judgements)
            loss.backward()
            optimizer.step()

        if epoch % print_every == 0:
            #print([i for i in model.parameters()])
            print('epoch: {}\t\t -- loss: {}'.format(epoch, loss.data[0]))

In [None]:
sample = ddict['train_data'][0]
premise = sample[0]
hypothesis = sample[1]
judgement = sample[2]
judgements = sample[3]

In [6]:
import numpy as np
train_data = ddict['train_data']

"""
for i in range(50):
    for j in range(len(train_data[i])):
        print(train_data[i][j])
        print('--')
    print('==================')
"""
for sample in train_data[20:30]:
    for i in range(len(sample[0])):   #Premise
        sample[0][i] = create_one_hot(vocab_size, sample[0][i])
        
    for i in range(len(sample[1])):   #Hypotheses
        sample[1][i] = create_one_hot(vocab_size, sample[1][i])

    sample[2] = create_one_hot(3, sample[2])         #Judgement
    
    sample[0] = torch.Tensor(sample[0])
    sample[1] = torch.Tensor(sample[1])
    sample[2] = torch.Tensor(sample[2])

print(sample[0].size())
print(sample[1].size())
print(sample[2].size())


IndexError: arrays used as indices must be of integer (or boolean) type

In [None]:
model = MatchLSTM(30, vocab_size, 3)

In [None]:
train(10, model, ddict['train_data'])