Skip to content

Commit

Permalink
Merge branch 'master' into async-tb-fix2
Browse files Browse the repository at this point in the history
  • Loading branch information
muupan authored Dec 14, 2020
2 parents d05ffba + 7026ca8 commit 55e35e1
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions pfrl/experiments/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,15 @@ def create_tb_writer(outdir):
return tb_writer


def record_tb_stats(summary_writer, agent_stats, eval_stats, t):
def record_tb_stats(summary_writer, agent_stats, eval_stats, env_stats, t):
cur_time = time.time()

for stat, value in agent_stats:
summary_writer.add_scalar("agent/" + stat, value, t, cur_time)

for stat, value in env_stats:
summary_writer.add_scalar("env/" + stat, value, t, cur_time)

for stat in ("mean", "median", "max", "min", "stdev"):
value = eval_stats[stat]
summary_writer.add_scalar("eval/" + stat, value, t, cur_time)
Expand All @@ -326,8 +329,8 @@ def record_tb_stats_loop(outdir, queue, stop_event):

while not (stop_event.wait(1e-6) and queue.empty()):
if not queue.empty():
agent_stats, eval_stats, t = queue.get()
record_tb_stats(tb_writer, agent_stats, eval_stats, t)
agent_stats, eval_stats, env_stats, t = queue.get()
record_tb_stats(tb_writer, agent_stats, eval_stats, env_stats, t)


def save_agent(agent, t, outdir, logger, suffix=""):
Expand Down Expand Up @@ -452,7 +455,7 @@ def evaluate_and_update_max_score(self, t, episodes):
record_stats(self.outdir, values)

if self.use_tensorboard:
record_tb_stats(self.tb_writer, agent_stats, eval_stats, t)
record_tb_stats(self.tb_writer, agent_stats, eval_stats, env_stats, t)

if mean > self.max_score:
self.logger.info("The best score is updated %s -> %s", self.max_score, mean)
Expand Down Expand Up @@ -567,7 +570,7 @@ def evaluate_and_update_max_score(self, t, episodes, env, agent):
record_stats(self.outdir, values)

if self.record_tb_stats_queue is not None:
self.record_tb_stats_queue.put([agent_stats, eval_stats, t])
self.record_tb_stats_queue.put([agent_stats, eval_stats, env_stats, t])

with self._max_score.get_lock():
if mean > self._max_score.value:
Expand Down

0 comments on commit 55e35e1

Please sign in to comment.