In [1]:
import random
import itertools

import jax
import jax.numpy as np # always..
import numpy as onp # original numpy

In [2]:
def sigmoid(x):
  return 1 / (1 + np.exp(-x))

def net(params, x):
  w1, b1, w2, b2 = params
  hidden = np.tanh(np.dot(w1, x) + b1)
  return sigmoid(np.dot(w2, hidden) + b2)

def loss(params, x, y):
  out = net(params, x)
  cross_entropy = -y * np.log(out) - (1 - y) * np.log(1 - out)
  return cross_entropy

def test_all_inputs(inputs, params):
  preds = [int(net(params, x) > 0.5) for x in inputs]
  return (preds == [onp.bitwise_xor(*x) for x in inputs]) # note onp, note *x for unzip

In [3]:
# jax provides better reproducibility via random num gen
# however, we prefer onp here... why... colin why

def initial_params():
  return [
          onp.random.randn(3, 2), # w1
          onp.random.randn(3), # b1
          onp.random.randn(3), # w2
          onp.random.randn() # b2
  ]

In [4]:
loss_grad = jax.grad(loss)

eta = 1

inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])

params = initial_params()

for n in itertools.count():
  x = inputs[onp.random.choice(inputs.shape[0])]
  y = onp.bitwise_xor(*x)

  # by default, the gradient is taken with the first argument of "loss"
  # that is, params
  grads = loss_grad(params, x, y)

  params = [param - eta * grad for param, grad in zip(params, grads)]

  if not n % 100:
    print("Iteration", n)
    if test_all_inputs(inputs, params): break

Iteration 0
Iteration 100
Iteration 200


In [5]:
for param in params:
  print(param)

[[-2.915513  -3.1900535]
 [ 2.7019804  1.3214543]
 [-2.9132302 -3.1299062]]
[ 0.77140146 -2.70173     0.73939276]
[-1.5628929 -1.8294228 -1.5880487]
-2.5038054


In [6]:
# when i JIT compile my loss function... I achieve nice speedup
# this is for XLA stuff, and GPU / TPU

%timeit loss_grad(params, x, y) # x, y from final iteration :P

100 loops, best of 3: 12.5 ms per loop


In [7]:
# now, i'll JIT compile it
loss_grad = jax.jit(jax.grad(loss))

# call loss_grad once to compile
loss_grad(params, x, y)

# now time it. this is the right way, coz in sgd too, we'd be calling
# loss_grad lots of times after compilation
%timeit loss_grad(params, x, y)

1000 loops, best of 3: 433 µs per loop


In [8]:
# vmap => easily convert SGD into minibatch GD

loss_grad = jax.jit(jax.vmap(jax.grad(loss), in_axes = (None, 0, 0), out_axes = 0))

params = initial_params()

batch_size = 100

for n in itertools.count():
  x = inputs[onp.random.choice(inputs.shape[0], size = batch_size)]
  y = onp.bitwise_xor(x[:, 0], x[:, 1])
  grads = loss_grad(params, x, y)

  params = [param - eta * np.mean(grad, axis = 0) for param, grad in zip(params, grads)]

  # if n > 500: break

  if not n % 100:
    print(n)
    if test_all_inputs(inputs, params): break

0
100


In [9]:
for param in params:
  print(param)

[[ 1.6345334 -2.2520266]
 [-1.8182846 -1.1511141]
 [ 1.7510666  2.6979718]]
[ 1.9308393   0.63822705 -3.4900067 ]
[-2.5126188 -2.361304  -2.6519485]
-0.49163997


# More JAX basics

In [10]:
x = onp.random.normal(size=(3000,3000))
%timeit np.dot(x, x.T)

The slowest run took 7.97 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 58.4 ms per loop


In [11]:
%timeit onp.dot(x, x.T)
# notice the speedup <3

1 loop, best of 3: 786 ms per loop


In [12]:
%timeit np.dot(x, x.T).block_until_ready()
# a bit longer... coz jax implemented asynchronous dispatch before..
# that is, it gives us partial result of 3000 * 3000 without completing it entirely
# this is useful for accelerators which don't need entire output immediately  

10 loops, best of 3: 80.1 ms per loop


In [13]:
import jax.random as jrandom

key = jrandom.PRNGKey(0)
y = jrandom.normal(key=key, shape=(3000, 3000))
%timeit np.dot(y, y.T)

# this is much faster, coz no need to transfer numpy array from cpu to gpu 
# unlike before...
# the array is directly created in device (GPU) coz jax wala random 

100 loops, best of 3: 20.6 ms per loop


In [None]:
# computing Jacobian using forward and reverse mode
# y = C(B(A(x))) => dy/dx = dy/dc * dc/db * db/da * da/dx
# forward => dy/dx = (dy/dc * (dc/db * (db/da * da/dx)))
# reverse => dy/dx = (((dy/dc * dc/db) * db/da) * da/dx)

# which one's better? mostly reverse.. coz x \in R^n, y \in R for most opt probs

from jax import jacfwd, jacrev

# isolate w_1
w_1, b_1, w_2, b_2 = params
f = lambda w: loss([w, b_1, w_2, b_2], np.array([0, 1]), 1)

J = jacfwd(f)(w_1)
print(J.shape)
print(J)

J = jacrev(f)(w_1)
print(J.shape)
print(J)

(3, 2)
[[0.         0.00120079]
 [0.         0.00013758]
 [0.         0.07751656]]
(3, 2)
[[0.         0.00120079]
 [0.         0.00013758]
 [0.         0.07751656]]


In [None]:
# much more succintly
J_list = jacfwd(loss)(params, np.array([0,1]), 1)
print(J_list)

[DeviceArray([[0.        , 0.00120079],
             [0.        , 0.00013758],
             [0.        , 0.07751656]], dtype=float32), DeviceArray([0.00120079, 0.00013758, 0.07751656], dtype=float32), DeviceArray([ 0.06271198, -0.06307167,  0.05167241], dtype=float32), DeviceArray(-0.06309009, dtype=float32)]


In [None]:
# jacobian vector product jvp

from jax import jvp

# note that v is a vector in w_1 \in R^(3*2) space
v = onp.random.normal(size = w_1.shape)
loss_val, jvp_val = jvp(f, (w_1,), (v,))

print(loss_val, jvp_val)

0.06516814 0.023711631


In [None]:
# vector jacobian product vjp

loss_val, vjp_val = jax.vjp(f, w_1)

u = onp.random.normal(size = loss_val.shape)

v = vjp_val(u)
print(v) # which will be of shape w_1, ie, (3,2)

(DeviceArray([[ 0.        ,  0.03020284],
             [-0.        , -0.00504099],
             [-0.        , -0.01350596]], dtype=float32),)


In [None]:
# jit gotcha
# the point of compiling with @jit is.. we don't want to compile again..
# we want to ensure that the inputs to the fn remain same later on too..
# as in, say inp = [2, 3, 4] .. we expect that other inps are also of [a, b, c] form 

# when u encounter control flow type statements on inp in fn, @jit becomes mad