In [47]:
import jax
from jax import numpy as jnp, random as jr
from flax import linen as nn, optim
import numpy as np
import pandas as pd
import altair as alt
import time

## JAX basics

In [3]:
# a scalar function and its gradient
def x_squared(x):
    return x**2
g = jax.grad(x_squared)
x = 2.
print(f"x_squared({x}): {x_squared(x)}")
print(f"d/dx x_squared({x}): {g(x)}")

# use vmap to vectorize the functions
v_x_squared = jax.vmap(x_squared)
v_g = jax.vmap(g)
xs = jnp.linspace(-5, 5, 100)
ys = v_x_squared(xs)
gs = v_g(xs)

x_squared(2.0): 4.0
d/dx x_squared(2.0): 4.0


In [4]:
# stick the results in a dataframe and plot them
df = pd.concat([
    pd.DataFrame({'x': xs, 'y': ys, 'kind': 'value'}), # the values f(x)
    pd.DataFrame({'x': xs, 'y': gs, 'kind': 'grad'}), # the gradients g(x)
])
chart = alt.Chart(df).mark_line(size=3).encode(x='x', y='y', color='kind')
chart

## Two spirals dataset

In [5]:
def make_spirals(n_samples, noise_std=0., rotations=1.):
    ts = jnp.linspace(0, 1, n_samples)
    rs = ts ** 0.5
    thetas = rs * rotations * 2 * np.pi
    signs = np.random.randint(0, 2, (n_samples,)) * 2 - 1
    labels = (signs > 0).astype(int)

    xs = rs * signs * jnp.cos(thetas) + np.random.randn(n_samples) * noise_std
    ys = rs * signs * jnp.sin(thetas) + np.random.randn(n_samples) * noise_std
    points = jnp.stack([xs, ys], axis=1)
    return points, labels

In [42]:
points, labels = make_spirals(100, noise_std=0.05)
df = pd.DataFrame({'x': points[:, 0], 'y': points[:, 1], 'label': labels})

spirals_x_axis = alt.X('x', scale=alt.Scale(domain=[-1.5, 1.5], nice=False))
spirals_y_axis = alt.Y('y', scale=alt.Scale(domain=[-1.5, 1.5], nice=False))

spiral_chart = alt.Chart(df, width=350, height=300).mark_circle(stroke="white", size=80, opacity=1).encode(
    x=spirals_x_axis, y=spirals_y_axis, 
    color=alt.Color('label:N'))
spiral_chart.save('two_spirals.html')
spiral_chart

## A simple classifier

In [7]:
class MLPClassifier(nn.Module):
    hidden_layers: int = 2
    hidden_dim: int = 512
    n_classes: int = 2

    @nn.compact
    def __call__(self, x):
        for layer in range(self.hidden_layers):
            x = nn.Dense(self.hidden_dim)(x)
            x = nn.relu(x)
        x = nn.Dense(self.n_classes)(x)
        x = nn.log_softmax(x)
        return x

### Helper functions for initializing and training

In [8]:
# Somewhat confusingly, instantiating a Flax module gives you an object
# which contains functions, NOT state
classifier_fns = MLPClassifier()

def cross_entropy(logprobs, labels):
    one_hot_labels = jax.nn.one_hot(labels, logprobs.shape[1])
    return -jnp.mean(jnp.sum(one_hot_labels * logprobs, axis=-1))

def loss_fn(params, batch):
    logits = classifier_fns.apply({'params': params}, batch[0])
    loss = jnp.mean(cross_entropy(logits, batch[1]))
    return loss

loss_and_grad_fn = jax.value_and_grad(loss_fn)

### API for the classifier

In [9]:
def init_fn(input_shape, seed):
    rng = jr.PRNGKey(jnp.array(seed, int))
    dummy_input = jnp.ones((1, *input_shape))
    params = classifier_fns.init(rng, dummy_input)['params']
    optimizer_def = optim.Adam(learning_rate=1e-3)
    optimizer = optimizer_def.create(params)
    return optimizer

@jax.jit  # jit makes it go brrr
def train_step_fn(optimizer, batch):
    loss = loss_fn(optimizer.target, batch)
    loss, grad = loss_and_grad_fn(optimizer.target, batch)
    optimizer = optimizer.apply_gradient(grad)
    return optimizer, loss

@jax.jit  # jit makes it go brrr
def predict_fn(optimizer, x):
    x = jnp.array(x)
    return classifier_fns.apply({'params': optimizer.target}, x)

### Running the network

In [126]:
model_state = init_fn(input_shape=(2,), seed=0)
for i in range(100):
    model_state, loss = train_step_fn(model_state, (points, labels))
print(loss)

0.0111716185


In [127]:
def all_preds(model_state):
    grid_size = 50
    width = 1.5
    x0s, x1s = jnp.meshgrid(jnp.linspace(-width, width, grid_size),
                            jnp.linspace(-width, width, grid_size))
    xs = jnp.stack([x0s, x1s]).transpose().reshape((-1, 2))
    preds = predict_fn(model_state, xs)    
    return xs, preds

xs, preds = all_preds(model_state)

In [131]:
data = {'x': xs[:, 0], 'y': xs[:, 1], 'pred': jnp.exp(preds)[:, 1]}
df = pd.DataFrame(data)
pred_chart = alt.Chart(df, width=240, height=240, title="Predictions from MLP").mark_square(size=50, opacity=1).encode(
    x=spirals_x_axis, 
    y=spirals_y_axis,
    color=alt.Color('pred', scale=alt.Scale(scheme='blueorange')),
)
chart = pred_chart + spiral_chart
chart.save('mlp_pred.html')
chart

## Parallelizing training

In [69]:
parallel_init_fn = jax.vmap(init_fn, in_axes=(None, 0))
parallel_train_step_fn = jax.vmap(train_step_fn, in_axes=(0, None))

N = 100
seeds = jnp.linspace(0, N - 1, N)

model_states = parallel_init_fn((2,), seeds)
for i in range(100):
    model_states, losses = parallel_train_step_fn(model_states, (points, labels))
print(losses)

[0.01117162 0.01037653 0.01939761 0.02218039 0.01346352 0.01723901
 0.00587268 0.00587268 0.00967717 0.01087084 0.01558958 0.0154999
 0.01204127 0.01605278 0.01605278 0.01004384 0.01449939 0.0228159
 0.01101634 0.0141475  0.01500179 0.00995679 0.01437751 0.01418262
 0.01447369 0.01447369 0.01164044 0.02000876 0.02000876 0.01940052
 0.01782143 0.01782143 0.02150425 0.01284952 0.01645697 0.01619485
 0.01471215 0.01507873 0.01523398 0.01634016 0.01620772 0.01100599
 0.01763341 0.01177186 0.00967445 0.01040979 0.018541   0.02595497
 0.01012443 0.00891571 0.00891571 0.01305469 0.00918238 0.01823689
 0.01071366 0.01531681 0.01531681 0.01637555 0.00583412 0.02080311
 0.01125573 0.0100053  0.0100053  0.00948207 0.00908887 0.01080691
 0.01182392 0.01861817 0.01880448 0.00895523 0.01122447 0.01460181
 0.01060087 0.01881975 0.01335606 0.00876098 0.01047368 0.01277807
 0.01438036 0.01076297 0.01345763 0.0133969  0.01012769 0.01770655
 0.0082693  0.00877883 0.01998596 0.00964176 0.01486988 0.010239

### Plotting each network's predictions

In [55]:
parallel_all_preds = jax.vmap(all_preds)
xs, batched_preds = parallel_all_preds(model_states)
xs = xs[0]

In [61]:
charts = []
for preds in batched_preds:
#     print(preds.shape)
    data = {'x': xs[:, 0], 'y': xs[:, 1], 'pred': jnp.exp(preds)[:, 1]}
    df = pd.DataFrame(data)
    single_chart = alt.Chart(df, width=240, height=240).mark_square(size=50, opacity=1).encode(
        x=spirals_x_axis, 
        y=spirals_y_axis,
        color=alt.Color('pred', scale=alt.Scale(scheme='blueorange')),
    )
    chart = single_chart + spiral_chart
    charts.append(chart)
chart = alt.hconcat(*charts[:2])
chart.save('multi_mlp_pred.html')
chart

## Bootstrapped ensembles

In [76]:
def get_first_seed(dataset_index):
    return jr.split(jr.PRNGKey(dataset_index))[0, 0]

get_first_seed(0)

DeviceArray(4146024105, dtype=uint32)

In [85]:
@jax.jit
def get_example(data_x, data_y, dataset_index, i):
    """Gets example `i` from the bootstrapped dataset with index `dataset_index`."""
    first_seed = get_first_seed(dataset_index)
    dataset_size = data_x.shape[0]

    # only use dataset_size distinct seeds
    # this makes sure that our bootstrap-sampled dataset includes exactly 
    # `dataset_size` points.
    i = i % dataset_size

    point_seed = first_seed + i
    point_index = jr.randint(jr.PRNGKey(point_seed), shape=(),
                             minval=0, maxval=dataset_size)
    x_i = jax.lax.dynamic_index_in_dim(data_x, point_index,
                                       keepdims=False)
    y_i = jax.lax.dynamic_index_in_dim(data_y, point_index,
                                       keepdims=False)
    return x_i, y_i

get_example(points, labels, 0, 0)

(DeviceArray([ 0.7040819, -0.5774228], dtype=float32),
 DeviceArray(1, dtype=int32))

In [114]:
def bootstrap_multi_iterator(dataset, dataset_indices):
    """Creates an iterator which, at each step, returns a batch of batches.

    The kth batch is sampled from the bootstrapped resample of `dataset`
    with seed `seeds[k]`."""
    
    batch_size = 32
    dataset_indices = jnp.array(dataset_indices)
    data_x, data_y = dataset
    dataset_size = len(data_x)
    
    get_example_from_dataset = jax.partial(get_example, data_x, data_y)
    
    # for sampling a batch of data from one dataset
    get_batch = jax.vmap(get_example_from_dataset, in_axes=(None, 0))
    # for sampling a batch of data from _each_ dataset
    get_multibatch = jax.vmap(get_batch, in_axes=(0, None))

    def iterate_multibatch():
        """Construct an iterator which runs forever, at each step returning
        a batch of batches."""
        i = 0
        while True:
            indices = jnp.arange(i, i + batch_size, dtype=jnp.int32)
            yield get_multibatch(dataset_indices, indices)
            i += batch_size

    loader_iter = iterate_multibatch()
    return loader_iter

In [115]:
# same as before
parallel_init_fn = jax.vmap(init_fn, in_axes=(None, 0))
# vmap over both inputs now
bootstrap_train_step_fn = jax.vmap(train_step_fn, in_axes=(0, 0))

# make seeds 0 to N-1, which we use for initializing the network and bootstrapping
N = 100
seeds = jnp.linspace(0, N - 1, N).astype(jnp.int32)

model_states = parallel_init_fn((2,), seeds)
data_iterator = bootstrap_multi_iterator((points, labels), dataset_indices=seeds)
for i in range(100):
    x_batch, y_batch = next(data_iterator)
    model_states, losses = bootstrap_train_step_fn(model_states, (x_batch, y_batch))
print(losses)

[0.14846763 0.09306543 0.24074371 0.26202717 0.26234168 0.18515839
 0.10521372 0.10521372 0.1059431  0.10932036 0.21017203 0.08179321
 0.14106122 0.2092368  0.2092368  0.07349811 0.1657829  0.25442493
 0.10656386 0.08645718 0.26358885 0.11631659 0.18151587 0.12747103
 0.15158615 0.15158615 0.17667997 0.17182218 0.17182218 0.07631133
 0.21079394 0.21079394 0.15786466 0.17365721 0.17031248 0.23151168
 0.18349862 0.2156504  0.11407475 0.1858593  0.17583668 0.11868238
 0.13702197 0.07849094 0.13553149 0.2394673  0.15495649 0.15809467
 0.1068648  0.1509621  0.1509621  0.18882835 0.06280423 0.18574849
 0.12810391 0.15844561 0.15844561 0.22531107 0.10022258 0.14315307
 0.10456075 0.11673002 0.11673002 0.11814751 0.1691907  0.22820649
 0.12136436 0.18610588 0.14557543 0.1453721  0.14257357 0.11867595
 0.12358713 0.19261132 0.18556489 0.13552938 0.17152317 0.22407383
 0.14999484 0.11981826 0.17370228 0.08921565 0.0759412  0.09591725
 0.12944266 0.11199161 0.12544012 0.19200137 0.14022672 0.0686

In [121]:
### Plotting bootstrapped models' predictions

In [116]:
xs, batched_preds = parallel_all_preds(model_states)
xs = xs[0]
charts = []
for preds in batched_preds:
    data = {'x': xs[:, 0], 'y': xs[:, 1], 'pred': jnp.exp(preds)[:, 1]}
    df = pd.DataFrame(data)
    bootstrap_pred_chart = alt.Chart(df, width=240, height=240).mark_square(size=50, opacity=1).encode(
        x=spirals_x_axis, 
        y=spirals_y_axis,
        color=alt.Color('pred', scale=alt.Scale(scheme='blueorange')),
    )
    chart = bootstrap_pred_chart + spiral_chart
    charts.append(chart)
chart = alt.hconcat(*charts[:2])
chart.save('bootstrap_mlp_pred.html')
chart

In [132]:
# ensemble predictions across our models
batched_probs = jnp.exp(batched_preds)
bootstrapped_probs = jnp.mean(batched_probs, axis=0)

data = {'x': xs[:, 0], 'y': xs[:, 1], 'pred': bootstrapped_probs[:, 1]}
df = pd.DataFrame(data)
ensemble_chart = alt.Chart(df, width=240, height=240, title="Predictions from bootstrap").mark_square(size=50, opacity=1).encode(
    x=spirals_x_axis, 
    y=spirals_y_axis,
    color=alt.Color('pred', scale=alt.Scale(scheme='blueorange', domain=[0, 1])),
)
chart = ensemble_chart + spiral_chart
chart.save('ensemble_mlp_pred.html')
chart

### Compare with a single model trained on the whole dataset

In [133]:
bootstrap_compare_chart = (pred_chart + spiral_chart) | (ensemble_chart + spiral_chart)
bootstrap_compare_chart.save('bootstrap_compare_pred.html')
bootstrap_compare_chart

In [122]:
(pred_chart + spiral_chart) | (ensemble_chart + spiral_chart)

In [123]:
bootstrap_compare_chart