In [None]:
import ring
import numpy as np
import qmt
import imt
import mediapy
import tree
import jax

In [2]:
path_pkl = "/Users/simon/Downloads/STS/STS_TK_correct1.pickle"
data = ring.utils.pickle_load(path_pkl)

In [3]:
graph = [-1, 0, 0, 2, 3, 0, 5, 6]
imu_to_body_convention = ['10B40AFE','10B40B14', "10B40AF8", "10B40AF4", "10B40AFA", "10B40AF9", "10B40B15", "10B40AF7"]

In [36]:
imu_data = {
    i: dict(acc=data[imu][["Acc_X", "Acc_Y", "Acc_Z"]], gyr=data[imu][["Gyr_X", "Gyr_Y", "Gyr_Z"]])
    for i, imu in enumerate(imu_to_body_convention)
}
imu_data = tree.map_structure(lambda df: df.to_numpy(), imu_data)

vqf = imt.solutions.VQF_Solution()
qhat = imt.Solver(graph, [vqf]*8, 0.01).step(imu_data)

In [46]:
sys_str = """
<x_xy model="gait">
    <options gravity="0 0 9.81" dt="0.01"/>
    <worldbody>
        <body name="0" joint="free" pos="0 0 2" damping="5 5 5 25 25 25">
            <body name="1" joint="spherical" pos=".3 0 -.1" damping="5 5 5">
                <geom type="xyz" dim=".06"/>
            </body>
            <geom type="xyz" dim=".08"/>
            <body name="2" joint="spherical" pos="0 -.15 0" damping="5 5 5">
                <geom type="xyz" dim=".06"/>
                <body name="3" joint="spherical" pos="-.3 0 0" damping="5 5 5">
                    <geom type="xyz" dim=".05"/>
                    <body name="4" joint="spherical" pos="-.25 0 0" damping="5 5 5">
                        <geom type="xyz" dim=".04"/>
                    </body>
                </body>
            </body>
            <body name="5" joint="spherical" pos="0 .15 0" damping="5 5 5">
                <geom type="xyz" dim=".06"/>
                <body name="6" joint="spherical" pos="-.3 0 0" damping="5 5 5">
                    <geom type="xyz" dim=".05"/>
                    <body name="7" joint="spherical" pos="-.3 0 0" damping="5 5 5">
                        <geom type="xyz" dim=".04"/>
                    </body>
                </body>
            </body>
        </body>
    </worldbody>
</x_xy>
"""

sys = ring.System.create(sys_str)

In [47]:
def dead_reckoning_position(quat, acc, dt):
    acc = qmt.rotate(quat, acc)
    # zero-velocity-update in earth frame
    acc -= np.mean(acc, axis=0)
    # integrate twice
    return np.cumsum(np.cumsum(acc, axis=0) * dt*dt, axis=0)

In [48]:
pos = dead_reckoning_position(qhat[0], imu_data[0]["acc"], 0.01)

In [49]:
qhat[-1] = np.array([1.0, 0, 0, 0])
parent_to_child_rots = [qmt.qinv(qmt.qrel(qhat[graph[i]], qhat[i])) for i in range(len(graph))]
q = np.concatenate([parent_to_child_rots[0]] + [pos] + parent_to_child_rots[1:], axis=-1)
x = jax.vmap(lambda q: ring.algorithms.forward_kinematics_transforms(sys, q)[0])(q)
mediapy.show_video(sys.render(x, camera="targetfar", height=720, width=1280, render_every_nth=4), fps=25)

Rendering frames..: 100%|██████████| 338/338 [00:02<00:00, 113.94it/s]


0
This browser does not support the video tag.
