remember to create a virtual environment and install the requirements:

```
python3 -m venv .venv
source .venv/bin/activate
pip install -r reqs.txt
```

In [1]:
import jax
import jax.numpy as jnp
import numpy as np

# Introduction to JAX

In short, JAX is an array-oriented numerical computing library that enables composable transformations. These include just-in-time (JIT) compilation, automatic vectorization and automatic differentiation. On top of that, since it leverages the XLA (Accelerated Linear Algebra)
compiler it can run on CPUs, GPUs and TPUs natively.

In [2]:
# Check available devices
print("JAX devices:", jax.devices())

def compute(x):
    return jnp.dot(x, x.T) + jnp.sum(x)

x = jnp.ones((3, 3))

# Run on CPU
cpu_result = jax.jit(compute, backend="cpu")(x)
print("Result on CPU:", cpu_result)

# Run on GPU (if available)
if any(d.device_kind == "GPU" for d in jax.devices()):
    gpu_result = jax.jit(compute, backend="gpu")(x)
    print("Result on GPU:", gpu_result)

# Run on TPU (if available)
if any(d.device_kind == "TPU" for d in jax.devices()):
    tpu_result = jax.jit(compute, backend="tpu")(x)
    print("Result on TPU:", tpu_result)

JAX devices: [CpuDevice(id=0)]
Result on CPU: [[12. 12. 12.]
 [12. 12. 12.]
 [12. 12. 12.]]


JAX provides an array interface that mimics NumPy and can be used as a drop-in replacement for NumPy arrays. The most notable difference between \texttt{numpy}, usually referred as \texttt{np}, arrays and \texttt{jax.numpy}, usually referred as \texttt{jnp}, is that the latter are always immutable.\newpage

In [3]:
# Create a NumPy array
np_array = np.array([1, 2, 3])
print("Original NumPy array:", np_array)
np_array[0] = 10  # Mutating a NumPy array is allowed
print("Modified NumPy array:", np_array)

# Create a JAX array
jnp_array = jnp.array([1, 2, 3])

try:
    jnp_array[0] = 10  # Attempting to mutate a JAX array
except TypeError as e:
    print("\n[!] JAX array mutation error:", e)

# Instead of mutating, we create a new array
new_jnp_array = jnp_array.at[0].set(10)
print("\nOriginal JAX array:", jnp_array)
print("New JAX array after modification:", new_jnp_array)


Original NumPy array: [1 2 3]
Modified NumPy array: [10  2  3]

[!] JAX array mutation error: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

Original JAX array: [1 2 3]
New JAX array after modification: [10  2  3]


All JAX operations are implemented in terms of XLA. This means that whenever we use \texttt{jnp} we will be taking advantage of the accelerated linear algebra compiler. However, we also have a lower level API available with \texttt{jax.lax}, which contains wrappers for primitive XLA operations. All \texttt{jnp} operations are implemented in terms of \texttt{jax.lax}.

JAX has an internal representation of programs called jaxpr language. We can use the \texttt{jax.make\_jaxpr} function to obtain the jaxpr representation of a function. This can be useful for finding out how certain functions get transformed to lower level operations. For instance, we can see how \texttt{jnp.dot} translates to a more general \texttt{jax.lax.dot\_general}. \newpage


In [7]:
def high_level(x, y):
    return jnp.dot(x, y)

def low_level(x, y):
    return jax.lax.dot_general(
        x, y,
        dimension_numbers=(((1,), (0,)), ((), ()))  # Batch dimensions are empty
    )
    
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])
jaxpr_high_level = jax.make_jaxpr(high_level)(A, B)
jaxpr_low_level = jax.make_jaxpr(low_level)(A, B)

print("JAXpr for jnp.dot:\n", jaxpr_high_level, "\n")
print("JAXpr for jax.lax.dot_general:\n", jaxpr_low_level)

JAXpr for jnp.dot:
 { lambda ; a:i32[2,2] b:i32[2,2]. let
    c:i32[2,2] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=int32
    ] a b
  in (c,) } 

JAXpr for jax.lax.dot_general:
 { lambda ; a:i32[2,2] b:i32[2,2]. let
    c:i32[2,2] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a b
  in (c,) }


Along with a NumPy-like API of functions that operate on arrays, JAX also includes a number of composable transformations which operate on functions. The ones we are interested the most in are:
- `jax.jit`.
- `jax.vmap`.
- `jax.grad`.

To do transformations, JAX uses the concept of tracing a function. Tracing works by replacing the array inputs of a function by abstract placeholders with the same shape and type. This allows JAX to determine the sequence of operations of a function and the effect these have on the input arrays, independently of their content.

We can see how JAX sees traced arrays by printing inside a function subject to a transform:

In [8]:
@jax.jit
def f(x):
  print("inside the function we see x as ", x)
  return x + 1

x = jnp.arange(5)
print("outside the function we see x as ", x)
y = f(x)
print("f(x) = ", y)

outside the function we see x as  [0 1 2 3 4]
inside the function we see x as  Traced<ShapedArray(int32[5])>with<DynamicJaxprTrace>
f(x) =  [1 2 3 4 5]


The printed value we see inside the function is not the array $x$, but rather an abstract traced representation that has the same shape and type. In this case, the JIT transform has used the traced information to create a compiled version of $f$. 

## Just-in-Time Compilation

By default, JAX executes operations eagerly, dispatching each operation individually to XLA without ahead-of-time compilation. The \texttt{jax.jit} transform leverages JAX’s tracing mechanism to capture entire computations, allowing the XLA compiler to optimize, fuse, and compile sequences of operations into a single efficient execution.

JIT compilation is very powerful but it has some limitations. In particular, it requires that branching operations must be determined and trace-time. This implies that the following function will not work with JIT:

In [8]:
def g(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

try:
    jax.jit(g)(10, 20)  # Raises an error
except TypeError as e:
    print("\t[!] JAX JIT error:", e)

	[!] JAX JIT error: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function g at /var/folders/6q/mzgmvrhn3l76p312t6zn3mxw0000gn/T/ipykernel_2089/1615799189.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument n.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError


The reason of these errors is that branching statements are allowed in JIT functions as long as they are based on static attributes, e.g. shape, type... since they can be determined at trace-time.

What we can do to deal with this is to only compile  certain parts of the function:

In [9]:
# While loop conditioned on x and n with a jitted body.

@jax.jit
def loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted(x, n):
  i = 0
  while i < n:
    i = loop_body(i)
  return x + i

x = g_inner_jitted(10, 20)
print(x)

30


If we really need to compile the whole function then we can mark certain arguments as static. This will make the compiled version of the function depend on the static arguments, so JAX will have to re-compile the function for every new static input. This is only a good strategy if the you know that there is only a limited number of static values.

In [10]:
from functools import partial

@partial(jax.jit, static_argnames=['n'])
def g_jit_decorated(x, n):
  i = 0
  while i < n:
    i += 1
  return x + i

x = g_jit_decorated(10, 20)
print(x)

30


The first time we call a JIT function it gets compiled and the resulting XLA code is cached.  Subsequent calls will then reuse the cached code. If we specify static arguments, the cached code will only be used for the same values of static arguments. To find the cached code JAX uses the hash of the function, this implies that we should not redefine equivalent functions in ways that can modify the function hash.

In [11]:
from functools import partial

def unjitted_loop_body(prev_i):
  return prev_i + 1

def g_inner_jitted_partial(x, n):
  i = 0
  while i < n:
    # Don't do this! each time the partial returns
    # a function with different hash
    i = jax.jit(partial(unjitted_loop_body))(i)
  return x + i

def g_inner_jitted_lambda(x, n):
  i = 0
  while i < n:
    # Don't do this!, lambda will also return
    # a function with a different hash
    i = jax.jit(lambda x: unjitted_loop_body(x))(i)
  return x + i

def g_inner_jitted_normal(x, n):
  i = 0
  while i < n:
    # this is OK, since JAX can find the
    # cached, compiled function
    i = jax.jit(unjitted_loop_body)(i)
  return x + i

In [12]:
%timeit g_inner_jitted_partial(10, 20).block_until_ready()
%timeit g_inner_jitted_lambda(10, 20).block_until_ready()
%timeit g_inner_jitted_normal(10, 20).block_until_ready()

221 ms ± 9.75 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
220 ms ± 6.08 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
565 μs ± 2.56 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Automatic Vectorization

Automatic vectorization allows us to extend a function defined on single input to support batch operations, while capitalizing on hardware acceleration.

In [35]:
def activation_score(x):
    # x is a 1D array.
    return jax.nn.tanh(jnp.mean(x))

In [36]:
# Single input example (1D array)
x_single = jnp.array([0.2, 0.5, 0.8])
print("Activation (single sample):", activation_score(x_single))

Activation (single sample): 0.46211717


In [37]:
# Batched inputs: each row represents one training example.
x_batch = jnp.array([
    [0.2, 0.5, 0.85],
    [0.1, -0.3, 0.9],
    [-0.4, 0.7, 0.6]
])

In [38]:
for lx in x_batch:
    print(activation_score(lx))

0.47512326
0.22918896
0.2913126


In [39]:
# Direct call on the batch
activation_direct = activation_score(x_batch)
print("Direct batch call activation:\n", activation_direct)

Direct batch call activation:
 0.33637553


In [42]:
# Vectorize the function over its first (and only) argument.
vectorize_along_rows = jax.vmap(activation_score, in_axes=(0))
vectorize_along_cols = jax.vmap(activation_score, in_axes=(1))

# Now, applying it to the batch returns an activation per sample.
activation_vmap = vectorize_along_rows(x_batch)
print("vmap rows:\n", activation_vmap)

print("vmap cols:\n", vectorize_along_cols(x_batch))

vmap rows:
 [0.47512326 0.22918896 0.2913126 ]
vmap cols:
 [-0.03332099  0.2913126   0.65461576]


what about a more complex example (perhaps this should be the only one)

In [57]:
def weighted_array_sum(scale: float, weights: jnp.ndarray, epsilon: float):
    return scale * jnp.sum(weights) + epsilon

In [58]:
weighted_array_sum(0.5, jnp.array([1,2,3]), 1e-4)

Array(3.0001, dtype=float32, weak_type=True)

how can we extend this function to multiple batches

In [66]:
scales = jnp.array([0.5, 1.0, 1.5])  # Shape: (3,)
weights = jnp.array([
    [1.0, 0.2, 0.3],
    [0.5, 0.8, 1.0],
    [0.7, 0.1, 0.4]
])
epsilon = 1e-4

# we can vectorize along the first two parameters. we'll interpret weights as a set of vectors and scales as a set of scalars.
# (note that we do not say anything about epsilon because we don't vectorize over it)
vectorized_weighted_array_sum = jax.vmap(weighted_array_sum, in_axes=(0, 0, None))

In [67]:
vectorized_weighted_array_sum(scales, weights, epsilon)

Array([0.7501   , 2.3000998, 1.8001001], dtype=float32)

to achieve this without vmap we would have to use do a manual for loop.

the `in_axes` parameter specifies along which axes we vectorize for each of the inputs. in this case we don't have much room to play with but we could vectorize along the colums or the rows of the weight matrix.

if our input data was higher dimensional we would have more freedom to play around with.

there is also the `out_axes` parameter, which specifies how the results of batch computations are stored across the axes.

In [77]:
def identity(x):
    return x

identity(weights)

Array([[1. , 0.2, 0.3],
       [0.5, 0.8, 1. ],
       [0.7, 0.1, 0.4]], dtype=float32)

In [82]:
print("id_rows(x) = \n", jax.vmap(identity, in_axes=0, out_axes=0)(weights), "\n")
print("id_cols(x) = \n", jax.vmap(identity, in_axes=0, out_axes=1)(weights))

id_rows(x) = 
 [[1.  0.2 0.3]
 [0.5 0.8 1. ]
 [0.7 0.1 0.4]] 

id_cols(x) = 
 [[1.  0.5 0.7]
 [0.2 0.8 0.1]
 [0.3 1.  0.4]]


## Automatic differentiation

The automatic differentiation transform is a pretty simple transform that allows us to get the gradient of any function. We can apply it multiple times to obtain a higher order gradient. Also we can choose along which parameters to compute it

In [90]:
def func(x, y):
    return x**2 + y**3

grad_x = jax.grad(func, argnums=0)        # Derivative w.r.t. x
grad_y = jax.grad(func, argnums=1)        # Derivative w.r.t. y
grad   = jax.grad(func, argnums= (0, 1))  # Both

x, y = 2.0, 3.0
print("Gradient w.r.t x:", grad_x(x, y))  # 2*x = 4
print("Gradient w.r.t y:", grad_y(x, y))  # 3*y^2 = 27
print("Gradient both:", grad(x, y))       # (4, 27)

Gradient w.r.t x: 4.0
Gradient w.r.t y: 27.0
Gradient both: (Array(4., dtype=float32, weak_type=True), Array(27., dtype=float32, weak_type=True))


In [92]:
def loss_fn(params, x):
    w, b = params["w"], params["b"]
    return jnp.sum((w * x + b) ** 2)

params = {"w": 2.0, "b": 1.0}
x = jnp.array([1.0, 2.0, 3.0])

grad_fn = jax.grad(loss_fn)
grads = grad_fn(params, x)
print("Gradients:", grads)  # {'w': ..., 'b': ...}

Gradients: {'b': Array(30., dtype=float32, weak_type=True), 'w': Array(68., dtype=float32, weak_type=True)}
