From 24af50b759112c7ab5f76b375f3de20434df3c75 Mon Sep 17 00:00:00 2001 From: Tian Lan Date: Tue, 7 Nov 2023 21:48:28 -0800 Subject: [PATCH 1/2] mountain car wip --- .../mountain_car/mountain_car.py | 124 ++++++++++++++++++ .../mountain_car/mountain_car_step_numba.py | 68 ++++++++++ .../classic_control/test_mountain_car.py | 89 +++++++++++++ 3 files changed, 281 insertions(+) create mode 100644 example_envs/single_agent/classic_control/mountain_car/mountain_car.py create mode 100644 example_envs/single_agent/classic_control/mountain_car/mountain_car_step_numba.py create mode 100644 tests/example_envs/numba_tests/single_agent/classic_control/test_mountain_car.py diff --git a/example_envs/single_agent/classic_control/mountain_car/mountain_car.py b/example_envs/single_agent/classic_control/mountain_car/mountain_car.py new file mode 100644 index 0000000..ea76d1a --- /dev/null +++ b/example_envs/single_agent/classic_control/mountain_car/mountain_car.py @@ -0,0 +1,124 @@ +import numpy as np +from warp_drive.utils.constants import Constants +from warp_drive.utils.data_feed import DataFeed +from warp_drive.utils.gpu_environment_context import CUDAEnvironmentContext + +from example_envs.single_agent.base import SingleAgentEnv, map_to_single_agent, get_action_for_single_agent +from gym.envs.classic_control.mountain_car import MountainCarEnv + +_OBSERVATIONS = Constants.OBSERVATIONS +_ACTIONS = Constants.ACTIONS +_REWARDS = Constants.REWARDS + + +class ClassicControlMountainCarEnv(SingleAgentEnv): + + name = "ClassicControlMountainCarEnv" + + def __init__(self, episode_length, env_backend="cpu", reset_pool_size=0, seed=None): + super().__init__(episode_length, env_backend, reset_pool_size, seed=seed) + + self.gym_env = MountainCarEnv() + + self.action_space = map_to_single_agent(self.gym_env.action_space) + self.observation_space = map_to_single_agent(self.gym_env.observation_space) + + def step(self, action=None): + self.timestep += 1 + action = get_action_for_single_agent(action) + state, reward, terminated, _, _ = self.gym_env.step(action) + + obs = map_to_single_agent(state) + rew = map_to_single_agent(reward) + done = {"__all__": self.timestep >= self.episode_length or terminated} + info = {} + + return obs, rew, done, info + + def reset(self): + self.timestep = 0 + if self.reset_pool_size < 2: + # we use a fixed initial state all the time + initial_state, _ = self.gym_env.reset(seed=self.seed) + else: + initial_state, _ = self.gym_env.reset(seed=None) + obs = map_to_single_agent(initial_state) + + return obs + + +class CUDAClassicControlMountainCarEnv(ClassicControlMountainCarEnv, CUDAEnvironmentContext): + + def get_data_dictionary(self): + data_dict = DataFeed() + initial_state, _ = self.gym_env.reset(seed=self.seed) + + if self.reset_pool_size < 2: + data_dict.add_data( + name="state", + data=np.atleast_2d(initial_state), + save_copy_and_apply_at_reset=True, + ) + else: + data_dict.add_data( + name="state", + data=np.atleast_2d(initial_state), + save_copy_and_apply_at_reset=False, + ) + + data_dict.add_data_list( + [ + ("min_position", self.gym_env.min_position), + ("max_position", self.gym_env.max_position), + ("max_speed", self.gym_env.max_speed), + ("goal_position", self.gym_env.goal_position), + ("goal_velocity", self.gym_env.goal_velocity), + ("force", self.gym_env.force), + ("gravity", self.gym_env.gravity), + ] + ) + return data_dict + + def get_tensor_dictionary(self): + tensor_dict = DataFeed() + return tensor_dict + + def get_reset_pool_dictionary(self): + reset_pool_dict = DataFeed() + if self.reset_pool_size >= 2: + state_reset_pool = [] + for _ in range(self.reset_pool_size): + initial_state, _ = self.gym_env.reset(seed=None) + state_reset_pool.append(np.atleast_2d(initial_state)) + state_reset_pool = np.stack(state_reset_pool, axis=0) + assert len(state_reset_pool.shape) == 3 and state_reset_pool.shape[2] == 2 + + reset_pool_dict.add_pool_for_reset(name="state_reset_pool", + data=state_reset_pool, + reset_target="state") + return reset_pool_dict + + def step(self, actions=None): + self.timestep += 1 + args = [ + "state", + _ACTIONS, + "_done_", + _REWARDS, + _OBSERVATIONS, + "min_position", + "max_position", + "max_speed", + "goal_position", + "goal_velocity", + "force", + "gravity", + "_timestep_", + ("episode_length", "meta"), + ] + if self.env_backend == "numba": + self.cuda_step[ + self.cuda_function_manager.grid, self.cuda_function_manager.block + ](*self.cuda_step_function_feed(args)) + else: + raise Exception("CUDAClassicControlMountainCarEnv expects env_backend = 'numba' ") diff --git a/example_envs/single_agent/classic_control/mountain_car/mountain_car_step_numba.py b/example_envs/single_agent/classic_control/mountain_car/mountain_car_step_numba.py new file mode 100644 index 0000000..c936dd0 --- /dev/null +++ b/example_envs/single_agent/classic_control/mountain_car/mountain_car_step_numba.py @@ -0,0 +1,68 @@ +import numba.cuda as numba_driver +import math + + +@numba_driver.jit +def _clip(v, min, max): + if v < min: + return min + if v > max: + return max + return v + + +@numba_driver.jit +def NumbaClassicControlMountainCarEnvStep( + state_arr, + action_arr, + done_arr, + reward_arr, + observation_arr, + min_position, + max_position, + max_speed, + goal_position, + goal_velocity, + force, + gravity, + env_timestep_arr, + episode_length): + + kEnvId = numba_driver.blockIdx.x + kThisAgentId = numba_driver.threadIdx.x + + assert kThisAgentId == 0, "We only have one agent per environment" + + env_timestep_arr[kEnvId] += 1 + + assert 0 < env_timestep_arr[kEnvId] <= episode_length + + reward_arr[kEnvId, kThisAgentId] = 0.0 + + action = action_arr[kEnvId, kThisAgentId, 0] + + position = state_arr[kEnvId, kThisAgentId, 0] + velocity = state_arr[kEnvId, kThisAgentId, 1] + + velocity += (action - 1) * force + math.cos(3 * position) * (-gravity) + velocity = _clip(velocity, -max_speed, max_speed) + position += velocity + position = _clip(position, min_position, max_position) + if position == min_position and velocity < 0: + velocity = 0 + + state_arr[kEnvId, kThisAgentId, 0] = position + state_arr[kEnvId, kThisAgentId, 1] = velocity + + observation_arr[kEnvId, kThisAgentId, 0] = state_arr[kEnvId, kThisAgentId, 0] + observation_arr[kEnvId, kThisAgentId, 1] = state_arr[kEnvId, kThisAgentId, 1] + + terminated = bool( + position >= goal_position and velocity >= goal_velocity + ) + + # as long as not reset, we assign reward -1. This is consistent with original cartpole logic + reward_arr[kEnvId, kThisAgentId] = -1.0 + + if env_timestep_arr[kEnvId] == episode_length or terminated: + done_arr[kEnvId] = 1 \ No newline at end of file diff --git a/tests/example_envs/numba_tests/single_agent/classic_control/test_mountain_car.py b/tests/example_envs/numba_tests/single_agent/classic_control/test_mountain_car.py new file mode 100644 index 0000000..e0085cd --- /dev/null +++ b/tests/example_envs/numba_tests/single_agent/classic_control/test_mountain_car.py @@ -0,0 +1,89 @@ +import unittest +import numpy as np +import torch + +from warp_drive.env_cpu_gpu_consistency_checker import EnvironmentCPUvsGPU +from example_envs.single_agent.classic_control.mountain_car.mountain_car import \ + ClassicControlMountainCarEnv, CUDAClassicControlMountainCarEnv +from warp_drive.env_wrapper import EnvWrapper + + +env_configs = { + "test1": { + "episode_length": 500, + "reset_pool_size": 0, + "seed": 32145, + }, + "test2": { + "episode_length": 200, + "reset_pool_size": 0, + "seed": 54231, + }, +} + + +class MyTestCase(unittest.TestCase): + """ + CPU v GPU consistency unit tests + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.testing_class = EnvironmentCPUvsGPU( + cpu_env_class=ClassicControlMountainCarEnv, + cuda_env_class=CUDAClassicControlMountainCarEnv, + env_configs=env_configs, + gpu_env_backend="numba", + num_envs=5, + num_episodes=2, + ) + + def test_env_consistency(self): + try: + self.testing_class.test_env_reset_and_step() + except AssertionError: + self.fail("ClassicControlMountainCarEnv environment consistency tests failed") + + def test_reset_pool(self): + env_wrapper = EnvWrapper( + env_obj=CUDAClassicControlMountainCarEnv(episode_length=100, reset_pool_size=3), + num_envs=3, + env_backend="numba", + ) + env_wrapper.reset_all_envs() + env_wrapper.env_resetter.init_reset_pool(env_wrapper.cuda_data_manager, seed=12345) + self.assertTrue(env_wrapper.cuda_data_manager.reset_target_to_pool["state"] == "state_reset_pool") + + # squeeze() the agent dimension which is 1 always + state_after_initial_reset = env_wrapper.cuda_data_manager.pull_data_from_device("state").squeeze() + + reset_pool = env_wrapper.cuda_data_manager.pull_data_from_device( + env_wrapper.cuda_data_manager.get_reset_pool("state")) + reset_pool_mean = reset_pool.mean(axis=0).squeeze() + + env_wrapper.cuda_data_manager.data_on_device_via_torch("_done_")[:] = torch.from_numpy( + np.array([1, 1, 0]) + ).cuda() + + state_values = {0: [], 1: [], 2: []} + for _ in range(10000): + env_wrapper.env_resetter.reset_when_done(env_wrapper.cuda_data_manager, mode="if_done", undo_done_after_reset=False) + res = env_wrapper.cuda_data_manager.pull_data_from_device("state") + state_values[0].append(res[0]) + state_values[1].append(res[1]) + state_values[2].append(res[2]) + + state_values_env0_mean = np.stack(state_values[0]).mean(axis=0).squeeze() + state_values_env1_mean = np.stack(state_values[1]).mean(axis=0).squeeze() + state_values_env2_mean = np.stack(state_values[2]).mean(axis=0).squeeze() + + for i in range(len(reset_pool_mean)): + self.assertTrue(np.absolute(state_values_env0_mean[i] - reset_pool_mean[i]) < 0.1 * abs(reset_pool_mean[i])) + self.assertTrue(np.absolute(state_values_env1_mean[i] - reset_pool_mean[i]) < 0.1 * abs(reset_pool_mean[i])) + self.assertTrue( + np.absolute( + state_values_env2_mean[i] - state_after_initial_reset[0][i] + ) < 0.001 * abs(state_after_initial_reset[0][i]) + ) + + From 2093cd229d6b716b45562ab62416fd3df25ffef6 Mon Sep 17 00:00:00 2001 From: Tian Lan Date: Wed, 8 Nov 2023 10:22:42 -0800 Subject: [PATCH 2/2] unittest mountain car --- .../classic_control/test_cartpole.py | 2 + .../classic_control/test_mountain_car.py | 16 ++++--- .../training/example_training_script_numba.py | 10 +++++ .../run_configs/single_mountain_car.yaml | 43 +++++++++++++++++++ warp_drive/utils/numba_utils/misc.py | 1 + 5 files changed, 65 insertions(+), 7 deletions(-) create mode 100644 warp_drive/training/run_configs/single_mountain_car.yaml diff --git a/tests/example_envs/numba_tests/single_agent/classic_control/test_cartpole.py b/tests/example_envs/numba_tests/single_agent/classic_control/test_cartpole.py index 790ad7a..30e92ec 100644 --- a/tests/example_envs/numba_tests/single_agent/classic_control/test_cartpole.py +++ b/tests/example_envs/numba_tests/single_agent/classic_control/test_cartpole.py @@ -61,6 +61,8 @@ def test_reset_pool(self): env_wrapper.cuda_data_manager.get_reset_pool("state")) reset_pool_mean = reset_pool.mean(axis=0).squeeze() + self.assertTrue(reset_pool.std(axis=0).mean() > 1e-4) + env_wrapper.cuda_data_manager.data_on_device_via_torch("_done_")[:] = torch.from_numpy( np.array([1, 1, 0]) ).cuda() diff --git a/tests/example_envs/numba_tests/single_agent/classic_control/test_mountain_car.py b/tests/example_envs/numba_tests/single_agent/classic_control/test_mountain_car.py index e0085cd..f157321 100644 --- a/tests/example_envs/numba_tests/single_agent/classic_control/test_mountain_car.py +++ b/tests/example_envs/numba_tests/single_agent/classic_control/test_mountain_car.py @@ -61,6 +61,9 @@ def test_reset_pool(self): env_wrapper.cuda_data_manager.get_reset_pool("state")) reset_pool_mean = reset_pool.mean(axis=0).squeeze() + # we only need to check the 0th element of state because state[1] = 0 for reset always + self.assertTrue(reset_pool.std(axis=0).squeeze()[0] > 1e-4) + env_wrapper.cuda_data_manager.data_on_device_via_torch("_done_")[:] = torch.from_numpy( np.array([1, 1, 0]) ).cuda() @@ -77,13 +80,12 @@ def test_reset_pool(self): state_values_env1_mean = np.stack(state_values[1]).mean(axis=0).squeeze() state_values_env2_mean = np.stack(state_values[2]).mean(axis=0).squeeze() - for i in range(len(reset_pool_mean)): - self.assertTrue(np.absolute(state_values_env0_mean[i] - reset_pool_mean[i]) < 0.1 * abs(reset_pool_mean[i])) - self.assertTrue(np.absolute(state_values_env1_mean[i] - reset_pool_mean[i]) < 0.1 * abs(reset_pool_mean[i])) - self.assertTrue( - np.absolute( - state_values_env2_mean[i] - state_after_initial_reset[0][i] - ) < 0.001 * abs(state_after_initial_reset[0][i]) + self.assertTrue(np.absolute(state_values_env0_mean[0] - reset_pool_mean[0]) < 0.1 * abs(reset_pool_mean[0])) + self.assertTrue(np.absolute(state_values_env1_mean[0] - reset_pool_mean[0]) < 0.1 * abs(reset_pool_mean[0])) + self.assertTrue( + np.absolute( + state_values_env2_mean[0] - state_after_initial_reset[0][0] + ) < 0.001 * abs(state_after_initial_reset[0][0]) ) diff --git a/warp_drive/training/example_training_script_numba.py b/warp_drive/training/example_training_script_numba.py index a3c48a8..71ab94f 100644 --- a/warp_drive/training/example_training_script_numba.py +++ b/warp_drive/training/example_training_script_numba.py @@ -20,6 +20,7 @@ from example_envs.tag_continuous.tag_continuous import TagContinuous from example_envs.tag_gridworld.tag_gridworld import CUDATagGridWorld, CUDATagGridWorldWithResetPool from example_envs.single_agent.classic_control.cartpole.cartpole import CUDAClassicControlCartPoleEnv +from example_envs.single_agent.classic_control.mountain_car.mountain_car import CUDAClassicControlMountainCarEnv from warp_drive.env_wrapper import EnvWrapper from warp_drive.training.trainer import Trainer from warp_drive.training.utils.distributed_train.distributed_trainer_numba import ( @@ -35,6 +36,7 @@ _TAG_GRIDWORLD_WITH_RESET_POOL = "tag_gridworld_with_reset_pool" _CLASSIC_CONTROL_CARTPOLE = "single_cartpole" +__CLASSIC_CONTROL_MOUNTAIN_CAR = "single_mountain_car" # Example usages (from the root folder): # >> python warp_drive/training/example_training_script.py -e tag_gridworld @@ -92,6 +94,14 @@ def setup_trainer_and_train( event_messenger=event_messenger, process_id=device_id, ) + elif run_configuration["name"] == __CLASSIC_CONTROL_MOUNTAIN_CAR: + env_wrapper = EnvWrapper( + CUDAClassicControlMountainCarEnv(**run_configuration["env"]), + num_envs=num_envs, + env_backend="numba", + event_messenger=event_messenger, + process_id=device_id, + ) else: raise NotImplementedError( f"Currently, the environments supported are [" diff --git a/warp_drive/training/run_configs/single_mountain_car.yaml b/warp_drive/training/run_configs/single_mountain_car.yaml new file mode 100644 index 0000000..d1730c1 --- /dev/null +++ b/warp_drive/training/run_configs/single_mountain_car.yaml @@ -0,0 +1,43 @@ +# Copyright (c) 2021, salesforce.com, inc. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# For full license text, see the LICENSE file in the repo root +# or https://opensource.org/licenses/BSD-3-Clause + +# YAML configuration for the tag gridworld environment +name: "single_mountain_car" +# Environment settings +env: + episode_length: 500 + reset_pool_size: 1000 +# Trainer settings +trainer: + num_envs: 100 # number of environment replicas + num_episodes: 200000 # number of episodes to run the training for. Can be arbitrarily high! + train_batch_size: 50000 # total batch size used for training per iteration (across all the environments) + env_backend: "numba" # environment backend, pycuda or numba +# Policy network settings +policy: # list all the policies below + shared: + to_train: True # flag indicating whether the model needs to be trained + algorithm: "A2C" # algorithm used to train the policy + vf_loss_coeff: 1 # loss coefficient schedule for the value function loss + entropy_coeff: 0.05 # loss coefficient schedule for the entropy loss + clip_grad_norm: True # flag indicating whether to clip the gradient norm or not + max_grad_norm: 3 # when clip_grad_norm is True, the clip level + normalize_advantage: False # flag indicating whether to normalize advantage or not + normalize_return: False # flag indicating whether to normalize return or not + gamma: 0.99 # discount factor + lr: 0.001 # learning rate + model: # policy model settings + type: "fully_connected" # model type + fc_dims: [32, 32] # dimension(s) of the fully connected layers as a list + model_ckpt_filepath: "" # filepath (used to restore a previously saved model) +# Checkpoint saving setting +saving: + metrics_log_freq: 100 # how often (in iterations) to log (and print) the metrics + model_params_save_freq: 5000 # how often (in iterations) to save the model parameters + basedir: "/tmp" # base folder used for saving + name: "single_mountain_car" # base folder used for saving + tag: "experiments" # experiment name + diff --git a/warp_drive/utils/numba_utils/misc.py b/warp_drive/utils/numba_utils/misc.py index 68c7985..ebcd1c3 100644 --- a/warp_drive/utils/numba_utils/misc.py +++ b/warp_drive/utils/numba_utils/misc.py @@ -18,6 +18,7 @@ def get_default_env_directory(env_name): "TagGridWorld": "example_envs.tag_gridworld.tag_gridworld_step_numba", "TagContinuous": "example_envs.tag_continuous.tag_continuous_step_numba", "ClassicControlCartPoleEnv": "example_envs.single_agent.classic_control.cartpole.cartpole_step_numba", + "ClassicControlMountainCarEnv": "example_envs.single_agent.classic_control.mountain_car.mountain_car_step_numba", "YOUR_ENVIRONMENT": "PYTHON_PATH_TO_YOUR_ENV_SRC", } return envs.get(env_name, None)