In [None]:
%matplotlib qt
%load_ext autoreload
%autoreload 2

In [None]:
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
from time import time

# Jax 64-bit
import jax

jax.config.update("jax_enable_x64", True)

In [None]:
from ICARUS.dynamical_systems import NonLinearSystem
from ICARUS.dynamical_systems import SecondOrderSystem
from ICARUS.dynamical_systems.integrate import (
    BackwardEulerIntegrator,
    ForwardEulerIntegrator,
    RK4Integrator,
    RK45Integrator,
    CrankNicolsonIntegrator,
    GaussLegendreIntegrator,
    NewmarkIntegrator,
)

In [None]:
from test_integrators import test_all_integrators 

# Simple Mass-Damper System

In [None]:
# Define a simple m-c-k system
m = 1.0
c = 0.1
k = 1.0


def f(t: float, x: jnp.ndarray) -> jnp.ndarray:
    return jnp.array(
        [
            x[1],  # x' = v
            -c / m * x[1] - k / m * x[0],  # v' = a = -c/m * v - k/m * x
        ]
    )


# Create the system
system = NonLinearSystem(f)

# Test the integrators
x_data, t_data = test_all_integrators(
    system, jnp.array([1.0, 0.0]), 0.0, 100.0, 0.0001, compare_with_scipy=True
)



# Higher Order

# Second Order Systems

In [None]:
# Define a 2nd order system
m1 = 1.0
c1 = 0.1
k1 = 1.0

m2 = 1.0
c2 = 0.1
k2 = 1.0


def M(t, x):
    return jnp.array([[m1, 0], [0, m2]])


# M = jnp.array([m])
def C(t, x):
    return jnp.array(
        [
            [0.023, 1.024],  # [c1, 0],
            [-0.364, 3.31],  # [0, c2]
        ]
    )


# C = jnp.array([c])


def f_int(t, x):
    return jnp.array(
        [
            [1.97, 0.034],  # [k1, -k1],
            [0.034, 3.95],  # [-k1, k1 + k2]
        ]
    )


# f_int = jnp.array([k])


def f_ext(t: float, x: jnp.ndarray) -> jnp.ndarray:
    return jnp.array(
        [
            0.078,  # 0.0,
            10 * 0.466 * jnp.sin(t),
        ]
    )


# f_ext = lambda t, x: jnp.array([0.0])

system = SecondOrderSystem(M, C, f_int, f_ext)

In [None]:
# Test the integrators
test_all_integrators(
    system, jnp.array([0.0, 0.0, 0.0, 0.0]), 0.0, 100.0, 1e-4, compare_with_scipy=True
)