Skip to content

Commit

Permalink
Refactor evaluation code
Browse files Browse the repository at this point in the history
  • Loading branch information
vitchyr committed Oct 16, 2018
1 parent 37718d9 commit b333015
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 80 deletions.
41 changes: 32 additions & 9 deletions rlkit/core/rl_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import abc
import pickle
import time
from collections import OrderedDict

import gtimer as gt
import numpy as np

from rlkit.core import eval_util
from rlkit.core import logger
from rlkit.data_management.env_replay_buffer import EnvReplayBuffer
from rlkit.data_management.path_builder import PathBuilder
Expand Down Expand Up @@ -85,6 +87,8 @@ def __init__(
)
self.eval_policy = eval_policy
self.eval_sampler = eval_sampler
self.eval_statistics = OrderedDict()
self.need_to_update_eval_statistics = True

self.action_space = env.action_space
self.obs_space = env.observation_space
Expand Down Expand Up @@ -234,7 +238,7 @@ def _can_evaluate(self):
"""
return (
len(self._exploration_paths) > 0
and self.replay_buffer.num_steps_can_sample() >= self.batch_size
and not self.need_to_update_eval_statistics
)

def _can_train(self):
Expand Down Expand Up @@ -385,14 +389,33 @@ def training_mode(self, mode):
"""
pass

@abc.abstractmethod
def evaluate(self, epoch):
"""
Evaluate the policy, e.g. save/print progress.
:param epoch:
:return:
"""
pass
def evaluate(self, epoch, eval_paths=None):
statistics = OrderedDict()
statistics.update(self.eval_statistics)

logger.log("Collecting samples for evaluation")
if eval_paths:
test_paths = eval_paths
else:
test_paths = self.eval_sampler.obtain_samples()

statistics.update(eval_util.get_generic_path_information(
test_paths, stat_prefix="Test",
))
if len(self._exploration_paths) > 0:
statistics.update(eval_util.get_generic_path_information(
self._exploration_paths, stat_prefix="Exploration",
))
if hasattr(self.env, "log_diagnostics"):
self.env.log_diagnostics(test_paths, logger=logger)
if hasattr(self.env, "get_diagnostics"):
statistics.update(self.env.get_diagnostics(test_paths))

average_returns = eval_util.get_average_returns(test_paths)
statistics['AverageReturn'] = average_returns
for key, value in statistics.items():
logger.record_tabular(key, value)
self.need_to_update_eval_statistics = True

@abc.abstractmethod
def _do_training(self):
Expand Down
12 changes: 5 additions & 7 deletions rlkit/torch/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def __init__(
self.policy.parameters(),
lr=self.policy_learning_rate,
)
self.eval_statistics = None

def _do_training(self):
batch = self.get_batch()
Expand Down Expand Up @@ -208,12 +207,11 @@ def _do_training(self):

self._update_target_networks()

if self.eval_statistics is None:
"""
Eval should set this to None.
This way, these statistics are only computed for one batch.
"""
self.eval_statistics = OrderedDict()
"""
Save some statistics for eval using just one batch.
"""
if self.need_to_update_eval_statistics:
self.need_to_update_eval_statistics = False
self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
policy_loss
Expand Down
6 changes: 3 additions & 3 deletions rlkit/torch/dqn/double_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def _do_training(self):
self._update_target_network()

"""
Save some statistics for eval
Save some statistics for eval using just one batch.
"""
if self.eval_statistics is None:
self.eval_statistics = OrderedDict()
if self.need_to_update_eval_statistics:
self.need_to_update_eval_statistics = False
self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
self.eval_statistics.update(create_stats_ordered_dict(
'Y Predictions',
Expand Down
7 changes: 3 additions & 4 deletions rlkit/torch/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def __init__(
)
self.qf_criterion = qf_criterion or nn.MSELoss()

self.eval_statistics = None

def _do_training(self):
batch = self.get_batch()
Expand Down Expand Up @@ -98,10 +97,10 @@ def _do_training(self):
self._update_target_network()

"""
Save some statistics for eval
Save some statistics for eval using just one batch.
"""
if self.eval_statistics is None:
self.eval_statistics = OrderedDict()
if self.need_to_update_eval_statistics:
self.need_to_update_eval_statistics = False
self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
self.eval_statistics.update(create_stats_ordered_dict(
'Y Predictions',
Expand Down
11 changes: 3 additions & 8 deletions rlkit/torch/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def __init__(
self.target_vf = vf.copy()
self.qf_criterion = nn.MSELoss()
self.vf_criterion = nn.MSELoss()
self.eval_statistics = None

self.policy_optimizer = optimizer_class(
self.policy.parameters(),
Expand Down Expand Up @@ -132,14 +131,10 @@ def _do_training(self):
self._update_target_network()

"""
Save some statistics for eval
Save some statistics for eval using just one batch.
"""
if self.eval_statistics is None:
"""
Eval should set this to None.
This way, these statistics are only computed for one batch.
"""
self.eval_statistics = OrderedDict()
if self.need_to_update_eval_statistics:
self.need_to_update_eval_statistics = False
self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
self.eval_statistics['VF Loss'] = np.mean(ptu.get_numpy(vf_loss))
self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
Expand Down
12 changes: 5 additions & 7 deletions rlkit/torch/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def __init__(
self.policy.parameters(),
lr=policy_learning_rate,
)
self.eval_statistics = None

def _do_training(self):
batch = self.get_batch()
Expand Down Expand Up @@ -138,17 +137,16 @@ def _do_training(self):
ptu.soft_update_from_to(self.qf1, self.target_qf1, self.tau)
ptu.soft_update_from_to(self.qf2, self.target_qf2, self.tau)

if self.eval_statistics is None:
"""
Eval should set this to None.
This way, these statistics are only computed for one batch.
"""
"""
Save some statistics for eval using just one batch.
"""
if self.need_to_update_eval_statistics:
self.need_to_update_eval_statistics = False
if policy_loss is None:
policy_actions = self.policy(obs)
q_output = self.qf1(obs, policy_actions)
policy_loss = - q_output.mean()

self.eval_statistics = OrderedDict()
self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
self.eval_statistics['QF2 Loss'] = np.mean(ptu.get_numpy(qf2_loss))
self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
Expand Down
11 changes: 5 additions & 6 deletions rlkit/torch/tdm/tdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,11 @@ def _do_training(self):

self._update_target_networks()

if self.eval_statistics is None:
"""
Eval should set this to None.
This way, these statistics are only computed for one batch.
"""
self.eval_statistics = OrderedDict()
"""
Save some statistics for eval using just one batch.
"""
if self.need_to_update_eval_statistics:
self.need_to_update_eval_statistics = False
self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(
policy_loss
Expand Down
36 changes: 0 additions & 36 deletions rlkit/torch/torch_rl_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import abc
from collections import OrderedDict
from typing import Iterable

import numpy as np
Expand All @@ -8,16 +7,9 @@
from rlkit.core.rl_algorithm import RLAlgorithm
from rlkit.torch import pytorch_util as ptu
from rlkit.torch.core import PyTorchModule
from rlkit.core import logger, eval_util


class TorchRLAlgorithm(RLAlgorithm, metaclass=abc.ABCMeta):
def __init__(self, *args, render_eval_paths=False, plotter=None, **kwargs):
super().__init__(*args, **kwargs)
self.eval_statistics = None
self.render_eval_paths = render_eval_paths
self.plotter = plotter

def get_batch(self):
batch = self.replay_buffer.random_batch(self.batch_size)
return np_to_pytorch_batch(batch)
Expand All @@ -37,34 +29,6 @@ def to(self, device=None):
for net in self.networks:
net.to(device)

def evaluate(self, epoch):
statistics = OrderedDict()
statistics.update(self.eval_statistics)
self.eval_statistics = None

logger.log("Collecting samples for evaluation")
test_paths = self.eval_sampler.obtain_samples()

statistics.update(eval_util.get_generic_path_information(
test_paths, stat_prefix="Test",
))
statistics.update(eval_util.get_generic_path_information(
self._exploration_paths, stat_prefix="Exploration",
))
if hasattr(self.env, "log_diagnostics"):
self.env.log_diagnostics(test_paths)

average_returns = eval_util.get_average_returns(test_paths)
statistics['AverageReturn'] = average_returns
for key, value in statistics.items():
logger.record_tabular(key, value)

if self.render_eval_paths:
self.env.render_paths(test_paths)

if self.plotter:
self.plotter.draw()


def _elem_or_tuple_to_variable(elem_or_tuple):
if isinstance(elem_or_tuple, tuple):
Expand Down

0 comments on commit b333015

Please sign in to comment.