-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #90 from salesforce/classic_control
Classic control
- Loading branch information
Showing
7 changed files
with
339 additions
and
0 deletions.
There are no files selected for viewing
124 changes: 124 additions & 0 deletions
124
example_envs/single_agent/classic_control/mountain_car/mountain_car.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import numpy as np | ||
from warp_drive.utils.constants import Constants | ||
from warp_drive.utils.data_feed import DataFeed | ||
from warp_drive.utils.gpu_environment_context import CUDAEnvironmentContext | ||
|
||
from example_envs.single_agent.base import SingleAgentEnv, map_to_single_agent, get_action_for_single_agent | ||
from gym.envs.classic_control.mountain_car import MountainCarEnv | ||
|
||
_OBSERVATIONS = Constants.OBSERVATIONS | ||
_ACTIONS = Constants.ACTIONS | ||
_REWARDS = Constants.REWARDS | ||
|
||
|
||
class ClassicControlMountainCarEnv(SingleAgentEnv): | ||
|
||
name = "ClassicControlMountainCarEnv" | ||
|
||
def __init__(self, episode_length, env_backend="cpu", reset_pool_size=0, seed=None): | ||
super().__init__(episode_length, env_backend, reset_pool_size, seed=seed) | ||
|
||
self.gym_env = MountainCarEnv() | ||
|
||
self.action_space = map_to_single_agent(self.gym_env.action_space) | ||
self.observation_space = map_to_single_agent(self.gym_env.observation_space) | ||
|
||
def step(self, action=None): | ||
self.timestep += 1 | ||
action = get_action_for_single_agent(action) | ||
state, reward, terminated, _, _ = self.gym_env.step(action) | ||
|
||
obs = map_to_single_agent(state) | ||
rew = map_to_single_agent(reward) | ||
done = {"__all__": self.timestep >= self.episode_length or terminated} | ||
info = {} | ||
|
||
return obs, rew, done, info | ||
|
||
def reset(self): | ||
self.timestep = 0 | ||
if self.reset_pool_size < 2: | ||
# we use a fixed initial state all the time | ||
initial_state, _ = self.gym_env.reset(seed=self.seed) | ||
else: | ||
initial_state, _ = self.gym_env.reset(seed=None) | ||
obs = map_to_single_agent(initial_state) | ||
|
||
return obs | ||
|
||
|
||
class CUDAClassicControlMountainCarEnv(ClassicControlMountainCarEnv, CUDAEnvironmentContext): | ||
|
||
def get_data_dictionary(self): | ||
data_dict = DataFeed() | ||
initial_state, _ = self.gym_env.reset(seed=self.seed) | ||
|
||
if self.reset_pool_size < 2: | ||
data_dict.add_data( | ||
name="state", | ||
data=np.atleast_2d(initial_state), | ||
save_copy_and_apply_at_reset=True, | ||
) | ||
else: | ||
data_dict.add_data( | ||
name="state", | ||
data=np.atleast_2d(initial_state), | ||
save_copy_and_apply_at_reset=False, | ||
) | ||
|
||
data_dict.add_data_list( | ||
[ | ||
("min_position", self.gym_env.min_position), | ||
("max_position", self.gym_env.max_position), | ||
("max_speed", self.gym_env.max_speed), | ||
("goal_position", self.gym_env.goal_position), | ||
("goal_velocity", self.gym_env.goal_velocity), | ||
("force", self.gym_env.force), | ||
("gravity", self.gym_env.gravity), | ||
] | ||
) | ||
return data_dict | ||
|
||
def get_tensor_dictionary(self): | ||
tensor_dict = DataFeed() | ||
return tensor_dict | ||
|
||
def get_reset_pool_dictionary(self): | ||
reset_pool_dict = DataFeed() | ||
if self.reset_pool_size >= 2: | ||
state_reset_pool = [] | ||
for _ in range(self.reset_pool_size): | ||
initial_state, _ = self.gym_env.reset(seed=None) | ||
state_reset_pool.append(np.atleast_2d(initial_state)) | ||
state_reset_pool = np.stack(state_reset_pool, axis=0) | ||
assert len(state_reset_pool.shape) == 3 and state_reset_pool.shape[2] == 2 | ||
|
||
reset_pool_dict.add_pool_for_reset(name="state_reset_pool", | ||
data=state_reset_pool, | ||
reset_target="state") | ||
return reset_pool_dict | ||
|
||
def step(self, actions=None): | ||
self.timestep += 1 | ||
args = [ | ||
"state", | ||
_ACTIONS, | ||
"_done_", | ||
_REWARDS, | ||
_OBSERVATIONS, | ||
"min_position", | ||
"max_position", | ||
"max_speed", | ||
"goal_position", | ||
"goal_velocity", | ||
"force", | ||
"gravity", | ||
"_timestep_", | ||
("episode_length", "meta"), | ||
] | ||
if self.env_backend == "numba": | ||
self.cuda_step[ | ||
self.cuda_function_manager.grid, self.cuda_function_manager.block | ||
](*self.cuda_step_function_feed(args)) | ||
else: | ||
raise Exception("CUDAClassicControlMountainCarEnv expects env_backend = 'numba' ") |
68 changes: 68 additions & 0 deletions
68
example_envs/single_agent/classic_control/mountain_car/mountain_car_step_numba.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import numba.cuda as numba_driver | ||
import math | ||
|
||
|
||
@numba_driver.jit | ||
def _clip(v, min, max): | ||
if v < min: | ||
return min | ||
if v > max: | ||
return max | ||
return v | ||
|
||
|
||
@numba_driver.jit | ||
def NumbaClassicControlMountainCarEnvStep( | ||
state_arr, | ||
action_arr, | ||
done_arr, | ||
reward_arr, | ||
observation_arr, | ||
min_position, | ||
max_position, | ||
max_speed, | ||
goal_position, | ||
goal_velocity, | ||
force, | ||
gravity, | ||
env_timestep_arr, | ||
episode_length): | ||
|
||
kEnvId = numba_driver.blockIdx.x | ||
kThisAgentId = numba_driver.threadIdx.x | ||
|
||
assert kThisAgentId == 0, "We only have one agent per environment" | ||
|
||
env_timestep_arr[kEnvId] += 1 | ||
|
||
assert 0 < env_timestep_arr[kEnvId] <= episode_length | ||
|
||
reward_arr[kEnvId, kThisAgentId] = 0.0 | ||
|
||
action = action_arr[kEnvId, kThisAgentId, 0] | ||
|
||
position = state_arr[kEnvId, kThisAgentId, 0] | ||
velocity = state_arr[kEnvId, kThisAgentId, 1] | ||
|
||
velocity += (action - 1) * force + math.cos(3 * position) * (-gravity) | ||
velocity = _clip(velocity, -max_speed, max_speed) | ||
position += velocity | ||
position = _clip(position, min_position, max_position) | ||
if position == min_position and velocity < 0: | ||
velocity = 0 | ||
|
||
state_arr[kEnvId, kThisAgentId, 0] = position | ||
state_arr[kEnvId, kThisAgentId, 1] = velocity | ||
|
||
observation_arr[kEnvId, kThisAgentId, 0] = state_arr[kEnvId, kThisAgentId, 0] | ||
observation_arr[kEnvId, kThisAgentId, 1] = state_arr[kEnvId, kThisAgentId, 1] | ||
|
||
terminated = bool( | ||
position >= goal_position and velocity >= goal_velocity | ||
) | ||
|
||
# as long as not reset, we assign reward -1. This is consistent with original cartpole logic | ||
reward_arr[kEnvId, kThisAgentId] = -1.0 | ||
|
||
if env_timestep_arr[kEnvId] == episode_length or terminated: | ||
done_arr[kEnvId] = 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
91 changes: 91 additions & 0 deletions
91
tests/example_envs/numba_tests/single_agent/classic_control/test_mountain_car.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import unittest | ||
import numpy as np | ||
import torch | ||
|
||
from warp_drive.env_cpu_gpu_consistency_checker import EnvironmentCPUvsGPU | ||
from example_envs.single_agent.classic_control.mountain_car.mountain_car import \ | ||
ClassicControlMountainCarEnv, CUDAClassicControlMountainCarEnv | ||
from warp_drive.env_wrapper import EnvWrapper | ||
|
||
|
||
env_configs = { | ||
"test1": { | ||
"episode_length": 500, | ||
"reset_pool_size": 0, | ||
"seed": 32145, | ||
}, | ||
"test2": { | ||
"episode_length": 200, | ||
"reset_pool_size": 0, | ||
"seed": 54231, | ||
}, | ||
} | ||
|
||
|
||
class MyTestCase(unittest.TestCase): | ||
""" | ||
CPU v GPU consistency unit tests | ||
""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.testing_class = EnvironmentCPUvsGPU( | ||
cpu_env_class=ClassicControlMountainCarEnv, | ||
cuda_env_class=CUDAClassicControlMountainCarEnv, | ||
env_configs=env_configs, | ||
gpu_env_backend="numba", | ||
num_envs=5, | ||
num_episodes=2, | ||
) | ||
|
||
def test_env_consistency(self): | ||
try: | ||
self.testing_class.test_env_reset_and_step() | ||
except AssertionError: | ||
self.fail("ClassicControlMountainCarEnv environment consistency tests failed") | ||
|
||
def test_reset_pool(self): | ||
env_wrapper = EnvWrapper( | ||
env_obj=CUDAClassicControlMountainCarEnv(episode_length=100, reset_pool_size=3), | ||
num_envs=3, | ||
env_backend="numba", | ||
) | ||
env_wrapper.reset_all_envs() | ||
env_wrapper.env_resetter.init_reset_pool(env_wrapper.cuda_data_manager, seed=12345) | ||
self.assertTrue(env_wrapper.cuda_data_manager.reset_target_to_pool["state"] == "state_reset_pool") | ||
|
||
# squeeze() the agent dimension which is 1 always | ||
state_after_initial_reset = env_wrapper.cuda_data_manager.pull_data_from_device("state").squeeze() | ||
|
||
reset_pool = env_wrapper.cuda_data_manager.pull_data_from_device( | ||
env_wrapper.cuda_data_manager.get_reset_pool("state")) | ||
reset_pool_mean = reset_pool.mean(axis=0).squeeze() | ||
|
||
# we only need to check the 0th element of state because state[1] = 0 for reset always | ||
self.assertTrue(reset_pool.std(axis=0).squeeze()[0] > 1e-4) | ||
|
||
env_wrapper.cuda_data_manager.data_on_device_via_torch("_done_")[:] = torch.from_numpy( | ||
np.array([1, 1, 0]) | ||
).cuda() | ||
|
||
state_values = {0: [], 1: [], 2: []} | ||
for _ in range(10000): | ||
env_wrapper.env_resetter.reset_when_done(env_wrapper.cuda_data_manager, mode="if_done", undo_done_after_reset=False) | ||
res = env_wrapper.cuda_data_manager.pull_data_from_device("state") | ||
state_values[0].append(res[0]) | ||
state_values[1].append(res[1]) | ||
state_values[2].append(res[2]) | ||
|
||
state_values_env0_mean = np.stack(state_values[0]).mean(axis=0).squeeze() | ||
state_values_env1_mean = np.stack(state_values[1]).mean(axis=0).squeeze() | ||
state_values_env2_mean = np.stack(state_values[2]).mean(axis=0).squeeze() | ||
|
||
self.assertTrue(np.absolute(state_values_env0_mean[0] - reset_pool_mean[0]) < 0.1 * abs(reset_pool_mean[0])) | ||
self.assertTrue(np.absolute(state_values_env1_mean[0] - reset_pool_mean[0]) < 0.1 * abs(reset_pool_mean[0])) | ||
self.assertTrue( | ||
np.absolute( | ||
state_values_env2_mean[0] - state_after_initial_reset[0][0] | ||
) < 0.001 * abs(state_after_initial_reset[0][0]) | ||
) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# Copyright (c) 2021, salesforce.com, inc. | ||
# All rights reserved. | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
# For full license text, see the LICENSE file in the repo root | ||
# or https://opensource.org/licenses/BSD-3-Clause | ||
|
||
# YAML configuration for the tag gridworld environment | ||
name: "single_mountain_car" | ||
# Environment settings | ||
env: | ||
episode_length: 500 | ||
reset_pool_size: 1000 | ||
# Trainer settings | ||
trainer: | ||
num_envs: 100 # number of environment replicas | ||
num_episodes: 200000 # number of episodes to run the training for. Can be arbitrarily high! | ||
train_batch_size: 50000 # total batch size used for training per iteration (across all the environments) | ||
env_backend: "numba" # environment backend, pycuda or numba | ||
# Policy network settings | ||
policy: # list all the policies below | ||
shared: | ||
to_train: True # flag indicating whether the model needs to be trained | ||
algorithm: "A2C" # algorithm used to train the policy | ||
vf_loss_coeff: 1 # loss coefficient schedule for the value function loss | ||
entropy_coeff: 0.05 # loss coefficient schedule for the entropy loss | ||
clip_grad_norm: True # flag indicating whether to clip the gradient norm or not | ||
max_grad_norm: 3 # when clip_grad_norm is True, the clip level | ||
normalize_advantage: False # flag indicating whether to normalize advantage or not | ||
normalize_return: False # flag indicating whether to normalize return or not | ||
gamma: 0.99 # discount factor | ||
lr: 0.001 # learning rate | ||
model: # policy model settings | ||
type: "fully_connected" # model type | ||
fc_dims: [32, 32] # dimension(s) of the fully connected layers as a list | ||
model_ckpt_filepath: "" # filepath (used to restore a previously saved model) | ||
# Checkpoint saving setting | ||
saving: | ||
metrics_log_freq: 100 # how often (in iterations) to log (and print) the metrics | ||
model_params_save_freq: 5000 # how often (in iterations) to save the model parameters | ||
basedir: "/tmp" # base folder used for saving | ||
name: "single_mountain_car" # base folder used for saving | ||
tag: "experiments" # experiment name | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters