Skip to content

Commit

Permalink
Fix all docstrings to Google style
Browse files Browse the repository at this point in the history
  • Loading branch information
zuoxingdong committed Aug 3, 2019
1 parent ccb9b79 commit eb4d368
Show file tree
Hide file tree
Showing 21 changed files with 87 additions and 170 deletions.
6 changes: 2 additions & 4 deletions examples/vae/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,8 @@ def vae_loss(re_x, x, mu, logvar, mode='BCE'):
logvar (Tensor): log-variance of the latent variable
mode (str): Type of reconstruction loss, supported ['BCE', 'MSE']
Returns
-------
out : dict
a dictionary of selected output such as loss, reconstruction loss and KL loss.
Returns:
dict: a dictionary of selected output such as loss, reconstruction loss and KL loss.
"""
assert mode in ['BCE', 'MSE'], f'expected either BCE or MSE, got {mode}'

Expand Down
45 changes: 17 additions & 28 deletions lagom/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,19 @@
from abc import abstractmethod

from lagom.networks import Module
from lagom.envs import VecEnv


class BaseAgent(Module, ABC):
r"""Base class for all agents.
The agent could select an action from a given observation and update itself by defining a certain learning
mechanism.
The agent could select an action from some data (e.g. observation) and update itself by
defining a certain learning mechanism.
Any agent should subclass this class, e.g. policy-based or value-based.
.. note::
All agents should by default handle batched data e.g. batched observation returned from :class:`VecEnv`
and batched action for each sub-environment of a :class:`VecEnv`.
Args:
config (dict): a dictionary of configurations
env (VecEnv): environment object.
env (Env): environment object.
device (Device): a PyTorch device
**kwargs: keyword aguments used to specify the agent
Expand All @@ -36,27 +30,24 @@ def __init__(self, config, env, device, **kwargs):
self.is_recurrent = None

@abstractmethod
def choose_action(self, obs, **kwargs):
r"""Returns an (batched) action selected by the agent from received (batched) observation/
def choose_action(self, x, **kwargs):
r"""Returns the selected action given the data.
.. note::
Tensor conversion should be handled here instead of in policy or network forward pass.
It's recommended to handle all dtype/device conversions between CPU/GPU or Tensor/Numpy here.
The output is a dictionary containing useful items, e.g. action, action_logprob, state_value
The output is a dictionary containing useful items,
Args:
obs (object): batched observation returned from the environment. First dimension is treated
as batch dimension.
**kwargs: keyword arguments to specify action selection.
Returns
-------
out : dict
a dictionary of action selection output. It should also contain all useful information
to be stored during interaction with :class:`BaseRunner`. This allows a generic API of
the runner classes for all kinds of agents. Note that everything should be batched even
if for scalar loss, i.e. ``scalar_loss -> [scalar_loss]``
Returns:
dict: a dictionary of action selection output. It contains all useful information (e.g. action,
action_logprob, state_value). This allows the API to be generic and compatible with
different kinds of runner and agents.
"""
pass

Expand All @@ -66,22 +57,20 @@ def learn(self, D, **kwargs):
Args:
D (list): a list of batched data to train the agent e.g. in policy gradient, this can be
a list of :class:`Trajectory` or :class:`Segment`
a list of :class:`Trajectory`.
**kwargs: keyword arguments to specify learning mechanism
Returns
-------
out : dict
a dictionary of learning output. This could contain the loss.
Returns:
dict: a dictionary of learning output. This could contain the loss and other useful metrics.
"""
pass


class RandomAgent(BaseAgent):
r"""A random agent samples action uniformly from action space. """
def choose_action(self, obs, **kwargs):
if isinstance(self.env, VecEnv):
action = [self.env.action_space.sample() for _ in range(len(self.env))]
def choose_action(self, x, **kwargs):
if hasattr(self.env, 'num_envs'):
action = [self.env.action_space.sample() for _ in range(self.env.num_envs)]
else:
action = self.env.action_space.sample()
out = {'raw_action': action}
Expand Down
12 changes: 4 additions & 8 deletions lagom/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@ def train(self, n=None, **kwargs):
n (int, optional): n-th iteration for training.
**kwargs: keyword aguments used for logging.
Returns
-------
out : dict
training output
Returns:
dict: a dictionary of training output
"""
pass

Expand All @@ -53,9 +51,7 @@ def eval(self, n=None, **kwargs):
n (int, optional): n-th iteration for evaluation.
**kwargs: keyword aguments used for logging.
Returns
-------
out : dict
evluation output
Returns:
dict: a dictionary of evluation output
"""
pass
7 changes: 2 additions & 5 deletions lagom/envs/make_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@ def make_vec_env(make_env, num_env, init_seed):
num_env (int): number of environments to create.
init_seed (int): initial seed for :class:`Seeder` to sample random seeds.
Returns
-------
env : VecEnv
created vectorized environment
Returns:
VecEnv: created vectorized environment
"""
# Generate different seeds for each environment
seeder = Seeder(init_seed=init_seed)
Expand Down
38 changes: 15 additions & 23 deletions lagom/envs/vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,15 @@ def step(self, actions):
Args:
actions (list): a list of actions, each for one environment.
Returns
-------
observations : list
a list of observations, each returned from one environment after executing the given action.
rewards : list
a list of scalar rewards, each returned from one environment.
dones : list
a list of booleans indicating whether the episode terminates, each returned from one environment.
infos : list
a list of dictionaries of additional informations, each returned from one environment.
Returns:
tuple: a tuple of (observations, rewards, dones, infos)
* observations (list): a list of observations, each returned from one environment
after executing the given action.
* rewards (list): a list of scalar rewards, each returned from one environment.
* dones (list): a list of booleans indicating whether the episode terminates, each
returned from one environment.
* infos (list): a list of dictionaries of additional informations, each returned
from one environment.
"""
assert len(actions) == len(self)
observations = []
Expand All @@ -80,10 +78,8 @@ def reset(self):
If :meth:`step_async` is still working, then it will be aborted.
Returns
-------
observations : list
a list of initial observations from all environments.
Returns:
list: a list of initial observations from all environments.
"""
observations = [env.reset() for env in self.list_env]
return observations
Expand Down Expand Up @@ -118,20 +114,16 @@ def render(self, mode='human'):
def get_images(self):
r"""Returns a batched RGB array with shape [N, H, W, C] from all environments.
Returns
-------
imgs : ndarray
a batched RGB array with shape [N, H, W, C]
Returns:
ndarray: a batched RGB array with shape [N, H, W, C]
"""
return [env.render(mode='rgb_array') for env in self.list_env]

def get_viewer(self):
r"""Returns an instantiated :class:`ImageViewer`.
Returns
-------
viewer : ImageViewer
an image viewer
Returns:
ImageViewer: an image viewer
"""
if self.viewer is None: # create viewer is not existed
self.viewer = ImageViewer(max_width=500) # set a max width here
Expand Down
13 changes: 4 additions & 9 deletions lagom/envs/wrappers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ def get_wrapper(env, name):
env (Env): environment
name (str): name of the wrapper
Returns
-------
out : env
wrapped environment
Returns:
Env: wrapped environment
"""
if name == env.__class__.__name__:
return env
Expand All @@ -27,11 +25,8 @@ def get_all_wrappers(env):
Args:
env (Env): wrapped environment
Returns
-------
out : list
list of string names of wrappers
Returns:
list: a list of string names of wrappers
"""
out = []
while env is not env.unwrapped:
Expand Down
6 changes: 2 additions & 4 deletions lagom/es.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,8 @@ class BaseES(ABC):
def ask(self):
r"""Sample a set of new candidate solutions.
Returns
-------
solutions : list
sampled candidate solutions
Returns:
list: a list of sampled candidate solutions
"""
pass

Expand Down
6 changes: 2 additions & 4 deletions lagom/experiment/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,8 @@ def make_configs(self):
r"""Generate a list of all possible combinations of configurations, including
grid search and random search.
Returns
-------
list_config : list
a list of all possible configurations
Returns:
list: a list of all possible configurations
"""
keys_fixed = []
keys_grid = []
Expand Down
18 changes: 6 additions & 12 deletions lagom/networks/make_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ def make_fc(input_dim, hidden_sizes):
input_dim (int): input dimension in the first fully connected layer.
hidden_sizes (list): a list of hidden sizes, each for one fully connected layer.
Returns
-------
fc : nn.ModuleList
A ModuleList of fully connected layers.
Returns:
nn.ModuleList: A ModuleList of fully connected layers.
"""
assert isinstance(hidden_sizes, list), f'expected list, got {type(hidden_sizes)}'

Expand Down Expand Up @@ -63,10 +61,8 @@ def make_cnn(input_channel, channels, kernels, strides, paddings):
strides (list): a list of strides, each for one convolution layer.
paddings (list): a list of paddings, each for one convolution layer.
Returns
-------
cnn : nn.ModuleList
A ModuleList of 2D convolution layers.
Returns:
nn.ModuleList: A ModuleList of 2D convolution layers.
"""
N = len(channels)

Expand Down Expand Up @@ -120,10 +116,8 @@ def make_transposed_cnn(input_channel, channels, kernels, strides, paddings, out
paddings (list): a list of paddings, each for one transposed convolution layer.
output_paddings (list): a list of output paddings, each for one transposed convolution layer.
Returns
-------
transposed_cnn : nn.ModuleList
A ModuleList of 2D transposed convolution layers.
Returns:
nn.ModuleList: A ModuleList of 2D transposed convolution layers.
"""
N = len(channels)

Expand Down
12 changes: 4 additions & 8 deletions lagom/networks/mdn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,8 @@ def loss(self, logit_pi, mean, std, target):
std (Tensor): standard deviation of Gaussian mixtures, shape [N, K, D]
target (Tensor): target tensor, shape [N, D]
Returns
-------
loss : Tensor
calculated loss
Returns:
Tensor: calculated loss
"""
# target shape [N, D] to [N, 1, D]
target = target.unsqueeze(1)
Expand Down Expand Up @@ -105,10 +103,8 @@ def sample(self, logit_pi, mean, std, tau=1.0):
* If :math:`\tau > 1`: increase uncertainty
* If :math:`\tau < 1`: decrease uncertainty
Returns
-------
x : Tensor
sampled data with shape [N, D]
Returns:
Tensor: sampled data with shape [N, D]
"""
N, K, D = logit_pi.shape
pi = F.softmax(logit_pi/tau, dim=1)
Expand Down
7 changes: 2 additions & 5 deletions lagom/transform/explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,8 @@ def explained_variance(y_true, y_pred, **kwargs):
y_pred (list): predicted output
**kwargs: keyword arguments to specify the estimation of the explained variance.
Returns
-------
out : float
estimated explained variance
Returns:
float: estimated explained variance
"""
y_true = np.squeeze(y_true)
y_pred = np.squeeze(y_pred)
Expand Down
7 changes: 2 additions & 5 deletions lagom/transform/geometric_cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,8 @@ def geometric_cumsum(alpha, x):
alpha (float): exponential factor between zero and one.
x (list): input data
Returns
-------
out : ndarray
calculated data
Returns:
ndarray: calculated data
"""
x = np.asarray(x)
if x.ndim == 1:
Expand Down
10 changes: 3 additions & 7 deletions lagom/transform/interp_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,9 @@ def interp_curves(x, y):
y (list): a batch of y values.
num_point (int): number of points to generate from the interpolated line.
Returns
-------
out_x : list
interpolated x values (shared for the batch of curves)
out_y : list
interpolated y values
Returns:
tuple: a tuple of two lists. A list of interpolated x values (shared for the batch of curves)
and followed by a list of interpolated y values.
"""
new_x = np.unique(np.hstack(x))
assert new_x.ndim == 1
Expand Down
6 changes: 2 additions & 4 deletions lagom/transform/linear_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@ def __call__(self, x):
Args:
x (int): the current timestep.
Returns
-------
out : float
current value of the scheduling.
Returns:
float: current value of the scheduling.
"""
assert isinstance(x, int) and x >= 0, f'expected as a non-negative integer, got {x}'

Expand Down

0 comments on commit eb4d368

Please sign in to comment.