Skip to content

Commit

Permalink
feat(exp_cfg): fix 'start_policy' empty config bug
Browse files Browse the repository at this point in the history
This commit ensures that a descriptive error message is thrown if the
'start_policy' variable is specified as None in the `exp_cfg` file.
  • Loading branch information
rickstaa committed Feb 8, 2024
1 parent cefac06 commit e7f3cf9
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 14 deletions.
1 change: 0 additions & 1 deletion experiments/lac_cartpole_cost_experiment_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,4 @@ replay_size: "int(1e6)"
seed: 0 234 567 # NOTE: Use comma/space separated string for hyperparameter variants
device: "cpu"
save_freq: 1
start_policy:
export: False
1 change: 0 additions & 1 deletion experiments/lac_oscillator_experiment_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,4 @@ replay_size: "int(1e6)"
seed: 0 234 567 # NOTE: Use comma/space separated string for hyperparameter variants
device: "cpu"
save_freq: 1
start_policy:
export: False
7 changes: 4 additions & 3 deletions stable_learning_control/algos/pytorch/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,10 +1155,11 @@ def lac(
policy.restore(start_policy)
logger.log("Model successfully restored.", type="info")
except Exception as e:
err_str = e.args[0].lower().rstrip(".")
logger.log(
"Shutting down training since {}.".format(
e.args[0].lower().rstrip(".")
),
f"Training process has been terminated. Unable to restore the "
f"'start_policy' from '{start_policy}'. Please ensure the "
f"'start_policy' is correct and try again. Error details: {err_str}.",
type="error",
)
sys.exit(0)
Expand Down
7 changes: 4 additions & 3 deletions stable_learning_control/algos/pytorch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,10 +1021,11 @@ def sac(
policy.restore(start_policy)
logger.log("Model successfully restored.", type="info")
except Exception as e:
err_str = e.args[0].lower().rstrip(".")
logger.log(
"Shutting down training since {}.".format(
e.args[0].lower().rstrip(".")
),
f"Training process has been terminated. Unable to restore the "
f"'start_policy' from '{start_policy}'. Please ensure the "
f"'start_policy' is correct and try again. Error details: {err_str}.",
type="error",
)
sys.exit(0)
Expand Down
7 changes: 4 additions & 3 deletions stable_learning_control/algos/tf2/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,10 +1096,11 @@ def lac(
policy.restore(start_policy)
logger.log("Model successfully restored.", type="info")
except Exception as e:
err_str = e.args[0].lower().rstrip(".")
logger.log(
"Shutting down training since {}.".format(
e.args[0].lower().rstrip(".")
),
f"Training process has been terminated. Unable to restore the "
f"'start_policy' from '{start_policy}'. Please ensure the "
f"'start_policy' is correct and try again. Error details: {err_str}.",
type="error",
)
sys.exit(0)
Expand Down
7 changes: 4 additions & 3 deletions stable_learning_control/algos/tf2/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,10 +963,11 @@ def sac(
policy.restore(start_policy)
logger.log("Model successfully restored.", type="info")
except Exception as e:
err_str = e.args[0].lower().rstrip(".")
logger.log(
"Shutting down training since {}.".format(
e.args[0].lower().rstrip(".")
),
f"Training process has been terminated. Unable to restore the "
f"'start_policy' from '{start_policy}'. Please ensure the "
f"'start_policy' is correct and try again. Error details: {err_str}.",
type="error",
)
sys.exit(0)
Expand Down

0 comments on commit e7f3cf9

Please sign in to comment.