# Fusion Lab: JAX Operator Fusion & Bandwidth Analysis

**Objective:** Demonstrate how JAX's **XLA Compiler (JIT)** fuses element-wise operations into single kernels to overcome memory bottlenecks.

**Key Concepts:**
* **Memory Bandwidth:** The speed at which data moves between VRAM and the Compute Units.
* **Operator Fusion:** Merging multiple math steps (e.g., sin -> cos -> add) into one loop to avoid reading/writing to VRAM repeatedly.
* **Arithmetic Intensity (The "Workload" Checkboxes):**
    * *Light:* Low intensity. The GPU is idle waiting for memory.
    * *Heavy:* High intensity. The GPU is busy calculating, hiding memory latency.

**Instructions:**
1.  **Check Runtime:** Ensure you are using a GPU (Runtime > Change runtime type > T4 GPU).
2.  **Select Workloads:** Use the checkboxes to compare Low vs. High Arithmetic Intensity.
3.  **Run:** Click the **Run Benchmarks** button to see the Roofline plots.

In [None]:
# @title Fusion Lab

import jax
import jax.numpy as jnp
import time
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output

# 1. Define Workloads
# ---------------------------------------------------------
def op_light(x):
    return x * 0.5 + 1.0

def op_medium(x):
    return jnp.tanh(x) * jnp.sin(x) + (x * 0.5)

def op_heavy(x):
    # Simulates a complex shader.
    y = x
    for _ in range(5):
        y = jnp.sin(y) * jnp.cos(y) + jnp.tanh(y)
    return y

workloads = {
    "Light (Basic Math)": op_light,
    "Medium (Typical Activation)": op_medium,
    "Heavy (Complex 'Shader' Loop)": op_heavy
}

# 2. Interactive Dashboard Setup
# ---------------------------------------------------------
print("--- JAX Operator Fusion Multi-Lab ---")

checkboxes = {}
cb_layout = widgets.Layout(width='auto', margin='0px 20px 0px 0px')
for name in workloads.keys():
    checkboxes[name] = widgets.Checkbox(value=True, description=name, indent=False, layout=cb_layout)

checkbox_ui = widgets.HBox(list(checkboxes.values()))
button = widgets.Button(description="Run Benchmarks", button_style='success', icon='play', layout=widgets.Layout(margin='20px 0px 0px 0px'))
output = widgets.Output()

# 3. Benchmarking Logic
# ---------------------------------------------------------
def run_benchmark_multi(b):
    with output:
        clear_output(wait=True)

        # --- HARDWARE CHECK ---
        backend = jax.default_backend()
        device_name = jax.devices()[0].device_kind
        print(f"Hardware Detected: {device_name} ({backend.upper()})")

        if backend == 'cpu':
            print("⚠️ WARNING: You are running on CPU.")
            print("   The 'Heavy' workload will not show the expected Eager degradation.")
            print("   -> Go to 'Runtime' > 'Change runtime type' > Select 'T4 GPU' to see the real effect.")
            print("-" * 60)
        else:
            print("✅ GPU detected. Bandwidth saturation tests will be accurate.")
            print("-" * 60)

        selected_keys = [key for key, cb in checkboxes.items() if cb.value]
        n_plots = len(selected_keys)

        if n_plots == 0:
            print("Please select at least one workload.")
            return

        fig, axes = plt.subplots(n_plots, 1, figsize=(11, 5 * n_plots), squeeze=False)

        # INCREASED RANGE: Up to ~64MB elements (256MB data) to saturate GPU
        sizes = [1024 * (2**i) for i in range(16)]

        for idx, key in enumerate(selected_keys):
            ax = axes[idx, 0]
            func = workloads[key]
            jit_func = jax.jit(func)

            print(f"Benchmarking: {key}...")

            sizes_mb = []
            eager_bw = []
            jit_bw = []

            for n in sizes:
                dim = int(np.sqrt(n))
                x = jax.random.normal(jax.random.key(0), (dim, dim))
                total_bytes = (n * 4) * 2

                # Eager Run
                _ = func(x[:2,:2]).block_until_ready()
                s = time.time()
                func(x).block_until_ready()
                e_time = time.time() - s

                # JIT Run
                _ = jit_func(x[:2,:2]).block_until_ready()
                s = time.time()
                jit_func(x).block_until_ready()
                j_time = time.time() - s

                sizes_mb.append((n*4)/1e6)
                eager_bw.append((total_bytes/e_time)/1e9)
                jit_bw.append((total_bytes/j_time)/1e9)

            ax.plot(sizes_mb, jit_bw, 'o-', color='#2ecc71', label='JIT (Fused)', linewidth=2, markersize=6)
            ax.plot(sizes_mb, eager_bw, 'o--', color='#e74c3c', label='Eager (Unfused)', linewidth=2, markersize=6)
            ax.set_xscale('log')
            ax.set_ylabel('Effective BW (GB/s)', fontsize=10)
            ax.set_title(f"Roofline: {key}", fontsize=12, fontweight='bold')
            ax.grid(True, which="both", ls="-", alpha=0.3)
            ax.legend()

            if idx == n_plots - 1:
                ax.set_xlabel('Tensor Size (MB) - Log Scale', fontsize=10)

        print("Done!")
        plt.tight_layout()
        plt.show()

button.on_click(run_benchmark_multi)
display(widgets.VBox([checkbox_ui, button]))
display(output)