In [1]:
import os
# set visible GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

In [2]:
import time
import numpy as np
import math
import jax
import jax.numpy as jnp
import jax.experimental.pallas as pl

# ------------------------------
# Simple random tree data generation (can be replaced with your custom make_tree_data)
# Generates a rooted tree with N nodes (numbered 0..N-1) where root = 0, structured by levels
# ------------------------------
def make_tree_data(N: int, seed: int = 0):
    """
    Generate a tree with N nodes; root node is 0.
    
    Returns:
      edges:       (E, 2) int32 array, where each row is (child, parent); E = N-1.
                   Edges are stored consecutively by "depth from largest to smallest" (leaf -> root)
      diags:       (N,) float32 array, all values are positive
      solves:      (N,) float32 array
      uppers:      (N,) float32 array
      lowers:      (N,) float32 array
      level_start: (L,) int32 array, starting index of each level in the edges array
      level_size:  (L,) int32 array, number of edges in each level
    """
    assert N >= 2, "Number of nodes N must be at least 2"
    rng = np.random.default_rng(seed)

    # --- 1) Randomly generate parent nodes (ensure parent index < child index)
    parents = np.full(N, -1, dtype=np.int32)  # Parent of each node (initialized to -1)
    depths = np.zeros(N, dtype=np.int32)      # Depth of each node (root depth = 0)
    
    for c in range(1, N):
        # Randomly select parent from existing nodes (0 to c-1)
        p = rng.integers(0, c, endpoint=False)
        parents[c] = p
        depths[c] = depths[p] + 1  # Child depth = parent depth + 1

    max_depth = int(depths.max())  # Maximum depth of the tree

    # --- 2) Group children by their depth, generate edges and level slices (leaf -> root)
    # Initialize buckets: one bucket per depth (stores child nodes)
    buckets = [[] for _ in range(max_depth + 1)]
    for c in range(1, N):
        buckets[depths[c]].append(c)

    edges_list = []       # List to collect (child, parent) edge pairs
    level_start = []      # Starting index of each level in edges_list
    level_size = []       # Number of edges in each level
    e_count = 0           # Current count of edges
    
    # Iterate from maximum depth down to 1 (root at depth 0 has no parent edges)
    for d in range(max_depth, 0, -1):
        if not buckets[d]:  # Skip if no nodes at current depth
            continue
        
        s = e_count  # Starting index of current level
        # Add all (child, parent) pairs for nodes at current depth
        for c in buckets[d]:
            edges_list.append((c, parents[c]))
            e_count += 1
        
        level_start.append(s)
        level_size.append(e_count - s)  # Size = current edge count - start index

    # Convert lists to numpy arrays
    edges = np.asarray(edges_list, dtype=np.int32)
    level_start = np.asarray(level_start, dtype=np.int32)
    level_size = np.asarray(level_size, dtype=np.int32)

    # --- 3) Generate numerical values (within normal ranges)
    # Root node's upper/lower values are unused, set to 0
    diags = rng.uniform(2.0, 8.0, size=N).astype(np.float32)    # Diagonal values (positive)
    solves = rng.uniform(0.5, 1.5, size=N).astype(np.float32)   # Solution-related values
    uppers = rng.uniform(0.2, 1.2, size=N).astype(np.float32)   # Upper coefficients
    lowers = rng.uniform(0.05, 0.6, size=N).astype(np.float32)  # Lower coefficients
    uppers[0] = 0.0  # Unused for root
    lowers[0] = 0.0  # Unused for root

    # Optional: Add small safety margin to avoid extreme ratios
    # diags = diags + 1e-3

    # --- 4) Convert numpy arrays to JAX arrays
    return (
        jnp.array(edges, dtype=jnp.int32),
        jnp.array(diags, dtype=jnp.float32),
        jnp.array(solves, dtype=jnp.float32),
        jnp.array(uppers, dtype=jnp.float32),
        jnp.array(lowers, dtype=jnp.float32),
        jnp.array(level_start, dtype=jnp.int32),
        jnp.array(level_size, dtype=jnp.int32),
    )


In [3]:
# ------------------------------
# Method A: Single kernel with internal for-loop across levels (levels-in-kernel)
# ------------------------------
def kernel_levels(
    edges_ref,
    diags_ref,
    solves_ref,
    uppers_ref,
    lowers_ref,
    level_start_ref,
    level_size_ref,
    out_diags_ref,
    out_solves_ref,
):
    tid = pl.program_id(0)  # Thread ID: ranges from 0 to (Wmax - 1)
    num_levels = level_size_ref.shape[0]  # Total number of levels in the tree

    # Iterate across all levels of the tree
    for d in range(num_levels):
        # Load start index and size of current level from reference arrays
        start = pl.load(level_start_ref, d)
        size = pl.load(level_size_ref, d)

        # Check if current thread is active (within the edge count of current level)
        active = tid < size
        e = start + tid  # Current edge index in the edges array

        # Load child and parent indices (use 0/-1 as default if thread is inactive)
        c = pl.load(edges_ref, (e, 0), mask=active, other=0)
        p = pl.load(edges_ref, (e, 1), mask=active, other=-1)
        valid = active & (p >= 0)  # Valid only if thread is active and parent is valid

        # Load coefficients (use default values if invalid)
        up = pl.load(uppers_ref, c, mask=valid, other=0.0)
        lo = pl.load(lowers_ref, c, mask=valid, other=0.0)
        dc = pl.load(out_diags_ref, c, mask=valid, other=1.0)
        sc = pl.load(out_solves_ref, c, mask=valid, other=0.0)

        # Calculate multiplier and update parent nodes using atomic add (avoid race conditions)
        mul = up / dc
        pl.atomic_add(out_diags_ref, p, -lo * mul, mask=valid)
        pl.atomic_add(out_solves_ref, p, -sc * mul, mask=valid)

from functools import partial
@partial(jax.jit, static_argnames=['Wmax'])
# @jax.jit  # Uncomment this line to disable static Wmax (use only if necessary)
def run_levels_single_kernel(edges, diags, solves, uppers, lowers, level_start, level_size, Wmax):
    def _call(edges, diags, solves, uppers, lowers, level_start, level_size):
        # Call Pallas kernel for single-kernel level processing
        return pl.pallas_call(
            kernel_levels,
            # Define output shape and dtype (matches input diags/solves)
            out_shape=(
                jax.ShapeDtypeStruct(diags.shape, diags.dtype),
                jax.ShapeDtypeStruct(solves.shape, solves.dtype),
            ),
            grid=(Wmax,),  # Grid size: Wmax threads (one per thread ID)
            # Input-output aliases: map output indices to input indices
            # Output 0 → Input 1 (diags), Output 1 → Input 2 (solves)
            input_output_aliases={1: 0, 2: 1},
        )(edges, diags, solves, uppers, lowers, level_start, level_size)
    
    return _call(edges, diags, solves, uppers, lowers, level_start, level_size)



In [4]:
# ------------------------------
# Method B: One kernel per level (fori_loop = inter-level barrier)
# ------------------------------
def kernel_one_level(
    edges_ref,
    diags_ref,
    solves_ref,
    uppers_ref,
    lowers_ref,
    start_ref,  # Scalar: start index of current level
    size_ref,   # Scalar: number of edges in current level
    out_diags_ref,
    out_solves_ref,
):
    tid = pl.program_id(0)  # Thread ID: ranges from 0 to (Wmax - 1)

    # Load start index and size of current level (scalar values)
    start = pl.load(start_ref, ())
    size = pl.load(size_ref, ())

    # Check if current thread is active (within the edge count of current level)
    active = tid < size
    e = start + tid  # Current edge index in the edges array

    # Load child and parent indices (use 0/-1 as default if thread is inactive)
    c = pl.load(edges_ref, (e, 0), mask=active, other=0)
    p = pl.load(edges_ref, (e, 1), mask=active, other=-1)
    valid = active & (p >= 0)  # Valid only if thread is active and parent is valid

    # Load coefficients (use default values if invalid)
    up = pl.load(uppers_ref, c, mask=valid, other=0.0)
    lo = pl.load(lowers_ref, c, mask=valid, other=0.0)
    dc = pl.load(diags_ref, c, mask=valid, other=1.0)
    sc = pl.load(solves_ref, c, mask=valid, other=0.0)

    # Calculate multiplier and update parent nodes using atomic add
    mul = up / dc
    pl.atomic_add(out_diags_ref, p, -lo * mul, mask=valid)
    pl.atomic_add(out_solves_ref, p, -sc * mul, mask=valid)

# @jax.jit  # Uncomment this line to disable static Wmax (use only if necessary)
@partial(jax.jit, static_argnames=['Wmax'])
def run_levels_per_kernel(edges, diags, solves, uppers, lowers, level_start, level_size, Wmax = 32):
    # Define Pallas kernel call for a single level
    call_one_level = pl.pallas_call(
        kernel_one_level,
        # Define output shape and dtype (matches input diags/solves)
        out_shape=(
            jax.ShapeDtypeStruct(diags.shape, diags.dtype),
            jax.ShapeDtypeStruct(solves.shape, solves.dtype),
        ),
        grid=(Wmax,),  # Grid size: Wmax threads per level
        # Input-output aliases: map output indices to input indices
        input_output_aliases={1: 0, 2: 1},
    )

    def _run(edges, diags, solves, uppers, lowers, level_start, level_size):
        # For-loop across levels (uses JAX's lax.fori_loop for XLA compatibility)
        def body(d, carry):
            di, so = carry  # Current state: diags (di) and solves (so)
            # Dynamically load start index and size of current level
            s = jax.lax.dynamic_index_in_dim(level_start, d, keepdims=False)
            sz = jax.lax.dynamic_index_in_dim(level_size, d, keepdims=False)
            # Call kernel for current level and return updated state
            out_di, out_so = call_one_level(edges, di, so, uppers, lowers, s, sz)
            return (out_di, out_so)
        
        # Iterate from level 0 to (number of levels - 1), initial state = (diags, solves)
        return jax.lax.fori_loop(0, level_size.shape[0], body, (diags, solves))
    
    return _run(edges, diags, solves, uppers, lowers, level_start, level_size)



In [5]:
# ------------------------------
# Method C: Pure JAX reference implementation (CPU-style scalar loops)
# ------------------------------
def run_cpu_ref(edges, diags, solves, uppers, lowers, level_start, level_size):
    """CPU-style reference implementation with scalar loops (not JIT-compatible directly)"""
    od = diags  # Copy of diags (output diags)
    os = solves  # Copy of solves (output solves)
    num_levels = int(level_size.shape[0])  # Total number of levels

    # Iterate across all levels
    for d in range(num_levels):
        s = int(level_start[d])  # Start index of current level
        sz = int(level_size[d])  # Number of edges in current level
        # Iterate across all edges in current level
        for e in range(s, s + sz):
            # Get child and parent indices (scalar values)
            c, p = int(edges[e, 0]), int(edges[e, 1])
            if p >= 0:  # Skip if parent is invalid
                mul = uppers[c] / od[c]
                # Update parent's diags and solves
                od = od.at[p].add(-lowers[c] * mul)
                os = os.at[p].add(-os[c] * mul)
    
    return od, os

@jax.jit
def run_cpu_ref_jit(edges, diags, solves, uppers, lowers, level_start, level_size):
    """JIT-compatible version of the CPU reference implementation (uses lax.fori_loop)"""
    # For-loop across levels
    def body_level(d, carry):
        od, os = carry  # Current state: diags (od) and solves (os)
        # Dynamically load start index and size of current level
        s = jax.lax.dynamic_index_in_dim(level_start, d, keepdims=False)
        sz = jax.lax.dynamic_index_in_dim(level_size, d, keepdims=False)

        # For-loop across edges in current level
        def body_edge(i, carry2):
            od2, os2 = carry2  # Current state for edge loop
            e = s + i  # Current edge index
            # Dynamically load child and parent indices from edges array
            c = jax.lax.dynamic_index_in_dim(edges[:, 0], e, keepdims=False)
            p = jax.lax.dynamic_index_in_dim(edges[:, 1], e, keepdims=False)

            # Calculate multiplier
            mul = uppers[c] / od2[c]

            # Update parent's diags only if parent is valid (p >= 0)
            od2 = jax.lax.cond(
                p >= 0,
                lambda x: x.at[p].add(-lowers[c] * mul),  # Update if condition is true
                lambda x: x,  # Keep unchanged if condition is false
                operand=od2,
            )

            # Update parent's solves only if parent is valid (p >= 0)
            os2 = jax.lax.cond(
                p >= 0,
                lambda x: x.at[p].add(-os2[c] * mul),  # Update if condition is true
                lambda x: x,  # Keep unchanged if condition is false
                operand=os2,
            )

            return od2, os2

        # Iterate across edges (0 to sz-1), initial state = (od, os)
        return jax.lax.fori_loop(0, sz, body_edge, (od, os))

    # Iterate across levels (0 to num_levels-1), initial state = (diags, solves)
    return jax.lax.fori_loop(0, level_size.shape[0], body_level, (diags, solves))

@partial(jax.jit, static_argnames=['Wmax'])
def run_levels_vectorized(edges, diags, solves, uppers, lowers, level_start, level_size, Wmax):
    """
    Vectorized level-wise elimination/back-substitution (uses scatter-add for updates)
    
    Parameters
    ----
    edges : (E, 2) int32 array
        Each row represents an edge as (child, parent); parent = -1 means no parent
    diags, solves, uppers, lowers : (N,) float32 arrays
        Coefficients for nodes and edges
    level_start, level_size : (L,) int32 arrays
        Interval [start, start+size) for child nodes in each level

    Returns
    ----
    diags, solves : Updated float32 arrays after processing all levels
    """
    
    # For-loop across levels (vectorized processing per level)
    def body(d, carry):
        diags, solves = carry  # Current state: diags and solves
        start = level_start[d]  # Start index of current level
        size = level_size[d]    # Number of edges in current level

        # Dynamically slice edges array to get current level's edges
        # Dynamically slice edges array to get current level's edges (padded to Wmax)
        edge_block = jax.lax.dynamic_slice_in_dim(edges, start, Wmax, axis=0)
        children = edge_block[:, 0]  # Child indices of current level edges
        parents = edge_block[:, 1]   # Parent indices of current level edges

        # Create mask: threads within the actual edge count (ignore padded elements)
        valid_mask = jnp.arange(Wmax) < size
        # Final mask: valid thread AND valid parent (parent >= 0)
        mask = (parents >= 0) & valid_mask

        # Load coefficients with mask (zero out invalid entries)
        up = uppers[children] * mask
        lo = lowers[children] * mask
        # Avoid division by zero: set di to 1 for invalid entries (mask=0)
        di = diags[children] * mask + (~mask)
        so = solves[children] * mask

        # Calculate update values
        multiplier = up / di          # Multiplier for parent updates
        delta_diag = -lo * multiplier # Delta for parent diags
        delta_solve = -so * multiplier# Delta for parent solves

        # Scatter-add updates to parent nodes (apply mask to skip invalid updates)
        diags = diags.at[parents].add(jnp.where(mask, delta_diag, 0.0))
        solves = solves.at[parents].add(jnp.where(mask, delta_solve, 0.0))

        return (diags, solves)

    # Iterate across all levels (from 0 to number of levels - 1)
    diags, solves = jax.lax.fori_loop(
        0, level_start.shape[0], body, (diags, solves)
    )
    return diags, solves



In [6]:
# ------------------------------
# Consistency Check (verify if results from different methods match)
# ------------------------------
def check_allclose(a, b, atol=1e-5, rtol=1e-5, name=""):
    """
    Check if two arrays are element-wise close within tolerance.
    
    Args:
        a, b: Arrays to compare
        atol: Absolute tolerance
        rtol: Relative tolerance
        name: Name of the comparison (for error logging)
    
    Returns:
        bool: True if arrays are close, False otherwise
    """
    ok = jnp.allclose(a, b, atol=atol, rtol=rtol)
    if not bool(ok):
        # Calculate maximum absolute difference for error reporting
        diff = jnp.max(jnp.abs(a - b))
        print(f"[Mismatch] {name}: max|difference|={float(diff)}")
    return bool(ok)


# ------------------------------
# Benchmark Pipeline: Warmup → Multiple Runs → Statistics
# ------------------------------
def benchmark_all(N=500, seed=42, repeats=10, verify=True):
    """
    Benchmark all four tree processing methods and verify result consistency.
    
    Args:
        N: Number of nodes in the generated tree
        seed: Random seed for tree data generation (ensures reproducibility)
        repeats: Number of benchmark runs per method (for statistics)
        verify: Whether to check if results match the reference method (Method C)
    """
    print(f"\n=== Benchmark: N={N}, Repeats={repeats} ===")
    # Generate tree data (edges + coefficients)
    edges, diags, solves, uppers, lowers, level_start, level_size = make_tree_data(N, seed=seed)
    # Wmax = maximum number of edges in any level (for kernel grid sizing)
    Wmax = int(level_size.max().item())

    # ------------------------------
    # Benchmark Method A: Single Kernel (Levels-in-Kernel)
    # ------------------------------
    # Warmup: Run once to compile JIT and initialize resources
    t0 = time.perf_counter()
    A_di, A_so = run_levels_single_kernel(edges, diags, solves, uppers, lowers, level_start, level_size, Wmax)
    jax.block_until_ready((A_di, A_so))  # Wait for computation to finish
    warmup_time = time.perf_counter() - t0  # Ignore warmup time

    # Record time for multiple runs
    times_A = []
    for _ in range(repeats):
        t0 = time.perf_counter()
        A_di, A_so = run_levels_single_kernel(edges, diags, solves, uppers, lowers, level_start, level_size, Wmax)
        jax.block_until_ready((A_di, A_so))
        times_A.append(time.perf_counter() - t0)

    # ------------------------------
    # Benchmark Method B: Per-Level Kernel
    # ------------------------------
    # Warmup
    t0 = time.perf_counter()
    B_di, B_so = run_levels_per_kernel(edges, diags, solves, uppers, lowers, level_start, level_size, Wmax)
    jax.block_until_ready((B_di, B_so))
    warmup_time = time.perf_counter() - t0  # Ignore warmup time

    # Record time for multiple runs
    times_B = []
    for _ in range(repeats):
        t0 = time.perf_counter()
        B_di, B_so = run_levels_per_kernel(edges, diags, solves, uppers, lowers, level_start, level_size, Wmax)
        jax.block_until_ready((B_di, B_so))
        times_B.append(time.perf_counter() - t0)

    # ------------------------------
    # Benchmark Method C: Pure JAX Reference (JIT)
    # ------------------------------
    # Warmup
    t0 = time.perf_counter()
    C_di, C_so = run_cpu_ref_jit(edges, diags, solves, uppers, lowers, level_start, level_size)
    jax.block_until_ready((C_di, C_so))
    warmup_time = time.perf_counter() - t0  # Ignore warmup time

    # Record time for multiple runs
    times_C = []
    for _ in range(repeats):
        t0 = time.perf_counter()
        C_di, C_so = run_cpu_ref_jit(edges, diags, solves, uppers, lowers, level_start, level_size)
        jax.block_until_ready((C_di, C_so))
        times_C.append(time.perf_counter() - t0)

    # ------------------------------
    # Benchmark Method D: Vectorized JAX
    # ------------------------------
    # Warmup
    t0 = time.perf_counter()
    D_di, D_so = run_levels_vectorized(edges, diags, solves, uppers, lowers, level_start, level_size, Wmax)
    jax.block_until_ready((D_di, D_so))
    warmup_time = time.perf_counter() - t0  # Ignore warmup time

    # Record time for multiple runs
    times_D = []
    for _ in range(repeats):
        t0 = time.perf_counter()
        D_di, D_so = run_levels_vectorized(edges, diags, solves, uppers, lowers, level_start, level_size, Wmax)
        jax.block_until_ready((D_di, D_so))
        times_D.append(time.perf_counter() - t0)

    # ------------------------------
    # Consistency Verification (Compare to Reference Method C)
    # ------------------------------
    if verify:
        # Check if Method A matches Reference C
        ok1 = check_allclose(A_di, C_di, name="A vs C (diags)") and check_allclose(A_so, C_so, name="A vs C (solves)")
        # Check if Method B matches Reference C
        ok2 = check_allclose(B_di, C_di, name="B vs C (diags)") and check_allclose(B_so, C_so, name="B vs C (solves)")
        # Check if Method D matches Reference C
        ok3 = check_allclose(D_di, C_di, name="D vs C (diags)") and check_allclose(D_so, C_so, name="D vs C (solves)")
        print(f"Verification Results: A==C → {ok1}, B==C → {ok2}, D==C → {ok3}")

    # ------------------------------
    # Calculate and Print Statistics
    # ------------------------------
    def calculate_stats(times):
        """Convert time list to statistics (mean, median, min, max) in milliseconds"""
        times_arr = jnp.array(times)
        mean = float(times_arr.mean() * 1000)    # Convert to milliseconds
        median = float(jnp.median(times_arr) * 1000)
        min_val = float(times_arr.min() * 1000)
        max_val = float(times_arr.max() * 1000)
        return mean, median, min_val, max_val

    # Compute stats for all methods
    meanA, medA, minA, maxA = calculate_stats(times_A)
    meanB, medB, minB, maxB = calculate_stats(times_B)
    meanC, medC, minC, maxC = calculate_stats(times_C)
    meanD, medD, minD, maxD = calculate_stats(times_D)

    # Print time statistics
    print("\n--- Time Statistics (Milliseconds) ---")
    print(f"A  Single-kernel across levels : Mean={meanA:.3f}, Median={medA:.3f}, Min={minA:.3f}, Max={maxA:.3f}")
    print(f"B  Per-level kernel + barrier  : Mean={meanB:.3f}, Median={medB:.3f}, Min={minB:.3f}, Max={maxB:.3f}")
    print(f"C  Pure JAX (scalar loops)     : Mean={meanC:.3f}, Median={medC:.3f}, Min={minC:.3f}, Max={maxC:.3f}")
    print(f"D  Vectorized JAX              : Mean={meanD:.3f}, Median={medD:.3f}, Min={minD:.3f}, Max={maxD:.3f}")

    # Print speedup relative to Reference Method C (median time ratio)
    print("\n--- Speedup (Relative to Method C, Median Time) ---")
    speedupA = medC / medA if medA > 0 else float('inf')
    speedupB = medC / medB if medB > 0 else float('inf')
    speedupD = medC / medD if medD > 0 else float('inf')
    print(f"A is {speedupA:.2f}× faster than C")
    print(f"B is {speedupB:.2f}× faster than C")
    print(f"D is {speedupD:.2f}× faster than C")

In [7]:
benchmark_all(N=500,seed=42, repeats=10, verify=True)


=== Benchmark: N=500, Repeats=10 ===
[Mismatch] A vs C (diags): max|difference|=0.003604888916015625
[Mismatch] D vs C (diags): max|difference|=0.5986104011535645
Verification Results: A==C → False, B==C → True, D==C → False

--- Time Statistics (Milliseconds) ---
A  Single-kernel across levels : Mean=0.278, Median=0.271, Min=0.220, Max=0.404
B  Per-level kernel + barrier  : Mean=0.309, Median=0.297, Min=0.285, Max=0.388
C  Pure JAX (scalar loops)     : Mean=35.922, Median=35.982, Min=34.479, Max=36.974
D  Vectorized JAX              : Mean=0.367, Median=0.355, Min=0.340, Max=0.469

--- Speedup (Relative to Method C, Median Time) ---
A is 132.92× faster than C
B is 121.27× faster than C
D is 101.50× faster than C


In [30]:
import jax
import jax.numpy as jnp
import jax.experimental.pallas as pl

def kernel_all_levels_serial(
    edges_ref, diags_ref, solves_ref, uppers_ref, lowers_ref,
    level_start_ref, level_size_ref, num_levels_ref,
    out_diags_ref, out_solves_ref,
):
    L = pl.load(num_levels_ref, ())

    def level_body(d, _):
        s  = pl.load(level_start_ref, d)
        sz = pl.load(level_size_ref,  d)

        def edge_body(i, __):
            e = s + i
            c = pl.load(edges_ref, (e, 0))
            p = pl.load(edges_ref, (e, 1))

            def do_update(_):
                up = pl.load(uppers_ref, c)
                lo = pl.load(lowers_ref, c)
                dc = pl.load(diags_ref,  c)
                sc = pl.load(solves_ref, c)
                mul = up / dc
                pdi = pl.load(out_diags_ref, p)
                pso = pl.load(out_solves_ref, p)
                pl.store(out_diags_ref, p, pdi - lo * mul)
                pl.store(out_solves_ref, p, pso - sc * mul)
                return None

            return jax.lax.cond(p >= 0, do_update, lambda _: None, None)

        return jax.lax.fori_loop(0, sz, edge_body, None)

    jax.lax.fori_loop(0, L, level_body, None)


@jax.jit
def run_all_levels_single_kernel(edges, diags, solves, uppers, lowers, level_start, level_size):
    num_levels = jnp.asarray(level_size.shape[0], dtype=jnp.int32)

    call_all = pl.pallas_call(
        kernel_all_levels_serial,
        out_shape=(
            jax.ShapeDtypeStruct(diags.shape, diags.dtype),
            jax.ShapeDtypeStruct(solves.shape, solves.dtype),
        ),
        grid=(1,),  
        input_output_aliases={1: 0, 2: 1},
    )

    out_di, out_so = call_all(
        edges, diags, solves, uppers, lowers,
        level_start, level_size, num_levels
    )
    return out_di, out_so


In [31]:
edges, diags, solves, uppers, lowers, level_start, level_size = make_tree_data(500, seed=1)

In [35]:
# Warmup
t0 = time.perf_counter()
D_di, D_so = run_levels_per_kernel(edges, diags, solves, uppers, lowers, level_start, level_size)
jax.block_until_ready((D_di, D_so))
warmup_time = time.perf_counter() - t0  # Ignore warmup time

times_D = []
for _ in range(10):
    t0 = time.perf_counter()
    D_di, D_so = run_all_levels_single_kernel(edges, diags, solves, uppers, lowers, level_start, level_size)
    jax.block_until_ready((D_di, D_so))
    times_D.append(time.perf_counter() - t0)

In [33]:
def calculate_stats(times):
    """Convert time list to statistics (mean, median, min, max) in milliseconds"""
    times_arr = jnp.array(times)
    mean = float(times_arr.mean() * 1000)    # Convert to milliseconds
    median = float(jnp.median(times_arr) * 1000)
    min_val = float(times_arr.min() * 1000)
    max_val = float(times_arr.max() * 1000)
    return mean, median, min_val, max_val

In [36]:
meanD, medD, minD, maxD = calculate_stats(times_D)
print(f"D  Vectorized JAX              : Mean={meanD:.3f}, Median={medD:.3f}, Min={minD:.3f}, Max={maxD:.3f}")

D  Vectorized JAX              : Mean=0.330, Median=0.324, Min=0.311, Max=0.389
