Skip to content

Commit

Permalink
unittest mountain car
Browse files Browse the repository at this point in the history
  • Loading branch information
Emerald01 committed Nov 8, 2023
1 parent 24af50b commit 2093cd2
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def test_reset_pool(self):
env_wrapper.cuda_data_manager.get_reset_pool("state"))
reset_pool_mean = reset_pool.mean(axis=0).squeeze()

self.assertTrue(reset_pool.std(axis=0).mean() > 1e-4)

env_wrapper.cuda_data_manager.data_on_device_via_torch("_done_")[:] = torch.from_numpy(
np.array([1, 1, 0])
).cuda()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def test_reset_pool(self):
env_wrapper.cuda_data_manager.get_reset_pool("state"))
reset_pool_mean = reset_pool.mean(axis=0).squeeze()

# we only need to check the 0th element of state because state[1] = 0 for reset always
self.assertTrue(reset_pool.std(axis=0).squeeze()[0] > 1e-4)

env_wrapper.cuda_data_manager.data_on_device_via_torch("_done_")[:] = torch.from_numpy(
np.array([1, 1, 0])
).cuda()
Expand All @@ -77,13 +80,12 @@ def test_reset_pool(self):
state_values_env1_mean = np.stack(state_values[1]).mean(axis=0).squeeze()
state_values_env2_mean = np.stack(state_values[2]).mean(axis=0).squeeze()

for i in range(len(reset_pool_mean)):
self.assertTrue(np.absolute(state_values_env0_mean[i] - reset_pool_mean[i]) < 0.1 * abs(reset_pool_mean[i]))
self.assertTrue(np.absolute(state_values_env1_mean[i] - reset_pool_mean[i]) < 0.1 * abs(reset_pool_mean[i]))
self.assertTrue(
np.absolute(
state_values_env2_mean[i] - state_after_initial_reset[0][i]
) < 0.001 * abs(state_after_initial_reset[0][i])
self.assertTrue(np.absolute(state_values_env0_mean[0] - reset_pool_mean[0]) < 0.1 * abs(reset_pool_mean[0]))
self.assertTrue(np.absolute(state_values_env1_mean[0] - reset_pool_mean[0]) < 0.1 * abs(reset_pool_mean[0]))
self.assertTrue(
np.absolute(
state_values_env2_mean[0] - state_after_initial_reset[0][0]
) < 0.001 * abs(state_after_initial_reset[0][0])
)


10 changes: 10 additions & 0 deletions warp_drive/training/example_training_script_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from example_envs.tag_continuous.tag_continuous import TagContinuous
from example_envs.tag_gridworld.tag_gridworld import CUDATagGridWorld, CUDATagGridWorldWithResetPool
from example_envs.single_agent.classic_control.cartpole.cartpole import CUDAClassicControlCartPoleEnv
from example_envs.single_agent.classic_control.mountain_car.mountain_car import CUDAClassicControlMountainCarEnv
from warp_drive.env_wrapper import EnvWrapper
from warp_drive.training.trainer import Trainer
from warp_drive.training.utils.distributed_train.distributed_trainer_numba import (
Expand All @@ -35,6 +36,7 @@
_TAG_GRIDWORLD_WITH_RESET_POOL = "tag_gridworld_with_reset_pool"

_CLASSIC_CONTROL_CARTPOLE = "single_cartpole"
__CLASSIC_CONTROL_MOUNTAIN_CAR = "single_mountain_car"

# Example usages (from the root folder):
# >> python warp_drive/training/example_training_script.py -e tag_gridworld
Expand Down Expand Up @@ -92,6 +94,14 @@ def setup_trainer_and_train(
event_messenger=event_messenger,
process_id=device_id,
)
elif run_configuration["name"] == __CLASSIC_CONTROL_MOUNTAIN_CAR:
env_wrapper = EnvWrapper(
CUDAClassicControlMountainCarEnv(**run_configuration["env"]),
num_envs=num_envs,
env_backend="numba",
event_messenger=event_messenger,
process_id=device_id,
)
else:
raise NotImplementedError(
f"Currently, the environments supported are ["
Expand Down
43 changes: 43 additions & 0 deletions warp_drive/training/run_configs/single_mountain_car.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2021, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root
# or https://opensource.org/licenses/BSD-3-Clause

# YAML configuration for the tag gridworld environment
name: "single_mountain_car"
# Environment settings
env:
episode_length: 500
reset_pool_size: 1000
# Trainer settings
trainer:
num_envs: 100 # number of environment replicas
num_episodes: 200000 # number of episodes to run the training for. Can be arbitrarily high!
train_batch_size: 50000 # total batch size used for training per iteration (across all the environments)
env_backend: "numba" # environment backend, pycuda or numba
# Policy network settings
policy: # list all the policies below
shared:
to_train: True # flag indicating whether the model needs to be trained
algorithm: "A2C" # algorithm used to train the policy
vf_loss_coeff: 1 # loss coefficient schedule for the value function loss
entropy_coeff: 0.05 # loss coefficient schedule for the entropy loss
clip_grad_norm: True # flag indicating whether to clip the gradient norm or not
max_grad_norm: 3 # when clip_grad_norm is True, the clip level
normalize_advantage: False # flag indicating whether to normalize advantage or not
normalize_return: False # flag indicating whether to normalize return or not
gamma: 0.99 # discount factor
lr: 0.001 # learning rate
model: # policy model settings
type: "fully_connected" # model type
fc_dims: [32, 32] # dimension(s) of the fully connected layers as a list
model_ckpt_filepath: "" # filepath (used to restore a previously saved model)
# Checkpoint saving setting
saving:
metrics_log_freq: 100 # how often (in iterations) to log (and print) the metrics
model_params_save_freq: 5000 # how often (in iterations) to save the model parameters
basedir: "/tmp" # base folder used for saving
name: "single_mountain_car" # base folder used for saving
tag: "experiments" # experiment name

1 change: 1 addition & 0 deletions warp_drive/utils/numba_utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def get_default_env_directory(env_name):
"TagGridWorld": "example_envs.tag_gridworld.tag_gridworld_step_numba",
"TagContinuous": "example_envs.tag_continuous.tag_continuous_step_numba",
"ClassicControlCartPoleEnv": "example_envs.single_agent.classic_control.cartpole.cartpole_step_numba",
"ClassicControlMountainCarEnv": "example_envs.single_agent.classic_control.mountain_car.mountain_car_step_numba",
"YOUR_ENVIRONMENT": "PYTHON_PATH_TO_YOUR_ENV_SRC",
}
return envs.get(env_name, None)
Expand Down

0 comments on commit 2093cd2

Please sign in to comment.