Skip to content

Commit

Permalink
Epoch callback for offline .fit method (#286)
Browse files Browse the repository at this point in the history
* added discrete cql regulariser value to output for tracking

* removed obelet scalars in torch_api decorator

* added epoch callback functionality to learnable base

* clearer epoch callback comments

* reformatting
  • Loading branch information
joshuaspear committed Jun 23, 2023
1 parent a11b15d commit d5992b5
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions d3rlpy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,9 @@ def fit(
] = None,
shuffle: bool = True,
callback: Optional[Callable[["LearnableBase", int, int], None]] = None,
epoch_callback: Optional[
Callable[["LearnableBase", int, int], None]
] = None,
) -> List[Tuple[int, Dict[str, float]]]:
"""Trains with the given dataset.
Expand Down Expand Up @@ -398,6 +401,9 @@ def fit(
shuffle: flag to shuffle transitions on each epoch.
callback: callable function that takes ``(algo, epoch, total_step)``
, which is called every step.
epoch_callback: callable function that takes
``(algo, epoch, total_step)``, which is called at the end of
every epoch.
Returns:
list of result tuples (epoch, metrics) per epoch.
Expand All @@ -421,6 +427,7 @@ def fit(
scorers,
shuffle,
callback,
epoch_callback,
)
)
return results
Expand All @@ -445,6 +452,9 @@ def fitter(
] = None,
shuffle: bool = True,
callback: Optional[Callable[["LearnableBase", int, int], None]] = None,
epoch_callback: Optional[
Callable[["LearnableBase", int, int], None]
] = None,
) -> Generator[Tuple[int, Dict[str, float]], None, None]:
"""Iterate over epochs steps to train with the given dataset. At each
iteration algo methods and properties can be changed or queried.
Expand Down Expand Up @@ -480,6 +490,9 @@ def fitter(
shuffle: flag to shuffle transitions on each epoch.
callback: callable function that takes ``(algo, epoch, total_step)``
, which is called every step.
epoch_callback: callable function that takes
``(algo, epoch, total_step)``, which is called at the end of
every epoch.
Returns:
iterator yielding current epoch and metrics dict.
Expand Down Expand Up @@ -654,6 +667,9 @@ def fitter(
if callback:
callback(self, epoch, total_step)

if epoch_callback:
epoch_callback(self, epoch, total_step)

# save loss to loss history dict
self._loss_history["epoch"].append(epoch)
self._loss_history["step"].append(total_step)
Expand Down

0 comments on commit d5992b5

Please sign in to comment.