Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao committed Jul 11, 2020
1 parent b6beb67 commit 5ce2692
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
3 changes: 3 additions & 0 deletions test/base/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@


class MyTestEnv(gym.Env):
"""
This is a "going right" task. The task is to go right ``size`` steps.
"""
def __init__(self, size, sleep=0, dict_state=False, ma_rew=0):
self.size = size
self.sleep = sleep
Expand Down
10 changes: 8 additions & 2 deletions test/base/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,16 @@ def reward_metric(x):
batch = c1.sample(10)
print(batch)
c0.buffer.update(c1.buffer)
assert np.allclose(c0.buffer[:len(c0.buffer)].obs, [
obs = [
0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1., 2., 3., 4., 0., 1.,
0., 1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4., 0., 1., 0.,
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.])
1., 2., 0., 1., 0., 1., 2., 3., 0., 1., 2., 3., 4.]
assert np.allclose(c0.buffer[:len(c0.buffer)].obs, obs)
rew = [0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1,
0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0,
0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1]
assert np.allclose(c0.buffer[:len(c0.buffer)].rew,
[[x] * 4 for x in rew])
c2 = Collector(policy, envs, ReplayBuffer(size=100, stack_num=4),
preprocess_fn, reward_metric=reward_metric)
c2.collect(n_episode=[0, 0, 0, 10])
Expand Down
7 changes: 5 additions & 2 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Collector(object):
:param function reward_metric: to be used in multi-agent RL. The reward to
report is of shape [reward_length], but we need to return a single
scalar to monitor training. This function specifies what is the desired
metric, i.e. the reward of agent 1 or the average reward over all
metric, e.g. the reward of agent 1 or the average reward over all
agents. By default, the behavior is to select the reward of agent 1.
The ``preprocess_fn`` is a function called before the data has been added
Expand Down Expand Up @@ -117,7 +117,10 @@ def __init__(self,
self._action_noise = action_noise

def _rew_metric(x):
assert np.asanyarray(x) == 1, 'Please specify the reward_metric ' \
# this internal function is designed for single-agent RL
# for multi-agent RL, a reward_metric must be provided
assert np.asanyarray(x).size == 1,\
'Please specify the reward_metric '\
'since the reward is not a scalar.'
return x

Expand Down

0 comments on commit 5ce2692

Please sign in to comment.