Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve collector #125

Merged
merged 20 commits into from
Jul 12, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions test/base/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ def learn(self):

def preprocess_fn(**kwargs):
# modify info before adding into the buffer
if kwargs.get('info', None) is not None:
# if info is not provided from env, it will be a `Batch()`.
if not kwargs.get('info', Batch()).is_empty():
n = len(kwargs['obs'])
info = kwargs['info']
for i in range(n):
info[i].update(rew=kwargs['rew'][i])
return {'info': info}
# or
# return Batch(info=info)
# or: return Batch(info=info)
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
else:
return {}
return Batch()


class Logger(object):
Expand Down
204 changes: 96 additions & 108 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from tianshou.utils import MovAvg
from tianshou.env import BaseVectorEnv
from tianshou.policy import BasePolicy
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy
from tianshou.exploration import BaseNoise
from tianshou.data import Batch, ReplayBuffer, ListReplayBuffer, to_numpy


class Collector(object):
Expand All @@ -31,6 +31,11 @@ class Collector(object):
:param BaseNoise action_noise: add a noise to continuous action. Normally
a policy already has a noise param for exploration in training phase,
so this is recommended to use in test collector for some purpose.
:param function reward_metric: to be used in multi-agent RL. The reward to
report is of shape [reward_length], but we need to return a single
scalar to monitor training. This function specifies what is the desired
metric, i.e. the reward of agent 1 or the average reward over all
agents. By default, the behavior is to select the reward of agent 1.
youkaichao marked this conversation as resolved.
Show resolved Hide resolved

The ``preprocess_fn`` is a function called before the data has been added
to the buffer with batch format, which receives up to 7 keys as listed in
Expand Down Expand Up @@ -91,13 +96,12 @@ def __init__(self,
preprocess_fn: Callable[[Any], Union[dict, Batch]] = None,
stat_size: Optional[int] = 100,
action_noise: Optional[BaseNoise] = None,
reward_metric: Optional[Callable[[np.ndarray], float]] = None,
**kwargs) -> None:
super().__init__()
self.env = env
self.env_num = 1
self.collect_time = 0
self.collect_step = 0
self.collect_episode = 0
self.collect_time, self.collect_step, self.collect_episode = 0, 0, 0
self.buffer = buffer
self.policy = policy
self.preprocess_fn = preprocess_fn
Expand All @@ -107,23 +111,28 @@ def __init__(self,
self._cached_buf = []
if self._multi_env:
self.env_num = len(env)
self._cached_buf = [
ListReplayBuffer() for _ in range(self.env_num)]
self._cached_buf = [ListReplayBuffer()
for _ in range(self.env_num)]
self.stat_size = stat_size
self._action_noise = action_noise

def _rew_metric(x):
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
assert np.isscalar(x), 'Please specify the reward_metric ' \
'since the reward is not a scalar.'
return x

self._rew_metric = reward_metric or _rew_metric
self.reset()

def reset(self) -> None:
"""Reset all related variables in the collector."""
self.data = Batch(state={}, obs={}, act={}, rew={}, done={}, info={},
obs_next={}, policy={})
self.reset_env()
self.reset_buffer()
# state over batch is either a list, an np.ndarray, or a torch.Tensor
self.state = None
self.step_speed = MovAvg(self.stat_size)
self.episode_speed = MovAvg(self.stat_size)
self.collect_step = 0
self.collect_episode = 0
self.collect_time = 0
self.collect_time, self.collect_step, self.collect_episode = 0, 0, 0
if self._action_noise is not None:
self._action_noise.reset()

Expand All @@ -140,35 +149,28 @@ def reset_env(self) -> None:
"""Reset all of the environment(s)' states and reset all of the cache
buffers (if need).
"""
self._obs = self.env.reset()
obs = self.env.reset()
if not self._multi_env:
self._obs = self._make_batch(self._obs)
obs = self._make_batch(obs)
if self.preprocess_fn:
self._obs = self.preprocess_fn(obs=self._obs).get('obs', self._obs)
self._act, self._rew, self._done, self._info = \
Batch(), Batch(), Batch(), Batch()
if self._multi_env:
self.reward = np.zeros(self.env_num)
self.length = np.zeros(self.env_num)
else:
self.reward, self.length = 0, 0
obs = self.preprocess_fn(obs=obs).get('obs', obs)
self.data.obs = obs
self.reward = 0 # will be specified when the first data is ready
self.length = np.zeros(self.env_num)
for b in self._cached_buf:
b.reset()

def seed(self, seed: Optional[Union[int, List[int]]] = None) -> None:
"""Reset all the seed(s) of the given environment(s)."""
if hasattr(self.env, 'seed'):
return self.env.seed(seed)
return self.env.seed(seed)

def render(self, **kwargs) -> None:
"""Render all the environment(s)."""
if hasattr(self.env, 'render'):
return self.env.render(**kwargs)
return self.env.render(**kwargs)

def close(self) -> None:
"""Close the environment(s)."""
if hasattr(self.env, 'close'):
self.env.close()
self.env.close()

def _make_batch(self, data: Any) -> np.ndarray:
"""Return [data]."""
Expand All @@ -178,20 +180,14 @@ def _make_batch(self, data: Any) -> np.ndarray:
return np.array([data])

def _reset_state(self, id: Union[int, List[int]]) -> None:
"""Reset self.state[id]."""
if self.state is None:
return
if isinstance(self.state, list):
self.state[id] = None
elif isinstance(self.state, torch.Tensor):
self.state[id].zero_()
elif isinstance(self.state, np.ndarray):
if isinstance(self.state.dtype == np.object):
self.state[id] = None
else:
self.state[id] = 0
elif isinstance(self.state, Batch):
self.state.empty_(id)
"""Reset self.data.state[id]."""
state = self.data.state # it is a reference
if isinstance(state, torch.Tensor):
state[id].zero_()
elif isinstance(state, np.ndarray):
state[id] = None if state.dtype == np.object else 0
elif isinstance(state, Batch):
state.empty_(id)

def collect(self,
n_step: int = 0,
Expand Down Expand Up @@ -227,26 +223,27 @@ def collect(self,
* ``rew`` the mean reward over collected episodes.
* ``len`` the mean length over collected episodes.
"""
warning_count = 0
if not self._multi_env:
n_episode = np.sum(n_episode)
start_time = time.time()
assert sum([(n_step != 0), (n_episode != 0)]) == 1, \
"One and only one collection number specification is permitted!"
cur_step = 0
cur_episode = np.zeros(self.env_num) if self._multi_env else 0
reward_sum = 0
length_sum = 0
cur_step, cur_episode = 0, np.zeros(self.env_num)
reward_sum, length_sum = 0, 0
while True:
if warning_count >= 100000:
if cur_step >= 100000 and cur_episode.sum() == 0:
warnings.warn(
'There are already many steps in an episode. '
'You should add a time limitation to your environment!',
Warning)
batch = Batch(
obs=self._obs, act=self._act, rew=self._rew,
done=self._done, obs_next=Batch(), info=self._info,
policy=Batch())

# restore the state and the input data
last_state = self.data.state
if last_state.is_empty():
last_state = None
self.data.update(state=Batch(), obs_next=Batch(), policy=Batch())

# calculate the next action
if random:
action_space = self.env.action_space
if isinstance(action_space, list):
Expand All @@ -255,65 +252,54 @@ def collect(self,
result = Batch(act=self._make_batch(action_space.sample()))
else:
with torch.no_grad():
result = self.policy(batch, self.state)
result = self.policy(self.data, last_state)

# save hidden state to policy._state, in order to save into buffer
self.state = result.get('state', None)
# convert None to Batch()
state = result.get('state', Batch())
if state is None:
state = Batch()
self.data.state = state
if hasattr(result, 'policy'):
self._policy = to_numpy(result.policy)
if self.state is not None:
self._policy._state = self.state
elif self.state is not None:
self._policy = Batch(_state=self.state)
else:
self._policy = [{}] * self.env_num
self.data.policy = to_numpy(result.policy)
# save hidden state to policy._state, in order to save into buffer
self.data.policy._state = self.data.state

self._act = to_numpy(result.act)
self.data.act = to_numpy(result.act)
if self._action_noise is not None:
self._act += self._action_noise(self._act.shape)
obs_next, self._rew, self._done, self._info = self.env.step(
self._act if self._multi_env else self._act[0])
self.data.act += self._action_noise(self.data.act.shape)

# step in env
obs_next, rew, done, info = self.env.step(
self.data.act if self._multi_env else self.data.act[0])

# move data to self.data
if not self._multi_env:
obs_next = self._make_batch(obs_next)
self._rew = self._make_batch(self._rew)
self._done = self._make_batch(self._done)
self._info = self._make_batch(self._info)
rew = self._make_batch(rew)
done = self._make_batch(done)
info = self._make_batch(info)
self.data.obs_next = obs_next
self.data.rew = rew
self.data.done = done
self.data.info = info

if log_fn:
log_fn(self._info if self._multi_env else self._info[0])
log_fn(info if self._multi_env else info[0])
if render:
self.env.render()
self.render()
if render > 0:
time.sleep(render)

# add data into the buffer
self.length += 1
self.reward += self._rew
self.reward += self.data.rew
if self.preprocess_fn:
result = self.preprocess_fn(
obs=self._obs, act=self._act, rew=self._rew,
done=self._done, obs_next=obs_next, info=self._info,
policy=self._policy)
self._obs = result.get('obs', self._obs)
self._act = result.get('act', self._act)
self._rew = result.get('rew', self._rew)
self._done = result.get('done', self._done)
obs_next = result.get('obs_next', obs_next)
self._info = result.get('info', self._info)
self._policy = result.get('policy', self._policy)
if self._multi_env:
result = self.preprocess_fn(**self.data)
self.data.update(result)
if self._multi_env: # cached_buffer branch
for i in range(self.env_num):
data = {
'obs': self._obs[i], 'act': self._act[i],
'rew': self._rew[i], 'done': self._done[i],
'obs_next': obs_next[i], 'info': self._info[i],
'policy': self._policy[i]}
if self._cached_buf:
warning_count += 1
self._cached_buf[i].add(**data)
else:
warning_count += 1
if self.buffer is not None:
self.buffer.add(**data)
cur_step += 1
if self._done[i]:
self._cached_buf[i].add(**self.data[i])
if self.data.done[i]:
if n_step != 0 or np.isscalar(n_episode) or \
cur_episode[i] < n_episode[i]:
cur_episode[i] += 1
Expand All @@ -327,11 +313,13 @@ def collect(self,
if self._cached_buf:
self._cached_buf[i].reset()
self._reset_state(i)
if sum(self._done):
obs_next = self.env.reset(np.where(self._done)[0])
obs_next = self.data.obs_next
if sum(self.data.done):
obs_next = self.env.reset(np.where(self.data.done)[0])
if self.preprocess_fn:
obs_next = self.preprocess_fn(obs=obs_next).get(
'obs', obs_next)
self.data.obs_next = obs_next
if n_episode != 0:
if isinstance(n_episode, list) and \
(cur_episode >= np.array(n_episode)).all() or \
Expand All @@ -340,27 +328,27 @@ def collect(self,
break
else:
if self.buffer is not None:
self.buffer.add(
self._obs[0], self._act[0], self._rew[0],
self._done[0], obs_next[0], self._info[0],
self._policy[0])
self.buffer.add(**self.data[0])
cur_step += 1
if self._done:
if self.data.done[0]:
cur_episode += 1
reward_sum += self.reward[0]
reward_sum += self.reward
length_sum += self.length
self.reward, self.length = 0, 0
self.state = None
self.data.state = Batch()
obs_next = self._make_batch(self.env.reset())
if self.preprocess_fn:
obs_next = self.preprocess_fn(obs=obs_next).get(
'obs', obs_next)
self.data.obs_next = obs_next
if n_episode != 0 and cur_episode >= n_episode:
break
if n_step != 0 and cur_step >= n_step:
break
self._obs = obs_next
self._obs = obs_next
self.data.obs = self.data.obs_next
self.data.obs = self.data.obs_next

# generate the statistics
if self._multi_env:
cur_episode = sum(cur_episode)
duration = max(time.time() - start_time, 1e-9)
Expand Down