Skip to content

Commit

Permalink
Code clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
JacopoPan committed Aug 21, 2021
1 parent 3b05d6d commit 76ff24c
Show file tree
Hide file tree
Showing 57 changed files with 2,473 additions and 2,350 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ safe_control_gym/envs/gym_control/assets/cartpole.urdf
experiments/figure6/trained_gp_model/bak_best_model_*.pth
experiments/figure7/safe_exp_results/
experiments/figure8/unsafe_ppo_temp_data/
experiments/figure8/unsafe_ppo_model/bak_unsafe_ppo_model_30000.pt
temp-data/
z_docstring.py
TODOs.md
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,3 @@ task_config:
distrib: 'uniform'
low: -0.1
high: 0.1

29 changes: 25 additions & 4 deletions experiments/figure6/gp_mpc_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@
from safe_control_gym.utils.configuration import ConfigFactory


def plot_xz_comparison_diag_constraint(prior_run, run, init_ind, dir=None):
def plot_xz_comparison_diag_constraint(prior_run,
run,
init_ind,
dir=None
):
"""
"""
state_inds = [0,2]
goal = [0, 1]
ax = plot_2D_comparison_with_prior(state_inds, prior_run, run, goal, init_ind, dir=dir)
Expand All @@ -34,7 +41,15 @@ def plot_xz_comparison_diag_constraint(prior_run, run, init_ind, dir=None):
plt.tight_layout()


def plot_2D_comparison_with_prior(state_inds, prior_run, run, goal, init_ind, dir=None):
def plot_2D_comparison_with_prior(state_inds,
prior_run,
run, goal,
init_ind,
dir=None
):
"""
"""
horizon_cov = run.state_horizon_cov[init_ind]
horizon_states = run.horizon_states[init_ind]
prior_horizon_states = prior_run.horizon_states[init_ind]
Expand Down Expand Up @@ -94,7 +109,14 @@ def plot_2D_comparison_with_prior(state_inds, prior_run, run, goal, init_ind, di
return ax


def add_2d_cov_ellipse(position, cov, ax, legend=False):
def add_2d_cov_ellipse(position,
cov,
ax,
legend=False
):
"""
"""
evals, evecs = np.linalg.eig(cov)
major_axis_ind = np.argmax(evals)
minor_axis_ind = 0 if major_axis_ind == 1 else 1
Expand Down Expand Up @@ -127,7 +149,6 @@ def add_2d_cov_ellipse(position, cov, ax, legend=False):
fac = ConfigFactory()
fac.add_argument("--train_only", type=bool, default=False, help="True if only training is performed.")
config = fac.merge()

# Create environment.
env_func = partial(make,
config.task,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ algo_config:
pretraining: False
pretrained: null
constraint_slack: 0.05

log_interval: 1000
save_interval: 1000
num_checkpoints: 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ algo_config:
constraint_eval_interval: 5
constraint_buffer_size: 1000000
constraint_slack: 0.05

log_interval: 10
save_interval: 10
num_checkpoints: 5
Expand All @@ -27,4 +26,3 @@ task_config:
constraint_input_type: 'STATE'
active_dims: 0
done_on_violation: True

1 change: 1 addition & 0 deletions experiments/figure7/create_fig7.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/bin/bash

# Safe Explorer vs PPO baselines.
python3 ./safe_exp_plots.py --plot_dir ./safe_exp_results
21 changes: 11 additions & 10 deletions experiments/figure7/create_safe_exp_results.sh
Original file line number Diff line number Diff line change
@@ -1,30 +1,34 @@
#!/bin/bash
# proper cleanup for background processes

# Allow proper cleanup of background processes.
trap "exit" INT TERM ERR
trap "kill 0" EXIT

# Remove previous results.
rm -r -f ./safe_exp_results/

# Writing paths.
OUTPUT_DIR="./"
TAG_ROOT="safe_exp_results"

# Configuration files path.
CONFIG_PATH_ROOT="./config_overrides"

# Options.
seeds=(2 22 222 2222 22222 9 90 998 9999 90001)
thread=1

########################## ppo
# PPO.
TAG="ppo"
CONFIG_PATH="${CONFIG_PATH_ROOT}/ppo_cartpole.yaml"

for seed in "${seeds[@]}"
do
python3 ../main.py --algo ppo --task cartpole --overrides $CONFIG_PATH --output_dir ${OUTPUT_DIR} --tag $TAG_ROOT/$TAG --thread $thread --seed $seed
done

########################## ppo + reward shaping
# PPO with reward shaping.
TAG="ppo_rs"
CONFIG_PATH="${CONFIG_PATH_ROOT}/ppo_rs_cartpole.yaml"

tolerances=(0.15 0.2)
for tolerance in "${tolerances[@]}"
do
Expand All @@ -34,19 +38,16 @@ do
done
done

########################## pretrain safe explorer
# Safe Explorer pre-training.
TAG="safe_exp_pretrain"
CONFIG_PATH="${CONFIG_PATH_ROOT}/safe_explorer_ppo_cartpole_pretrain.yaml"

train_seed=88890
python3 ../main.py --algo safe_explorer_ppo --task cartpole --overrides $CONFIG_PATH --output_dir ${OUTPUT_DIR} --tag $TAG_ROOT/$TAG --thread $thread --seed $train_seed

########################## train safe explorer
# Safe Explorer.
PRETRAINED_PATH=(${OUTPUT_DIR}/$TAG_ROOT/$TAG/seed${train_seed}*)

TAG="safe_exp_"
CONFIG_PATH="${CONFIG_PATH_ROOT}/safe_explorer_ppo_cartpole.yaml"

slacks=(0.15 0.2)
for slack in "${slacks[@]}"
do
Expand Down

0 comments on commit 76ff24c

Please sign in to comment.