Skip to content

Commit

Permalink
Merge branch 'master' of github.com:takuseno/d3rlpy
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jun 1, 2024
2 parents 18d710a + da46ef1 commit 5f810eb
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
6 changes: 5 additions & 1 deletion d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ def fit_online(
random_steps: int = 0,
eval_env: Optional[GymEnv] = None,
eval_epsilon: float = 0.0,
eval_n_trials: int = 10,
save_interval: int = 1,
experiment_name: Optional[str] = None,
with_timestamp: bool = True,
Expand Down Expand Up @@ -764,7 +765,10 @@ def fit_online(
# evaluation
if eval_env:
eval_score = evaluate_qlearning_with_environment(
self, eval_env, epsilon=eval_epsilon
self,
eval_env,
n_trials=eval_n_trials,
epsilon=eval_epsilon,
)
logger.add_metric("evaluation", eval_score)

Expand Down
8 changes: 8 additions & 0 deletions d3rlpy/dataset/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,7 @@ def create_fifo_replay_buffer(
trajectory_slicer: Optional[TrajectorySlicerProtocol] = None,
writer_preprocessor: Optional[WriterPreprocessProtocol] = None,
env: Optional[GymEnv] = None,
write_at_termination: bool = False,
) -> ReplayBuffer:
"""Builds FIFO replay buffer.
Expand All @@ -770,6 +771,8 @@ def create_fifo_replay_buffer(
Writer preprocessor implementation. If ``None`` is given,
``BasicWriterPreprocess`` is used by default.
env: Gym environment to extract shapes of observations and action.
write_at_termination (bool): Flag to write experiences to the buffer at the
end of an episode all at once.
Returns:
Replay buffer.
Expand All @@ -782,6 +785,7 @@ def create_fifo_replay_buffer(
trajectory_slicer=trajectory_slicer,
writer_preprocessor=writer_preprocessor,
env=env,
write_at_termination=write_at_termination,
)


Expand All @@ -791,6 +795,7 @@ def create_infinite_replay_buffer(
trajectory_slicer: Optional[TrajectorySlicerProtocol] = None,
writer_preprocessor: Optional[WriterPreprocessProtocol] = None,
env: Optional[GymEnv] = None,
write_at_termination: bool = False,
) -> ReplayBuffer:
"""Builds infinite replay buffer.
Expand All @@ -809,6 +814,8 @@ def create_infinite_replay_buffer(
Writer preprocessor implementation. If ``None`` is given,
``BasicWriterPreprocess`` is used by default.
env: Gym environment to extract shapes of observations and action.
write_at_termination (bool): Flag to write experiences to the buffer at the
end of an episode all at once.
Returns:
Replay buffer.
Expand All @@ -821,4 +828,5 @@ def create_infinite_replay_buffer(
trajectory_slicer=trajectory_slicer,
writer_preprocessor=writer_preprocessor,
env=env,
write_at_termination=write_at_termination,
)
6 changes: 5 additions & 1 deletion reproductions/finetuning/cal_ql_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def main() -> None:
n_steps=1000000,
n_steps_per_epoch=1000,
save_interval=10,
evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
evaluators={
"environment": d3rlpy.metrics.EnvironmentEvaluator(env, n_trials=20)
},
experiment_name=f"CalQL_pretraining_{args.dataset}_{args.seed}",
)

Expand All @@ -68,6 +70,7 @@ def main() -> None:
limit=1000000,
env=env,
transition_picker=transition_picker,
write_at_termination=True,
)

# sample half from offline dataset and the rest from online buffer
Expand All @@ -90,6 +93,7 @@ def main() -> None:
n_updates=1000,
update_interval=1000,
save_interval=10,
eval_n_trials=20,
)


Expand Down

0 comments on commit 5f810eb

Please sign in to comment.