In [66]:
import jax
from jax import numpy as jnp, random as jr
from flax import linen as nn, optim
import numpy as np
import pandas as pd
import altair as alt

## JAX basics

In [21]:
# a scalar function and its gradient
def x_squared(x):
    return x**2
g = jax.grad(x_squared)
x = 2.
print(f"x_squared({x}): {x_squared(x)}")
print(f"d/dx x_squared({x}): {grad_x_sq(x)}")

# use vmap to vectorize the functions
v_x_squared = jax.vmap(x_squared)
v_g = jax.vmap(g)
xs = jnp.linspace(-5, 5, 100)
ys = v_x_squared(xs)
gs = v_g(xs)

x_squared(2.0): 4.0
d/dx x_squared(2.0): 4.0


In [22]:
# stick the results in a dataframe and plot them
df = pd.concat([
    pd.DataFrame({'x': xs, 'y': ys, 'kind': 'value'}), # the values f(x)
    pd.DataFrame({'x': xs, 'y': gs, 'kind': 'grad'}), # the gradients g(x)
])
chart = alt.Chart(df).mark_line(size=3).encode(x='x', y='y', color='kind')
chart

## Two spirals dataset

In [118]:
def make_spirals(n_samples, noise_std=0., rotations=1.):
    ts = jnp.linspace(0, 1, n_samples)
    rs = ts ** 0.5
    thetas = rs * rotations * 2 * np.pi
    signs = np.random.randint(0, 2, (n_samples,)) * 2 - 1
    labels = (signs > 0).astype(int)

    xs = rs * signs * jnp.cos(thetas) + np.random.randn(n_samples) * noise_std
    ys = rs * signs * jnp.sin(thetas) + np.random.randn(n_samples) * noise_std
    points = jnp.stack([xs, ys], axis=1)
    return points, labels

In [169]:
points, labels = make_spirals(100, noise_std=0.05)
df = pd.DataFrame({'x': points[:, 0], 'y': points[:, 1], 'label': labels})
chart = alt.Chart(df, width=350, height=300).mark_circle().encode(x='x', y='y', color='label:N')
chart.save('two_spirals.html')
chart

## A simple classifier

In [120]:
class MLPClassifier(nn.Module):
    hidden_layers: int = 2
    hidden_dim: int = 512
    n_classes: int = 2

    @nn.compact
    def __call__(self, x):
        for layer in range(self.hidden_layers):
            x = nn.Dense(self.hidden_dim)(x)
            x = nn.relu(x)
        x = nn.Dense(self.n_classes)(x)
        x = nn.log_softmax(x)
        return x

### Helper functions for initializing and training

In [173]:
# Somewhat confusingly, instantiating a Flax module gives you an object
# which contains functions, NOT state
classifier_fns = MLPClassifier()

def cross_entropy(logprobs, labels):
    one_hot_labels = jax.nn.one_hot(labels, logprobs.shape[1])
    return -jnp.mean(jnp.sum(one_hot_labels * logprobs, axis=-1))

def loss_fn(params, batch):
    logits = classifier_fns.apply({'params': params}, batch[0])
    loss = jnp.mean(cross_entropy(logits, batch[1]))
    return loss

loss_and_grad_fn = jax.value_and_grad(loss_fn)

### API for the classifier

In [179]:
def init_fn(input_shape, seed):
    rng = jr.PRNGKey(jnp.array(seed, int))
    dummy_input = jnp.ones((1, *input_shape))
    params = classifier_fns.init(rng, dummy_input)['params']
    optimizer_def = optim.Adam(learning_rate=1e-3)
    optimizer = optimizer_def.create(params)
    return optimizer

@jax.jit  # jit makes it go brrr
def train_step_fn(optimizer, batch):
    loss = loss_fn(optimizer.target, batch)
    loss, grad = loss_and_grad_fn(optimizer.target, batch)
    optimizer = optimizer.apply_gradient(grad)
    return optimizer, loss

@jax.jit  # jit makes it go brrr
def predict_fn(optimizer, x):
    x = jnp.array(x)
    return classifier_fns.apply({'params': optimizer.target}, x)

### Running the network

In [194]:
model_state = init_fn(input_shape=(2,), seed=0)
for i in range(100):
    model_state, loss = train_step_fn(model_state, (points, labels))
print(loss)

0.011945118


In [127]:
def all_preds(model_state):
    grid_size = 30
    width = 1.5
    x0s, x1s = np.meshgrid(np.linspace(-width, width, grid_size),
                           np.linspace(-width, width, grid_size))
    xs = np.stack([x0s, x1s]).transpose().reshape((-1, 2))
    preds = predict_fn(model_state, xs)    
    return xs, preds

xs, preds = all_preds(model_state)

In [174]:
data = {'x': xs[:, 0], 'y': xs[:, 1], 'pred': jnp.exp(preds)[:, 1]}
df = pd.DataFrame(data)
chart = alt.Chart(df, width=320, height=320).mark_square(size=70, opacity=1).encode(
    x='x', y='y', 
    color=alt.Color('pred', scale=alt.Scale(scheme='viridis')),
)
chart.save('mlp_pred.html')
chart

## Parallelizing training

In [195]:
parallel_init_fn = jax.vmap(init_fn, in_axes=(None, 0))
parallel_train_step_fn = jax.vmap(train_step_fn, in_axes=(0, None))

seeds = jnp.linspace(0, 9, 10)

model_states = parallel_init_fn((2,), seeds)
for i in range(100):
    model_states, losses = parallel_train_step_fn(model_states, (points, labels))
end = time.time()
print(losses)

[0.01194512 0.01250279 0.01615315 0.01403342 0.01800855 0.01515956
 0.00658712 0.00957206 0.00750575 0.00901282]
