In [None]:
%%bash
curl -sLO https://raw.githubusercontent.com/enakai00/colab_jaxbook/main/requirements.txt
pip install -r requirements.txt
pip list | grep -E '(jax|flax|optax)'

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import jax
from jax import random, numpy as jnp

plt.rcParams.update({'font.size': 12})

In [None]:
train_t = jnp.asarray([5.2, 5.7, 8.6, 14.9, 18.2, 20.4, 25.5, 26.4, 22.8, 17.5, 11.1, 6.6])

In [None]:
train_t = train_t.reshape([12, 1])

In [None]:
train_t

In [None]:
train_x = jnp.asarray([[month ** n for n in range(0, 5)] for month in range(1, 13)])

In [None]:
train_x

In [None]:
train_x.shape

In [None]:
key, key1 = random.split(random.PRNGKey(0))
w = random.normal(key1, [5, 1])
w

In [None]:
@jax.jit
def predict(w, x):
  y = jnp.matmul(x, w)
  return y

In [None]:
@jax.jit
def loss_fn(w, train_x, train_t):
  y = predict(w, train_x)
  loss = jnp.mean((y - train_t) ** 2)
  return loss

In [None]:
grad_loss = jax.jit(jax.grad(loss_fn))

In [None]:
%%time
learning_rate = 1e-8 * 1.4
for step in range(1, 5000001):
  grads = grad_loss(w, train_x, train_t)
  w = w - learning_rate * grads
  if step % 500000 == 0:
    loss_val = loss_fn(w, train_x, train_t)
    print('Step: {}, Loss: {:0.4f}'.format(step, loss_val), flush=True)

In [None]:
w

In [None]:
xs = np.linspace(1, 12, 100)
inputs = jnp.asarray([[month ** n for n in range(0, 5)] for month in xs])
ys = predict(w, inputs)

fig = plt.figure(figsize=(6, 4))
subplot = fig.add_subplot(1, 1, 1)
subplot.set_xlim(1, 12)
subplot.set_ylim(0, 30)
subplot.set_xticks(range(1, 3))
subplot.set_xlabel('Month')
subplot.set_ylabel('C')

subplot.scatter(range(1, 13), train_t)
_ = subplot.plot(xs, ys)

## Least Squares Method

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pandas import DataFrame

import jax, optax
from jax import random, numpy as jnp
from flax import linen as nn
from flax.training import train_state

plt.rcParams.update({'font.size': 12})

In [None]:
train_x = jnp.asarray([[month ** n for n in range(1, 5)] for month in range(1, 13)])

In [None]:
train_x

In [None]:
class TemperatureModel(nn.Module):
  @nn.compact
  def __call__(self, x):
    y = nn.Dense(features=1)(x)
    return y

In [None]:
key, key1 = random.split(random.PRNGKey(0))
variables = TemperatureModel().init(key1, train_x)

In [None]:
variables

In [None]:
TemperatureModel().apply(variables, train_x)

In [None]:
state = train_state.TrainState.create(
    apply_fn=TemperatureModel().apply,
    params=variables['params'],
    tx=optax.adam(learning_rate=0.001)
)

In [None]:
@jax.jit
def loss_fn(params, state, inputs, labels):
  predicts = state.apply_fn({'params': params}, inputs)
  loss = optax.l2_loss(predicts, labels).mean()
  return loss

In [None]:
@jax.jit
def train_step(state, inputs, labels):
  loss, grads = jax.value_and_grad(loss_fn)(state.params, state, inputs, labels)
  new_state = state.apply_gradients(grads=grads)
  return new_state, loss

In [None]:
%%time
loss_history = []
for step in range(1, 100001):
  state, loss_val = train_step(state, train_x, train_t)
  loss_history.append(jax.device_get(loss_val).tolist())
  if step % 10000 == 0:
    print('Step: {}, Loss: {:0.4f}'.format(step, loss_val), flush=True)

In [None]:
df = DataFrame({'Loss': loss_history})
df.index.name = 'Steps'
_ = df.plot(figsize=(6, 4), xlim=(0, 100))

df = DataFrame({'Loss': loss_history})
df.index.name = 'Steps'
_ = df.plot(figsize=(6, 4), ylim=(0, 8))

In [None]:
state.params

In [None]:
xs = np.linspace(1, 12, 100)
inputs = jnp.asarray([[month ** n for n in range(1, 5)] for month in xs])
ys = state.apply_fn({'params': state.params}, inputs)

fig = plt.figure(figsize=(6, 4))
subplot = fig.add_subplot(1, 1, 1)
subplot.set_xlim(1, 12)
subplot.set_ylim(0, 30)
subplot.set_xticks(range(1, 3))
subplot.set_xlabel('Month')
subplot.set_ylabel('C')

subplot.scatter(range(1, 13), train_t)
_ = subplot.plot(xs, ys)

## Logistic regression

In [None]:
key, key1, key2, key3 = random.split(random.PRNGKey(0), 4)
n0, mu0, variance0 = 20, [10, 11], 20
data0 = random.multivariate_normal(
    key1, jnp.asarray(mu0), jnp.eye(2)*variance0, jnp.asarray([n0])
)
data0 = jnp.hstack([data0, jnp.zeros([n0, 1])])

In [None]:
data0[:10]

In [None]:
n1, mu1, variance1 = 15, [18, 20], 22
data1 = random.multivariate_normal(
    key2, jnp.asarray(mu1), jnp.eye(2)*variance1, jnp.asarray([n1])
)
data1 = jnp.hstack([data1, jnp.ones([n1, 1])])

In [None]:
data1[:10]

In [None]:
data = random.permutation(key3, jnp.vstack([data0, data1]))

In [None]:
train_x, train_t = jnp.split(data, [2], axis=1)

In [None]:
train_x[:10]

In [None]:
train_t[:10]

In [None]:
class LogisticRegression(nn.Module):
  @nn.compact
  def __call__(self, x, get_logits=False):
    x = nn.Dense(features=1)(x)
    if get_logits:
      return x
    x = nn.sigmoid(x)
    return x

In [None]:
key, key1 = random.split(key, 2)
variables = LogisticRegression().init(key1, train_x)

In [None]:
variables

In [None]:
state = train_state.TrainState.create(
    apply_fn=LogisticRegression().apply,
    params=variables['params'],
    tx=optax.adam(learning_rate=0.001)
)

In [None]:
@jax.jit
def loss_fn(params, state, inputs, labels):
  logits = state.apply_fn({'params': params}, inputs, get_logits=True)
  loss = optax.sigmoid_binary_cross_entropy(logits, labels).mean()
  acc = jnp.mean(jnp.sign(logits) == jnp.sign(labels-0.5))
  return loss, acc

In [None]:
@jax.jit
def train_step(state, inputs, labels):
  (loss, acc), grads = jax.value_and_grad(loss_fn, has_aux=True)(
      state.params, state, inputs, labels
  )
  new_state = state.apply_gradients(grads=grads)
  return new_state, loss, acc

In [None]:
%%time
loss_history, acc_history = [], []
for step in range(1, 10001):
  state, loss, acc = train_step(state, train_x, train_t)
  loss_history.append(jax.device_get(loss).tolist())
  acc_history.append(jax.device_get(acc).tolist())
  if step % 1000 == 0:
    print('Step: {}, Loss: {:.4f}, Acc: {:.4f}'.format(step, loss, acc), flush=True)

In [None]:
df = DataFrame({'Accuracy': acc_history})
df.index.name = 'Steps'
_ = df.plot(figsize=(6, 4))

df = DataFrame({'Loss': loss_history})
df.index.name = 'Steps'
_ = df.plot(figsize=(6, 4))

In [None]:
state.params

In [None]:
[w1], [w2] = state.params['Dense_0']['kernel']
[b] = state.params['Dense_0']['bias']

In [None]:
train_set0 = [jax.device_get(x).tolist()
              for x, t in zip(train_x, train_t) if t == 0]
train_set1 = [jax.device_get(x).tolist()
              for x, t in zip(train_x, train_t) if t == 1]

In [None]:
fig = plt.figure(figsize=(7, 7))
subplot = fig.add_subplot(1, 1, 1)
subplot.set_xlim([0, 30])
subplot.set_ylim([0, 30])
subplot.set_xlabel('x1')
subplot.set_ylabel('x2')
subplot.scatter([x for x, y in train_set1],
                [y for x, y in train_set1], marker='x')
subplot.scatter([x for x, y in train_set0],
                [y for x, y in train_set0], marker='o')

xs = np.linspace(0, 30, 10)
ys = - (w1*xs/w2 + b/w2)
subplot.plot(xs, ys)

locations = [[x1, x2] for x2 in np.linspace(0, 30, 100) 
                      for x1 in np.linspace(0, 30, 100)]
p_vals = state.apply_fn(
    {'params': state.params}, np.array(locations)).reshape([100, 100])
_ = subplot.imshow(p_vals, origin='lower', extent=(0, 30, 0, 30),
                   vmin=0, vmax=1, cmap=plt.cm.gray_r, alpha=0.4)

## LogisticRegression classification on MNIST

In [None]:
import numpy as np
import matplotlib.pyplot as plot
from tensorflow.keras.datasets import mnist

plt.rcParams.update({'font.size': 12})

In [None]:
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
print(train_labels[:3])
train_images = train_images.reshape([-1, 784]).astype('float32') / 255.
test_images = test_images.reshape([-1, 784]).astype('float32') / 255.
train_labels = np.eye(10)[train_labels]  # one hot encoding
test_labels = np.eye(10)[test_labels]

In [None]:
fig = plt.figure(figsize=(8, 4))
for c, (image, label) in enumerate(zip(train_images[:10], train_labels[:10])):
  subplot = fig.add_subplot(2, 5, c + 1)
  subplot.set_xticks([])
  subplot.set_yticks([])
  subplot.set_title(np.argmax(label))
  subplot.imshow(image.reshape([28, 28]), vmin=0, vmax=1, cmap=plt.cm.gray_r)

In [None]:
import jax, optax
from jax import random, numpy as jnp
from flax import linen as nn
from flax.training import train_state

In [None]:
def create_batches(data, batch_size):
  num_batches, mod = divmod(len(data), batch_size)
  data_batched = np.split(data[:num_batches * batch_size], num_batches)
  if mod:
    data_batched.append(data[num_batches * batch_size:])  # Remaining data less than the batch_size
  data_batched = [jnp.asarray(x) for x in data_batched]
  return data_batched

In [None]:
class SoftmaxModel(nn.Module):
  @nn.compact
  def __call__(self, x, get_logits=False):
    x = nn.Dense(features=10)(x)
    if get_logits:
      return x
    x = nn.softmax(x)
    return x

In [None]:
key, key1 = random.split(random.PRNGKey(0))
variables = SoftmaxModel().init(key1, train_images[:1])
jax.tree_util.tree_map(lambda x: x.shape, variables['params'])

In [None]:
state = train_state.TrainState.create(
    apply_fn=SoftmaxModel().apply,
    params=variables['params'],
    tx=optax.adam(learning_rate=0.001)
)

In [None]:
@jax.jit
def loss_fn(params, state, inputs, labels):
  logits = state.apply_fn({'params': params}, inputs, get_logits=True)
  loss = optax.softmax_cross_entropy(logits, labels).mean()  # categorical cross entropy
  acc = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(labels, -1))
  return loss, acc

In [None]:
@jax.jit
def train_step(state, inputs, labels):
  (loss, acc), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params, state, inputs, labels)
  new_state = state.apply_gradients(grads=grads)  
  return new_state, loss, acc

In [None]:
def train_epoch(state, input_batched, label_batched, eval):
  loss_history, acc_history = [], []
  for inputs, labels in zip(input_batched, label_batched):
    new_state, loss, acc = train_step(state, inputs, labels)
    if not eval:
      state = new_state
    loss_history.append(jax.device_get(loss).tolist())
    acc_history.append(jax.device_get(acc).tolist())
  return state, np.mean(loss_history), np.mean(acc_history)

In [None]:
def fit(state, train_inputs, train_labels, test_inputs, test_labels, epochs, batch_size):
  train_inputs_batched = create_batches(train_inputs, batch_size)
  train_labels_batched = create_batches(train_labels, batch_size)
  test_inputs_batched = create_batches(test_inputs, batch_size)
  test_labels_batched = create_batches(test_labels, batch_size)

  loss_history_train, acc_history_train = [], []
  loss_history_test, acc_history_test = [], []

  for epoch in range(1, epochs + 1):
    # Training
    state, loss_train, acc_train = train_epoch(state, train_inputs_batched, train_labels_batched, eval=False)
    loss_history_train.append(loss_train)
    acc_history_train.append(acc_train)

    # Evaluation
    _, loss_test, acc_test = train_epoch(state, test_inputs_batched, test_labels_batched, eval=True)
    loss_history_test.append(loss_test)
    acc_history_test.append(acc_test)

    print('Epoch: {}, Loss: {:.4f}, Acc: {:.4f} / '.format(epoch, loss_train, acc_train), end='', flush=True)
    print('Loss(test): {:.4f}, Acc(test): {:.4f}'.format(loss_test, acc_test), flush=True)

  history = {
      'loss_train': loss_history_train,
      'acc_train': acc_history_train,
      'loss_test': loss_history_test,
      'acc_test': acc_history_test
  }

  return state, history

In [None]:
%%time
state, history = fit(
    state, train_images, train_labels, test_images, test_labels,
    epochs=16, batch_size=128
)

In [None]:
df = DataFrame({'Acc(train)': history['acc_train'], 'Acc(test)': history['acc_test']})
df.index_name = 'Epochs'
ax = df.plot(figsize=(6, 4))
ax.set_xticks(df.index)
_ = ax.set_xticklabels(df.index+1)

df = DataFrame({'Loss(train)': history['loss_train'], 'Loss(test)': history['loss_test']})
df.index_name = 'Epochs'
ax = df.plot(figsize=(6, 4))
ax.set_xticks(df.index)
_ = ax.set_xticklabels(df.index+1)

In [None]:
predictions = jax.device_get(
    state.apply_fn({'params': state.params}, test_images))

df = DataFrame({'pred': list(map(np.argmax, predictions)),
                'label': list(map(np.argmax, jax.device_get(test_labels)))})
correct = df[df['pred']==df['label']]
incorrect = df[df['pred']!=df['label']]

fig = plt.figure(figsize=(8, 16))
for i in range(10):
    indices = list(correct[correct['pred']==i].index[:3]) \
                + list(incorrect[incorrect['pred']==i].index[:3])
    for c, image in enumerate(test_images[indices]):
        subplot = fig.add_subplot(10, 6, i*6+c+1)
        subplot.set_xticks([])
        subplot.set_yticks([])
        subplot.set_title('{} / {}'.format(i, df['label'][indices[c]]))
        subplot.imshow(image.reshape([28, 28]),
                       vmin=0, vmax=1, cmap=plt.cm.gray_r)

## MLP

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pandas import DataFrame

import jax, optax
from jax import random, numpy as jnp
from flax import linen as nn
from flax.training import train_state

plt.rcParams.update({'font.size': 12})

In [None]:
def generate_datablock(key, n, mu, cov, t):
  data = random.multivariate_normal(
      key, jnp.asarray(mu), jnp.asarray(cov), jnp.asarray([n])
  )
  data = jnp.hstack([data, jnp.ones([n, 1]) * t])
  return data

key, key1, key2, key3, key4, key5 = random.split(random.PRNGKey(0), 6)
data1 = generate_datablock(key1, 15, [-3, -8], [[22, 0], [0, 22]], 0)
data2 = generate_datablock(key2, 15, [13, -8], [[22, 0], [0, 22]], 0)
data3 = generate_datablock(key3, 20, [-2, 8], [[40, 0], [0, 40]], 0)
data4 = generate_datablock(key4, 25, [8, 3], [[14, 4], [4, 14]], 1)

data = random.permutation(key5, jnp.vstack([data1, data2, data3, data4]))
train_x, train_t = jnp.split(data, [2], axis=1)

In [None]:
print(data1.shape, data1[0])
print(data2.shape, data2[0])
print(data3.shape, data3[0])
print(data4.shape, data4[0])
print(data.shape, data[0])
print(train_x.shape, train_t.shape)

In [None]:
class SingleLayerModel(nn.Module):
  @nn.compact
  def __call__(self, x, get_logits=False):
    x = nn.Dense(features=2, name='HiddenLayer')(x)
    x = nn.tanh(x)
    x = nn.Dense(features=1, name='OutputLayer')(x)
    if get_logits:
      return x
    x = nn.sigmoid(x)
    return x

In [None]:
key, key1 = random.split(key)
variables = SingleLayerModel().init(key1, train_x)

jax.tree_util.tree_map(lambda x: x.shape, variables['params'])

In [None]:
state = train_state.TrainState.create(
    apply_fn=SingleLayerModel().apply,
    params=variables['params'],
    tx=optax.adam(learning_rate=0.001)
)

In [None]:
@jax.jit
def loss_fn(params, state, inputs, labels):
  logits = state.apply_fn({'params': params}, inputs, get_logits=True)
  loss = optax.sigmoid_binary_cross_entropy(logits, labels).mean()
  acc = jnp.mean(jnp.sign(logits) == jnp.sign(labels-0.5))
  return loss, acc

In [None]:
@jax.jit
def train_step(state, inputs, labels):
  (loss, acc), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params, state, inputs, labels)
  new_state = state.apply_gradients(grads=grads)
  return new_state, loss, acc

In [None]:
%%time
loss_history, acc_history = [], []
for step in range(1, 5001):
  state, loss, acc = train_step(state, train_x, train_t)
  loss_history.append(jax.device_get(loss).tolist())
  acc_history.append(jax.device_get(acc).tolist())
  if step % 1000 == 0:
    print('Step: {}, Loss: {:.4f}, Acc: {:.4f}'.format(step, loss, acc), flush=True)

In [None]:
df = DataFrame({'Acc': acc_history})
df.index_name = 'Steps'
_ = df.plot(figsize=(6, 4))

df = DataFrame({'Loss': loss_history})
df.index_name = 'Steps'
_ = df.plot(figsize=(6, 4))

In [None]:
train_set0 = [jax.device_get(x).tolist()
              for x, t in zip(train_x, train_t) if t == 0]
train_set1 = [jax.device_get(x).tolist()
              for x, t in zip(train_x, train_t) if t == 1]

fig = plt.figure(figsize=(7, 7))
subplot = fig.add_subplot(1, 1, 1)
subplot.set_ylim([-15, 15])
subplot.set_xlim([-15, 15])
subplot.set_xlabel('x1')
subplot.set_ylabel('x2')
subplot.scatter([x for x, y in train_set1],
                [y for x, y in train_set1], marker='x')
subplot.scatter([x for x, y in train_set0],
                [y for x, y in train_set0], marker='o')

locations = [[x1, x2] for x2 in np.linspace(-15, 15, 500) 
                      for x1 in np.linspace(-15, 15, 500)]
p_vals = state.apply_fn({'params': state.params},
                        np.array(locations)).reshape([500, 500])
_ = subplot.imshow(p_vals, origin='lower', extent=(-15, 15, -15, 15),
                   vmin=0, vmax=1, cmap=plt.cm.gray_r, alpha=0.4)

## MLP on MNIST

In [None]:
class SingleLayerSoftmaxModel(nn.Module):
  num_nodes: int =1024

  @nn.compact
  def __call__(self, x, get_logits=False):
    x = nn.Dense(features=self.num_nodes)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    if get_logits:
      return x
    x = nn.softmax(x)
    return x