# F16 jax version

In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
from jax import random, vmap, jit
import jax
import jax.numpy as jnp
from F16_jax.F16Dynamics import update


n = 1000000
rng = jax.random.PRNGKey(42)
rng, _rng = jax.random.split(rng)
x_jnp = random.uniform(_rng, (n, 12))
u_jnp = random.uniform(_rng, (n, 5))
update_F16_vmap = jit(vmap(update, in_axes=(0, 0, None)))
result_F16_jnp = update_F16_vmap(x_jnp, u_jnp, 0.02)

In [None]:
%timeit update_F16_vmap(x_jnp, u_jnp, 0.02)

# F16 torch version

In [None]:
import torch
from torchdiffeq import odeint_adjoint as odeint
from F16_torch.F16Dynamics import F16Dynamics
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

dynamics = F16Dynamics(device)

    
def update_F16_torch(x, u, dt):
    x = odeint(dynamics,
                    torch.hstack((x, u)),
                    torch.tensor([0., dt], device=device),
                    method='euler')[1, :, :12]
    return x

x_tensor = torch.from_numpy(jax.device_get(x_jnp)).to(device)
u_tensor = torch.from_numpy(jax.device_get(u_jnp)).to(device)
result_F16_tensor = update_F16_torch(x_tensor, u_tensor, 0.02)

In [None]:
%timeit update_F16_torch(x, u, 0.02)

# J20 original version

In [7]:
from uav_plant.flight_dynamics_model.plane import Plane
import numpy as np

n = 1000000
cmdInput_numpy = None
result_J20_numpy = None
for _ in range(n):
    cmdInput = np.random.rand(12) * 1000 + 1000
    plane = Plane()
    plane.update(0.2, cmdInput)
    if cmdInput_numpy is None:
        cmdInput_numpy = cmdInput.reshape(1, -1)
        result_J20_numpy = plane.dynamics.motionState.state.reshape(1, -1)
    else:
        cmdInput_numpy = np.vstack((cmdInput_numpy, cmdInput.reshape(1, -1)))
        result_J20_numpy = np.vstack((result_J20_numpy, plane.dynamics.motionState.state.reshape(1, -1)))

# J20 jax version

In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
import jax
import jax.numpy as jnp
from J20_jax.flight_dynamics_model import plane

n = 1000000
latitude = 31.835 * jnp.ones(n)
longitude = 117.089 * jnp.ones(n)
altitude = 31.0 * jnp.ones(n)
roll = jnp.zeros(n)
pitch = jnp.zeros(n)
yaw = jnp.zeros(n)
velNED = jnp.zeros((n, 3))
angVel = jnp.zeros((n, 3))
accelNED = jnp.zeros((n, 3))
fuelVolume = -jnp.ones(n)
CSD = jnp.zeros((n, 6))
J20Plane = jax.jit(jax.vmap(plane.createPlane, 
                            in_axes=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)))(
                                latitude, longitude, altitude,
                                roll, pitch, yaw,
                                velNED, angVel, accelNED,
                                fuelVolume, CSD
                            )

In [2]:
key = jax.random.PRNGKey(42)
cmdInput_jnp = jax.random.normal(key, shape=(n, 12))
# cmdInput_jnp = jnp.array(cmdInput_numpy.reshape(-1, 12))
update_J20_vmap = jax.jit(jax.vmap(plane.update, in_axes=(0, None, 0)))
J20Plane = update_J20_vmap(J20Plane, 0.02, cmdInput_jnp)
result_J20_jnp = jnp.hstack((
    J20Plane.dynamics.motionState.position_NED,
    J20Plane.dynamics.motionState.velocity_Body,
    J20Plane.dynamics.motionState.quaternion_Body2NED,
    J20Plane.dynamics.motionState.angularSpeed_Body,
    J20Plane.dynamics.motionState.accel_Body
    ))

In [3]:
%timeit update_J20_vmap(J20Plane, 0.02, cmdInput_jnp)

56.1 ms ± 4.24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Compare results

In [7]:
result_F16_jnp = jax.device_get(result_F16_jnp)
result_F16_tensor = result_F16_tensor.cpu().numpy()

In [None]:

from scipy.stats import pearsonr


p1 = pearsonr(result_F16_jnp, result_F16_tensor)
print(p1)

In [None]:
from scipy.stats import pearsonr

result_J20_jnp = jax.device_get(result_J20_jnp)
p1 = pearsonr(result_J20_jnp, result_J20_numpy)
print(result_J20_numpy)
print(result_J20_jnp)
print(p1)