In [None]:
from cabinet_robot.joint_estimation import FGJointEstimator, sturm_twist_estimation, EstimationResults
import jax
import numpy as np
from jaxlie import SE3 as jaxlie_SE3

%reload_ext autoreload
%autoreload 2

In [None]:
# jax.config.update("jax_disable_jit", True)


In [None]:
original_part_pose = np.eye(4)
original_part_pose[2, 3] = 0.1

part_poses = [np.copy(original_part_pose)]
for _ in range(10):
    original_part_pose[0, 3] += 0.01
    original_part_pose[1, 3] += 0.005
    part_poses.append(np.copy(original_part_pose))


In [None]:
estimator = FGJointEstimator()


In [None]:
graph = estimator._build_graph(11) # builds but does not compile so should run fast
graph

In [None]:
estimator.get_compiled_graph(11) # builds and compiles so should run slow the first time

In [None]:

estimation = estimator.estimate_joint_twist(part_poses)
print(f"{estimation.twist=}")
print(f"{estimation.twist_frame_in_base_pose=}")

In [None]:
estimation.aux_data

In [None]:
# test: latent_pose second[0] = base_transform @ exp(twist * joint_state[0])

print(
    estimation.twist_frame_in_base_pose.as_matrix()
    @ jaxlie_SE3.exp(estimation.twist * estimation.current_joint_configuration).as_matrix()
)
print(
    estimation.twist_frame_in_base_pose.as_matrix()
    @ jaxlie_SE3.exp(estimation.twist * estimation.aux_data["joint_states"][0]).as_matrix()
)

import spatialmath.base as sm

# twist_in_poses_frame = sm.tr2adjoint(np.asarray(results.base_transform.as_matrix())) @ np.asarray(results.twist)
# print(jaxlie_SE3.exp(twist_in_poses_frame * results.aux_data["joint_states"][5]).as_matrix())


In [None]:
print(estimation.aux_data["joint_states"])

## Explore methods for joint configuration estimation

In [None]:

#  part pose = = base_transform @ exp(twist * joint_state)
# so joint_state = log(base_transform^-1 @ part_pose) / twist
# but this last division can suffer from numerical issues...

part_pose = np.asarray(estimation.aux_data["latent_poses"]["second"][-1].as_matrix())
base_transform = np.asarray(estimation.twist_frame_in_base_pose.as_matrix())
twist = np.asarray(estimation.twist)

print(f"the GT joint configuration is {estimation.current_joint_configuration}")


In [None]:
# naive joint configuration estimation -> numerical errors can dominate 
import spatialmath.base as sm

pose_in_twist_frame = sm.trinv(base_transform) @ part_pose
print(pose_in_twist_frame)
joint_state = sm.trlog(pose_in_twist_frame, twist=True,check=False) / twist
print(joint_state)
print(f"naive estimation = {np.mean(joint_state)}")

In [None]:
# so weigh every factor according to the relative magnitude of the part twist to make a more robust estimate
pose_twist = sm.trlog(pose_in_twist_frame, twist=True,check=False)
normalized_pose_twist = pose_twist / np.linalg.norm(pose_twist,ord=1)
print(pose_twist/twist)
print(np.abs(normalized_pose_twist))
print(f"more robust estimation = {np.sum(pose_twist/twist * np.abs(normalized_pose_twist))}")

In [None]:
# but even better method is required...
# do some explicit oultier detection or something?? 
# best option might be to use the results of the FG (which is however only valid in int the Twist frame of the FG..)

### Can we get rid of the 'twist frame' by expressing the twist in the base frame?

In [None]:
# Twist_expressed_in_base = sm.tr2adjoint(base_transform) @ twist
twist_expressed_in_base = sm.tr2adjoint(base_transform) @ twist
print(twist_expressed_in_base)
print(twist)

In [None]:
# now the twist is expressed in the base frame
# part_pose = part_pose[0] @ exp(twist * joint_state')


In [None]:
import rerun 
rerun.init("test-joint-estimation", spawn=True)


In [None]:
rerun.log_points(
    "part_poses",
    positions=np.array(part_poses)[:, :3, 3],
    colors=np.zeros((len(part_poses), 3), dtype=np.uint8),
    radii=0.01,
)

In [None]:
from cabinet_robot.visualisation import visualize_estimation
visualize_estimation(estimation)