Skip to content

Commit

Permalink
[RLlib] move evaluation to trainer.step() such that the result is pro…
Browse files Browse the repository at this point in the history
…perly logged (#12708)
  • Loading branch information
Maltimore committed Jan 25, 2021
1 parent 964689b commit b4702de
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
8 changes: 0 additions & 8 deletions rllib/agents/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,14 +535,6 @@ def train(self) -> ResultDict:
if hasattr(self, "workers") and isinstance(self.workers, WorkerSet):
self._sync_filters_if_needed(self.workers)

if self.config["evaluation_interval"] == 1 or (
self._iteration > 0 and self.config["evaluation_interval"]
and self._iteration % self.config["evaluation_interval"] == 0):
evaluation_metrics = self._evaluate()
assert isinstance(evaluation_metrics, dict), \
"_evaluate() needs to return a dict."
result.update(evaluation_metrics)

return result

def _sync_filters_if_needed(self, workers: WorkerSet):
Expand Down
12 changes: 12 additions & 0 deletions rllib/agents/trainer_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,18 @@ def _init(self, config: TrainerConfigDict,
@override(Trainer)
def step(self):
res = next(self.train_exec_impl)

# self._iteration gets incremented after this function returns,
# meaning that e. g. the first time this function is called,
# self._iteration will be 0. We check `self._iteration+1` in the
# if-statement below to reflect that the first training iteration
# is already over.
if (self.config["evaluation_interval"] and (self._iteration + 1) %
self.config["evaluation_interval"] == 0):
evaluation_metrics = self._evaluate()
assert isinstance(evaluation_metrics, dict), \
"_evaluate() needs to return a dict."
res.update(evaluation_metrics)
return res

@override(Trainer)
Expand Down

0 comments on commit b4702de

Please sign in to comment.