In [19]:
import jax.numpy as jnp
import numpy as np
import jax
from scipy.spatial.distance import cdist
from scipy.integrate import RK45
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import equinox as eqx
from tqdm import tqdm

jax.config.update("jax_platform_name", "cpu") # force cpu usage for maximum compatibility

In [20]:

seed = 42
np.random.seed(seed)
key = jax.random.key(seed)
depth = 2
width = 50

In [21]:
try:
    data = np.load('/Users/haydenoutlaw/Documents/Courses/SciML/ncsu-sciml/hw2/hw2_p3_data.npz')
    x_tr = data['x_tr']
    y_tr = data['y_tr']
    print("dataset loaded.\n")
except:
    raise RuntimeError("error loading dataset.\n")

indices = np.arange(len(x_tr))
train_idx, test_idx = train_test_split(indices, test_size=0.33, random_state=seed)

X_train, X_test = jnp.array(x_tr[train_idx]), jnp.array(x_tr[test_idx])
y_train, y_test = jnp.array(y_tr[train_idx]), jnp.array(y_tr[test_idx])

print("dataset format:")
print("xtr:", X_train.shape, "xts:", X_test.shape)
print("ytr:", y_train.shape, "yts", y_test.shape)


dataset loaded.

dataset format:
xtr: (335, 2) xts: (165, 2)
ytr: (335, 1) yts (165, 1)


In [22]:
## CUSTOM PRIMITIVE DEFINITIONS

class Linear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size))

    def __call__(self, x):
        return self.weight @ x + self.bias


class MLP(eqx.Module):
    layers: list
    activations: list

    def __init__(self, architecture, key, activation=jax.nn.relu):
        keys = jax.random.split(key, len(architecture) - 1)
        self.layers = [
            Linear(architecture[i], architecture[i+1], keys[i])
            for i in range(len(architecture) - 1)
        ]
        self.activations = [activation] * (len(self.layers) - 1) + [eqx.nn.Identity()]

    def __call__(self, x):
        for layer, act in zip(self.layers, self.activations):
            x = act(layer(x))
        return x
    

# CUSTOM STEP FUNCTIONALS

def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jnp.mean((y - pred_y) ** 2)

# train and eval step
@eqx.filter_jit
def train_step(model, opt_state, x, y, optimizer):
    loss, grads = eqx.filter_value_and_grad(loss_fn)(model, x, y)
    updates, opt_state = optimizer.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss

@eqx.filter_jit
def eval_step(model, x, y):
    return loss_fn(model, x, y)

In [23]:
arch = [x_tr.shape[1]] + [width] * depth + [y_tr.shape[1]]
new_model = MLP(arch, key, activation=jax.nn.gelu)
new_model = eqx.tree_deserialise_leaves("dnn_model_opt.eqx", new_model)

print("model architecture:")
print(new_model)


model architecture:
MLP(
  layers=[
    Linear(weight=f32[50,2], bias=f32[50]),
    Linear(weight=f32[50,50], bias=f32[50]),
    Linear(weight=f32[1,50], bias=f32[1])
  ],
  activations=[<function gelu>, <function gelu>, Identity()]
)


In [24]:
# sample random points from x_tr
key, subkey = jax.random.split(key)
num_samples = 500
idx = jax.random.choice(subkey, x_tr.shape[0], (num_samples,), replace=False)

# get true and predicted values on grid
X_sample = jnp.array(x_tr[idx])
Y_true = jnp.array(y_tr[idx]).flatten()
Y_pred = jax.vmap(lambda x: new_model(x), in_axes=0)(X_sample)

# convert everything to flat numpy arrays of shape (500,)
X_sample = np.array(X_sample)
Y_true   = np.array(Y_true).reshape(-1)
Y_pred   = np.array(Y_pred).reshape(-1)
print(Y_true.shape)
print(Y_pred.shape)
errors   = np.abs(Y_true - Y_pred)

(500,)
(500,)


In [25]:
print("overall L2 loss on domain:")
print(loss_fn(new_model, X_sample, Y_true))

overall L2 loss on domain:
0.18218222
