diff --git a/evaluation.py b/evaluation.py index 6938795..84490a1 100644 --- a/evaluation.py +++ b/evaluation.py @@ -29,7 +29,7 @@ def decode(examples, model, args, verbose=False, **kwargs): try: hyp.code = model.transition_system.ast_to_surface_code(hyp.tree) - if args.lang == 'wikisql': + if args.lang == 'wikisql' and args.answer_prune: # try execute the code, if fails, skip this example! # if the execution returns null, also skip this example! detokenized_hyp_query = detokenize_query(hyp.code, example.meta, example.table) diff --git a/exp.py b/exp.py index 8a3ebc4..670554b 100644 --- a/exp.py +++ b/exp.py @@ -67,6 +67,9 @@ def init_arg_parser(): # wikisql arg_parser.add_argument('--column_att', choices=['dot_prod', 'affine'], default='affine') + arg_parser.add_argument('--answer_prune', dest='answer_prune', action='store_true') + arg_parser.add_argument('--no_answer_prune', dest='answer_prune', action='store_false') + arg_parser.set_defaults(answer_prune=True) # parent information switch and input feeding arg_parser.add_argument('--no_parent_production_embed', default=False, action='store_true')