# Benchmark of hydrogen wavefunction integration

In [1]:
import timeit
from typing import Literal

import numpy as np

from numerov.integrate import integrate_hydrogen

# Test cases from pytest parametrize
test_cases: list[tuple[int, int, int, Literal["forward", "backward"], bool]] = [
    (1, 0, 1, "backward", False),
    (1, 0, 1, "backward", True),
]

In [2]:
def run_benchmark(number=10):
    """Run benchmark for different quantum states.

    Args:
        number: Number of times to run each test for averaging
    """
    # Fixed parameters from test file
    dx = 1e-3
    xmin = dx
    xmax = 80
    epsilon_u = 1e-10

    # run the integration once to compile the numba function
    n, l, Z, direction, use_njit = test_cases[0]
    energy = -(Z**2) / (n**2)
    integrate_hydrogen(energy, Z, n, l, dx, xmin, xmax, direction, epsilon_u, use_njit=True)

    results = []

    for n, l, Z, direction, use_njit in test_cases:
        energy = -(Z**2) / (n**2)

        # Setup the test function
        setup = "from numerov.integrate import integrate_hydrogen"
        stmt = (
            f"integrate_hydrogen({energy}, {Z}, {n}, {l}, {dx}, {xmin}, {xmax}, '{direction}', {epsilon_u}, {use_njit})"
        )

        # Time the integration multiple times and take average/std
        times = timeit.repeat(stmt=stmt, setup=setup, number=1, repeat=number)
        avg_time = np.mean(times)
        std_time = np.std(times)

        results.append(
            {"n": n, "l": l, "Z": Z, "direction": direction, "use_njit": use_njit, "time": avg_time, "std": std_time}
        )

    return results

In [3]:
# Run benchmark and print results
print("Running benchmarks...")
results = run_benchmark(number=10)

Running benchmarks...


In [4]:
print("\nBenchmark Results:")
print("-" * 70)
print(f"{'n':>3} {'l':>3} {'Z':>3} {'direction':>10} {'use_njit':>10} {'Time (ms)':>10} {'Std (ms)':>10}")
print("-" * 70)

for r in results:
    print(
        f"{r['n']:3d} {r['l']:3d} {r['Z']:3d} {r['direction']:>10} {str(r['use_njit']):>10} {r['time'] * 1000:10.2f} "
        f"{r['std'] * 1000:10.2f}"
    )


Benchmark Results:
----------------------------------------------------------------------
  n   l   Z  direction   use_njit  Time (ms)   Std (ms)
----------------------------------------------------------------------
  1   0   1   backward      False      89.51       5.98
  1   0   1   backward       True       0.85       0.03
