In [1]:
import os
import pickle
from jasyntho import SynthTree
from jasyntho.extract import ExtractReaction

async def extract_tree(path, model='gpt-3.5-turbo', method='text'):
    tree = SynthTree.from_dir(path)
    tree.rxn_extract = ExtractReaction(llm=model)

    tree.raw_prods = await tree.async_extract_rss(mode=method)
    tree.products = [p for p in tree.raw_prods if not p.isempty()]

    reach_sgs = tree.partition()
    return tree

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load ground truth graph
import networkx as nx

path = '../benchmark/papers/ja074300t'
with open(os.path.join(path, 'gt_graph.pickle'), 'rb') as f:
    gt_G = pickle.load(f)

In [3]:
# Extract a graph from paper

# tree = await extract_tree(path, model='gpt-3.5-turbo', method='vision')
# extracted_G = tree.full_g

# with open(os.path.join(path, 'extracted_graph_gpt35_vision.pickle'), 'wb') as f:
#     pickle.dump(extracted_G, f)

with open(os.path.join(path, 'extracted_graph_gpt35_vision.pickle'), 'rb') as f:
    extracted_G = pickle.load(f)

# Ways of comparing the two graphs

- Graph Edit Distance  # very slow
- Subgraph matching
- Spectral analysis
- Edge overlap

In [None]:
# sorts of subgraph matching

# Find subgraphs of extracted graph, and compare with ground truth graph
def find_subgraphs_larger_than_n(G, N):
    subgraphs = []
    
    # Find all connected components (subgraphs) in the graph
    connected_subgraphs = nx.connected_components(G)
    
    # Iterate over each connected subgraph
    for subgraph_nodes in connected_subgraphs:

        # Check if the size of the subgraph is greater than N
        if len(subgraph_nodes) > N:
            # Create a subgraph from the nodes
            subgraph = G.subgraph(subgraph_nodes)
            subgraphs.append(subgraph)
    
    return subgraphs

# Find subgraphs of size greater than 3
N = 3
subg_3 = find_subgraphs_larger_than_n(extracted_G.to_undirected(), N)


def subgraph_in_gt(subgraph, gt_G):
    # Check if the subgraph is present in the host graph
    matcher = nx.algorithms.isomorphism.GraphMatcher(gt_G, subgraph)
    is_match = matcher.subgraph_is_isomorphic()

    if is_match:
        print("The subgraph is present in the host graph.")
        # Get the mapping of nodes between the subgraph and the host graph
        mapping = matcher.mapping
        print("Mapping:", mapping)
    else:
        print("The subgraph is not present in the host graph.")

for subgraph in subg_3:
    subgraph_in_gt(subgraph, gt_G.to_undirected())

In [5]:
# Compare the immediate neighborhood of the nodes. e.g. is the local structure preserved?

def get_neighborhood_subgraph(G, node):
    # Get the incoming and outgoing neighbors of the node
    in_neighbors = [edge[0] for edge in G.in_edges(node)]
    out_neighbors = [edge[1] for edge in G.out_edges(node)]
    neighbors = set(in_neighbors + out_neighbors + [node])
    subgraph = G.subgraph(neighbors)
    return subgraph

def subgraph_in_gt_exact(subgraph, gt_G):
    """Check if the subgraph is present in the host graph."""
    subg_gt = gt_G.subgraph(subgraph.nodes)
    if len(subg_gt) == len(subgraph):
        return True
    return False

def subgraph_in_gt_isomorphic(subgraph, gt_G):
    matcher = nx.algorithms.isomorphism.GraphMatcher(gt_G, subgraph)
    is_match = matcher.subgraph_is_isomorphic()
    return is_match

def compare_local_exact_0(G, gt_G):
    quant = []
    for node in G.nodes:
        sg = get_neighborhood_subgraph(G, node)
        if len(sg)>1:
            v = subgraph_in_gt_exact(sg, gt_G)
            quant.append(v)
    return sum(quant)/len(quant)
    
def compare_local_exact(G, gt_G):
    c1 = compare_local_exact_0(gt_G, G)
    c2 = compare_local_exact_0(G, gt_G)
    return c1, c2

def compare_local_iso(G, gt_G):
    quant = []
    for node in G.nodes:
        sg = get_neighborhood_subgraph(G, node)
        if len(sg)>1:
            v = subgraph_in_gt_isomorphic(sg, gt_G)
            quant.append(v)
    return sum(quant)/len(quant)


In [6]:
print(compare_local_exact(gt_G, gt_G))
print(compare_local_exact(extracted_G, gt_G))

# Compare with other syntheses
path2 = '../benchmark/papers/jacs.0c00308'
with open(os.path.join(path2, 'gt_graph.pickle'), 'rb') as f:
    other_G = pickle.load(f)
    print(compare_local_exact(other_G, gt_G))
path2 = '../benchmark/papers/jacs.0c00363'
with open(os.path.join(path2, 'gt_graph.pickle'), 'rb') as f:
    other_G = pickle.load(f)
    print(compare_local_exact(other_G, gt_G))

(1.0, 1.0)
(0.3471502590673575, 0.18584070796460178)
(0.0, 0.007874015748031496)
(0.0051813471502590676, 0.011904761904761904)


In [7]:
# Similar thing, but with paths (testing more long-range structure)

def get_paths(G):
    paths = []
    for n0 in G.nodes:
        for n1 in G.nodes:
            if n0 != n1:
                if nx.has_path(G, n0, n1):
                    ps = nx.all_simple_paths(G, source=n0, target=n1)
                    paths += list(ps)
    return paths

def compare_path_exact_0(G, gt_G):
    quant = []
    subgraphs = get_paths(G)
    for path in subgraphs:
        if len(path)>1:
            sg = G.subgraph(path)
            v = subgraph_in_gt_exact(sg, gt_G)
            quant.append(v)
    return sum(quant)/len(quant)

def compare_path_exact(G, gt_G):
    c0 = compare_path_exact_0(G, gt_G)
    c1 = compare_path_exact_0(gt_G, G)
    return c0, c1


print(compare_path_exact(gt_G, gt_G))
print(compare_path_exact(extracted_G, gt_G))

# Compare with other syntheses
path2 = '../benchmark/papers/jacs.0c00308'
with open(os.path.join(path2, 'gt_graph.pickle'), 'rb') as f:
    other_G = pickle.load(f)
    print(compare_path_exact(other_G, gt_G))
path2 = '../benchmark/papers/jacs.0c00363'
with open(os.path.join(path2, 'gt_graph.pickle'), 'rb') as f:
    other_G = pickle.load(f)
    print(compare_path_exact(other_G, gt_G))

(1.0, 1.0)
(0.1385886840432295, 0.39069250709788916)
(0.00035056967572304995, 0.0)
(0.0034423407917383822, 0.00037032465127762005)


In [8]:
# TODO next: compare routes extracted with different methods! see if they make sense

async def extractg(path, model='gpt-3.5-turbo', method='text'):

    tree = await extract_tree(path, model=model, method=method)
    extracted_G = tree.full_g

    if model=='gpt-3.5-turbo':
        k = "gpt35"
    elif model=='gpt-4-turbo':
        k = "gpt4t"
    elif model=='gpt-4':
        k = "gpt4"
    elif model=='gpt-4o':
        k = "gpt4o"

    with open(os.path.join(path, f'extracted_graph_{k}_{method}.pickle'), 'wb') as f:
        pickle.dump(extracted_G, f)

    return extracted_G

In [9]:
eg = await extractg(path, model='gpt-3.5-turbo', method='text')
print(compare_path_exact(eg, gt_G))

eg = await extractg(path, model='gpt-4-turbo', method='text')
print(compare_path_exact(eg, gt_G))

eg = await extractg(path, model='gpt-3.5-turbo', method='vision')
print(compare_path_exact(eg, gt_G))

eg = await extractg(path, model='gpt-4-turbo', method='vision')
print(compare_path_exact(eg, gt_G))

Found key 'SI-7' in multiple children.
Found key 'SiO2' in multiple children.
Found key 'SiO2' in multiple children.
Found key 'NaHMDS' in multiple children.
Found key 'THF' in multiple children.
Found key 'THF' in multiple children.
[93mTotal paragraphs: 87[39m
[93mProcessed paragraphs: 53[39m
[93mFound 39 empty paragraphs.[39m
[93m	Validation error.: 2[39m
[93m	No product found: 37[39m
(0.5382059800664452, 0.18219972842858906)
Found key '99' in multiple children.
Found key 'SI-7' in multiple children.
[93mTotal paragraphs: 87[39m
[93mProcessed paragraphs: 65[39m
[93mFound 25 empty paragraphs.[39m
[93m	No product found: 25[39m
(0.8092224231464737, 0.1575114183434144)
Error in processing batch: Error code: 400 - {'error': {'message': 'Your input image may contain content that is not allowed by our safety system.', 'type': 'invalid_request_error', 'param': None, 'code': 'content_policy_violation'}}
Error in processing batch: Error code: 400 - {'error': {'message': 'You

In [10]:
import os

os.getenv("OPENAI_API_KEY")

'sk-proj-OdwtjV3T1Q1V8rXdcDPQT3BlbkFJPq3j4kgQtLbAFuaKzBW0'

In [9]:

import matplotlib.pyplot as plt
import networkx as nx
from networkx.drawing.nx_agraph import graphviz_layout

def extract_subgraph(graph, start_node):
    """Use BFS to find all nodes reachable from start_node."""
    reachable_nodes = set(nx.bfs_tree(graph, start_node))
    return graph.subgraph(reachable_nodes).copy()

def plot_graph(G):
    fig = plt.figure(figsize=(10, 7))
    pos = graphviz_layout(G, prog="dot")
    nx.draw(G, pos, with_labels=True, arrows=True)
    plt.show()


reach_sgs = SynthTree.get_reach_subgraphs(extracted_G)
print(len(reach_sgs))

# for g in reach_sgs.values():
#     if len(g) > 1:
#             plot_graph(g)
#             print(g.nodes)

68
