# Fundamental Concepts of Jax/Flax

In [None]:
!pip install jax jaxlib flax 

Collecting flax
[?25l  Downloading https://files.pythonhosted.org/packages/63/1f/63c720200f7a679d9fd408eb2641960d6fa99030c874ecba24091e694f91/flax-0.3.3-py3-none-any.whl (179kB)
[K     |█▉                              | 10kB 16.7MB/s eta 0:00:01[K     |███▋                            | 20kB 22.0MB/s eta 0:00:01[K     |█████▌                          | 30kB 15.0MB/s eta 0:00:01[K     |███████▎                        | 40kB 13.0MB/s eta 0:00:01[K     |█████████▏                      | 51kB 9.7MB/s eta 0:00:01[K     |███████████                     | 61kB 9.2MB/s eta 0:00:01[K     |████████████▉                   | 71kB 8.8MB/s eta 0:00:01[K     |██████████████▋                 | 81kB 9.6MB/s eta 0:00:01[K     |████████████████▌               | 92kB 9.8MB/s eta 0:00:01[K     |██████████████████▎             | 102kB 8.9MB/s eta 0:00:01[K     |████████████████████▏           | 112kB 8.9MB/s eta 0:00:01[K     |██████████████████████          | 122kB 8.9MB/s eta 0:00:

## jax.numpy and numpy
---


In [None]:
import numpy as np
import jax.numpy as jnp

In [None]:
jax_array = jnp.zeros((3,3), dtype=jnp.float32) 




In [None]:
jax_array[0,0]=1 # This will throw an error

TypeError: ignored

In [None]:
from jax.ops import index, index_add, index_update

In [None]:
new_jax_array = index_update(jax_array, index[1, :], 1.)
new_jax_array

DeviceArray([[0., 0., 0.],
             [1., 1., 1.],
             [0., 0., 0.]], dtype=float32)

In [None]:
new_jax_array = index_add(new_jax_array, index[1, :], 1)
new_jax_array

DeviceArray([[0., 0., 0.],
             [2., 2., 2.],
             [0., 0., 0.]], dtype=float32)

In [None]:
from jax import random
key = random.PRNGKey(0)

In [None]:
print(random.normal(key, shape=(1,)))
# Unlike numpy, this will give the exact same results;
print(random.normal(key, shape=(1,)))

[-0.20584235]
[-0.20584235]


In [None]:
key, subkey = random.split(key)
print(random.normal(key, shape=(1,)))
print(random.normal(subkey, shape=(1,)))

[0.14389044]
[-1.2515389]


In [None]:
jnp.add(jax_array, 1) # works

DeviceArray([[1., 1., 1.],
             [1., 1., 1.],
             [1., 1., 1.]], dtype=float32)

In [None]:
from jax import lax
lax.add(jax_array, 1) # nope

TypeError: ignored

In [None]:
lax.add(jax_array, 1.0)

DeviceArray([[1., 1., 1.],
             [1., 1., 1.],
             [1., 1., 1.]], dtype=float32)

In [None]:
x = random.uniform(random.PRNGKey(0), (1000,), dtype=jnp.float64)
x.dtype # This will not work!


dtype('float32')

In [None]:
# from jax.config import config
# config.update("jax_enable_x64", True)
# Must run the above commands on start-up to use float64.

## JIT simple example
---
This is a simple example of using `jit` functionality in Jax. Note that if you run the code twice, things will be different.

In [None]:
import numpy as np
import jax.numpy as jnp
from jax import jit

In [None]:
@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)


In [None]:
f(x, y)

Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=0/1)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=0/1)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=0/1)>


DeviceArray([1.1194923, 2.7518752, 3.0375113], dtype=float32)

### Pure functions and jit

In [None]:
g = 0.
def impure_uses_globals(x):
  return x + g

print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  
print ("Second call: ", jit(impure_uses_globals)(5.))
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))

# Question: Why does this happen?

First call:  4.0
Second call:  5.0
Third call, different type:  [14.]


In [None]:
g = 0.
def impure_saves_global(x):
  global g
  g = x
  return x

print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g)  


First call:  4.0
Saved global:  Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>


In [None]:
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

45
0


In [None]:
# This works
class Counter:
  """A simple counter."""

  def __init__(self):
    self.n = 0

  def count(self) -> int:
    """Increments the counter and returns the new value."""
    self.n += 1
    return self.n

  def reset(self):
    """Resets the counter to zero."""
    self.n = 0


counter = Counter()

for _ in range(3):
  print(counter.count())

1
2
3


In [None]:
# Doesn't work
counter.reset()
fast_count = jit(counter.count)

for _ in range(3):
  print(fast_count())

1
1
1


In [None]:
from typing import Tuple

CounterState = int

class CounterV2:

  def count(self, n: CounterState) -> Tuple[int, CounterState]:
    return n+1, n+1

  def reset(self) -> CounterState:
    return 0

counter = CounterV2()
state = counter.reset()

fast_count = jax.jit(counter.count)

for _ in range(3):
  value, state = fast_count(state)
  print(value)

NameError: ignored

### Tracing and static variables/operations in jit


In [None]:
# This will NOT work!
@jit
def f(x, neg):
  return -x if neg else x

f(1, True)

In [None]:
# This works
from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
  return -x if neg else x

f(1, True)

DeviceArray(-1, dtype=int32)

In [None]:
# This doesn't work
@jit
def f(x):
  return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
f(x)

ConcretizationTypeError: ignored

In [None]:
# To understand why, it is good practice to print which is traced and which is not
@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # comment this out to avoid the error:
  # return x.reshape(jnp.array(x.shape).prod())

f(x)

x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=0/1)>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=0/1)>


In [None]:
# Just replace it with numpy solves the problem
@jit
def f(x):
  return x.reshape((np.prod(x.shape),))

f(x)

DeviceArray([1., 1., 1., 1., 1., 1.], dtype=float32)

### Control flow

In [None]:
@jit
def f(x, neg):
  return lax.cond(neg, 
                  lambda x: x, 
                  lambda x: -x, 
                  x)
print(f(x, True))
print(f(x, False))

[[1. 1. 1.]
 [1. 1. 1.]]
[[-1. -1. -1.]
 [-1. -1. -1.]]


In [None]:
# Exercise: rewrite the following code using scan
def cum(x):
    for i in range(10):
        x+=10
    return x

## PyTree

In [None]:
from jax import tree_map, tree_multimap
list_of_lists = [
    [1, 2, 3],
    [1, 2],
    [1, 2, 3, 4]
]

tree_map(lambda x: x*2, list_of_lists)

[[2, 4, 6], [2, 4], [2, 4, 6, 8]]

In [None]:
another_list_of_lists = list_of_lists
tree_multimap(lambda x, y: x+y*2, list_of_lists, another_list_of_lists)

[[3, 6, 9], [3, 6], [3, 6, 9, 12]]

## vmap and pmap

In [None]:
def predict(W, b, input_vec):
    activations = input_vec
    outputs = jnp.dot(W, activations) + b  # `input_vec` on the right-hand side!
    activations = jnp.tanh(outputs)
    return outputs

In [None]:
from jax import vmap
W, b = jnp.ones((10, 5)), jnp.ones(1)
input_batch = jnp.ones((64,5))
predictions = vmap(predict, in_axes=(None, None, 0))(W, b, input_batch)
print(predictions)

# Implement Autograd using Jax

## Autograd basics

In [None]:
def toy1(w, b):
    return jnp.dot(w, b).sum()

In [None]:
w = jnp.ones((3, 3))
b = jnp.ones(3)

toy1(w, b)

DeviceArray(9., dtype=float32)

In [None]:
from jax import grad, value_and_grad
print(grad(toy1)(w, b))

[[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]]


In [None]:
print(grad(toy1, (0, ))(w, b))

(DeviceArray([[1., 1., 1.],
             [1., 1., 1.],
             [1., 1., 1.]], dtype=float32),)


In [None]:
print(grad(toy1, (1, ))(w, b))

(DeviceArray([3., 3., 3.], dtype=float32),)


In [None]:
print(grad(toy1, (0, 1))(w, b))

(DeviceArray([[1., 1., 1.],
             [1., 1., 1.],
             [1., 1., 1.]], dtype=float32), DeviceArray([3., 3., 3.], dtype=float32))


In [None]:
# A more convenient way
def toy2(params):
    return jnp.dot(params['w'], params['b']).sum()

params = dict()
params['w'] = w
params['b'] = b

print(toy2(params))

9.0


In [None]:
print(grad(toy2)(params))

{'b': DeviceArray([3., 3., 3.], dtype=float32), 'w': DeviceArray([[1., 1., 1.],
             [1., 1., 1.],
             [1., 1., 1.]], dtype=float32)}


## Defining our own autograd functions

In [None]:
from jax import custom_jvp

### This is the original function 
@custom_jvp
def log1pexp(x):
  return jnp.log(1. + jnp.exp(x))

### This is the definition of custom jvp 
@log1pexp.defjvp
def log1pexp_jvp(primals, tangents):
  x, = primals # This is the original input
  x_dot, = tangents # This is the tangents associated with x
  ans = log1pexp(x) # The function value
  ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot # The derivative w.r.t. to times the tangets of x
  return ans, ans_dot

## Entmax-$\alpha$ implementation

In [None]:
import jax.numpy as jnp
from jax import grad, jit, value_and_grad
from jax import vmap, pmap
from jax import random
import jax
from jax import lax
from jax import custom_jvp

def p_tau(z, tau, alpha=1.5):
    return jnp.clip((alpha - 1) * z - tau, 0) ** (1 / (alpha - 1))


def get_tau(tau, tau_max, tau_min, z_value):
    return lax.cond(z_value < 1,
                    lambda _: (tau, tau_min),
                    lambda _: (tau_max, tau),
                    operand=None
                    )
    
def body(kwargs, x):
    tau_min = kwargs['tau_min']
    tau_max = kwargs['tau_max']
    z = kwargs['z']
    alpha = kwargs['alpha']

    tau = (tau_min + tau_max) / 2
    z_value = p_tau(z, tau, alpha).sum()
    taus = get_tau(tau, tau_max, tau_min, z_value)
    tau_max, tau_min = taus[0], taus[1]
    return {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, None

def map_row(z_input, alpha, T):
    z = (alpha - 1) * z_input

    tau_min, tau_max = jnp.min(z) - 1, jnp.max(z) - z.shape[0] ** (1 - alpha)
    result, _ = lax.scan(body, {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, xs=None,
                         length=T)
    tau = (result['tau_max'] + result['tau_min']) / 2
    result = p_tau(z, tau, alpha)
    return result / result.sum()

def _entmax(input, axis=-1, alpha=1.5, T=20):
    result = vmap(lambda z: map_row(z, alpha, T), axis)(input) ## Pay attention here!
    return result

def entmax(input, axis=-1, alpha=1.5, T=10):
    return _entmax(input, axis, alpha, T)



In [None]:
import numpy as np
input = jnp.array(np.random.randn(64, 10)).block_until_ready()
weight = jnp.array(np.random.randn(64, 10)).block_until_ready()

def toy(input, weight):
    return (weight*entmax(input, 0, 1.5, 20)).sum()

toy(input, weight)

DeviceArray(8.167998, dtype=float32)

In [None]:
@jax.partial(custom_jvp, nondiff_argnums=(1, 2, 3,))
def entmax(input, axis=-1, alpha=1.5, T=10):
    return _entmax(input, axis, alpha, T)

def _entmax_jvp_impl(axis, alpha, T, primals, tangents):
    input = primals[0]
    Y = entmax(input, axis, alpha, T)
    gppr = Y  ** (2 - alpha)
    grad_output = tangents[0]
    dX = grad_output * gppr
    q = dX.sum(axis=axis) / gppr.sum(axis=axis)
    q = jnp.expand_dims(q, axis=axis)
    dX -= q * gppr
    return Y, dX


@entmax.defjvp
def entmax_jvp(axis, alpha, T, primals, tangents):
    return _entmax_jvp_impl(axis, alpha, T, primals, tangents)

In [None]:
import numpy as np
input = jnp.array(np.random.randn(64, 10)).block_until_ready()
weight = jnp.array(np.random.randn(64, 10)).block_until_ready()

def toy(input, weight):
    return (weight*entmax(input, 0, 1.5, 20)).sum()

value_and_grad(toy)(input, weight)

(DeviceArray(5.219076, dtype=float32),
 DeviceArray([[ 9.20430869e-02,  2.55589876e-02,  9.70999748e-02,
               -1.48003921e-01,  3.68221849e-01,  8.76951963e-02,
               -2.42806301e-01, -8.62555429e-02,  1.65654510e-01,
                0.00000000e+00],
              [-2.18896680e-02,  6.26884401e-02,  2.98321128e-01,
                0.00000000e+00, -4.10345271e-02, -1.58703014e-01,
                1.54489398e-01, -1.68707818e-01,  0.00000000e+00,
                1.38728499e+00],
              [-9.53757241e-02,  2.80622810e-01, -0.00000000e+00,
               -2.75156528e-01,  1.17984172e-02,  1.63770449e+00,
                9.73208435e-03,  0.00000000e+00,  7.07152393e-03,
                0.00000000e+00],
              [-5.86056001e-02,  3.85744303e-01,  2.90085286e-01,
                0.00000000e+00, -0.00000000e+00,  4.16471958e-02,
               -1.52479634e-01,  2.15069398e-01,  0.00000000e+00,
               -0.00000000e+00],
              [-3.53923082e-01, -3.46

In [None]:
import jax.numpy as jnp
from jax import grad, jit, value_and_grad
from jax import vmap, pmap
from jax import random
import jax
from jax import lax
from jax import custom_jvp


def p_tau(z, tau, alpha=1.5):
    return jnp.clip((alpha - 1) * z - tau, 0) ** (1 / (alpha - 1))


def get_tau(tau, tau_max, tau_min, z_value):
    return lax.cond(z_value < 1,
                    lambda _: (tau, tau_min),
                    lambda _: (tau_max, tau),
                    operand=None
                    )


def body(kwargs, x):
    tau_min = kwargs['tau_min']
    tau_max = kwargs['tau_max']
    z = kwargs['z']
    alpha = kwargs['alpha']

    tau = (tau_min + tau_max) / 2
    z_value = p_tau(z, tau, alpha).sum()
    taus = get_tau(tau, tau_max, tau_min, z_value)
    tau_max, tau_min = taus[0], taus[1]
    return {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, None

@jax.partial(jax.jit, static_argnums=(2,))
def map_row(z_input, alpha, T):
    z = (alpha - 1) * z_input

    tau_min, tau_max = jnp.min(z) - 1, jnp.max(z) - z.shape[0] ** (1 - alpha)
    result, _ = lax.scan(body, {'tau_min': tau_min, 'tau_max': tau_max, 'z': z, 'alpha': alpha}, xs=None,
                         length=T)
    tau = (result['tau_max'] + result['tau_min']) / 2
    result = p_tau(z, tau, alpha)
    return result / result.sum()

@jax.partial(jax.jit, static_argnums=(1,3,))
def _entmax(input, axis=-1, alpha=1.5, T=20):
    result = vmap(lambda z: map_row(z, alpha, T), axis)(input)
    return result

@jax.partial(custom_jvp, nondiff_argnums=(1, 2, 3,))
def entmax(input, axis=-1, alpha=1.5, T=10):
    return _entmax(input, axis, alpha, T)

@jax.partial(jax.jit, static_argnums=(0,2,))    
def _entmax_jvp_impl(axis, alpha, T, primals, tangents):
    input = primals[0]
    Y = entmax(input, axis, alpha, T)
    gppr = Y  ** (2 - alpha)
    grad_output = tangents[0]
    dX = grad_output * gppr
    q = dX.sum(axis=axis) / gppr.sum(axis=axis)
    q = jnp.expand_dims(q, axis=axis)
    dX -= q * gppr
    return Y, dX


@entmax.defjvp
def entmax_jvp(axis, alpha, T, primals, tangents):
    return _entmax_jvp_impl(axis, alpha, T, primals, tangents)

import numpy as np
input = jnp.array(np.random.randn(64, 10)).block_until_ready()
weight = jnp.array(np.random.randn(64, 10)).block_until_ready()

def toy(input, weight):
    return (weight*entmax(input, 0, 1.5, 20)).sum()

jax.jit(value_and_grad(toy))(input, weight)

(DeviceArray(3.5463314, dtype=float32),
 DeviceArray([[ 0.00000000e+00, -9.28829432e-01,  1.49818644e-01,
                8.19825530e-01,  0.00000000e+00, -9.33015406e-01,
               -1.50957555e-01, -2.59509310e-02,  1.53118121e-02,
               -1.00283481e-01],
              [-5.03038764e-02,  2.88019031e-02,  5.03074788e-02,
               -3.37607563e-01, -8.93707499e-02,  6.79402351e-01,
               -3.41596955e-04, -2.95041502e-01,  5.44383116e-02,
                1.10143265e-02],
              [ 2.80410666e-02,  3.04949880e-02,  0.00000000e+00,
               -0.00000000e+00,  4.97258455e-01,  5.13189249e-02,
               -4.07762965e-03,  0.00000000e+00, -3.51836272e-02,
               -3.70255321e-01],
              [-3.07102263e-01,  7.89324790e-02, -7.85121694e-03,
               -5.40790319e-01,  0.00000000e+00, -2.78178394e-01,
               -9.22627896e-02, -3.53259332e-02, -2.95178555e-02,
               -0.00000000e+00],
              [-2.68901922e-02, -3.8