# Runtime Comparison: Sequential vs. Async DORAnetMCTS

This notebook compares the runtime performance of sequential (`DORAnetMCTS`) and asynchronous (`AsyncExpansionDORAnetMCTS`) implementations for the target molecule **5,6-dihydroyangonin**.

**Target molecule**: 5,6-dihydroyangonin  
**SMILES**: `COC1=CC(OC(C=CC2=CC=C(OC)C=C2)C1)=O`

## 1. Setup and Imports

In [None]:
import time
import os
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import Draw

from DORAnet_agent import DORAnetMCTS, AsyncExpansionDORAnetMCTS, Node, clear_smiles_cache
from DORAnet_agent.policies import (
    SpawnRetroTideOnDatabaseCheck,
    SAScore_and_TerminalRewardPolicy,
)

# Matplotlib styling
plt.rcParams['figure.dpi'] = 150
plt.rcParams['figure.autolayout'] = True
plt.rcParams['xtick.labelsize'] = 12
plt.rcParams['ytick.labelsize'] = 12

# Repository root
REPO_ROOT = Path('../').resolve()
print(f"Repository root: {REPO_ROOT}")

## 2. Define Target Molecule

In [None]:
# 5,6-dihydroyangonin
TARGET_SMILES = "COC1=CC(OC(C=CC2=CC=C(OC)C=C2)C1)=O"
TARGET_NAME = "5,6-dihydroyangonin"

# Validate SMILES
target_mol = Chem.MolFromSmiles(TARGET_SMILES)
if target_mol is None:
    raise ValueError(f"Invalid SMILES: {TARGET_SMILES}")

print(f"Target: {TARGET_NAME}")
print(f"SMILES: {TARGET_SMILES}")
print(f"Molecular formula: {Chem.rdMolDescriptors.CalcMolFormula(target_mol)}")
print(f"Molecular weight: {Chem.Descriptors.MolWt(target_mol):.2f} g/mol")

# Display molecule
Draw.MolToImage(target_mol, size=(300, 300))

## 3. Configure Data Files

In [None]:
# Cofactors files (compounds to exclude from results)
cofactors_files = [
    REPO_ROOT / "data/raw/all_cofactors.csv",
    REPO_ROOT / "data/raw/chemistry_helpers.csv",
]

# PKS library
pks_library_file = REPO_ROOT / "data/processed/expanded_PKS_SMILES_V3.txt"

# Sink compounds (terminal building blocks)
sink_compounds_files = [
    REPO_ROOT / "data/processed/biological_building_blocks.txt",
    REPO_ROOT / "data/processed/chemical_building_blocks.txt",
]

# Prohibited chemicals
prohibited_chemicals_file = REPO_ROOT / "data/processed/prohibited_chemical_SMILES.txt"

# Verify files exist
print("Verifying data files:")
for f in cofactors_files + [pks_library_file] + sink_compounds_files + [prohibited_chemicals_file]:
    exists = f.exists()
    status = "OK" if exists else "MISSING"
    print(f"  [{status}] {f.name}")

## 4. Configure MCTS Parameters

Using identical parameters for both sequential and async runs to ensure a fair comparison.

In [None]:
# MCTS configuration
TOTAL_ITERATIONS = 100
MAX_DEPTH = 3
MAX_CHILDREN_PER_EXPAND = 30
CHILD_DOWNSELECTION_STRATEGY = "most_thermo_feasible"
SELECTION_POLICY = "depth_biased"
DEPTH_BONUS_COEFFICIENT = 4.0

# RetroTide parameters (for PKS fragment verification)
RETROTIDE_KWARGS = {
    "max_depth": 5,
    "total_iterations": 50,
    "maxPKSDesignsRetroTide": 500,
}

print("MCTS Configuration:")
print(f"  Total iterations: {TOTAL_ITERATIONS}")
print(f"  Max depth: {MAX_DEPTH}")
print(f"  Max children per expand: {MAX_CHILDREN_PER_EXPAND}")
print(f"  Child downselection: {CHILD_DOWNSELECTION_STRATEGY}")
print(f"  Selection policy: {SELECTION_POLICY}")
print(f"  Depth bonus coefficient: {DEPTH_BONUS_COEFFICIENT}")

## 5. Define Common Agent Configuration

In [None]:
def create_common_kwargs():
    """Create common keyword arguments for both MCTS variants."""
    
    # Policies
    rollout_policy = SpawnRetroTideOnDatabaseCheck(
        success_reward=1.0,
        failure_reward=0.0,
    )
    reward_policy = SAScore_and_TerminalRewardPolicy(
        sink_terminal_reward=1.0,
        pks_terminal_reward=1.0,
    )
    
    return {
        "target_molecule": target_mol,
        "total_iterations": TOTAL_ITERATIONS,
        "max_depth": MAX_DEPTH,
        "max_children_per_expand": MAX_CHILDREN_PER_EXPAND,
        "child_downselection_strategy": CHILD_DOWNSELECTION_STRATEGY,
        "use_enzymatic": True,
        "use_synthetic": True,
        "generations_per_expand": 1,
        "cofactors_files": [str(f) for f in cofactors_files],
        "pks_library_file": str(pks_library_file),
        "sink_compounds_files": [str(f) for f in sink_compounds_files],
        "prohibited_chemicals_file": str(prohibited_chemicals_file),
        "rollout_policy": rollout_policy,
        "reward_policy": reward_policy,
        "spawn_retrotide": True,
        "retrotide_kwargs": RETROTIDE_KWARGS,
        "sink_terminal_reward": 1.0,
        "selection_policy": SELECTION_POLICY,
        "depth_bonus_coefficient": DEPTH_BONUS_COEFFICIENT,
        "stop_on_first_pathway": False,
    }

print("Common configuration function defined.")

## 6. Run Sequential DORAnetMCTS

In [None]:
# Clear cache for fair comparison
clear_smiles_cache()

# Create root node
root_seq = Node(fragment=target_mol, parent=None, depth=0, provenance="target")

# Create sequential agent
seq_agent = DORAnetMCTS(
    root=root_seq,
    **create_common_kwargs()
)

print(f"Running Sequential DORAnetMCTS for {TARGET_NAME}...")
print(f"  Iterations: {TOTAL_ITERATIONS}")
print()

# Time the run
seq_start = time.perf_counter()
seq_agent.run()
seq_runtime = time.perf_counter() - seq_start

# Collect results
seq_sink_compounds = seq_agent.get_sink_compounds()
seq_pks_matches = seq_agent.get_pks_matches()

print(f"\nSequential MCTS Results:")
print(f"  Runtime: {seq_runtime:.2f} seconds")
print(f"  Total nodes: {len(seq_agent.nodes)}")
print(f"  Sink compounds found: {len(seq_sink_compounds)}")
print(f"  PKS matches found: {len(seq_pks_matches)}")
print(f"  Terminal nodes: {len(seq_sink_compounds) + len(seq_pks_matches)}")

## 7. Run Async DORAnetMCTS

In [None]:
# Clear cache for fair comparison
clear_smiles_cache()

# Create root node (fresh for async)
root_async = Node(fragment=target_mol, parent=None, depth=0, provenance="target")

# Get number of workers (auto-detect)
num_workers = os.cpu_count() - 1 if os.cpu_count() else 4

# Create async agent
async_agent = AsyncExpansionDORAnetMCTS(
    root=root_async,
    num_workers=num_workers,
    max_inflight_expansions=num_workers,
    **create_common_kwargs()
)

print(f"Running Async DORAnetMCTS for {TARGET_NAME}...")
print(f"  Iterations: {TOTAL_ITERATIONS}")
print(f"  Workers: {num_workers}")
print()

# Time the run
async_start = time.perf_counter()
async_agent.run()
async_runtime = time.perf_counter() - async_start

# Collect results
async_sink_compounds = async_agent.get_sink_compounds()
async_pks_matches = async_agent.get_pks_matches()

print(f"\nAsync MCTS Results:")
print(f"  Runtime: {async_runtime:.2f} seconds")
print(f"  Total nodes: {len(async_agent.nodes)}")
print(f"  Sink compounds found: {len(async_sink_compounds)}")
print(f"  PKS matches found: {len(async_pks_matches)}")
print(f"  Terminal nodes: {len(async_sink_compounds) + len(async_pks_matches)}")

## 8. Runtime Comparison Summary

In [None]:
# Calculate speedup
speedup = seq_runtime / async_runtime if async_runtime > 0 else float('inf')

# Create comparison table
comparison_df = pd.DataFrame({
    'Metric': [
        'Runtime (seconds)',
        'Total nodes',
        'Sink compounds',
        'PKS matches',
        'Terminal nodes',
        'Nodes/second'
    ],
    'Sequential': [
        f"{seq_runtime:.2f}",
        len(seq_agent.nodes),
        len(seq_sink_compounds),
        len(seq_pks_matches),
        len(seq_sink_compounds) + len(seq_pks_matches),
        f"{len(seq_agent.nodes) / seq_runtime:.2f}"
    ],
    'Async': [
        f"{async_runtime:.2f}",
        len(async_agent.nodes),
        len(async_sink_compounds),
        len(async_pks_matches),
        len(async_sink_compounds) + len(async_pks_matches),
        f"{len(async_agent.nodes) / async_runtime:.2f}"
    ]
})

print("="*60)
print("RUNTIME COMPARISON SUMMARY")
print("="*60)
print(f"\nTarget: {TARGET_NAME}")
print(f"Iterations: {TOTAL_ITERATIONS}")
print(f"Async workers: {num_workers}")
print()
display(comparison_df)

print(f"\n{'='*60}")
print(f"SPEEDUP: {speedup:.2f}x")
print(f"{'='*60}")

if speedup > 1:
    print(f"\nAsync is {speedup:.2f}x faster than sequential")
elif speedup < 1:
    print(f"\nSequential is {1/speedup:.2f}x faster than async")
else:
    print(f"\nBoth methods have similar performance")

## 9. Visualize Runtime Comparison

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Bar chart: Runtime comparison
ax1 = axes[0]
methods = ['Sequential', f'Async\n({num_workers} workers)']
runtimes = [seq_runtime, async_runtime]
colors = ['#1f77b4', '#ff7f0e']

bars = ax1.bar(methods, runtimes, color=colors, edgecolor='black', linewidth=1.2)
ax1.set_ylabel('Runtime (seconds)', fontsize=12)
ax1.set_title(f'Runtime Comparison\n{TARGET_NAME}', fontsize=14, fontweight='bold')

# Add value labels on bars
for bar, runtime in zip(bars, runtimes):
    height = bar.get_height()
    ax1.annotate(f'{runtime:.1f}s',
                xy=(bar.get_x() + bar.get_width() / 2, height),
                xytext=(0, 3),
                textcoords="offset points",
                ha='center', va='bottom', fontsize=11, fontweight='bold')

ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)

# Bar chart: Nodes explored
ax2 = axes[1]
nodes = [len(seq_agent.nodes), len(async_agent.nodes)]

bars2 = ax2.bar(methods, nodes, color=colors, edgecolor='black', linewidth=1.2)
ax2.set_ylabel('Total Nodes', fontsize=12)
ax2.set_title(f'Nodes Explored\n{TARGET_NAME}', fontsize=14, fontweight='bold')

# Add value labels on bars
for bar, n in zip(bars2, nodes):
    height = bar.get_height()
    ax2.annotate(f'{n}',
                xy=(bar.get_x() + bar.get_width() / 2, height),
                xytext=(0, 3),
                textcoords="offset points",
                ha='center', va='bottom', fontsize=11, fontweight='bold')

ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)

plt.tight_layout()

# Save figure
figures_dir = Path('../figures')
figures_dir.mkdir(parents=True, exist_ok=True)
save_path = figures_dir / 'sequential_vs_async_runtime_comparison.png'
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Figure saved to: {save_path}")

plt.show()

## 10. Throughput Comparison

In [None]:
# Calculate throughput metrics
seq_throughput = len(seq_agent.nodes) / seq_runtime
async_throughput = len(async_agent.nodes) / async_runtime

seq_iter_rate = TOTAL_ITERATIONS / seq_runtime
async_iter_rate = TOTAL_ITERATIONS / async_runtime

fig, ax = plt.subplots(figsize=(8, 6))

metrics = ['Nodes/second', 'Iterations/second']
seq_values = [seq_throughput, seq_iter_rate]
async_values = [async_throughput, async_iter_rate]

x = np.arange(len(metrics))
width = 0.35

bars1 = ax.bar(x - width/2, seq_values, width, label='Sequential', color='#1f77b4', edgecolor='black')
bars2 = ax.bar(x + width/2, async_values, width, label=f'Async ({num_workers} workers)', color='#ff7f0e', edgecolor='black')

ax.set_ylabel('Rate', fontsize=12)
ax.set_title(f'Throughput Comparison\n{TARGET_NAME}', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(metrics)
ax.legend()

# Add value labels
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        ax.annotate(f'{height:.1f}',
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3),
                    textcoords="offset points",
                    ha='center', va='bottom', fontsize=10)

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.tight_layout()

save_path = figures_dir / 'sequential_vs_async_throughput.png'
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"Figure saved to: {save_path}")

plt.show()

## 11. Final Summary

In [None]:
print("="*70)
print("FINAL SUMMARY: Sequential vs. Async DORAnetMCTS")
print("="*70)
print(f"\nTarget molecule: {TARGET_NAME}")
print(f"SMILES: {TARGET_SMILES}")
print(f"\nConfiguration:")
print(f"  - Iterations: {TOTAL_ITERATIONS}")
print(f"  - Max depth: {MAX_DEPTH}")
print(f"  - Max children per expand: {MAX_CHILDREN_PER_EXPAND}")
print(f"  - Selection policy: {SELECTION_POLICY}")
print(f"  - Async workers: {num_workers}")
print(f"\nResults:")
print(f"  Sequential runtime: {seq_runtime:.2f} seconds")
print(f"  Async runtime: {async_runtime:.2f} seconds")
print(f"  Speedup: {speedup:.2f}x")
print(f"\n  Sequential nodes: {len(seq_agent.nodes)}")
print(f"  Async nodes: {len(async_agent.nodes)}")
print(f"\n  Sequential terminal nodes: {len(seq_sink_compounds) + len(seq_pks_matches)}")
print(f"  Async terminal nodes: {len(async_sink_compounds) + len(async_pks_matches)}")
print("="*70)