In [1]:
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict

from optimizers import PAGE
from test_opts import MLP

import matplotlib.pyplot as plt


In [2]:
KEY = jax.random.PRNGKey(42)

In [3]:
XDIM = 128
BS = 256
BS_HAT = 16

true_A = jax.random.uniform(KEY, (XDIM, 1))

def generate_batch(batch_size=BS, key: jax.random.PRNGKey=KEY) -> tuple[jnp.ndarray, jnp.ndarray]:
    x = jax.random.uniform(key, (batch_size, XDIM), minval=0, maxval=3)
    y = jnp.dot(x, true_A) + jax.random.normal(key, (batch_size, 1))*0.001
    return x, y

In [4]:
model = MLP()
init_batch = generate_batch(BS)
variables = model.init(KEY, init_batch[0])

def loss_fn(variables: FrozenDict, batch: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
    x, y = batch
    pred = model.apply(variables, x)
    return jnp.mean((pred - y) ** 2), {}

opt = PAGE(loss_fn=loss_fn,
           eval_loss_fn=loss_fn,
           p=BS_HAT/(BS+BS_HAT),
           lr=1e-5,
           bs=BS,
           bs_hat=BS_HAT,
           need_jit=True)
state = opt.init(variables, init_batch)

In [5]:
# %%time
losses = []

for step in range(20_000):
	KEY, train_key = jax.random.split(KEY)
	batch = generate_batch(BS, train_key)
	loss, state = opt.update(state, batch)
	if step % 50 == 0 or step == 499:
		losses.append(loss)
		print(f"Step {step} | Loss: {loss:.4f}")

Step 0 | Loss: 9572.6133
Step 50 | Loss: 16.7431
Step 100 | Loss: 18.3288
Step 150 | Loss: 13.5051
Step 200 | Loss: 32.5960
Step 250 | Loss: 12.2121
Step 300 | Loss: 18.8980
Step 350 | Loss: 13.6375
Step 400 | Loss: 15.3263
Step 450 | Loss: 15.2845
Step 499 | Loss: 9.5949
Step 500 | Loss: 8.4468
Step 550 | Loss: 11.4807
Step 600 | Loss: 8.7649
Step 650 | Loss: 14.6219
Step 700 | Loss: 11.0443
Step 750 | Loss: 8.7406
Step 800 | Loss: 6.9903
Step 850 | Loss: 9.2044
Step 900 | Loss: 9.3376
Step 950 | Loss: 4.1958
Step 1000 | Loss: 10.5264
Step 1050 | Loss: 14.8288
Step 1100 | Loss: 5.6612
Step 1150 | Loss: 7.2619
Step 1200 | Loss: 6.9668
Step 1250 | Loss: 6.1871
Step 1300 | Loss: 3.2327
Step 1350 | Loss: 12.3944
Step 1400 | Loss: 5.7206
Step 1450 | Loss: 7.1044
Step 1500 | Loss: 3.1510
Step 1550 | Loss: 7.3383
Step 1600 | Loss: 5.6802
Step 1650 | Loss: 3.1582
Step 1700 | Loss: 5.5416
Step 1750 | Loss: 8.8929
Step 1800 | Loss: 7.3935
Step 1850 | Loss: 3.2722
Step 1900 | Loss: 4.8737
Step 1

In [6]:
# plt.plot(losses)
# plt.yscale('log')