In [34]:
import jax.numpy as jnp
import jax

# state space dim: 3
# control space dim: 2

def f(x, u):
  return x**2 + jnp.array([
    [1, 1],
    [1, 2],
    [2, 1]
  ]) @ jnp.sqrt(u)

x_test = jnp.array([
  [0],
  [1],
  [2]
])

u_test = jnp.array([
  [1],
  [1]
])

f(x_test, u_test)


Array([[2.],
       [4.],
       [7.]], dtype=float32)

In [35]:
dynamics_func_dx = jax.jacfwd(f, argnums=0)   # wrt first arg (x)
dynamics_func_du = jax.jacfwd(f, argnums=1)   # wrt second arg (u)

F_x = dynamics_func_dx(x_test.astype(float), u_test.astype(float))
F_u = dynamics_func_du(x_test.astype(float), u_test.astype(float))

F_x = F_x.squeeze((1,3))    # remove extra column-vector dimensions
F_u = F_u.squeeze((1,3))    # remove extra column-vector dimensions

print(F_x)
print(F_u)

[[0. 0. 0.]
 [0. 2. 0.]
 [0. 0. 4.]]
[[0.5 0.5]
 [0.5 1. ]
 [1.  0.5]]
