Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cartpole env #25

Merged
merged 9 commits into from
Nov 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions envpool/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ py_library(
srcs = ["entry.py"],
deps = [
"//envpool/atari:atari_registration",
"//envpool/classic_control:classic_control_registration",
],
)

Expand All @@ -27,6 +28,7 @@ py_library(
":entry",
":registration",
"//envpool/atari",
"//envpool/classic_control",
"//envpool/python",
],
)
Expand Down
48 changes: 48 additions & 0 deletions envpool/classic_control/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
load("@pip_requirements//:requirements.bzl", "requirement")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")

package(default_visibility = ["//visibility:public"])

cc_library(
name = "cartpole",
hdrs = ["cartpole.h"],
deps = [
"//envpool/core:async_envpool",
],
)

pybind_extension(
name = "classic_control_envpool",
srcs = [
"classic_control.cc",
],
deps = [
":cartpole",
"//envpool/core:py_envpool",
],
)

py_library(
name = "classic_control",
srcs = ["__init__.py"],
data = [":classic_control_envpool.so"],
deps = ["//envpool/python:api"],
)

py_test(
name = "classic_control_test",
srcs = ["classic_control_test.py"],
deps = [
":classic_control",
requirement("numpy"),
requirement("absl-py"),
],
)

py_library(
name = "classic_control_registration",
srcs = ["registration.py"],
deps = [
"//envpool:registration",
],
)
28 changes: 28 additions & 0 deletions envpool/classic_control/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classic control env in EnvPool."""

from envpool.python.api import py_env

from .classic_control_envpool import _CartPoleEnvPool, _CartPoleEnvSpec

CartPoleEnvSpec, CartPoleDMEnvPool, CartPoleGymEnvPool = py_env(
_CartPoleEnvSpec, _CartPoleEnvPool
)

__all__ = [
"CartPoleEnvSpec",
"CartPoleDMEnvPool",
"CartPoleGymEnvPool",
]
126 changes: 126 additions & 0 deletions envpool/classic_control/cartpole.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Copyright 2021 Garena Online Private Limited
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef ENVPOOL_CLASSIC_CONTROL_CARTPOLE_H_
#define ENVPOOL_CLASSIC_CONTROL_CARTPOLE_H_

#include <cmath>
#include <random>

#include "envpool/core/async_envpool.h"
#include "envpool/core/env.h"

namespace classic_control {

class CartPoleEnvFns {
public:
static decltype(auto) DefaultConfig() {
return MakeDict("max_episode_steps"_.bind(200),
"reward_threshold"_.bind(195.0));
}
template <typename Config>
static decltype(auto) StateSpec(const Config& conf) {
// TODO(jiayi): specify range with [4.8, fmax, np.pi / 7.5, fmax]
mavenlin marked this conversation as resolved.
Show resolved Hide resolved
return MakeDict("obs"_.bind(Spec<float>({4})));
}
template <typename Config>
static decltype(auto) ActionSpec(const Config& conf) {
return MakeDict("action"_.bind(Spec<int>({-1}, {0, 1})));
}
};

typedef class EnvSpec<CartPoleEnvFns> CartPoleEnvSpec;

class CartPoleEnv : public Env<CartPoleEnvSpec> {
protected:
const double kPi = std::acos(-1);
const double kGravity = 9.8;
const double kMassCart = 1.0;
const double kMassPole = 0.1;
const double kMassTotal = kMassCart + kMassPole;
const double kLength = 0.5;
const double kMassPoleLength = kMassPole * kLength;
const double kForceMag = 10.0;
const double kTau = 0.02;
const double kThetaThresholdRadians = 12 * 2 * kPi / 360;
const double kXThreshold = 2.4;
const double kInitRange = 0.05;
int max_episode_steps_, elapsed_step_;
double x_, x_dot_, theta_, theta_dot_;
std::uniform_real_distribution<> dist_;
bool done_;

public:
CartPoleEnv(const Spec& spec, int env_id)
: Env<CartPoleEnvSpec>(spec, env_id),
max_episode_steps_(spec.config["max_episode_steps"_]),
elapsed_step_(max_episode_steps_ + 1),
dist_(-kInitRange, kInitRange),
done_(true) {}

bool IsDone() override { return done_; }

void Reset() override {
x_ = dist_(gen_);
x_dot_ = dist_(gen_);
theta_ = dist_(gen_);
theta_dot_ = dist_(gen_);
done_ = false;
elapsed_step_ = 0;
State state = Allocate();
WriteObs(state, 0.0f);
}

void Step(const Action& action) override {
done_ = (++elapsed_step_ >= max_episode_steps_);
int act = action["action"_];
double force = act == 1 ? kForceMag : -kForceMag;
double costheta = std::cos(theta_);
double sintheta = std::sin(theta_);
double temp =
(force + kMassPoleLength * theta_dot_ * theta_dot_ * sintheta) /
kMassTotal;
double theta_acc =
(kGravity * sintheta - costheta * temp) /
(kLength * (4.0 / 3.0 - kMassPole * costheta * costheta / kMassTotal));
double x_acc = temp - kMassPoleLength * theta_acc * costheta / kMassTotal;

x_ += kTau * x_dot_;
x_dot_ += kTau * x_acc;
theta_ += kTau * theta_dot_;
theta_dot_ += kTau * theta_acc;
if (x_ < -kXThreshold || x_ > kXThreshold ||
theta_ < -kThetaThresholdRadians || theta_ > kThetaThresholdRadians)
done_ = true;

State state = Allocate();
WriteObs(state, 1.0f);
}

private:
void WriteObs(State& state, float reward) { // NOLINT
state["obs"_][0] = static_cast<float>(x_);
state["obs"_][1] = static_cast<float>(x_dot_);
state["obs"_][2] = static_cast<float>(theta_);
state["obs"_][3] = static_cast<float>(theta_dot_);
state["reward"_] = reward;
}
};

typedef AsyncEnvPool<CartPoleEnv> CartPoleEnvPool;

} // namespace classic_control

#endif // ENVPOOL_CLASSIC_CONTROL_CARTPOLE_H_
23 changes: 23 additions & 0 deletions envpool/classic_control/classic_control.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright 2021 Garena Online Private Limited
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "envpool/classic_control/cartpole.h"
#include "envpool/core/py_envpool.h"

typedef PyEnvSpec<classic_control::CartPoleEnvSpec> CartPoleEnvSpec;
typedef PyEnvPool<classic_control::CartPoleEnvPool> CartPoleEnvPool;

PYBIND11_MODULE(classic_control_envpool, m) {
REGISTER(m, CartPoleEnvSpec, CartPoleEnvPool)
}
49 changes: 49 additions & 0 deletions envpool/classic_control/classic_control_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for classic control environments."""

import numpy as np
from absl.testing import absltest

from envpool.classic_control import CartPoleEnvSpec, CartPoleGymEnvPool


class _CartPoleEnvPoolTest(absltest.TestCase):

def test_seed(self) -> None:
num_envs = 4
config = CartPoleEnvSpec.gen_config(max_episode_steps=200, seed=0)
spec = CartPoleEnvSpec(config)
env0 = CartPoleGymEnvPool(spec)
config = CartPoleEnvSpec.gen_config(max_episode_steps=200, seed=0)
spec = CartPoleEnvSpec(config)
env1 = CartPoleGymEnvPool(spec)
config = CartPoleEnvSpec.gen_config(max_episode_steps=200, seed=1)
spec = CartPoleEnvSpec(config)
env2 = CartPoleGymEnvPool(spec)
fmax = np.finfo(np.float32).max
obs_range = np.array([4.8, fmax, np.pi / 7.5, fmax])
for _ in range(5000):
action = np.random.randint(2, size=(num_envs,))
obs0 = env0.step(action)[0]
obs1 = env1.step(action)[0]
obs2 = env2.step(action)[0]
np.testing.assert_allclose(obs0, obs1)
self.assertTrue(np.abs(obs0 - obs2).sum() > 0)
self.assertTrue(np.all(np.abs(obs0) < obs_range))
self.assertTrue(np.all(np.abs(obs2) < obs_range))


if __name__ == "__main__":
absltest.main()
36 changes: 36 additions & 0 deletions envpool/classic_control/registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2021 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classic control env registration."""

from envpool.registration import register

register(
task_id="CartPole-v0",
import_path="envpool.classic_control",
spec_cls="CartPoleEnvSpec",
dm_cls="CartPoleDMEnvPool",
gym_cls="CartPoleGymEnvPool",
max_episode_steps=200,
reward_threshold=195.0,
)

register(
task_id="CartPole-v1",
import_path="envpool.classic_control",
spec_cls="CartPoleEnvSpec",
dm_cls="CartPoleDMEnvPool",
gym_cls="CartPoleGymEnvPool",
max_episode_steps=500,
reward_threshold=475.0,
)
1 change: 1 addition & 0 deletions envpool/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
"""Entry point for all envs' registration."""

import envpool.atari.registration # noqa: F401
import envpool.classic_control.registration # noqa: F401
12 changes: 12 additions & 0 deletions envpool/make_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,18 @@ def test_make_atari(self) -> None:
self.assertEqual(env_gym.action_space.n, 18)
self.assertEqual(env_dm.action_spec().num_values, 18)

def test_make_classic(self) -> None:
spec = envpool.make_spec("CartPole-v0")
env_gym = envpool.make_gym("CartPole-v1")
env_dm = envpool.make_dm("CartPole-v1")
print(env_dm)
print(env_gym)
self.assertIsInstance(env_gym, gym.Env)
self.assertIsInstance(env_dm, dm_env.Environment)
# check reward threshold
self.assertEqual(spec.reward_threshold, 195.0)
self.assertEqual(env_gym.spec.reward_threshold, 475.0)


if __name__ == "__main__":
absltest.main()
Loading