Skip to content

Commit

Permalink
Pull epoch_callback
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Jun 23, 2023
1 parent 27d7355 commit e7840e0
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def fit(
save_interval: int = 1,
evaluators: Optional[Dict[str, EvaluatorProtocol]] = None,
callback: Optional[Callable[[Self, int, int], None]] = None,
epoch_callback: Optional[Callable[[Self, int, int], None]] = None,
) -> List[Tuple[int, Dict[str, float]]]:
"""Trains with the given dataset.
Expand All @@ -393,6 +394,9 @@ def fit(
evaluators: list of evaluators.
callback: callable function that takes ``(algo, epoch, total_step)``
, which is called every step.
epoch_callback: callable function that takes
``(algo, epoch, total_step)``, which is called at the end of
every epoch.
Returns:
list of result tuples (epoch, metrics) per epoch.
Expand All @@ -409,6 +413,7 @@ def fit(
save_interval,
evaluators,
callback,
epoch_callback,
)
)
return results
Expand All @@ -425,6 +430,7 @@ def fitter(
save_interval: int = 1,
evaluators: Optional[Dict[str, EvaluatorProtocol]] = None,
callback: Optional[Callable[[Self, int, int], None]] = None,
epoch_callback: Optional[Callable[[Self, int, int], None]] = None,
) -> Generator[Tuple[int, Dict[str, float]], None, None]:
"""Iterate over epochs steps to train with the given dataset. At each
iteration algo methods and properties can be changed or queried.
Expand All @@ -450,6 +456,9 @@ def fitter(
evaluators: list of evaluators.
callback: callable function that takes ``(algo, epoch, total_step)``
, which is called every step.
epoch_callback: callable function that takes
``(algo, epoch, total_step)``, which is called at the end of
every epoch.
Returns:
iterator yielding current epoch and metrics dict.
Expand Down Expand Up @@ -530,6 +539,10 @@ def fitter(
if callback:
callback(self, epoch, total_step)

# call epoch_callback if given
if epoch_callback:
epoch_callback(self, epoch, total_step)

if evaluators:
for name, evaluator in evaluators.items():
test_score = evaluator(self, dataset)
Expand Down

0 comments on commit e7840e0

Please sign in to comment.