Skip to content

Commit

Permalink
Refactored samplers
Browse files Browse the repository at this point in the history
  • Loading branch information
sharif1093 committed May 15, 2019
1 parent 0ffd7a6 commit 82d1229
Show file tree
Hide file tree
Showing 6 changed files with 386 additions and 380 deletions.
4 changes: 2 additions & 2 deletions digideep/agent/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch.nn.functional as F
import torch.utils.data

from digideep.agent.samplers.default import sampler_re
# from digideep.agent.samplers.default import check_shape
from digideep.agent.samplers.ddpg import sampler_re
# from digideep.agent.samplers.common import check_shape

from digideep.utility.toolbox import get_class
from digideep.utility.logging import logger
Expand Down
6 changes: 3 additions & 3 deletions digideep/agent/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch.nn.functional as F
import torch.utils.data

from digideep.agent.samplers.default import sampler_ff, sampler_rn
# from digideep.agent.samplers.default import check_shape
from digideep.agent.samplers.ppo import sampler_ff, sampler_rn
# from digideep.agent.samplers.common import check_shape

from digideep.utility.toolbox import get_class
from digideep.utility.logging import logger
Expand Down Expand Up @@ -122,7 +122,7 @@ def step(self):
* ``/agents/<agent_name>/artifacts/advantages``
* ``/agents/<agent_name>/artifacts/returns``
The last two keys are added by the :mod:`digideep.agent.samplers.default`, while the rest are added at
The last two keys are added by the :mod:`digideep.agent.samplers`, while the rest are added at
:class:`~digideep.environment.explorer.Explorer`.
"""
Expand Down
147 changes: 147 additions & 0 deletions digideep/agent/samplers/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""
This module provides helping functions and tools to create a sampler from the memory.
The samplers take advantage of a highly modular pattern in order to create new samplers
or change the behavior of the current ones much easier.
One can build modular samplers by cascading functions using :class:`Compose` class.
All function must have the following signature:
.. code-block:: python
def func(data, info)
* ``data`` is a dictionary where all data is stored. It can be the whole memory at the first sampler
block, and then narrowing down to a small sampled chunk of data at the end.
* ``info`` is a dict containing information that is passed through to the last sampler. It basically
contains information that one sampler may need, e.g. ``batch_size``, ``memory_size``, etc.
Parts of this module is inspired by `pytorch-a2c-ppo-acktr <https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/>`_.
"""

import numpy as np
import os, inspect
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
from digideep.utility.logging import logger

import warnings

class Compose:
"""A class to create a composed function from a list.
Args:
functions (list): A list of functions with the following prototype which will be called
in cascade after the class is called. The first function in the list will be called
first and results will be passed to the second one and so on.
.. code-block:: python
f_composed = Compose([f1, f2])
"""
def __init__(self, functions):
self.functions = functions
def __call__(self, data, info):
for f in self.functions:
data = f(data, info)
return data


##################################
## SAMPLER ##
##################################
def get_memory_params(memory, info):
"""A sampler function to get memory parameters and store them in the ``info`` for the future samplers.
Args:
memory: The main memory object.
info (dict): A dictionary that can be used to transfer hyper-information among sampler functions.
Returns:
dict: A reference to the internal buffer of the memory.
Todo:
We may have a ``include_keys``/``exclude_keys`` argument to filter keys in the memory.
It can help control the downstream flow of keys in the memory. Only those keys would be
passed through that satisfy both ``include_keys`` and ``exclude_keys``. Example:
.. code-block:: python
include_keys:["/actions/*","/observations"]
exclude_keys:["/info*"]
"""

info["num_steps"] = memory.get_chunk_sample_num()
info["num_workers"] = memory.get_num_batches()
info["num_records"] = memory.get_last_trans_index()

buffer = memory.get_buffer()
return buffer


####################
###### HELPER ######
####################
def flatten_first_two(batch):
"""This is a helper function that is used in other sampler functions.
It flattens the first two dimensions of each key entry
in the batch, thus making the data flattened.
"""
# The data must be intact up to preprocess.
# After that we are free.
for key in batch:
batch[key] = batch[key].reshape(-1, *batch[key].shape[2:])
return batch

def truncate_datalists(chunk, info):
"""This sampler function truncates the last data entry in the chunk and returns the rest.
For an example key in the chunk, if the input shape is ``(batch_size, T+1, ...)``, it will become
``(batch_size, T, ...)``. This is basically to remove the final "half-step" taken in the environment.
"""
params = info["truncate_datalists"]
n = params["n"]

for key in chunk:
chunk[key] = chunk[key][:,:-n]
return chunk


#########################
###### CHECK CHUNK ######
#########################
# if torch.isnan(torch.tensor( ... )).any():
def check_nan(chunk, info):
"""This sampler function has debugging purposes and will publish a warning message if there are NaN values in the chunk.
"""
for key in chunk:
if np.isnan(chunk[key]).any():
logger.warn("%s:%s[%d]: Found NaN '%s'." %
(os.path.basename(inspect.stack()[2].filename),
inspect.stack()[2].function,
inspect.stack()[2].lineno,
key))
return chunk

def check_shape(chunk, info):
"""This sampler function has debugging purposes and reports the shapes of every key in the data chunk.
"""
logger.warn("%s:%s[%d]: Checking shapes:" %
(os.path.basename(inspect.stack()[2].filename),
inspect.stack()[2].function,
inspect.stack()[2].lineno))
for key in chunk:
logger.warn("%s %s" % ( key, str(chunk[key].shape)))
return chunk

def check_stats(chunk, info):
"""This sampler function has debugging purposes and will report the mean and standard deviation of every key in the data chunk.
"""
logger.warn("%s:%s[%d]: Checking stats:" %
(os.path.basename(inspect.stack()[2].filename),
inspect.stack()[2].function,
inspect.stack()[2].lineno))

for key in chunk:
logger.warn("{} = {:.2f} (\xB1{:.2f} 95%)".format(key, np.nanmean(chunk[key]), 2*np.nanstd(chunk[key])))
return chunk

def print_line(chunk, info):
logger.warn("=========================================")
return chunk
66 changes: 66 additions & 0 deletions digideep/agent/samplers/ddpg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import numpy as np
import os, inspect, warnings
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
from digideep.utility.logging import logger


from .common import Compose, get_memory_params, check_nan, check_shape, check_stats, print_line
from .common import flatten_first_two

def get_sample_memory(buffer, info):
"""Sampler function for DDPG-like algorithms where we want to sample data from an experience replay buffer.
This function does not sample from final steps where mask equals ``0`` (as they don't have subsequent observations.)
This function adds the following key to the memory:
* ``/observations_2``
Returns:
dict: One sampled batch to be used in the DDPG algorithm for one step of training. The shape of each
key in the output batch will be: ``(batch_size, *key_shape[2:])``
"""
batch_size = info["batch_size"]

num_workers = info["num_workers"]
N = info["num_records"] - 1 # We don't want to consider the last "incomplete" record, hence "-1"

masks_arr = buffer["/masks"][:,:N]
masks_arr = masks_arr.reshape(-1)
total_arr = np.arange(0,num_workers*N)
valid_arr = total_arr[masks_arr.astype(bool)]

if batch_size >= len(valid_arr):
# We don't have enough data in the memory yet.
warnings.warn("batch_size ({}) should be smaller than total number of records (~ {}={}x{}).".format(batch_size, num_workers*N, num_workers, N))
return None

# Sampling with replacement:
sample_indices = np.random.choice(valid_arr, batch_size, replace=True)

sample_tabular = [[sample_indices // N], [sample_indices % N]]
sample_tabular_2 = [[sample_indices // N], [sample_indices % N + 1]]

# Extracting the indices
batch = {}
for key in buffer:
batch[key] = buffer[key][sample_tabular[0],sample_tabular[1]]
# Adding predictive keys
batch["/observations_2"] = buffer["/observations"][sample_tabular_2[0],sample_tabular_2[1]]

batch = flatten_first_two(batch)
return batch


#############################
### Composing the sampler ###
#############################

# Sampler with replay buffer
sampler_re = Compose([get_memory_params, # Must be present: It gets the memory parameters and passes them to the rest of functions through "info".
get_sample_memory, # Sample
# check_nan, # It complains about existing NaNs in the chunk.
# check_shape, # It prints the shapes of the existing keys in the chunk.
# check_stats,
# print_line, # This only prints a line for more beautiful debugging.
])

0 comments on commit 82d1229

Please sign in to comment.