Skip to content

Commit

Permalink
benchmark test setup
Browse files Browse the repository at this point in the history
  • Loading branch information
svsawant committed May 24, 2024
1 parent 94cbcc4 commit c57b53a
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 35 deletions.
54 changes: 28 additions & 26 deletions examples/rl/config_overrides/cartpole/cartpole_stab.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,65 +8,67 @@ task_config:

# state initialization
init_state:
init_x: 0.1
init_x_dot: -1.5
init_theta: -0.155
init_theta_dot: 0.75
init_x: 0.5
init_x_dot: 0.0
init_theta: 0.0
init_theta_dot: 0.0
randomized_init: True
randomized_inertial_prop: False

init_state_randomization_info:
init_x:
distrib: 'uniform'
low: -2
high: 2
low: 0.0
high: 0.0
init_x_dot:
distrib: 'uniform'
low: -2
high: 2
low: -0.05
high: 0.05
init_theta:
distrib: 'uniform'
low: -0.16
high: 0.16
low: -0.05
high: 0.05
init_theta_dot:
distrib: 'uniform'
low: -1
high: 1
low: -0.05
high: 0.05

task: stabilization
task_info:
stabilization_goal: [0.7, 0]
stabilization_goal: [0.0, 0]
stabilization_goal_tolerance: 0.0

inertial_prop:
pole_length: 0.5
cart_mass: 1
pole_mass: 0.1

episode_len_sec: 10
episode_len_sec: 5
cost: rl_reward
obs_goal_horizon: 0

# RL Reward
rew_state_weight: [1, 1, 1, 1]
rew_state_weight: [1, 0.1, 1, 0.1]
rew_act_weight: 0.1
rew_exponential: True

# Disturbances
disturbances:
observation:
- disturbance_func: white_noise
std: 0.0001

# constraints
constraints:
- constraint_form: default_constraint
constrained_variable: state
upper_bounds:
- 2
- 2
- 0.16
- 1
lower_bounds:
- -2
- -2
- -0.16
- -1
upper_bounds: [10, 10, 5, 10]
lower_bounds: [-10, -10, -5, -10]
- constraint_form: default_constraint
constrained_variable: input
done_on_out_of_bound: True
upper_bounds: [10]
lower_bounds: [-10]
done_on_out_of_bound: False
done_on_violation: False
use_constraint_penalty: False
constraint_penalty: -10.0
2 changes: 1 addition & 1 deletion examples/rl/config_overrides/cartpole/ppo_cartpole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ algo_config:
max_grad_norm: 0.5

# runner args
max_env_steps: 300000
max_env_steps: 720000
num_workers: 1
rollout_batch_size: 4
rollout_steps: 150
Expand Down
8 changes: 4 additions & 4 deletions examples/rl/config_overrides/cartpole/sac_cartpole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@ algo_config:
entropy_lr: 0.001

# runner args
max_env_steps: 150000
max_env_steps: 10000
warm_up_steps: 100
rollout_batch_size: 4
num_workers: 1
max_buffer_size: 1000000
max_buffer_size: 10000
deque_size: 10
eval_batch_size: 10

# misc
log_interval: 3000
log_interval: 500
save_interval: 0
num_checkpoints: 0
eval_interval: 3000
eval_interval: 500
eval_save_best: True
tensorboard: False
4 changes: 2 additions & 2 deletions safe_control_gym/controllers/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def select_action(self, obs, info=None):
def run(self,
env=None,
render=False,
n_episodes=10,
n_episodes=50,
verbose=False,
):
'''Runs evaluation with current policy.'''
Expand All @@ -218,7 +218,7 @@ def run(self,

obs, info = env.reset()
obs = self.obs_normalizer(obs)
ep_returns, ep_lengths = [], []
ep_returns, ep_lengths, eval_return = [], [], 0.0
frames = []
while len(ep_returns) < n_episodes:
action = self.select_action(obs=obs, info=info)
Expand Down
6 changes: 4 additions & 2 deletions safe_control_gym/envs/gym_control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ def __init__(self,
# BenchmarkEnv constructor, called after defining the custom args,
# since some BenchmarkEnv init setup can be task(custom args)-dependent.
super().__init__(init_state=init_state, inertial_prop=inertial_prop, **kwargs)
self.Q = np.diag(self.rew_state_weight)
self.R = np.diag(self.rew_act_weight)

# Create PyBullet client connection.
self.PYB_CLIENT = -1
Expand Down Expand Up @@ -607,8 +609,8 @@ def _get_reward(self):
act = np.asarray(self.current_noisy_physical_action)
if self.TASK == Task.STABILIZATION:
state_error = state - self.X_GOAL
dist = np.sum(self.rew_state_weight * state_error * state_error)
dist += np.sum(self.rew_act_weight * act * act)
dist = 0.5 * np.sum(self.rew_state_weight * state_error * state_error)
dist += 0.5 * np.sum(self.rew_act_weight * act * act)
if self.TASK == Task.TRAJ_TRACKING:
wp_idx = min(self.ctrl_step_counter + 1, self.X_GOAL.shape[0] - 1) # +1 because state has already advanced but counter not incremented.
state_error = state - self.X_GOAL[wp_idx]
Expand Down

0 comments on commit c57b53a

Please sign in to comment.