In [1]:
import jax.numpy as jnp
from jax import grad,jit,vmap
from jax import random

key = random.PRNGKey(0)

In [2]:
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.))

0.070650816


`grad(f)(x)` is $\nabla f(x)$

In [3]:
print(grad(grad(jnp.tanh))(2.0))

-0.13621868


In [4]:
def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())

In [16]:
W_grad = grad(loss,argnums=0)(W,b)
b_grad = grad(loss,argnums=1)(W,b)
W_grad, b_grad = grad(loss, (0,1))(W,b)

`grad(f,i)` is equivalent to $\partial_i f$

In [19]:
def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    probs = preds * targets + (1 - preds) * (1 - targets)
    return - jnp.sum(jnp.log(probs))
print(grad(loss2)({'W':W, 'b':b}))

{'W': DeviceArray([-0.16965581, -0.8774649 , -1.4901346 ], dtype=float32), 'b': DeviceArray(-0.2922725, dtype=float32)}


In [20]:
from jax import value_and_grad

loss_val, Wb_grad = value_and_grad(loss, (0,1))(W,b)
loss_val, Wb_grad

(DeviceArray(3.0519397, dtype=float32),
 (DeviceArray([-0.16965581, -0.8774649 , -1.4901346 ], dtype=float32),
  DeviceArray(-0.2922725, dtype=float32)))

In [23]:
eps = 1e-4

b_grad_numerical = (loss(W,b + eps/2.) - loss(W,b - eps/2.))/eps
b_grad_numerical,grad(loss,1)(W,b)

(DeviceArray(-0.29325485, dtype=float32),
 DeviceArray(-0.2922725, dtype=float32))

In [35]:
k,sk = random.split(key)
vec = random.normal(sk, W.shape)
unit_vec = vec / jnp.sqrt(vec @ vec)

In [37]:
unit_vec @ unit_vec

DeviceArray(1., dtype=float32)

In [47]:
W_grad_num = (loss(W + unit_vec * eps/2.,b) - loss(W - unit_vec * eps/2, b))/eps
W_grad_num, unit_vec @ grad(loss,0)(W,b) 

(DeviceArray(-0.19550323, dtype=float32),
 DeviceArray(-0.19909073, dtype=float32))

In [50]:
from jax.test_util import check_grads
%time check_grads(loss, (W,b), order=2)

CPU times: user 187 ms, sys: 3.03 ms, total: 190 ms
Wall time: 187 ms


For a scalar-valued $f : R^n \to R $, Hessian at $x \in R^n$ is written as $\partial^2 f(x)$. THen we can evaluate $v \to \partial^2f(x) \cdot v$

In [54]:
from jax import jacfwd, jacrev


f = lambda W: predict(W, b, inputs)
J = jacfwd(f)(W)
J

DeviceArray([[ 0.05981752,  0.12883775,  0.08857594],
             [ 0.04015911, -0.04928619,  0.0068453 ],
             [ 0.12188288,  0.01406341, -0.3047072 ],
             [ 0.00140426, -0.00472516,  0.00263774]], dtype=float32)

In [None]:
grad()

In [58]:
predict(W,b,inputs)

DeviceArray([0.13262239, 0.9520671 , 0.6249393 , 0.99809873], dtype=float32)

In [33]:
(vec ** 2).sum()

DeviceArray(1.3960315, dtype=float32)

In [34]:
vec @ vec

DeviceArray(1.3960315, dtype=float32)

In [32]:
jnp.dot(vec,vec)

DeviceArray(1.3960315, dtype=float32)

In [60]:
import jax
jax.random.PRNGKey(0)

DeviceArray([0, 0], dtype=uint32)

## JAX to tflite example

In [61]:
import numpy as np
import tensorflow as tf
import functools

import time
import itertools

import numpy.random as npr

import jax.numpy as jnp
from jax import jit, grad, random
from jax.experimental import optimizers
from jax.experimental import stax



In [62]:
def _one_hot(x, k, dtype=np.float32):
    """Create a one-hot encoding of x of size k."""
    return np.array(x[:, None] == np.arange(k), dtype)

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images = train_images.astype(np.float32)
test_images = test_images.astype(np.float32)

train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [64]:
train_images

array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0.

In [83]:
def loss(params, batch):
    inputs,targets =batch
    preds = predict(params,inputs)
    return -jnp.mean(jnp.sum(preds * targets, axis=1))
def accuracy(params, batch):
    inputs, targets = batch
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(predict(params, inputs), axis=1)
    return jnp.mean(predicted_class == target_class)

init_random_params, predict = stax.serial(
    stax.Flatten,
    stax.Dense(1024),
    stax.Relu,
    stax.Dense(1024),
    stax.Relu,
    stax.Dense(10),
    stax.LogSoftmax
)

In [84]:
rng = random.PRNGKey(0)

In [85]:
step_size = 0.001
num_epochs = 10
batch_size = 128
momentum_mass = 0.9

In [89]:
out_shape,init_params = init_random_params(rng, (-1,28*28))

In [91]:
tf.nest.map_structure(np.shape, init_params)

[(),
 ((784, 1024), (1024,)),
 (),
 ((1024, 1024), (1024,)),
 (),
 ((1024, 10), (10,)),
 ()]

In [None]:
# opt_state = 

In [95]:
num_train = train_images.shape[0]
sample_batch = train_images[:batch_size], train_labels[:batch_size]



val, loss_grad = value_and_grad(loss)(init_params,sample_batch)

In [111]:
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)

In [113]:
def data_stream():
    rng = npr.RandomState(0)
    while True:
        perm = rng.permutation(num_train)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size:(i + 1) * batch_size]
            yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()

In [115]:
opt_init, opt_update, get_params = optimizers.momentum(
    step_size, mass=momentum_mass)

In [117]:
@jit
def update(i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)


In [118]:
_, init_params = init_random_params(rng, (-1, 28 * 28))

In [112]:
npr

<module 'numpy.random' from '/opt/conda/lib/python3.7/site-packages/numpy/random/__init__.py'>

In [110]:
init_params

[(),
 (DeviceArray([[ 0.0599803 ,  0.07477883, -0.02636744, ..., -0.00203827,
                 0.00477604,  0.0088953 ],
               [-0.05266815, -0.04449302,  0.05804382, ...,  0.03099251,
                 0.01045519, -0.01258917],
               [ 0.02642592, -0.05847212,  0.01522273, ...,  0.02278844,
                 0.00729971,  0.02985927],
               ...,
               [ 0.00589537,  0.03279144, -0.01694287, ...,  0.0098981 ,
                -0.07415088, -0.04778202],
               [-0.0611038 , -0.00923215,  0.03037605, ...,  0.02570599,
                -0.00972158, -0.03481346],
               [ 0.0474807 ,  0.01125032, -0.00361404, ..., -0.06323019,
                 0.01199978, -0.00674079]], dtype=float32),
  DeviceArray([-0.00016071,  0.01405241,  0.00313025, ...,  0.01631044,
               -0.02043703,  0.00848408], dtype=float32)),
 (),
 (DeviceArray([[ 0.03682379, -0.02499464,  0.02435319, ..., -0.00554798,
                -0.0256047 , -0.01538447],
          

In [106]:
loss_grad[0]

()

In [82]:
params

<function jax.example_libraries.stax.serial.<locals>.init_fun(rng, input_shape)>