<a target="_blank" href="https://colab.research.google.com/github/pywavelet/pywavelet/blob/main/docs/runtime.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>


# Runtime Comparisons

In [1]:
import importlib
import numpy as np
import jax.numpy as jnp
from tqdm.auto import tqdm
from pywavelet.types import FrequencySeries
from pywavelet.transforms.phi_computer import phitilde_vec_norm
from timeit import repeat as timing_repeat



cupy_available = importlib.util.find_spec("cupy") is not None

if cupy_available:
    import cupy as cp
    

def generate_freq_domain_signal(
        ND, f0=20.0, dt=0.0125, A=2
) -> FrequencySeries:
    """
    Generates a frequency domain signal.

    Parameters:
    ND (int): Number of data points.
    f0 (float): Frequency of the signal. Default is 20.0.
    dt (float): Time step. Default is 0.0125.
    A (float): Amplitude of the signal. Default is 2.

    Returns:
    FrequencySeries: The generated frequency domain signal.
    """
    ts = np.arange(0, ND) * dt
    y = A * np.sin(2 * np.pi * f0 * ts)
    yf = FrequencySeries(y, ts)
    return yf


def generate_func_args(ND, backend="numpy"):
    Nt = int(np.log2(ND))
    Nf = ND // Nt
    yf = generate_freq_domain_signal(ND).data
    phif = phitilde_vec_norm(Nf, Nt, d=4.0)
    if backend == "jax":
        yf = jnp.array(yf)
        phif = jnp.array(phif)
    if backend == "cupy" and cupy_available:
        yf = cp.array(yf)
        phif = cp.array(phif)
    return yf, Nf, Nt, phif


def collect_runtime(
        func, func_args, n=5, nreps=5
):
    func(*func_args)  # Warm up run
    times = timing_repeat(
        lambda: func(*func_args),
        number=n,
        repeat=nreps
    )
    
    
    return (np.median(times), (np.std(times)))



def collect_runtimes(func, backend, ND_values, number=5, repeat=5):
    results = {}
    bar = tqdm(ND_values, desc="Running")
    for ND in bar:
        bar.set_postfix(ND=f"2**{int(np.log2(ND))}")
        func_args = generate_func_args(ND, backend)
        results[ND] = collect_runtime(func, func_args, number, repeat)
    return results





  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from pywavelet.transforms.numpy.forward.from_freq import (
    transform_wavelet_freq_helper,
)

ND = [2 ** i for i in range(4, 15)]
numpy_runtimes = collect_runtimes(
    transform_wavelet_freq_helper, "numpy", ND, number=5, repeat=5
)


Running:   9%|▉         | 1/11 [00:03<00:39,  3.93s/it, ND=2**6]

In [None]:
from pywavelet.transforms.jax.forward.from_freq import (
    transform_wavelet_freq_helper,
)
jax_runtimes = collect_runtimes(
    transform_wavelet_freq_helper, "jax", ND, number=5, repeat=5
)


Running:   0%|          | 0/13 [00:00<?, ?it/s, ND=2**2]INFO:2025-03-24 16:31:45,464:jax._src.xla_bridge:927: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'


INFO:2025-03-24 16:31:45,471:jax._src.xla_bridge:927: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/Users/avaj0001/miniforge3/envs/pywavelet/bin/../lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)
