In [1]:
from rdkit import Chem, RDLogger

from collections import defaultdict

import queue
import pandas as pd
import heapq

from pathlib import Path
import sys
import time

proj_root = (Path.cwd().parent).resolve()
sys.path.insert(0, str(proj_root))

from src.evolve import evol_single
from src.degen import *
from src.utils import *
from src.rediscovery import *

OPERATION = {
    "connect":"c",
    "vinyl":"v",
    "ethynyl":"e",
    "annulate_pl2":"2",
    "annulate_pl4":"4",
    "phenyl":"p"
}

operation_improvement = {
    "connect": 1,
    "vinyl": 1,
    "ethynyl": 1,
    "annulate_pl2": 2,
    "annulate_pl4": 3,
    "phenyl": 4
}


In [2]:
target_smi = Chem.MolToSmiles(Chem.MolFromSmiles("C12=CC=CC(C3=C(C4=CC=C5)C5=CC=C3)=C1C4=CC=C2"))
# Retrosynthesis analysis
rds = RediscoverySubstructure(target_smi)

edges = set()
nodes = []
paths = defaultdict(int)
gens = defaultdict(int)
parents = defaultdict(set)
nodes.append("[H][H]")
nodes_pool = [(0, "[H][H]")]
heapq.heapify(nodes_pool)
paths["[H][H]"] = 1
gens["[H][H]"] = 0

prev_repl = "[H][H]"
while nodes_pool:
    print(f"nodes_pool size: {len(nodes_pool)}, nodes size: {len(nodes)}, edges size: {len(edges)}")
    _, smi_p = heapq.heappop(nodes_pool)
    print(smi_p,gens[smi_p],paths[smi_p])
    if prev_repl == smi_p:
        print(f"updated!!: {smi_p}")
    for operation, children in evol_single(smi_p, operations=("connect","vinyl","ethynyl","annulate_pl2","annulate_pl4","phenyl")).items():
        for smi_c in children:
            score = rds.score(smi_c)
            if score > -0.5:
                if prev_repl == smi_p:
                    prev_repl = smi_c
                    print(f"  child: {smi_c}")
                edges.add((smi_p,smi_c,operation))
                parents[smi_c].add(smi_p)
                if not smi_c in paths.keys():
                    nodes.append(smi_c)
                    gens[smi_c] = gens[smi_p] + operation_improvement[operation]
                    heapq.heappush(nodes_pool, ((gens[smi_p]+operation_improvement[operation]), smi_c))
                paths[smi_c] += paths[smi_p]

nodes_pool size: 1, nodes size: 1, edges size: 0
[H][H] 0 1
updated!!: [H][H]
  child: C=C
nodes_pool size: 2, nodes size: 3, edges size: 2
C=C 1 1
updated!!: C=C
  child: C=CC=C
nodes_pool size: 3, nodes size: 5, edges size: 5
C=CC=C 2 1
updated!!: C=CC=C
  child: C=CC(=C)C=C
nodes_pool size: 6, nodes size: 9, edges size: 11
C=CC(=C)C=C 3 1
updated!!: C=CC(=C)C=C
  child: C=CC(=C)C(=C)C=C
nodes_pool size: 11, nodes size: 15, edges size: 19
C=CC=CC=C 3 1
nodes_pool size: 15, nodes size: 20, edges size: 29
C=CC(=C)C(=C)C=C 4 1
updated!!: C=CC(=C)C(=C)C=C
  child: C=c1ccccc1=C
nodes_pool size: 20, nodes size: 26, edges size: 38
C=CC=C(C=C)C=C 4 2
nodes_pool size: 27, nodes size: 34, edges size: 51
C=CC=CC(=C)C=C 4 2
nodes_pool size: 39, nodes size: 47, edges size: 74
C=CC=CC=CC=C 4 1
nodes_pool size: 44, nodes size: 53, edges size: 87
c1ccccc1 4 4
nodes_pool size: 45, nodes size: 55, edges size: 90
C=CC(=C)C(=C)C(=C)C=C 5 1
nodes_pool size: 52, nodes size: 63, edges size: 101
C=CC(=C)C=C

In [3]:
print(paths[target_smi])

592431631


In [4]:
node_by_generation = defaultdict(int)
for node in nodes:
    gen = sum(num_unit_method(node))
    node_by_generation[gen] += 1
print(node_by_generation)

defaultdict(<class 'int'>, {0: 1, 1: 1, 4: 5, 2: 1, 5: 14, 3: 2, 6: 41, 7: 134, 8: 376, 9: 837, 10: 1219, 11: 884, 12: 323, 13: 68, 14: 7, 15: 1})


In [14]:
for child, parent_set in parents.items():
    degen_smis_set = set()
    for operation, degen_smis in degen_single(smi=child,operations=("connect","vinyl","ethynyl","annulate_pl2","annulate_pl4","phenyl")).items():
        for degen_smi in degen_smis:
            if not (degen_smi in parent_set):
                print(f"parent -> child not found (degen): parent:{degen_smi}, child:{child}")
        degen_smis_set.update(degen_smis)
    for parent in parent_set:
        if not parent in degen_smis_set:
            print(f"child -> parent not found (evol): parent:{parent}, child:{child}")

In [5]:
unit = len([e for e in edges if e[2] in {"connect","vinyl","ethynyl"}])
comp = len([e for e in edges if e[2] in {"annulate_pl2","annulate_pl4","phenyl"}])
print(unit, comp)

26974 4098
