In [1]:
import jax
from jax import numpy as jnp
from matplotlib import pyplot as plt
from jax.config import config

config.update("jax_enable_x64", True)


In [2]:
def grad_hat(x_hat, basis, nx, ny):
    x_tilde = jnp.matmul(basis, x_hat)
    x_tilde_grad = jnp.gradient(x_tilde.reshape(nx, ny))
    return jnp.stack(
        (
            jnp.matmul(x_tilde_grad[0].flatten(), basis),
            jnp.matmul(x_tilde_grad[1].flatten(), basis),
        )
    ).T


In [3]:
data1 = jnp.load("mu_0.9.npy")
data2 = jnp.load("mu_0.95.npy")
data3 = jnp.load("mu_1.05.npy")
data4 = jnp.load("mu_1.1.npy")

x_ref = jnp.mean(jnp.concatenate([data1, data2, data3, data4], axis=0), axis=0)

S, N = data1.shape
end_time = 2
dt = 2 / (S - 1)
nx = 60
ny = 60

assert nx * ny == N

data1_dot = (data1[1:] - data1[:-1]) / dt
data1 = data1[1:]

data2_dot = (data2[1:] - data2[:-1]) / dt
data2 = data2[1:]

data3_dot = (data3[1:] - data3[:-1]) / dt
data3 = data3[1:]

data4_dot = (data4[1:] - data4[:-1]) / dt
data4 = data4[1:]

X = jnp.concatenate([data1, data2, data3, data4], axis=0) - x_ref
X_dot = jnp.concatenate([data1_dot, data2_dot, data3_dot, data4_dot], axis=0)

S, N = X.shape

X_test = jnp.load("mu_1.0.npy")


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [4]:
# number of basis
n = 20

D, V = jnp.linalg.eigh(jnp.matmul(X.T, X))
idx = jnp.argsort(D)[::-1]
D = D[idx]
V = V[:, idx]
basis = V[:, :n]


In [5]:
X_hat = jnp.matmul(X, basis)
X_dot_hat = jnp.matmul(X_dot, basis)

X_test_hat = jnp.matmul(X_test - x_ref, basis)


In [6]:
D = jnp.zeros((S, n + n * n + 2 * n * n))
for i in range(S):
    D = D.at[i, 0:n].set(X_hat[i])
    D = D.at[i, n : n + n * n].set(jnp.outer(X_hat[i], X_hat[i]).flatten())
    D = D.at[i, n + n * n :].set(
        jnp.outer(X_hat[i], grad_hat(X_hat[i], basis, nx, ny)).flatten()
    )


In [7]:
res = jnp.linalg.lstsq(D, X_dot_hat, rcond=1e-11)
AHG = res[0]
A = AHG[0:n, :].T
H = AHG[n : n + n * n, :].T
G = AHG[n + n * n :, :].T


In [8]:
def nr_solve(f, df, x, atol=1e-10, rtol=1e-8, max_itr=50, args=None):
    r = f(x, *args)
    nr = jnp.linalg.norm(r)
    nr0 = nr
    print("Itr = {:}, residual norm = {:.4E}".format(0, nr))
    if (nr < atol):
        return x
    for i in range(max_itr):
        p = jnp.linalg.solve(df(x, *args), r)
        x -= p
        r = f(x, *args)
        nr = jnp.linalg.norm(r)
        print("Itr = {:}, residual norm = {:.4E}".format(i + 1, nr))
        if (nr < atol or nr < rtol * nr0):
            return x
        
    raise Exception("solve failed")
        

In [9]:
def f(x, t):
    x_dot = 0
    x_dot += jnp.matmul(A, x)
    x_dot += jnp.matmul(H, jnp.outer(x, x).flatten())
    grad_x = grad_hat(x, basis, nx, ny)
    x_dot += jnp.matmul(G, jnp.outer(x, grad_x).flatten())
    return x_dot

df = jax.jit(jax.jacrev(f))

def residual(x, t, x_old, dt):
    return x - x_old - dt * f(x, t)


jacobian = jax.jit(jax.jacrev(residual))


In [10]:
x_hat = X_test_hat[0]
t = 0
sol = {t: x_hat}
dt = 2 / S
dt_min = dt / 100
while t < 2:
    print("Time = {:}, dt = {:}".format(t + dt, dt))
    try:
        x_hat = nr_solve(residual, jacobian, x_hat, args=(t, x_hat, dt), max_itr=20)
    except:
        dt = dt / 2
        if dt < dt_min:
            raise Exception("dt is too small")
        continue
    t = t + dt
    sol[t] = x_hat

Time = 0.0003333333333333333, dt = 0.0003333333333333333
Itr = 0, residual norm = 1.7807E-02
Itr = 1, residual norm = 1.7683E-06
Itr = 2, residual norm = 9.5176E-13
Time = 0.0006666666666666666, dt = 0.0003333333333333333
Itr = 0, residual norm = 1.7817E-02
Itr = 1, residual norm = 1.7587E-06
Itr = 2, residual norm = 8.7690E-13
Time = 0.001, dt = 0.0003333333333333333
Itr = 0, residual norm = 1.7827E-02
Itr = 1, residual norm = 1.7644E-06
Itr = 2, residual norm = 1.0761E-12
Time = 0.0013333333333333333, dt = 0.0003333333333333333
Itr = 0, residual norm = 1.7837E-02
Itr = 1, residual norm = 1.7489E-06
Itr = 2, residual norm = 9.5033E-13
Time = 0.0016666666666666666, dt = 0.0003333333333333333
Itr = 0, residual norm = 1.7846E-02
Itr = 1, residual norm = 1.7411E-06
Itr = 2, residual norm = 1.1173E-12
Time = 0.002, dt = 0.0003333333333333333
Itr = 0, residual norm = 1.7856E-02
Itr = 1, residual norm = 1.7342E-06
Itr = 2, residual norm = 9.7471E-13
Time = 0.0023333333333333335, dt = 0.00033

Exception: dt is too small

In [12]:
t

0.045468749999999974