# Fusion Lab: JAX Operator Fusion & Bandwidth Analysis

**Objective:** Demonstrate how JAX's **XLA Compiler (JIT)** fuses element-wise operations into single kernels to overcome memory bottlenecks.

**Key Concepts:**
* **Memory Bandwidth:** The speed at which data moves between VRAM and the Compute Units.
* **Operator Fusion:** Merging multiple math steps (e.g., sin -> cos -> add) into one loop to avoid reading/writing to VRAM repeatedly.
* **Arithmetic Intensity (The "Workload" Checkboxes):**
    * *Light:* Low intensity. The GPU is idle waiting for memory.
    * *Heavy:* High intensity. The GPU is busy calculating, hiding memory latency.

**Instructions:**
1.  **Check Runtime:** Ensure you are using a GPU (Runtime > Change runtime type > T4 GPU).
2.  **Select Workloads:** Use the checkboxes to compare Low vs. High Arithmetic Intensity.
3.  **Run:** Click the **Run Benchmarks** button to see the Roofline plots.

In [None]:
# @title Fusion Lab

import jax
import jax.numpy as jnp
import time
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

# --- Configuration ---
# Increasing this range to ensure we actually hit VRAM limits on larger GPUs
SIZES = [1024 * (2**i) for i in range(16)]
key = jax.random.key(0)

# --- Kernel Definitions ---
# We want to compare low arithmetic intensity vs high intensity to see
# where the memory bandwidth bottleneck actually hits.

def simple_scale(x):
    # Memory bound: barely any math, mostly reading/writing
    return x * 0.5 + 1.0

def activation_mix(x):
    # Standard DL workload: mix of trig and basic arithmetic
    return jnp.tanh(x) * jnp.sin(x) + (x * 0.5)

def heavy_shader(x):
    # Compute bound: forcing the ALU to work hard to hide memory latency
    y = x
    for _ in range(5):
        y = jnp.sin(y) * jnp.cos(y) + jnp.tanh(y)
    return y

kernels = {
    "Light (Memory Bound)": simple_scale,
    "Medium (Standard)": activation_mix,
    "Heavy (Compute Bound)": heavy_shader
}

# --- UI Setup ---
print("Initializing Dashboard...")

# Checkboxes for selecting which kernels to plot
checkboxes = {
    name: widgets.Checkbox(value=True, description=name, indent=False)
    for name in kernels
}
checkbox_ui = widgets.HBox(list(checkboxes.values()))

# The big run button
btn_run = widgets.Button(
    description="Run Profiler",
    button_style='success',
    icon='rocket' # classic dev humor
)
output_area = widgets.Output()

# --- Benchmarking Logic ---
def run_benchmark(b):
    with output_area:
        clear_output(wait=True)

        # quick hardware check to manage expectations
        backend = jax.default_backend()
        device = jax.devices()[0].device_kind
        print(f"Running on: {device} ({backend.upper()})")

        if backend == 'cpu':
            print("Note: You're on CPU. You won't see the roofline effect clearly.")
            print("Tip: Switch runtime to T4 GPU for the actual bandwidth test.")
        else:
            print("GPU Active. Bandwidth saturation test ready.")
        print("-" * 40)

        # Filter enabled tests
        active_tests = [k for k, cb in checkboxes.items() if cb.value]
        if not active_tests:
            print("Select a workload first!")
            return

        # Setup plots
        # squeeze=False ensures 'axes' is ALWAYS a 2D array, preventing index errors
        fig, axes = plt.subplots(len(active_tests), 1, figsize=(10, 5 * len(active_tests)), squeeze=False)

        for idx, test_name in enumerate(active_tests):
            ax = axes[idx, 0]
            func = kernels[test_name]
            jit_func = jax.jit(func)

            print(f"Profiling: {test_name}...")

            data_mb = []
            bw_eager = []
            bw_jit = []

            for n in SIZES:
                # Generate random data
                dim = int(np.sqrt(n))
                x = jax.random.normal(key, (dim, dim))

                # total read + write bytes (approx)
                total_bytes = (n * 4) * 2

                # 1. Eager Execution
                # CRITICAL: warm up call to block async dispatch
                _ = func(x[:2,:2]).block_until_ready()

                t0 = time.time()
                func(x).block_until_ready()
                t_eager = time.time() - t0

                # 2. JIT (Fused) Execution
                # Compile/Warmup first! Otherwise we time compilation.
                _ = jit_func(x[:2,:2]).block_until_ready()

                t0 = time.time()
                jit_func(x).block_until_ready()
                t_jit = time.time() - t0

                # Log stats
                data_mb.append((n*4) / 1e6) # size in MB
                bw_eager.append((total_bytes / t_eager) / 1e9) # GB/s
                bw_jit.append((total_bytes / t_jit) / 1e9)     # GB/s

            # Visualization
            ax.plot(data_mb, bw_jit, 'o-', c='green', label='JIT (Fused)')
            ax.plot(data_mb, bw_eager, 'o--', c='red', label='Eager (Unfused)')

            ax.set_xscale('log')
            ax.set_ylabel('Throughput (GB/s)')
            ax.set_title(f"Roofline Analysis: {test_name}")
            ax.grid(True, which="both", alpha=0.3)
            ax.legend()

            if idx == len(active_tests) - 1:
                ax.set_xlabel('Tensor Size (MB)')

        print("Benchmark Complete.")
        plt.tight_layout()
        plt.show()

# Hook up events
btn_run.on_click(run_benchmark)

display(widgets.VBox([checkbox_ui, btn_run]))
display(output_area)