Skip to content

Commit

Permalink
seq2tree exps
Browse files Browse the repository at this point in the history
  • Loading branch information
pcyin committed Jan 3, 2017
1 parent 5111155 commit f0ec265
Show file tree
Hide file tree
Showing 5 changed files with 437 additions and 14 deletions.
15 changes: 11 additions & 4 deletions code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@
decode_parser.add_argument('-type', default='test_data')

# evaluation operation
evaluate_parser.add_argument('-mode', default='self')
evaluate_parser.add_argument('-input', default='decode_results.bin')
evaluate_parser.add_argument('-type', default='test_data')
evaluate_parser.add_argument('-seq2tree_sample_file', default='model.sample')
evaluate_parser.add_argument('-seq2tree_id_file', default='test.id.txt')

# misc
parser.add_argument('-ifttt_test_split', default='data/ifff.test_data.gold.id')
Expand Down Expand Up @@ -175,11 +178,15 @@
serialize_to_file(decode_results, args.saveto)

if args.operation == 'evaluate':
decode_results_file = args.input
dataset = eval(args.type)
decode_results = deserialize_from_file(decode_results_file)

evaluate_decode_results(dataset, decode_results)
if config.mode == 'self':
decode_results_file = args.input
decode_results = deserialize_from_file(decode_results_file)

evaluate_decode_results(dataset, decode_results)
elif config.mode == 'seq2tree':
from evaluation import evaluate_seq2tree_sample_file
evaluate_seq2tree_sample_file(config.seq2tree_sample_file, config.seq2tree_id_file, dataset)

if args.operation == 'interactive':
from dataset import canonicalize_query, query_to_data
Expand Down
13 changes: 7 additions & 6 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,9 @@ def parse_django_dataset():
parse_trees = [e['parse_tree'] for e in data]

# apply unary closures
unary_closures = get_top_unary_closures(parse_trees, k=0, freq=UNARY_CUTOFF_FREQ)
for i, parse_tree in enumerate(parse_trees):
apply_unary_closures(parse_tree, unary_closures)
# unary_closures = get_top_unary_closures(parse_trees, k=0, freq=UNARY_CUTOFF_FREQ)
# for i, parse_tree in enumerate(parse_trees):
# apply_unary_closures(parse_tree, unary_closures)

# build the grammar
grammar = get_grammar(parse_trees)
Expand All @@ -390,7 +390,7 @@ def parse_django_dataset():
# grammar, all_parse_trees = extract_grammar(code_file)

annot_tokens = list(chain(*[e['query_tokens'] for e in data]))
annot_vocab = gen_vocab(annot_tokens, vocab_size=5000, freq_cutoff=5) # gen_vocab(annot_tokens, vocab_size=5980)
annot_vocab = gen_vocab(annot_tokens, vocab_size=5000, freq_cutoff=3) # gen_vocab(annot_tokens, vocab_size=5980)

terminal_token_seq = []
empty_actions_count = 0
Expand Down Expand Up @@ -428,7 +428,7 @@ def get_terminal_tokens(_terminal_str):
assert len(terminal_token) > 0
terminal_token_seq.append(terminal_token)

terminal_vocab = gen_vocab(terminal_token_seq, vocab_size=5000, freq_cutoff=5)
terminal_vocab = gen_vocab(terminal_token_seq, vocab_size=5000, freq_cutoff=3)
assert '_STR:0_' in terminal_vocab

train_data = DataSet(annot_vocab, terminal_vocab, grammar, 'train_data')
Expand Down Expand Up @@ -550,7 +550,8 @@ def get_terminal_tokens(_terminal_str):
test_data.init_data_matrices()

serialize_to_file((train_data, dev_data, test_data),
'data/django.cleaned.dataset.freq5.par_info.refact.space_only.unary_closure.freq{UNARY_CUTOFF_FREQ}.order_by_ulink_len.bin'.format(UNARY_CUTOFF_FREQ=UNARY_CUTOFF_FREQ))
'data/django.cleaned.dataset.freq3.par_info.refact.space_only.order_by_ulink_len.bin')
# 'data/django.cleaned.dataset.freq5.par_info.refact.space_only.unary_closure.freq{UNARY_CUTOFF_FREQ}.order_by_ulink_len.bin'.format(UNARY_CUTOFF_FREQ=UNARY_CUTOFF_FREQ))

return train_data, dev_data, test_data

Expand Down
81 changes: 81 additions & 0 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,87 @@ def evaluate_decode_results(dataset, decode_results, verbose=True):
return cum_bleu, cum_acc


def evaluate_seq2tree_sample_file(sample_file, id_file, dataset):
from lang.py.parse import tokenize_code, de_canonicalize_code
import ast, astor
import traceback
from lang.py.seq2tree_exp import seq2tree_repr_to_ast_tree, merge_broken_value_nodes
from lang.py.parse import decode_tree_to_python_ast

f_sample = open(sample_file)
line_id_to_raw_id = OrderedDict()
raw_id_to_eid = OrderedDict()
for i, line in enumerate(open(id_file)):
raw_id = int(line.strip())
line_id_to_raw_id[i] = raw_id

for eid in range(len(dataset.examples)):
raw_id_to_eid[dataset.examples[eid].raw_id] = eid

cum_bleu = 0.0
cum_acc = 0.0
sm = SmoothingFunction()
convert_error_num = 0

for i in range(len(line_id_to_raw_id)):
print 'working on %d' % i
ref_repr = f_sample.readline().strip()
predict_repr = f_sample.readline().strip()
predict_repr = predict_repr.replace('<U>', 'str{}{unk}').replace('( )', '( str{}{unk} )')
f_sample.readline()

try:
parse_tree = seq2tree_repr_to_ast_tree(predict_repr)
merge_broken_value_nodes(parse_tree)
except:
print 'error when converting:'
print predict_repr
convert_error_num += 1
continue

raw_id = line_id_to_raw_id[i]
eid = raw_id_to_eid[raw_id]
example = dataset.examples[eid]

ref_code = example.code
ref_ast_tree = ast.parse(ref_code).body[0]
refer_source = astor.to_source(ref_ast_tree).strip()
refer_tokens = tokenize_code(refer_source)

try:
ast_tree = decode_tree_to_python_ast(parse_tree)
code = astor.to_source(ast_tree).strip()
except:
print "Exception in converting tree to code:"
print '-' * 60
print 'line id: %d' % i
traceback.print_exc(file=sys.stdout)
print '-' * 60

if config.data_type == 'django':
ref_code_for_bleu = example.meta_data['raw_code']
pred_code_for_bleu = de_canonicalize_code(code, example.meta_data['raw_code'])
# convert canonicalized code to raw code
for literal, place_holder in example.meta_data['str_map'].iteritems():
pred_code_for_bleu = pred_code_for_bleu.replace('\'' + place_holder + '\'', literal)
elif config.data_type == 'hs':
ref_code_for_bleu = ref_code
pred_code_for_bleu = code

# we apply Ling Wang's trick when evaluating BLEU scores
refer_tokens_for_bleu = tokenize_for_bleu_eval(ref_code_for_bleu)
pred_tokens_for_bleu = tokenize_for_bleu_eval(pred_code_for_bleu)

ngram_weights = [0.25] * min(4, len(refer_tokens_for_bleu))
bleu_score = sentence_bleu([refer_tokens_for_bleu], pred_tokens_for_bleu, weights=ngram_weights,
smoothing_function=sm.method3)
cum_bleu += bleu_score

cum_bleu /= len(line_id_to_raw_id)
logging.info('num. errors when converting repr to tree: %d', convert_error_num)
logging.info('sentence level bleu: %f', cum_bleu)


def evaluate_ifttt_results(dataset, decode_results, verbose=True):
assert dataset.count == len(decode_results)

Expand Down
Loading

0 comments on commit f0ec265

Please sign in to comment.