In [1]:
import jax
import jax.numpy as jnp
import numpy as np

# Exploring grad

In [112]:
def piece_wise(arr):
    error = jnp.where(arr==0, 0, arr**2)
    return error.mean()

grad_piece_wise = jax.grad(piece_wise)
arr = jnp.array([0., 1., 2., 0., 100., 200.])
print(grad_piece_wise(arr))

[ 0.          0.33333334  0.6666667   0.         33.333336   66.66667   ]


In [113]:
def sinx(arr):
    error = jnp.sin(arr)
    return error.mean()

grad_sinx = jax.grad(sinx)
arr = jnp.array([0., 1., 2., 0., 100., 200.])
print(grad_sinx(arr))
print(jnp.cos(arr))

[ 0.16666667  0.09005038 -0.06935781  0.16666667  0.14371982  0.08119795]
[ 1.          0.5403023  -0.41614684  1.          0.8623189   0.48718768]


In [114]:
def loop(arr):
    for i in range(10):
        arr = sinx(arr)
    return arr.mean()

grad_loop = jax.grad(loop)
arr = jnp.array([0., 0., 10., 20.])
print(grad_loop(arr))

[ 0.24070115  0.24070115 -0.20196548  0.09822582]


# Jacobians and hessians
`jacfwd` uses forward-mode autodiff and so is better for "tall" Jacobian matrices, while `jacrev` uses reverse-mode and so is better for "wide" Jacobian matrices (with a grad corresponding to the special case of a Jacobian matrix consisting of a single row).

In [109]:
def square(arr):
    return jnp.power(arr, 2)

jacobian1 = jax.jacfwd(square_sum)
jacobian2 = jax.jacrev(square_sum)

In [110]:
arr = jnp.array([1., 2., 3.])
print(jacobian1(arr))
print(jacobian2(arr))

[[2. 0. 0.]
 [0. 4. 0.]
 [0. 0. 6.]]
[[2. 0. 0.]
 [0. 4. 0.]
 [0. 0. 6.]]


In [115]:
hess_piece_wise = jax.jacfwd(jax.grad(piece_wise))
hess_sinx = jax.jacfwd(jax.grad(sinx))
hess_loop = jax.jacfwd(jax.grad(loop))

In [116]:
arr = jnp.array([1., 2., 3.])
print(hess_piece_wise(arr))
print(hess_sinx(arr))
print(hess_loop(arr))

[[0.6666667 0.        0.       ]
 [0.        0.6666667 0.       ]
 [0.        0.        0.6666667]]
[[-0.28049034 -0.         -0.        ]
 [-0.         -0.30309916 -0.        ]
 [-0.         -0.         -0.04704   ]]
[[-0.10576932  0.02106125  0.05010367]
 [ 0.02106125 -0.10096754 -0.03859041]
 [ 0.05010366 -0.03859041 -0.10495693]]


# Jit

In [117]:
# Just like numba's autojit, simple

# Vmap

In [233]:
from functools import partial

def l2_v1(n, pred, target):
    residual = pred-target
    return residual**2/n

def l2_v2(pred, target):
    residual = pred-target
    return residual**2

pred = jnp.array([2., 3., 4.5])
target = jnp.array([1., 2., 3.])

loss_v1 = jax.vmap(partial(l2_v1, len(pred)))
loss_v2 = jax.vmap(l2_v2)

In [235]:
print(loss_v1(target, pred).sum())
print(loss_v2(target, pred).mean())

1.4166667
1.4166666


# Jax's grad, jit, vmap only support `pure functions`
What are `pure functions`?
They are:
1. Not affected by or does not affect global variables.
2. Return value only based on given arguments. one-one mapping. Same args must give same output. 

In [245]:
def weighted_sin_pure(x, w=[10]):
    return jnp.sin(x)*w[0]

pure = jax.grad(weighted_sin_pure)
print(pure(10.))

-8.390715


In [247]:
w = []
def weighted_sin_impure(x):
    global w
    w.append(10)
    return jnp.sin(x)*w[0]

impure = jax.grad(weighted_sin_impure)
print(impure(10.))

-8.390715


..... I don't think I understand `impure functions` yet. I guess I will deal with them when I get some error.

# Linear regression

In [251]:
X = np.random.rand(10, 2)
y = np.random.rand(10)
w = np.random.rand()