# Pitching Airfoil 

In [140]:
import jax.numpy as jnp
import numpy as np
import jax
import matplotlib.pyplot as plt
from jax import grad, jit, vmap
from typing import Callable
from functools import partial
from jax import lax


# Dynamical Systems

In [141]:
class DynamicalSystem:
    # This class is the base class for all dynamical elements
    def __init__(self) -> None:
        raise NotImplementedError
    
    def jacobian(self, x: jnp.ndarray, t: float) -> jnp.ndarray:
        raise NotImplementedError

    def __call__(self, x: jnp.ndarray, t: float) -> jnp.ndarray:
        raise NotImplementedError

class LinearSystem(DynamicalSystem):
    def __init__(self, A: jnp.ndarray, B: jnp.ndarray) -> None:
        self.A = A
        self.B = B

    def jacobian(self, x: jnp.ndarray, t: float) -> jnp.ndarray:
        return self.A

    @partial(jit, static_argnums=(0,))
    def __call__(self, x: jnp.ndarray, t: float) -> jnp.ndarray:
        return self.A @ x + self.B

class NonlinearSystem(DynamicalSystem):
    def __init__(self, f: Callable[[jnp.ndarray,float], jnp.ndarray]) -> None:
        self.f = jit(f)
    
    @partial(jit, static_argnums=(0,))
    def jacobian(self, x: jnp.ndarray, t: float) -> jnp.ndarray:
        return jax.jacobian(self.f, argnums=0)(x, t)

    @partial(jit, static_argnums=(0,))
    def linearize(self, x: jnp.ndarray, t: float) -> tuple[jnp.ndarray, jnp.ndarray]:
        A = self.jacobian(x, t)
        B = self.f(x, t) - A @ x
        return A, B
    
    @partial(jit, static_argnums=(0,))
    def __call__(self, x: jnp.ndarray, t: float) -> jnp.ndarray:
        # Linearize the system first and then use the linearized system to get the output
        linearized: LinearSystem = LinearSystem(*self.linearize(x, t))
        return linearized(x, t)

class SecondOrderSystem(DynamicalSystem):
    """
    This class is for second order systems of the form:
    M(u,t) u'' + C(u,t) u' + f_int(u,t) = f_ext(u,t)
    """

    def __init__(
        self, 
        M: Callable[[jnp.ndarray, float], jnp.ndarray]  | jnp.ndarray, 
        C: Callable[[jnp.ndarray, float], jnp.ndarray]  | jnp.ndarray, 
        f_int: Callable[[jnp.ndarray, float], jnp.ndarray] | jnp.ndarray, 
        f_ext: Callable[[jnp.ndarray, float], jnp.ndarray]
    ) -> None:
        if isinstance(M, jnp.ndarray):
            self.M = jit(lambda x, t: M@x[:x.shape[0]//2])
        else:
            self.M = jit(M)

        if isinstance(C, jnp.ndarray):
            self.C = jit(lambda x, t: C@x[x.shape[0]//2:])
        else:
            self.C = jit(C)

        if isinstance(f_int, jnp.ndarray):
            self.f_int = jit(lambda x, t: f_int * x[:x.shape[0]//2])
        else:
            self.f_int = jit(f_int)


        self.f_ext = jit(f_ext)

    def jacobian(self, x: jnp.ndarray, t: float) -> jnp.ndarray:
        # Convert system to a first order system of ODEs by defining the state as [u, u']
        # Then the jacobian is given by:
        # J = [[0, I], [-M^-1 @ f_int, -M^-1 @ C]]
        n = x.shape[0]
        I = jnp.eye(n)
        M = self.M(x, t)
        C = self.C(x, t)
        f_int = self.f_int(x, t)
        J = jnp.block([[jnp.zeros((n,n)), I], [-jnp.linalg.solve(M, f_int), -jnp.linalg.solve(M, C)]])
        return J
    
    @partial(jit, static_argnums=(0,))
    def __call__(self, x: jnp.ndarray, t: float) -> jnp.ndarray:
        # Convert system to a first order system of ODEs by defining the state as [u, u']
        # Then the output is given by:
        # [u', -M^-1 @ (f_int - f_ext) - C @ u]
        n = x.shape[0]
        u = x[:n//2]
        v = x[n//2:]
        M = self.M(u, t)
        C = self.C(u, t)
        f_int = self.f_int(u, t)
        f_ext = self.f_ext(u, t)
        return jnp.concatenate([v, -jnp.linalg.solve(M, f_int - f_ext) - jnp.dot(C, v)])
    
    @partial(jit, static_argnums=(0,))
    def linearize(self, x: jnp.ndarray, t: float) -> tuple[jnp.ndarray, jnp.ndarray]:
        # Linearize the system first and then use the linearized system to get the output
        n = x.shape[0]
        u = x[:n//2]
        v = x[n//2:]
        M = self.M(u, t)
        C = self.C(u, t)
        f_int = self.f_int(u, t)
        f_ext = self.f_ext(u, t)
        A = jnp.block([[jnp.zeros((n//2,n//2)), jnp.eye(n//2)], [-jnp.linalg.solve(M, f_int - f_ext), -jnp.dot(jnp.linalg.solve(M, C), jnp.eye(n//2))]])
        B = jnp.block([[jnp.zeros((n//2,n//2)), jnp.zeros((n//2,n//2))], [jnp.zeros((n//2,n//2)), -jnp.linalg.solve(M, jnp.eye(n//2))]])
        return A, B

# Integrators

In [146]:
class Integrator:
    def __init__(self, dt: float, system: DynamicalSystem) -> None:
        self.name = None
        self.type = None
        raise NotImplementedError

    def step(self, x: jnp.ndarray, t: float) -> jnp.ndarray:
        raise NotImplementedError
    
    def simulate(self, x0: jnp.ndarray, t0: float, tf: float) -> tuple[jnp.ndarray, jnp.ndarray]:
        raise NotImplementedError
    
class ForwardEulerIntegrator(Integrator):
    def __init__(self, dt: float, system: DynamicalSystem) -> None:
        self.name = "Forward Euler"
        self.type = "Explicit"
        self.dt = dt
        self.system = system

    @partial(jit, static_argnums=(0,))
    def step(self, x: jnp.ndarray, t: float) -> jnp.ndarray:
        return x + self.dt * self.system(x, t)
    
    def simulate(self, x0: jnp.ndarray, t0: float, tf: float) -> tuple[jnp.ndarray, jnp.ndarray]:
        num_steps = jnp.ceil((tf - t0) / self.dt).astype(int)
        x = x0
        t = t0
        # Preallocate the trajectory array
        trajectory = jnp.zeros((num_steps + 1, x0.shape[0]))
        trajectory = trajectory.at[0].set(x0)
        times = jnp.linspace(start= t0, stop = tf, num = num_steps + 1)

        times, trajectory = self._simulate(trajectory, times, num_steps)
        return times, trajectory

    @partial(jit, static_argnums=(0,))
    def _simulate(self, trajectory: jnp.ndarray, times: jnp.ndarray, num_steps: int) -> tuple[jnp.ndarray, jnp.ndarray]:
        # Create a loop using lax.fori_loop that integrates the system using the forward Euler method and
        # stores the results in the trajectory array
        def body(i, trajectory):
            x = trajectory[i - 1]
            t = times[i - 1]
            trajectory = trajectory.at[i].set(self.step(x, t + i * self.dt))
            return trajectory
        
        times, trajectory = lax.fori_loop(1, num_steps + 1, body, (times, trajectory))
        return times, trajectory

class BackwardEulerIntegrator(Integrator):
    def __init__(self, dt: float, system: DynamicalSystem) -> None:
        self.name = "Backward Euler"
        self.type = "Implicit"
        self.dt: float = dt
        self.system = system

    @partial(jit, static_argnums=(0,))
    def step(self, x: jnp.ndarray, t: float) -> jnp.ndarray:
        # Implement the backward Euler method
        return jnp.linalg.solve(jnp.eye(x.shape[0]) - self.dt * self.system.jacobian(x, t), x + self.dt * self.system(x, t))
    
    def simulate(self, x0: jnp.ndarray, t0: float, tf: float) -> tuple[jnp.ndarray, jnp.ndarray]:
        num_steps = jnp.ceil((tf - t0) / self.dt).astype(int)
        x = x0
        t = t0
        trajectory = jnp.zeros((num_steps + 1, x0.shape[0]))
        trajectory = trajectory.at[0].set(x0)
        trajectory = self._simulate(trajectory, x, t, num_steps)
        times = np.linspace(start= t0, stop = tf, num = num_steps + 1)
        return jnp.array(times), trajectory
    
    @partial(jit, static_argnums=(0,))
    def _simulate(self, trajectory: jnp.ndarray, x: jnp.ndarray, t: float, num_steps: int) -> jnp.ndarray:
        # Create a loop using lax.fori_loop that integrates the system using the backward Euler method and
        # stores the results in the trajectory array
        def body(i, trajectory):
            x = trajectory[i - 1]
            trajectory = trajectory.at[i].set(self.step(x, t + i * self.dt))
            return trajectory
        
        trajectory = lax.fori_loop(1, num_steps + 1, body, trajectory)
        return trajectory
    
class RK4Integrator(Integrator):
    def __init__(self, dt: float, system: DynamicalSystem) -> None:
        self.name = "RK4"
        self.type = "Explicit"
        self.dt = dt
        self.system = system

    @partial(jit, static_argnums=(0,))
    def step(self, x: jnp.ndarray, t: float) -> jnp.ndarray:
        k1 = self.system(x, t)
        k2 = self.system(x + 0.5 * self.dt * k1, t + 0.5 * self.dt)
        k3 = self.system(x + 0.5 * self.dt * k2, t + 0.5 * self.dt)
        k4 = self.system(x + self.dt * k3, t + self.dt)
        return x + (self.dt / 6) * (k1 + 2 * k2 + 2 * k3 + k4)
    
    def simulate(self, x0: jnp.ndarray, t0: float, tf: float) -> tuple[jnp.ndarray, jnp.ndarray]:
        num_steps = jnp.ceil((tf - t0) / self.dt).astype(int)
        x = x0
        t = t0
        trajectory = jnp.zeros((num_steps + 1, x0.shape[0]))
        trajectory = trajectory.at[0].set(x0)


        times = np.linspace(start= t0, stop = tf, num = num_steps + 1)
        trajectory = self._simulate(trajectory, x, t, num_steps)
        
        return jnp.array(times), trajectory
    
    @partial(jit, static_argnums=(0,))
    def _simulate(self, trajectory: jnp.ndarray, x: jnp.ndarray, t: float, num_steps: int) -> jnp.ndarray:
        # Create a loop using lax.fori_loop that integrates the system using the RK4 method and
        # stores the results in the trajectory array
        def body(i, trajectory):
            x = trajectory[i - 1]
            trajectory = trajectory.at[i].set(self.step(x, t + i * self.dt))
            return trajectory
        
        trajectory = lax.fori_loop(1, num_steps + 1, body, trajectory)
        return trajectory
    
class RK45Integrator(Integrator):
    """
    This class implements the Runge-Kutta Fehlberg 4(5) method for solving ODEs
    RK45 is a 5th order method with an embedded 4th order method for error estimation
    """
    def __init__(self, dt: float, system: DynamicalSystem) -> None:
        self.name = "RK45"
        self.type = "Explicit"
        self.dt = dt
        self.system = system

    # @partial(jit, static_argnums=(0,))
    def step(self, x: jnp.ndarray, t: float, dt) -> tuple[jnp.ndarray,float]:
        # Implement the Runge-Kutta Fehlberg 4(5) method
        tol=1e-6

        k1 = self.system(x, t)
        k2 = self.system(x + 0.25 * self.dt * k1, t + 0.25 * self.dt)
        k3 = self.system(x + (3/32) * self.dt * k1 + (9/32) * self.dt * k2, t + (3/8) * self.dt)
        k4 = self.system(x + (1932/2197) * self.dt * k1 - (7200/2197) * self.dt * k2 + (7296/2197) * self.dt * k3, t + (12/13) * self.dt)
        k5 = self.system(x + (439/216) * self.dt * k1 - 8 * self.dt * k2 + (3680/513) * self.dt * k3 - (845/4104) * self.dt * k4, t + self.dt)
        k6 = self.system(x - (8/27) * self.dt * k1 + 2 * self.dt * k2 - (3544/2565) * self.dt * k3 + (1859/4104) * self.dt * k4 - (11/40) * self.dt * k5, t + 0.5 * self.dt)

        # Compute the 4th and 5th order estimates
        x_4 = x + (25/216) * self.dt * k1 + (1408/2565) * self.dt * k3 + (2197/4104) * self.dt * k4 - (1/5) * self.dt * k5
        x_5 = x + (16/135) * self.dt * k1 + (6656/12825) * self.dt * k3 + (28561/56430) * self.dt * k4 - (9/50) * self.dt * k5 + (2/55) * self.dt * k6

        # Compute the error
        error = jnp.abs(x_5 - x_4)

        # if np.max(error) < tol:
        #     t = t + dt
        #     y = x_5

        # Adjust step size based on error
        # Calculate new step size
        delta = 0.84 * (tol / np.max(error))**0.25
        delta = min(1.1, max(0.1, delta))
        dt = dt * delta

        return x_5, dt
        
    def simulate(self, x0: jnp.ndarray, t0: float, tf: float) -> tuple[jnp.ndarray, jnp.ndarray]:
        times, trajectory = self._simulate(x0, t0, tf)        
        return times, trajectory

    def _simulate(self, x0: jnp.ndarray, t0: float, tf: float) -> tuple[jnp.ndarray, jnp.ndarray]:
        # Create a loop using lax.while that integrates the system using the RK45 method and
        # stores the results in the trajectory array

        # Preallocate the trajectory array
        # Since we don't know the number of steps, we will use a python list to store the trajectory
        trajectory = [x0]
        times = [t0]
        x = x0
        dt = self.dt
        t = t0

        while t < tf:
            x, dt = self.step(x, t, dt)
            t += dt
            trajectory.append(x)
            times.append(t)
            print(f"t = {t}, dt = {dt} and x = {x}")

        
        return jnp.array(times), jnp.array(trajectory)
                

# MCK

In [147]:
# Define a simple m-c-k system
m = 1.0
c0 = 0.1
k = 1.0


def f(x: jnp.ndarray, t: float) -> jnp.ndarray:
    return jnp.array([
        x[1],                       # x' = v
        -c0*np.exp(-t)/m * x[1] - k/m * x[0]]   # a = -c/m * v - k/m * x
    )

# Create the system
system = NonlinearSystem(f)

# Create all the integrators
integrator_feuler = ForwardEulerIntegrator(0.01, system)
integrator_beuler = BackwardEulerIntegrator(0.01, system)
integrator_rk4 = RK4Integrator(0.01, system)
integrator_rk45 = RK45Integrator(0.01, system)

# Simulate the system
x = jnp.array([1.0, 0.0])
t = 0.0
t_end = 100.0

# Simulate the system using all the integrators
x_data = {}
t_data = {}

integrators = [
    # integrator_beuler,
    integrator_feuler,
    # integrator_rk4,
    # integrator_rk45,
]
for integrator in integrators:
    print(f"Simulating using {integrator.name} integrator")
   
    t_data[integrator.name], x_data[integrator.name] = integrator.simulate(x, t, t_end)
    print(f"\tSimulated using {integrator.name} integrator")

# Plot the results
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111)
for integrator in integrators:
    ax.plot(t_data[integrator.name], x_data[integrator.name][:,0], label=integrator.name)
ax.set_xlabel("Time (s)")
ax.set_ylabel("Displacement (m)")
ax.legend()
plt.show()

# Calculate the errors between the integrators
errors = {}
for integrator in integrators:
    errors[integrator.name] = jnp.linalg.norm(x_data[integrator.name] - x_data[integrators[0].name])

print(f"Errors between integrators: {errors}")


Simulating using Forward Euler integrator


TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[].
The error occurred while tracing the function while_body_fun at c:\Users\tryfo\anaconda3\envs\aero\Lib\site-packages\jax\_src\lax\control_flow\loops.py:1845 for while_loop. This concrete value was not available in Python because it depends on the value of the argument loop_carry[0].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError