You may have noticed that saving of trained models / controllers is no longer as straightforward.

However, it is still possible. Let's see how.

In [1]:
from cc import save, load  
from cc.examples.feedforward_controller import make_feedforward_controller
from cc.env import make_env
import equinox as eqx 
from cc.env.collect import collect
import jax.numpy as jnp 

  from .autonotebook import tqdm as notebook_tqdm




In [2]:
env = make_env("two_segments_v1", random=1)
controller = make_feedforward_controller(jnp.ones((1000,1)))

`ReplaySample` objects can still be stored using the more convenient `load` and `save` functions.

In [3]:
replay_sample = collect(env, controller)

save(replay_sample, "replay_sample.pkl"); 
_ = load("replay_sample.pkl")

However, for objects that inherit from `eqx.Module` this will not work

In [4]:
save(controller, "controller.pkl")

Exception: Not possible. Use `eqx.tree_serialise_leaves(path, obj)` instead.
            To de-serialise use `eqx.tree_deserialise_leaves`.

Instead, we will have to use the more verbose syntax

In [5]:
# the extension is convention
# think of this operation as only dumping the parameters on disk
# but not the controller structure
eqx.tree_serialise_leaves("controller.eqx", controller)

# hence to de-serialise we will have to provide the structure 
# and the parameters will be overwritten
controller = eqx.tree_deserialise_leaves("controller.eqx", controller)

In [6]:
controller

FeedforwardController(us=f32[1000,1], count=i32[1])

What if (you still try to do this using `load`)?

In [7]:
load("controller.eqx")

UnpicklingError: STACK_GLOBAL requires str