In [1]:
import jax
import jax.numpy as jnp

from optimizers import SGD
from test_opts import MLP

import matplotlib.pyplot as plt


In [2]:
KEY = jax.random.PRNGKey(42)

In [3]:
XDIM = 128
BS = 256

true_A = jax.random.uniform(KEY, (XDIM, 1))

def generate_batch(batch_size=BS, key: jax.random.PRNGKey=KEY) -> tuple[jnp.ndarray, jnp.ndarray]:
    x = jax.random.uniform(key, (batch_size, XDIM), minval=0, maxval=3)
    y = jnp.dot(x, true_A) + jax.random.normal(key, (batch_size, 1))*0.01
    return x, y

In [4]:
model = MLP()
x0, _ = generate_batch(BS)
params = model.init(KEY, x0)

def loss_fn(params, batch: tuple[jnp.ndarray, jnp.ndarray]) -> jnp.ndarray:
    x, y = batch
    pred = model.apply(params, x)
    return jnp.mean((pred - y) ** 2)

opt = SGD(lr=1e-5, loss_fn=loss_fn)

In [5]:
%%time
losses = []

for step in range(20_000):
	train_key, KEY = jax.random.split(KEY)
	batch = generate_batch(BS, train_key)
	new_params, state = opt.update(params, batch, state=None)
	params = new_params
	if step % 50 == 0 or step == 499:
		loss = loss_fn(params, batch)
		losses.append(loss)
		print(f"Step {step} | Loss: {loss:.4f}")

Step 0 | Loss: 9576.6406
Step 50 | Loss: 16.2252
Step 100 | Loss: 17.3694
Step 150 | Loss: 16.8205
Step 200 | Loss: 13.2355
Step 250 | Loss: 13.0969
Step 300 | Loss: 14.7088
Step 350 | Loss: 14.8927
Step 400 | Loss: 12.4836
Step 450 | Loss: 12.7970
Step 499 | Loss: 10.5046
Step 500 | Loss: 12.8932
Step 550 | Loss: 11.1468
Step 600 | Loss: 11.9507
Step 650 | Loss: 11.1245
Step 700 | Loss: 10.6491
Step 750 | Loss: 10.3700
Step 800 | Loss: 11.6847
Step 850 | Loss: 9.3029
Step 900 | Loss: 10.5449
Step 950 | Loss: 10.9886
Step 1000 | Loss: 8.6645
Step 1050 | Loss: 8.9899
Step 1100 | Loss: 7.6863
Step 1150 | Loss: 6.4369
Step 1200 | Loss: 6.4950
Step 1250 | Loss: 7.0492
Step 1300 | Loss: 7.2859
Step 1350 | Loss: 6.4091
Step 1400 | Loss: 5.4269
Step 1450 | Loss: 5.8520
Step 1500 | Loss: 6.5000
Step 1550 | Loss: 7.0424
Step 1600 | Loss: 6.5510
Step 1650 | Loss: 5.7963
Step 1700 | Loss: 5.5490
Step 1750 | Loss: 5.1196
Step 1800 | Loss: 4.6126
Step 1850 | Loss: 4.9115
Step 1900 | Loss: 5.0157
St

In [6]:
# plt.plot(losses)
# plt.yscale('log')