In [40]:
import os
import time
import numpy as np
from tqdm import tqdm
import jax
import jax.numpy as jnp
from jax import random
import optax
from flax import linen as nn
from flax.training import train_state

import pickle
from scipy.integrate import solve_ivp


savedir = "models/simple_64_another_2"

In [41]:
class MLP(nn.Module):
    dim: int
    out_dim: int = 1
    w: int = 64

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.w)(x)
        x = nn.selu(x)
        x = nn.Dense(self.w)(x)
        x = nn.selu(x)
        x = nn.Dense(self.w)(x)
        x = nn.selu(x)
        x = nn.Dense(self.out_dim)(x)
        return x
model = MLP(dim=4)

In [42]:
@jax.jit
def predict(params, inputs):
    return model.apply({"params": params}, inputs)

In [43]:
with open(f'{savedir}/w.pkl', 'rb') as f:
    params = pickle.load(f)

In [44]:
params

{'Dense_0': {'bias': Array([ 0.09206923, -0.346633  , -0.16783881,  0.7796735 ,  0.11282883,
         -0.12830903, -0.21261385, -0.05855507, -0.6620905 ,  0.01374729,
          0.54079086,  0.04040891,  0.02714562, -0.03558638,  0.2224404 ,
         -0.01543862, -0.26361072,  0.10217989, -0.5372539 ,  0.2284727 ,
          0.61656344, -0.9127779 , -0.1701792 , -0.01294379, -0.20482953,
         -0.19118284, -0.5200493 , -0.30188486,  0.1637638 , -0.30561835,
          0.09131347, -0.0759472 ,  0.16333155, -0.38095048, -0.10012918,
          0.2368503 , -0.34774688, -0.10274654,  0.11005658,  0.05618269,
          0.04169287, -0.06281649,  0.33957514,  0.57285744, -0.08232248,
          0.20374869,  0.29558802, -0.76348466,  0.44540855, -0.22040068,
         -0.19969437, -0.28088045,  0.14401032,  0.00487827, -0.45642883,
          0.11923363, -0.06899499,  0.19890033, -0.5095794 , -0.12242357,
          0.01299737, -0.8300294 , -0.2693881 , -0.80380183], dtype=float32),
  'kernel': Arr

In [45]:
def ode_function(t, m, d, e):
    inputs = jnp.array([m[0], d, e, t]).reshape(1,-1)
    vt = predict(params, inputs)
    return vt[0]

In [46]:
def ode_function(t, m, d, e):
    inputs = jnp.array([m[0], d, e, t])
    return predict(params, inputs)[0]

def d_by_m_e(m, e):
    noise = np.random.normal(scale=1e-4, size=1).item()
    d = np.power(e, 2) * np.power(m, 3) + m * np.exp(-np.abs(0.2 - e)) + noise
    return d

errors, m_sol, d_err = [], [], []

for _ in tqdm(range(1000)):
    m0 = np.random.uniform(size=1).item()
    m = 0.2
    e = 0.1 
    d = d_by_m_e(m, e)

    solution = solve_ivp(ode_function, t_span=[0, 1], y0=[m0], t_eval=None, args=(d, e))
    m_sol.append(solution.y[0][-1])
    errors.append(np.abs(m_sol[-1] - m))
    d_err.append(np.abs(d_by_m_e(m_sol[-1], e) - d))


100%|██████████| 1000/1000 [00:59<00:00, 16.83it/s]


In [47]:
np.mean(errors)

0.022723877849682095

In [49]:
solution

  message: The solver successfully reached the end of the integration interval.
  success: True
   status: 0
        t: [ 0.000e+00  1.311e-01  5.594e-01  7.215e-01  8.836e-01
             9.601e-01  1.000e+00]
        y: [[ 2.788e-01  2.696e-01  2.406e-01  2.296e-01  2.189e-01
              2.160e-01  2.212e-01]]
      sol: None
 t_events: None
 y_events: None
     nfev: 56
     njev: 0
      nlu: 0