In [2]:
import time

import jax.numpy as jnp
import numpy as np

from vehicle_model_jax import evalf, get_default_params

In [8]:
p = get_default_params()
p_tuple = tuple(p.values())

In [11]:
# 1. Steady State
x0 = jnp.zeros(10)
u = jnp.array([0.0, 0.0])

f = evalf(x0, p_tuple, u)
print(f)

assert np.all(f == 0.0)

[ 0.  0. -0.  0.  0.  0.  0.  0.  0.  0.]


In [27]:
# 2. Constant Speed
x0 = jnp.zeros(10)
x0 = x0.at[7].set(10.0)  # v
u = jnp.array([10.0, 0.0])

f = evalf(x0, p_tuple, u)
print(f)

assert f[4] == x0[7]  # velocity
assert f[7] < 0.0  # wind

[ 0.0000000e+00 -5.5555555e+04 -0.0000000e+00  0.0000000e+00
  1.0000000e+01  0.0000000e+00  0.0000000e+00 -2.0000000e-01
  0.0000000e+00  0.0000000e+00]


In [19]:
# 3. Speed-Up
x0 = jnp.zeros(10)
u = jnp.array([20.0, 0.0])

f = evalf(x0, p_tuple, u)
print(f)

assert f[9] == u[0]

[1.1070818e+03 7.0206531e+04 2.2141637e-01 2.1061958e+01 0.0000000e+00
 0.0000000e+00 0.0000000e+00 0.0000000e+00 0.0000000e+00 2.0000000e+01]


In [31]:
# 4. Spin
x0 = jnp.zeros(10)
u = jnp.array([0.0, 1.0])

f = evalf(x0, p_tuple, u)
print(f)

assert f[8] > 0.0  # angular acceleration
assert np.all(f[:8] == 0.0) & np.all(f[9:] == 0.0)

[ 0.          0.         -0.          0.          0.          0.
  0.          0.          0.39999998  0.        ]


In [32]:
# 5. Turn
x0 = jnp.zeros(10)
x0 = x0.at[7].set(10.0)  # v
u = jnp.array([10.0, 1.0])

f = evalf(x0, p_tuple, u)
print(f)

assert f[4] == x0[7]  # velocity
assert f[7] < 0.0  # wind
assert f[8] > 0.0  # angular acceleration

[ 0.0000000e+00 -5.5555555e+04 -0.0000000e+00  0.0000000e+00
  1.0000000e+01  0.0000000e+00  0.0000000e+00 -2.0000000e-01
  3.9999998e-01  0.0000000e+00]
