Skip to content

Commit

Permalink
move python computation logic into c++ (#208)
Browse files Browse the repository at this point in the history
Moved python computation in `_to_dm()` function in `dm_envpool.py` to
c++ implementation.

After fix, envpool's speedup increased from ~0.7x to ~0.9x:

```
Namespace(domain='cheetah', seed=0, task='run', total_step=200000)
100%|███████████████████████████████████████████████████████████████████████████████████████████████| 200000/200000 [00:16<00:00, 11947.27it/s]
FPS(dmc) = 11945.41
100%|███████████████████████████████████████████████████████████████████████████████████████████████| 200000/200000 [00:17<00:00, 11233.41it/s]
FPS(envpool) = 11233.09
EnvPool Speedup: 0.94x
```

Added two fields in `State`: 

- discount: default by (1.0 - `done`)
- step_type: aligned with `dm_env.StepType` value
  • Loading branch information
wangsiping97 committed Oct 26, 2022
1 parent fee5a0a commit 93474cf
Show file tree
Hide file tree
Showing 19 changed files with 72 additions and 58 deletions.
1 change: 0 additions & 1 deletion envpool/atari/atari_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ class AtariEnvFns {
{conf["stack_num"_] * (conf["gray_scale"_] ? 1 : 3),
conf["img_height"_], conf["img_width"_]},
{0, 255})),
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})),
"info:lives"_.Bind(Spec<int>({-1})),
"info:reward"_.Bind(Spec<float>({-1})),
"info:terminated"_.Bind(Spec<int>({-1}, {0, 1})));
Expand Down
8 changes: 7 additions & 1 deletion envpool/core/env.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,13 @@ class Env {
State Allocate(int player_num = 1) {
slice_ = sbq_->Allocate(player_num, order_);
State state(&slice_.arr);
state["done"_] = IsDone();
bool done = IsDone();
state["done"_] = done;
state["discount"_] = static_cast<float>(!done);
// dm_env.StepType.FIRST == 0
// dm_env.StepType.MID == 1
// dm_env.StepType.LAST == 2
state["step_type"_] = current_step_ == 0 ? 0 : done ? 2 : 1;
state["info:env_id"_] = env_id_;
state["elapsed_step"_] = current_step_;
int* player_env_id(static_cast<int*>(state["info:players.env_id"_].Data()));
Expand Down
4 changes: 3 additions & 1 deletion envpool/core/env_spec.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ auto common_state_spec =
MakeDict("info:env_id"_.Bind(Spec<int>({})),
"info:players.env_id"_.Bind(Spec<int>({-1})),
"elapsed_step"_.Bind(Spec<int>({})), "done"_.Bind(Spec<bool>({})),
"reward"_.Bind(Spec<float>({-1})));
"reward"_.Bind(Spec<float>({-1})),
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})),
"step_type"_.Bind(Spec<int>({})));

/**
* EnvSpec funciton, it constructs the env spec when a Config is passed.
Expand Down
7 changes: 4 additions & 3 deletions envpool/mujoco/dmc/acrobot.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ class AcrobotEnvFns {
template <typename Config>
static decltype(auto) StateSpec(const Config& conf) {
return MakeDict("obs:orientations"_.Bind(Spec<mjtNum>({4})),
"obs:velocity"_.Bind(Spec<mjtNum>({2})),
"obs:velocity"_.Bind(Spec<mjtNum>({2}))
#ifdef ENVPOOL_TEST
"info:qpos0"_.Bind(Spec<mjtNum>({2})),
,
"info:qpos0"_.Bind(Spec<mjtNum>({2}))
#endif
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})));
); // NOLINT
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
Expand Down
7 changes: 4 additions & 3 deletions envpool/mujoco/dmc/ball_in_cup.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ class BallInCupEnvFns {
template <typename Config>
static decltype(auto) StateSpec(const Config& conf) {
return MakeDict("obs:position"_.Bind(Spec<mjtNum>({4})),
"obs:velocity"_.Bind(Spec<mjtNum>({4})),
"obs:velocity"_.Bind(Spec<mjtNum>({4}))
#ifdef ENVPOOL_TEST
"info:qpos0"_.Bind(Spec<mjtNum>({4})),
,
"info:qpos0"_.Bind(Spec<mjtNum>({4}))
#endif
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})));
); // NOLINT
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
Expand Down
7 changes: 4 additions & 3 deletions envpool/mujoco/dmc/cartpole.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,13 @@ class CartpoleEnvFns {
" for dmc cartpole.");
}
return MakeDict("obs:position"_.Bind(Spec<mjtNum>({1 + 2 * n_poles})),
"obs:velocity"_.Bind(Spec<mjtNum>({1 + n_poles})),
"obs:velocity"_.Bind(Spec<mjtNum>({1 + n_poles}))
#ifdef ENVPOOL_TEST
,
"info:qpos0"_.Bind(Spec<mjtNum>({1 + n_poles})),
"info:qvel0"_.Bind(Spec<mjtNum>({1 + n_poles})),
"info:qvel0"_.Bind(Spec<mjtNum>({1 + n_poles}))
#endif
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})));
); // NOLINT
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
Expand Down
7 changes: 4 additions & 3 deletions envpool/mujoco/dmc/cheetah.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ class CheetahEnvFns {
template <typename Config>
static decltype(auto) StateSpec(const Config& conf) {
return MakeDict("obs:position"_.Bind(Spec<mjtNum>({8})),
"obs:velocity"_.Bind(Spec<mjtNum>({9})),
"obs:velocity"_.Bind(Spec<mjtNum>({9}))
#ifdef ENVPOOL_TEST
"info:qpos0"_.Bind(Spec<mjtNum>({9})),
,
"info:qpos0"_.Bind(Spec<mjtNum>({9}))
#endif
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})));
); // NOLINT
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
Expand Down
7 changes: 4 additions & 3 deletions envpool/mujoco/dmc/finger.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,13 @@ class FingerEnvFns {
"obs:velocity"_.Bind(Spec<mjtNum>({3})),
"obs:touch"_.Bind(Spec<mjtNum>({2})),
"obs:target_position"_.Bind(Spec<mjtNum>({2})),
"obs:dist_to_target"_.Bind(Spec<mjtNum>({})),
"obs:dist_to_target"_.Bind(Spec<mjtNum>({}))
#ifdef ENVPOOL_TEST
,
"info:qpos0"_.Bind(Spec<mjtNum>({3})),
"info:target"_.Bind(Spec<mjtNum>({1})),
"info:target"_.Bind(Spec<mjtNum>({1}))
#endif
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})));
); // NOLINT
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
Expand Down
7 changes: 4 additions & 3 deletions envpool/mujoco/dmc/fish.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ class FishEnvFns {
return MakeDict("obs:joint_angles"_.Bind(Spec<mjtNum>({7})),
"obs:upright"_.Bind(Spec<mjtNum>({})),
"obs:velocity"_.Bind(Spec<mjtNum>({13})),
"obs:target"_.Bind(Spec<mjtNum>({3})),
"obs:target"_.Bind(Spec<mjtNum>({3}))
#ifdef ENVPOOL_TEST
,
"info:qpos0"_.Bind(Spec<mjtNum>({14})),
"info:target0"_.Bind(Spec<mjtNum>({3})),
"info:target0"_.Bind(Spec<mjtNum>({3}))
#endif
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})));
); // NOLINT
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
Expand Down
7 changes: 4 additions & 3 deletions envpool/mujoco/dmc/hopper.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ class HopperEnvFns {
static decltype(auto) StateSpec(const Config& conf) {
return MakeDict("obs:position"_.Bind(Spec<mjtNum>({6})),
"obs:velocity"_.Bind(Spec<mjtNum>({7})),
"obs:touch"_.Bind(Spec<mjtNum>({2})),
"obs:touch"_.Bind(Spec<mjtNum>({2}))
#ifdef ENVPOOL_TEST
"info:qpos0"_.Bind(Spec<mjtNum>({7})),
,
"info:qpos0"_.Bind(Spec<mjtNum>({7}))
#endif
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})));
); // NOLINT
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
Expand Down
7 changes: 4 additions & 3 deletions envpool/mujoco/dmc/humanoid.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ class HumanoidEnvFns {
"obs:torso_vertical"_.Bind(Spec<mjtNum>({3})),
"obs:com_velocity"_.Bind(Spec<mjtNum>({3})),
"obs:position"_.Bind(Spec<mjtNum>({28})),
"obs:velocity"_.Bind(Spec<mjtNum>({27})),
"obs:velocity"_.Bind(Spec<mjtNum>({27}))
#ifdef ENVPOOL_TEST
"info:qpos0"_.Bind(Spec<mjtNum>({28})),
,
"info:qpos0"_.Bind(Spec<mjtNum>({28}))
#endif
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})));
); // NOLINT
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
Expand Down
7 changes: 4 additions & 3 deletions envpool/mujoco/dmc/humanoid_CMU.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ class HumanoidCMUEnvFns {
"obs:extremities"_.Bind(Spec<mjtNum>({12})),
"obs:torso_vertical"_.Bind(Spec<mjtNum>({3})),
"obs:com_velocity"_.Bind(Spec<mjtNum>({3})),
"obs:velocity"_.Bind(Spec<mjtNum>({62})),
"obs:velocity"_.Bind(Spec<mjtNum>({62}))
#ifdef ENVPOOL_TEST
"info:qpos0"_.Bind(Spec<mjtNum>({63})),
,
"info:qpos0"_.Bind(Spec<mjtNum>({63}))
#endif
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})));
); // NOLINT
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
Expand Down
7 changes: 4 additions & 3 deletions envpool/mujoco/dmc/manipulator.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ class ManipulatorEnvFns {
"obs:hand_pos"_.Bind(Spec<mjtNum>({4})),
"obs:object_pos"_.Bind(Spec<mjtNum>({4})),
"obs:object_vel"_.Bind(Spec<mjtNum>({3})),
"obs:target_pos"_.Bind(Spec<mjtNum>({4})),
"obs:target_pos"_.Bind(Spec<mjtNum>({4}))
#ifdef ENVPOOL_TEST
,
"info:qpos0"_.Bind(Spec<mjtNum>({11})),
"info:random_info"_.Bind(Spec<mjtNum>({8})),
"info:random_info"_.Bind(Spec<mjtNum>({8}))
#endif
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})));
); // NOLINT
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
Expand Down
7 changes: 4 additions & 3 deletions envpool/mujoco/dmc/pendulum.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ class PendulumEnvFns {
template <typename Config>
static decltype(auto) StateSpec(const Config& conf) {
return MakeDict("obs:orientation"_.Bind(Spec<mjtNum>({2})),
"obs:velocity"_.Bind(Spec<mjtNum>({1})),
"obs:velocity"_.Bind(Spec<mjtNum>({1}))
#ifdef ENVPOOL_TEST
"info:qpos0"_.Bind(Spec<mjtNum>({1})),
,
"info:qpos0"_.Bind(Spec<mjtNum>({1}))
#endif
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})));
); // NOLINT
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
Expand Down
7 changes: 4 additions & 3 deletions envpool/mujoco/dmc/point_mass.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ class PointMassEnvFns {
template <typename Config>
static decltype(auto) StateSpec(const Config& conf) {
return MakeDict("obs:position"_.Bind(Spec<mjtNum>({2})),
"obs:velocity"_.Bind(Spec<mjtNum>({2})),
"obs:velocity"_.Bind(Spec<mjtNum>({2}))
#ifdef ENVPOOL_TEST
,
"info:qpos0"_.Bind(Spec<mjtNum>({2})),
"info:wrap_prm"_.Bind(Spec<mjtNum>({4})),
"info:wrap_prm"_.Bind(Spec<mjtNum>({4}))
#endif
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})));
); // NOLINT
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
Expand Down
7 changes: 4 additions & 3 deletions envpool/mujoco/dmc/reacher.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ class ReacherEnvFns {
static decltype(auto) StateSpec(const Config& conf) {
return MakeDict("obs:position"_.Bind(Spec<mjtNum>({2})),
"obs:to_target"_.Bind(Spec<mjtNum>({2})),
"obs:velocity"_.Bind(Spec<mjtNum>({2})),
"obs:velocity"_.Bind(Spec<mjtNum>({2}))
#ifdef ENVPOOL_TEST
,
"info:qpos0"_.Bind(Spec<mjtNum>({2})),
"info:target"_.Bind(Spec<mjtNum>({2})),
"info:target"_.Bind(Spec<mjtNum>({2}))
#endif
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})));
); // NOLINT
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
Expand Down
7 changes: 4 additions & 3 deletions envpool/mujoco/dmc/swimmer.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ class SwimmerEnvFns {
}
return MakeDict("obs:joints"_.Bind(Spec<mjtNum>({n_bodies - 1})),
"obs:to_target"_.Bind(Spec<mjtNum>({2})),
"obs:body_velocities"_.Bind(Spec<mjtNum>({3 * n_bodies})),
"obs:body_velocities"_.Bind(Spec<mjtNum>({3 * n_bodies}))
#ifdef ENVPOOL_TEST
,
"info:qpos0"_.Bind(Spec<mjtNum>({n_bodies + 2})),
"info:target0"_.Bind(Spec<mjtNum>({2})),
"info:target0"_.Bind(Spec<mjtNum>({2}))
#endif
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})));
); // NOLINT
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
Expand Down
7 changes: 4 additions & 3 deletions envpool/mujoco/dmc/walker.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ class WalkerEnvFns {
static decltype(auto) StateSpec(const Config& conf) {
return MakeDict("obs:orientations"_.Bind(Spec<mjtNum>({14})),
"obs:height"_.Bind(Spec<mjtNum>({})),
"obs:velocity"_.Bind(Spec<mjtNum>({9})),
"obs:velocity"_.Bind(Spec<mjtNum>({9}))
#ifdef ENVPOOL_TEST
"info:qpos0"_.Bind(Spec<mjtNum>({9})),
,
"info:qpos0"_.Bind(Spec<mjtNum>({9}))
#endif
"discount"_.Bind(Spec<float>({-1}, {0.0, 1.0})));
); // NOLINT
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
Expand Down
12 changes: 2 additions & 10 deletions envpool/python/dm_envpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,11 @@ def _to_dm(
state = treevalue.unflatten(
[(path, vi) for (path, _), vi in zip(tree_pairs, values)]
)
done = state.done
elapse = state.elapsed_step
discount = getattr(state, "discount", (1.0 - done).astype(np.float32))
step_type = (
(elapse == 0) * dm_env.StepType.FIRST +
((elapse > 0) & ~done) * dm_env.StepType.MID +
done * dm_env.StepType.LAST
)
timestep = TimeStep(
step_type=step_type,
step_type=state.step_type,
observation=state.State,
reward=state.reward,
discount=discount,
discount=state.discount,
)
return timestep

Expand Down

0 comments on commit 93474cf

Please sign in to comment.