In [1]:
import os, sys, time, json, sqlite3, re, random, numpy as np, pickle as cp, glob, operator
from copy import deepcopy
from collections import OrderedDict
import subprocess32 as subprocess
from functools import partial
np.random.seed(1307)
sys.path.append('..')
from pycparser import parse_file, c_ast
from pycparser.plyparser import ParseError

from util.helpers import get_rev_dict, remove_non_ascii, make_dir_if_not_exists as mkdir
from util.helpers import get_curr_time_string, remove_non_ascii
from util.helpers import isolate_line, fetch_line, extract_line_number, get_lines, recompose_program
from util.ast_helpers import get_subtree_list, get_linearized_ast, get_ast
from concurrent.futures import ThreadPoolExecutor
from pprint import pprint
#from util.c_tokenizer import C_Tokenizer
#tokenize = C_Tokenizer().tokenize

Using TensorFlow backend.


In [2]:
dataset = '../data/dataset.db'
with sqlite3.connect(dataset) as conn:
    c = conn.cursor()

In [3]:
c.execute("SELECT name FROM sqlite_master WHERE type='table';")
print(c.fetchall())

[('test_runs',), ('orgsource',), ('test_run_summary',), ('programs',), ('problems',)]


In [4]:
destination = '../data/network_inputs/bugloc-%s/' % time.strftime("%d-%m")
mkdir(destination)
print(destination)

../data/network_inputs/bugloc-17-09/


In [5]:
eval_set = np.load('../data/eval_set.npy', allow_pickle=True).item()
eval_dict = {}
for problem_id in eval_set:
    for program_id, row in eval_set[problem_id].items():
        eval_dict[program_id] = row

eval_set_program_ids = set(eval_dict.keys())
print ('len(eval_set_program_ids):', len(eval_set_program_ids))

len(eval_set_program_ids): 2167


In [6]:
query='''SELECT p.program_id, subtree_list_ast_without_leaves, test_id, t.verdict FROM 
        programs p INNER JOIN orgsource o ON o.program_id = p.program_id
        INNER JOIN problems q ON o.problem_id = q.problem_id 
        INNER JOIN test_runs t ON t.program_id = p.program_id
        INNER JOIN test_run_summary trs ON trs.program_id = p.program_id
        WHERE trs.verdict<>"ALL_FAIL" AND q.problem_id=?;'''

problem_ids = [str(row[0]) for row in c.execute('SELECT DISTINCT problem_id FROM orgsource;')]
print('len(problem_ids)', len(problem_ids), problem_ids[0])

len(problem_ids) 29 3024


In [7]:
max_programs_per_test_case_result = 700

test_wise_counts = {}
eval_test_wise_counts = {}
all_data = {}
all_eval_data = {}
test_id_to_problem_id_map = {}

with sqlite3.connect(dataset) as conn:
    c = conn.cursor()

for problem_id in problem_ids:
        test_wise_counts[problem_id] = {}
        eval_test_wise_counts[problem_id] = {}
        this_data_len = 0
        this_eval_data = []
        all_data[problem_id] = {}
        for row in c.execute(query, (problem_id,)):
            program_id = row[0]
            subtree_list_ast = json.loads(row[1])
            test_id = str(row[2])
            verdict = row[3]
            
            if test_id in test_id_to_problem_id_map:
                assert test_id_to_problem_id_map[test_id] == problem_id
            test_id_to_problem_id_map[test_id] = problem_id
            
            if test_id not in test_wise_counts[problem_id]:
                test_wise_counts[problem_id][test_id] = {0:[], 1:[]}
                eval_test_wise_counts[problem_id][test_id] = {0:[], 1:[]}
                all_data[problem_id][test_id] = {0:[], 1:[]}

            if program_id in eval_set_program_ids:
                this_eval_data += [(program_id, subtree_list_ast, test_id, verdict)]
                eval_test_wise_counts[problem_id][test_id][verdict] += [program_id]
            elif len(test_wise_counts[problem_id][test_id][verdict]) < max_programs_per_test_case_result:
                all_data[problem_id][test_id][verdict] += [(program_id, subtree_list_ast, test_id, verdict)]
                test_wise_counts[problem_id][test_id][verdict] += [program_id]
                this_data_len += 1
                
        all_eval_data[problem_id] = this_eval_data
        print('problem_id: %6s' % problem_id, '#examples:', this_data_len, '#eval_examples:', len(this_eval_data))

total = 0
for problem_id in all_data:
    for test_id in all_data[problem_id]:
        for verdict in all_data[problem_id][test_id]:
            total += len(all_data[problem_id][test_id][verdict])
print('total', total)

c.close()
conn.close()

problem_id:   3024 #examples: 7658 #eval_examples: 1848
problem_id:   3035 #examples: 10138 #eval_examples: 730
problem_id:   3018 #examples: 7983 #eval_examples: 828
problem_id:   3003 #examples: 9602 #eval_examples: 104
problem_id:   3026 #examples: 5136 #eval_examples: 390
problem_id:   3019 #examples: 4590 #eval_examples: 516
problem_id:   3032 #examples: 6569 #eval_examples: 357
problem_id:   3002 #examples: 6409 #eval_examples: 490
problem_id:   3008 #examples: 10155 #eval_examples: 927
problem_id:   3043 #examples: 8475 #eval_examples: 972
problem_id:   3016 #examples: 9536 #eval_examples: 510
problem_id:   3037 #examples: 4752 #eval_examples: 801
problem_id:   3047 #examples: 7460 #eval_examples: 280
problem_id:   3007 #examples: 3930 #eval_examples: 390
problem_id:   3009 #examples: 6589 #eval_examples: 609
problem_id:   3049 #examples: 11763 #eval_examples: 2790
problem_id:   3030 #examples: 7830 #eval_examples: 90
problem_id:   3055 #examples: 12387 #eval_examples: 1170
prob

In [8]:
for problem_id in all_data:
    print ('problem_id:', problem_id)
    for test_id in test_wise_counts[problem_id]:
        print (test_id.replace('.txt', '').replace('IN_', ''), 
        'fail:', len(test_wise_counts[problem_id][test_id][0]),
        'pass:', len(test_wise_counts[problem_id][test_id][1]))

problem_id: 3024
40013 fail: 273 pass: 700
40024 fail: 553 pass: 455
40078 fail: 250 pass: 700
40224 fail: 273 pass: 700
40297 fail: 546 pass: 462
40331 fail: 248 pass: 700
40394 fail: 199 pass: 700
40410 fail: 199 pass: 700
problem_id: 3035
40021 fail: 527 pass: 529
40082 fail: 392 pass: 664
40139 fail: 390 pass: 666
40205 fail: 382 pass: 674
40212 fail: 393 pass: 663
40237 fail: 394 pass: 662
40294 fail: 248 pass: 700
40312 fail: 224 pass: 700
40352 fail: 265 pass: 700
40383 fail: 265 pass: 700
problem_id: 3018
40049 fail: 431 pass: 456
40108 fail: 312 pass: 575
40274 fail: 360 pass: 527
40334 fail: 308 pass: 579
40335 fail: 313 pass: 574
40426 fail: 311 pass: 576
40444 fail: 332 pass: 555
40447 fail: 342 pass: 545
40449 fail: 459 pass: 428
problem_id: 3003
40006 fail: 700 pass: 531
40015 fail: 700 pass: 279
40141 fail: 231 pass: 700
40156 fail: 391 pass: 700
40362 fail: 686 pass: 700
40409 fail: 700 pass: 693
40453 fail: 700 pass: 491
40460 fail: 700 pass: 700
problem_id: 3026
40002

In [9]:
for problem_id in all_eval_data:
    print ('problem_id:', problem_id)
    for test_id in eval_test_wise_counts[problem_id]:
        print (test_id.replace('.txt', '').replace('IN_', ''), 'fail:', 
               len(eval_test_wise_counts[problem_id][test_id][0]),'pass:', 
               len(eval_test_wise_counts[problem_id][test_id][1]))

problem_id: 3024
40013 fail: 36 pass: 195
40024 fail: 202 pass: 29
40078 fail: 30 pass: 201
40224 fail: 36 pass: 195
40297 fail: 197 pass: 34
40331 fail: 30 pass: 201
40394 fail: 27 pass: 204
40410 fail: 27 pass: 204
problem_id: 3035
40021 fail: 28 pass: 45
40082 fail: 26 pass: 47
40139 fail: 28 pass: 45
40205 fail: 27 pass: 46
40212 fail: 33 pass: 40
40237 fail: 33 pass: 40
40294 fail: 13 pass: 60
40312 fail: 11 pass: 62
40352 fail: 9 pass: 64
40383 fail: 9 pass: 64
problem_id: 3018
40049 fail: 63 pass: 29
40108 fail: 9 pass: 83
40274 fail: 28 pass: 64
40334 fail: 24 pass: 68
40335 fail: 8 pass: 84
40426 fail: 24 pass: 68
40444 fail: 10 pass: 82
40447 fail: 21 pass: 71
40449 fail: 63 pass: 29
problem_id: 3003
40006 fail: 3 pass: 10
40015 fail: 7 pass: 6
40141 fail: 1 pass: 12
40156 fail: 3 pass: 10
40362 fail: 1 pass: 12
40409 fail: 2 pass: 11
40453 fail: 7 pass: 6
40460 fail: 1 pass: 12
problem_id: 3026
40002 fail: 12 pass: 53
40134 fail: 27 pass: 38
40218 fail: 16 pass: 49
40344 fai

## Upsample data for balancing

In [10]:
for problem_id in all_data:
    for test_id in all_data[problem_id]:
        if len(all_data[problem_id][test_id][0]) < len(all_data[problem_id][test_id][1]):
            small, target_len = all_data[problem_id][test_id][0], len(all_data[problem_id][test_id][1])
        elif len(all_data[problem_id][test_id][0]) > len(all_data[problem_id][test_id][1]):
            small, target_len = all_data[problem_id][test_id][1], len(all_data[problem_id][test_id][0])
        else:
            continue
        np.random.seed(int(test_id.replace('.txt', '').replace('IN_', '')))
        upsamples = np.random.choice(len(small), target_len-len(small))
        for idx in upsamples:
            small.append(small[idx])
        assert len(all_data[problem_id][test_id][0]) == len(all_data[problem_id][test_id][1])

In [11]:
np.random.seed(1307)

In [12]:
for problem_id in all_data:
    print('problem_id:', problem_id)
    for test_id in all_data[problem_id]:
        print (test_id.replace('.txt', '').replace('IN_', ''), 
               'fail:', len(all_data[problem_id][test_id][0]), 
               'pass:', len(all_data[problem_id][test_id][1]))
        assert len(all_data[problem_id][test_id][0]) == len(all_data[problem_id][test_id][1])

problem_id: 3024
40013 fail: 700 pass: 700
40024 fail: 553 pass: 553
40078 fail: 700 pass: 700
40224 fail: 700 pass: 700
40297 fail: 546 pass: 546
40331 fail: 700 pass: 700
40394 fail: 700 pass: 700
40410 fail: 700 pass: 700
problem_id: 3035
40021 fail: 529 pass: 529
40082 fail: 664 pass: 664
40139 fail: 666 pass: 666
40205 fail: 674 pass: 674
40212 fail: 663 pass: 663
40237 fail: 662 pass: 662
40294 fail: 700 pass: 700
40312 fail: 700 pass: 700
40352 fail: 700 pass: 700
40383 fail: 700 pass: 700
problem_id: 3018
40049 fail: 456 pass: 456
40108 fail: 575 pass: 575
40274 fail: 527 pass: 527
40334 fail: 579 pass: 579
40335 fail: 574 pass: 574
40426 fail: 576 pass: 576
40444 fail: 555 pass: 555
40447 fail: 545 pass: 545
40449 fail: 459 pass: 459
problem_id: 3003
40006 fail: 700 pass: 700
40015 fail: 700 pass: 700
40141 fail: 700 pass: 700
40156 fail: 700 pass: 700
40362 fail: 700 pass: 700
40409 fail: 700 pass: 700
40453 fail: 700 pass: 700
40460 fail: 700 pass: 700
problem_id: 3026
40002

In [13]:
new_all_data = {}
for problem_id in all_data:
    new_all_data[problem_id] = []
    for test_id in all_data[problem_id]:
        new_all_data[problem_id] += all_data[problem_id][test_id][0] + all_data[problem_id][test_id][1]
del all_data
all_data = new_all_data

total = 0
for problem_id in all_data:
    total += len(all_data[problem_id])

In [22]:
subtrees_per_program, nodes_per_subtree = [], []
max_subtrees_per_program, max_nodes_per_subtree = 0, 0
d = all_data.copy()
d.update(all_eval_data)
for problem_id, rows in d.items():
    for program_id, subtree_list_ast, test_id, verdict in rows:
        subtrees_per_program.append(len(subtree_list_ast))
        nodes_per_subtree.append(max([len(subtree) for subtree, coord in subtree_list_ast]))
        max_subtrees_per_program = max(max_subtrees_per_program, len(subtree_list_ast))
        max_nodes_per_subtree = max(max_nodes_per_subtree, *[len(subtree) for subtree, coord in subtree_list_ast])
        
print ('max_subtrees_per_program:', max_subtrees_per_program, 'max_nodes_per_subtree:', max_nodes_per_subtree)

assert len(subtrees_per_program) == len(nodes_per_subtree)

subtrees_per_program = list(sorted(subtrees_per_program))
nodes_per_subtree = list(sorted(nodes_per_subtree))

for percentile in [80,90,95,98,99]:
    i = int( percentile * len(nodes_per_subtree) / 100.0 )
    print ('percentile:', percentile, 'max_subtrees:',  subtrees_per_program[i], 'max_nodes:', nodes_per_subtree[i])

max_subtrees_per_program: 149 max_nodes_per_subtree: 21
percentile: 80 max_subtrees: 78 max_nodes: 13
percentile: 90 max_subtrees: 90 max_nodes: 15
percentile: 95 max_subtrees: 110 max_nodes: 17
percentile: 98 max_subtrees: 126 max_nodes: 18
percentile: 99 max_subtrees: 133 max_nodes: 19


In [24]:
# Discard last percentile programs

max_trees = 149
max_nodes = 21
selected_data = {}
selected_eval_data = {}
for problem_id, rows in all_data.items():
    selected_data[problem_id] = []
    for program_id, subtree_list_ast, test_id, verdict in rows:
        if len(subtree_list_ast) <= max_trees and max([len(subtree) for subtree, coord in subtree_list_ast]) <= max_nodes:
            selected_data[problem_id].append((program_id, subtree_list_ast, test_id, verdict))

for problem_id, rows in all_eval_data.items():
    selected_eval_data[problem_id] = []
    for program_id, subtree_list_ast, test_id, verdict in rows:
        if len(subtree_list_ast) <= max_trees and max([len(subtree) for subtree, coord in subtree_list_ast]) <= max_nodes:
            selected_eval_data[problem_id].append((program_id, subtree_list_ast, test_id, verdict))

new_total = 0
for problem_id in selected_data:
    new_total += len(selected_data[problem_id])
print ('old-total', total, 'new_total:', new_total, 'difference:', total-new_total)

del all_data, all_eval_data

old-total 261174 new_total: 256297 difference: 4877


In [25]:
def get_id_map(ast, program_id=None):
    '''shuffles ids before assigning them indices using 
    program_id as randomness seed if program_id is not None'''
    
    ids = []
    for subtree, coord in ast:
        for node in subtree:
            if '_<id>_' in node and '@' in node:
                org_id = node.split('_<id>_')[1].split('@')[0]
                if org_id not in ids:
                    ids.append(org_id)
                    
    if program_id is not None:
        random.seed(program_id)
        random.shuffle(ids)

    id_map = {}
    for id_ in ids:
        id_map[id_] = len(id_map)
    return id_map

def normalize_ids(ast, id_map):
    new_ast = []
    for subtree, coord in ast:
        new_subtree = []
        for node in subtree:
            if '_<id>_' in node and '@' in node:
                org_id = node.split('_<id>_')[1].split('@')[0]
                new_subtree.append(node.replace('_<id>_' + org_id + '@', '_<id>_' + str(id_map[org_id]) + '@'))
            else:
                new_subtree.append(node)
        assert len(new_subtree) == len(subtree)
        new_ast.append((new_subtree, coord))
    return new_ast

normalized_data, normalized_eval_data = {}, {}

for problem_id in selected_data:
    normalized_data[problem_id] = []
    remaining = len(selected_data[problem_id]) - 1
    for program_id, subtree_list_ast, test_id, verdict in selected_data[problem_id]:
        if remaining % 250 == 0:
            print ('%s | remaining:%6d     \r' % (problem_id, remaining))
        remaining -= 1
        id_map = get_id_map(subtree_list_ast, program_id)
        norm_subtree_list_ast = normalize_ids(subtree_list_ast, id_map)
        normalized_data[problem_id] += [(program_id, id_map, norm_subtree_list_ast, test_id, verdict)]
    print()
del selected_data
        
for problem_id in selected_eval_data:
    normalized_eval_data[problem_id] = []
    for program_id, subtree_list_ast, test_id, verdict in selected_eval_data[problem_id]:
        id_map = get_id_map(subtree_list_ast, program_id)
        norm_subtree_list_ast = normalize_ids(subtree_list_ast, id_map)
        normalized_eval_data[problem_id] += [(program_id, id_map, norm_subtree_list_ast, test_id, verdict)]
        
del selected_eval_data

3024 | remaining: 10500     
3024 | remaining: 10250     
3024 | remaining: 10000     
3024 | remaining:  9750     
3024 | remaining:  9500     
3024 | remaining:  9250     
3024 | remaining:  9000     
3024 | remaining:  8750     
3024 | remaining:  8500     
3024 | remaining:  8250     
3024 | remaining:  8000     
3024 | remaining:  7750     
3024 | remaining:  7500     
3024 | remaining:  7250     
3024 | remaining:  7000     
3024 | remaining:  6750     
3024 | remaining:  6500     
3024 | remaining:  6250     
3024 | remaining:  6000     
3024 | remaining:  5750     
3024 | remaining:  5500     
3024 | remaining:  5250     
3024 | remaining:  5000     
3024 | remaining:  4750     
3024 | remaining:  4500     
3024 | remaining:  4250     
3024 | remaining:  4000     
3024 | remaining:  3750     
3024 | remaining:  3500     
3024 | remaining:  3250     
3024 | remaining:  3000     
3024 | remaining:  2750     
3024 | remaining:  2500     
3024 | remaining:  2250     
3024 | remaini

3002 | remaining:  6500     
3002 | remaining:  6250     
3002 | remaining:  6000     
3002 | remaining:  5750     
3002 | remaining:  5500     
3002 | remaining:  5250     
3002 | remaining:  5000     
3002 | remaining:  4750     
3002 | remaining:  4500     
3002 | remaining:  4250     
3002 | remaining:  4000     
3002 | remaining:  3750     
3002 | remaining:  3500     
3002 | remaining:  3250     
3002 | remaining:  3000     
3002 | remaining:  2750     
3002 | remaining:  2500     
3002 | remaining:  2250     
3002 | remaining:  2000     
3002 | remaining:  1750     
3002 | remaining:  1500     
3002 | remaining:  1250     
3002 | remaining:  1000     
3002 | remaining:   750     
3002 | remaining:   500     
3002 | remaining:   250     
3002 | remaining:     0     

3008 | remaining: 12000     
3008 | remaining: 11750     
3008 | remaining: 11500     
3008 | remaining: 11250     
3008 | remaining: 11000     
3008 | remaining: 10750     
3008 | remaining: 10500     
3008 | remain

3049 | remaining: 12000     
3049 | remaining: 11750     
3049 | remaining: 11500     
3049 | remaining: 11250     
3049 | remaining: 11000     
3049 | remaining: 10750     
3049 | remaining: 10500     
3049 | remaining: 10250     
3049 | remaining: 10000     
3049 | remaining:  9750     
3049 | remaining:  9500     
3049 | remaining:  9250     
3049 | remaining:  9000     
3049 | remaining:  8750     
3049 | remaining:  8500     
3049 | remaining:  8250     
3049 | remaining:  8000     
3049 | remaining:  7750     
3049 | remaining:  7500     
3049 | remaining:  7250     
3049 | remaining:  7000     
3049 | remaining:  6750     
3049 | remaining:  6500     
3049 | remaining:  6250     
3049 | remaining:  6000     
3049 | remaining:  5750     
3049 | remaining:  5500     
3049 | remaining:  5250     
3049 | remaining:  5000     
3049 | remaining:  4750     
3049 | remaining:  4500     
3049 | remaining:  4250     
3049 | remaining:  4000     
3049 | remaining:  3750     
3049 | remaini

3010 | remaining:  2500     
3010 | remaining:  2250     
3010 | remaining:  2000     
3010 | remaining:  1750     
3010 | remaining:  1500     
3010 | remaining:  1250     
3010 | remaining:  1000     
3010 | remaining:   750     
3010 | remaining:   500     
3010 | remaining:   250     
3010 | remaining:     0     

3025 | remaining:  7000     
3025 | remaining:  6750     
3025 | remaining:  6500     
3025 | remaining:  6250     
3025 | remaining:  6000     
3025 | remaining:  5750     
3025 | remaining:  5500     
3025 | remaining:  5250     
3025 | remaining:  5000     
3025 | remaining:  4750     
3025 | remaining:  4500     
3025 | remaining:  4250     
3025 | remaining:  4000     
3025 | remaining:  3750     
3025 | remaining:  3500     
3025 | remaining:  3250     
3025 | remaining:  3000     
3025 | remaining:  2750     
3025 | remaining:  2500     
3025 | remaining:  2250     
3025 | remaining:  2000     
3025 | remaining:  1750     
3025 | remaining:  1500     
3025 | remain

## Data partition

In [26]:
np.random.seed(1189)
vdp = 0.05
training_data, validation_data = {}, {}

for problem_id in normalized_data:
    total_size = len(normalized_data[problem_id])
    valid_data_size = int(vdp * total_size)
    np.random.shuffle(normalized_data[problem_id])
    validation_data[problem_id] = normalized_data[problem_id][:valid_data_size]
    training_data[problem_id] = normalized_data[problem_id][valid_data_size:]
    
test_data = normalized_eval_data
    
    
all_tests = set()
problem_id_dict = {}
program_ids_dict = {}
all_program_ids = set()
for problem_id in normalized_data:
    for program_id, id_map, norm_subtree_list_ast, test_id, verdict in (normalized_data[problem_id] + normalized_eval_data[problem_id]):
        assert problem_id == test_id_to_problem_id_map[test_id]
        all_tests.add(problem_id+test_id)
        all_program_ids.add(program_id)
        
test_dict = {}
for test_id in all_tests:
    test_dict[test_id] = len(test_dict)

print ('all tests:', len(test_dict))
print ('all programs:', len(all_program_ids))

all_problem_ids = sorted(normalized_data.keys())
for problem_id in all_problem_ids:
    if problem_id not in problem_id_dict:
        problem_id_dict[problem_id] = len(problem_id_dict)

print ('all problems:', len(problem_id_dict))

max_subtrees_per_program, max_nodes_per_subtree = 0, 0
for problem_id in normalized_data:
    for program_id, id_map, norm_subtree_list_ast, test_id, verdict in (normalized_data[problem_id] + normalized_eval_data[problem_id]):
        max_subtrees_per_program = max(max_subtrees_per_program, len(norm_subtree_list_ast))
        max_nodes_per_subtree = max(max_nodes_per_subtree, *[len(subtree) for subtree, coord in norm_subtree_list_ast])
        
info = {'max_subtrees_per_program':max_subtrees_per_program, 'max_nodes_per_subtree':max_nodes_per_subtree}
pprint (info)

all tests: 231
all programs: 29993
all problems: 29
{'max_nodes_per_subtree': 21, 'max_subtrees_per_program': 149}


## Generate and save subtree_list AST dataset

In [27]:
def build_dictionary_for_subtree_list_ast(data, tl_dict={}):

    def build_dict(subtree_list_ast, dict_ref):
        for subtree, coords in subtree_list_ast:
            for token in subtree:
                token = token.strip()
                if token not in dict_ref:
                    dict_ref[token] = len(dict_ref)

    tl_dict['_pad_'] = 0
    tl_dict['_eos_'] = 1

    for problem_id, rows in data.items():
        for _, _, subtree_list_ast, _, _ in rows:
            build_dict(subtree_list_ast, tl_dict)

    print ('dictionary size:', len(tl_dict))
    assert len(tl_dict) > 50
    return tl_dict

In [28]:
tl_dict = build_dictionary_for_subtree_list_ast(normalized_data)
tl_dict = build_dictionary_for_subtree_list_ast(normalized_eval_data, tl_dict)
rev_tl_dict = get_rev_dict(tl_dict)

dictionary size: 1146
dictionary size: 1210


In [33]:
def save_dictionaries(destination, all_dicts):
    if type(all_dicts) == dict:
        all_dicts = (all_dicts, get_rev_dict(all_dicts))
    else:
        assert type(all_dicts) == tuple
    np.save(os.path.join(destination, 'all_dicts.npy'), all_dicts)

def load_dictionaries(destination):
    all_dicts = np.load(os.path.join(destination, 'all_dicts.npy')).item()
    assert type(all_dicts) == tuple, type(all_dicts)
    return all_dicts
    
def save_data(destination, fold_data, all_dicts):
    for key in fold_data.keys():
        print ('%s-data size:' % key, len(fold_data[key]))
        with open(os.path.join(destination, ('examples-%s.pkl' % key)), 'wb') as f:
            cp.dump(fold_data[key], f)
    save_dictionaries(destination, all_dicts)

## Generate data with buggy subtrees for incorrect programs for which we could find corresponding correct programs with small diff 

In [30]:
def get_buggy_lines(program_id):
    global eval_dict
    if program_id in eval_dict:
        return set(eval_dict[program_id][3])
    else:
        return set([0])

In [31]:
def vectorize_subtree_list_ast(_tl_dict, subtree_list_ast, buggy_lines):
    vec_ast = []
    buggy_subtrees = {line:set() for line in buggy_lines}
    found = 0
    for idx, (subtree, coord) in enumerate(subtree_list_ast):
        line, char = map(int, coord.split(':'))
        if line in buggy_lines: 
            buggy_subtrees[line].add(idx)
            found += 1
        
        vec_subtree = []
        for token in subtree:
            vec_subtree.append(_tl_dict[token])
        
        vec_ast.append(vec_subtree)

    if not (len(buggy_lines)==0 or found > 0):
        raise ValueError('buggy_lines:%d, found %d' % (len(buggy_lines), found))
    return vec_ast, buggy_subtrees

In [35]:
this_fold = {'train':[], 'test':[], 'validation':[], 'eval':[]}
total_processed, exceptions = 0, 0

erratic_ids = []
ast_vec_cache = {}
def vectorize(tl_dict, subtree_list_ast, program_id):
    if program_id not in ast_vec_cache:
        buggy_lines = get_buggy_lines(program_id)
        vec_ast, buggy_subtrees = vectorize_subtree_list_ast(tl_dict, subtree_list_ast, buggy_lines)
        ast_vec_cache[program_id] = (vec_ast, buggy_subtrees)        
    return ast_vec_cache[program_id]

for name, each in zip(['train', 'test', 'validation'], [training_data, test_data, validation_data]):
    for problem_id, rows in each.items():
        for program_id, id_map, norm_subtree_list_ast, test_id, verdict in rows:
            try:
                total_processed += 1
                vec_ast, buggy_subtrees = vectorize(tl_dict, norm_subtree_list_ast, program_id)
                assert problem_id == test_id_to_problem_id_map[test_id]
                this_fold[name] += [(problem_id_dict[problem_id], program_id, test_dict[problem_id+test_id], vec_ast, verdict, buggy_subtrees)]
                if program_id in eval_set_program_ids and not verdict:
                    this_fold['eval'] += [(problem_id_dict[problem_id], program_id, test_dict[problem_id+test_id], vec_ast, verdict, buggy_subtrees)]
            except Exception as e:
                erratic_ids.append(program_id)
                exceptions += 1

print ('\ntotal_processed:', total_processed, 'exceptions:', exceptions)
save_data(destination, this_fold, (tl_dict, test_dict, problem_id_dict, info, {}))
np.save(os.path.join(destination, 'errs_in_finding_buggy_subtrees.npy'), erratic_ids)
print ('saved at:', destination)
print (len(set(erratic_ids)), set(erratic_ids))


total_processed: 273617 exceptions: 0
train-data size: 243495
test-data size: 17320
validation-data size: 12802
eval-data size: 7558
saved at: ../data/network_inputs/bugloc-17-09/
0 set()
