# predTED Speed Benchmark

Comparison of three methods for computing RNA structure distances:

1. **predted** — Python LightGBM prediction (approximate TED)
2. **RNA.tree_edit_distance()** — ViennaRNA Python API (exact TED)
3. **RNAdistance CLI** — ViennaRNA command-line tool (exact TED)

In [None]:
import time
import subprocess
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import RNA

import predted

## Load test structures

In [None]:
structures_file = Path("..") / "data" / "structures.txt"
all_structures = structures_file.read_text().strip().splitlines()
print(f"Loaded {len(all_structures)} structures")
print(f"Length range: {min(len(s) for s in all_structures)} – {max(len(s) for s in all_structures)}")

## Helper functions

In [None]:
def bench_predted_matrix(structures: list[str]) -> float:
    """Benchmark predted.predict_matrix(). Returns wall-clock seconds."""
    start = time.perf_counter()
    predted.predict_matrix(structures)
    return time.perf_counter() - start


def bench_rna_ted_python(structures: list[str]) -> float:
    """Benchmark RNA.tree_edit_distance() for all pairs. Returns seconds."""
    n = len(structures)
    # Pre-compute tree strings
    tree_strings = [
        RNA.db_to_tree_string(s, RNA.STRUCTURE_TREE_EXPANDED)
        for s in structures
    ]
    start = time.perf_counter()
    for i in range(n):
        for j in range(i + 1, n):
            t1 = RNA.make_tree(tree_strings[i])
            t2 = RNA.make_tree(tree_strings[j])
            RNA.tree_edit_distance(t1, t2)
            RNA.free_tree(t1)
            RNA.free_tree(t2)
    return time.perf_counter() - start


def bench_rnadistance_cli(structures: list[str]) -> float:
    """Benchmark RNAdistance CLI in matrix mode. Returns seconds."""
    input_text = "\n".join(structures) + "\n"
    start = time.perf_counter()
    result = subprocess.run(
        ["RNAdistance", "-Df", "-Xm"],
        input=input_text,
        capture_output=True,
        text=True,
    )
    elapsed = time.perf_counter() - start
    if result.returncode != 0:
        print(f"RNAdistance error: {result.stderr[:200]}")
    return elapsed

## Run benchmarks

In [None]:
test_sizes = [10, 25, 50, 100, 250, 500]

# Warm up predted (first call loads the LightGBM model)
_ = predted.predict(all_structures[0], all_structures[1])

results = []

for n in test_sizes:
    subset = all_structures[:n]
    n_pairs = n * (n - 1) // 2
    print(f"\n--- N = {n} ({n_pairs:,} pairs) ---")

    # predted
    t_predted = bench_predted_matrix(subset)
    print(f"  predted:              {t_predted:.3f}s  ({t_predted / n_pairs * 1000:.3f} ms/pair)")

    # RNA Python API
    t_rna_py = bench_rna_ted_python(subset)
    print(f"  RNA.tree_edit_dist(): {t_rna_py:.3f}s  ({t_rna_py / n_pairs * 1000:.3f} ms/pair)")

    # RNAdistance CLI
    t_cli = bench_rnadistance_cli(subset)
    print(f"  RNAdistance CLI:      {t_cli:.3f}s  ({t_cli / n_pairs * 1000:.3f} ms/pair)")

    results.append({
        "n": n,
        "n_pairs": n_pairs,
        "predted_s": t_predted,
        "rna_python_s": t_rna_py,
        "rnadistance_cli_s": t_cli,
    })

## Results table

In [None]:
print(f"{'N':>5} {'Pairs':>8} │ {'predted':>9} {'RNA Py':>9} {'CLI':>9} │ {'Speedup vs RNA':>15} {'Speedup vs CLI':>15}")
print("─" * 85)
for r in results:
    sp_rna = r['rna_python_s'] / r['predted_s'] if r['predted_s'] > 0 else float('inf')
    sp_cli = r['rnadistance_cli_s'] / r['predted_s'] if r['predted_s'] > 0 else float('inf')
    print(
        f"{r['n']:>5} {r['n_pairs']:>8,} │ "
        f"{r['predted_s']:>8.3f}s {r['rna_python_s']:>8.3f}s {r['rnadistance_cli_s']:>8.3f}s │ "
        f"{sp_rna:>14.1f}x {sp_cli:>14.1f}x"
    )

## Visualisation

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

ns = [r["n"] for r in results]
t_predted = [r["predted_s"] for r in results]
t_rna = [r["rna_python_s"] for r in results]
t_cli = [r["rnadistance_cli_s"] for r in results]

# --- Left: absolute time ---
ax = axes[0]
ax.plot(ns, t_predted, "o-", label="predted (LightGBM)", linewidth=2, markersize=6)
ax.plot(ns, t_rna, "s-", label="RNA.tree_edit_distance()", linewidth=2, markersize=6)
ax.plot(ns, t_cli, "^-", label="RNAdistance CLI", linewidth=2, markersize=6)
ax.set_xlabel("Number of structures (N)")
ax.set_ylabel("Wall-clock time (s)")
ax.set_title("Pairwise distance computation time")
ax.legend()
ax.set_yscale("log")
ax.grid(True, alpha=0.3)

# --- Right: time per pair ---
ax = axes[1]
pairs = [r["n_pairs"] for r in results]
ax.plot(ns, [t / p * 1000 for t, p in zip(t_predted, pairs)], "o-",
        label="predted", linewidth=2, markersize=6)
ax.plot(ns, [t / p * 1000 for t, p in zip(t_rna, pairs)], "s-",
        label="RNA.tree_edit_distance()", linewidth=2, markersize=6)
ax.plot(ns, [t / p * 1000 for t, p in zip(t_cli, pairs)], "^-",
        label="RNAdistance CLI", linewidth=2, markersize=6)
ax.set_xlabel("Number of structures (N)")
ax.set_ylabel("Time per pair (ms)")
ax.set_title("Per-pair computation cost")
ax.legend()
ax.set_yscale("log")
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("speed_benchmark.pdf", bbox_inches="tight")
plt.savefig("speed_benchmark.png", dpi=150, bbox_inches="tight")
plt.show()
print("Saved: speed_benchmark.pdf, speed_benchmark.png")

## Single-pair timing (median over 1000 repeats)

In [None]:
s1 = all_structures[0]
s2 = all_structures[1]
n_repeats = 1000

# predted single pair
times_predted = []
for _ in range(n_repeats):
    t0 = time.perf_counter()
    predted.predict(s1, s2)
    times_predted.append(time.perf_counter() - t0)

# RNA Python single pair
t1_str = RNA.db_to_tree_string(s1, RNA.STRUCTURE_TREE_EXPANDED)
t2_str = RNA.db_to_tree_string(s2, RNA.STRUCTURE_TREE_EXPANDED)
times_rna = []
for _ in range(n_repeats):
    t0 = time.perf_counter()
    t1 = RNA.make_tree(t1_str)
    t2 = RNA.make_tree(t2_str)
    RNA.tree_edit_distance(t1, t2)
    RNA.free_tree(t1)
    RNA.free_tree(t2)
    times_rna.append(time.perf_counter() - t0)

med_predted = np.median(times_predted) * 1000
med_rna = np.median(times_rna) * 1000

print(f"Single-pair timing (median of {n_repeats} repeats):")
print(f"  predted:              {med_predted:.3f} ms")
print(f"  RNA.tree_edit_dist(): {med_rna:.3f} ms")
print(f"  Ratio:                {med_predted / med_rna:.1f}x {'slower' if med_predted > med_rna else 'faster'}")