In [1]:
import os
os.environ["config"] = "/home/msun415/induction/src/config/ckt.yaml"
import sys
os.chdir('/home/msun415/induction/')
from src.examples import *
from src.draw.graph import draw_graph
from argparse import ArgumentParser
import pickle
import heapq
from src.grammar.ednce import *
from src.draw.graph import *
from src.api.get_motifs import *
from src.algo.utils import *
from src.algo.common import *
from src.grammar.common import *
from src.grammar.utils import *
from src.algo.ednce import *
from src.model import graph_regression, transformer_regression
from argparse import ArgumentParser
from networkx.algorithms.isomorphism import DiGraphMatcher
import pdb

In [2]:
def load_data(args):
    if DATASET == "cora":
        g = load_cora()
    elif DATASET == "test":
        g = create_test_graph(1)
    elif DATASET == "debug":
        g = debug()
    elif DATASET == "house":
        g = create_house_graph()
    elif DATASET == "ckt":
        g = load_ckt(args)
    elif DATASET == "enas":
        g = load_enas(args)
    elif DATASET == "bn":
        g = load_bn(args)
    elif DATASET == "ast":
        g = load_ast(args)
    elif DATASET == "mol":
        g = load_mols(args)
    else:
        raise NotImplementedError
    return g
    
def get_args():
    parser = ArgumentParser()
    # global args
    parser.add_argument("--visualize", dest="global_visualize", action='store_true')
    parser.add_argument("--cache", dest="global_cache", action='store_true')    
    parser.add_argument("--num_threads", dest="global_num_threads", type=int)
    parser.add_argument("--num_procs", dest="global_num_procs", type=int)    
    # hparams
    parser.add_argument("--scheme", choices=['one','zero'], help='whether to index from 0 or 1', default='zero')    
    # ablations
    parser.add_argument("--ablate_tree", action='store_true') 
    parser.add_argument("--ablate_merge", action='store_true') 
    parser.add_argument("--ablate_root", action='store_true') 
    # task params
    parser.add_argument("--task", nargs='+', choices=["learn","generate","prediction"])
    parser.add_argument("--seed")
    parser.add_argument("--grammar_ckpt")
    # mol dataset args
    parser.add_argument(
        "--mol-dataset",
        choices=["ptc","hopv","polymers_117", "isocyanates", "chain_extenders", "acrylates"],
    )
    parser.add_argument(
        "--num-data-samples", type=int
    )
    parser.add_argument("--ambiguous-file", help='if given and exists, load data from this file to learn grammar; if given and not exist, save ambiguous data to this file after learn grammar')
    parser.add_argument("--num_samples", default=10000, type=int, help='how much to generate')
    return parser.parse_args([
        '--task', 'learn',
        '--ambiguous-file', 'cache/api_ckt_ednce/ambig_1.json'
    ])

In [3]:
args = get_args()
g = load_data(args)
orig = deepcopy(g)
cache_iter, cache_path = setup()
print(cache_path)
g, grammar, anno, iter = init_grammar(g, cache_iter, cache_path, EDNCEGrammar)
grammar, model, anno, g = terminate(g, grammar, anno, iter)
# for j, m in enumerate(model):
#     pre = get_prefix(m.id)
#     # draw_tree(m, os.path.join(IMG_DIR, f"model_{iter}_{pre}.png"))
#     model[j] = EDNCEModel(dfs(anno, m.id))
for j, m in enumerate(model):
    pre = get_prefix(m.id)
    # draw_tree(m, os.path.join(IMG_DIR, f"model_{iter}_{pre}.png"))
    model[j] = EDNCEModel(dfs(anno, m.id))
graphs = [copy_graph(orig, orig.comps[get_prefix(m.seq[0])]) for m in model]

loading graphs:   0%|          | 0/5000 [00:00<?, ?it/s]

loading graphs: 100%|██████████| 5000/5000 [00:01<00:00, 3402.34it/s]


cache/api_ckt_ednce/466.pkl


In [4]:
%matplotlib inline
g = copy_graph(orig, orig.comps[0])
fig = draw_graph(g)
fig

In [None]:
def worker_single(stack, grammar, graph, init_hash, mem, lock):
    while True:
        with lock:
            if len(stack) == 0:
                if init_hash in mem and mem[init_hash] != 0:
                    print("process done")                    
                break
            else:
                print(len(stack))
                cur, val = stack.pop(-1)
            if val in mem:
                if mem[val] != 0:
                    continue
            else:
                mem[val] = 0
        nts = grammar.search_nts(cur, NONTERMS)
        if len(nts) == 0:
            if nx.is_isomorphic(cur, graph, node_match=node_match):
                with lock: 
                    mem[val] = [[]]
            else:
                with lock:
                    mem[val] = []
            continue # done        
        done = True
        res = []
        for j, nt in enumerate(nts):
            for i, rule in enumerate(grammar.rules):                      
                if rule is None:
                    continue
                nt_label = cur.nodes[nt]['label']
                if rule.nt == nt_label:
                    c = rule(cur, nt)
                    if not nx.is_connected(nx.Graph(c)):
                        continue
                    if not nx.is_directed_acyclic_graph(c):
                        continue
                    exist = find_partial([graph], c)
                    if not exist:
                        continue
                    hash_val = wl_hash(c)
                    with lock:
                        if hash_val not in mem:
                            if done:
                                stack.append((cur, val))
                                done = False
                            stack.append((c, hash_val))                            
                        else:
                            if mem[hash_val] == 0:
                                if done:                                
                                    stack.append((cur, val))                            
                                    done = False
                            else:
                                for seq in mem[hash_val]: # res
                                    res.append([i]+deepcopy(seq))
        with lock:
            if done:        
                mem[val] = res 

In [5]:
def set_to_key(S):
    return tuple(sorted(list(S)))

def hitting(elim_sets, beam_width=100):
    H = []
    heapq.heappush(H, (0, set())) # priority=-len
    for S in elim_sets: # beam search
        # priority queue
        H_copy = list(H)
        H = []
        for _, h in H_copy: # copy
            for s in S:
                if len(H) < beam_width:
                    h_ = h|set([s])
                    heapq.heappush(H, (-len(h_), h_))
                else:
                    val_, h_ = heapq.heappop(H)
                    h__ = h|set([s])
                    if -val_ <= len(h_):
                        heapq.heappush(H, (val_, h_))
                    else:
                        heapq.heappush(H, (-len(h__), h__))
    while len(H)>1:
        heapq.heappop(H)
    l, ans = heapq.heappop(H)
    print(f"Input: {elim_sets}, Output: {ans}")
    return ans


def try_disambiguate(grammar, derivs):
    counts = {}
    for deriv in derivs:
        key = set_to_key(deriv)
        if key not in counts:
            counts[key] = []
        counts[key].append(deriv)
    elim_rule_sets = set()
    for key in counts:
        if len(counts[key]) == 1:
            keep_deriv = counts[key][0]
            elim_sets = set()
            for deriv in derivs:
                if deriv == keep_deriv:
                    continue
                elim_sets.add(set_to_key(set(deriv)-set(keep_deriv)))
            elim_rule_set = hitting(elim_sets)
            elim_rule_sets.add(set_to_key(elim_rule_set))
    # elim rule set
    min_num_remove = len(grammar.rules)
    min_remove = None
    for rule_set in elim_rule_sets:
        num_remove = 0
        for r in rule_set:
            if grammar.rules[r] is not None:
                num_remove += 1
        if num_remove < min_num_remove:
            min_num_remove = num_remove
            min_remove = rule_set
    if min_remove is None:
        min_remove = hitting(counts.keys())            
    return min_remove



In [6]:
all_derivs = {}
index_graphs = [(i, graph) for i, graph in enumerate(graphs)]
sorted_graphs = sorted(index_graphs, key=lambda x:-len(x[1]))
for (index, graph) in sorted_graphs:
    manager = mp.Manager()
    stack = manager.list()
    mem = manager.dict()
    lock = manager.Lock()        
    g = nx.DiGraph()
    g.add_node('0', label='black')    
    init_hash = wl_hash(g)
    stack.append((deepcopy(g), init_hash))
    processes = []
    for _ in range(NUM_PROCS):
        p = mp.Process(target=worker_single, args=(stack, grammar, graph, init_hash, mem, lock))
        p.start()
        processes.append(p)
    for p in processes:
        p.join()    
    derivs = mem[init_hash]
    all_derivs[index] = derivs
    rule_set = try_disambiguate(grammar, derivs)
    n = len(list(filter(None, grammar.rules)))
    print(f"removing rules {rule_set}: {n}->{n-len(rule_set)}")
    for r in rule_set:        
        grammar.rules[r] = None

1
9
9
11
14
18
23
29
36
44
53
63
74
86
97
108
125
138
154
168
184
200
199
214
215
232
231
230
250
251
252
253
267
268
270
272
287
286
285
284
283
289
293
307
309
312
313
316
315
318
321
332
334
335
335
337
336
337
340
339
338
350
349
348
351
350
351
354
353
356
355
356
355
354
356
358
357
357
357
367
366
366
365
364
363
362
369
368
367
368
367
368
367
370
369
369
371
373
374
375
376
375
379
382
392
391
390
389
390
389
390
391
392
391
395
395
398
397
399
399
398
397
399
398
398
403
402
402
405
407
406
408
440
440
448
460
1195
1211
1272
1279
1349
1372
1421
1443
1460
1491
1529
1543
1569
1582
1595
1618
1630
1636
1713
1783
1822
1823
1863
1894
1911
1920
1984
2008
2020
2072
2093
2110
2109
2113
2154
2162
2172
2174
2177
2183
2185
2208
2217
2225
2225
2244
2262
2266
2289
2293
2295
2299
2336
2339
2344
2350
2355
2357
2366
2375
2385
2389
2400
2401
2428
2432
2451
2479
2488
2517
2528
2530
2551
2558
2558
2567
2566
2570
2573
2592
2592
2603
2637
2659
2687
2692
2698
2719
2736
2735
2760
2802
2846
2846
2857

In [None]:
sk = set_to_key
S = [sk({1,2,3}), sk({2,3,4,5}), sk({3,6,7}), sk({5,6,8}), sk({8,9,10})]
hitting(set(S))