In [None]:
import os
import pickle
hash_dir = '/ssd/msun415/program_cache-bb=1000-prods=2/'
programs = pickle.load(open(os.path.join(hash_dir, '2.pkl'), 'rb'))

In [None]:
import networkx as nx
from synnet.utils.data_utils import binary_tree_to_skeleton, Program

def prog_to_binary_tree(prog):
    g = nx.DiGraph()
    i = 0
    child_dic = {}
    for e in prog.entries:
        if isinstance(e, tuple):
            g.add_node(i)
            child_dic[e[0]] = child_dic.get(e[0], []) + [(i, e[1]==0)]
            i += 1
        else:
            g.add_node(i)
            g.add_node(i+1)            
            child_dic[e] = child_dic.get(e, []) + [(i, False), (i+1, True)]
            i += 2
    r_dic = {}    
    for r in prog.rxn_tree:
        n = len(g)
        g.add_node(n)
        for c in list(prog.rxn_tree[r]):
            interm = r_dic[c]
            left = prog.rxn_tree.nodes[c]['child'] == 'left'
            g.add_edge(n, interm, left=left)
            g.nodes[interm]['left'] = left
        childs = child_dic[r]
        for c, left in childs:    
            g.add_edge(n, c, left=left)
            g.nodes[c]['left'] = left
        r_dic[r] = n
    return g
    

In [None]:
import heapq
import multiprocessing as mp
from multiprocessing.pool import ThreadPool
from tqdm import tqdm
from tdc import Oracle

max_size = 128
batch_size = 10000

heap = []


# def scoring(prog):
#     oracle = Oracle(name="QED")     
#     prog.product_map.load()
#     score = 0.
#     prod_map = prog.product_map._product_map
#     e = 1 if d == 2 else 0
#     scores = [oracle(prod) for prod in list(prod_map[e])]
#     score = max(scores)
#     prog.product_map.unload()
#     print(score)
#     return score


def score_single(smi, d, i):
    oracle = Oracle(name="QED")
    res = []
    if isinstance(smi, list):
        for s in smi:
            score = oracle(smi)
            res.append((score, d, i))
    else:
        score = oracle(smi)
        return (score, d, i)
    return res

def score_list(lis):
    assert isinstance(lis, list)
    oracle = Oracle(name="QED")
    res = []
    for smi, d, i in lis:
        score = oracle(smi)
        res.append((score, d, i))
    return res
    
    


In [None]:
import random
random.seed(0)
batches_per_program = 1
fpath = 'SynthesisNet/scores-qed.txt'
heap = []
max_size = 128

def load_file(fpath):
    res = []
    with open(fpath) as f:
        lines = f.readlines()[1:]
        for line in tqdm(lines, desc="processing lines"):
            line = line.strip('(').rstrip(')\n')
            if not line:
                continue  
            res.append(line)
        return res
    
# lines = load_file(fpath)

def re_eval():
    for line in lines:
        score, d, ind = line.split(', ')
        print(line)
        d = int(d)
        i = int(ind)
        prog = programs[d][int(ind)]
        prog.product_map.load()
        prod_map = prog.product_map._product_map
        e = 1 if d == 2 else 0
        smis = [(prod, d, i) for prod in list(prod_map[e])]
        random.shuffle(smis)
        smis = smis[:batches_per_program*batch_size] # only do first batch
        inner_batch_size = (len(smis)+99)//100
        smi_lists = [smis[i*inner_batch_size:(i+1)*inner_batch_size] for i in range(100)]
        with mp.Pool(100) as p:
            scores = p.map(score_list, tqdm(smi_lists, desc="single batch"))        
        res = [r for res_lis in scores for r in res_lis]
        for (score, d, j), smi in zip(res, smis):
            if len(heap) < max_size:
                heapq.heappush(heap, (score, d, j, smi))
            else:
                small = heapq.heappushpop(heap, (score, d, j, smi))
        print(heap)


fpath = 'SynthesisNet/scores-logp.txt'
def load_file2(fpath):
    res = []
    with open(fpath) as f:
        lines = f.readlines()[1:]
        for line in tqdm(lines, desc="processing lines"):
            line = line.strip('(').rstrip(')\n')
            if not line:
                continue  
            res.append(line)
        return res

In [None]:
from synnet.encoding.fingerprints import fp_2048
import json

def lines_to_json(heap):
    dics = []
    for score, d, j, smi in heap:
        dic = {}
        d = int(d)
        j = int(j)
        tree = prog_to_binary_tree(programs[d][j])
        root = next(v for v, d in tree.in_degree() if d == 0)
        dic['bt'] = nx.tree_data(tree, root)
        dic['fp'] = fp_2048(smi)
        dic['smi'] = smi
        dics.append(dic)
    return dics

# json.dump(lines_to_json(heap), open('SynthesisNet/indvs-qed.json', 'w+'))

In [None]:
for prop in ['drd2','jnk','logp']:
    heap = [l.split(', ') for l in load_file2(f'SynthesisNet/scores-{prop}.txt')]
    json.dump(lines_to_json(heap), open(f'SynthesisNet/indvs-{prop}.json', 'w+'))
    print(f'SynthesisNet/indvs-{prop}.json')

In [None]:
from pathlib import Path
import os
import shutil

base = 'SynthesisNet/results/logs/gnn'
for version in os.listdir(base):
    ckpts = Path(os.path.join(base, version)).rglob("*.ckpt")
    ckpts = list(ckpts)
    if ckpts:
        print("pass")
    else:
        print(version)
        shutil.rmtree(os.path.join(base, version))

In [None]:
from synnet.config import DELIM
lines = open('SynthesisNet/output_reconstruct_top_k=3_max_num_rxns=3_max_rxns=-1.txt').readlines()
recovered = {'targets': [], 'decoded': []}
unrecovered = {'targets': [], 'decoded': []}
for line in lines:
    index, res, score = line.split(' ')
    target, decoded, index = res.split(DELIM)
    score = float(score)
    if score == 1.0:
        recovered['targets'].append(target)
        recovered['decoded'].append(decoded)
    else:
        unrecovered['targets'].append(target)
        unrecovered['decoded'].append(decoded)

In [None]:
from tdc import Evaluator

for metric in "KL_divergence FCD_Distance Novelty Validity Uniqueness".split():
    evaluator = Evaluator(name=metric)
    try:
        score_recovered = evaluator(recovered["targets"], recovered["decoded"])
        score_unrecovered = evaluator(unrecovered["targets"], unrecovered["decoded"])
    except TypeError:
        # Some evaluators only take 1 input args, try that.
        score_recovered = evaluator(recovered["decoded"])
        score_unrecovered = evaluator(unrecovered["decoded"])
    except Exception as e:
        print(e)
        score_recovered, score_unrecovered = np.nan, np.nan

    print(f"Evaluation metric for {evaluator.name}:")
    print(f"    Recovered score: {score_recovered:.2f}")
    print(f"  Unrecovered score: {score_unrecovered:.2f}")


In [None]:
len(unrecovered['targets'])

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.colors as mcolors



for pat in ['Top 1','Top 10','Top 100','Top 1 SA','Top 10 SA','Top 100 SA','Top 1 AUC','Top 10 AUC','Top 100 AUC']:
    df = pd.read_csv('SynthesisNet/data/pmo/results.csv', index_col=0)
    heat_plot = {}
    for col in df.columns:    
        ranks = []        
        if col[-len(pat):] != pat:
            continue
        if 'qed' in col:
            continue
        if '7l11' in col:
            continue
        if 'drd3' in col:
            continue
        for v in df[col]:
            v = str(v)
            if v == 'nan':
                rank = -1
            else:
                start = v.index('(')
                end = v.index('|')
                rank = int(v[start+1:end])
            ranks.append(rank)
        heat_plot[col] = ranks

    # Convert the dictionary to a DataFrame for easier manipulation
    df = pd.DataFrame(heat_plot, index=df.index)

    # Create a custom color map
    colors = [(0, "green"), (0.5, "yellow"), (1, "red")]  # From green (best) to red (worst)
    cmap = mcolors.LinearSegmentedColormap.from_list("rank_cmap", colors)

    # Normalize the values for the color map (excluding -1)
    norm = mcolors.Normalize(vmin=1, vmax=df.max().max())

    # Set up the figure and axis
    fig, ax = plt.subplots()

    # Create the heatmap
    sns.heatmap(df.T, fmt="d", cmap=cmap, cbar=False, linewidths=.5, 
                linecolor='black', mask=(df.T == -1), ax=ax, norm=norm)

    # Manually add the black cells for -1
    for i in range(df.shape[1]):
        for j in range(df.shape[0]):
            if df.iloc[j, i] == -1:
                ax.add_patch(plt.Rectangle((i, j), 1, 1, fill=True, color='black', edgecolor='black'))

    # Create the colorbar on the same axis
    cbar = plt.colorbar(mappable=plt.cm.ScalarMappable(norm=norm, cmap=cmap), ax=ax, ticks=range(1, df.max().max() + 1))
    cbar.ax.invert_yaxis()  # Invert the color bar so that green (best) is at the top

    # Set labels and title
    ax.set_xlabel('Baselines')
    ax.set_ylabel('Metrics')
    ax.set_title('Baseline Rankings Heatmap')

    fig.savefig(f'SynthesisNet/data/pmo/rankings/{pat}.png', bbox_inches='tight')