In [1]:
import torch
from torch import nn
from torch.nn import functional as F

In [2]:
from lib.vectorize import vectorize
ddict = vectorize(sent_size=50)
vocab_size = len(list(ddict['word2index']))

100%|██████████| 549367/549367 [00:05<00:00, 94687.19it/s] 
100%|██████████| 9842/9842 [00:00<00:00, 90806.65it/s]
100%|██████████| 9824/9824 [00:00<00:00, 66129.20it/s]


In [3]:
len(ddict['train_data'])

549367

In [120]:
class MatchLSTM(nn.Module):
    def __init__(self, embedding_size, vocab_size, class_size):
        super(MatchLSTM, self).__init__()
        
        self.embedding_size = embedding_size    # E
        self.vocab_size     = vocab_size        # V
               
        self.embed = nn.Embedding(vocab_size, embedding_size)        #VxE -> #E
        
        self.premise_lstm    = nn.LSTMCell(embedding_size, embedding_size)
        self.hypothesis_lstm = nn.LSTMCell(embedding_size, embedding_size)
        self.match_lstm      = nn.LSTMCell(embedding_size, embedding_size)
        
        self.attend_premise    = nn.Linear(embedding_size, embedding_size, bias=False)
        self.attend_hypothesis = nn.Linear(embedding_size, embedding_size, bias=False)
        self.attend_state      = nn.Linear(embedding_size, embedding_size, bias=False)
        self.attend_match      = nn.Linear(embedding_size, embedding_size, bias=False)
        
        self.scale = nn.Linear(embedding_size, 1)
        self.merge_attention = nn.Linear(2*embedding_size, embedding_size)
        self.classify = nn.Linear(embedding_size, class_size)
        
        self.print_sizes = True
        
    def initial_hidden_state(self):
        return Variable(torch.zeros([1, self.embedding_size]))
        
    def printsize(tensor):
        if self.print_sizes:
            print(tensor.size())
            
    def forward(self, premise, hypothesis):
        #print('premise:{}'.format(premise.size()))
        #print('hypothesis:{}'.format(hypothesis.size()))
        premise_emb    = self.embed(premise)                           #PlxH
        hypothesis_emb = self.embed(hypothesis)                        #HlXH
        
        #print('premise_emb:{}'.format(premise_emb.size()))
        #print('hypothesis_emb:{}'.format(hypothesis_emb.size()))
        
        hidden_state = self.initial_hidden_state()
        cell_state   = self.initial_hidden_state()
        #print('hidden_state:{}'.format(hidden_state.size()))
        hypothesis_states = []                                              #HlxH
        for h in hypothesis_emb:
            hidden_state, cell_state = self.hypothesis_lstm(h.view([1, -1]), 
                                                            (hidden_state, cell_state))
            hypothesis_states.append(hidden_state)
                
        hidden_state = self.initial_hidden_state()
        cell_state   = self.initial_hidden_state()
        premise_states = []                                                 #PlxH
        for p in premise_emb:
            hidden_state, cell_state = self.hypothesis_lstm(p.view([1,-1]),
                                                            (hidden_state, cell_state))
            premise_states.append(hidden_state)
            
        premise_states = torch.stack(premise_states).squeeze(1)                     #PlXH        
        #print('premise_states:{}'.format(premise_states.size()))

        hidden_state = self.initial_hidden_state()
        cell_state   = self.initial_hidden_state()
        for h in hypothesis_states:            
            hattn = self.attend_hypothesis(h)                               #1xH
            #print('hattn:{}'.format(hattn.size()))
            
            pattn = self.attend_premise(premise_states)                     #PlxH
            #print('pattn:{}'.format(pattn.size()))
            
            mattn = self.attend_match(hidden_state)                         #1xH
            #print('mattn:{}'.format(mattn.size()))
            
            attn = F.softmax(self.scale( hattn.expand_as(pattn)             #PlxH -> scale ->PlX1
                                        + pattn
                                        + mattn.expand_as(pattn))
                            )                                   
            #print('attn:{}'.format(attn.size()))
            
            attn = torch.mm(attn.t(), premise_states)                      #1xPl * PlxH -> 1xH 
            #print('attn:{}'.format(attn.size()))
            
            attn_hidden_mat = self.merge_attention(torch.cat([attn, h], 1))                         # HXH
            #print('attn_hidden_mat:{}'.format(attn_hidden_mat.size()))
            
            #print('hidden_state:{}'.format(hidden_state.size()))

            hidden_state, cell_state = self.match_lstm(attn_hidden_mat, 
                                                       (hidden_state, cell_state))
            
            #print('hidden_state:{}'.format(hidden_state.size()))
        attended_match_state = hidden_state
        return F.log_softmax(self.classify(attended_match_state))
    
        
        

In [25]:
import numpy as np
def create_one_hot(length, index):
    a = np.zeros([length])
    a[index] = 1
    return a

In [123]:
from torch import optim
from torch.autograd import Variable
def train(epochs, model,  train_batches, print_every = 100):
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.1)
    for epoch in range(epochs+1):
        for batch, sample in enumerate(train_batches):
            premise = Variable(torch.LongTensor(sample[0]))
            hypothesis = Variable(torch.LongTensor(sample[1]))
            #judgements = Variable(torch.from_numpy(
            #    create_one_hot(3, sample[2])).long().view([1,-1]))
            judgements = Variable(torch.LongTensor([sample[2]]))
            
            optimizer.zero_grad()

            predictions = model(premise, hypothesis)
            #print(predictions.size(), judgements.size())
            loss = F.nll_loss(predictions, judgements)
            loss.backward()
            optimizer.step()

            if batch % print_every == 0:
                #print([i for i in model.parameters()])
                print('epoch: {}\tbatch: {}\t -- loss: {}'.format(epoch, batch, loss.data[0]))

In [109]:
sample = ddict['train_data'][0]
premise = sample[0]
hypothesis = sample[1]
judgement = sample[2]
judgements = sample[3]

In [121]:
model = MatchLSTM(3, vocab_size, 3)

In [None]:
train(10, model, ddict['train_data'])

epoch: 0	batch: 0	 -- loss: 1.1436717510223389
epoch: 0	batch: 100	 -- loss: 1.0857354402542114
epoch: 0	batch: 200	 -- loss: 1.0953764915466309
epoch: 0	batch: 300	 -- loss: 1.1129271984100342
epoch: 0	batch: 400	 -- loss: 1.114245057106018
epoch: 0	batch: 500	 -- loss: 1.1073307991027832
epoch: 0	batch: 600	 -- loss: 1.1390821933746338
epoch: 0	batch: 700	 -- loss: 1.1181352138519287
epoch: 0	batch: 800	 -- loss: 1.0912469625473022
epoch: 0	batch: 900	 -- loss: 1.124155879020691
epoch: 0	batch: 1000	 -- loss: 1.0970820188522339
epoch: 0	batch: 1100	 -- loss: 1.1396733522415161
epoch: 0	batch: 1200	 -- loss: 1.1130518913269043
epoch: 0	batch: 1300	 -- loss: 1.0947747230529785
epoch: 0	batch: 1400	 -- loss: 1.1371382474899292
epoch: 0	batch: 1500	 -- loss: 1.0769662857055664
epoch: 0	batch: 1600	 -- loss: 1.1039656400680542
epoch: 0	batch: 1700	 -- loss: 1.1066042184829712
epoch: 0	batch: 1800	 -- loss: 1.0910228490829468
epoch: 0	batch: 1900	 -- loss: 1.1416176557540894
epoch: 0	batch