In [2]:
from jasyntho import SynthTree
from jasyntho.extract import ExtractReaction
import matplotlib.pyplot as plt
import networkx as nx
from networkx.drawing.nx_agraph import graphviz_layout

def extract_subgraph(graph, start_node):
    """Use BFS to find all nodes reachable from start_node."""
    reachable_nodes = set(nx.bfs_tree(graph, start_node))
    return graph.subgraph(reachable_nodes).copy()

def plot_graph(G):
    fig = plt.figure(figsize=(10, 7))
    pos = graphviz_layout(G, prog="dot")
    nx.draw(G, pos, with_labels=True, arrows=True)
    plt.show()

def update_graph(G, new_edges, rm_edges, rm_nodes, names_mapping):
    """function takes as input new_edges, rm_edges, rm_nodes, names_mapping"""

    # Relabel nodes
    G = nx.relabel_nodes(G, names_mapping)

    # Updates edges
    for e in new_edges:
        G.add_edge(e[0], e[1])
    for e in rm_edges:
        G.remove_edge(e[0], e[1])

    # Remove nodes
    for n in rm_nodes:
        G.remove_node(n)
    
    # Remove nodes with length 1
    reach_sgs = SynthTree.get_reach_subgraphs(G)
    for k, g in reach_sgs.items():
        if len(g) == 1:
            G.remove_node(k)
    return G

async def extract_tree(path, model='gpt-3.5-turbo'):
    tree = SynthTree.from_dir(path)
    tree.rxn_extract = ExtractReaction(model=model)

    tree.raw_prods = await tree.async_extract_rss(mode='vision')
    tree.products = [p for p in tree.raw_prods if not p.isempty()]

    reach_sgs = tree.partition()
    return tree

  from .autonotebook import tqdm as notebook_tqdm


# ja074300t

In [None]:
path = '../benchmark/papers/ja074300t'
tree = await extract_tree(path)

In [None]:
fg = tree.full_g.copy()

names_mapping = {
    'SI–1 2': 'SI-1',
    'oxazolidinone-lactol 70': '70',
    '(Z)-114': '114',
    'Aldehyde': 'SI-30',
}

# a <- b reaction.  a is product, b is reactant
new_edges = [
    ('58', '(-)-36'),
    ('58', '57'),
    ('58', 'DME'),
    ('57', 'EtOAc'),
    ('98', '93'),
    ('98', '93'),
    ('98', 'sodium chlorite'),
    ('98', 'MeCN'),
    ('98', 'sodium phosphate'),
    ('98', 'TEMPO'),
    ('98', 'MTBE'),
    ('98', 'MTBE'),
    ('99', '98'),
    ('99', 'MeI'),
    ('99', 'potassium carbonate'),
    ('99', 'DMF'),
    ('SI-17', 'SI-16'),
    ('SI-17', 'THF'),
    ('SI-17', 'NH4Cl'),
    ('SI-17', '4-pentenyl-1-magnesium bromide'),
    ('101', 'SI-17'),
    ('101', 'DMSO'),
    ('101', 'Et3N'),
    ('101', 'oxalyl chloride'),
    ('SI-18', '101'),
    ('SI-18', '102'),
    ('SI-18', 'N-potassiumhexamethyldisilazane'),
    ('103', 'HF'),
    ('103', 'SI-18'),
    ('113', '(E)-107'),
    ('113', '(Z)-107'),
    ('118', '117'),
    ('118', '116'),
    ('122', 'SI-23'),
    ('123', '122'),
    ('SI-24', '123'),
    ('124', 'SI-24'),
]

rm_edges = [
    ('101', 'SI-16'),
    ('101', 'Mg'),
    ('101', '5-bromo-1-pentene'),
]

rm_nodes = [
    'OOOTBSOTIPS1163',
    'TBSSI-223',
    'NNOTIPSOOOTBSSI-23',
    '17b'
]

# gt head nodes
gthn = [
    '32',
    '75',
    '70',
    '100',
    '105',
    '114',
    '1',
    '131',
    '132',
    'SI-29'
]    

reach_sgs = SynthTree.get_reach_subgraphs(fg)
print(len(reach_sgs))

# for g in reach_sgs.values():
#     plot_graph(g)
#     print(g.nodes)
    # if len(g) > 1:
    #     if not any(map(lambda x: x in g, gthn)):
    #         plot_graph(g)
    #         print(g.nodes)


import pickle
with open(os.path.join(path, "gt_graph.pickle"), "wb") as f:
    pickle.dump(fg, f)

# ja512124c

In [None]:
path = '../benchmark/papers/jacs.0c00308'
tree = await extract_tree(path)


In [None]:

import os
fg = tree.full_g.copy()

names_mapping = {
    'epoxide 25': '25',
    'cyclohexadienone 26': '26',
    'tetracycle 22': '22',
    'diol 24': '24',
    'diketone 13': '13',
    'enone 23': '23',
    'alcohol 27': '27',
    'spirolleycle 30': '30',
    "corresponding ketone": "18",
    '4,4-dimethyl cyclopentanone 37':'37',
    'Acetonide 38':'38',
    'tetracyclic diketones C7-deoxy-13': 'C7-deoxy-13',
    'MOM ether S-6': 'S-6',
    'vinylphenol C7-deoxy-14': 'C7-deoxy-14',
}

# a <- b reaction.  a is product, b is reactant
new_edges = [
    ('S-3', 'IBX'),
    ('S-3', 'DMSO'),
    ('S-3', 'NaOH'),
    ('S-1', '16'),
    ('S-2', 'BzOH'),
    ('S-2', 'PPh3'),
    ('S-2', 'DIAD'),
    ('S-2', 'p-TsOH•H2O'),
    ('15', 'H2NNHCONH2•HCl'),
    ('15', 'Pb(OAc)4'),
    ('15', 'Pd/BaSO4'),
    ('15', 'SiO2'),
    ('C7-epi-18', '15'),
    ('C7-epi-18', '17'),
    ('C7-epi-18', 'n-BuLi'),
    ('18', 'C7-epi-18'),
    ('C7-deoxy-C8,C13-diepi-13', 'C7-deoxy-14'),
    ('C7-deoxy-C8,C13-diepi-13', 'HFIP'),
    ('C7-deoxy-C8,C13-diepi-13', 'PIFA'),
    ('C7-deoxy-C8,C13-diepi-13', 'Na2S2O3'),
    ('S-6', 'S-5'),
    ('S-6', 'NaHMDS'),
    ('37', '33'),
    ('39', '38'),
    ('S-10', '39'),
    ('S-9', '39'),
    ('41', '39'),
    ('7', '42'),
    ('7', 'VO(acac)2'),
    ('7', '4A molecular sieve'),
    ('7', 'TBHP'),
    ('7', 'NaBH4'),
]

rm_edges = [
    ('S-2', 'S-3'),
    ('S-1', 'S-2'),
    ('S-1', 'S-3'),
    ('S-3', 'S-1'),
    ('15', 'S-1'),
    ('S-6', '18'),
    ('38', '39'),
    ('7', '1'),
    ('7', '5'),
    ('7', '2'),
    ('7', '10'),
]

rm_nodes = [
    '17b',
    '9',
    '5',
    '4',
    '2',
    '10',
    '11',
    'P1',
    '18a',
    'S4',
    '8'
]

# gt head nodes
gthn = [
    'C7-deoxy-13',
    'C7,C8,C13-triepi-13',
    'C7-deoxy-C8,C13-diepi-13',
    'C14-epi-27',
    '21',
    '30',
    '34',
    '35',
    '36',
    'S-8',
    'S-9',
    'S-10',
    'S-11',
    '46',
    '7',
]    

fg = update_graph(fg, new_edges, rm_edges, rm_nodes, names_mapping)

reach_sgs = SynthTree.get_reach_subgraphs(fg)
print(len(reach_sgs))

# for g in reach_sgs.values():
#     if len(g) > 1:
#         # if not any(map(lambda x: x in g, gthn)):
#             plot_graph(g)
#             print(g.nodes)


import pickle
with open(os.path.join(path, "gt_graph.pickle"), "wb") as f:
    pickle.dump(fg, f)

# jacs.0c00363

In [5]:
path = '../benchmark/papers/jacs.0c00363'
tree = await extract_tree(path)

Finished processing batch. Cost: 0.019095
Finished processing batch. Cost: 0.02968
Finished processing batch. Cost: 0.030175000000000004
Finished processing batch. Cost: 0.030520000000000002
Finished processing batch. Cost: 0.03175
Finished processing batch. Cost: 0.030160000000000003
Finished processing batch. Cost: 0.031120000000000002
Finished processing batch. Cost: 0.030070000000000003
Finished processing batch. Cost: 0.029725
Finished processing batch. Cost: 0.031735000000000006
Finished processing batch. Cost: 0.031810000000000005
Finished processing batch. Cost: 0.035125
Finished processing batch. Cost: 0.031465
Finished processing batch. Cost: 0.03241
Finished processing batch. Cost: 0.03572500000000001
Finished processing batch. Cost: 0.029710000000000004
Finished processing batch. Cost: 0.030010000000000002
Finished processing batch. Cost: 0.041095000000000007
Finished processing batch. Cost: 0.046150000000000004
Finished processing batch. Cost: 0.047170000000000004
Finished

In [40]:
import os
fg = tree.full_g.copy()

names_mapping = {
    'Compound 12': '12',
    'Compound 13': '13',
    'Compound 16': '16',
    'Compound 24': '24',
    'Compound S4': 'S4',
    'Compound 19': '19',
    'Compound 20': '20',
    'Compound S5': 'S5',
    'Crude S3': 'S3',
    'Mixture of 15a, 15b, and 15c': '15a',
}

# a <- b reaction.  a is product, b is reactant
new_edges = [
   ('18', '17'),
   ('19', '18'),
   ('19', 'LiAlH4'),
   ('19', 'NaOH'),
   ('20', '19'),
   ('20', 'NaHCO3'),
   ('20', 'DMP'),
   ('20', 'S4'),
   ('S5', '20'),
   ('S6', 'S5'),
   ('10', 'S1'),
   ('S2', '10'),
   ('11', 'S2'),
   ('S3', '11'),
   ('12', 'S3'),
   ('15b', '12'),
   ('15b', 'n-C4F9SO2F'),
   ('15b', 'DBU'),
   ('15c', '12'),
   ('15c', 'n-C4F9SO2F'),
   ('15c', 'DBU'),
]

rm_edges = [
    ('S4', 'S2'),
    ('S4', 'S3'),
    ('S5', 'S3'),
    ('S5', 'S1'),
    ('S2', 'S1'),
    ('12', '11'),
]

rm_nodes = [
    '2',
    '3',
    '4',
    'Natural Propindilactone G',
    'Synthetic Propindilactone G',
    'Compound B',
    'Compound A',
    '17a',
    'P1',
    '18a'
]

# gt head nodes
gthn = [
    '1',
    '24',
    '15a',
    '15b',
    '15c',
    '13',

    
]    

fg = update_graph(fg, new_edges, rm_edges, rm_nodes, names_mapping)

reach_sgs = SynthTree.get_reach_subgraphs(fg)
print(len(reach_sgs))

# for g in reach_sgs.values():
#     if len(g) > 1:
#         if not any(map(lambda x: x in g, gthn)):
#             plot_graph(g)
#             print(g.nodes)

import pickle
with open(os.path.join(path, "gt_graph.pickle"), "wb") as f:
    pickle.dump(fg, f)

6
