In [None]:
import numpy as onp
import jax
import jaxlie
import trimesh
import viser
from jax import numpy as jnp


from solution import solve_spline_1d, evaluate_spline_1d

%load_ext autoreload
%autoreload 2

In [None]:
x0 = jnp.array([0.0, 0.0, 0.0])
gates = jnp.array([[1.0, 2.0], [0.75, -0.5], [0.1, 0.6]])
dt = jnp.array([2.0, 1.0])
coeffs = jax.vmap(solve_spline_1d, in_axes=(0, 0, None))(x0, gates, dt)

ts = jnp.linspace(0, sum(dt), 100)

def eval_at_t(t: float) -> jnp.ndarray:
    return jax.vmap(evaluate_spline_1d, in_axes=(0, None, None))(coeffs, t, dt)[0]

trajectory = jax.vmap(eval_at_t)(ts)

In [None]:
if 'server' not in dir() or server is None:
    server = viser.ViserServer()
else:
    server.scene.reset()


server.scene.add_grid("grid")

server.scene.add_line_segments(
    "trajectory",
    onp.concatenate([trajectory[:-1, None, :], trajectory[1:, None, :]], axis=1),
    colors=(1.0, 0.0, 0.0),
    line_width=2.5,
)
server.scene.add_frame(
    "init_pos",
    position=onp.array(x0),
    axes_length=0.25)

g1 = server.scene.add_frame(
    "gate1",
    position=onp.array(gates[:, 0],),
    axes_length=0.25
)
g2 = server.scene.add_frame(
    "gate2",
    position=onp.array(gates[:, 1]),
    axes_length=0.25
)
server.scene.show()

In [None]:
from solution import DroneState, simulate_trajectory

init_state = DroneState(jnp.zeros((3,)), jnp.eye(3), jnp.zeros((3,)), jnp.zeros((3,)))

sim_dt = 0.01
coeffs = jax.vmap(solve_spline_1d, in_axes=(0, 0, None))(init_state.posW, gates, dt)
states = simulate_trajectory(coeffs, dt, sim_dt, init_state)

if 'server' not in dir() or server is None:
    server = viser.ViserServer()
else:
    server.scene.reset()


server.scene.add_grid("grid")
server.scene.add_line_segments(
    "trajectory",
    onp.concatenate([states.posW[:-1, None, :], states.posW[1:, None, :]], axis=1),
    colors=(0.0, 0.0, 1.0),
    line_width=2.5,
)
server.scene.add_line_segments(
    "desired_trajectory",
    onp.concatenate([trajectory[:-1, None, :], trajectory[1:, None, :]], axis=1),
    colors=(1.0, 0.0, 0.0),
    line_width=2.5,
)
server.scene.add_frame(
    "init_pos",
    position=onp.array(init_state.posW),
    axes_length=0.25)
g1 = server.scene.add_frame(
    "gate1",
    position=onp.array(gates[:, 0],),
    axes_length=0.25
)
g2 = server.scene.add_frame(
    "gate2",
    position=onp.array(gates[:, 1]),
    axes_length=0.25
)

drone_mesh = trimesh.load("data/quadcopter_drone.glb").apply_transform(trimesh.transformations.rotation_matrix(onp.pi/2, [1, 0, 0]))
drone = server.scene.add_mesh_trimesh("drone", drone_mesh, scale=0.25)

serializer = server.get_scene_serializer()
for i in range(states.posW.shape[0]):
    drone.position = onp.array(states.posW[i])
    drone.wxyz = onp.array(jaxlie.SO3.from_matrix(states.R[i]).wxyz)
    serializer.insert_sleep(sim_dt)

serializer.show()