# Lab 1: Introduction to JAX
---
Instructor: Dr. Jan Blechschmidt\
Email: Jan.Blechschmidt@math.tu-freiberg.de\
Credits: This tutorial is mainly based on [this tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial2/Introduction_to_JAX.html) and [the JAX documentation](https://jax.readthedocs.io/en/latest/jax-101/index.html).

---

Other recommended tutorials are:
* [JAX 101](https://jax.readthedocs.io/en/latest/jax-101/index.html) with many subtutorials on individual parts of JAX
* [JAX - The Sharp Bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) discusses the constraints of JAX and how to overcome them
* [Jax for the Impatient](https://flax.readthedocs.io/en/latest/notebooks/jax_for_the_impatient.html) for a quick intro to JAX with focus on deep learning
* [Flax Basics](https://flax.readthedocs.io/en/latest/notebooks/flax_basics.html) as introduction to the Flax framework

The following notebook is meant to give a short introduction to JAX. It assumes that you are familiar with basic concepts of Python and NumPy.

## Part A: Motivation

First, you might ask: why should you learn JAX?
In particular, since there are already so many other deep learning frameworks like [PyTorch](https://pytorch.org/) and [TensorFlow](https://www.tensorflow.org/)?

### Pro's

- JAX is a Python library for accelerator-oriented array computation, designed for high-performance numerical computing and large-scale machine learning
- JAX provides a familiar NumPy-style API for ease of adoption by researchers and engineers.
- JAX includes composable function transformations for
  - just-in-time (JIT) compilation  via Open XLA, an open-source machine learning compiler ecosystem,
  - automatic differentiation,
  - vectorization,
  - parallelization.
- The same code executes on multiple backends including CPU, GPU & TPU.

### Con's

In order to efficiently compile programs just-in-time in JAX, the functions need to be written with certain constraints:
- The functions are not allowed to have side-effects, meaning that they are not allowed to affect any variable outside of their namespaces.
For instance, in-place operations affect a variable even outside of the function.
Moreover, stochastic operations such as `torch.rand(...)` change the global state of pseudo random number generators, which is not allowed in functional JAX (we will see later how JAX handles random number generation).
- JAX compiles the functions based on anticipated shapes of all arrays/tensors in the function.
  This becomes problematic if the shapes or the program flow within the function depends on the values of the tensor.
  For instance, in the operation `y = x[x>3]`, the shape of `y` depends on how many values of `x` are greater than 3.
- Since JAX is typically used in the machine learning context, it's standard data type is `float32` (single precision) and not as one might expect `float64` (double precision).
  
We will discuss more of these constraints in this and the following notebooks.
However, in most common cases of training neural networks, it is straightforward to write functions within these constraints.

### What else to know before we start?

Throughout our tutorials, we will try to draw comparisons to PyTorch. If you are not familiar with PyTorch, you can ignore these comparisons or take a look at PyTorch tutorials, e.g. [this tutorial](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial2/Introduction_to_PyTorch.html)
Nevertheless, we will use PyTorch's data loading capabilities sometimes in JAX due to their flexibility.
Furthermore, we use:
- [Flax](https://flax.readthedocs.io/en/latest/) as a neural network library in JAX,
- [Optax](https://optax.readthedocs.io/en/latest/index.html) to implement common deep learning optimizers.

More on these packages later. First, let's get started with some basic JAX operations.

## Part B: First steps

### JAX as NumPy on accelerators

Every deep learning framework has its own API for dealing with data arrays. For example, PyTorch uses `torch.Tensor` as data arrays on which it defines several operations like matrix multiplication, taking the mean of the elements, etc. In JAX, this basic API strongly resembles the one of [NumPy](https://numpy.org/), and even has the same name in JAX (`jax.numpy`). So, for now, let's think of JAX as NumPy that runs on accelerators. As a first step, let's import JAX and its NumPy API:

In [None]:
import jax
import jax.numpy as jnp
print("Using jax", jax.__version__)

At the current time of writing (May 2025), the newest JAX version is `0.6.0` which supports most of the common NumPy functionalities.
Please ensure that your version is at least `0.6.0`, since some parallel computing concepts that we will use in the following weeks only work properly with a recent version.

The NumPy API of JAX is usually imported as `jnp`, to keep a resemblance to NumPy's import as `np`.
In the following subsections, we will discuss the main differences between the classical NumPy API and the one of JAX.

#### Device Arrays

As a first test, let's create some arbitrary arrays like we would do in NumPy. For instance, let's create an array of zeros with shape `[2,5]`.

In [None]:
a = jnp.zeros((2, 5), dtype=jnp.float32)
print(a)

Similarly, we can create an array with values of 0 to 5 by using `arange`:

In [None]:
b = jnp.arange(6)
print(b)

**Task**: What is the data type of the variable `b`? 

In [None]:
# Your code here

You can also specify the datatype when creating most arrays by setting the optional parameter `dtype`.

**Task**: Create a one-dimensional array `c` containing the numbers from 5 to 10 with data type `jnp.float16`. From now on, when writing "array" we mean a **JAX array** and not a NumPy array.

In [None]:
# Your code here

**Task**: What is the class of the variable `b`.

In [None]:
# Your code here

Instead of a simple NumPy array, it shows the type `ArrayImpl` which is what JAX uses to represent arrays.
In contrast to NumPy, JAX can execute the same code on different backends – CPU, GPU and TPU.
An `ArrayImpl` therefore represents an array which is on one of these accelerators.

**Task**: Similar to PyTorch, we can check the device of the array `b` by calling `b.device`.

In [None]:
# Your code here

Depending on your setup, the output will look different:
- If you have a GPU and installed JAX properly, it will show the default GPU device `GpuDevice(id=0)`
- If you do this exercise on Google Colab, remember to select a GPU in your runtime environment. Then you should also see `GpuDevice(id=0)` (a restart of the kernel might be necessary)
- Otherwise, the object will be on `CpuDevice(id=0)`

The function `jax.devices()` will show you all currently available devices.

In [None]:
jax.devices()

If you have only a CPU available, then explicitly *printing* the `device` will reveal some more interesting details:

In [None]:
print(b.device)

It returns ``TFRT_CPU_0``, which is an optimized CPU backend. According to the [website of the developers](https://github.com/ROCm/tensorflow-runtime), "TFRT is a new TensorFlow runtime. It aims to provide a unified, extensible infrastructure layer with best-in-class performance across a wide variety of domain specific hardware. It provides efficient use of multithreaded host CPUs, supports fully asynchronous programming models, and focuses on low-level efficiency."

Therefore, you can work with JAX accelerators whether you have a CPU, GPU or TPU.

In order to get a variable from a JAX device, we can use `jax.device_get`:

In [None]:
b_cpu = jax.device_get(b)
print(b_cpu)
print(b_cpu.__class__)

Thus, a simple CPU-based array is nothing else than a NumPy array, which allows for a simple conversion between the two frameworks! To explicitly push a NumPy array to the accelerator, you can use `jax.device_put`:

In [None]:
b_xpu = jax.device_put(b_cpu)
print(f'Device put: {b_xpu.__class__} on {b_xpu.device}')

Please note that JAX will handle any device clash itself when you try to perform operations on a NumPy array and a DeviceArray by modeling the output as `ArrayImpl` again:

In [None]:
out = b_cpu + b_xpu
print(out)
print(out.__class__)

An important technical detail of running operations on DeviceArrays is that when a JAX function is called, the corresponding operation **takes place asynchronously on the accelerator** when possible.

For instance, if we call `out = jnp.matmul(b, b)`, JAX first returns a placeholder array for `out` which may not be filled with the values as soon as the function call finishes.
This way, Python will not block the execution of follow-up statements, but instead only does it whenever we strictly need the value of `out`, for instance for printing or putting it on CPU.
PyTorch uses a very similar principle to allow asynchronous computation.
For more details, see [JAX - Asynchronous Dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html).

#### Immutable tensors

When we would like to change a NumPy array in-place, like replacing the first element of `b` with `1` instead of `0`, we could simply write `b[0]=1`. However, in JAX, this is not possible. A `DeviceArray` object is *immutable*, which means that no in-place operations are possible. The reason for this goes back to our discussion in the introduction: JAX requires programs to be "pure" functions, i.e. no effects on variables outside of the function are allowed. Allowing in-place operations of variables would make the program analysis for JAX's just-in-time compilation difficult. Instead, we can use the expression `b.at[0].set(1)` which, analogous to the in-place operation, returns a new array which is identical to `b`, except that its value at the first position is 1.

**Task**: Let's try that out below:

In [None]:
print('Original array:', b)
# Your code here

However, we said that JAX is very efficient. 
While it is indeed less efficient, it can be made much more efficient with JAX's just-in-time compilation.
The compiler can recognize unnecessary array duplications, and replace them with in-place operations.
More on the **just-in-time** compilation later! 

#### Pseudo Random Numbers in JAX

In machine learning, we come across several situations where we need to generate pseudo random numbers, e.g.:
- randomly shuffling a dataset
- sampling a dropout mask for regularization
- training a variational autoencoder by sampling from the approximate posterior
- many more

In libraries like NumPy and PyTorch, the random number generator are controlled by a seed, which we set initially to obtain the same samples every time we run the code (this is why the numbers are not truly random, hence "pseudo"-random). However, if we call `np.random.normal()` 5 times consecutively, we will get 5 different numbers since every execution changes the (global) state/seed of the pseudo random number generation (PRNG).

In JAX, if we would try to generate a random number with this approach, a function creating pseudo-random number would have an effect outside of it. To prevent this, JAX takes a different approach by explicitly passing and iterating the PRNG state. First, let's create a PRNG for the seed 42:

In [None]:
rng = jax.random.key(42)

Now, we can use this PRNG state to generate random numbers. Since with this state, the random number generation becomes deterministic, we sample the same number every time. This is not the case in NumPy if we set the seed once before both operations:

To compare this with numpy, we import it here as well.

In [None]:
import numpy as np 

# A non-desirable way of generating pseudo-random numbers...
jax_random_number_1 = jax.random.normal(rng)
jax_random_number_2 = jax.random.normal(rng)
print('JAX - Random number 1:', jax_random_number_1)
print('JAX - Random number 2:', jax_random_number_2)

# Typical random numbers in NumPy
np.random.seed(42)
np_random_number_1 = np.random.normal()
np_random_number_2 = np.random.normal()
print('NumPy - Random number 1:', np_random_number_1)
print('NumPy - Random number 2:', np_random_number_2)

Usually, we want to have a behavior like NumPy where we get a different random number every time we sample. To achieve this, we can *split* the PRNG state to get usable subkeys every time we need a new pseudo-random number. We can do this with `jax.random.split(...)`:

In [None]:
rng, subkey1, subkey2 = jax.random.split(rng, num=3)  # We create 3 new keys
jax_random_number_1 = jax.random.normal(subkey1)
jax_random_number_2 = jax.random.normal(subkey2)
print('JAX new - Random number 1:', jax_random_number_1)
print('JAX new - Random number 2:', jax_random_number_2)

Every time you run this cell, you will obtain different random numbers for both operations since we create new PRNG states before sampling and update `rng` itself. 

**Advice**: In general, you want to split the PRNG key every time before generating a pseudo-number, to prevent accidentally obtaining the exact same numbers (for instance, sampling the exact same dropout mask every time you run the network makes dropout itself quite useless...). For a deeper dive into the ideas behind the random number generation in JAX, see JAX's tutorial on [Pseudo Random Numbers](https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html).

## Part C: Automatic differentiation and Just-in-Time compilation

### Function transformations with JAX expressions (abbreviated `jaxpr`)

Rosalia Schneider and Vladimir Mikulik summarize the key points of JAX in the [JAX 101 tutorial](https://jax.readthedocs.io/en/latest/jax-101/01-jax-basics.html) as follows: 

> The most important difference, and in some sense the root of all the rest, is that JAX is designed to be functional, as in functional programming. The reason behind this is that the kinds of program transformations that JAX enables are much more feasible in functional-style programs. [...] The important feature of functional programming to grok when working with JAX is very simple: don’t write code with side-effects.

Essentially, we want to write our main code of JAX in functions that:
- do not affect anything else besides its outputs
- do not change input arrays in-place and
- access global variables.

While this might seem limiting at first, you get used to this quite quickly and most JAX functions that need to fulfill these constraints can be written this way without problems.

Note that not all possible functions in training a neural network need to fulfill the constraints. For instance, loading or saving of models, the logging, or the data generation can be done in naive functions. Only the training and execution of the machine learning model, which we want to do very efficiently on our accelerator (GPU or TPU), should strictly follow these constraints.

What does make JAX functions so special, and how can we think about them? A good way of gaining understanding in how JAX handles function is to understand its intermediate representation: jaxpr (aka JAX expression).

Conceptually, you can think of any operation that JAX does on a function, as first trace-specializing the Python function to be transformed into a small and well-behaved intermediate form. This means that we check which operations are performed on which array, and what shapes the arrays are. Based on this representation, JAX then interprets the function with transformation-specific interpretation rules, which includes automatic differentiation or compiling a function in XLA (meaning Accelerated Linear Algebra) to efficiently use the accelerator.

To illustrate this intermediate representation, let's consider a simple function to discuss the concept of dynamic computation graphs.

Using common NumPy operations in JAX, we can write the function

$$
y = \frac{1}{n} \sum_{i=1}^n\left[\left(x_i+2\right)^2+3\right]
$$

as follows:

In [None]:
def f(x):
    x = x + 2
    x = x ** 2
    x = x + 3
    y = x.mean()
    return y

x = jnp.arange(3, dtype=jnp.float32)
print('Input', x)
print('Output', f(x))

To view the jaxpr representation of this function, we can use `jax.make_jaxpr`. Since the tracing depends on the shape of the input, we need to pass an input to the function (here of shape `(3,)`):

In [None]:
jax.make_jaxpr(f)(x)

A jaxpr representation follows the structure:

```python
jaxpr ::= { lambda Var* ; Var+.
            let Eqn*
            in  [Expr+] }
```
where `Var*` are constants and `Var+` are input arguments. In the cell above, this is `a:f32[3]`, i.e. an array of shape 3 with type `jnp.float32` (`x`). The list of equations, `Eqn*`, define the intermediate results of the function. You can see that each operation in the function `f` is translated to a corresponding equation, like `x = x + 2` is translated to `b:f32[3] = add a 2.0`. Furthermore, you see the specialization of the operations on the input shape, like `x.mean()` being replacing in `e` and `f` with summing and dividing by 3. Finally, `Expr+` in the jaxpr representation are the outputs of the functions. In the example, this is `f`, i.e. the final result of the function.
Based on these atomic operations, JAX offers all kind of function transformations, of which we will discuss the most important ones later in this section. 

**Task**: You probably know, that one can write the simple function way shorter. How does the `jaxpr` looks in this case?

In [None]:
# Your code here

Chances are high, that your `jaxpr` representation looks identical to the one above. Note that the final division again depends on the input size. In this case, we devide the variable by `3.0` since our input array is of length `3`.

Hence, you can consider the `jaxpr` representation as an intermediate compilation stage of JAX. What happens if we actually try to look at the jaxpr representation of a function with **side-effect**?
Let's consider the following function, which, as an illustrative example, appends the input to a global list:

In [None]:
global_list = []

# Invalid function with side-effect
def norm(x):
    global_list.append(x)
    x = x ** 2
    y = x.sum()
    y = jnp.sqrt(y)
    return y

jax.make_jaxpr(norm)(x)

As you can see, the `jaxpr` representation of the function does not contain any operation for `global_list.append(x)`. This is because `jaxpr` only understand side-effect-free code, and cannot represent such effects.

Thus, we need to stick with pure functions without any side effects, to prevent any unwanted errors in our functions. If you are interested in learning more about the `jaxpr` representation, check out the [JAX documentation](https://jax.readthedocs.io/en/latest/jaxpr.html) on it. 

**Task**: Now execute the function `norm` once on the input `inp` and check the variable `global_list`.

In [None]:
# Your code here

You see that it is important to stick with pure functions without side effects. But for this tutorial, we just need the basics as discussed above.

### Automatic differentiation

The intermediate jaxpr representation defines a computation graph, on which we can perform an essential operation of deep learning framework: automatic differentiation.

In frameworks like PyTorch with a dynamic computation graph, we would compute the gradients based on the loss tensor itself, e.g. by calling `loss.backward()`. However, JAX directly works with functions. Instead of backpropagating gradients through tensors, JAX takes as input a function, and outputs another function which directly calculates the gradients for it. While this might seem quite different to what you are used to from other frameworks, it is quite intuitive: your gradient of parameters is really a function of parameters and data.

The transformation that allows us to do this is `jax.grad`, which takes as input the function, and returns another function representing the gradient calculation of the (first) input with respect to the output.

**Task**: Try it out! Use the function `jax.grad` to differentiate the function `f`. Evaluate its gradient for the input `x`.

In [None]:
# Your code here

The gradient we get here is exactly the one we would obtain when doing the calculation by hand. Moreover, we can also print the jaxpr representation of the gradient function:

In [None]:
jax.make_jaxpr(grad_function)(x)

This shows a unique property of JAX: we can print out the exact computation graph for determining the gradients. Compared to the original function, you can see new equations like `d:f32[3] = integer_pow[y=1] b` and `e:f32[3] = mul 2.0 d`, which model the intermediate gradient of $\partial b_i^2/\partial b_i = 2b_i$. Furthermore, the return value `j` is the multiplication of `e` with $1/3$, which maps to the gradient being:

$$ \frac{\partial y}{\partial x_i} = \frac{2}{3}(x_i + 2)$$

Hence, we can not only use JAX to estimate the gradients at a certain input value, but actually return the analytical gradient function which is quite a nice feature of JAX!

Often, we do not only want the gradients, but also the actual output of the function, for instance for logging the loss. This can be efficiently done using `jax.value_and_grad`:

In [None]:
val_grad_function = jax.value_and_grad(f)
val_grad_function(x)

Of course, you can also look at the `jaxpr` representation of the `val_grad_function`.

In [None]:
jax.make_jaxpr(val_grad_function)(x)

Note that the only difference comes after line

    g:f32[] = reduce_sum[axes=(0,)] f

where the `grad_function` ''throws away'' the variable which represents the value 

    _:f32[] = div g 3.0
    
while the `val_grad_function` assigns it to
    
    h:f32[] = div g 3.0
and finally returns it as its first argument.

Further, we can specialize the gradient function to consider multiple input arguments, and add extra outputs that we may not want to differentiate (for instance the accuracy in classification). We will visit the most important ones in the network training later, and refer to other great resources for more details ([JAX Quickstart](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html#taking-derivatives-with-grad), [Autodiff cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html), [Advanced autodiff](https://jax.readthedocs.io/en/latest/jax-101/04-advanced-autodiff.html)).

To train neural networks, we need to determine the gradient for every parameter in the network with respect to the loss. Listing all parameters as input arguments quickly gets annoying and infeasible.
JAX offers an elegant data structure to summarize all parameters: a pytree ([documentation](https://jax.readthedocs.io/en/latest/pytrees.html)). A pytree is a container-like object which structures its elements as a tree. For instance, a linear neural network might have its parameters organized similar to:

```python
params = {
    'linear1': {
        'weights': ...,
        'bias': ...
    },
    ...
}
```
JAX offers functions to process pytrees efficiently, such as obtaining all leafs (i.e. all parameters in a network) or applying a function on each element. We will come back to these structures when training a full network.

### Speeding up computation with Just-In-Time compilation

Interestingly, from the previous code cell, you can see in the jaxpr representation of the gradient function that calculating the array `f` and scalar `g` are unnecessary. Intuitively, the gradient of taking the mean is independent of the actual output of the mean, hence we could drop `f` and `g` without any drawback. Finding such cases to improve efficiency and optimizing the code to take full advantage of the available accelerator hardware is one of the big selling points of JAX. It achieves that by *compiling functions just-in-time* with [XLA](https://www.tensorflow.org/xla) (Accelerated Linear Algebra), using their jaxpr representation. Thereby, XLA fuses operations to reduce execution time of short-lived operations and eliminates intermediate storage buffers where not needed. For more details, see the [XLA documentation](https://docs.w3cub.com/tensorflow~guide/performance/xla/index). 

To compile a function, JAX provides the `jax.jit` transformation. We can either apply the transformation directly on a function (as we will do in the next cell), or use the decorator `@jax.jit` before a function.

In [None]:
jitted_function = jax.jit(f)

The `jitted_function` takes the same input arguments as the original function `f`. Since the jaxpr representation of a function **depends on the input shape**, the *compilation is started once we put the first input in*.

**Important**: Note that this also means that for every different shape we want to run the function, a new XLA compilation is needed. This is why it is recommended to use padding in cases where your input shape strongly varies, e.g. in the case of transformer architectures.
For now, let's create an array with 1000 random values, on which we apply the jitted function:

In [None]:
# Create a new random subkey for generating new random values
rng, normal_rng = jax.random.split(rng)
large_input = jax.random.normal(normal_rng, (1000,))
# Run the jitted function once to start compilation
_ = jitted_function(large_input)

The output is not any different from what you would get from the non-jitted function. However, how is it about the runtime? Let's time both the original and the jitted function. Due to the asynchronous execution on the accelerator, we add `.block_until_ready()` on the output, which blocks the Python execution until the accelerator finished computing the result and hence get an accurate time estimate.

In [None]:
%%timeit
f(large_input).block_until_ready()

Without `.block_until_ready()` we get the wrong estimate

In [None]:
%%timeit
f(large_input)

If you are on a CPU, the difference might not be that much, but on a GPU the values typically differ a lot.

**Task**: Now test the jitted function!

In [None]:
%%timeit
# Your code here

We see that the compiled function is almost 10-15x faster (depends on your system)! This is quite an improvement in performance, and shows the potential of compiling functions with XLA.
Furthermore, we can also apply multiple transformations on the same function in JAX, such as applying `jax.jit` on a gradient function:

In [None]:
jitted_grad_function = jax.jit(grad_function)
_ = jitted_grad_function(large_input)  # Apply once to compile

Let's time the functions once more:

In [None]:
%%timeit
grad_function(large_input).block_until_ready()

In [None]:
%%timeit
jitted_grad_function(large_input).block_until_ready()

Once more, the jitted function is much faster than the original one. Intuitively, this shows the potential speed up we can gain by using `jax.jit` to compile the whole training step of a network.
Generally, we want to jit the largest possible chunk of computation to give the compiler maximum freedom.

There are situations in which applying jit to a function is not straight-forward, for instance, if an input argument cannot be traced, or you need to use loops that depend on input arguments.
To keep the tutorial simple, and since most neural network training functions do not run into these issues, we do not discuss such special cases here.

**Some further references on jit-compilation**:
- Jax documentation in [JAX 101 Tutorial](https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html)
- "To JIT or not to JIT" part in [Thinking in JAX](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit)

## Part D: Advanced topics - Vectorization and parallelization

After reading this tutorial, you might wonder why we left out some key advertisement points of JAX: automatic vectorization, easy parallelization on multiple accelerators, etc.

The reason why we did not include them in our previous discussion is that for building simple networks, and actual most models in our tutorials, you do not really need these methods.

However, since they can be handy at some times, for instance, if you have access to a large cluster or are faced with functions that are annoying to vectorize, we review them here in a separate section: 

This part is inspired by this [JAX tutorial](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).

### Automatic Vectorization with vmap

In machine learning, we often vectorize methods to efficiently process multiple inputs or batch elements at the same time. Usually, we have to write the code ourselves to support additional dimensions to vectorize over. However, since JAX can already transform functions to run efficiently on accelerators or calculate gradients, it can also automatically vectorize a function.
For instance, let's consider a simple linear layer where we write a function for a single input `x` $\in \mathbb{R}^n$, a weight matrix `W` $\in \mathbb{R}^{n \times m}$ and a bias vector `b` $\in \mathbb{R}^m$.

The function should perform the operation:

$$
x^T \cdot W + b
$$

In [None]:
def simple_linear(x, W, b):
    # We could already vectorize this function with matmul, but as an example,
    # let us use a non-vectorized function with same output
    return (x[:, None] * W).sum(axis=0) + b

In [None]:
# Example inputs
rng, x_rng, w_rng, b_rng = jax.random.split(rng, 4)
x = jax.random.normal(x_rng, (4,))
W = jax.random.normal(w_rng, (4, 3))
b = jax.random.normal(b_rng, (3,))

simple_linear(x, W, b)

Now, we would like the function to support a batch dimension on `x`, which means that we would like to evaluate the function `simple_linear` for multiple inputs `x` at once. Sure, you can do this explicitly in this case, but if your input is an image, a sequence of words/tokens or a tensor, this function can become quite handy.
In general, this additional batch dimension is used as the first dimension, i.e. x is then of dimension `batch_size` times `n`.

Our naive implementation above does not support this, since we specialized the axis we sum over. So, let's make JAX do the work for us and vectorize the function by using `jax.vmap`:

In [None]:
vectorized_linear = jax.vmap(simple_linear,
                             in_axes=(0, None, None),  # Which axes to vectorize for each input
                             out_axes=0  # Which axes to map to in the output
                            )

Specifying `None` for the in-axes of the input arguments `w` and `b` means that we do not want to vectorize any of their input dimensions. With this vmap specification, the function `vectorized_linear` now supports an extra batch dimension in `x`! Let's try it out:

In [None]:
rng, x_vec_rng = jax.random.split(rng, 2)
x_vec = jax.random.normal(x_vec_rng, (10, 4))

vectorized_linear(x_vec, W, b)

**Task**: Of course, you know how to do it explicitly without using `vmap`, right?

In [None]:
# Your code here

The new function indeed vectorized our code, calculating $N$ applications of the weights and bias to the input. We can also vectorize the code to run multiple inputs `x` on multiple weights `w` and biases `b` by changing the input argument `in_axes` to `(0, 0, 0)`, or simply `0`. Morever, we can again stack multiple function transformations, such as jitting a vectorized function. Further details on `jax.vmap` can be found in this [tutorial](
https://jax.readthedocs.io/en/latest/jax-101/03-vectorization.html
) and its [documentation](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html?highlight=vmap).

### Parallel evaluation

`jax.vmap` vectorizes a function on a single accelerator. What we can do if we have multiple GPUs or TPUs available is part of the next lectures and labs.

### Dynamic shapes

JAX has the great advantage of providing just-in-time compilation of functions to speed up the computation. For this, it uses its intermediate representation `jaxpr`, which is specialized to the shapes of the input arguments. However, this also means that a jitted function is specialized to a certain shape, and running the jitted function with a different input shape requires recompiling the function. For instance, consider the following simple function:

In [None]:
def my_function(x):
    print('Running the function with shape', x.shape)
    return x.mean()

jitted_function = jax.jit(my_function)

The print statement is only executed once when the function is compiled, and for all consecutive function calls, this print statement will be ignored since it is not part of the jaxpr representation. Let's run the function now with multiple different input shapes:

In [None]:
for i in range(10):
    jitted_function(jnp.zeros(i+1,))

As we can see, the function is compiled for every different input we give it. This can become inefficient if we actually work with many different shapes. However, running the function again with one of the previous input shapes will not require another compilation:

In [None]:
# Running the functions a second time will not print out anything since
# the functions are already jitted for the respective input shapes.
for i in range(10):
    jitted_function(jnp.zeros(i+1,))

If we have a very limited set of different shapes, we do not see a big performance difference. For instance, in our evaluation, the last batch size is smaller than the previous since we have a limited size of the evaluation dataset. However, for other applications, we might encounter this problem much more often, e.g. for **applications in natural language processing (NLP), time series and graphs**.
In these cases, it is recommend to pad the batches to prevent many re-compilations.

### Debugging in jitted functions

During coding, we likely want to debug our model and sometimes print out intermediate values. In JAX, when jitting functions, this is not as straightforward. As we could see from the previous cells, a print statement is only executed once during compilation, and afterwards removed since it is not part of the jaxpr representation. Furthermore, there can be issues when tracking NaNs in your code (see the [sharp bits tutorial](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#debugging-nans)), and errors like out-of-bounds indexing are silently handled on accelerators by returning -1 instead of an error (see the corresponding section in the [sharp bits tutorial](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#out-of-bounds-indexing)). However, if necessary, one can either run the unjitted version of the forward pass first, and even introduce print statements to the jitted version where needed (see [here](https://github.com/google/jax/issues/196) for a great explanation). Still, it is not as straight-forward as in PyTorch, for example.

### Miscellaneous divergences from NumPy

While `jax.numpy` makes replicates the behavior of `NumPy`’s API, there do exist some differences, e.g. the following.
You can find out more [in this tutorial](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html#miscellaneous-divergences-from-numpy).

**Task**: Observe the different behaviour when executing `np.arange(10)[11]` and `jnp.arange(10)[11]`, resp.

In [None]:
# Your code here

Another example of an unsafe cast with differing results is this one.

In [None]:
print(np.arange(254.0, 258.0).astype('uint8'))
print(jnp.arange(254.0, 258.0).astype('uint8'))

### Double (64bit) precision

JAX by default enforces *single-precision* numbers to mitigate the Numpy API’s tendency to aggressively promote operands to double. This is the desired behavior for many machine-learning applications, but it may catch you by surprise!

In [None]:
x = jax.random.uniform(jax.random.key(0), (1000,), dtype=jnp.float64)
x.dtype

To use double-precision numbers, you need to set the jax_enable_x64 configuration variable.
You can to either by setting the environment variable `JAX_ENABLE_X64=True` or by setting the appropriate jax option via

In [None]:
jax.config.update("jax_enable_x64", True)

In [None]:
x = jax.random.uniform(jax.random.key(0), (1000,), dtype=jnp.float64)
x.dtype

**Important**: XLA doesn’t support 64-bit **convolutions** on all backends! Again, most machine learning models are trained and stored in single precision or other types like `float16` or `bfloat16`.