Source: https://www.youtube.com/watch?v=SstuvS-tVc0

In [None]:
import jax.numpy as jnp

##### Example 1

In [None]:
xs = jnp.array([0, 1, 2, 3, 4])

In [None]:
type(xs)

jaxlib.xla_extension.Array

In [None]:
xs

Array([0, 1, 2, 3, 4], dtype=int32)

Create a new array that has the same values as `xs`, but with the values of indices `2` to `100`

In [None]:
new_xs = xs.at[2].set(100)

In [None]:
new_xs

Array([  0,   1, 100,   3,   4], dtype=int32)

In [None]:
xs

Array([0, 1, 2, 3, 4], dtype=int32)

##### Example 2

In [None]:
from jax import random

In [None]:
seed = 42

Create a pseudo random with seed `42`

In [None]:
key = random.PRNGKey(seed)

In [None]:
key

Array([ 0, 42], dtype=uint32)

### Transform Functions

##### Example 1

In [None]:
import jax

In [None]:
def square(x):
    return x**2

Compile the function `square` to XLA-optimized machine code 

In [None]:
modified_square = jax.jit(square)

In [None]:
modified_square(2)

Array(4, dtype=int32, weak_type=True)

##### Example 2

In [None]:
xs = jnp.arange(3.)

In [None]:
import jax
import jax.numpy as jnp

In [None]:
def loss(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

Compute the gradient of the function `loss`

In [None]:
grad_loss = jax.grad(loss)

In [None]:
grad_loss

<function __main__.loss(x)>

In [None]:
xs, grad_loss(xs)

(Array([0., 1., 2.], dtype=float32),
 Array([0.25      , 0.19661197, 0.10499357], dtype=float32))

##### Example 2

In [None]:
import jax

In [None]:
f = lambda x: x**2 + x + 4

Compute the 3rd order derivative of `f`

In [None]:
dfdx = jax.grad(f)

In [None]:
d2fdx = jax.grad(dfdx)

In [None]:
d3fdx = jax.grad(d2fdx)

In [None]:
d3fdx

<function __main__.<lambda>(x)>

In [None]:
d3fdx(1.)

Array(0., dtype=float32, weak_type=True)

##### Example 3

In [None]:
def multiply(x, y):
    return x * y

In [None]:
xs = [3, 4, 5]
ys = [6, 7, 8]

In [None]:
results = []
for x, y in zip(xs, ys):
    result = multiply(x, y)
    results.append(result)

In [None]:
results

[18, 28, 40]

Apply the function `multiply` to  all of the pairs of numbers `xs`, and `ys` at once

In [None]:
import jax
import jax.numpy as jnp

In [None]:
vmapped_multiply = jax.vmap(multiply)
results = vmapped_multiply(jnp.array(xs), jnp.array(ys))

In [None]:
results

Array([18, 28, 40], dtype=int32)

##### Example 4

In [None]:
import numpy as np

In [None]:
x = np.random.randn(3, 5)
y = np.random.randn(5, 3)

In [None]:
type(x), type(y)

(numpy.ndarray, numpy.ndarray)

In [None]:
from jax import jit

In [None]:
@jit
def f(x, y):
    print(f"this is a comment!")
    return jnp.dot(x+1, y+1)

The first call

In [None]:
f(x, y)

this is a comment!


Array([[18.459412 , 11.57047  ,  7.9957714],
       [ 5.205792 ,  9.011772 ,  2.5705764],
       [11.115227 ,  8.656818 ,  8.333181 ]], dtype=float32)

The second call

In [None]:
f(x, y)

Array([[18.459412 , 11.57047  ,  7.9957714],
       [ 5.205792 ,  9.011772 ,  2.5705764],
       [11.115227 ,  8.656818 ,  8.333181 ]], dtype=float32)

Why does the function not print the comment when it is called the second time?

**Explain**: Because when you call the function the first time, it is compiled and the comments are printed. But when you call the function the second time, the compiled version is used (the compiled code does not include comments), and that's why the comments are not printed again.

##### Example 5

##### Example 6

In [None]:
import random

Is `add_random_number` a pure function? Explain (both yes and no)

In [None]:
def add_random_number(x):
    random_num = random.uniform(0, 1)
    print(f"Adding random number {random_num}")
    return x + random_num

**Explain**

No. It's not a pure function because:

- Reason 1: The output depends not only on the input `x` but also on the randomly generated number `random_num`. This means that, given the same input `x`, the function can produce different outputs, which violates the first property of a pure function.

- Reason 2: The function has a side effect of printing the random number to the console, which violates the second property of a pure function.

In [None]:
add_random_number(1) == add_random_number(1)

Adding random number 0.28968690517955153
Adding random number 0.4863246276705302


False

##### Example 7

In [None]:
import jax.numpy as jnp

In [None]:
x = jnp.arange(10)

In [None]:
x

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [None]:
x.shape

(10,)

In [None]:
new_x = x.at[42].add(69)

What is the value of `new_x`? Explain

**Hint**: Maybe has some error 😉

**Explain**

The value of `new_x` is the same as `x`.

Because JAX handles out-of-bounds indices differently than NumPy/Python array.

In JAX, when an out-of-bounds index is used in `x.at[idx].add(val)`, it will simply ignore the operation and return the original array without changes.

In [None]:
new_x

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

##### Example 8

In [None]:
from jax import make_jaxpr