Skip to content

Commit

Permalink
Move batch_indices and advance function to flexible pool
Browse files Browse the repository at this point in the history
  • Loading branch information
hartikainen committed Jul 5, 2018
1 parent b8d4928 commit 5221555
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
12 changes: 8 additions & 4 deletions softlearning/replay_pools/flexible_replay_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ def add_fields(self, fields):
field_shape = [self._max_size] + list(field_attrs['shape'])
setattr(self, field_name, np.zeros(field_shape))

def _advance(self, count=1):
self._pointer = (self._pointer + count) % self._max_size
if self._size + count <= self._max_size:
self._size += count

def add_sample(self, **kwargs):
for field_name in self.field_names:
getattr(self, field_name)[self._pointer] = kwargs.pop(field_name)
Expand Down Expand Up @@ -69,9 +74,8 @@ def __setstate__(self, pool_state):
self._pointer = pool_state['_pointer']
self._size = pool_state['_size']

@abstractmethod
def batch_indices(self, batch_size):
pass
def random_indices(self, batch_size):
return np.random.randint(0, self._size, batch_size)

def random_batch(self, batch_size, field_name_filter=None):
field_names = self.field_names
Expand All @@ -81,7 +85,7 @@ def random_batch(self, batch_size, field_name_filter=None):
if field_name_filter(field_name)
]

indices = self.batch_indices(batch_size)
indices = self.random_indices(batch_size)

return {
field_name: getattr(self, field_name)[indices]
Expand Down
8 changes: 0 additions & 8 deletions softlearning/replay_pools/simple_replay_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,3 @@ def __init__(self, env_spec, *args, **kwargs):

def terminate_episode(self):
pass

def _advance(self):
self._pointer = (self._pointer + 1) % self._max_size
if self._size < self._max_size:
self._size += 1

def batch_indices(self, batch_size):
return np.random.randint(0, self._size, batch_size)

0 comments on commit 5221555

Please sign in to comment.