-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0ffd7a6
commit 82d1229
Showing
6 changed files
with
386 additions
and
380 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
""" | ||
This module provides helping functions and tools to create a sampler from the memory. | ||
The samplers take advantage of a highly modular pattern in order to create new samplers | ||
or change the behavior of the current ones much easier. | ||
One can build modular samplers by cascading functions using :class:`Compose` class. | ||
All function must have the following signature: | ||
.. code-block:: python | ||
def func(data, info) | ||
* ``data`` is a dictionary where all data is stored. It can be the whole memory at the first sampler | ||
block, and then narrowing down to a small sampled chunk of data at the end. | ||
* ``info`` is a dict containing information that is passed through to the last sampler. It basically | ||
contains information that one sampler may need, e.g. ``batch_size``, ``memory_size``, etc. | ||
Parts of this module is inspired by `pytorch-a2c-ppo-acktr <https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/>`_. | ||
""" | ||
|
||
import numpy as np | ||
import os, inspect | ||
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler | ||
from digideep.utility.logging import logger | ||
|
||
import warnings | ||
|
||
class Compose: | ||
"""A class to create a composed function from a list. | ||
Args: | ||
functions (list): A list of functions with the following prototype which will be called | ||
in cascade after the class is called. The first function in the list will be called | ||
first and results will be passed to the second one and so on. | ||
.. code-block:: python | ||
f_composed = Compose([f1, f2]) | ||
""" | ||
def __init__(self, functions): | ||
self.functions = functions | ||
def __call__(self, data, info): | ||
for f in self.functions: | ||
data = f(data, info) | ||
return data | ||
|
||
|
||
################################## | ||
## SAMPLER ## | ||
################################## | ||
def get_memory_params(memory, info): | ||
"""A sampler function to get memory parameters and store them in the ``info`` for the future samplers. | ||
Args: | ||
memory: The main memory object. | ||
info (dict): A dictionary that can be used to transfer hyper-information among sampler functions. | ||
Returns: | ||
dict: A reference to the internal buffer of the memory. | ||
Todo: | ||
We may have a ``include_keys``/``exclude_keys`` argument to filter keys in the memory. | ||
It can help control the downstream flow of keys in the memory. Only those keys would be | ||
passed through that satisfy both ``include_keys`` and ``exclude_keys``. Example: | ||
.. code-block:: python | ||
include_keys:["/actions/*","/observations"] | ||
exclude_keys:["/info*"] | ||
""" | ||
|
||
info["num_steps"] = memory.get_chunk_sample_num() | ||
info["num_workers"] = memory.get_num_batches() | ||
info["num_records"] = memory.get_last_trans_index() | ||
|
||
buffer = memory.get_buffer() | ||
return buffer | ||
|
||
|
||
#################### | ||
###### HELPER ###### | ||
#################### | ||
def flatten_first_two(batch): | ||
"""This is a helper function that is used in other sampler functions. | ||
It flattens the first two dimensions of each key entry | ||
in the batch, thus making the data flattened. | ||
""" | ||
# The data must be intact up to preprocess. | ||
# After that we are free. | ||
for key in batch: | ||
batch[key] = batch[key].reshape(-1, *batch[key].shape[2:]) | ||
return batch | ||
|
||
def truncate_datalists(chunk, info): | ||
"""This sampler function truncates the last data entry in the chunk and returns the rest. | ||
For an example key in the chunk, if the input shape is ``(batch_size, T+1, ...)``, it will become | ||
``(batch_size, T, ...)``. This is basically to remove the final "half-step" taken in the environment. | ||
""" | ||
params = info["truncate_datalists"] | ||
n = params["n"] | ||
|
||
for key in chunk: | ||
chunk[key] = chunk[key][:,:-n] | ||
return chunk | ||
|
||
|
||
######################### | ||
###### CHECK CHUNK ###### | ||
######################### | ||
# if torch.isnan(torch.tensor( ... )).any(): | ||
def check_nan(chunk, info): | ||
"""This sampler function has debugging purposes and will publish a warning message if there are NaN values in the chunk. | ||
""" | ||
for key in chunk: | ||
if np.isnan(chunk[key]).any(): | ||
logger.warn("%s:%s[%d]: Found NaN '%s'." % | ||
(os.path.basename(inspect.stack()[2].filename), | ||
inspect.stack()[2].function, | ||
inspect.stack()[2].lineno, | ||
key)) | ||
return chunk | ||
|
||
def check_shape(chunk, info): | ||
"""This sampler function has debugging purposes and reports the shapes of every key in the data chunk. | ||
""" | ||
logger.warn("%s:%s[%d]: Checking shapes:" % | ||
(os.path.basename(inspect.stack()[2].filename), | ||
inspect.stack()[2].function, | ||
inspect.stack()[2].lineno)) | ||
for key in chunk: | ||
logger.warn("%s %s" % ( key, str(chunk[key].shape))) | ||
return chunk | ||
|
||
def check_stats(chunk, info): | ||
"""This sampler function has debugging purposes and will report the mean and standard deviation of every key in the data chunk. | ||
""" | ||
logger.warn("%s:%s[%d]: Checking stats:" % | ||
(os.path.basename(inspect.stack()[2].filename), | ||
inspect.stack()[2].function, | ||
inspect.stack()[2].lineno)) | ||
|
||
for key in chunk: | ||
logger.warn("{} = {:.2f} (\xB1{:.2f} 95%)".format(key, np.nanmean(chunk[key]), 2*np.nanstd(chunk[key]))) | ||
return chunk | ||
|
||
def print_line(chunk, info): | ||
logger.warn("=========================================") | ||
return chunk |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import numpy as np | ||
import os, inspect, warnings | ||
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler | ||
from digideep.utility.logging import logger | ||
|
||
|
||
from .common import Compose, get_memory_params, check_nan, check_shape, check_stats, print_line | ||
from .common import flatten_first_two | ||
|
||
def get_sample_memory(buffer, info): | ||
"""Sampler function for DDPG-like algorithms where we want to sample data from an experience replay buffer. | ||
This function does not sample from final steps where mask equals ``0`` (as they don't have subsequent observations.) | ||
This function adds the following key to the memory: | ||
* ``/observations_2`` | ||
Returns: | ||
dict: One sampled batch to be used in the DDPG algorithm for one step of training. The shape of each | ||
key in the output batch will be: ``(batch_size, *key_shape[2:])`` | ||
""" | ||
batch_size = info["batch_size"] | ||
|
||
num_workers = info["num_workers"] | ||
N = info["num_records"] - 1 # We don't want to consider the last "incomplete" record, hence "-1" | ||
|
||
masks_arr = buffer["/masks"][:,:N] | ||
masks_arr = masks_arr.reshape(-1) | ||
total_arr = np.arange(0,num_workers*N) | ||
valid_arr = total_arr[masks_arr.astype(bool)] | ||
|
||
if batch_size >= len(valid_arr): | ||
# We don't have enough data in the memory yet. | ||
warnings.warn("batch_size ({}) should be smaller than total number of records (~ {}={}x{}).".format(batch_size, num_workers*N, num_workers, N)) | ||
return None | ||
|
||
# Sampling with replacement: | ||
sample_indices = np.random.choice(valid_arr, batch_size, replace=True) | ||
|
||
sample_tabular = [[sample_indices // N], [sample_indices % N]] | ||
sample_tabular_2 = [[sample_indices // N], [sample_indices % N + 1]] | ||
|
||
# Extracting the indices | ||
batch = {} | ||
for key in buffer: | ||
batch[key] = buffer[key][sample_tabular[0],sample_tabular[1]] | ||
# Adding predictive keys | ||
batch["/observations_2"] = buffer["/observations"][sample_tabular_2[0],sample_tabular_2[1]] | ||
|
||
batch = flatten_first_two(batch) | ||
return batch | ||
|
||
|
||
############################# | ||
### Composing the sampler ### | ||
############################# | ||
|
||
# Sampler with replay buffer | ||
sampler_re = Compose([get_memory_params, # Must be present: It gets the memory parameters and passes them to the rest of functions through "info". | ||
get_sample_memory, # Sample | ||
# check_nan, # It complains about existing NaNs in the chunk. | ||
# check_shape, # It prints the shapes of the existing keys in the chunk. | ||
# check_stats, | ||
# print_line, # This only prints a line for more beautiful debugging. | ||
]) |
Oops, something went wrong.