In [1]:
from cc.model import LinearModel, LinearModelOptions, NonlinearModelOptions, NonlinearModel
from cc.model.eval_model import eval_model
from cc.collect import collect_sample
from cc.env import make_env, ModelBasedEnv
from cc.train import train_model, TrainingOptionsModel
from cc import save, load 
import jax.random as jrand
import optax  



In [2]:
time_limit = 10.0
control_timestep = 0.01

env = make_env("two_segments_v1", time_limit=time_limit, control_timestep=control_timestep, random=1)

In [3]:
train_sample = collect_sample(
    env,
    seeds_gp=[0,1,2,4,5,6,7,8,9,10,11,12,13,14],
    seeds_cos=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
)

test_sample = collect_sample(
    env, 
    seeds_gp=[15, 16, 17, 18, 19],
    seeds_cos=[2.5, 5.0, 7.5, 10.0, 12.5] # really shouldn't be called seeds, rather frequency
)

            If this message is display continuously then you probably forgot to compile the model or controller. 
            This can be fixed by calling `*model/controller* = equniox.filter_jit(*model/controller*).
            


In [29]:
options = NonlinearModelOptions(
    12, 1, 1, # state-size, input-size, output-size
    "EE", # integrate-method
    jrand.PRNGKey(1,), # seed for parameter init
    depth_f=2, # number of layers 
    width_f=25, # width of layers
    depth_g=0,
)

In [30]:
model = NonlinearModel(options)

In [31]:
import numpy as np


action = np.array([0.2])
# this returns a new model with an updated internal state
# and of course the actual prediction of the observation
new_model, predicted_obs = model(action)

            If this message is display continuously then you probably forgot to compile the model or controller. 
            This can be fixed by calling `*model/controller* = equniox.filter_jit(*model/controller*).
            


In [32]:
predicted_obs

OrderedDict([('xpos_of_segment_end',
              DeviceArray([0.01827196], dtype=float32))])

In [33]:
type(new_model)

cc.model.model.NonlinearModel

In [35]:
training_options = TrainingOptionsModel(
    optax.adam(3e-3), 0.05, 1000, 1, True 
)

# requires ~25 seconds on my PC
# achieves a Test-Loss of ~4.2 on v1
model, losses = train_model(model, train_sample, training_options=training_options, test_sample=test_sample)

  0%|          | 0/1000 [00:00<?, ?it/s]

            If this message is display continuously then you probably forgot to compile the model or controller. 
            This can be fixed by calling `*model/controller* = equniox.filter_jit(*model/controller*).
            
            If this message is display continuously then you probably forgot to compile the model or controller. 
            This can be fixed by calling `*model/controller* = equniox.filter_jit(*model/controller*).
            
            If this message is display continuously then you probably forgot to compile the model or controller. 
            This can be fixed by calling `*model/controller* = equniox.filter_jit(*model/controller*).
            
            If this message is display continuously then you probably forgot to compile the model or controller. 
            This can be fixed by calling `*model/controller* = equniox.filter_jit(*model/controller*).
            


Trainings-Loss:     0.8498 | Test-Loss:     7.1419 | ParamsRegu:     0.0259:   0%|          | 1/1000 [00:01<19:37,  1.18s/it]

            If this message is display continuously then you probably forgot to compile the model or controller. 
            This can be fixed by calling `*model/controller* = equniox.filter_jit(*model/controller*).
            
            If this message is display continuously then you probably forgot to compile the model or controller. 
            This can be fixed by calling `*model/controller* = equniox.filter_jit(*model/controller*).
            
            If this message is display continuously then you probably forgot to compile the model or controller. 
            This can be fixed by calling `*model/controller* = equniox.filter_jit(*model/controller*).
            
            If this message is display continuously then you probably forgot to compile the model or controller. 
            This can be fixed by calling `*model/controller* = equniox.filter_jit(*model/controller*).
            


Trainings-Loss:     0.3967 | Test-Loss:     4.2053 | ParamsRegu:     0.0280: 100%|██████████| 1000/1000 [00:25<00:00, 39.09it/s]


In [36]:
# small little helper function
predicted_observation, test_rmse = eval_model(model, test_sample)

            If this message is display continuously then you probably forgot to compile the model or controller. 
            This can be fixed by calling `*model/controller* = equniox.filter_jit(*model/controller*).
            
            If this message is display continuously then you probably forgot to compile the model or controller. 
            This can be fixed by calling `*model/controller* = equniox.filter_jit(*model/controller*).
            


In [37]:
predicted_observation["xpos_of_segment_end"].shape

(10, 1001, 1)

In [38]:
test_rmse

OrderedDict([('xpos_of_segment_end', DeviceArray(1.9408518, dtype=float32))])

Finally, you can also replace the `Mujoco` physics component in your environment with your model. 

This creates a new environment that looks exactly the same from outside.

In [39]:
env_model = ModelBasedEnv(env, model, time_limit=time_limit, control_timestep=control_timestep)

In [40]:
env_model.step([0.2])

TimeStep(step_type=<StepType.MID: 1>, reward=array(0.), discount=array(1., dtype=float32), observation=OrderedDict([('xpos_of_segment_end', array([0.04514713], dtype=float32))]))

In [41]:
# save model 
save(model, "model_for_two_segments_v1.pkl")

In [42]:
load("model_for_two_segments_v1.pkl")

NonlinearModel(
  rhs=NonlinearRHS(
    f=MLP(
      layers=[
        Linear(
          weight=f32[25,13],
          bias=f32[25],
          in_features=13,
          out_features=25,
          use_bias=True
        ),
        Linear(
          weight=f32[25,25],
          bias=f32[25],
          in_features=25,
          out_features=25,
          use_bias=True
        ),
        Linear(
          weight=f32[12,25],
          bias=f32[12],
          in_features=25,
          out_features=12,
          use_bias=True
        )
      ],
      activation=<CompiledFunction>,
      final_activation=<function <lambda>>,
      in_size=13,
      out_size=12,
      width_size=25,
      depth=2
    ),
    g=MLP(
      layers=[
        Linear(
          weight=f32[1,12],
          bias=f32[1],
          in_features=12,
          out_features=1,
          use_bias=True
        )
      ],
      activation=<CompiledFunction>,
      final_activation=<function <lambda>>,
      in_size=12,
      out_si