<a target="_blank" href="https://colab.research.google.com/github/pywavelet/pywavelet/blob/main/docs/runtime.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>


# Runtime Comparisons

In [None]:
import importlib
import numpy as np
import jax.numpy as jnp
import pandas as pd 
import jax
from pywavelet.backend import cuda_available
from tqdm.auto import tqdm
from pywavelet.types import FrequencySeries
from pywavelet.transforms.phi_computer import phitilde_vec_norm
from timeit import repeat as timing_repeat
import matplotlib.pyplot as plt

jax.config.update("jax_enable_x64", False)
    
JAX_DEVICE = jax.default_backend()
JAX_PRECISION = "x64" if jax.config.jax_enable_x64 else "x32"


if cuda_available:
    import cupy as cp


def generate_freq_domain_signal(
        ND, f0=20.0, dt=0.0125, A=2
) -> FrequencySeries:
    """
    Generates a frequency domain signal.

    Parameters:
    ND (int): Number of data points.
    f0 (float): Frequency of the signal. Default is 20.0.
    dt (float): Time step. Default is 0.0125.
    A (float): Amplitude of the signal. Default is 2.

    Returns:
    FrequencySeries: The generated frequency domain signal.
    """
    ts = np.arange(0, ND) * dt
    y = A * np.sin(2 * np.pi * f0 * ts)
    yf = FrequencySeries(y, ts)
    return yf


def generate_func_args(ND, backend="numpy"):
    Nf = Nt = int(np.sqrt(ND))
    yf = generate_freq_domain_signal(ND).data
    phif = phitilde_vec_norm(Nf, Nt, d=4.0)
    if backend == "jax":
        yf = jnp.array(yf)
        phif = jnp.array(phif)
    if backend == "cupy" and cuda_available:
        yf = cp.array(yf)
        phif = cp.array(phif)
    return yf, Nf, Nt, phif


def collect_runtime(
        func, func_args, n=5, nreps=5
):
    func(*func_args)  # Warm up run
    times = timing_repeat(
        lambda: func(*func_args),
        number=n,
        repeat=nreps
    )

    return (np.median(times), (np.std(times)))


def collect_runtimes(func, backend, NF_values, number=5, repeat=5):
    results = {}
    bar = tqdm(NF_values, desc="Running")
    for Nf in bar:
        ND = Nf * Nf
        bar.set_postfix(ND=f"2**{int(np.log2(ND))}")
        func_args = generate_func_args(ND, backend)
        try:
            results[ND] = collect_runtime(func, func_args, number, repeat)
        except Exception as e:
            print(f"Error processing ND={ND}: {e}")
            results[ND] = (np.nan, np.nan)
    return results


def run_transforms():
    from pywavelet.transforms.jax.forward.from_freq import transform_wavelet_freq_helper as jax_transform
    from pywavelet.transforms.numpy.forward.from_freq import transform_wavelet_freq_helper as np_transform

    min_pow2 = 2
    max_pow2 = 12
    NF = [2 ** i for i in range(min_pow2, max_pow2)]

    runtimes = {}
    runtimes["numpy"] = collect_runtimes(np_transform, "numpy", NF, number=5, repeat=5)
    
    max_pow2 = 15
    NF = [2 ** i for i in range(min_pow2, max_pow2)]
    runtimes["jax"] = collect_runtimes(jax_transform, "jax", NF, number=5, repeat=5)

    if cuda_available:
        from pywavelet.transforms.cupy.forward.from_freq import transform_wavelet_freq_helper as cp_transform
        runtimes['cupy'] = collect_runtimes(cp_transform, "cupy", NF, number=10, repeat=10)
    
    
    return runtimes


def plot(runtimes):
    fig, ax = plt.subplots(figsize=(4, 3.5))
    for i, backend in enumerate(runtimes.keys()):
        _plot_backend_runtime(ax, runtimes[backend], backend, f"C{i}")
    ax.set_yscale("log")
    ax.set_xscale("log")
    ax.set_xlabel("Number of Data Points")
    ax.set_ylabel("Runtime (s)")
    ax.legend(frameon=False)
    return fig, ax


def _plot_backend_runtime(ax, runtimes, backend, color):
    NDs = list(runtimes.keys())
    times = [runtimes[ND][0] for ND in NDs]
    stds = [runtimes[ND][1] for ND in NDs]
    # plot a band around the median runtime
    ax.fill_between(NDs, np.array(times) - np.array(stds), np.array(times) + np.array(stds), alpha=0.3, color=color)
    ax.plot(NDs, times, label=f"{backend}", color=color)
    ax.set_xlim(min(NDs), max(NDs))


runtimes = run_transforms()
# save runtime data as a txt file
runtime_data = pd.DataFrame(runtimes)
runtime_data.to_csv(f"runtime_data_{JAX_DEVICE}.csv")
    

fig, ax = plot(runtimes)
fig.savefig("runtime.png", bbox_inches="tight")



In [24]:
nan = np.nan


gpu_runtimes = {'numpy': {16: (np.float64(2.7647000024444424e-05),
   np.float64(1.6699892612917253e-05)),
  64: (np.float64(4.7878000032142154e-05), np.float64(7.901239258106135e-06)),
  256: (np.float64(0.00014417199997751595),
   np.float64(2.8245273864177316e-05)),
  1024: (np.float64(0.00025147099995592725),
   np.float64(8.368707136149373e-06)),
  4096: (np.float64(0.0008757960000593812),
   np.float64(0.00039357430227278433)),
  16384: (np.float64(0.0017043970000258923),
   np.float64(0.0016274004399110828)),
  65536: (np.float64(0.006220592999966357),
   np.float64(0.00021847848737320755)),
  262144: (np.float64(0.023843203999945217),
   np.float64(0.0008785219521943771)),
  1048576: (np.float64(0.11971158300002571), np.float64(0.02209812945217875)),
  4194304: (np.float64(0.7133738380000523), np.float64(0.006134709481659322))},
 'jax': {16: (np.float64(0.0005724070000496795),
   np.float64(0.0001069087688290125)),
  64: (np.float64(0.0007151319999820771), np.float64(0.0001322757281663352)),
  256: (np.float64(0.0009230550000438598), np.float64(0.00031332690291921287)),
  1024: (np.float64(0.001453614000070047), np.float64(0.00012679139424215402)),
  4096: (np.float64(0.00050203600005716), np.float64(8.404927733069893e-05)),
  16384: (np.float64(0.0007342449999896417), np.float64(6.60918000741e-05)),
  65536: (np.float64(0.0009854250000671527),
   np.float64(9.820818468604813e-05)),
  262144: (np.float64(0.0005175009999902613),
   np.float64(4.419451355415323e-05)),
  1048576: (np.float64(0.0003900709999697938),
   np.float64(7.364248018965731e-05)),
  4194304: (np.float64(0.0004053449999901204),
   np.float64(7.663108380403174e-05)),
  16777216: (np.float64(0.00039541899991490936),
   np.float64(6.804382786659215e-05)),
  67108864: (np.float64(0.00040576800006419944),
   np.float64(0.00010823411767385025)),
  268435456: (np.float64(0.00041394100003344647),
   np.float64(0.00010930952627993826))},
 'cupy': {16: (np.float64(0.019849124999950618),
   np.float64(0.0012326430575985117)),
  64: (np.float64(0.022546809499999654), np.float64(0.001692504727864258)),
  256: (np.float64(0.021691686499991647), np.float64(0.00294086495457803)),
  1024: (np.float64(0.022270052500005022), np.float64(0.003807071679929202)),
  4096: (np.float64(0.02085154750000129), np.float64(0.0010313077174266435)),
  16384: (np.float64(0.02093089249996183), np.float64(0.0010492676283390795)),
  65536: (np.float64(0.022014402500019514), np.float64(0.0006802123910234058)),
  262144: (np.float64(0.024796750999996675),
   np.float64(0.0057645733683490815)),
  1048576: (np.float64(0.02190001700000721),
   np.float64(0.0008732466576075845)),
  4194304: (np.float64(0.05879033200005779), np.float64(0.014904883472313363)),
  16777216: (np.float64(0.23404195400001981), np.float64(0.07533314482865239)),
  67108864: (nan, nan),
  268435456: (nan, nan)}}

cpu_runtimes = {'numpy': {16: (np.float64(9.805700028664432e-05),
   np.float64(0.0002719699903221544)),
  64: (np.float64(8.721999984118156e-05), np.float64(3.582454577267885e-05)),
  256: (np.float64(0.00017828199997893535),
   np.float64(1.2305009703125784e-05)),
  1024: (np.float64(0.00036261100103729405),
   np.float64(2.4494629976772865e-05)),
  4096: (np.float64(0.0006539139994856669), np.float64(0.0008354864099797487)),
  16384: (np.float64(0.002017660999626969), np.float64(0.0005277675716976345)),
  65536: (np.float64(0.007982569000887452), np.float64(0.0007685313919699793)),
  262144: (np.float64(0.03098107999903732),
   np.float64(0.00045184700913424144)),
  1048576: (np.float64(0.12877668900000572), np.float64(0.002616350943173567)),
  4194304: (np.float64(0.719053980001263), np.float64(0.0469158455496374))},
 'jax': {16: (np.float64(0.00023953800155140925),
   np.float64(0.00015563753017088007)),
  64: (np.float64(0.00035316300090926234), np.float64(0.0002318705268616366)),
  256: (np.float64(0.00029555400033132173), np.float64(8.832475279764218e-05)),
  1024: (np.float64(0.0004884099998889724),
   np.float64(0.00024645926639514133)),
  4096: (np.float64(0.0004465590009203879),
   np.float64(0.00019698850450255553)),
  16384: (np.float64(0.0005183960001886589),
   np.float64(0.00025487691910374257)),
  65536: (np.float64(0.00037294600042514503),
   np.float64(0.0008727913545748731)),
  262144: (np.float64(0.0033758919998945203),
   np.float64(0.002688101382975047)),
  1048576: (np.float64(0.019741617999898153),
   np.float64(0.010757652737059507)),
  4194304: (np.float64(0.10961731399947894), np.float64(0.05839073400005447))},
 'tpu_jax': {16: (np.float64(0.0001429589999588643),
   np.float64(3.5459720815065046e-05)),
  64: (np.float64(0.00017167299995435314), np.float64(4.2662062404195546e-05)),
  256: (np.float64(0.00014238400001431728),
   np.float64(2.8613417994725626e-05)),
  1024: (np.float64(0.00016977199993561953),
   np.float64(0.0002816987455046532)),
  4096: (np.float64(0.00015566100000796723),
   np.float64(4.2144860721421086e-05)),
  16384: (np.float64(0.00015526099991802766),
   np.float64(3.326513666243062e-05)),
  65536: (np.float64(0.00011258099993938231),
   np.float64(5.987320498690532e-05)),
  262144: (np.float64(0.00011269100002664345),
   np.float64(3.9259844590263475e-05)),
  1048576: (np.float64(0.00012115300000914431),
   np.float64(4.568541504776411e-05)),
  4194304: (np.float64(0.00013279199993121438),
   np.float64(3.920358955286624e-05)),
  16777216: (np.float64(0.00012158700008058076),
   np.float64(2.9125096906138658e-05)),
  67108864: (np.float64(0.0002409089998991476), np.float64(2.604900528570504)),
  268435456: (nan, nan)}}

