<a href="https://colab.research.google.com/github/pchanial/python-for-data-scientists/blob/master/Course_Numpy_JAX_APC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Numpy & JAX



## Objectives of the course:
- Understand why Numpy is faster than Vanilla Python
- Understand why JAX is faster than Numpy
- Numpy concepts no more applicable to JAX
- Numpy concepts still applicable to JAX


In [110]:
# imports necessary for the course
import timeit
from typing import Any, Callable, Sequence
import jax
import jax.numpy as jnp
import jaxlib
import numpy as np
import matplotlib.pyplot as mp
#import polars as pl

print('Devices:', jax.devices())
import time


# helpers
def pointer(x: np.ndarray) -> int:
  """Returns the memory address of the first array element."""
  return x.__array_interface__['data'][0]


def bench_one(func: Callable[[], Any]) -> float:
    """Returns execution time in s."""
    repeat = 7
    timer = timeit.Timer(func)
    number, _ = timer.autorange()
    runs = np.array([_ / number for _ in timer.repeat(repeat=repeat, number=number)])
    runs_ms = runs * 1000
    print(f'{np.min(runs_ms):.3f} ms ± {np.std(runs_ms) * 1000:.2f} µs (min ± std. dev. of {repeat} runs, {number} loops each)')
    return np.min(runs)


def bench(func: Callable, values: Sequence[Any], *, setup: Callable | None = None) -> list[float]:
    elapsed_times = []
    for value in values:
        if setup is not None:
            args = setup(value)
        else:
            args = (value,)
        if isinstance(func, jaxlib.xla_extension.PjitFunction):
            func(*args)
            benchmarked_func = lambda: func(*args).block_until_ready()
        else:
            benchmarked_func = lambda: func(*args)
        elapsed_times.append(bench_one(benchmarked_func))
    return elapsed_times


def bench_many(funcs: Sequence[Callable], values: Sequence[Any], setups: Sequence[Callable | None], labels: Sequence[str]) -> None:
    for func, setup, label in zip(funcs, setups, labels):
        run_times = bench(func, values, setup=setup)
        mp.loglog(values, run_times, marker='.', label=label)
    mp.ylabel('Elapsed time [s]')
    mp.legend()


Devices: [CpuDevice(id=0)]


## Why is Numpy faster than Vanilla Python

The Numpy array model is quite powerful, but before delving into the details on how Numpy arrays can be manipulated, it is interesting to understand why they are much more efficient than Python lists. First, when a Numpy array is created, its elements are stored one next to the other (the memory storage is contiguous, see figure on the left for a 2-dimensional array), whereas in a Python list, elements are created before the list and they can be stored wherever in the memory (the memory storage is scattered, see figure on the right). In most systems, data from the main memory is transferred to the CPU via layers of caches, which implies that memory transfers from the cache to the CPU involve whole chunks of contiguous memory (a cache line) even if only few bytes in the cache line are actually requested by the CPU. As a consequence, a non-contiguous memory storage of the data will force the transfer of unneeded data from the cache and will incur a bandwidth penalty. In addition to that, modern architectures also have the possibility to anticipate transfers from the memory by prefetching the next cache lines. This mechanism will obviously better work when the data storage is contiguous. A second advantage of Numpy’s arrays over Python’s lists is that all elements occupy the same number of bytes, and as a consequence, the location of an element in the memory (its address) can be cheaply computed from its index and the location of the first element. There is no such relationship in Python lists: the location of each element has to be stored in the memory, so that every read or write access has the indirection overhead of transferring this element location to the CPU beforehand.

- Vanilla Python memory layout

![list of list](https://raw.githubusercontent.com/pchanial/python-for-data-scientists/gh-pages/source/layout_listoflist.png)

- Numpy memory layout

![ndarray](https://raw.githubusercontent.com/pchanial/python-for-data-scientists/gh-pages/source/layout_2darray.png)


For writing efficient code for Numpy and JAX, one should think in terms of vectors, matrices or tensors to reduce the back and forth between the low-level implementation of mathematical functions and the Python interpreter.



**Exercice**: Wave the for loops goodbye!

Compute $\pi$, as given by the Madhava formula
$\pi = \sqrt{12}\sum^\infty_{k=0} \frac{(-\frac{1}{3})^{k}}{2k+1}$.
The $k$ indices ranging from 0 to (let’s say) 29 will be returned by the NumPy function `arange` (see above) and $\pi$ will be computed by calling another NumPy function (`sum`), instead of using a for loop.

In [None]:
N = 30
pi = ...
assert abs(pi - np.pi) < 1e-15

## Why is JAX faster than Numpy

While operation fusion (or kernel fusion) is the flagship feature of XLA, it should be noted that XLA also performs a ton of other whole-program optimizations, like specializing to known tensor shapes (allowing for more aggressive constant propagation), analyzing and scheduling memory usage to eliminate intermediate storage buffers[4], performing memory layout operations, and only computing subsets of requested values if not all of them are being returned[5].

In [None]:
values = [2**n for n in range(10, 24)]

setup_numpy = lambda n: (np.arange(n, dtype=np.float32),)
setup_jax = lambda n: (jnp.arange(n, dtype=np.float32),)
results = bench_many(
    [np.sum, jax.jit(jnp.sum)],
    values,
    setups=[setup_numpy, setup_jax],
    labels=['numpy', 'jax'],
)
mp.title('sum(x)')

In [None]:
values = [2**n for n in range(10, 24)]

setup_numpy = lambda n: (np.random.normal(size=n).astype(np.float32), np.random.normal(size=n).astype(np.float32))
key = jax.random.key(0)
key1, key2 = jax.random.split(key)
setup_jax = lambda n: (jax.random.normal(key1, (n,), np.float32), jax.random.normal(key2, (n,), np.float32))

bench_many(
    [lambda x, y: 2 * x * y + 3 * y + 1, jax.jit(lambda x, y: 2 * x * y + 3 * y + 1)],
    values,
    setups=[setup_numpy, setup_jax],
    labels=['numpy', 'jax'],
)
mp.title('2xy + 3y + 1')



## Numpy concepts no more applicable to JAX
- inplace operations
- promotion rules
- np.random (use `jax.random`)


## Numpy concepts still applicable to JAX

- indexing
- broadcasting
- ufunc methods



### Indexing

### Broadcasting
Broadcasting allows operations (such as addition, multiplication etc.) which are normally element-wise to be carried on arrays of different shapes. It is a virtual replication of the arrays along the missing dimensions. It can be seen as a generalization of operations involving an array and a scalar.

- the addition of a scalar on an matrix can be seen as the addition of a matrix with identical elements (and same dimensions).

![broadcast scalar](https://raw.githubusercontent.com/pchanial/python-for-data-scientists/gh-pages/source/broadcast_scalar.png)

- the addition of a row on a matrix will be seen as the addition of a matrix with replicated rows (the number of columns must match).

![broadcast column](https://raw.githubusercontent.com/pchanial/python-for-data-scientists/gh-pages/source/broadcast_column.png)

- conversely the addition of a column on a matrix will be seen as the addition of a matrix with replicated columns (the number of rows must match)

![broadcast row](https://raw.githubusercontent.com/pchanial/python-for-data-scientists/gh-pages/source/broadcast_row.png)

- What if the rank of the arrays is greater than 2? There is no restriction on the rank: any dimension of length 1 is broadcastable and is virtually replicated to match the other array’s dimension length. The two arrays may have different broadcastable dimensions. If this happens, the result of the operation will have more elements than any of the operands.

- Can it work on arrays of different ranks? Sure! Dimensions of length 1 are prepended (added on the left of the array shape) until the two arrays have the same rank. As a consequence, the following operation is possible:

```python
np.zeros((5, 9)) + np.ones(9)
```

but not this one, since the righmost dimensions are different:

```python
np.zeros((5, 9)) + np.ones(5)
# ValueError: operands could not be broadcast together with shapes (5,9) (5)
```

So for columns, an additional dimension must be specified and added on the right:

```python
np.zeros((5, 9)) + np.ones(5)[:, None]
```

Can it work on more than two arrays? Yes again! But you have to find an element-wise operation with more than two operands...

Since the replication is virtual, no memory is wasted. Broadcasting is fast. Use it wherever possible, just keep an eye on the size of the broadcast result to make sure that it does not become too large.



In [None]:
np.broadcast_shapes((3, 1, 4), (3, 4))

In [None]:
a1 = np.ones((4, 1))
a2 = np.ones((2, 1, 3))
a1_broadcast, a2_broadcast = np.broadcast_arrays(a1, a2)
print('broadcast shapes:', a1_broadcast.shape, a2_broadcast.shape)
print('same memory layout:', pointer(a1) == pointer(a1_broadcast))
print(f'{a1.strides=}', f'{a1_broadcast.strides=}')

**Exercise 1**: Can the arrays of the following shapes be broadcast together? If yes, what would be the shape of the result?

    (7, 1) and (7, 4)
    (7,) and (4, 7)
    (3, 3) and (2, 3)**texte en gras**
    (1, 1, 1, 8) and (1, 9, 1)
    (4, 1, 9) and (3, 1)

**Exercice 2**: Remove the for loops in this code by using broadcasting and measure the improvement in execution time.


In [None]:
import matplotlib.pyplot as mp
import numpy as np

NDETECTOR = 8
NSAMPLE = 1000
SAMPLING_PERIOD = 0.1
GLITCH_TAU = 0.3
GLITCH_AMPL = 20
GAIN_SIGMA = 0.03
SOURCE_AMPL = 7
SOURCE_PERIOD = 5
NOISE_SIGMA = 0.7

time = np.arange(NSAMPLE) * SAMPLING_PERIOD
glitch = np.zeros(NSAMPLE)
glitch[100:] = GLITCH_AMPL * np.exp(-time[:-100] / GLITCH_TAU)
gain = 1 + GAIN_SIGMA * np.random.standard_normal(NDETECTOR)
offset = np.arange(NDETECTOR)
source = SOURCE_AMPL * np.sin(2 * np.pi * time / SOURCE_PERIOD)
noise = NOISE_SIGMA * np.random.standard_normal((NDETECTOR, NSAMPLE))

signal = np.empty((NDETECTOR, NSAMPLE))
for idet in range(NDETECTOR):
    for isample in range(NSAMPLE):
        signal[idet, isample] = (
            gain[idet] * source[isample]
            + glitch[isample]
            + offset[idet]
            + noise[idet, isample]
        )

mp.figure()
mp.subplot(211)
mp.imshow(signal, aspect='auto', interpolation='none')
mp.xlabel('sample')
mp.ylabel('detector')
mp.subplot(212)
for s in signal:
    mp.plot(time, s)
mp.xlabel('time [s]')
mp.ylabel('signal')
mp.show()


**Exercice 3**: Write a one-liner function that normalizes by the euclidian norm M N-dimensional real vectors packed in an array of shape (M, N). Bonus if the function works with a tensor of any rank, such as (P, Q, M, N).

In [None]:
import numpy as np

def fast_normalize(v):
    return v / ???

vectors = np.random.normal(size=(10, 3))
expected_normalized_vectors = vectors.copy()
for vector in expected_normalized_vectors:
    vector /= np.sqrt(vector[0]**2 + vector[1]**2 + vector[2]**2)
actual_normalized_vectors = fast_normalize(vectors)

assert np.allclose(actual_normalized_vectors, expected_normalized_vectors)

### Universal function methods

