# Fundamental Concepts of Jax



In [67]:
!pip install jax jaxlib

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


## jax.numpy and numpy
---


In [3]:
import numpy as np
import jax.numpy as jnp

In [4]:
x = jnp.arange(5.0)




In [15]:
x = jnp.arange(5.0)
x.at[2:4].add(10)

Array([ 0.,  1., 12., 13.,  4.], dtype=float32)

In [16]:
from jax import random
key = random.PRNGKey(0)

In [17]:
print(random.normal(key, shape=(1,)))
# Unlike numpy, this will give the exact same results;
print(random.normal(key, shape=(1,)))

[-0.20584226]
[-0.20584226]


In [18]:
key, subkey, _ = random.split(key, 3)
print(random.normal(key, shape=(1,)))
print(random.normal(subkey, shape=(1,)))

[1.1188384]
[0.5781488]


## JIT simple example
---
This is a simple example of using `jit` functionality in Jax. Note that if you run the code twice, things will be different.

In [35]:
import numpy as np
import jax.numpy as jnp
from jax import jit
import jax

In [22]:
@jit
def f(x, y):
  print("Running f():") #Print will introduce an side effect
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)


In [23]:
f(x, y)

Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>


Array([1.5641091, 2.545023 , 4.3680015], dtype=float32)

### Pure functions and jit
___________________________
Here you have to save the global variable if you do not want any side effect.

In [27]:
g = 0.
def impure_uses_globals(x):
  return x + g

print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  
print ("Second call: ", jit(impure_uses_globals)(5.))
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))

# Question: Why does this happen?

First call:  4.0
Second call:  5.0
Third call, different type:  [14.]


In [28]:
g = 0.
def impure_saves_global(x):
  global g
  g = x
  return x

print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g)  


First call:  4.0
Saved global:  Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


In [30]:
from jax import lax

In [31]:
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

45
0


In [32]:
# This works
class Counter:
  """A simple counter."""

  def __init__(self):
    self.n = 0

  def count(self) -> int:
    """Increments the counter and returns the new value."""
    self.n += 1
    return self.n

  def reset(self):
    """Resets the counter to zero."""
    self.n = 0


counter = Counter()

for _ in range(3):
  print(counter.count())

1
2
3


In [33]:
# Doesn't work
counter.reset()
fast_count = jit(counter.count)

for _ in range(3):
  print(fast_count())

1
1
1


In [36]:
from typing import Tuple

CounterState = int

class CounterV2:

  def count(self, n: CounterState) -> Tuple[int, CounterState]:
    return n+1, n+1

  def reset(self) -> CounterState:
    return 0

counter = CounterV2()
state = counter.reset()

fast_count = jax.jit(counter.count)

for _ in range(3):
  value, state = fast_count(state)
  print(value)

1
2
3


### Tracing and static variables/operations in jit


In [37]:
# This will NOT work!
@jit
def f(x, neg):
  return -x if neg else x

f(1, True)

ConcretizationTypeError: ignored

In [38]:
# This works
from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
    print("Need to compile")
    return -x if neg else x

In [39]:
f(1, True)
f(2, False)

Need to compile
Need to compile


Array(2, dtype=int32, weak_type=True)

In [40]:
# This doesn't work
@jit
def f(x):
  return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
f(x)

TypeError: ignored

In [41]:
# To understand why, it is good practice to print which is traced and which is not
@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # comment this out to avoid the error:
  # return x.reshape(jnp.array(x.shape).prod())

f(x)

x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=1/0)>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>


In [42]:
# Just replace it with numpy solves the problem
@jit
def f(x):
  return x.reshape((np.prod(x.shape),))

f(x)

Array([1., 1., 1., 1., 1., 1.], dtype=float32)

### Control flow

In [None]:
# lax.cond is equivalent to 
# def cond(pred, true_fun, false_fun, operand):
#   if pred:
#     return true_fun(operand)
#   else:
#     return false_fun(operand)

In [None]:
@jit
def f(x, neg):
  return lax.cond(neg, 
                  lambda inp: inp, 
                  lambda inp: -inp, 
                  x)
print(f(x, True))
print(f(x, False))

In [None]:
# The definition of scan
# def scan(f, init, xs, length=None):
#   if xs is None:
#     xs = [None] * length
#   carry = init
#   ys = []
#   for x in xs:
#     carry, y = f(carry, x)
#     ys.append(y)
#   return carry, np.stack(ys)

# Example: [1,2,3,4,5,6,7,8,9,10]
# f = lambda x, y:x+y, x+y
# init = 0
# xs = [1,2,3,4,5,6,7,8,9,10]
# carry = 0
# ys = []
# carry, y = 1, 1
# ys = [1]
# carry, y = 3, 3
# ys = [1,3]
# And the code continues

In [None]:
# Exercise: rewrite the following code using scan
def cum(x):
    for i in range(10):
        x+=10
    return x

In [None]:
def cum_v2(x):
    def func_inner(inp, y):
        return inp+10, None
    return lax.scan(func_inner, x, None, 10)

In [None]:
cum_v2(0)

## PyTree

In [45]:
list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]
]

jax.tree_map(lambda x: x*2, list_of_lists)


[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

In [46]:
another_list_of_lists = list_of_lists
jax.tree_map(lambda x, y: x+y*2, list_of_lists, another_list_of_lists)

[[3, 6, 9], [3, 6], [3, 6, 9, 12]]

## vmap and pmap

In [47]:
def predict(W, b, input_vec):
    activations = input_vec
    outputs = jnp.dot(W, activations) + b  # `input_vec` on the right-hand side!
    activations = jnp.tanh(outputs)
    return outputs

In [48]:
from jax import vmap
W, b = jnp.ones((10, 5)), jnp.ones(1)
input_batch = jnp.ones((64,5))
predictions = vmap(predict, in_axes=(None, None, 0))(W, b, input_batch)
print(predictions)

[[6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]
 [6. 6. 6.

# Implement Autograd using Jax

## Autograd basics
---
In general, since we only care about autograd function, we only need to define the so-called Jacobian-vector product (jvp). To understand what is jvp, consider the following approximation of an arbitrary function.

$f(x+v) \approx f(x) + \partial f v$.

Here, $\partial f v$ is the jvp. In most cases, it suffices to to define jvp in order to get complete autograd machenism to work. 



In [49]:
def toy1(w, b):
    return jnp.dot(w, b).sum()

In [50]:
w = jnp.ones((3, 3))
b = jnp.ones(3)

toy1(w, b)

Array(9., dtype=float32)

In [51]:
from jax import grad, value_and_grad
print(grad(toy1)(w, b))

[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]


In [52]:
print(grad(toy1, (0, ))(w, b))

(Array([[1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.]], dtype=float32),)


In [53]:
print(grad(toy1, (1, ))(w, b))

(Array([3., 3., 3.], dtype=float32),)


In [54]:
print(grad(toy1, (0, 1))(w, b))

(Array([[1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.]], dtype=float32), Array([3., 3., 3.], dtype=float32))


In [55]:
# A more convenient way is to define the dictionaries. 
def toy2(params):
    return jnp.dot(params['w'], params['b']).sum()

params = dict()
params['w'] = w
params['b'] = b

print(toy2(params))

9.0


In [56]:
print(grad(toy2)(params))

{'b': Array([3., 3., 3.], dtype=float32), 'w': Array([[1., 1., 1.],
       [1., 1., 1.],
       [1., 1., 1.]], dtype=float32)}


## Defining our own autograd functions

In [57]:
from jax import custom_jvp

### This is the original function 
@custom_jvp
def log1pexp(x):
  return jnp.log(1. + jnp.exp(x))

### This is the definition of custom jvp 
@log1pexp.defjvp
def log1pexp_jvp(primals, tangents):
  x, = primals # This is the original input
  x_dot, = tangents # This is the tangents associated with x
  ans = log1pexp(x) # The function value
  ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot # The derivative w.r.t. to times the tangets of x
  return ans, ans_dot

## Entmax-$\alpha$ implementation
_________________________________

This comes from [this paper](https://arxiv.org/pdf/1912.11637.pdf).

In [59]:
### The first idea is to try it without any JIT


import jax.numpy as jnp
from jax import grad, jit, value_and_grad
from jax import vmap, pmap
from jax import random
import jax
from jax import lax
from jax import custom_jvp

def p_tau(z, tau, alpha=1.5):
    return jnp.clip((alpha - 1) * z - tau, 0) ** (1 / (alpha - 1))


def get_tau(tau, tau_max, tau_min, z_value):
    return lax.cond(z_value < 1,
                    lambda _: (tau, tau_min),
                    lambda _: (tau_max, tau),
                    operand=None
                    )
    
def body(kwargs, x):
    tau_min = kwargs['tau_min']
    tau_max = kwargs['tau_max']
    z = kwargs['z']
    alpha = kwargs['alpha']

    tau = (tau_min + tau_max) / 2
    z_value = p_tau(z, tau, alpha).sum()
    taus = get_tau(tau, tau_max, tau_min, z_value)
    tau_max, tau_min = taus[0], taus[1]
    return {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, None

def map_row(z_input, alpha, T):
    z = (alpha - 1) * z_input

    tau_min, tau_max = jnp.min(z) - 1, jnp.max(z) - z.shape[0] ** (1 - alpha)
    result, _ = lax.scan(body, {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, xs=None,
                         length=T)
    tau = (result['tau_max'] + result['tau_min']) / 2
    result = p_tau(z, tau, alpha)
    return result / result.sum()

def _entmax(input, axis=-1, alpha=1.5, T=20):
    result = vmap(lambda z: map_row(z, alpha, T), axis)(input) ## Pay attention here!
    return result

def entmax(input, axis=-1, alpha=1.5, T=10):
    return _entmax(input, axis, alpha, T)



In [60]:
import numpy as np
input = jnp.array(np.random.randn(64, 10)).block_until_ready()
weight = jnp.array(np.random.randn(64, 10)).block_until_ready()

def toy(input, weight):
    return (weight*entmax(input, 0, 1.5, 20)).sum()

toy(input, weight)

Array(-7.328603, dtype=float32)

In [63]:
### The second step is to do it without JIT

@partial(custom_jvp, nondiff_argnums=(1, 2, 3,))
def entmax(input, axis=-1, alpha=1.5, T=10):
    return _entmax(input, axis, alpha, T)

def _entmax_jvp_impl(axis, alpha, T, primals, tangents):
    input = primals[0]
    Y = entmax(input, axis, alpha, T)
    gppr = Y  ** (2 - alpha)
    grad_output = tangents[0]
    dX = grad_output * gppr
    q = dX.sum(axis=axis) / gppr.sum(axis=axis)
    q = jnp.expand_dims(q, axis=axis)
    dX -= q * gppr
    return Y, dX


@entmax.defjvp
def entmax_jvp(axis, alpha, T, primals, tangents):
    return _entmax_jvp_impl(axis, alpha, T, primals, tangents)

In [64]:
import numpy as np
input = jnp.array(np.random.randn(64, 10)).block_until_ready()
weight = jnp.array(np.random.randn(64, 10)).block_until_ready()

def toy(input, weight):
    return (weight*entmax(input, 0, 1.5, 20)).sum()

value_and_grad(toy)(input, weight)

(Array(9.618575, dtype=float32),
 Array([[ 0.0163976 , -0.10273282,  0.06469838,  0.01347553, -0.5293114 ,
         -0.        ,  0.05213852, -0.18185703,  0.10476619, -0.16644944],
        [ 0.15140419, -0.11494707,  0.28096712, -0.06867173,  0.        ,
         -0.28615507,  0.        ,  0.02768484,  0.00857848,  0.19686379],
        [ 0.0899851 ,  0.17949833, -0.15484603,  0.15318573,  0.35233545,
          0.        ,  0.09586761, -0.01077976, -0.00738061, -0.12621234],
        [-0.22487909,  0.04021792,  0.19407374,  0.6367048 ,  0.10712407,
         -0.27746785, -0.2437227 , -0.11114511,  0.09989041, -0.00346325],
        [ 1.0118692 ,  0.35418448, -0.08551174, -1.0712987 ,  0.3389108 ,
         -0.03183718, -0.        , -0.        , -0.1114131 , -0.        ],
        [ 0.        , -0.20795597,  0.11889306,  0.16654454,  0.00884845,
         -0.        , -0.        ,  0.91167074,  0.5254332 , -0.8373338 ],
        [-0.00269561, -0.        ,  0.55038655, -0.10957966,  0.5481675 ,

In [66]:
import jax.numpy as jnp
from jax import grad, jit, value_and_grad
from jax import vmap, pmap
from jax import random
import jax
from jax import lax
from jax import custom_jvp


def p_tau(z, tau, alpha=1.5):
    return jnp.clip((alpha - 1) * z - tau, 0) ** (1 / (alpha - 1))


def get_tau(tau, tau_max, tau_min, z_value):
    return lax.cond(z_value < 1,
                    lambda _: (tau, tau_min),
                    lambda _: (tau_max, tau),
                    operand=None
                    )


def body(kwargs, x):
    tau_min = kwargs['tau_min']
    tau_max = kwargs['tau_max']
    z = kwargs['z']
    alpha = kwargs['alpha']

    tau = (tau_min + tau_max) / 2
    z_value = p_tau(z, tau, alpha).sum()
    taus = get_tau(tau, tau_max, tau_min, z_value)
    tau_max, tau_min = taus[0], taus[1]
    return {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, None

@partial(jax.jit, static_argnums=(2,))
def map_row(z_input, alpha, T):
    z = (alpha - 1) * z_input

    tau_min, tau_max = jnp.min(z) - 1, jnp.max(z) - z.shape[0] ** (1 - alpha)
    result, _ = lax.scan(body, {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, xs=None,
                         length=T)
    tau = (result['tau_max'] + result['tau_min']) / 2
    result = p_tau(z, tau, alpha)
    return result / result.sum()

@partial(jax.jit, static_argnums=(1,3,))
def _entmax(input, axis=-1, alpha=1.5, T=20):
    result = vmap(lambda z: map_row(z, alpha, T), axis)(input)
    return result

@partial(custom_jvp, nondiff_argnums=(1, 2, 3,))
def entmax(input, axis=-1, alpha=1.5, T=10):
    return _entmax(input, axis, alpha, T)

@partial(jax.jit, static_argnums=(0,2,))    
def _entmax_jvp_impl(axis, alpha, T, primals, tangents):
    input = primals[0]
    Y = entmax(input, axis, alpha, T)
    gppr = Y  ** (2 - alpha)
    grad_output = tangents[0]
    dX = grad_output * gppr
    q = dX.sum(axis=axis) / gppr.sum(axis=axis)
    q = jnp.expand_dims(q, axis=axis)
    dX -= q * gppr
    return Y, dX


@entmax.defjvp
def entmax_jvp(axis, alpha, T, primals, tangents):
    return _entmax_jvp_impl(axis, alpha, T, primals, tangents)

import numpy as np
input = jnp.array(np.random.randn(64, 10)).block_until_ready()
weight = jnp.array(np.random.randn(64, 10)).block_until_ready()

def toy(input, weight):
    return (weight*entmax(input, 0, 1.5, 20)).sum()

jax.jit(value_and_grad(toy))(input, weight)

(Array(5.6667066, dtype=float32),
 Array([[ 5.72826028e-01, -6.89338446e-01, -2.99717970e-02,
          1.43840790e-01, -0.00000000e+00,  2.27360785e-01,
         -9.76727128e-01, -0.00000000e+00,  1.70223825e-02,
          1.08865477e-01],
        [ 4.90314923e-02, -4.86451328e-01,  5.29754572e-02,
          4.89378944e-02,  0.00000000e+00,  0.00000000e+00,
         -2.82343719e-02, -2.83114940e-01, -0.00000000e+00,
          2.43406266e-01],
        [ 0.00000000e+00, -1.47417799e-01,  1.36417411e-02,
         -0.00000000e+00, -2.54689276e-01,  2.44241610e-01,
          3.75806056e-02, -2.16792032e-01, -0.00000000e+00,
         -4.15655375e-01],
        [ 0.00000000e+00,  4.33304101e-01,  1.25181958e-01,
         -5.59592582e-02,  2.47171968e-02,  3.68507579e-02,
         -5.63591063e-01, -2.94631310e-02, -0.00000000e+00,
          4.62481171e-01],
        [-5.11221355e-03, -6.72956169e-01, -7.54861832e-01,
         -1.49681762e-01, -1.72253639e-01, -3.39079142e-01,
          1.498270