# Introduction to JAX

This notebook provides a brief introduction to [JAX](https://jax.readthedocs.io/), a library for high-performance numerical computing with automatic differentiation. JAX is particularly useful for scientific computing and machine learning because it allows us to:

1. Write NumPy-like code that runs fast
2. Automatically compute derivatives of functions

Other benefits (but not necessarily used in the project):

3. Easily vectorize computations
4. Compile code for execution on GPUs

These features make JAX a great choice for implementing Physics-Informed Neural Networks (PINNs), and is what is often used by researchers in this field.

In [1]:
import jax
import jax.numpy as jnp
import numpy as np

## JAX vs NumPy: Familiar Syntax

If you know NumPy, you already know most of JAX! The `jax.numpy` module provides almost the same API as NumPy.

In [2]:
# NumPy
x_np = np.array([1.0, 2.0, 3.0])
print("NumPy array:", x_np)
print("Sum:", np.sum(x_np))
print("Mean:", np.mean(x_np))

NumPy array: [1. 2. 3.]
Sum: 6.0
Mean: 2.0


In [3]:
# JAX is almost identical, only replace np with jnp!
x_jax = jnp.array([1.0, 2.0, 3.0])
print("JAX array:", x_jax)
print("Sum:", jnp.sum(x_jax))
print("Mean:", jnp.mean(x_jax))

JAX array: [1. 2. 3.]
Sum: 6.0
Mean: 2.0


Most NumPy operations work the same way in JAX:

In [4]:
# Matrix operations
A = jnp.array([[1.0, 2.0], [3.0, 4.0]])
b = jnp.array([1.0, 2.0])

print("Matrix A:\n", A)
print("Vector b:", b)
print("A @ b:", A @ b)
print("A.T:\n", A.T)
print("jnp.linalg.inv(A):\n", jnp.linalg.inv(A))

Matrix A:
 [[1. 2.]
 [3. 4.]]
Vector b: [1. 2.]
A @ b: [ 5. 11.]
A.T:
 [[1. 3.]
 [2. 4.]]
jnp.linalg.inv(A):
 [[-2.0000002   1.0000001 ]
 [ 1.5000001  -0.50000006]]


In [5]:
# Element-wise operations
x = jnp.linspace(0, 2 * jnp.pi, 4)
print("x:", x)
print("sin(x):", jnp.sin(x))
print("exp(x):", jnp.exp(x))

x: [0.        2.0943952 4.1887903 6.2831855]
sin(x): [ 0.0000000e+00  8.6602539e-01 -8.6602545e-01  1.7484555e-07]
exp(x): [  1.         8.120528  65.94297  535.49176 ]


Convert between NumPy and JAX arrays using `jnp.asarray()` and `np.asarray()`. However, try to minimize conversions.

In [6]:
jax_array = jnp.array([1.0, 2.0, 3.0])
numpy_array = np.array([4.0, 5.0, 6.0])

# Convert JAX array to NumPy array
converted_to_numpy = np.asarray(jax_array)
print("Converted to NumPy:", converted_to_numpy)
print("Type:", type(converted_to_numpy))

# Convert NumPy array to JAX array
converted_to_jax = jnp.asarray(numpy_array)
print("\nConverted to JAX:", converted_to_jax)
print("Type:", type(converted_to_jax))

Converted to NumPy: [1. 2. 3.]
Type: <class 'numpy.ndarray'>

Converted to JAX: [4. 5. 6.]
Type: <class 'jaxlib._jax.ArrayImpl'>


## Array Immutability

One key difference from NumPy: **JAX arrays are immutable**. You cannot modify them in-place.

In [7]:
# This works in NumPy
x_np = np.array([1, 2, 3])
x_np[0] = 10
print("NumPy after modification:", x_np)

NumPy after modification: [10  2  3]


In [8]:
# This does NOT work in JAX - uncomment to see the error
x_jax = jnp.array([1, 2, 3])
# x_jax[0] = 10  # TypeError: JAX arrays are immutable

Instead, JAX provides the `.at` property for creating modified copies of arrays:

In [9]:
x = jnp.array([1, 2, 3, 4, 5])

# Set a single element
x_new = x.at[0].set(10)
print("Original x:", x)  # unchanged!
print("New array: ", x_new)

# Set multiple elements
x_new2 = x.at[1:3].set(99)
print("Set slice: ", x_new2)

Original x: [1 2 3 4 5]
New array:  [10  2  3  4  5]
Set slice:  [ 1 99 99  4  5]


In [10]:
# 2D array example
A = jnp.zeros((3, 3))
A = A.at[0, :].set(1.0)  # Set first row to 1
A = A.at[:, 2].set(2.0)  # Set last column to 2
A = A.at[1, 1].set(5.0)  # Set center element to 5
print(A)

[[1. 1. 2.]
 [0. 5. 2.]
 [0. 0. 2.]]


## Automatic Differentiation with `jax.grad`

This is main reason why we use JAX in this project! The `jax.grad` function automatically computes the gradient/derivative of a function.

In [11]:
# Define a simple function f(x) = x^2
def f(x):
    return x**2


# Evaluate at x = 3.0
x = 3.0

# Create the gradient function: f'(x) = 2x, and evaluate it at x = 3.0
f_x = jax.grad(f)(x)

print(f"f({x}) = {f(x)}")
print(f"f'({x}) = {f_x}")  # Should be 2 * 3 = 6

f(3.0) = 9.0
f'(3.0) = 6.0


In [12]:
# A more complex function: f(x) = sin(x) * exp(-x)
def g(x):
    return jnp.sin(x) * jnp.exp(-x)


# Evaluate at x = 1.0
x = 1.0

# The derivative: g'(x) = cos(x)*exp(-x) - sin(x)*exp(-x) = exp(-x)*(cos(x) - sin(x))
g_x = jax.grad(g)(x)

print(f"g({x}) = {g(x):.6f}")
print(f"g'({x}) = {g_x:.6f}")

# Verify analytically
analytical = jnp.exp(-x) * (jnp.cos(x) - jnp.sin(x))
print(f"Analytical: {analytical:.6f}")

g(1.0) = 0.309560
g'(1.0) = -0.110794
Analytical: -0.110794


Stop and think about how amazing this is! You can define any function using standard JAX operations, and `jax.grad` will give you a new function that computes its derivative.

### Second-Order Derivatives

To compute second derivatives, simply apply `grad` twice:

In [13]:
# f(x) = x^3
# f'(x) = 3x^2
# f''(x) = 6x
def f(x):
    return x**3


x = 2.0

f_x = jax.grad(f)(x)  # First derivative
f_xx = jax.grad(jax.grad(f))(x)  # Second derivative

print(f"f({x}) = {f(x)}")
print(f"f'({x}) = {f_x}")  # Should be 3 * 4 = 12
print(f"f''({x}) = {f_xx}")  # Should be 6 * 2 = 12

f(2.0) = 8.0
f'(2.0) = 12.0
f''(2.0) = 12.0


### Gradients of Multivariate Functions

For functions of multiple variables, `grad` computes partial derivatives. By default, it differentiates with respect to the first argument.

In [14]:
# f(x, y) = x^2 * y + y^3
# df/dx = 2xy
# df/dy = x^2 + 3y^2
def f(x, y):
    return x**2 * y + y**3


x, y = 2.0, 3.0

# Gradient with respect to x (first argument, argnums=0 is default)
f_x = jax.grad(f, argnums=0)(x, y)

# Gradient with respect to y (second argument)
f_y = jax.grad(f, argnums=1)(x, y)

print(f"f({x}, {y}) = {f(x, y)}")
print(f"df/dx = {f_x}")  # 2 * 2 * 3 = 12
print(f"df/dy = {f_y}")  # 4 + 27 = 31

f(2.0, 3.0) = 39.0
df/dx = 12.0
df/dy = 31.0


### Second-Order Partial Derivatives

For PDEs, we often need second derivatives like $\frac{\partial^2 f}{\partial x^2}$:

In [15]:
# f(x, y) = sin(x) * cos(y)
# d2f/dx2 = -sin(x) * cos(y)
# d2f/dy2 = -sin(x) * cos(y)
# d2f/dxdy = -cos(x) * sin(y)
def f(x, y):
    return jnp.sin(x) * jnp.cos(y)


x, y = jnp.pi / 4, jnp.pi / 3

# Second derivative with respect to x
f_xx = jax.grad(jax.grad(f, argnums=0), argnums=0)(x, y)

# Second derivative with respect to y
f_yy = jax.grad(jax.grad(f, argnums=1), argnums=1)(x, y)

# Mixed derivative
f_xy = jax.grad(jax.grad(f, argnums=0), argnums=1)(x, y)

print(f"f(x, y) = {f(x, y):.6f}")
print(f"d2f/dx2 = {f_xx:.6f}")
print(f"d2f/dy2 = {f_yy:.6f}")
print(f"d2f/dxdy = {f_xy:.6f}")

# Verify: d2f/dx2 = -sin(x)*cos(y)
print(f"\nAnalytical d2f/dx2 = {-jnp.sin(x) * jnp.cos(y):.6f}")

f(x, y) = 0.353553
d2f/dx2 = -0.353553
d2f/dy2 = -0.353553
d2f/dxdy = -0.612372

Analytical d2f/dx2 = -0.353553


## `jax.value_and_grad`: Get Both Value and Gradient

Often we need both the function value and its gradient (e.g., for optimization). Computing them separately would be wasteful. `jax.value_and_grad` computes both efficiently:

In [16]:
def loss(x):
    return jnp.sum(x**2)


x = jnp.array([1.0, 2.0, 3.0])

# Create a function that returns (value, gradient), and evaluate it
value, gradient = jax.value_and_grad(loss)(x)

print(f"x = {x}")
print(f"loss(x) = {value}")
print(f"grad(loss)(x) = {gradient}")  # Should be 2*x

x = [1. 2. 3.]
loss(x) = 14.0
grad(loss)(x) = [2. 4. 6.]


This is particularly useful in training loops where we need the loss value for monitoring and the gradient for updating parameters.

## JIT Compilation with `jax.jit`

`jax.jit` compiles functions using XLA (Accelerated Linear Algebra) for faster execution. The first call is slower (compilation), but subsequent calls are much faster.

In [17]:
def slow_function(x):
    """A function with many operations."""
    for _ in range(100):
        x = x @ x.T
        x = x / jnp.linalg.norm(x)
    return x


# JIT-compile the function
fast_function = jax.jit(slow_function)

x = jnp.array([[1.0, 2.0], [3.0, 4.0]])

In [18]:
# Time the non-jitted version
%timeit slow_function(x).block_until_ready()

1.58 ms ± 72.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [19]:
# First call includes compilation time
_ = fast_function(x)

# Time the jitted version (after compilation)
%timeit fast_function(x).block_until_ready()

24.3 μs ± 2.12 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


As a shortcut we usually define the functions we want to JIT using the `@jax.jit` decorator:

In [20]:
@jax.jit
def fast_function(x):
    """A function with many operations."""
    for _ in range(100):
        x = x @ x.T
        x = x / jnp.linalg.norm(x)
    return x

### Combining `jit` with `grad`

JAX transformations compose nicely. You can JIT-compile a gradient function:

In [21]:
def loss(params, x):
    """A simple loss function."""
    return jnp.sum((params - x) ** 2)


# JIT-compiled gradient function
grad_loss = jax.jit(jax.grad(loss))

# Or equivalently
grad_loss = jax.grad(loss)
grad_loss = jax.jit(grad_loss)

params = jnp.array([1.0, 2.0, 3.0])
x = jnp.array([0.0, 0.0, 0.0])
print(f"Gradient: {grad_loss(params, x)}")

Gradient: [2. 4. 6.]


## Vectorization with `jax.vmap`

`jax.vmap` (vectorizing map) automatically converts a function that operates on single examples to one that operates on batches. This is much faster than using Python loops.

In [22]:
# A function that operates on a single vector
def normalize(x):
    """Normalize a single vector."""
    return x / jnp.linalg.norm(x)


# Batch of vectors (each row is a vector)
batch = jnp.array([[3.0, 4.0], [1.0, 0.0], [0.0, 5.0], [1.0, 1.0]])

# Use vmap to apply normalize to each row
batch_normalize = jax.vmap(normalize)
result = batch_normalize(batch)

print("Batch of vectors:")
print(batch)
print("\nNormalized (using vmap):")
print(result)

Batch of vectors:
[[3. 4.]
 [1. 0.]
 [0. 5.]
 [1. 1.]]

Normalized (using vmap):
[[0.6        0.8       ]
 [1.         0.        ]
 [0.         1.        ]
 [0.70710677 0.70710677]]


In [23]:
# Batch of vectors (each row is a vector)
batch = jnp.array([[3.0, 4.0], [1.0, 0.0], [0.0, 5.0], [1.0, 1.0]])

print("Batch of vectors:")
print(batch)

Batch of vectors:
[[3. 4.]
 [1. 0.]
 [0. 5.]
 [1. 1.]]


In [24]:
# The slow way: Python loop
result_loop = jnp.array([normalize(batch[i]) for i in range(len(batch))])
print("Using loop:")
print(result_loop)

Using loop:
[[0.6        0.8       ]
 [1.         0.        ]
 [0.         1.        ]
 [0.70710677 0.70710677]]


In [25]:
# The fast way: vmap
batch_normalize = jax.vmap(normalize)
result_vmap = batch_normalize(batch)
print("Using vmap:")
print(result_vmap)

Using vmap:
[[0.6        0.8       ]
 [1.         0.        ]
 [0.         1.        ]
 [0.70710677 0.70710677]]


## Random Number Generation

JAX handles random numbers differently from NumPy. Instead of a global random state, JAX uses explicit **keys** that you pass to random functions. This makes randomness reproducible and compatible with JIT compilation.

In [26]:
from jax import random

# Create a random key (like setting a seed)
key = random.key(42)
print("Key:", key)

Key: Array((), dtype=key<fry>) overlaying:
[ 0 42]


In [27]:
uniform_samples = random.uniform(key, shape=(5,))
print("Uniform [0, 1):", uniform_samples)

normal_samples = random.normal(key, shape=(5,))
print("Standard normal:", normal_samples)

Uniform [0, 1): [0.48870957 0.6797972  0.6162715  0.5610161  0.4506446 ]
Standard normal: [-0.02830462  0.46713185  0.29570296  0.15354592 -0.12403282]


**Important**: Using the same key gives the same random numbers. To get different random numbers, you need to **split** the key:

In [28]:
# Same key = same numbers
print("Same key, call 1:", random.normal(key, shape=(3,)))
print("Same key, call 2:", random.normal(key, shape=(3,)))

# Split the key to get new keys
key, subkey = random.split(key)
print("\nAfter splitting:")
print("Using subkey:", random.normal(subkey, shape=(3,)))

Same key, call 1: [-0.02830462  0.46713185  0.29570296]
Same key, call 2: [-0.02830462  0.46713185  0.29570296]

After splitting:
Using subkey: [ 0.60576403  0.7990441  -0.908927  ]


In [29]:
# You can split into multiple keys at once
key, *subkeys = random.split(key, num=4)  # Get 3 subkeys + 1 new main key
print("Generated 3 different samples:")
for i, sk in enumerate(subkeys):
    print(f"  subkey {i}: {random.normal(sk, shape=(2,))}")

Generated 3 different samples:
  subkey 0: [-0.21089035 -1.3627948 ]
  subkey 1: [-1.8259704  -0.40702963]
  subkey 2: [-1.0296261  0.3765022]


## Summary

| Feature | What it does | Example |
|---------|-------------|--------|
| `jax.numpy` | NumPy-compatible array operations | `jnp.sin(x)`, `jnp.array([1,2,3])` |
| `.at[].set()` | Immutable array updates | `x.at[0].set(10)` |
| `jax.grad` | Automatic differentiation | `jax.grad(f)(x)` |
| `jax.value_and_grad` | Get value and gradient together | `val, grad = jax.value_and_grad(f)(x)` |
| `jax.jit` | Compile for speed | `jax.jit(f)(x)` |
| `jax.vmap` | Vectorize over batches | `jax.vmap(f)(batch)` |
| `random.key` | Create random key | `key = random.key(42)` |
| `random.split` | Split key for new randomness | `key, subkey = random.split(key)` |

These tools compose arbitrarily: `jax.jit(jax.vmap(jax.grad(f)))` is perfectly valid!