In [4]:
import os
import pickle
hash_dir = '/ssd/msun415/program_cache-bb=1000-prods=2/'
programs = pickle.load(open(os.path.join(hash_dir, '2.pkl'), 'rb'))

10:20:41 rdkit INFO: Enabling RDKit 2022.09.5 jupyter extensions


In [5]:
import networkx as nx
from synnet.utils.data_utils import binary_tree_to_skeleton, Program

def prog_to_binary_tree(prog):
    g = nx.DiGraph()
    i = 0
    child_dic = {}
    for e in prog.entries:
        if isinstance(e, tuple):
            g.add_node(i)
            child_dic[e[0]] = child_dic.get(e[0], []) + [(i, e[1]==0)]
            i += 1
        else:
            g.add_node(i)
            g.add_node(i+1)            
            child_dic[e] = child_dic.get(e, []) + [(i, False), (i+1, True)]
            i += 2
    r_dic = {}    
    for r in prog.rxn_tree:
        n = len(g)
        g.add_node(n)
        for c in list(prog.rxn_tree[r]):
            interm = r_dic[c]
            left = prog.rxn_tree.nodes[c]['child'] == 'left'
            g.add_edge(n, interm, left=left)
            g.nodes[interm]['left'] = left
        childs = child_dic[r]
        for c, left in childs:    
            g.add_edge(n, c, left=left)
            g.nodes[c]['left'] = left
        r_dic[r] = n
    return g
    

In [6]:
import heapq
import multiprocessing as mp
from multiprocessing.pool import ThreadPool
from tqdm import tqdm
from tdc import Oracle

max_size = 128
batch_size = 10000

heap = []


# def scoring(prog):
#     oracle = Oracle(name="QED")     
#     prog.product_map.load()
#     score = 0.
#     prod_map = prog.product_map._product_map
#     e = 1 if d == 2 else 0
#     scores = [oracle(prod) for prod in list(prod_map[e])]
#     score = max(scores)
#     prog.product_map.unload()
#     print(score)
#     return score


def score_single(smi, d, i):
    oracle = Oracle(name="QED")
    res = []
    if isinstance(smi, list):
        for s in smi:
            score = oracle(smi)
            res.append((score, d, i))
    else:
        score = oracle(smi)
        return (score, d, i)
    return res

def score_list(lis):
    assert isinstance(lis, list)
    oracle = Oracle(name="QED")
    res = []
    for smi, d, i in lis:
        score = oracle(smi)
        res.append((score, d, i))
    return res
    
    


In [7]:
import random
random.seed(0)
batches_per_program = 1
fpath = '/home/msun415/SynTreeNet/scores-qed.txt'
heap = []
max_size = 128

def load_file(fpath):
    res = []
    with open(fpath) as f:
        lines = f.readlines()[1:]
        for line in tqdm(lines, desc="processing lines"):
            line = line.strip('(').rstrip(')\n')
            if not line:
                continue  
            res.append(line)
        return res
    
# lines = load_file(fpath)

def re_eval():
    for line in lines:
        score, d, ind = line.split(', ')
        print(line)
        d = int(d)
        i = int(ind)
        prog = programs[d][int(ind)]
        prog.product_map.load()
        prod_map = prog.product_map._product_map
        e = 1 if d == 2 else 0
        smis = [(prod, d, i) for prod in list(prod_map[e])]
        random.shuffle(smis)
        smis = smis[:batches_per_program*batch_size] # only do first batch
        inner_batch_size = (len(smis)+99)//100
        smi_lists = [smis[i*inner_batch_size:(i+1)*inner_batch_size] for i in range(100)]
        with mp.Pool(100) as p:
            scores = p.map(score_list, tqdm(smi_lists, desc="single batch"))        
        res = [r for res_lis in scores for r in res_lis]
        for (score, d, j), smi in zip(res, smis):
            if len(heap) < max_size:
                heapq.heappush(heap, (score, d, j, smi))
            else:
                small = heapq.heappushpop(heap, (score, d, j, smi))
        print(heap)


fpath = '/home/msun415/SynTreeNet/scores-logp.txt'
def load_file2(fpath):
    res = []
    with open(fpath) as f:
        lines = f.readlines()[1:]
        for line in tqdm(lines, desc="processing lines"):
            line = line.strip('(').rstrip(')\n')
            if not line:
                continue  
            res.append(line)
        return res

In [18]:
from synnet.encoding.fingerprints import fp_2048
import json

def lines_to_json(heap):
    dics = []
    for score, d, j, smi in heap:
        dic = {}
        d = int(d)
        j = int(j)
        tree = prog_to_binary_tree(programs[d][j])
        root = next(v for v, d in tree.in_degree() if d == 0)
        dic['bt'] = nx.tree_data(tree, root)
        dic['fp'] = fp_2048(smi)
        dic['smi'] = smi
        dics.append(dic)
    return dics

# json.dump(lines_to_json(heap), open('/home/msun415/SynTreeNet/indvs-qed.json', 'w+'))

In [19]:
for prop in ['drd2','jnk','logp']:
    heap = [l.split(', ') for l in load_file2(f'/home/msun415/SynTreeNet/scores-{prop}.txt')]
    json.dump(lines_to_json(heap), open(f'/home/msun415/SynTreeNet/indvs-{prop}.json', 'w+'))
    print(f'/home/msun415/SynTreeNet/indvs-{prop}.json')

processing lines: 100%|██████████| 131/131 [00:00<00:00, 906689.48it/s]


/home/msun415/SynTreeNet/indvs-drd2.json


processing lines: 100%|██████████| 131/131 [00:00<00:00, 1445931.12it/s]


/home/msun415/SynTreeNet/indvs-jnk.json


processing lines: 100%|██████████| 131/131 [00:00<00:00, 1489034.75it/s]


/home/msun415/SynTreeNet/indvs-logp.json


In [14]:
from pathlib import Path
import os
import shutil

base = '/home/msun415/SynTreeNet/results/logs/gnn'
for version in os.listdir(base):
    ckpts = Path(os.path.join(base, version)).rglob("*.ckpt")
    ckpts = list(ckpts)
    if ckpts:
        print("pass")
    else:
        print(version)
        shutil.rmtree(os.path.join(base, version))

version_90
version_8
version_213
version_109
version_115
pass
version_103
version_52
version_91
pass
pass
version_60
pass
version_70
version_112
version_6
pass
pass
version_110
pass
version_49
pass
version_76
pass
version_82
version_33
version_102
pass
version_210
version_94
version_207
version_2
version_100
version_47
version_51
version_87
version_75
version_5
version_13
version_25
version_72
pass
pass
version_48
version_69
version_106
pass
version_79
version_92
version_63
version_208
pass
version_95
version_211
version_219
version_39
version_116
version_9
version_0
version_59
version_10
version_28
version_65
pass
version_117
version_209
pass
version_64
pass
pass
version_34
pass
version_50
pass
version_27
version_66
pass
version_35
version_78
version_81
version_41
pass
pass
pass
pass
version_71
version_104
pass
pass
version_29
version_55
version_7
version_212
pass
pass
version_222
version_12
version_77
version_30
pass
version_220
version_36
gnn
version_107
version_73
version_11
versio

'version_44'