In [2]:
import jax
from jax import numpy as jnp
from matplotlib import pyplot as plt
from jax.config import config

config.update("jax_enable_x64", True)


In [3]:
def grad_hat(x_hat, basis, nx, ny):
    x_tilde = jnp.matmul(basis, x_hat)
    x_tilde_grad = jnp.gradient(x_tilde.reshape(nx, ny))
    return jnp.stack(
        (
            jnp.matmul(x_tilde_grad[0].flatten(), basis),
            jnp.matmul(x_tilde_grad[1].flatten(), basis),
        )
    ).T


In [4]:
data = jnp.load("mu_0.9.npy")

S, N = data.shape
end_time = 2
dt = 2 / (S - 1)
nx = 60
ny = 60

assert nx * ny == N

S -= 1
data_dot = (data[1:] - data[:-1]) / dt
data = data[1:]

X_dot = data_dot
X = data


No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [5]:
# number of basis
n = 20

D, V = jnp.linalg.eigh(jnp.matmul(X.T, X))
idx = jnp.argsort(D)[::-1]
D = D[idx]
V = V[:, idx]
basis = V[:, :n]


In [6]:
X_hat = jnp.matmul(X, basis)
X_dot_hat = jnp.matmul(X_dot, basis)


In [7]:
D = jnp.zeros((S, n + n * n + 2 * n * n))
for i in range(S):
    D = D.at[i, 0:n].set(X_hat[i])
    D = D.at[i, n : n + n * n].set(jnp.outer(X_hat[i], X_hat[i]).flatten())
    D = D.at[i, n + n * n :].set(
        jnp.outer(X_hat[i], grad_hat(X_hat[i], basis, nx, ny)).flatten()
    )


In [8]:
res = jnp.linalg.lstsq(D, X_dot_hat, rcond=1e-11)
AHG = res[0]
A = AHG[0:n, :].T
H = AHG[n : n + n * n, :].T
G = AHG[n + n * n :, :].T


In [9]:
def nr_solve(f, df, x, atol=1e-10, rtol=1e-8, max_itr=50, args=None):
    r = f(x, *args)
    nr = jnp.linalg.norm(r)
    nr0 = nr
    print("Itr = {:}, residual norm = {:.4E}".format(0, nr))
    if (nr < atol):
        return x
    for i in range(max_itr):
        p = jnp.linalg.solve(df(x, *args), r)
        x -= p
        r = f(x, *args)
        nr = jnp.linalg.norm(r)
        print("Itr = {:}, residual norm = {:.4E}".format(i + 1, nr))
        if (nr < atol or nr < rtol * nr0):
            return x
        
    raise Exception("solve failed")
        

In [10]:
def f(x, t):
    x_dot = 0
    x_dot += jnp.matmul(A, x)
    x_dot += jnp.matmul(H, jnp.outer(x, x).flatten())
    grad_x = grad_hat(x, basis, nx, ny)
    x_dot += jnp.matmul(G, jnp.outer(x, grad_x).flatten())
    return x_dot

df = jax.jit(jax.jacrev(f))

def residual(x, t, x_old, dt):
    return x - x_old - dt * f(x, t)


jacobian = jax.jit(jax.jacrev(residual))


In [11]:
x_hat = X_hat[0]
t = 0
sol = {t: x_hat}
dt = 2 / S
dt_min = dt / 100
while t < 2:
    print("Time = {:}, dt = {:}".format(t + dt, dt))
    try:
        x_hat = nr_solve(residual, jacobian, x_hat, args=(t, x_hat, dt), max_itr=20)
    except:
        dt = dt / 2
        if dt < dt_min:
            raise Exception("dt is too small")
        continue
    t = t + dt
    sol[t] = x_hat

Time = 0.0013333333333333333, dt = 0.0013333333333333333
Itr = 0, residual norm = 5.7980E-02
Itr = 1, residual norm = 1.1782E-02
Itr = 2, residual norm = 1.3120E-06
Itr = 3, residual norm = 3.2176E-13
Time = 0.0026666666666666666, dt = 0.0013333333333333333
Itr = 0, residual norm = 5.8076E-02
Itr = 1, residual norm = 1.1922E-02
Itr = 2, residual norm = 1.3268E-06
Itr = 3, residual norm = 3.5371E-13
Time = 0.004, dt = 0.0013333333333333333
Itr = 0, residual norm = 5.8165E-02
Itr = 1, residual norm = 1.2062E-02
Itr = 2, residual norm = 1.3416E-06
Itr = 3, residual norm = 2.9907E-13
Time = 0.005333333333333333, dt = 0.0013333333333333333
Itr = 0, residual norm = 5.8246E-02
Itr = 1, residual norm = 1.2201E-02
Itr = 2, residual norm = 1.3564E-06
Itr = 3, residual norm = 3.6145E-13
Time = 0.006666666666666666, dt = 0.0013333333333333333
Itr = 0, residual norm = 5.8319E-02
Itr = 1, residual norm = 1.2340E-02
Itr = 2, residual norm = 1.3712E-06
Itr = 3, residual norm = 2.9770E-13
Time = 0.008,

Exception: dt is too small

In [12]:
t

0.03675000000000001