In [13]:
from functools import partial
import math
import time
import os
import json

import jax
import jax.numpy as jnp
import numpy as np
import optax

from tqdm import tqdm

import matplotlib.pyplot as plt

In [2]:
RNG_SEED = 1234
TRAIN_EPOCHS = 1000
ODE_STEPS = 20           # explicit Euler steps -> step size = 1/ODE_STEPS
LEAKY_ALPHA = 0.01
LR_MAX = 1e-3
LR_MIN = 1e-5
CYCLE_LENGTH = 200       # epochs per cosine cycle for cyclic LR
TEST_POINTS = 1000
BATCH_SIZE = None        # None -> full-batch for small N
RESULTS_DIR = "results_neural_ode_experiment"
os.makedirs(RESULTS_DIR, exist_ok=True)

# Grids for experiments (paper):
TRAIN_SIZES = [10, 50, 100, 500, 1000]
HS = [5, 10, 50, 100]

### Utils

In [15]:
def target_fn(x):
    return jnp.sin(10.0 * x) + x

key = jax.random.PRNGKey(RNG_SEED)

def leaky_relu(x, alpha=LEAKY_ALPHA):
    return jnp.where(x >= 0, x, alpha * x)

# parameter counting utility
def count_params(params):
    def tree_size(t):
        return sum([p.size for p in jax.tree_util.tree_map.tree_leaves(t)])
    return tree_size(params)


### Models

In [5]:
def init_linear_params(key, in_dim, out_dim, scale=0.1):
    k1, k2 = jax.random.split(key)
    W = scale * jax.random.normal(k1, (out_dim, in_dim))
    b = jnp.zeros((out_dim,))
    return {'W': W, 'b': b}

# Single-hidden-layer: l2 o sigma o l1
def init_single_hidden(key, d):
    k1, k2 = jax.random.split(key)
    p1 = init_linear_params(k1, 1, d)
    p2 = init_linear_params(k2, d, 1)
    return {'l1': p1, 'l2': p2}

def apply_single_hidden(params, x):
    # x: (batch,1)
    z = jnp.dot(x, params['l1']['W'].T) + params['l1']['b']  # (batch,d)
    z = leaky_relu(z)
    out = jnp.dot(z, params['l2']['W'].T) + params['l2']['b']  # (batch,1)
    return out

In [6]:
# Neural ODE block: phi is flow of du/dt = sigma(Au + b_ode) with explicit Euler

def init_neural_ode_params(key, d):
    # params: l1 (1->d), l2 (d->1), A (dxd), b_ode (d)
    k1, k2, k3 = jax.random.split(key, 3)
    l1 = init_linear_params(k1, 1, d)
    l2 = init_linear_params(k2, d, 1)
    A = 0.1 * jax.random.normal(k3, (d, d))
    b_ode = jnp.zeros((d,))
    return {'l1': l1, 'l2': l2, 'A': A, 'b': b_ode}

@partial(jax.jit, static_argnums=(2,))
def phi_flow(u0, A, b_ode, steps=ODE_STEPS):
    # u0: (batch,d)
    dt = 1.0 / steps
    def body(i, u):
        # compute sigma(Au + b) for each batch
        v = jnp.dot(u, A.T) + b_ode  # (batch,d)
        v = leaky_relu(v)
        return u + dt * v
    u = jax.lax.fori_loop(0, steps, body, u0)
    return u

def apply_neural_ode(params, x):
    # x: (batch,1)
    u0 = jnp.dot(x, params['l1']['W'].T) + params['l1']['b']  # (batch,d)
    uT = phi_flow(u0, params['A'], params['b'])
    out = jnp.dot(uT, params['l2']['W'].T) + params['l2']['b']
    return out

In [7]:
# Two-hidden-layer: l3 o sigma o l2 o sigma o l1

def init_two_hidden(key, d):
    k1, k2, k3 = jax.random.split(key, 3)
    l1 = init_linear_params(k1, 1, d)
    l2 = init_linear_params(k2, d, d)
    l3 = init_linear_params(k3, d, 1)
    return {'l1': l1, 'l2': l2, 'l3': l3}

def apply_two_hidden(params, x):
    z1 = jnp.dot(x, params['l1']['W'].T) + params['l1']['b']  # (batch,d)
    z1 = leaky_relu(z1)
    z2 = jnp.dot(z1, params['l2']['W'].T) + params['l2']['b']  # (batch,d)
    z2 = leaky_relu(z2)
    out = jnp.dot(z2, params['l3']['W'].T) + params['l3']['b']  # (batch,1)
    return out

### Training utils

In [24]:
def make_dataset(key, N):
    xs = jax.random.uniform(key, (N, 1), minval=0.0, maxval=1.0)
    ys = target_fn(xs)
    return xs, ys

# cyclic cosine annealing schedule
def cyclic_cosine_lr(step, lr_max=LR_MAX, lr_min=LR_MIN, cycle_length=CYCLE_LENGTH):
    # step in epochs (integer)
    t = (step % cycle_length) / float(cycle_length)
    return lr_min + 0.5 * (lr_max - lr_min) * (1.0 + jnp.cos(jnp.pi * t))

# sgd step functions using optax but with dynamic lr per epoch

@partial(jax.jit, static_argnums=(0,))
def loss_fn_apply(apply_fn, params, x, y):
    preds = apply_fn(params, x)
    return jnp.mean((preds - y) ** 2)

# wrapper trainer that trains a single model instance and returns test MSE history

def train_model(rng, init_fn, apply_fn, train_x, train_y, test_x, test_y, epochs=TRAIN_EPOCHS):
    # init params
    rng, sk = jax.random.split(rng)
    params = init_fn(sk)
    # optax optimizer with placeholder LR -> we'll update lr in each epoch by creating a new opt state when lr changes
    # To keep things simple, use optax.adam with a constant lr that we manually update by scaling grads.
    # Instead we will implement a simple optax chain with zero lr and scale grads manually per epoch.

    optimizer = optax.adam(1.0)  # lr will be applied manually when updating parameters
    opt_state = optimizer.init(params)

    # jit'd gradient function
    # @jax.jit
    def compute_grads(params, x, y):
        loss, grads = jax.value_and_grad(lambda p: loss_fn_apply(apply_fn, p, x, y))(params)
        return loss, grads

    history = {'train_loss': [], 'test_loss': []}

    n_train = train_x.shape[0]
    batch_size = BATCH_SIZE or n_train

    for epoch in tqdm(range(epochs), 'training loop....'):
        # learning rate for current epoch
        lr = float(cyclic_cosine_lr(epoch))

        # simple full-batch or mini-batch loop
        perm = jax.random.permutation(rng, n_train)
        rng, _ = jax.random.split(rng)
        for i in range(0, n_train, batch_size):
            idx = perm[i:i+batch_size]
            xb = train_x[idx]
            yb = train_y[idx]
            loss, grads = compute_grads(params, xb, yb)
            updates, opt_state = optimizer.update(grads, opt_state, params)
            # scale updates by lr
            scaled_updates = jax.tree_util.tree_map(lambda u: lr * u, updates)
            params = optax.apply_updates(params, scaled_updates)

        # record losses
        train_loss = float(loss_fn_apply(apply_fn, params, train_x, train_y))
        test_loss = float(loss_fn_apply(apply_fn, params, test_x, test_y))
        history['train_loss'].append(train_loss)
        history['test_loss'].append(test_loss)

        # optional progress print every 100 epochs
        if (epoch + 1) % 100 == 0:
            print(f"epoch {epoch+1}/{epochs}  lr={lr:.2e}  train_mse={train_loss:.6e}  test_mse={test_loss:.6e}")

    return params, history

In [18]:
def run_experiments():
    global key
    results = {}

    # fixed test set
    key, sk = jax.random.split(key)
    test_x = jax.random.uniform(sk, (TEST_POINTS, 1), minval=0.0, maxval=1.0)
    test_y = target_fn(test_x)

    for d in HS:
        for N in TRAIN_SIZES:
            print(f"\nRunning experiment: hidden dim d={d}, train size N={N}")
            # generate single training set for this pair (N,d)
            key, sk = jax.random.split(key)
            train_x, train_y = make_dataset(sk, N)

            # initialize model RNGs deterministically to keep fair comparison
            k_single, k_ode, k_two = jax.random.split(key, 3)

            # init fns
            init_s = lambda k: init_single_hidden(k, d)
            apply_s = apply_single_hidden

            init_o = lambda k: init_neural_ode_params(k, d)
            apply_o = apply_neural_ode

            init_t = lambda k: init_two_hidden(k, d)
            apply_t = apply_two_hidden

            # train each model
            t0 = time.time()
            params_s, hist_s = train_model(k_single, init_s, apply_s, train_x, train_y, test_x, test_y)
            t1 = time.time()
            params_o, hist_o = train_model(k_ode, init_o, apply_o, train_x, train_y, test_x, test_y)
            t2 = time.time()
            params_t, hist_t = train_model(k_two, init_t, apply_t, train_x, train_y, test_x, test_y)
            t3 = time.time()

            print(f"Times (s): single={t1-t0:.1f}, ode={t2-t1:.1f}, two={t3-t2:.1f}")

            results[(d, N)] = {
                'single': {'params': params_s, 'hist': hist_s},
                'ode': {'params': params_o, 'hist': hist_o},
                'two': {'params': params_t, 'hist': hist_t},
                'train_x': np.array(train_x), 'train_y': np.array(train_y),
                'test_x': np.array(test_x), 'test_y': np.array(test_y)}

            # save intermediate results to disk to keep a record
            save_path = os.path.join(RESULTS_DIR, f"result_d{d}_N{N}.npz")
            # we cannot save JAX DeviceArray directly; convert histories and losses
            np.savez(save_path,
                     train_x=np.array(train_x), train_y=np.array(train_y),
                     test_x=np.array(test_x), test_y=np.array(test_y),
                     hist_single=np.array(hist_s['test_loss']),
                     hist_ode=np.array(hist_o['test_loss']),
                     hist_two=np.array(hist_t['test_loss']))

    # save summary
    with open(os.path.join(RESULTS_DIR, 'meta_summary.json'), 'w') as f:
        json.dump({'HS': HS, 'TRAIN_SIZES': TRAIN_SIZES, 'ODE_STEPS': ODE_STEPS}, f)

    return results

In [25]:
print("Starting experiments. This may take a while depending on hardware.")
results = run_experiments()
print("Experiments finished. Results saved in:", RESULTS_DIR)

Starting experiments. This may take a while depending on hardware.

Running experiment: hidden dim d=5, train size N=10


training loop....:   0%|          | 0/1000 [00:00<?, ?it/s]

training loop....:  11%|█         | 109/1000 [00:02<00:13, 63.89it/s]

epoch 100/1000  lr=5.13e-04  train_mse=8.386896e-01  test_mse=8.715138e-01


training loop....:  21%|██        | 207/1000 [00:04<00:12, 64.39it/s]

epoch 200/1000  lr=1.01e-05  train_mse=8.129290e-01  test_mse=8.494824e-01


training loop....:  31%|███▏      | 313/1000 [00:05<00:10, 66.92it/s]

epoch 300/1000  lr=5.13e-04  train_mse=7.126792e-01  test_mse=7.648321e-01


training loop....:  41%|████      | 411/1000 [00:07<00:09, 62.35it/s]

epoch 400/1000  lr=1.01e-05  train_mse=6.917909e-01  test_mse=7.474537e-01


training loop....:  51%|█████     | 509/1000 [00:08<00:08, 60.31it/s]

epoch 500/1000  lr=5.13e-04  train_mse=6.111264e-01  test_mse=6.814530e-01


training loop....:  61%|██████    | 607/1000 [00:10<00:06, 62.95it/s]

epoch 600/1000  lr=1.01e-05  train_mse=5.945447e-01  test_mse=6.681502e-01


training loop....:  71%|███████   | 712/1000 [00:11<00:04, 64.21it/s]

epoch 700/1000  lr=5.13e-04  train_mse=5.312324e-01  test_mse=6.184835e-01


training loop....:  81%|████████  | 810/1000 [00:13<00:02, 64.59it/s]

epoch 800/1000  lr=1.01e-05  train_mse=5.184230e-01  test_mse=6.087028e-01


training loop....:  91%|█████████ | 908/1000 [00:15<00:01, 62.07it/s]

epoch 900/1000  lr=5.13e-04  train_mse=4.701903e-01  test_mse=5.730073e-01


training loop....: 100%|██████████| 1000/1000 [00:16<00:00, 60.91it/s]


epoch 1000/1000  lr=1.01e-05  train_mse=4.606131e-01  test_mse=5.661880e-01


training loop....:   0%|          | 0/1000 [00:00<?, ?it/s]


ValueError: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>, JitTracer<float32[5]>. The error was:
TypeError: unhashable type: 'DynamicJaxprTracer'


In [22]:
def build_and_plot_comparison(results_dir=RESULTS_DIR, HS=HS, TRAIN_SIZES=TRAIN_SIZES, save_png=True):
    """
    Load per-(d,N) .npz result files saved by run_experiments and build comparison plots.

    - For each hidden size d in HS, plots test MSE vs number of training samples N for the three models
        (single, ode, two) on the same subplot.
    - Saves a single figure with subplots for all d into results_dir (mse_vs_N.png) and per-d PNGs.

    Returns a dict containing numpy arrays for final test MSEs with shape (len(HS), len(TRAIN_SIZES)).
    """
    # prepare arrays
    H = len(HS)
    S = len(TRAIN_SIZES)
    mse_single = np.zeros((H, S))
    mse_ode = np.zeros((H, S))
    mse_two = np.zeros((H, S))

    for i, d in enumerate(HS):
        for j, N in enumerate(TRAIN_SIZES):
            fname = os.path.join(results_dir, f"result_d{d}_N{N}.npz")
            if not os.path.exists(fname):
                print(f"Warning: file not found {fname}, filling with NaN")
                mse_single[i, j] = np.nan
                mse_ode[i, j] = np.nan
                mse_two[i, j] = np.nan
                continue
            data = np.load(fname)
            # hist arrays stored as hist_single/test_loss etc -> final element is last epoch
            try:
                hist_s = data['hist_single']
                hist_o = data['hist_ode']
                hist_t = data['hist_two']
            except Exception:
                # backwards compatibility / alternative key names
                hist_s = data.get('hist_single', data.get('hist_s', None))
                hist_o = data.get('hist_ode', data.get('hist_o', None))
                hist_t = data.get('hist_two', data.get('hist_t', None))
            if hist_s is None or hist_o is None or hist_t is None:
                print(f"Missing history arrays in {fname}; filling with NaN")
                mse_single[i, j] = np.nan
                mse_ode[i, j] = np.nan
                mse_two[i, j] = np.nan
                continue
            mse_single[i, j] = float(hist_s[-1])
            mse_ode[i, j] = float(hist_o[-1])
            mse_two[i, j] = float(hist_t[-1])

    # Plotting: one figure with subplots per hidden size
    cols = int(math.ceil(math.sqrt(H)))
    rows = int(math.ceil(H / cols))
    fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3 * rows), squeeze=False)
    for idx, d in enumerate(HS):
        r = idx // cols
        c = idx % cols
        ax = axes[r][c]
        ax.plot(TRAIN_SIZES, mse_single[idx, :], marker='o', label='single (1 hidden)')
        ax.plot(TRAIN_SIZES, mse_ode[idx, :], marker='s', label='neural ODE')
        ax.plot(TRAIN_SIZES, mse_two[idx, :], marker='^', label='two-hidden')
        ax.set_xscale('log')
        ax.set_yscale('log')
        ax.set_xlabel('Training samples (N)')
        ax.set_ylabel('Test MSE')
        ax.set_title(f'hidden dim d={d}')
        ax.grid(True, which='both', ls='--', alpha=0.4)
        ax.legend()

    # hide empty subplots
    for idx in range(H, rows * cols):
        r = idx // cols
        c = idx % cols
        axes[r][c].axis('off')

    plt.tight_layout()
    out_path = os.path.join(results_dir, 'mse_vs_N.png')
    if save_png:
        fig.savefig(out_path, dpi=200)
        print(f"Saved combined MSE plot to {out_path}")
        # also save per-d plots
        for idx, d in enumerate(HS):
            fig2, ax2 = plt.subplots(figsize=(6,4))
            ax2.plot(TRAIN_SIZES, mse_single[idx, :], marker='o', label='single (1 hidden)')
            ax2.plot(TRAIN_SIZES, mse_ode[idx, :], marker='s', label='neural ODE')
            ax2.plot(TRAIN_SIZES, mse_two[idx, :], marker='^', label='two-hidden')
            ax2.set_xscale('log')
            ax2.set_yscale('log')
            ax2.set_xlabel('Training samples (N)')
            ax2.set_ylabel('Test MSE')
            ax2.set_title(f'hidden dim d={d}')
            ax2.grid(True, which='both', ls='--', alpha=0.4)
            ax2.legend()
            out_path_d = os.path.join(results_dir, f'mse_vs_N_d{d}.png')
            fig2.tight_layout()
            plt.show()
            fig2.savefig(out_path_d, dpi=200)
            plt.close(fig2)
            print(f"Saved per-d plot to {out_path_d}")

    # plt.close(fig)
    plt.show()

    return {'HS': HS, 'TRAIN_SIZES': TRAIN_SIZES,
            'mse_single': mse_single, 'mse_ode': mse_ode, 'mse_two': mse_two}

In [23]:
summary = build_and_plot_comparison()
print("Done. Summary keys:", list(summary.keys()))



ValueError: Data has no positive values, and therefore cannot be log-scaled.

Error in callback <function _draw_all_if_interactive at 0x72fc23d89580> (for post_execute), with arguments args (),kwargs {}:


ValueError: Data has no positive values, and therefore cannot be log-scaled.

ValueError: Data has no positive values, and therefore cannot be log-scaled.

<Figure size 800x600 with 4 Axes>

In [None]:
exit

: 