In [2]:
from bokeh.transform import linear_cmap
from bokeh.palettes import Viridis256
import networkx as nx
from bokeh.plotting import figure, show, output_file
from bokeh.models import GraphRenderer, StaticLayoutProvider, Circle, MultiLine, HoverTool, LinearColorMapper, ColorBar
from bokeh.io import output_notebook
output_notebook()

from rdkit import Chem
from RetroTide_agent.node import Node
from RetroTide_agent.mcts import MCTS

In [3]:
root = Node(PKS_product = None,
            PKS_design = None,
            parent = None,
            depth = 0)

mcts = MCTS(root = root,
            target_molecule = Chem.MolFromSmiles("O=C1C=CCC(CO)O1"), # OC(CC(O)CC(O)=O)/C=C/C1=CC=CC=C1 # CCCCCC(=O)O # O=C1C=CCC(CO)O1
            max_depth = 10,
            total_iterations = 15000,
            maxPKSDesignsRetroTide = 3000,
            selection_policy = "UCB1")

mcts.run()


Selected leaf node at depth 0

computing module 1
   testing 2755 designs
Expanded leaf node: 2755 new children
computing module 1
   testing 2755 designs
   best score is 0.75
computing module 2
   testing 1425 designs
   best score is 0.8888888888888888
computing module 3
   testing 1425 designs
   best score is 0.8181818181818182
TARGET REACHED IN SIMULATION THROUGH CYCLIZATION!
<bcs.bcs.Cluster object at 0x13433b730>
Simulation reward = 0.89
Backpropagation complete.

Unable to perform cyclization reaction

Unable to perform cyclization reaction

Unable to perform cyclization reaction

Unable to perform cyclization reaction

Unable to perform cyclization reaction

Selected leaf node at depth 1

computing module 3
   testing 95 designs
Expanded leaf node: 95 new children
computing module 1
   testing 95 designs
   best score is 0.7
computing module 2
   testing 1425 designs
   best score is 0.6666666666666666
Simulation reward = 0.70
Backpropagation complete.

Unable to perform cyc

In [4]:
def visualize_mcts_tree(mcts):
    """Visualizes the MCTS search tree using NetworkX and Bokeh with a depth legend."""

    G = nx.DiGraph()
    for node in mcts.nodes:
        G.add_node(node.node_id,
                   depth = node.depth,
                   visits = node.visits,
                   value = node.value,
                   score = node.selection_score)

    for parent_id, child_id in mcts.edges:
        G.add_edge(parent_id, child_id)

    # Compute layout
    pos = nx.spring_layout(G, seed=42, scale=500)

    # Extract x, y coordinates
    node_x = [pos[node][0] for node in G.nodes()]
    node_y = [pos[node][1] for node in G.nodes()]
    edge_x, edge_y = [], []

    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.extend([x0, x1, None])
        edge_y.extend([y0, y1, None])

    # Create Bokeh figure
    plot = figure(title="MCTS Search Tree",
                  width=1000, height=800,
                  x_range=(-550, 550), y_range=(-550, 550),
                  tools="pan,wheel_zoom,box_zoom,reset,save")

    # Add hover tool
    plot.add_tools(
        HoverTool(tooltips=[("Node", "@index"),
                            ("Depth", "@depth"),
                            ("Visits", "@visits"),
                            ("Value", "@value"),
                            ("Selection score", "@score"),]))

    graph = GraphRenderer()

    # Get depth values for color mapping
    depths = [G.nodes[n]['depth'] for n in G.nodes()]
    min_depth, max_depth = min(depths), max(depths)

    # Create color mapper and color bar
    color_mapper = LinearColorMapper(palette=Viridis256, low=min_depth, high=max_depth)
    color_bar = ColorBar(color_mapper=color_mapper, location=(0, 0), title="Depth")

    # Set node positions and attributes
    graph.node_renderer.data_source.data = dict(
        index=list(G.nodes()),
        x=node_x,
        y=node_y,
        depth=depths,
        visits=[G.nodes[n]['visits'] for n in G.nodes()],
        score=[G.nodes[n]['score'] for n in G.nodes()],
    )

    graph.node_renderer.glyph = Circle(radius=4, fill_color=linear_cmap("depth", Viridis256, min_depth, max_depth), line_color=None)

    # Set edge positions
    graph.edge_renderer.data_source.data = dict(start=[e[0] for e in G.edges()], end=[e[1] for e in G.edges()])
    graph.edge_renderer.glyph = MultiLine(line_color="gray", line_width=1)

    graph.layout_provider = StaticLayoutProvider(graph_layout=pos)

    plot.add_layout(color_bar, 'right')  # Add the color bar to the right side
    plot.renderers.append(graph)

    show(plot)


In [10]:
visualize_mcts_tree(mcts)

In [5]:
query_depth = 2
query_node_id = 39

for node in mcts.nodes:
    if node.depth == query_depth and node.node_id == query_node_id:
        print(node)

In [6]:
for node in mcts.nodes:
    if node.value != 0.0:
        print(node)

Node ID: 0, Depth: 0, PKS Design: No design, PKS Product: None
Node ID: 109, Depth: 1, PKS Design: [["AT{'substrate': 'Methylmalonyl-CoA'}", 'loading: True'], ["AT{'substrate': 'Malonyl-CoA'}", "KR{'type': 'B'}", 'DH{}', 'loading: False']], PKS Product: CCC=CC(=O)[S]
Node ID: 121, Depth: 1, PKS Design: [["AT{'substrate': 'prop'}", 'loading: True'], ["AT{'substrate': 'Malonyl-CoA'}", "KR{'type': 'B'}", 'DH{}', 'loading: False']], PKS Product: CCC=CC(=O)[S]
Node ID: 571, Depth: 1, PKS Design: [["AT{'substrate': 'Malonyl-CoA'}", 'loading: True'], ["AT{'substrate': 'Malonyl-CoA'}", "KR{'type': 'B'}", 'DH{}', 'loading: False']], PKS Product: CC=CC(=O)[S]
Node ID: 573, Depth: 1, PKS Design: [["AT{'substrate': 'cemal'}", 'loading: True'], ["AT{'substrate': 'Malonyl-CoA'}", "KR{'type': 'B'}", 'DH{}', 'loading: False']], PKS Product: CC=CC(=O)[S]
Node ID: 575, Depth: 1, PKS Design: [["AT{'substrate': 'Acetyl-CoA'}", 'loading: True'], ["AT{'substrate': 'Malonyl-CoA'}", "KR{'type': 'B'}", 'DH{}',

In [7]:
set(mcts.successful_nodes)

set()

In [8]:
mcts.bag_of_graphs

[<rdkit.Chem.rdchem.Mol at 0x133de0b30>,
 <rdkit.Chem.rdchem.Mol at 0x133de0ac0>,
 <rdkit.Chem.rdchem.Mol at 0x133de0f20>,
 <rdkit.Chem.rdchem.Mol at 0x133de0ba0>,
 <rdkit.Chem.rdchem.Mol at 0x133de0c10>,
 <rdkit.Chem.rdchem.Mol at 0x133de04a0>,
 <rdkit.Chem.rdchem.Mol at 0x133de0dd0>,
 <rdkit.Chem.rdchem.Mol at 0x133de0d60>,
 <rdkit.Chem.rdchem.Mol at 0x133de0cf0>,
 <rdkit.Chem.rdchem.Mol at 0x133de05f0>,
 <rdkit.Chem.rdchem.Mol at 0x133de0f90>,
 <rdkit.Chem.rdchem.Mol at 0x133de1230>,
 <rdkit.Chem.rdchem.Mol at 0x133de0510>,
 <rdkit.Chem.rdchem.Mol at 0x133de0580>,
 <rdkit.Chem.rdchem.Mol at 0x133de0430>,
 <rdkit.Chem.rdchem.Mol at 0x133de11c0>,
 <rdkit.Chem.rdchem.Mol at 0x133de1150>,
 <rdkit.Chem.rdchem.Mol at 0x133de10e0>,
 <rdkit.Chem.rdchem.Mol at 0x133de12a0>,
 <rdkit.Chem.rdchem.Mol at 0x133de0c80>,
 <rdkit.Chem.rdchem.Mol at 0x133de1000>,
 <rdkit.Chem.rdchem.Mol at 0x133de1310>,
 <rdkit.Chem.rdchem.Mol at 0x133de1380>,
 <rdkit.Chem.rdchem.Mol at 0x133de13f0>,
 <rdkit.Chem.rdc

In [9]:
mcts.is_PKS_product_in_bag_of_graphs(PKS_product = Chem.MolFromSmiles("CC=CC(O)=O"),
                                     consider_stereo = False)

True

In [53]:
mcts.bag_of_graphs[36]

IndexError: list index out of range

In [54]:
for node in mcts.nodes:
    if node.node_id == 2865:
        query_node = node

In [55]:
query_node

Node ID: 2865, Depth: 1, PKS Design: [["AT{'substrate': 'Methylmalonyl-CoA'}", 'loading: True'], ["AT{'substrate': 'Malonyl-CoA'}", "KR{'type': 'B'}", 'DH{}', 'loading: False']], PKS Product: CCC=CC(=O)[S]

In [56]:
mcts.calculate_subgraph_value(query_node)

1