In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [None]:
key = random.key(0)
x = random.normal(key, (10,))
print(x)

In [None]:
size = 5000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

In [None]:
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()

In [None]:
from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

In [None]:
def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()

In [None]:
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()

In [None]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

In [None]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

In [None]:
@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

In [None]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

In [None]:
x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)

In [None]:
from jax import lax
result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # note: explicit promotion
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])  # equivalent of padding='full' in NumPy
result[0, 0]

In [None]:
.56*375

In [None]:
@jit
def get_negatives(x):
  return x[x < 0]

x = jnp.array(np.random.randn(10))
get_negatives(x)

In [None]:
@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)

In [None]:
from jax import random

In [None]:
key = random.key(1)
print(f"key: {key}")
print(f"random 1 = {random.normal(key, shape=(1,))}")
print(f"random 1 = {random.normal(key, shape=(1,))}")
key, subkey = random.split(key)
print(f"random 2 = {random.normal(key, shape=(1,))}")
print(f"subrandom 1 = {random.normal(subkey, shape=(1,))}")
key, subkey2 = random.split(key)
print(f"random 3 = {random.normal(key, shape=(1,))}")
print(f"subrandom 2 = {random.normal(subkey, shape=(1,))}")
subkey, subsubkey = random.split(subkey)
print(f"subrandom 3 {random.normal(subkey, shape=(1,))}")
print(f"subsubrandom 1 = {random.normal(subsubkey, shape=(1,))}")

In [None]:
def nansum(x):
  """ Sum all input values, ignoring NaNs. """
  mask = ~jnp.isnan(x)  # boolean mask selecting non-nan values
  x_without_nans = x[mask]
  return x_without_nans.sum()

In [None]:
x = jnp.array([1, 2, jnp.nan, 3, 4])
print(nansum(x))

In [None]:
@jit
def nansum_2(x):
  mask = ~jnp.isnan(x)  # boolean mask selecting non-nan values
  return jnp.where(mask, x, 0).sum()

print(nansum_2(x))

In [None]:
sum_single = lambda carry, x: (carry, x) if jnp.isnan(x) else (carry + x, x) 
def nansum_3(x):
    return lax.scan(sum_single, 0, x)

In [None]:
x = random.uniform(random.key(0), (1000,), dtype=jnp.float64)
x.dtype

In [None]:
long_vector = jnp.arange(int(1e7))

%timeit jnp.dot(long_vector, long_vector).block_until_ready()

In [None]:
long_vector_on_cpu = jax.device_put(long_vector, device=jax.devices('cpu')[0])

%timeit jnp.dot(long_vector_on_cpu, long_vector_on_cpu).block_until_ready()
                                                                

---

## Tutorial: JAX 101

In [None]:
def sum_of_squares(x):
  return jnp.sum(x**2)

sum_of_squares_dx = jax.grad(sum_of_squares)

x = jnp.asarray([1.0, 2.0, 3.0, 4.0])

print(sum_of_squares(x))

print(sum_of_squares_dx(x))

In [None]:
def sum_squared_error(x, y):
  return jnp.sum((x-y)**2)

sum_squared_error_dx = jax.grad(sum_squared_error)

y = jnp.asarray([1.1, 2.1, 3.1, 4.1])

print(sum_squared_error_dx(x, y))

First training loop

In [None]:
import jax.numpy as jnp
import jax
from jax import grad, jit, vmap
from jax import random
import numpy as np
import matplotlib.pyplot as plt
import lovely_jax as lj
from lovely_numpy import lo
lj.monkey_patch()


In [None]:
xs = np.random.normal(size=(100,))
noise = np.random.normal(scale=0.1, size=(100,))
ys = xs * 3 - 1 + noise

plt.scatter(xs, ys);

In [None]:
def model(theta, x):
  """Computes wx + b on a batch of input x."""
  w, b = theta
  return w * x + b

@jit
def loss_fn(theta, x, y):
  prediction = model(theta, x)
  return jnp.mean((prediction-y)**2)

def update(theta, x, y, lr=0.1):
  return theta - lr * jax.grad(loss_fn)(theta, x, y)

In [None]:
theta = jnp.array([1., 1.])

for _ in range(1000):
  theta = update(theta, xs, ys)

plt.scatter(xs, ys)
plt.plot(xs, model(theta, xs), c='r')

w, b = theta
print(f"w: {w:<.2f}, b: {b:<.2f}")

In [None]:
print(jax.make_jaxpr(update)(theta, xs, ys))

In [None]:
import jax
import jax.numpy as jnp

x = jnp.arange(7)
w = jnp.array([3., 3., 4.])

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

convolve(x, w)

In [None]:
xs = jnp.stack([x, x])
ws = jnp.stack([w, w])

def manually_batched_convolve(xs, ws):
  output = []
  for i in range(xs.shape[0]):
    output.append(convolve(xs[i], ws[i]))
  return jnp.stack(output)

manually_batched_convolve(xs, ws)

In [None]:
def manually_vectorized_convolve(xs, ws):
  output = []
  for i in range(1, xs.shape[-1] -1):
    output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
  return jnp.stack(output, axis=1)

manually_vectorized_convolve(xs, ws)

In [None]:
auto_batch_convolve = jax.vmap(convolve)
auto_batch_convolve(xs, ws)

In [None]:
def f(x):
    return jnp.dot(x, x)

print(jax.jacfwd(f)(jnp.array([1., 1., 1.])))

In [None]:
import jax
import jax.numpy as jnp

example_trees = [
    [1, 'a', object()],
    (1, (2, 3), ()),
    [1, {'k1': 2, 'k2': (3, 4)}, 5],
    {'a': 2, 'b': (2, 3)},
    jnp.array([1, 2, 3]),
]

# Let's see how many leaves they have:
for pytree in example_trees:
  leaves = jax.tree_util.tree_leaves(pytree)
  print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")

In [None]:
[jax.tree_util.keystr(path) for path in (jax.tree_util.tree_flatten_with_path(example_trees))[0]]

In [None]:
import numpy as np

def init_mlp_params(layer_widths):
  params = []
  for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
    params.append(
        dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
             biases=np.ones(shape=(n_out,))
            )
    )
  return params

params = init_mlp_params([1, 128, 128, 1])

jax.tree.map(lambda x: x.shape, params)

In [None]:
def forward(params, x):
  *hidden, last = params
  for layer in hidden:
    x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
  return x @ last['weights'] + last['biases']

def loss_fn(params, x, y):
  return jnp.mean((forward(params, x) - y) ** 2)

LEARNING_RATE = 0.0001

@jax.jit
def update(params, x, y):

  grads = jax.grad(loss_fn)(params, x, y)
  # Note that `grads` is a pytree with the same structure as `params`.
  # `jax.grad` is one of the many JAX functions that has
  # built-in support for pytrees.

  # This is handy, because we can apply the SGD update using tree utils:
  return jax.tree.map(
      lambda p, g: p - LEARNING_RATE * g, params, grads
  )

import matplotlib.pyplot as plt

xs = np.random.normal(size=(128, 1))
ys = xs ** 2

for _ in range(1000):
  params = update(params, xs, ys)

plt.scatter(xs, ys)
plt.scatter(xs, forward(params, xs), label='Model prediction')
plt.legend();

In [None]:
a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]

# Try to make another tree with ones instead of zeros
shapes = jax.tree.map(lambda x: x.shape, a_tree)
print(shapes)
jax.tree.map(jnp.ones, shapes)


In [None]:
help(jax.random.normal)

In [None]:
k = jax.random.key(43)
k1, k2 = jax.random.split(k)

# generate the data
n = 1000 # number of data points
n_epochs = 1000
true_slope = 3.5
true_bias = 1.0
noise_amplitude = 0.3
xs = jax.random.normal(k1, (n,))
ys = true_slope * xs + true_bias + noise_amplitude * jax.random.normal(k2, (n,))

parameters = {'slope': 1.0, 'bias': 0.5}

def forward(params: dict, x: np.ndarray | jnp.ndarray) -> np.ndarray | jnp.ndarray:
    return params['slope'] * x + params['bias']

def loss_fn(params:dict, xs: jnp.ndarray, ys: jnp.ndarray) -> jnp.float32:
    return jnp.sum((ys - forward(params, xs))**2)

@jit
def update(params: dict, xs, ys, lr=0.0005):
    grads = jax.grad(loss_fn)(params, xs, ys)
    new_values = jax.tree.map(lambda p, g: p - lr * g, params, grads)
    return new_values

for _ in range(n_epochs):
    parameters = update(parameters, xs, ys)

print(f"slope: {parameters['slope']}")
print(f"bias: {parameters['bias']}")
    