Skip to content

Commit

Permalink
Sampling should always be with replacement
Browse files Browse the repository at this point in the history
  • Loading branch information
sharif1093 committed Aug 14, 2019
1 parent 44254ce commit 4cacef7
Showing 1 changed file with 26 additions and 17 deletions.
43 changes: 26 additions & 17 deletions digideep/agent/samplers/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from .common import flatten_first_two
from digideep.utility.logging import logger

from digideep.utility.profiling import KeepTime

def get_sample_memory(buffer, info):
"""Sampler function for DDPG-like algorithms where we want to sample data from an experience replay buffer.
Expand All @@ -24,9 +26,11 @@ def get_sample_memory(buffer, info):

N = info["num_records"] - 1 # We don't want to consider the last "incomplete" record, hence "-1"

masks_arr = buffer["/masks"][:,:N]
masks_arr = masks_arr.reshape(-1)
total_arr = np.arange(0,num_workers*N)
with KeepTime("mask_array"):
masks_arr = buffer["/masks"][:,:N]
masks_arr = masks_arr.reshape(-1)
with KeepTime("total_array"):
total_arr = np.arange(0,num_workers*N)
## This is if we want to mask final states (mask equals 0 for final state, 1 otherwise.)
# valid_arr = total_arr[masks_arr.astype(bool)]
valid_arr = total_arr
Expand All @@ -36,24 +40,29 @@ def get_sample_memory(buffer, info):
logger.debug("batch_size ({}) should be smaller than total number of records (~ {}={}x{}).".format(batch_size, num_workers*N, num_workers, N))
return None

# Sampling with replacement:
# sample_indices = np.random.choice(valid_arr, batch_size, replace=True)
sample_indices = np.random.choice(valid_arr, batch_size, replace=False)
with KeepTime("sampling_by_choice"):
# Sampling with replacement:
sample_indices = np.random.choice(valid_arr, batch_size, replace=True)
# NOTE: Never ever use sampling without replacement: Its time scales up with th array size.
# sample_indices = np.random.choice(valid_arr, batch_size, replace=False)

sample_tabular = [[sample_indices // N], [sample_indices % N]]
sample_tabular_2 = [[sample_indices // N], [sample_indices % N + 1]]
with KeepTime("tabular_index_extraction"):
sample_tabular = [[sample_indices // N], [sample_indices % N]]
sample_tabular_2 = [[sample_indices // N], [sample_indices % N + 1]]

# Extracting the indices
batch = {}
for key in buffer:
batch[key] = buffer[key][sample_tabular[0],sample_tabular[1]]
# Extracting the indices
batch = {}
for key in buffer:
batch[key] = buffer[key][sample_tabular[0],sample_tabular[1]]

observation_path = "/observations" + observation_path
batch["/obs_with_key"] = batch[observation_path]
# Adding predictive keys
batch["/obs_with_key_2"] = buffer[observation_path][sample_tabular_2[0],sample_tabular_2[1]]
with KeepTime("post_key_generation"):
observation_path = "/observations" + observation_path
batch["/obs_with_key"] = batch[observation_path]
# Adding predictive keys
batch["/obs_with_key_2"] = buffer[observation_path][sample_tabular_2[0],sample_tabular_2[1]]

batch = flatten_first_two(batch)
with KeepTime("flatten_first_two"):
batch = flatten_first_two(batch)
return batch


Expand Down

0 comments on commit 4cacef7

Please sign in to comment.