<a href="https://colab.research.google.com/github/profteachkids/CHE2064/blob/master/DAE_ChebyshevCollocation_vs_JacobianOfBroydenSolution.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [59]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import config
from jax.config import config
config.update("jax_enable_x64", True)  # JAX default is 32bit single precision
from jax.experimental.host_callback import id_print
from scipy.integrate import solve_ivp
from plotly.subplots import make_subplots
import plotly.io as pio

from numpy.polynomial.chebyshev import chebval, chebder, chebfit, Chebyshev
from scipy.optimize import minimize

pio.templates.default = "plotly_dark"



In [2]:
def broyden(func, x, J=None, max_iter=100, verbose=0, xmax=jnp.inf, xmin=-jnp.inf):
    Jf = jax.jacobian(func) if J is None else J
    J = jnp.atleast_2d(Jf(x))
    Jinv = jnp.linalg.inv(J)
    f = jnp.atleast_1d(func(x))

    for i in range(max_iter):

        dx = -Jinv @ jnp.atleast_1d(f)
        if verbose > 0:
            print(f"\nIter: {i}  dx: {dx}")
        alpha_max_limits = jnp.min(jnp.where(x + dx > xmax, (xmax - x) / (dx), 1))
        alpha_min_limits = jnp.min(jnp.where(x + dx < xmin, (xmin - x) / (dx), 1))
        alpha = min(alpha_max_limits, alpha_min_limits)

        while alpha > 0.01:
            dx_try = alpha * dx
            xp = x + dx_try
            fp = func(xp)
            dnorm = jnp.linalg.norm(fp) - jnp.linalg.norm(f)
            if verbose > 1:
                print(
                    f"Alpha {alpha}   dnorm {dnorm}  dx_try {dx_try}   f {f}    fp {fp}"
                )
            if dnorm > 0:
                alpha *= 0.5
            else:
                break
        if alpha <= 0.01:
            if verbose > 0:
                print("reevaluate J")
            Jinv = jnp.linalg.inv(Jf(x))
            continue

        dx = dx_try
        f = fp
        x = xp
        if verbose > 0:
            print(x, f)
        if jnp.linalg.norm(fp) < 1e-12:
            break

        u = jnp.expand_dims(fp, 1)
        v = jnp.expand_dims(dx, 1) / jnp.linalg.norm(dx) ** 2
        Jinv = Jinv - Jinv @ u @ v.T @ Jinv / (1 + v.T @ Jinv @ u)  # Sherman-Morrison
    return x, f

In [3]:
mu = 1.0

@jax.jit
def f(t, v, dv):
    x, y = v
    dx, dy = dv
    eq1 = dx - y + jnp.sin(dy)
    eq2 = dy - mu * (1 - x ** 2) * y + x + jnp.cos(dx)
    return jnp.array([eq1, eq2])

def dv(t, v):
    global dv_sol
    res = broyden(lambda dv: f(t, v, dv), dv_sol)
    dv_sol = res[0]
    return dv_sol

@jax.jit
def Jdv(t, v):
    vv = lambda v_arg: v_arg - (
        jnp.linalg.inv(jax.jacobian(f, 1)(t, v_arg, dv_sol)) @ f(t, v_arg, dv_sol)
    )
    return jax.jacobian(vv)(v)

In [44]:
tend = 1.0
v0 = jnp.array([1.0, 1.0])
dv_sol = jnp.array([1.0, 1.0])
res = solve_ivp(dv, (0, tend), v0, method="Radau", dense_output=True, jac=Jdv)


In [52]:

tplot = np.linspace(0, tend, 100)
sol = res.sol(tplot)
np.savetxt("RadauSol.csv", sol, delimiter=",")

In [46]:
fig = make_subplots()
fig.add_scatter(x=tplot, y=sol[0])
fig.add_scatter(x=tplot, y=sol[1])
fig.update_layout(width=800, height=400)

In [123]:
def f_cheb(c, t, x, y):
  c = c.reshape(-1,2)
  x.coeff, y.coeff = c[:,0], c[:,1]
  dx, dy = x.deriv(), y.deriv()
  eq1 = dx(t) - y(t) + jnp.sin(dy(t))
  eq2 = dy(t) - mu * (1 - x(t) ** 2) * y(t) + x(t) + jnp.cos(dx(t))
  return (tend/2)*np.sum(w*eq1**2 + w*eq2**2)

In [145]:
x=Chebyshev(np.ones(3), domain=(0,tend))
y=Chebyshev(np.ones(3), domain=(0,tend))

x=Chebyshev.fit(tplot, sol[0], deg=2, domain=(0,tend))
y=Chebyshev.fit(tplot, sol[1], deg=2, domain=(0,tend))

In [146]:
c0=np.concatenate([x.coef,y.coef])
N=20
t_leg, w = np.polynomial.legendre.leggauss(N)
t = tend/2*t_leg + tend/2

In [147]:
res = minimize(lambda c: f_cheb(c, t, x, y),c0)

In [148]:
fig = make_subplots()
fig.add_scatter(x=tplot, y=x(tplot))
fig.add_scatter(x=tplot, y=y(tplot))
fig.update_layout(width=800, height=400)

In [132]:
x.coef

array([ 1.538952  ,  0.40710712, -0.11236267,  0.02579797])

In [61]:
x(0.1)

-1.74

In [62]:
x.coef

array([1., 2., 3.])

In [63]:
x.coef=[3.,4.,1.]

In [64]:
x(0.1)

2.42

In [69]:
x=Chebyshev()

TypeError: ignored