In [None]:
import os
import sys
os.chdir('/home/msun415/induction/')
import pickle
from src.config import METHOD, DATASET, GRAMMAR
import importlib
from src.examples import *
from src.draw.color import to_hex, CMAP
from src.draw.graph import draw_graph
from src.config import RADIUS
from argparse import ArgumentParser
import pickle
from src.grammar.common import get_args
from src.grammar.ednce import *
from src.draw.graph import *
from src.api.get_motifs import *
from src.algo.utils import *
from src.algo.common import *
from src.grammar.common import *
from src.grammar.utils import *
from src.algo.ednce import terminate, dfs
from src.model import graph_regression, transformer_regression
from argparse import ArgumentParser
from networkx.algorithms.isomorphism import DiGraphMatcher
import pdb

In [None]:
def load_data(args):
    if DATASET == "cora":
        g = load_cora()
    elif DATASET == "test":
        g = create_test_graph(1)
    elif DATASET == "debug":
        g = debug()
    elif DATASET == "house":
        g = create_house_graph()
    elif DATASET == "ckt":
        g = load_ckt(args)
    elif DATASET == "enas":
        g = load_enas(args)
    elif DATASET == "bn":
        g = load_bn(args)
    elif DATASET == "ast":
        g = load_ast(args)
    elif DATASET == "mol":
        g = load_mols(args)
    else:
        raise NotImplementedError
    return g
    
def get_args():
    parser = ArgumentParser()
    # global args
    parser.add_argument("--visualize", dest="global_visualize", action='store_true')
    parser.add_argument("--cache", dest="global_cache", action='store_true')    
    parser.add_argument("--num_threads", dest="global_num_threads", type=int)
    parser.add_argument("--num_procs", dest="global_num_procs", type=int)    
    # hparams
    parser.add_argument("--scheme", choices=['one','zero'], help='whether to index from 0 or 1', default='zero')    
    # ablations
    parser.add_argument("--ablate_tree", action='store_true') 
    parser.add_argument("--ablate_merge", action='store_true') 
    parser.add_argument("--ablate_root", action='store_true') 
    # task params
    parser.add_argument("--task", nargs='+', choices=["learn","generate","prediction"])
    parser.add_argument("--seed")
    parser.add_argument("--grammar_ckpt")
    # mol dataset args
    parser.add_argument(
        "--mol-dataset",
        choices=["ptc","hopv","polymers_117", "isocyanates", "chain_extenders", "acrylates"],
    )
    parser.add_argument(
        "--num-data-samples", type=int
    )
    parser.add_argument("--ambiguous-file", help='if given and exists, load data from this file to learn grammar; if given and not exist, save ambiguous data to this file after learn grammar')
    parser.add_argument("--num_samples", default=10000, type=int, help='how much to generate')
    return parser.parse_args([
        '--task', 'learn',
        '--ambiguous-file', 'cache/api_bn_ednce/ambig_1.json'
    ])

In [None]:
args = get_args()
g = load_data(args)
cache_iter, cache_path = setup()
g, grammar, anno, iter = init_grammar(g, cache_iter, cache_path, EDNCEGrammar)
# grammar, model, anno, g = terminate(g, grammar, anno, iter)
# for j, m in enumerate(model):
#     pre = get_prefix(m.id)
#     # draw_tree(m, os.path.join(IMG_DIR, f"model_{iter}_{pre}.png"))
#     model[j] = EDNCEModel(dfs(anno, m.id))

In [None]:
graphs = []
for j in range(1):
    deriv = model[j].seq[::-1]
    deriv = [model[j].graph[n].attrs['rule'] for n in deriv]
    t2r = {i:i for i in range(len(grammar.rules))}
    deriv_g = grammar.derive(deriv, t2r)
    draw_graph(deriv_g, '/home/msun415/test.png')
    graphs.append(deriv_g)

In [None]:
def find_indices(graphs, query):
    ans = []
    for i in range(len(graphs)):
        if nx.is_isomorphic(graphs[i], query):
            ans.append(i)
    return ans


def node_match(d1, d2):
    return d1.get("label", "#") == d2.get("label", "#")    


def find_partial(graphs, query):
    ans = []
    # query can be a (possibly disconnected) directed graph
    # query_und = nx.Graph(query)
    for i in range(len(graphs)):
        bad = False
        if len(query) > len(graphs[i]):
            continue
        if len(query.edges) > len(graphs[i].edges):
            continue
        # for conn in nx.connected_components(query_und):
            # conn_g = copy_graph(query, conn)
        gm = DiGraphMatcher(graphs[i], query, node_match=node_match)
        ism_iter = list(gm.subgraph_isomorphisms_iter())
        if len(ism_iter) == 0:
            break
        if not bad:
            ans.append(i)
    return ans


def worker(shared_queue, found, lock):
    while True:
        with lock:
            if shared_queue.empty():
                print("process done")
                break
            print(f"len(interms): {shared_queue.qsize()}")
            interm, deriv, poss = shared_queue.get()
        print(f"deriv: {deriv}")
        nts = grammar.search_nts(interm, NONTERMS)
        if len(nts) == 0:
            if nx.is_isomorphic(interm, graphs[poss], node_match=node_match):
                with lock:
                    found[poss].append(deriv)
                    print(f"found {deriv} graph {poss}, count: {len(found[poss])}")
        for j, nt in enumerate(nts):
            for i, rule in enumerate(grammar.rules):                      
                nt_label = interm.nodes[nt]['label']
                if rule.nt == nt_label:
                    c = deepcopy(interm)
                    c = rule(c, nt)
                    if nx.is_connected(nx.Graph(c)):
                        # if poss == 0 and i == 62:
                        #     pdb.set_trace()                     
                        ts = [x for x in c if c.nodes[x]['label'] in TERMS]
                        c_t = copy_graph(c, ts)
                        exist = find_partial([graphs[poss]], c_t)
                        if exist:
                            with lock:
                                shared_queue.put((c, deriv+[i], poss))            


NUM_PROCS = 50
N = len(graphs)
manager = mp.Manager()
shared_queue = manager.Queue()
found = manager.list()
g = nx.DiGraph()
g.add_node('0', label='black')
for j in range(NUM_PROCS):
    shared_queue.put((deepcopy(g), [], j))
    found.append(manager.list())
lock = manager.Lock()
processes = []
for _ in range(NUM_PROCS):
    p = mp.Process(target=worker, args=(shared_queue, found, lock))
    p.start()
    processes.append(p)
for p in processes:
    p.join()

In [None]:
all_derivs = list(map(list, found))
sets_of_sets = []
for derivs in all_derivs:
    sets = []
    for i in range(len(derivs)):
        # keep this deriv
        for j in range(len(derivs)):
            if j == i:
                continue
            sets.append(set(derivs[j])-set(derivs[i]))
    sets_of_sets.append(sets)

In [None]:
np.argwhere([len(deriv) > 1 for deriv in all_derivs])
# [len(deriv) > 1 for deriv in all_derivs]
poss_elims = []
for chosen in product(*[sets for sets in sets_of_sets if sets]):
    elim = set.union(*chosen)
    exist = False
    for p in poss_elims:
        if p == elim:
            exist = True
            break
    if not exist:
        poss_elims.append(elim)

poss_elims = sorted(poss_elims, key=len)
for i in range(len(poss_elims)):
    if poss_elims[i] is None:
        continue
    for j in range(i+1, len(poss_elims)):
        if poss_elims[j] is None:
            continue
        if not (poss_elims[i]-poss_elims[j]):
            poss_elims[j] = None

min_poss_elims = list(filter(lambda x: x is not None, poss_elims))
best_e = None
best_counter = None
for e in min_poss_elims:
    counter = []
    for i, derivs in enumerate(all_derivs):
        inters = [bool(set(derivs[j]) & e) for j in range(len(derivs))]
        if np.all(inters):
            counter.append(i)
    if best_counter is None or len(counter) < len(best_counter):
        best_counter = counter
        best_e = e



In [None]:
rmap = {}
d = 0
for i in range(len(grammar.rules)):
    if i in best_e:
        d += 1
        continue
    rmap[i] = i-d

for i in range(len(model)-1,-1,-1):
    if i in best_counter:
        model.pop(i)
        continue
    for n in model[i].graph:
        r = model[i].graph[n].attrs['rule']
        model[i].graph[n].attrs['rule'] = rmap[r]

