In [1]:
%load_ext autoreload

In [4]:
%autoreload 2
import jax
import jax.numpy as jnp
import numpy as np

import random
import itertools
from __future__ import print_function

In [47]:
# based on colin raffel's blog post
# learn XOR with a nn
# single hidden layer, 3 neurons, tanh nonlinearity, cross-ent loss w/ sgd

# for the output of the net
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

def net(params, x):
    w1, b1, w2, b2 = params
    hidden = jnp.tanh(jnp.dot(w1, x) + b1)
    return sigmoid(jnp.dot(w2, hidden) + b2)

def ce_loss(params, x, y):
    y_hat = net(params, x)
    cross_entropy = -y*jnp.log(y_hat) - (1 - y)*jnp.log(1 - y_hat)
    return cross_entropy

def test_all_inputs(inputs, params):
    # threshold the output since it is a probability and we want a class (0 or 1)
    preds = [int(net(params, inp) > 0.5) for inp in inputs]
    for inp, out in zip(inputs, preds):
        print(inp, '->', out)
    return (preds == [np.bitwise_xor(*inp) for inp in inputs])

In [48]:
inputs = jnp.array([[0, 0], [0, 1], [1, 0], [1, 1]])
# when init params, do it randomly before training -> no need for jax (no compiliation or derivatives)
def init_params():
    return [
        np.random.randn(3, 2),   #w1
        np.random.randn(3),      #b1
        np.random.randn(3),      #w2
        np.random.randn(),       #b2
    ]

In [49]:
y = test_all_inputs(inputs, init_params())

[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1


In [50]:
# train: jax.grad takes a fn and returns new fn that takes grad of 
# orig fn wrt 1st arg (use argnums to change)
loss_grad = jax.grad(ce_loss)

# sgd lr
lr = 1.

params = init_params()

for n in itertools.count():
    # get single random input (pick a random row)
    x = inputs[np.random.choice(inputs.shape[0])]
    # compute target output
    y = np.bitwise_xor(*x)
    # get grad of loss
    grads = loss_grad(params, x, y)
    # update params with grad descent
    params = [param - lr * grad for param, grad in zip(params, grads)]
    # check every 100 iter if XOR works
    if n % 100 == 0:
        print(f'iter {n}')
        if test_all_inputs(inputs, params):
            break
               

iter 0
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 1
iter 100
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
iter 200
[0 0] -> 0
[0 1] -> 0
[1 0] -> 1
[1 1] -> 0
iter 300
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


In [51]:
# jax.jit
# time orig grad fn
%timeit loss_grad(params, x, y)

3.98 ms ± 40.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [53]:
loss_grad = jax.jit(jax.grad(ce_loss))

# run once to trigger jit comp
loss_grad(params, x, y)
%timeit loss_grad(params, x, y)

5.01 μs ± 9.55 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [55]:
# train using jit
params = init_params()

for n in itertools.count():
    x = inputs[np.random.choice(inputs.shape[0])]
    y = np.bitwise_xor(*x)
    grads = loss_grad(params, x, y)
    params = [param - grad * lr for grad, param in zip(grads, params)]

    if n % 100 == 0:
        print(f'iter {n}')
        if test_all_inputs(inputs, params):
            break

iter 0
[0 0] -> 1
[0 1] -> 1
[1 0] -> 1
[1 1] -> 1
iter 100
[0 0] -> 0
[0 1] -> 0
[1 0] -> 0
[1 1] -> 0
iter 200
[0 0] -> 0
[0 1] -> 1
[1 0] -> 1
[1 1] -> 0


In [None]:
# jax.vmap
