From e3ed6379a6d43b4e70fb3b10959a457e705b513c Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Sat, 20 Nov 2021 18:04:49 -0500 Subject: [PATCH 1/6] add cartpole --- envpool/BUILD | 1 + envpool/classic_control/BUILD | 40 ++++ envpool/classic_control/__init__.py | 28 +++ envpool/classic_control/cartpole.h | 126 ++++++++++++ envpool/classic_control/classic_control.cc | 23 +++ .../classic_control/classic_control_test.py | 48 +++++ envpool/entry.py | 22 +++ envpool/make_test.py | 12 ++ examples/cartpole_ppo.py | 183 ++++++++++++++++++ setup.cfg | 1 + 10 files changed, 484 insertions(+) create mode 100644 envpool/classic_control/BUILD create mode 100644 envpool/classic_control/__init__.py create mode 100644 envpool/classic_control/cartpole.h create mode 100644 envpool/classic_control/classic_control.cc create mode 100644 envpool/classic_control/classic_control_test.py create mode 100644 examples/cartpole_ppo.py diff --git a/envpool/BUILD b/envpool/BUILD index 24b74c13..c4026a12 100644 --- a/envpool/BUILD +++ b/envpool/BUILD @@ -27,6 +27,7 @@ py_library( ":entry", ":registration", "//envpool/atari", + "//envpool/classic_control", "//envpool/python", ], ) diff --git a/envpool/classic_control/BUILD b/envpool/classic_control/BUILD new file mode 100644 index 00000000..b32b5de5 --- /dev/null +++ b/envpool/classic_control/BUILD @@ -0,0 +1,40 @@ +load("@pip_requirements//:requirements.bzl", "requirement") +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "cartpole", + hdrs = ["cartpole.h"], + deps = [ + "//envpool/core:async_envpool", + ], +) + +pybind_extension( + name = "classic_control_envpool", + srcs = [ + "classic_control.cc", + ], + deps = [ + ":cartpole", + "//envpool/core:py_envpool", + ], +) + +py_library( + name = "classic_control", + srcs = ["__init__.py"], + data = [":classic_control_envpool.so"], + deps = ["//envpool/python:api"], +) + +py_test( + name = "classic_control_test", + srcs = ["classic_control_test.py"], + deps = [ + "//envpool", + requirement("numpy"), + requirement("absl-py"), + ], +) diff --git a/envpool/classic_control/__init__.py b/envpool/classic_control/__init__.py new file mode 100644 index 00000000..d99bf283 --- /dev/null +++ b/envpool/classic_control/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2021 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Classic control env in EnvPool.""" + +from envpool.python.api import py_env + +from .classic_control_envpool import _CartPoleEnvPool, _CartPoleEnvSpec + +CartPoleEnvSpec, CartPoleDMEnvPool, CartPoleGymEnvPool = py_env( + _CartPoleEnvSpec, _CartPoleEnvPool +) + +__all__ = [ + "CartPoleEnvSpec", + "CartPoleDMEnvPool", + "CartPoleGymEnvPool", +] diff --git a/envpool/classic_control/cartpole.h b/envpool/classic_control/cartpole.h new file mode 100644 index 00000000..d8cd46a8 --- /dev/null +++ b/envpool/classic_control/cartpole.h @@ -0,0 +1,126 @@ +/* + * Copyright 2021 Garena Online Private Limited + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef ENVPOOL_CLASSIC_CONTROL_CARTPOLE_H_ +#define ENVPOOL_CLASSIC_CONTROL_CARTPOLE_H_ + +#include +#include + +#include "envpool/core/async_envpool.h" +#include "envpool/core/env.h" + +namespace classic_control { + +class CartPoleEnvFns { + public: + static decltype(auto) DefaultConfig() { + return MakeDict("max_episode_steps"_.bind(200), + "reward_threshold"_.bind(195.0)); + } + template + static decltype(auto) StateSpec(const Config& conf) { + // TODO(jiayi): specify range with [4.8, fmax, np.pi / 7.5, fmax] + return MakeDict("obs"_.bind(Spec({4}))); + } + template + static decltype(auto) ActionSpec(const Config& conf) { + return MakeDict("action"_.bind(Spec({-1}, {0, 1}))); + } +}; + +typedef class EnvSpec CartPoleEnvSpec; + +class CartPoleEnv : public Env { + protected: + const double kPi = std::acos(-1); + const double kGravity = 9.8; + const double kMassCart = 1.0; + const double kMassPole = 0.1; + const double kMassTotal = kMassCart + kMassPole; + const double kLength = 0.5; + const double kMassPoleLength = kMassPole * kLength; + const double kForceMag = 10.0; + const double kTau = 0.02; + const double kThetaThresholdRadians = 12 * 2 * kPi / 360; + const double kXThreshold = 2.4; + const double kInitRange = 0.05; + int max_episode_steps_, elapsed_step_; + double x_, x_dot_, theta_, theta_dot_; + std::uniform_real_distribution<> dist_; + bool done_; + + public: + CartPoleEnv(const Spec& spec, int env_id) + : Env(spec, env_id), + max_episode_steps_(spec.config["max_episode_steps"_]), + elapsed_step_(max_episode_steps_ + 1), + dist_(-0.05, 0.05), + done_(true) {} + + bool IsDone() override { return done_; } + + void Reset() override { + x_ = dist_(gen_); + x_dot_ = dist_(gen_); + theta_ = dist_(gen_); + theta_dot_ = dist_(gen_); + done_ = false; + elapsed_step_ = 0; + State state = Allocate(); + WriteObs(state, 0.0f); + } + + void Step(const Action& action) override { + done_ = (++elapsed_step_ >= max_episode_steps_); + int act = action["action"_]; + double force = act == 1 ? kForceMag : -kForceMag; + double costheta = std::cos(theta_); + double sintheta = std::sin(theta_); + double temp = + (force + kMassPoleLength * theta_dot_ * theta_dot_ * sintheta) / + kMassTotal; + double theta_acc = + (kGravity * sintheta - costheta * temp) / + (kLength * (4.0 / 3.0 - kMassPole * costheta * costheta / kMassTotal)); + double x_acc = temp - kMassPoleLength * theta_acc * costheta / kMassTotal; + + x_ += kTau * x_dot_; + x_dot_ += kTau * x_acc; + theta_ += kTau * theta_dot_; + theta_dot_ += kTau * theta_acc; + if (x_ < -kXThreshold || x_ > kXThreshold || + theta_ < -kThetaThresholdRadians || theta_ > kThetaThresholdRadians) + done_ = true; + + State state = Allocate(); + WriteObs(state, 1.0f); + } + + private: + void WriteObs(State& state, float reward) { // NOLINT + state["obs"_][0] = static_cast(x_); + state["obs"_][1] = static_cast(x_dot_); + state["obs"_][2] = static_cast(theta_); + state["obs"_][3] = static_cast(theta_dot_); + state["reward"_] = reward; + } +}; + +typedef AsyncEnvPool CartPoleEnvPool; + +} // namespace classic_control + +#endif // ENVPOOL_CLASSIC_CONTROL_CARTPOLE_H_ diff --git a/envpool/classic_control/classic_control.cc b/envpool/classic_control/classic_control.cc new file mode 100644 index 00000000..52ca7054 --- /dev/null +++ b/envpool/classic_control/classic_control.cc @@ -0,0 +1,23 @@ +// Copyright 2021 Garena Online Private Limited +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "envpool/classic_control/cartpole.h" +#include "envpool/core/py_envpool.h" + +typedef PyEnvSpec CartPoleEnvSpec; +typedef PyEnvPool CartPoleEnvPool; + +PYBIND11_MODULE(classic_control_envpool, m) { + REGISTER(m, CartPoleEnvSpec, CartPoleEnvPool) +} diff --git a/envpool/classic_control/classic_control_test.py b/envpool/classic_control/classic_control_test.py new file mode 100644 index 00000000..5c340792 --- /dev/null +++ b/envpool/classic_control/classic_control_test.py @@ -0,0 +1,48 @@ +# Copyright 2021 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from absl.testing import absltest + +from envpool.classic_control import CartPoleEnvSpec, CartPoleGymEnvPool + + +class _CartPoleEnvPoolTest(absltest.TestCase): + + def test_seed(self) -> None: + num_envs = 4 + config = CartPoleEnvSpec.gen_config(max_episode_steps=200, seed=0) + spec = CartPoleEnvSpec(config) + env0 = CartPoleGymEnvPool(spec) + config = CartPoleEnvSpec.gen_config(max_episode_steps=200, seed=0) + spec = CartPoleEnvSpec(config) + env1 = CartPoleGymEnvPool(spec) + config = CartPoleEnvSpec.gen_config(max_episode_steps=200, seed=1) + spec = CartPoleEnvSpec(config) + env2 = CartPoleGymEnvPool(spec) + fmax = np.finfo(np.float32).max + obs_range = np.array([4.8, fmax, np.pi / 7.5, fmax]) + for _ in range(1000): + action = np.random.randint(2, size=(num_envs,)) + obs0 = env0.step(action)[0] + obs1 = env1.step(action)[0] + obs2 = env2.step(action)[0] + np.testing.assert_allclose(obs0, obs1) + self.assertTrue(np.abs(obs0 - obs2).sum() > 0) + self.assertTrue(np.all(np.abs(obs0) < obs_range)) + self.assertTrue(np.all(np.abs(obs2) < obs_range)) + + +if __name__ == "__main__": + absltest.main() diff --git a/envpool/entry.py b/envpool/entry.py index 44fde252..3ab63e58 100644 --- a/envpool/entry.py +++ b/envpool/entry.py @@ -35,3 +35,25 @@ task=game, base_path=base_path, ) + +# Classic Control + +register( + task_id="CartPole-v0", + import_path="envpool.classic_control", + spec_cls="CartPoleEnvSpec", + dm_cls="CartPoleDMEnvPool", + gym_cls="CartPoleGymEnvPool", + max_episode_steps=200, + reward_threshold=195.0, +) + +register( + task_id="CartPole-v1", + import_path="envpool.classic_control", + spec_cls="CartPoleEnvSpec", + dm_cls="CartPoleDMEnvPool", + gym_cls="CartPoleGymEnvPool", + max_episode_steps=500, + reward_threshold=475.0, +) diff --git a/envpool/make_test.py b/envpool/make_test.py index 4f109823..1aa5f2d5 100644 --- a/envpool/make_test.py +++ b/envpool/make_test.py @@ -40,6 +40,18 @@ def test_make_atari(self) -> None: self.assertEqual(env_gym.action_space.n, 18) self.assertEqual(env_dm.action_spec().num_values, 18) + def test_make_classic(self) -> None: + spec = envpool.make_spec("CartPole-v0") + env_gym = envpool.make_gym("CartPole-v1") + env_dm = envpool.make_dm("CartPole-v1") + print(env_dm) + print(env_gym) + self.assertIsInstance(env_gym, gym.Env) + self.assertIsInstance(env_dm, dm_env.Environment) + # check reward threshold + self.assertEqual(spec.reward_threshold, 195.0) + self.assertEqual(env_gym.spec.reward_threshold, 475.0) + if __name__ == "__main__": absltest.main() diff --git a/examples/cartpole_ppo.py b/examples/cartpole_ppo.py new file mode 100644 index 00000000..dd34d9e6 --- /dev/null +++ b/examples/cartpole_ppo.py @@ -0,0 +1,183 @@ +# Copyright 2021 Garena Online Private Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pprint + +import gym +import numpy as np +import pytest +import torch +from tianshou.data import Collector, VectorReplayBuffer +from tianshou.env import DummyVectorEnv +from tianshou.policy import PPOPolicy +from tianshou.trainer import onpolicy_trainer +from tianshou.utils import TensorboardLogger +from tianshou.utils.net.common import Net +from tianshou.utils.net.discrete import Actor, Critic +from torch.utils.tensorboard import SummaryWriter + +import envpool + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--task", type=str, default="CartPole-v1") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--buffer-size", type=int, default=20000) + parser.add_argument("--lr", type=float, default=1e-3) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--epoch", type=int, default=10) + parser.add_argument("--step-per-epoch", type=int, default=50000) + parser.add_argument("--episode-per-collect", type=int, default=20) + parser.add_argument("--repeat-per-collect", type=int, default=2) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--hidden-sizes", type=int, nargs="*", default=[64, 64]) + parser.add_argument("--training-num", type=int, default=20) + parser.add_argument("--test-num", type=int, default=100) + parser.add_argument("--logdir", type=str, default="log") + parser.add_argument( + "--device", + type=str, + default="cuda" if torch.cuda.is_available() else "cpu" + ) + # ppo special + parser.add_argument("--vf-coef", type=float, default=0.5) + parser.add_argument("--ent-coef", type=float, default=0.0) + parser.add_argument("--eps-clip", type=float, default=0.2) + parser.add_argument("--max-grad-norm", type=float, default=0.5) + parser.add_argument("--gae-lambda", type=float, default=0.95) + parser.add_argument("--rew-norm", type=int, default=1) + parser.add_argument("--dual-clip", type=float, default=None) + parser.add_argument("--value-clip", type=int, default=1) + parser.add_argument("--watch", action="store_true") + args = parser.parse_known_args()[0] + return args + + +def run_ppo(args=get_args()): + env = gym.make(args.task) + if args.task == "CartPole-v0": + env.spec.reward_threshold = 200 + elif args.task == "CartPole-v1": + env.spec.reward_threshold = 500 + + train_envs = envpool.make( + args.task, num_envs=args.training_num, env_type="gym" + ) + test_envs = envpool.make(args.task, num_envs=args.test_num, env_type="gym") + args.state_shape = env.observation_space.shape or env.observation_space.n + args.action_shape = env.action_space.shape or env.action_space.n + ss_ = train_envs.observation_space.shape or train_envs.observation_space.n + assert ss_ == args.state_shape + as_ = train_envs.action_space.shape or train_envs.action_space.n + assert as_ == args.action_shape + + # seed + np.random.seed(args.seed) + torch.manual_seed(args.seed) + # model + net = Net( + args.state_shape, hidden_sizes=args.hidden_sizes, device=args.device + ) + actor = Actor(net, args.action_shape, device=args.device).to(args.device) + critic = Critic(net, device=args.device).to(args.device) + # orthogonal initialization + for m in list(actor.modules()) + list(critic.modules()): + if isinstance(m, torch.nn.Linear): + torch.nn.init.orthogonal_(m.weight) + torch.nn.init.zeros_(m.bias) + optim = torch.optim.Adam( + set(actor.parameters()).union(critic.parameters()), lr=args.lr + ) + dist = torch.distributions.Categorical + policy = PPOPolicy( + actor, + critic, + optim, + dist, + discount_factor=args.gamma, + max_grad_norm=args.max_grad_norm, + eps_clip=args.eps_clip, + vf_coef=args.vf_coef, + ent_coef=args.ent_coef, + gae_lambda=args.gae_lambda, + reward_normalization=args.rew_norm, + dual_clip=args.dual_clip, + value_clip=args.value_clip, + action_space=env.action_space, + deterministic_eval=True + ) + # collector + train_collector = Collector( + policy, train_envs, VectorReplayBuffer(args.buffer_size, len(train_envs)) + ) + test_collector = Collector(policy, test_envs) + # log + log_path = os.path.join(args.logdir, args.task, "ppo") + writer = SummaryWriter(log_path) + logger = TensorboardLogger(writer) + + def save_fn(policy): + torch.save(policy.state_dict(), os.path.join(log_path, "policy.pth")) + + def stop_fn(mean_rewards): + return mean_rewards >= env.spec.reward_threshold + + def watch(): + # Let's watch its performance! + env = DummyVectorEnv( + [lambda: gym.make(args.task) for _ in range(args.test_num)] + ) + policy.eval() + collector = Collector(policy, env) + result = collector.collect(n_episode=args.test_num) + rews, lens = result["rews"], result["lens"] + print(f"Final reward: {rews.mean()}, length: {lens.mean()}") + assert stop_fn(rews.mean() + 5) + return rews.mean() + + if args.watch: + return watch() + + # trainer + result = onpolicy_trainer( + policy, + train_collector, + test_collector, + args.epoch, + args.step_per_epoch, + args.repeat_per_collect, + args.test_num, + args.batch_size, + episode_per_collect=args.episode_per_collect, + stop_fn=stop_fn, + save_fn=save_fn, + logger=logger + ) + pprint.pprint(result) + assert stop_fn(result["best_reward"]) + return watch() + + +@pytest.mark.parametrize("task", ["CartPole-v0", "CartPole-v1"]) +def test_classic(task: str): + args = get_args() + args.task = task + run_ppo(args) + + +if __name__ == "__main__": + run_ppo(get_args()) diff --git a/setup.cfg b/setup.cfg index aae5af25..67d21516 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,6 +34,7 @@ include = envpool* [options.package_data] envpool = atari/*.so atari/atari_roms/*/*.bin + classic_control/*.so [yapf] based_on_style = yapf From da448d7f43aebc6889814ae526abfbb7e9de7323 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Sat, 20 Nov 2021 18:09:10 -0500 Subject: [PATCH 2/6] fix lint --- envpool/classic_control/classic_control_test.py | 1 + examples/cartpole_ppo.py | 10 +--------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/envpool/classic_control/classic_control_test.py b/envpool/classic_control/classic_control_test.py index 5c340792..de5f66f2 100644 --- a/envpool/classic_control/classic_control_test.py +++ b/envpool/classic_control/classic_control_test.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Unit tests for classic control environments.""" import numpy as np from absl.testing import absltest diff --git a/examples/cartpole_ppo.py b/examples/cartpole_ppo.py index dd34d9e6..2bd93ebe 100644 --- a/examples/cartpole_ppo.py +++ b/examples/cartpole_ppo.py @@ -18,7 +18,6 @@ import gym import numpy as np -import pytest import torch from tianshou.data import Collector, VectorReplayBuffer from tianshou.env import DummyVectorEnv @@ -67,7 +66,7 @@ def get_args(): return args -def run_ppo(args=get_args()): +def run_ppo(args): env = gym.make(args.task) if args.task == "CartPole-v0": env.spec.reward_threshold = 200 @@ -172,12 +171,5 @@ def watch(): return watch() -@pytest.mark.parametrize("task", ["CartPole-v0", "CartPole-v1"]) -def test_classic(task: str): - args = get_args() - args.task = task - run_ppo(args) - - if __name__ == "__main__": run_ppo(get_args()) From 112da718ebf4033444175127c4cee9b34b1b4384 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Sat, 20 Nov 2021 18:10:45 -0500 Subject: [PATCH 3/6] \t -> ' ' --- setup.cfg | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.cfg b/setup.cfg index 67d21516..231b6c7e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -26,7 +26,7 @@ install_requires = gym>=0.18 numpy>=1.19 types-protobuf>=3.17.3 - typing-extensions + typing-extensions [options.packages.find] include = envpool* @@ -34,7 +34,7 @@ include = envpool* [options.package_data] envpool = atari/*.so atari/atari_roms/*/*.bin - classic_control/*.so + classic_control/*.so [yapf] based_on_style = yapf From ee5f90e131d492fc29bfa14e806c70d00e062b73 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Sat, 20 Nov 2021 18:17:30 -0500 Subject: [PATCH 4/6] polish --- envpool/classic_control/cartpole.h | 2 +- envpool/classic_control/classic_control_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/envpool/classic_control/cartpole.h b/envpool/classic_control/cartpole.h index d8cd46a8..f1d7111c 100644 --- a/envpool/classic_control/cartpole.h +++ b/envpool/classic_control/cartpole.h @@ -67,7 +67,7 @@ class CartPoleEnv : public Env { : Env(spec, env_id), max_episode_steps_(spec.config["max_episode_steps"_]), elapsed_step_(max_episode_steps_ + 1), - dist_(-0.05, 0.05), + dist_(-kInitRange, kInitRange), done_(true) {} bool IsDone() override { return done_; } diff --git a/envpool/classic_control/classic_control_test.py b/envpool/classic_control/classic_control_test.py index de5f66f2..fa7baad9 100644 --- a/envpool/classic_control/classic_control_test.py +++ b/envpool/classic_control/classic_control_test.py @@ -34,7 +34,7 @@ def test_seed(self) -> None: env2 = CartPoleGymEnvPool(spec) fmax = np.finfo(np.float32).max obs_range = np.array([4.8, fmax, np.pi / 7.5, fmax]) - for _ in range(1000): + for _ in range(5000): action = np.random.randint(2, size=(num_envs,)) obs0 = env0.step(action)[0] obs1 = env1.step(action)[0] From dcbdbc559d321f30b2b416301377065a3023b098 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Sat, 20 Nov 2021 18:19:54 -0500 Subject: [PATCH 5/6] rename --- examples/{ => tianshou_examples}/cartpole_ppo.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/{ => tianshou_examples}/cartpole_ppo.py (100%) diff --git a/examples/cartpole_ppo.py b/examples/tianshou_examples/cartpole_ppo.py similarity index 100% rename from examples/cartpole_ppo.py rename to examples/tianshou_examples/cartpole_ppo.py From 2a5c25340c67245a15f818b7447007fbf167f7f8 Mon Sep 17 00:00:00 2001 From: Jiayi Weng Date: Sat, 20 Nov 2021 19:51:20 -0500 Subject: [PATCH 6/6] polish --- envpool/classic_control/BUILD | 2 +- examples/tianshou_examples/cartpole_ppo.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/envpool/classic_control/BUILD b/envpool/classic_control/BUILD index b32b5de5..22d0a661 100644 --- a/envpool/classic_control/BUILD +++ b/envpool/classic_control/BUILD @@ -33,7 +33,7 @@ py_test( name = "classic_control_test", srcs = ["classic_control_test.py"], deps = [ - "//envpool", + ":classic_control", requirement("numpy"), requirement("absl-py"), ], diff --git a/examples/tianshou_examples/cartpole_ppo.py b/examples/tianshou_examples/cartpole_ppo.py index 2bd93ebe..9d34381c 100644 --- a/examples/tianshou_examples/cartpole_ppo.py +++ b/examples/tianshou_examples/cartpole_ppo.py @@ -24,7 +24,7 @@ from tianshou.policy import PPOPolicy from tianshou.trainer import onpolicy_trainer from tianshou.utils import TensorboardLogger -from tianshou.utils.net.common import Net +from tianshou.utils.net.common import ActorCritic, Net from tianshou.utils.net.discrete import Actor, Critic from torch.utils.tensorboard import SummaryWriter @@ -93,14 +93,13 @@ def run_ppo(args): ) actor = Actor(net, args.action_shape, device=args.device).to(args.device) critic = Critic(net, device=args.device).to(args.device) + actor_critic = ActorCritic(actor, critic) # orthogonal initialization - for m in list(actor.modules()) + list(critic.modules()): + for m in actor_critic.modules(): if isinstance(m, torch.nn.Linear): torch.nn.init.orthogonal_(m.weight) torch.nn.init.zeros_(m.bias) - optim = torch.optim.Adam( - set(actor.parameters()).union(critic.parameters()), lr=args.lr - ) + optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr) dist = torch.distributions.Categorical policy = PPOPolicy( actor,