Skip to content

Commit

Permalink
remove multibuf
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 committed Jul 11, 2020
1 parent 1d5058b commit 306dd68
Showing 1 changed file with 13 additions and 75 deletions.
88 changes: 13 additions & 75 deletions tianshou/data/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ class Collector(object):
:param BaseNoise action_noise: add a noise to continuous action. Normally
a policy already has a noise param for exploration in training phase,
so this is recommended to use in test collector for some purpose.
:param get_final_reward_fn: to be used in multi-agent RL. The reward to
report is of shape [reward_length], but we need to return a single
scalar to monitor training. This function specifies what is the
desired metric, i.e. the reward of agent 1 or the average reward
over all agents. By default, the behavior is to select the
reward of agent 1.
The ``preprocess_fn`` is a function called before the data has been added
to the buffer with batch format, which receives up to 7 keys as listed in
Expand Down Expand Up @@ -93,14 +87,10 @@ class Collector(object):
def __init__(self,
policy: BasePolicy,
env: Union[gym.Env, BaseVectorEnv],
buffer: Optional[Union[ReplayBuffer, List[ReplayBuffer]]]
= None,
buffer: Optional[ReplayBuffer] = None,
preprocess_fn: Callable[[Any], Union[dict, Batch]] = None,
stat_size: Optional[int] = 100,
action_noise: Optional[BaseNoise] = None,
reward_length: int = 1,
get_final_reward_fn: Optional[Callable[[np.ndarray], float]]
= None,
**kwargs) -> None:
super().__init__()
self.env = env
Expand All @@ -111,35 +101,14 @@ def __init__(self,
self.buffer = buffer
self.policy = policy
self.preprocess_fn = preprocess_fn
self.reward_length = reward_length

def _get_final_reward_fn(x):
# x can be a scalar or ndarray with shape [reward_length]
if isinstance(x, np.ndarray):
return x[0]
return x
self.get_final_reward_fn = get_final_reward_fn or _get_final_reward_fn
# if preprocess_fn is None:
# def _prep(**kwargs):
# return kwargs
# self.preprocess_fn = _prep
self.process_fn = policy.process_fn
self._multi_env = isinstance(env, BaseVectorEnv)
self._multi_buf = False # True if buf is a list
# need multiple cache buffers only if storing in one buffer
self._cached_buf = []
if self._multi_env:
self.env_num = len(env)
if isinstance(self.buffer, list):
assert len(self.buffer) == self.env_num, \
'The number of data buffer does not match the number of ' \
'input env.'
self._multi_buf = True
elif isinstance(self.buffer, ReplayBuffer) or self.buffer is None:
self._cached_buf = [
ListReplayBuffer() for _ in range(self.env_num)]
else:
raise TypeError('The buffer in data collector is invalid!')
self._cached_buf = [
ListReplayBuffer() for _ in range(self.env_num)]
self.stat_size = stat_size
self._action_noise = action_noise
self.reset()
Expand All @@ -160,12 +129,8 @@ def reset(self) -> None:

def reset_buffer(self) -> None:
"""Reset the main data buffer."""
if self._multi_buf:
for b in self.buffer:
b.reset()
else:
if self.buffer is not None:
self.buffer.reset()
if self.buffer is not None:
self.buffer.reset()

def get_env_num(self) -> int:
"""Return the number of environments the collector have."""
Expand All @@ -183,10 +148,10 @@ def reset_env(self) -> None:
self._act, self._rew, self._done, self._info = \
Batch(), Batch(), Batch(), Batch()
if self._multi_env:
self.reward = np.zeros((self.env_num, self.reward_length))
self.reward = np.zeros(self.env_num)
self.length = np.zeros(self.env_num)
else:
self.reward, self.length = np.zeros(self.reward_length), 0
self.reward, self.length = 0, 0
for b in self._cached_buf:
b.reset()

Expand Down Expand Up @@ -320,7 +285,7 @@ def collect(self,
if render > 0:
time.sleep(render)
self.length += 1
self.reward += self._rew.reshape(self.reward.shape)
self.reward += self._rew
if self.preprocess_fn:
result = self.preprocess_fn(
obs=self._obs, act=self._act, rew=self._rew,
Expand All @@ -343,10 +308,6 @@ def collect(self,
if self._cached_buf:
warning_count += 1
self._cached_buf[i].add(**data)
elif self._multi_buf:
warning_count += 1
self.buffer[i].add(**data)
cur_step += 1
else:
warning_count += 1
if self.buffer is not None:
Expand All @@ -362,8 +323,7 @@ def collect(self,
cur_step += len(self._cached_buf[i])
if self.buffer is not None:
self.buffer.update(self._cached_buf[i])
self.reward[i] = np.zeros(self.reward_length)
self.length[i] = 0
self.reward[i], self.length[i] = 0, 0
if self._cached_buf:
self._cached_buf[i].reset()
self._reset_state(i)
Expand All @@ -389,12 +349,7 @@ def collect(self,
cur_episode += 1
reward_sum += self.reward[0]
length_sum += self.length
if self._multi_env:
self.reward = np.zeros(
self.reward_length, self.env_num)
else:
self.reward = np.zeros(self.reward_length)
self.length = 0
self.reward, self.length = 0, 0
self.state = None
obs_next = self._make_batch(self.env.reset())
if self.preprocess_fn:
Expand All @@ -418,13 +373,12 @@ def collect(self,
n_episode = np.sum(n_episode)
else:
n_episode = max(cur_episode, 1)
reward_sum = self.get_final_reward_fn(reward_sum / n_episode)
return {
'n/ep': cur_episode,
'n/st': cur_step,
'v/st': self.step_speed.get(),
'v/ep': self.episode_speed.get(),
'rew': reward_sum,
'rew': reward_sum / n_episode,
'len': length_sum / n_episode,
}

Expand All @@ -437,22 +391,6 @@ def sample(self, batch_size: int) -> Batch:
the buffer, otherwise it will extract the data with the given
batch_size.
"""
if self._multi_buf:
if batch_size > 0:
lens = [len(b) for b in self.buffer]
total = sum(lens)
batch_index = np.random.choice(
len(self.buffer), batch_size, p=np.array(lens) / total)
else:
batch_index = np.array([])
batch_data = Batch()
for i, b in enumerate(self.buffer):
cur_batch = (batch_index == i).sum()
if batch_size and cur_batch or batch_size <= 0:
batch, indice = b.sample(cur_batch)
batch = self.process_fn(batch, b, indice)
batch_data.cat_(batch)
else:
batch_data, indice = self.buffer.sample(batch_size)
batch_data = self.process_fn(batch_data, self.buffer, indice)
batch_data, indice = self.buffer.sample(batch_size)
batch_data = self.process_fn(batch_data, self.buffer, indice)
return batch_data

0 comments on commit 306dd68

Please sign in to comment.