In [712]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
from jax.experimental import optimizers

## Use JIT to make things faster

In [43]:
def sigmoid(x):
    return jnp.exp(x) / (1 + jnp.exp(-x))

In [44]:
x = jnp.ones((1000, 1000))

In [49]:
%timeit -n 10 -r 100 sigmoid(x).block_until_ready()

2.54 ms ± 144 µs per loop (mean ± std. dev. of 100 runs, 10 loops each)


In [None]:
jitted_sigmoid = jit(sigmoid)

In [52]:
%timeit -n 10 -r 100 jitted_sigmoid(x).block_until_ready()

611 µs ± 113 µs per loop (mean ± std. dev. of 100 runs, 10 loops each)


This is ~5 times speedup!

## vmap

In [72]:
key = jax.random.PRNGKey(618)
x = random.normal(key, (30,))
batch_x = random.normal(key, (20, 30))
W = random.normal(key, (50, 30))

def forward(x, W):
    return x @ W.T

In [80]:
%timeit -n 100 -r 10 forward(x, W).block_until_ready()

332 µs ± 51.2 µs per loop (mean ± std. dev. of 10 runs, 100 loops each)


In [115]:
@jit
def forward_jit(x, W):
    return forward(x, W)

In [116]:
%timeit -n 100 -r 10 forward_jit(x, W).block_until_ready()

The slowest run took 11.45 times longer than the fastest. This could mean that an intermediate result is being cached.
94.3 µs ± 133 µs per loop (mean ± std. dev. of 10 runs, 100 loops each)


In [141]:
def forward(x, W):
    return x @ W.T

def loop_forward(batch_x, W):
    return jnp.stack([forward(x, W) for v in batch_x])

def batch_forward(batch_x, W):
    return batch_x @ W.T

vmap_forward = vmap(forward, in_axes=(0, None))

loop_forward_jit = jit(loop_forward)
batch_forward_jit = jit(batch_forward)
vmap_forward_jit = jit(vmap_forward)

In [156]:
%timeit -n 10 -r 10 loop_forward(batch_x, W).block_until_ready()
%timeit -n 10 -r 10 loop_forward_jit(batch_x, W).block_until_ready()
%timeit -n 10 -r 10 batch_forward(batch_x, W).block_until_ready()
%timeit -n 10 -r 10 batch_forward_jit(batch_x, W).block_until_ready()
%timeit -n 10 -r 10 vmap_forward(batch_x, W).block_until_ready()
%timeit -n 10 -r 10 vmap_forward_jit(batch_x, W).block_until_ready()

9.49 ms ± 771 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)
36.7 µs ± 10.8 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)
305 µs ± 19.6 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)
40.2 µs ± 10.5 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)
968 µs ± 132 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)
55.2 µs ± 15.8 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)


Only slightly slower than the batched version.

## Simple Prediction 

In [297]:
import bokeh
from bokeh.plotting import figure
from bokeh.plotting import output_notebook
from bokeh.plotting import show
import numpy as np
output_notebook()

In [298]:
n = 100
dim_in = 1
dim_out = 1
xs = np.random.normal(size=(n, dim_in))
noise = np.random.normal(scale=1, size=xs.shape)
ys = xs * 3 - 1 + noise

In [299]:
def plot(xs, ys, ys_pred=None):
    xs = xs.squeeze()
    ys = ys.squeeze()
    fig = figure(title='source data')
    fig.circle(x=xs, y=ys)
    if ys_pred is not None:
        ys_pred = ys_pred.squeeze()
        fig.line(x=xs, y=ys_pred)
    show(fig)

In [300]:
def predict(x, A, b):
    return x @ A.T + b

In [301]:
def loss(y_pred, y_true):
    return np.mean((y_pred - y_true) ** 2)

In [302]:
@jit
def forward(x, A, b, y_true):
    y_pred = predict(x, A, b)
    return loss(y_pred, y_true)

In [303]:
A = np.random.randn(dim_out, dim_in)
b = np.random.randn(dim_out)
take_grad = grad(forward, argnums=(1, 2))

In [304]:
# before
ys_pred = predict(xs, A, b)
plot(xs, ys, ys_pred)

# train
alpha = 0.01
for _ in range(100):
    dydA, dydb = take_grad(xs, A, b, ys)
    A -= alpha * dydA
    b -= alpha * dydb

# after
ys_pred = predict(xs, A, b)
plot(xs, ys, ys_pred)

## Simple MLP for MNIST

In [553]:
import torchvision
import torch
from string import ascii_lowercase
import random
from collections import namedtuple

In [548]:
def get_random_name():
    num_chars = 10
    return ''.join([random.choice(ascii_lowercase) for _ in range(num_chars)])

In [1001]:
class MLP(object):
    def __init__(self, dim_in, dim_out):
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.name = get_random_name()
        
    @staticmethod
    def initialize_params(key, mlp):
        keys = jax.random.split(key, 2)
        # kaiming initialization
        A = jax.random.normal(keys[0], (mlp.dim_out, mlp.dim_in)) * (jnp.sqrt(2 / mlp.dim_out))
        b = jax.random.normal(keys[0], (mlp.dim_out,)) * (jnp.sqrt(2 / mlp.dim_out))
        return {mlp.name: (A, b)}
        
    @staticmethod
    def forward(params, x):
        A, b = params
        return x @ A.T + b

In [1002]:
def relu(x):
    return x * (x > 0)

def softmax(x):
    return jnp.exp(x) / jnp.sum(jnp.exp(x), keepdims=True)
    
eps = jnp.finfo(jnp.float32).eps
def cross_entropy_loss(y_pred, y_true):
    y_pred += eps
    return jnp.sum(y_true * -jnp.log(y_pred))

def one_hot(x, num_classes):
    vec = jnp.zeros(num_classes)
    vec = vec.at[x].set(1)
    return vec

In [1057]:
class Module(object):
    @staticmethod
    def get_config(dim_in, dim_out, hidden_sizes):
        mlps = []
        all_sizes = [dim_in, *hidden_sizes, dim_out]
        for i in range(len(hidden_sizes)):
            layer_dim_in = all_sizes[i]
            layer_dim_out = all_sizes[i + 1]
            mlps.append(MLP(layer_dim_in, layer_dim_out))
        return {
            'mlps': mlps,
            'l2_lambda': 1e-7
        }
    
    @staticmethod
    def initialize_params(key, config):
        params = {}
        for mlp in config['mlps']:
            key, new_key = jax.random.split(key)
            params.update(MLP.initialize_params(new_key, mlp))
        return params
            
    @staticmethod
    def forward(params, config, x, y_true):
        for mlp in config['mlps']:
            x = mlp.forward(params[mlp.name], x)
            x = relu(x)
            
        y_pred = softmax(x)
        
        num_classes = config['mlps'][-1].dim_out
        y_true = one_hot(y_true, num_classes)
        
        pred_loss = cross_entropy_loss(y_pred, y_true)
        
        reg_loss = 0
        for mlp in config['mlps']:
            for w in params[mlp.name]:
                reg_loss += jnp.sum(config['l2_lambda'] * (w ** 2))
        
        return pred_loss + reg_loss

@jit
def update(params, config, x, y_true, opt_state, opt_update):
    forward = vmap(Module.forward, in_axes=(None, None, 0, 0))
    loss, grads = value_and_grad(forward)(params, config, x, y_true)
    opt_update(0, )

In [1058]:
batch_size = 128

In [1059]:
train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        '/tmp/', train=True, download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307,), (0.3081,)),
            np.ravel
        ])
    ),
    batch_size=batch_size,
    shuffle=True
)

In [1060]:
test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        '/tmp/', train=False, download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize((0.1307,), (0.3081,)),
            np.ravel
        ])
    ),
    batch_size=batch_size,
    shuffle=True
)

In [1061]:
sample_x, sample_y = next(iter(train_loader))
sample_x = jnp.asarray(sample_x)
sample_y = jnp.asarray(sample_y)

In [1062]:
dim_in = sample_x.shape[1]
hidden_sizes = [32, 16, 10]
dim_out = jnp.max(sample_y)

In [1063]:
key = jax.random.PRNGKey(0)
config = Module.get_config(dim_in, dim_out, hidden_sizes)
key, new_key = jax.random.split(key)
params = Module.initialize_params(new_key, config)

In [1064]:
num_epochs = 10
opt_step_size = 1e-3
opt_init, opt_update, get_params = optimizers.adam(step_size)
opt_state = opt_init(params)

batch_forward = vmap(Module.forward, in_axes=(None, None, 0, 0))
def batch_forward_sum_loss(params, config, batch_x, batch_y):
    loss = batch_forward(params, config, batch_x, batch_y)
    return jnp.sum(loss)

for epoch in range(num_epochs):
    losses = []
    for idx, (batch_x, batch_y) in enumerate(train_loader):
        batch_x = jnp.asarray(batch_x)
        batch_y = jnp.asarray(batch_y)
        update_fn = value_and_grad(batch_forward_sum_loss)
        loss, grads = update_fn(params, config, batch_x, batch_y)
        losses.append(loss)
        opt_state = opt_update(idx, grads, opt_state)
        params = get_params(opt_state)
    print('loss', np.mean(losses))

loss 199.77393
loss 96.81453
loss 76.54293


KeyboardInterrupt: 