This example is available as a jupyter notebook [here](https://github.com/SimiPixel/x_xy_v2/blob/main/docs/notebooks/machine_learning.ipynb).

And on Google Colab [here](https://colab.research.google.com/github/SimiPixel/x_xy_v2/blob/main/docs/notebooks/machine_learning.ipynb)

Setup the environment if this is executed on Google Colab.

*Make sure to change the runtime type to `GPU`. To do this go to `Runtime` -> `Change runtime type` -> `GPU`*

Otherwise, rendering won't work in Google Colab.

In [1]:
import os

try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False

if IN_COLAB:
    os.system("pip install --quiet 'x_xy[all_muj] @ git+https://github.com/SimiPixel/x_xy_v2'")
    os.system("pip install --quiet mediapy")

import x_xy
# automatically detects colab or not
x_xy.utils.setup_colab_env()

In [2]:
from x_xy.subpkgs import ml, exp, sys_composer, sim2real
import mediapy
import jax.numpy as jnp
import tree_utils

In [3]:
def load_data_and_prediction(motion, sys, params):
    exp_data = exp.load_data("S_04", motion)
    xml_str = exp.load_xml_str("S_04")
    xs = sim2real.xs_from_raw(sys, exp.link_name_pos_rot_data(exp_data, xml_str), qinv=True)

    # slightly decrease `transform1.pos.x` by a little; purely for better optics
    translations, rotations = sim2real.unzip_xs(sys, xs)
    seg_mask = jnp.array([sys.name_to_idx(name) for name in sys.link_names[1:] if name[:3] != "imu"])
    imu_mask = jnp.array([sys.name_to_idx(name) for name in sys.link_names[1:] if name[:3] == "imu"])
    translations = translations.replace(pos=translations.pos.at[:, seg_mask, 0].set(translations.pos[:, seg_mask, 0] - 0.03))
    translations = translations.replace(pos=translations.pos.at[:, imu_mask, 0].set(translations.pos[:, imu_mask, 0] + 0.03))
    xs_translated = sim2real.zip_xs(sys, translations, rotations)

    X = {seg: {} for seg in ["seg2", "seg3", "seg4"]}
    for seg in X:
        imu_data = exp_data[seg]["imu_rigid"]
        imu_data.pop("mag")
        if seg == "seg3":
            imu_data = tree_utils.tree_zeros_like(imu_data)
        X[seg].update(imu_data)

    sys_noimu, _ = sys_composer.make_sys_noimu(sys)
    filter = ml.RNNOFilter(params=params)
    filter.init(sys_noimu, tree_utils.tree_slice(X, 0))
    yhat = tree_utils.tree_slice(filter.predict(tree_utils.add_batch_dim(X)), 0)
    return xs_translated, yhat


In [4]:
params = ml.load(pretrained="rr_rr_unknown")
motion = "thomas_fast"
sys = exp.load_sys("S_04", morph_yaml_key="seg2", delete_after_morph=["seg5", "imu3"])

xs, yhat = load_data_and_prediction(motion, sys, params)
frames = x_xy.render_prediction(sys, xs, yhat, stepframe=4, width=640, height=480, camera="c", 
                         add_cameras={-1: '<camera name="c" mode="targetbody" target="3" pos=".5 -.5 1.25"/>',})

Rendering frames..: 100%|██████████| 1150/1150 [00:07<00:00, 160.92it/s]


In [7]:
mediapy.show_video(frames, fps=25.0)

0
This browser does not support the video tag.
