In [1]:
import logging
import h5py
import torch
import torch._utils

try:
    torch._utils._rebuild_tensor_v2
except AttributeError:
    def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):
        tensor = torch._utils._rebuild_tensor(storage, storage_offset, size, stride)
        tensor.requires_grad = requires_grad
        tensor._backward_hooks = backward_hooks
        return tensor
    torch._utils._rebuild_tensor_v2 = _rebuild_tensor_v2

import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import time
import numpy as np
import sys
from models.baseline_snli import encoder
from models.baseline_snli import binary_label_atten
import argparse
from models.snli_data import snli_data
from models.snli_data import w2v
from random import shuffle
from models.baseline_snli import SeqAttnMatch

In [22]:
train_file = '/homes/rpujari/scratch/raj_qa/preprocess/data/multirc_bin-train.hdf5'
dev_file = '/homes/rpujari/scratch/raj_qa/preprocess/data/st_bin-test.hdf5'
test_file = '/homes/rpujari/scratch/raj_qa/preprocess/data/multirc_bin-val.hdf5'
w2v_file = '/homes/rpujari/scratch/raj_qa/preprocess/data/glove.hdf5'
log_dir = '/homes/rpujari/scratch/raj_qa/parikh_nli/trained_model/'
log_fname = 'snli_bin_pred.log'
gpu_id = 1
embedding_size = 300
epoch = 250
dev_interval = 1
optimizer ='Adagrad'
Adagrad_init = 0.
lr = 0.05
hidden_size = 300
max_length = -1
display_interval = 1000
max_grad_norm = 5
para_init = 0.1
weight_decay = 1e-5
model_path = '/homes/rpujari/scratch/raj_qa/parikh_nli/trained_model/'
trained_encoder = '/homes/rpujari/scratch/raj_qa/preprocess/saved_models/snli_bin_1_epoch-241_dev-acc-0.781_input-encoder.pt'
trained_attn = '/homes/rpujari/scratch/raj_qa/preprocess/saved_models/snli_bin_1_epoch-241_dev-acc-0.781_inter-atten.pt'
seq_attn = '/homes/rpujari/scratch/raj_qa/preprocess/saved_models/snli_bin_1_epoch-241_dev-acc-0.781_seq-atten.pt'
input_optimizer = '/homes/rpujari/scratch/raj_qa/preprocess/saved_models/snli_bin_1_epoch-241_dev-acc-0.781_input-optimizer.pt'
inter_atten_optimizer = '/homes/rpujari/scratch/raj_qa/preprocess/saved_models/snli_bin_1_epoch-241_dev-acc-0.781_inter-atten-optimizer.pt'
seq_atten_optimizer = '/homes/rpujari/scratch/raj_qa/preprocess/saved_models/snli_bin_1_epoch-241_dev-acc-0.781_seq-atten-optimizer.pt' 
test_mode = True
resume = False
train_entailment = True
train_contradiction = True

In [20]:
if max_length < 0:
    max_length = 9999

# initialize the logger
# create logger
logger_name = "mylog"
logger = logging.getLogger(logger_name)
logger.setLevel(logging.INFO)

# file handler
fh = logging.FileHandler(log_dir + log_fname)
fh.setLevel(logging.INFO)
logger.addHandler(fh)

# stream handler
console = logging.StreamHandler()
console.setLevel(logging.INFO)
logger.addHandler(console)

torch.cuda.set_device(gpu_id)

#load train data
logger.info('loading data...')
train_data = snli_data(train_file, max_length)
train_batches = train_data.batches
train_id_batches = train_data.id_batches
logger.info('train size # sent ' + str(train_data.size))

# load test data
logger.info('loading data...')
dev_data = snli_data(dev_file, max_length)
dev_batches = dev_data.batches
dev_id_batches = dev_data.id_batches
logger.info('dev size # sent ' + str(dev_data.size))

# load test data
logger.info('loading data...')
test_data = snli_data(test_file, max_length)
test_batches = test_data.batches
test_id_batches = test_data.id_batches
logger.info('test size # sent ' + str(test_data.size))


loading data...
loading data...
loading data...
loading data...
loading data...
train size # sent 113762
train size # sent 113762
train size # sent 113762
train size # sent 113762
train size # sent 113762
loading data...
loading data...
loading data...
loading data...
loading data...
dev size # sent 9802
dev size # sent 9802
dev size # sent 9802
dev size # sent 9802
dev size # sent 9802
loading data...
loading data...
loading data...
loading data...
loading data...
test size # sent 23285
test size # sent 23285
test size # sent 23285
test size # sent 23285
test size # sent 23285


In [4]:
# get input embeddings
logger.info('loading input embeddings...')
word_vecs = w2v(w2v_file).word_vecs

#loading trained model
input_encoder = encoder(word_vecs.size()[0], embedding_size, hidden_size, para_init)
input_encoder.embedding.weight.data.copy_(word_vecs)
input_encoder.embedding.weight.requires_grad = False
seq_atten = SeqAttnMatch(hidden_size, para_init)
inter_atten = binary_label_atten(hidden_size, 2, para_init)

input_encoder.cuda(gpu_id)
inter_atten.cuda(gpu_id)
seq_atten.cuda(gpu_id)    

loading input embeddings...


SeqAttnMatch (
  (linear): Linear (300 -> 300)
)

In [23]:
logger.info('loading trained model.')    
input_encoder.load_state_dict(torch.load(trained_encoder, map_location={'cuda:0':'cuda:1'}))
inter_atten.load_state_dict(torch.load(trained_attn, map_location={'cuda:0':'cuda:1'}))
seq_atten.load_state_dict(torch.load(seq_attn, map_location={'cuda:0':'cuda:1'}))

loading trained model.
loading trained model.
loading trained model.
loading trained model.
loading trained model.


In [9]:
with open('/homes/rpujari/scratch/parikh_nli/preprocess/decomp-attn/data/snli.word.dict', 'r') as infile:
    vocab_dict = {}
    flines = infile.read().split('\n')
    for line in flines:
        cols = line.split()
        if len(cols) == 2:
            vocab_dict[int(cols[1])] = cols[0]

In [26]:
debug = False

input_encoder.eval()
seq_atten.eval()
inter_atten.eval()

out_file = open(log_dir + 'test_out.txt', 'w') 
err_file = open(log_dir + 'test_err.txt', 'w')

tot_corr = 0.0
tot_eg = 0.0

ent_c = 0.
contr_c = 0.
neu_c = 0.

ent_pr = 0.
contr_pr = 0.
neu_pr = 0.

ent_tot = 0.
contr_tot = 0.
neu_tot = 0.

have_ques = test_data.have_ques
have_ques = 0

for i in range(len(test_batches)):
    test_src_batch, test_tgt_batch, test_ques_batch, test_lbl_batch = test_batches[i]
    test_src_ids, test_targ_ids = test_id_batches[i]

    test_src_batch = Variable(test_src_batch.cuda(gpu_id))
    test_tgt_batch = Variable(test_tgt_batch.cuda(gpu_id))
    test_ques_batch = Variable(test_ques_batch.cuda(gpu_id))
    test_lbl_batch = Variable(test_lbl_batch.cuda(gpu_id))

    test_src_linear, test_tgt_linear, test_ques_linear=input_encoder(
        test_src_batch, test_tgt_batch, test_ques_batch)

    if have_ques == 1:
        #Prepare masks
        test_ques_mask = Variable(torch.from_numpy(np.zeros(test_ques_linear.data.shape[:2])).byte().cuda(gpu_id))
        test_src_linear = seq_atten.forward(test_src_linear, test_ques_linear, test_ques_mask)
        test_tgt_linear = seq_atten.forward(test_tgt_linear, test_ques_linear, test_ques_mask)

    ent_prob, contr_prob = inter_atten(test_src_linear, test_tgt_linear)        

    ent_probs = F.softmax(ent_prob)
    contr_probs = F.softmax(contr_prob)

    ent_prob, ent_pred = ent_probs.data.max(dim=1)
    contr_prob, contr_pred = contr_probs.data.max(dim=1)
    
    tot_eg += test_lbl_batch.data.size()[0]
    
    for eg_num in range(len(ent_pred)):
        
        if debug:
            sent = []
            for idx in range(test_src_batch.data[eg_num].size()[0]):
                sent.append(vocab_dict[test_src_batch.data[eg_num][idx] + 1])
            t_sent = []
            for idx in range(test_tgt_batch.data[eg_num].size()[0]):
                t_sent.append(vocab_dict[test_tgt_batch.data[eg_num][idx] + 1])
            q_sent = []
            for idx in range(test_ques_batch.data[eg_num].size()[0]):
                q_sent.append(vocab_dict[test_ques_batch.data[eg_num][idx] + 1])
            print(' '.join(q_sent))
            print(' '.join(sent))
            print(' '.join(t_sent))
            print(ent_pred[eg_num], contr_pred[eg_num], test_lbl_batch.data[eg_num][0], test_lbl_batch.data[eg_num][1])
            print('\n')
        
        if test_lbl_batch.data[eg_num][0] == 1 and test_lbl_batch.data[eg_num][1] == 0:
            ent_tot += 1
        elif test_lbl_batch.data[eg_num][0] == 0 and test_lbl_batch.data[eg_num][1] == 1:
            contr_tot += 1
        elif test_lbl_batch.data[eg_num][0] == 0 and test_lbl_batch.data[eg_num][1] == 0:
            neu_tot += 1

        if (ent_pred[eg_num] == test_lbl_batch.data[eg_num][0] or test_lbl_batch.data[eg_num][0] == -1) and\
           (contr_pred[eg_num] == test_lbl_batch.data[eg_num][1] or test_lbl_batch.data[eg_num][1] == -1):
            tot_corr += 1.0
            if ent_pred[eg_num] == 1 and contr_pred[eg_num] == 0:
                ent_c += 1
            elif ent_pred[eg_num] == 0 and contr_pred[eg_num] == 1:
                contr_c += 1
            elif ent_pred[eg_num] == 0 and contr_pred[eg_num] == 0:
                neu_c += 1
            else:
                if ent_prob[eg_num] > contr_prob[eg_num]:
                    ent_c += 1
                else:
                    contr_c += 1
        else:
            err_file.write(str(test_src_ids[eg_num]) + '\t' + str(test_targ_ids[eg_num]) + '\t' + str(ent_pred[eg_num]) + '\t' + str(contr_pred[eg_num]) + str(test_lbl_batch[eg_num]) + '\n')

        if ent_pred[eg_num] == 1 and contr_pred[eg_num] == 0:
            m_id = 0
            m_prob = ent_prob[eg_num]
            ent_pr += 1
        elif ent_pred[eg_num] == 0 and contr_pred[eg_num] == 1:
            m_id = 2
            m_prob = contr_prob[eg_num]
            contr_pr += 1
        elif ent_pred[eg_num] == 0 and contr_pred[eg_num] == 0:
            m_id = 1
            m_prob = 1.0
            neu_pr += 1
        else:
            if ent_prob[eg_num] > contr_prob[eg_num]:
                m_id = 0
                m_prob = ent_prob[eg_num]
                ent_pr += 1
            else:
                m_id = 2
                m_prob = contr_prob[eg_num]
                contr_pr += 1
        out_file.write(str(test_src_ids[eg_num]) + '\t' + str(test_targ_ids[eg_num]) + '\t' + str(m_prob) + '\t' + str(m_id) + '\n')

out_file.close()
err_file.close()
print(ent_pr, contr_pr, neu_pr, round(ent_c/(ent_tot + 0.01), 4),\
      round(contr_c/(contr_tot  + 0.01), 4), round(neu_c/(neu_tot + 0.01), 4), \
     round((ent_c/(ent_pr + 0.01)), 4), round((contr_c/(contr_pr + 0.01)), 4), round((neu_c/(neu_pr + 0.01)), 4))
print('Accuracy: '+ str(tot_corr / tot_eg))

(12599.0, 7174.0, 3513.0, 0.0, 0.0, 0.1509, 0.0, 0.0, 1.0)
Accuracy: 0.150863179593
