In [1]:
import jax
import jax.numpy as jnp

from optimizers import PAGE
from test_opts import MLP

import matplotlib.pyplot as plt


In [2]:
KEY = jax.random.PRNGKey(42)

In [3]:
XDIM = 128
BS = 256
BS_HAT = 16

true_A = jax.random.uniform(KEY, (XDIM, 1))

def generate_batch(batch_size=BS, key: jax.random.PRNGKey=KEY) -> tuple[jnp.ndarray, jnp.ndarray]:
    x = jax.random.uniform(key, (batch_size, XDIM), minval=0, maxval=3)
    y = jnp.dot(x, true_A) + jax.random.normal(key, (batch_size, 1))*0.001
    return x, y

In [4]:
model = MLP()
init_batch = generate_batch(BS)
params = model.init(KEY, init_batch[0])

def loss_fn(params: dict, batch: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
    x, y = batch
    pred = model.apply(params, x)
    return jnp.mean((pred - y) ** 2)

opt = PAGE(loss_fn=loss_fn,
           p=BS_HAT/(BS+BS_HAT),
           lr=1e-5,
           bs=BS,
           bs_hat=BS_HAT,
           need_jit=True)
state = opt.init(params, init_batch)

In [5]:
%%time
losses = []

for step in range(20_000):
	KEY, train_key = jax.random.split(KEY)
	batch = generate_batch(BS, train_key)
	new_params, state = opt.update(params, batch, state)
	params = new_params
	if step % 50 == 0 or step == 499:
		loss = loss_fn(params, batch)
		losses.append(loss)
		print(f"Step {step} | Loss: {loss:.4f}")

Step 0 | Loss: 9601.2324
Step 50 | Loss: 19.9341
Step 100 | Loss: 17.4572
Step 150 | Loss: 14.4246
Step 200 | Loss: 17.8038
Step 250 | Loss: 14.2952
Step 300 | Loss: 13.6857
Step 350 | Loss: 14.1862
Step 400 | Loss: 13.9985
Step 450 | Loss: 11.9915
Step 499 | Loss: 12.0208
Step 500 | Loss: 11.0850
Step 550 | Loss: 14.5749
Step 600 | Loss: 10.3985
Step 650 | Loss: 12.1444
Step 700 | Loss: 11.0124
Step 750 | Loss: 10.3109
Step 800 | Loss: 9.9666
Step 850 | Loss: 9.7815
Step 900 | Loss: 10.1695
Step 950 | Loss: 8.0181
Step 1000 | Loss: 8.5729
Step 1050 | Loss: 9.8154
Step 1100 | Loss: 7.8551
Step 1150 | Loss: 8.1890
Step 1200 | Loss: 7.5346
Step 1250 | Loss: 6.4746
Step 1300 | Loss: 6.6429
Step 1350 | Loss: 7.0684
Step 1400 | Loss: 6.9528
Step 1450 | Loss: 6.3874
Step 1500 | Loss: 5.9101
Step 1550 | Loss: 6.1463
Step 1600 | Loss: 6.2912
Step 1650 | Loss: 5.7880
Step 1700 | Loss: 4.9908
Step 1750 | Loss: 5.3412
Step 1800 | Loss: 5.1591
Step 1850 | Loss: 5.0426
Step 1900 | Loss: 5.4189
Step

In [6]:
# plt.plot(losses)
# plt.yscale('log')