In [21]:
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
import numpy as np
import time
from functools import partial

print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")

JAX version: 0.6.1
Available devices: [CudaDevice(id=0)]


In [23]:
def slow_function(x):
    """A function without JIT compilation"""
    for i in range(100):
        x = jnp.sin(x) + jnp.cos(x)
    return x

@jit
def fast_function(x):
    """The same function with JIT compilation"""
    for i in range(100):
        x = jnp.sin(x) + jnp.cos(x)
    return x

x = jnp.array([1.0, 2.0, 3.0])

# Time the functions
print("First call (includes compilation time):")
start = time.time()
result_fast = fast_function(x)
fast_time_first = time.time() - start
print(f"JIT function first call: {fast_time_first:.6f} seconds")

start = time.time()
result_slow = slow_function(x)
slow_time = time.time() - start
print(f"Regular function: {slow_time:.6f} seconds")

print("\nSecond call (compiled version):")
start = time.time()
result_fast = fast_function(x)
result_fast.block_until_ready()  # Ensure the computation is complete
fast_time_second = time.time() - start
print(f"JIT function second call: {fast_time_second:.6f} seconds")

print(f"\nSpeedup: {slow_time / fast_time_second:.2f}x")

First call (includes compilation time):
JIT function first call: 1.789231 seconds
Regular function: 0.014970 seconds

Second call (compiled version):
JIT function second call: 0.001622 seconds

Speedup: 9.23x


In [25]:
@jit
def add_multiply(x, y):
    print(x)
    return x + y, x * y

# First call - triggers compilation
print("First call with specific shapes:")
a = jnp.array([1.0, 2.0])
b = jnp.array([3.0, 4.0])
result1 = add_multiply(a, b)
print(f"Result: {result1}")

# Second call with same shapes - uses cached compilation
print("\nSecond call with same shapes:")
c = jnp.array([5.0, 6.0])
d = jnp.array([7.0, 8.0])
result2 = add_multiply(c, d)
print(f"Result: {result2}")

# Different shapes trigger recompilation
print("\nCall with different shapes (triggers recompilation):")
e = jnp.array([1.0, 2.0, 3.0])  # Different shape!
f = jnp.array([4.0, 5.0, 6.0])
result3 = add_multiply(e, f)
print(f"Result: {result3}")

First call with specific shapes:
Traced<float32[2]>with<DynamicJaxprTrace>
Result: (Array([4., 6.], dtype=float32), Array([3., 8.], dtype=float32))

Second call with same shapes:
Result: (Array([12., 14.], dtype=float32), Array([35., 48.], dtype=float32))

Call with different shapes (triggers recompilation):
Traced<float32[3]>with<DynamicJaxprTrace>
Result: (Array([5., 7., 9.], dtype=float32), Array([ 4., 10., 18.], dtype=float32))


In [36]:
print("=== Static vs Dynamic Arguments in JIT ===\n")

# Example 1: Loop bounds MUST be static
@jit
def bad_dynamic_loop(x, num_iterations):
    """This will fail - loop bounds must be known at compile time"""
    for i in range(num_iterations):  # ERROR: num_iterations must be static!
        x = x * 1.1 + 0.01
    return x

@partial(jit, static_argnums=(1,))
def good_static_loop(x, num_iterations):
    """This works - num_iterations is marked as static"""
    for i in range(num_iterations):
        x = x * 1.1 + 0.01
    return x

x = jnp.array([1.0, 2.0, 3.0])

# This will fail with a TracerIntegerConversionError
print("❌ Trying dynamic loop bounds:")
try:
    result = bad_dynamic_loop(x, 5)
    print(f"Unexpected success: {result}")
except Exception as e:
    print(f"Failed as expected: {type(e).__name__}")
    print(f"Error message: {str(e)[:100]}...")

print("\n✅ Using static loop bounds:")
result = good_static_loop(x, 5)
print(f"5 iterations: {result}")

result = good_static_loop(x, 3)  # Different static value = recompilation
print(f"3 iterations: {result}")



=== Static vs Dynamic Arguments in JIT ===

❌ Trying dynamic loop bounds:
Failed as expected: TracerIntegerConversionError
Error message: The __index__() method was called on traced array with shape int32[]
The error occurred while tracin...

✅ Using static loop bounds:
5 iterations: [1.6715611 3.282071  4.8925824]
3 iterations: [1.3641001 2.6951    4.0261006]


In [34]:
print("=== Control Flow in JIT: Good vs Bad Examples ===\n")

# ✅ GOOD: Using jnp.where for conditional logic
@jit
def good_conditional(x, threshold=5.0):
    """Control flow that works with JIT"""
    print(x)
    return jnp.where(x > threshold, jnp.sin(x), jnp.cos(x))

# ❌ BAD: Python if/else with array-dependent conditions
def bad_python_if(x):
    """This will cause issues with JIT due to Python control flow"""
    if jnp.sum(x) > 5.0:  # Condition depends on array values
        return jnp.sin(x)
    else:
        return jnp.cos(x)

# Let's see what happens when we JIT the bad version
try:
    bad_jit = jit(bad_python_if)
    x_test = jnp.array([1.0, 2.0, 3.0])
    result = bad_jit(x_test)
    
except Exception as e:
    print(f"❌ JIT failed on bad function: {e}")
    
print(good_conditional(x_test))


=== Control Flow in JIT: Good vs Bad Examples ===

❌ JIT failed on bad function: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function bad_python_if at /tmp/ipykernel_836/737641638.py:11 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError
Traced<float32[3]>with<DynamicJaxprTrace>
[ 0.5403023 -0.4161468 -0.9899925]


In [40]:
print("=== Loops in JIT: Python vs JAX primitives ===\n")

# Example 1: Moving average with different approaches
print("1. Moving Average Computation:")

# ❌ Python loop approach (works but not optimal)
@partial(jit, static_argnums=(1,))
def moving_avg_python_loop(data, window_size):
    """Moving average using Python loop - requires static window_size"""
    n = len(data)
    result = jnp.zeros(n - window_size + 1)
    
    for i in range(len(result)):  # Loop bound must be static!
        result = result.at[i].set(jnp.mean(data[i:i+window_size]))
    
    return result

# ✅ Using jax.lax.scan (better for dynamic operations)
@partial(jit, static_argnums=(1,))
def moving_avg_scan(data, window_size):
    """Moving average using scan - more flexible"""
    def scan_fn(i, _):
        avg = jnp.mean(jax.lax.dynamic_slice(data, (i,), (window_size,)))
        return i + 1, avg

    n_windows = data.shape[0] - window_size + 1
    _, averages = jax.lax.scan(scan_fn, 0, xs=None, length=n_windows)
    return averages

# Test data
data = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
window_size = 3

print("Data:", data)
print("Python loop result:", moving_avg_python_loop(data, window_size))
print("Scan result:", moving_avg_scan(data, window_size))


=== Loops in JIT: Python vs JAX primitives ===

1. Moving Average Computation:
Data: [ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10.]
Python loop result: [2. 3. 4. 5. 6. 7. 8. 9.]
Scan result: [2. 3. 4. 5. 6. 7. 8. 9.]


In [41]:
# Cell 7: Combining JIT with Other JAX Transformations
@jit  # JIT the gradient computation
def loss_function(params, x, y):
    """A simple quadratic loss function"""
    pred = params['w'] * x + params['b']
    return jnp.mean((pred - y)**2)

# Create the gradient function and JIT it
grad_loss = jit(grad(loss_function, argnums=0))

# Sample data
x_data = jnp.array([1.0, 2.0, 3.0, 4.0])
y_data = jnp.array([2.0, 4.0, 6.0, 8.0])  # y = 2x
params = {'w': 1.5, 'b': 0.1}

print("Original loss:", loss_function(params, x_data, y_data))
print("Gradients:", grad_loss(params, x_data, y_data))

# JIT with vmap for batch processing
@jit
def batch_prediction(params, x_batch):
    """Vectorized prediction"""
    return vmap(lambda x: params['w'] * x + params['b'])(x_batch)

batch_x = jnp.array([[1., 2.], [3., 4.], [5., 6.]])
batch_pred = batch_prediction(params, batch_x)
print(f"Batch predictions: {batch_pred}")

Original loss: 1.6350002
Gradients: {'b': Array(-2.3000002, dtype=float32, weak_type=True), 'w': Array(-7.0000005, dtype=float32, weak_type=True)}
Batch predictions: [[1.6 3.1]
 [4.6 6.1]
 [7.6 9.1]]


In [42]:
# Cell 8: Debugging JIT Functions
@jit
def debug_function(x):
    """Function to demonstrate debugging techniques"""
    intermediate = x * 2
    # Use jax.debug.print for debugging inside JIT
    jax.debug.print("Debug: intermediate value = {}", intermediate)
    result = jnp.sin(intermediate)
    return result

# Temporarily disable JIT for debugging
def debug_function_no_jit(x):
    """Same function without JIT for debugging"""
    intermediate = x * 2
    print(f"Debug: intermediate value = {intermediate}")  # Regular print works
    result = jnp.sin(intermediate)
    return result

x = jnp.array([1.0, 2.0])
print("With JIT (using jax.debug.print):")
result_jit = debug_function(x)

print("\nWithout JIT (for debugging):")
result_no_jit = debug_function_no_jit(x)

print(f"Results match: {jnp.allclose(result_jit, result_no_jit)}")

With JIT (using jax.debug.print):
Debug: intermediate value = [2. 4.]

Without JIT (for debugging):
Debug: intermediate value = [2. 4.]
Results match: True


In [43]:
# Cell 10: Common Pitfalls and Best Practices
print("=== Common JIT Pitfalls and Solutions ===\n")

# Pitfall 1: Side effects
print("1. Side Effects Issue:")
counter = 0

@jit
def bad_side_effect(x):
    global counter
    counter += 1  # This won't work as expected in JIT!
    return x * 2

# The counter increment happens only during compilation
result1 = bad_side_effect(jnp.array([1.0]))
result2 = bad_side_effect(jnp.array([2.0]))  # Same shape, uses cached version
print(f"Counter after two calls: {counter}")  # Will be 1, not 2!

# Solution: Keep state functional
def good_functional_approach(x, counter):
    return x * 2, counter + 1

# Pitfall 2: Using Python containers
print("\n2. Python Containers Issue:")

@jit
def list_function(x):
    # This works but isn't efficient
    result_list = []
    for i in range(3):
        result_list.append(x[i] * 2)
    return jnp.array(result_list)

# Better approach
@jit
def array_function(x):
    return x[:3] * 2

x = jnp.array([1, 2, 3, 4, 5])
print(f"List approach: {list_function(x)}")
print(f"Array approach: {array_function(x)}")

=== Common JIT Pitfalls and Solutions ===

1. Side Effects Issue:
Counter after two calls: 1

2. Python Containers Issue:
List approach: [2 4 6]
Array approach: [2 4 6]
