In [1]:
import os
import time
import numpy as np
from tqdm import tqdm
import jax
import jax.numpy as jnp
from jax import random
import optax
from flax import linen as nn
from flax.training import train_state

import pickle
from scipy.integrate import solve_ivp


savedir = "models/simple_64_another"

In [103]:
class MLP(nn.Module):
    dim: int
    out_dim: int = 1
    w: int = 128

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.w)(x)
        x = nn.selu(x)
        x = nn.Dense(self.w)(x)
        x = nn.selu(x)
        x = nn.Dense(self.w)(x)
        x = nn.selu(x)
        x = nn.Dense(self.out_dim)(x)
        return x
model = MLP(dim=4)

In [104]:
@jax.jit
def predict(params, inputs):
    return model.apply({"params": params}, inputs)

In [105]:
with open(f'w.pkl', 'rb') as f:
    params = pickle.load(f)

In [106]:
params

{'Dense_0': {'bias': Array([ 1.03462823e-02, -4.36216630e-02,  6.35367353e-04,  2.83446703e-02,
          4.26050136e-03, -1.68262944e-02, -1.11784860e-02, -5.94135709e-02,
         -9.15442780e-03, -2.19262931e-02,  2.50141341e-02, -4.06255797e-02,
         -5.48381619e-02,  6.49208054e-02,  4.87201214e-02,  3.56901102e-02,
          1.99148413e-02,  5.65144457e-02,  2.78431680e-02, -1.51757544e-04,
         -4.25914228e-02,  3.31006721e-02, -3.07021774e-02, -9.85401049e-02,
          3.10358638e-03,  1.67842526e-02, -1.47928605e-02,  7.61809899e-03,
         -4.08222713e-02,  3.09791733e-02, -4.41862792e-02,  8.83590356e-02,
         -1.66929718e-02,  9.50063467e-02,  6.15200698e-02, -5.06070629e-02,
         -3.85314487e-02,  3.53376009e-02,  7.21366256e-02,  1.10821195e-01,
         -9.42450985e-02,  9.65999886e-02, -3.44303735e-02,  2.44618729e-01,
         -6.41742125e-02,  2.68931519e-02,  3.06071136e-02, -5.11502549e-02,
          4.83847819e-02, -9.29455906e-02,  5.80329970e-0

In [113]:
def ode_function(t, m, d, e):
    inputs = jnp.array([m[0], d, e, t]).reshape(1,-1)
    vt = predict(params, inputs)
    return vt[0]

In [114]:
def d_by_m_e(m, e):
    noise = np.random.normal(scale=1e-4, size=1).item()
    d = np.power(e, 2) * np.power(m, 3) + m * np.exp(-np.abs(0.2 - e)) + noise
    return d

errors, m_sol, d_err = [], [], []

for _ in tqdm(range(10)):
    m0 = np.random.uniform(size=1).item()
    m = 0.2
    e = 0.1 
    d = d_by_m_e(m, e)

    solution = solve_ivp(ode_function, t_span=[0, 1], y0=[m0], t_eval=None, args=(d, e))
    m_sol.append(solution.y[0][-1])
    errors.append(np.abs(m_sol[-1] - m))
    d_err.append(np.abs(d_by_m_e(m_sol[-1], e) - d))

100%|██████████| 10/10 [00:00<00:00, 20.53it/s]


In [116]:
errors

[0.32066546050052086,
 0.22113395795466112,
 0.3402325772666651,
 0.4045555562910952,
 0.015011115604092906,
 0.2739414747709257,
 0.3628347536266723,
 0.1822880436511642,
 0.2744153463652132,
 0.37982655195275744]

In [109]:
def ode_function(t, m, d, e):
    inputs = jnp.array([m[0], d, e, t])
    print(inputs.shape)
    return predict(params, inputs)

In [110]:
m0 = 0.51
m = 0.2
e = 0.1 
noise = np.random.normal(scale=1e-4, size=1)
d = np.power(e, 2) * np.power(m, 3) + m * np.exp(-np.abs(0.2 - e)) + noise
d = d.item()

solution = solve_ivp(ode_function, t_span=[0, 1], y0=[m0], t_eval=None, args=(d, e))

(4,)
(4,)
(4,)
(4,)
(4,)
(4,)
(4,)
(4,)
(4,)
(4,)
(4,)
(4,)
(4,)
(4,)
(4,)
(4,)
(4,)
(4,)
(4,)
(4,)


In [111]:
solution

  message: The solver successfully reached the end of the integration interval.
  success: True
   status: 0
        t: [ 0.000e+00  8.423e-02  4.735e-01  1.000e+00]
        y: [[ 5.100e-01  5.282e-01  5.097e-01  2.337e-01]]
      sol: None
 t_events: None
 y_events: None
     nfev: 20
     njev: 0
      nlu: 0

In [117]:
@jax.jit
def sample_conditional_pt(x0, x1, t, sigma):
    t = t.reshape(-1, *([1] * (x0.ndim - 1)))
    mu_t = t * x1 + (1 - t) * x0
    epsilon = jax.random.normal(jax.random.PRNGKey(42), x0.shape)
    return mu_t + sigma * epsilon

In [120]:
x0 = jnp.array([0,1,2,3,4,5,6])
x1 = jnp.array([0,1,2,3,4,5,6])
t = jnp.array([0,1,2,3,4,5,6])

In [125]:
sample_conditional_pt(x0, x1, t, 0.01)

Array([4.8962492e-03, 1.0025226e+00, 1.9937261e+00, 2.9972436e+00,
       4.0071573e+00, 4.9943151e+00, 5.9825382e+00], dtype=float32)