In [2]:
#!/usr/bin/env python
# coding: utf-8

from __future__ import print_function, division

import torch
import torchtext
import argparse
import time
import seq2seq
import subprocess
import re


from regexDFAEquals import regex_equiv_from_raw, unprocess_regex, regex_equiv
from seq2seq.optim import Optimizer
from seq2seq.models import EncoderRNN, DecoderRNN,Seq2seq
from seq2seq.loss import NLLLoss
from seq2seq.util.checkpoint import Checkpoint
from seq2seq.dataset import SourceField, TargetField
from seq2seq.evaluator import Predictor
from seq2seq.util.string_preprocess import get_set_num,pad_tensor, count_star, decode_tensor_input, decode_tensor_target
from seq2seq.util.regex_operation import pos_membership_test, neg_membership_test, preprocess_regex, regex_equal,\
regex_inclusion,valid_regex, or_exception



parser = argparse.ArgumentParser()
parser.add_argument('--train_path', action='store', dest='train_path', help='path to train data')
parser.add_argument('--test_path', action='store', dest='test_path', help='path to test data')
parser.add_argument('--checkpoint', action='store', dest='checkpoint', help='path to checkpoint')
opt = parser.parse_args()

latest_check_point = Checkpoint.get_latest_checkpoint(opt.checkpoint)
checkpoint = Checkpoint.load(latest_check_point)
input_vocab = checkpoint.input_vocab
output_vocab = checkpoint.output_vocab
print(latest_check_point)
print(input_vocab.stoi)
print(output_vocab.stoi)

../supervised/bidirectional_attn/checkpoints/2020_07_21_16_09_09
defaultdict(<bound method Vocab._default_unk_index of <torchtext.vocab.Vocab object at 0x7f0fa6f50588>>, {'<unk>': 0, '<pad>': 1, '0': 2, '2': 3, '1': 4, '3': 5, '<sep>': 6})
defaultdict(<bound method Vocab._default_unk_index of <torchtext.vocab.Vocab object at 0x7f104031cb70>>, {'<unk>': 0, '<pad>': 1, '0': 2, '1': 3, '3': 4, '2': 5, '|': 6, '(': 7, ')': 8, '*': 9, '<eos>': 10, '<sos>': 11, '[0-3]*': 12, '[0-3]': 13})


In [3]:
model = checkpoint.model
optimizer = checkpoint.optimizer
weight = torch.ones(len(output_vocab))
pad = output_vocab.stoi['<pad>']
loss = NLLLoss(weight, pad)
batch_size = 1 
print(model)

Seq2seq(
  (encoder): EncoderRNN(
    (input_dropout): Dropout(p=0.25, inplace=False)
    (embedding): Embedding(7, 128)
    (rnn): GRU(128, 128, num_layers=2, batch_first=True, dropout=0.25, bidirectional=True)
  )
  (decoder): DecoderRNN(
    (input_dropout): Dropout(p=0.25, inplace=False)
    (rnn): GRU(256, 256, num_layers=2, batch_first=True, dropout=0.2)
    (embedding): Embedding(14, 256)
    (attention): Attention(
      (linear_out): Linear(in_features=512, out_features=256, bias=True)
    )
    (out): Linear(in_features=256, out_features=14, bias=True)
  )
)


In [5]:
train_file = opt.train_path
test_file = opt.test_path
src = SourceField()
tgt = TargetField()

train = torchtext.data.TabularDataset(
    path=train_file, format='tsv',
    fields=[('src', src), ('tgt', tgt)]
)
test_data = torchtext.data.TabularDataset(
    path=test_file, format='tsv',
    fields=[('src', src), ('tgt', tgt)]
)

src.build_vocab(train, max_size=500)
tgt.build_vocab(train, max_size=500)

device = torch.device('cuda:0') if torch.cuda.is_available() else -1
batch_iterator = torchtext.data.BucketIterator(
    dataset=test_data, batch_size=1,
    sort=False, sort_within_batch=False,sort_key=lambda x: len(x.src),
    device=device, repeat=False, shuffle=False, train=False)



In [6]:
model.eval()
loss.reset()
start = time.time()

match = 0
total = 0
num_samples = 0

with torch.no_grad():
    with open('{}_error_analysis.txt'.format(opt.checkpoint), 'w') as fw:
        statistics = [{'cnt':0,'hit': 0,'string_equal':0,'dfa_equal':0, 'membership_equal':0, 'invalid_regex':0} for _ in range(4)]
        for batch in batch_iterator:
            num_samples = num_samples + 1
            input_variables, input_lengths  = getattr(batch, seq2seq.src_field_name)
            target_variables = getattr(batch, seq2seq.tgt_field_name)
            
            with torch.no_grad():
                softmax_list, _, other =model(input_variables, input_lengths)
            
            length = other['length'][0]
            tgt_id_seq = [other['sequence'][di][0].data[0] for di in range(length)]
            tgt_seq = [output_vocab.itos[tok] for tok in tgt_id_seq]
            
            # calculate NLL accuracy
            non_padding = target_variables.view(-1)[1:].ne(pad).type(torch.LongTensor)
            predict_var = torch.stack(tgt_id_seq)
            target_var = target_variables.view(-1)[1:]

            max_len = max(len(predict_var), len(target_var))
            padded_predict_var = pad_tensor(predict_var, max_len, output_vocab)
            padded_target_var = pad_tensor(target_var, max_len, output_vocab)
            padded_non_padding = pad_tensor(non_padding, max_len, output_vocab).type(torch.bool)
            correct = padded_predict_var.eq(padded_target_var).masked_select(padded_non_padding).sum().item()

            match += correct
            total += non_padding.sum().item()

            predict_regex = ' '.join(tgt_seq[:-1])
            target_regex = decode_tensor_target(target_variables, output_vocab)
            target_regex, predict_regex = preprocess_regex(target_regex, predict_regex)
                        
            for i in range(input_variables.size(0)):
                res = ""
                for j in range(input_variables[i].size(0)):
                    res += input_vocab.itos[input_variables[i][j]]
                res = res.replace('<pad>', '')
            
            inputs = res.split('<sep>')
            pos_input = inputs[:int(len(inputs)/2)]
            neg_input = inputs[int(len(inputs)/2):]
            star_cnt = count_star(target_regex)

            statistics[star_cnt]['cnt'] +=1            
            # calculate regex equivalent accuracy
            
            if not valid_regex(predict_regex) or not or_exception(predict_regex):
                statistics[star_cnt]['invalid_regex'] +=1
            else:    
                if target_regex == predict_regex:
                    statistics[star_cnt]['hit'] +=1
                    statistics[star_cnt]['string_equal']+=1
                elif regex_equal(target_regex, predict_regex):
                    statistics[star_cnt]['hit'] +=1
                    statistics[star_cnt]['dfa_equal'] +=1
                elif pos_membership_test(predict_regex, pos_input) and neg_membership_test(predict_regex,neg_input):
                    statistics[star_cnt]['hit'] +=1 
                    statistics[star_cnt]['membership_equal'] +=1
                else: 
                    fw.write('pos_input : ' + ' '.join(pos_input)+'\n')
                    fw.write('neg_input : ' + ' '.join(neg_input)+'\n')
                    fw.write('target_regex : ' + target_regex  +'\n')
                    fw.write('predict_regex : ' + predict_regex + '\n\n')
            
            if total == 0:
                accuracy = float('nan')
            else:
                accuracy = match / total
                
            if num_samples % 100 == 0:
                print("Iterations: ", num_samples, ", " , "{0:.3f}".format(accuracy) + '\n')
            
end = time.time()
print(statistics)
print(end-start)



Iterations:  100 ,  0.894

Iterations:  200 ,  0.888

Iterations:  300 ,  0.881

Iterations:  400 ,  0.891

Iterations:  500 ,  0.904

Iterations:  600 ,  0.904

Iterations:  700 ,  0.904

Iterations:  800 ,  0.913

Iterations:  900 ,  0.913

Iterations:  1000 ,  0.913

Iterations:  1100 ,  0.915

Iterations:  1200 ,  0.913

Iterations:  1300 ,  0.912

Iterations:  1400 ,  0.911

Iterations:  1500 ,  0.912

Iterations:  1600 ,  0.912

Iterations:  1700 ,  0.913

Iterations:  1800 ,  0.915

Iterations:  1900 ,  0.914

Iterations:  2000 ,  0.914

Iterations:  2100 ,  0.904

Iterations:  2200 ,  0.891

Iterations:  2300 ,  0.881

Iterations:  2400 ,  0.873

Iterations:  2500 ,  0.866

Iterations:  2600 ,  0.858

Iterations:  2700 ,  0.849

Iterations:  2800 ,  0.841

Iterations:  2900 ,  0.836

Iterations:  3000 ,  0.831

Iterations:  3100 ,  0.824

Iterations:  3200 ,  0.819

Iterations:  3300 ,  0.814

Iterations:  3400 ,  0.810

Iterations:  3500 ,  0.806

Iterations:  3600 ,  0.803

I