Skip to content

Commit

Permalink
Add save_interval argument to algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jan 27, 2021
1 parent 76ba368 commit 0d91df6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
6 changes: 6 additions & 0 deletions d3rlpy/algos/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def fit_online(
eval_env: Optional[gym.Env] = None,
eval_epsilon: float = 0.0,
save_metrics: bool = True,
save_interval: int = 1,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logdir: str = "d3rlpy_logs",
Expand All @@ -214,6 +215,7 @@ def fit_online(
eval_epsilon: :math:`\\epsilon`-greedy factor during evaluation.
save_metrics: flag to record metrics. If False, the log
directory is not created and the model parameters are not saved.
save_interval: the number of epochs before saving models.
experiment_name: experiment name for logging. If not passed,
the directory name will be ``{class name}_online_{timestamp}``.
with_timestamp: flag to add timestamp string to the last of
Expand Down Expand Up @@ -245,6 +247,7 @@ def fit_online(
eval_env=eval_env,
eval_epsilon=eval_epsilon,
save_metrics=save_metrics,
save_interval=save_interval,
experiment_name=experiment_name,
with_timestamp=with_timestamp,
logdir=logdir,
Expand All @@ -266,6 +269,7 @@ def fit_batch_online(
eval_env: Optional[gym.Env] = None,
eval_epsilon: float = 0.0,
save_metrics: bool = True,
save_interval: int = 1,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logdir: str = "d3rlpy_logs",
Expand All @@ -289,6 +293,7 @@ def fit_batch_online(
eval_epsilon: :math:`\\epsilon`-greedy factor during evaluation.
save_metrics: flag to record metrics. If False, the log
directory is not created and the model parameters are not saved.
save_interval: the number of epochs before saving models.
experiment_name: experiment name for logging. If not passed,
the directory name will be ``{class name}_online_{timestamp}``.
with_timestamp: flag to add timestamp string to the last of
Expand Down Expand Up @@ -320,6 +325,7 @@ def fit_batch_online(
eval_env=eval_env,
eval_epsilon=eval_epsilon,
save_metrics=save_metrics,
save_interval=save_interval,
experiment_name=experiment_name,
with_timestamp=with_timestamp,
logdir=logdir,
Expand Down
12 changes: 9 additions & 3 deletions d3rlpy/online/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def train_single_env(
eval_env: Optional[gym.Env] = None,
eval_epsilon: float = 0.0,
save_metrics: bool = True,
save_interval: int = 1,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logdir: str = "d3rlpy_logs",
Expand All @@ -121,6 +122,7 @@ def train_single_env(
eval_epsilon: :math:`\\epsilon`-greedy factor during evaluation.
save_metrics: flag to record metrics. If False, the log
directory is not created and the model parameters are not saved.
save_interval: the number of epochs before saving models.
experiment_name: experiment name for logging. If not passed,
the directory name will be ``{class name}_online_{timestamp}``.
with_timestamp: flag to add timestamp string to the last of
Expand Down Expand Up @@ -251,6 +253,7 @@ def train_single_env(
if eval_scorer:
logger.add_metric("evaluation", eval_scorer(algo))

if epoch % save_interval == 0:
# save metrics
logger.commit(epoch, total_step)
logger.save_model(total_step, algo)
Expand All @@ -268,6 +271,7 @@ def train_batch_env(
eval_env: Optional[gym.Env] = None,
eval_epsilon: float = 0.0,
save_metrics: bool = True,
save_interval: int = 1,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
logdir: str = "d3rlpy_logs",
Expand All @@ -291,6 +295,7 @@ def train_batch_env(
eval_epsilon: :math:`\\epsilon`-greedy factor during evaluation.
save_metrics: flag to record metrics. If False, the log
directory is not created and the model parameters are not saved.
save_interval: the number of epochs before saving models.
experiment_name: experiment name for logging. If not passed,
the directory name will be ``{class name}_online_{timestamp}``.
with_timestamp: flag to add timestamp string to the last of
Expand Down Expand Up @@ -419,9 +424,10 @@ def train_batch_env(
if eval_scorer:
logger.add_metric("evaluation", eval_scorer(algo))

# save metrics
logger.commit(epoch, total_step)
logger.save_model(total_step, algo)
if epoch % save_interval == 0:
# save metrics
logger.commit(epoch, total_step)
logger.save_model(total_step, algo)

# finish all process
env.close()

0 comments on commit 0d91df6

Please sign in to comment.