# Training the RNNo with a custom loss function

This notebook showcases how train an RNNo network with a custom loss function rather than the default mean-reduces angle error. This is showcased by scaling the error by a softmax over the time axis, which puts more weight on the time intervals with a higher deviation compared to ones with lower deviation.

In [None]:
import jax
import jax.numpy as jnp
import tree_utils
from jax.nn import softmax
import matplotlib.pyplot as plt
import mediapy

import x_xy
from x_xy.subpkgs import ml, sim2real, sys_composer

Set the batch size and number of training episodes according to the available hardware.

In [None]:
BATCHSIZE = 32
NUM_TRAINING_EPISODES = 1500

## Defining the systems

We use two separate systems, both parsed from XML strings: one for training (`sys`) and one for inference (`dustin_sys`).

In [None]:
sys_str = r"""
<x_xy model="three_segment_kinematic_chain">
    <options gravity="0 0 9.81" dt="0.01"/>
    <defaults>
        <geom color="orange"/>
    </defaults>
    <worldbody>
        <body name="seg2" joint="free" pos="0 0 2">
            <geom type="box" mass="0.1" pos="0.5 0 0" dim="1 0.25 0.2"/>
            <body name="seg1" joint="ry">
                <geom type="box" mass="0.1" pos="-0.5 0 0" dim="1 0.25 0.2"/>
                <body name="imu1" joint="frozen" pos="-0.5 0 0.125">
                    <geom type="box" mass="0.05" dim="0.2 0.2 0.05" color="red"/>
                </body>
            </body>
            <body name="seg3" joint="rz" pos="1 0 0">
                <geom type="box" mass="0.1" pos="0.5 0 0" dim="1 0.25 0.2"/>
                <body name="imu2" joint="frozen" pos="0.5 0 -0.125">
                    <geom type="box" mass="0.05" dim="0.2 0.2 0.05" color="red"/>
                </body>
            </body>
        </body>
    </worldbody>
</x_xy>
"""
sys = x_xy.io.load_sys_from_str(sys_str)

In [None]:
dustin_exp_xml_seg1 = r"""
<x_xy model="dustin_exp">
    <options gravity="0 0 9.81" dt="0.01"/>
    <defaults>
        <geom color="white"/>
    </defaults>
    <worldbody>
        <body name="seg1" joint="free">
            <geom type="box" mass="10" pos="-0.5 0 0" dim="1 0.25 0.2"/>
            <body name="seg2" joint="ry">
                <geom type="box" mass="10" pos="0.5 0 0" dim="1 0.25 0.2"/>
                <body name="seg3" joint="rz" pos="0.2 0 0" >
                    <geom type="box" mass="10" pos="0.5 0 0" dim="1 0.25 0.2"/>
                </body>
            </body>
        </body>
    </worldbody>
</x_xy>
"""
sys_inference = x_xy.io.load_sys_from_str(dustin_exp_xml_seg1)

## Generating the motion data

Our motion data will be automatically generated using a `Generator`, which can be customised using an `MotionConfig`. The `Generator` will generate data for both `q`, that is the state of all the joint angles in the system, as well as `xs`, which describes the orientations of all the links in the system. To use this data for training our RNNo, we first have to bring it into the correct form using a `finalise_fn`.

In [None]:
def finalise_fn(key: jax.Array, q: jax.Array, xs: x_xy.Transform, sys: x_xy.System):
    def xs_by_name(name: str):
        return xs.take(sys.name_to_idx(name), axis=1)

    key, *consume = jax.random.split(key, 3)

    # the input X to our RNNo is the IMU data of segments 1 and 3
    X = {
        "seg1": x_xy.imu(xs_by_name("imu1"), sys.gravity, sys.dt, consume[0], True),
        "seg3": x_xy.imu(xs_by_name("imu2"), sys.gravity, sys.dt, consume[1], True),
    }

    # seg2 has no IMU, but we still need to make an entry in our X
    X["seg2"] = tree_utils.tree_zeros_like(X["seg1"])

    # the output of the RNNo is the estimated relative poses of our segments
    y = x_xy.algorithms.rel_pose(sys_scan=sys_inference, xs=xs, sys_xs=sys)

    return X, y

config = x_xy.algorithms.MotionConfig(dpos_max=0.3, ang0_min=0.0, ang0_max=0.0)

gen = x_xy.build_generator(sys, config, finalize_fn=finalise_fn)
gen = x_xy.batch_generator(gen, BATCHSIZE)

## Custom loss function

To customise the loss function of the RNNo, we transform the error values before they are averaged. The input to our loss function will be both $q$, the real joint state, as well as $\hat{q}$, the joint space estimated by our RNNo. `q` and `q_hat` will both be `jax.Array`s of shape `(T_tbp, 4)`, where the first axis is slice over time (of our TBPTT length) and the second axis are the 4 components of a quaternion.

In this notebook we want to change the relative weightings of the errors at different times using a softmax function in order to put more weight on larger errors. First we convert the errors from quaterions to angles. Then we scale each error angle by a factor, calculated from a softmax over the angles. The calculation of the factors includes a call to `jax.lax.stop_gradient` to make it so our gradients are only from the errors themselves, not the factors as well.

In [None]:
def make_loss_fn(beta):
    def metric_fn(q, q_hat):
        return x_xy.maths.angle_error(q, q_hat) ** 2

    if beta is not None:

        def loss_fn(q, q_hat):
            # q.shape == q_hat.shape == (1000, 4)
            angles = metric_fn(q, q_hat)

            factors = angles.shape[-1] * softmax(
                beta * jax.lax.stop_gradient(angles), axis=-1
            )

            errors = factors * angles

            return errors

    else:
        loss_fn = metric_fn

    return loss_fn

`beta` determines the strength of our weighting: the larger beta, the more relative weight we put on the larger errors, while `beta = 0.0` makes the scaling factors uniform one and gives us back our unweighted errors. Alternatively `beta = None` bypasses the scaling altogether. 

In [None]:
beta = 1.0

In [None]:
rnno = ml.make_rnno(sys_inference)

loss_fn = make_loss_fn(beta)

save_params = ml.callbacks.SaveParamsTrainingLoopCallback(
    "parameters.pickle", upload=False
)

ml.train(gen, NUM_TRAINING_EPISODES, rnno, callbacks=[save_params], loss_fn=loss_fn)

To visualise our network, we can render it using mediapy. First we generate some motion data.

In [None]:
gen = x_xy.build_generator(sys, config)

key = jax.random.PRNGKey(1)

q, xs = gen(key)

We need to again bring the motion data in the correct form for our RNNo and can then run inference of the generated data.

In [None]:
params = ml.load("parameters.pickle")

X, y = finalise_fn(key, q, xs, sys)

X = tree_utils.add_batch_dim(X)

_, state = rnno.init(key, X)

state = tree_utils.add_batch_dim(state)

y_hat, _ = rnno.apply(params, state, X)
y_hat = tree_utils.to_2d_if_3d(y_hat, strict=True)

First we want to plot the angle error for both segment 2 and segment 3 over time.

In [None]:
y["seg2"][:10]

In [None]:
y_hat["seg2"]

In [None]:

fig, ax = plt.subplots()

angle_error2 = jnp.rad2deg(x_xy.maths.angle_error(y["seg2"], y_hat["seg2"]))
angle_error3 = jnp.rad2deg(x_xy.maths.angle_error(y["seg3"], y_hat["seg3"]))

T = jnp.arange(angle_error2.size) * sys_inference.dt

ax.plot(T, angle_error2, label="seg2")
ax.plot(T, angle_error3, label="seg3")

ax.set_xlabel("time [s]")
ax.set_ylabel("abs. angle error [deg]")

ax.legend()

plt.show()

Next we have to create an `xs_hat` of the estimated orientations, so that we can render them.

In [None]:
# Extract translations from data-generating system...
translations, rotations = sim2real.unzip_xs(
    sys_inference, sim2real.match_xs(sys_inference, xs, sys)
)

y_hat_inv = jax.tree_map(lambda quat: x_xy.maths.quat_inv(quat), y_hat) 

# ... swap rotations with predicted ones...
rotations_hat = [] 
for i, name in enumerate(sys_inference.link_names):
    if name in y_hat_inv:
        rotations_name = x_xy.Transform.create(rot=y_hat_inv[name])
    else:
        rotations_name = rotations.take(i, axis=1)
    rotations_hat.append(rotations_name)

# ... and plug the positions and rotations back together.
rotations_hat = rotations_hat[0].batch(*rotations_hat[1:]).transpose((1, 0, 2))
xs_hat = sim2real.zip_xs(sys_inference, translations, rotations_hat)

# Create combined system that shall be rendered and its transforms
sys_render = sys_composer.inject_system(sys, sys_inference.add_prefix_suffix(suffix="_hat"))
xs_render = x_xy.Transform.concatenate(xs, xs_hat, axis=1)

Now we can render both the predicted system (in white) as well as the real system (in orange).

In [None]:
xs_list = [xs_render[i] for i in range(xs_render.shape())]

frames = x_xy.render(sys_render, xs_list, camera="targetfar")
mediapy.show_video([frame[..., :3] for frame in frames], fps=int(1 / sys.dt))
