In [None]:
%%bash
curl -sLO https://raw.githubusercontent.com/enakai00/colab_jaxbook/main/requirements.txt
pip install -r requirements.txt
pip list | grep -E '(jax|flax|optax)'

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import jax
from jax import random, numpy as jnp

plt.rcParams.update({'font.size': 12})

In [None]:
train_t = jnp.asarray([5.2, 5.7, 8.6, 14.9, 18.2, 20.4, 25.5, 26.4, 22.8, 17.5, 11.1, 6.6])

In [None]:
train_t = train_t.reshape([12, 1])

In [None]:
train_t

In [None]:
train_x = jnp.asarray([[month ** n for n in range(0, 5)] for month in range(1, 13)])

In [None]:
train_x

In [None]:
train_x.shape

In [None]:
key, key1 = random.split(random.PRNGKey(0))
w = random.normal(key1, [5, 1])
w

In [None]:
@jax.jit
def predict(w, x):
  y = jnp.matmul(x, w)
  return y

In [None]:
@jax.jit
def loss_fn(w, train_x, train_t):
  y = predict(w, train_x)
  loss = jnp.mean((y - train_t) ** 2)
  return loss

In [None]:
grad_loss = jax.jit(jax.grad(loss_fn))

In [None]:
%%time
learning_rate = 1e-8 * 1.4
for step in range(1, 5000001):
  grads = grad_loss(w, train_x, train_t)
  w = w - learning_rate * grads
  if step % 500000 == 0:
    loss_val = loss_fn(w, train_x, train_t)
    print('Step: {}, Loss: {:0.4f}'.format(step, loss_val), flush=True)

In [None]:
w

In [None]:
xs = np.linspace(1, 12, 100)
inputs = jnp.asarray([[month ** n for n in range(0, 5)] for month in xs])
ys = predict(w, inputs)

fig = plt.figure(figsize=(6, 4))
subplot = fig.add_subplot(1, 1, 1)
subplot.set_xlim(1, 12)
subplot.set_ylim(0, 30)
subplot.set_xticks(range(1, 3))
subplot.set_xlabel('Month')
subplot.set_ylabel('C')

subplot.scatter(range(1, 13), train_t)
_ = subplot.plot(xs, ys)

## Least Squares Method

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pandas import DataFrame

import jax, optax
from jax import random, numpy as jnp
from flax import linen as nn
from flax.training import train_state

plt.rcParams.update({'font.size': 12})

In [None]:
train_x = jnp.asarray([[month ** n for n in range(1, 5)] for month in range(1, 13)])

In [None]:
train_x

In [None]:
class TemperatureModel(nn.Module):
  @nn.compact
  def __call__(self, x):
    y = nn.Dense(features=1)(x)
    return y

In [None]:
key, key1 = random.split(random.PRNGKey(0))
variables = TemperatureModel().init(key1, train_x)

In [None]:
variables

In [None]:
TemperatureModel().apply(variables, train_x)

In [None]:
state = train_state.TrainState.create(
    apply_fn=TemperatureModel().apply,
    params=variables['params'],
    tx=optax.adam(learning_rate=0.001)
)

In [None]:
@jax.jit
def loss_fn(params, state, inputs, labels):
  predicts = state.apply_fn({'params': params}, inputs)
  loss = optax.l2_loss(predicts, labels).mean()
  return loss

In [None]:
@jax.jit
def train_step(state, inputs, labels):
  loss, grads = jax.value_and_grad(loss_fn)(state.params, state, inputs, labels)
  new_state = state.apply_gradients(grads=grads)
  return new_state, loss

In [None]:
%%time
loss_history = []
for step in range(1, 100001):
  state, loss_val = train_step(state, train_x, train_t)
  loss_history.append(jax.device_get(loss_val).tolist())
  if step % 10000 == 0:
    print('Step: {}, Loss: {:0.4f}'.format(step, loss_val), flush=True)

In [None]:
df = DataFrame({'Loss': loss_history})
df.index.name = 'Steps'
_ = df.plot(figsize=(6, 4), xlim=(0, 100))

df = DataFrame({'Loss': loss_history})
df.index.name = 'Steps'
_ = df.plot(figsize=(6, 4), ylim=(0, 8))

In [None]:
state.params

In [None]:
xs = np.linspace(1, 12, 100)
inputs = jnp.asarray([[month ** n for n in range(1, 5)] for month in xs])
ys = state.apply_fn({'params': state.params}, inputs)

fig = plt.figure(figsize=(6, 4))
subplot = fig.add_subplot(1, 1, 1)
subplot.set_xlim(1, 12)
subplot.set_ylim(0, 30)
subplot.set_xticks(range(1, 3))
subplot.set_xlabel('Month')
subplot.set_ylabel('C')

subplot.scatter(range(1, 13), train_t)
_ = subplot.plot(xs, ys)