remember to create a virtual environment and install the requirements:

```
python3 -m venv .venv
source .venv/bin/activate
pip install -r reqs.txt
```

# Introduction to JAX

In short, JAX is an array-oriented numerical computing library that enables composable transformations. These include just-in-time (JIT) compilation, automatic vectorization and automatic differentiation. On top of that, since it leverages the XLA (Accelerated Linear Algebra)
compiler it can run on CPUs, GPUs and TPUs natively.

In [1]:
import jax
import jax.numpy as jnp

# Check available devices
print("JAX devices:", jax.devices())

# Define a simple computation
def compute(x):
    return jnp.dot(x, x.T) + jnp.sum(x)

# Create input data
x = jnp.ones((3, 3))

# Run on CPU
cpu_result = jax.jit(compute, backend="cpu")(x)
print("Result on CPU:", cpu_result)

# Run on GPU (if available)
if any(d.device_kind == "GPU" for d in jax.devices()):
    gpu_result = jax.jit(compute, backend="gpu")(x)
    print("Result on GPU:", gpu_result)

# Run on TPU (if available)
if any(d.device_kind == "TPU" for d in jax.devices()):
    tpu_result = jax.jit(compute, backend="tpu")(x)
    print("Result on TPU:", tpu_result)


JAX devices: [CpuDevice(id=0)]
Result on CPU: [[12. 12. 12.]
 [12. 12. 12.]
 [12. 12. 12.]]


JAX has a powerful API for writing numerical code. It provides a very NumPy-inspired array interface, which can be used as a drop-in replacement for NumPy arrays. The most notable difference between \texttt{numpy}, usually referred as \texttt{np}, arrays and \texttt{jax.numpy}, usually referred as \texttt{jnp}, is that the latter are always immutable.

In [2]:
import numpy as np
import jax.numpy as jnp

# Create a NumPy array
np_array = np.array([1, 2, 3])
print("Original NumPy array:", np_array)
np_array[0] = 10  # Mutating a NumPy array is allowed
print("Modified NumPy array:", np_array)

# Create a JAX array
jnp_array = jnp.array([1, 2, 3])

try:
    jnp_array[0] = 10  # Attempting to mutate a JAX array
except TypeError as e:
    print("\t[!] JAX array mutation error:", e)

# Instead of mutating, we create a new array
new_jnp_array = jnp_array.at[0].set(10)
print("Original JAX array:", jnp_array)
print("New JAX array after modification:", new_jnp_array)


Original NumPy array: [1 2 3]
Modified NumPy array: [10  2  3]
	[!] JAX array mutation error: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
Original JAX array: [1 2 3]
New JAX array after modification: [10  2  3]


All JAX operations are implemented in terms of XLA operations. This means that whenever we use \texttt{jnp} we will be using the accelerated linear algebra compiler. We have access to lower level operations, the ones used by \texttt{jnp}, via \texttt{jax.lax}. Which is more powerful, but stricter, lower level API.

In essence, \texttt{jnp} is a high level wrapper meant to be very similar to NumPy and \texttt{jax.lax} contains the underlying low level operations. Here are two equivalent pieces of code:

In [7]:
import jax.numpy as jnp

# Define input matrices
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])

# High-level matrix multiplication
C = jnp.dot(A, B)
print("High-level (jnp.dot) result:\n", C)

JAXpr for jnp.dot:
 { lambda ; a:i32[2,2] b:i32[2,2]. let
    c:i32[2,2] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=int32
    ] a b
  in (c,) }


In [10]:
from jax import lax
import jax.numpy as jnp

# Define input matrices
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])

def low_level(x, y):
    return lax.dot_general(
        x, y,
        dimension_numbers=(((1,), (0,)), ((), ()))  # Batch dimensions are empty
    )

print("Low-level (lax.dot_general) result:\n", low_level(A, B))

Low-level (lax.dot_general) result:
 [[19 22]
 [43 50]]


We can use JAXEXPRs to see how these two translate to low-level code:

In [8]:
jaxpr_high_level = jax.make_jaxpr(high_level)(A, B)
print("JAXpr for jnp.dot:\n", jaxpr_high_level)

JAXpr for jnp.dot:
 { lambda ; a:i32[2,2] b:i32[2,2]. let
    c:i32[2,2] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=int32
    ] a b
  in (c,) }


In [11]:
jaxpr_low_level = jax.make_jaxpr(low_level)(A, B)
print("JAXpr for lax.dot_general:\n", jaxpr_low_level)

JAXpr for lax.dot_general:
 { lambda ; a:i32[2,2] b:i32[2,2]. let
    c:i32[2,2] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a b
  in (c,) }


JAX uses the concept of a tracer to be able to do transformations. Tracers are abstract placeholders for array objects that are passed to JAX functions to know the sequence of operations that a function encodes.


In [16]:
@jax.jit
def f(x):
  print("x = ", x)
  return x + 1

x = jnp.arange(5)
y = f(x)
print("f(x) = ", y)

x =  Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace>
f(x) =  [1 2 3 4 5]


The printed value is not the $x$ array, but rather a tracer instance that has the same shape and type. By executing the function with traced values, JAX can determine the sequence of operations encoded by the function before those operations are actually executed. This allows transformations to map this sequence of input operations to a transformed sequence of operations.

## Just-in-Time Compilation

## Automatic Vectorization

## Automatic Differentiation


# Using JAX with Pennylane

## Shots and Samples