Skip to content

Commit

Permalink
Update Network and Policies
Browse files Browse the repository at this point in the history
* Rewrite Networks: well-documented, add initialization method, add network block making function i.e. MLP or CNN

* Update Policies: Rewrite policies, more general API for Categorical and Gaussian policies, user only needs to define feature layer network, and output layer handled automatically by the policy

* Add Test: networks and policies
  • Loading branch information
zuoxingdong committed Sep 3, 2018
1 parent 941f228 commit 32006ea
Show file tree
Hide file tree
Showing 21 changed files with 3,131 additions and 1,640 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ pytest test -v

# Reference

This repo is inspired by [OpenAI rllab](https://github.com/rll/rllab), [OpenAI baselines](https://github.com/openai/baselines), [RLPyTorch](https://github.com/pytorch/ELF/tree/master/src_py/rlpytorch), [TensorForce](https://github.com/reinforceio/tensorforce), and [Intel Coach](https://github.com/NervanaSystems/coach)
This repo is inspired by [OpenAI rllab](https://github.com/rll/rllab), [OpenAI baselines](https://github.com/openai/baselines), [RLPyTorch](https://github.com/pytorch/ELF/tree/master/src_py/rlpytorch), [TensorForce](https://github.com/reinforceio/tensorforce) [Intel Coach](https://github.com/NervanaSystems/coach) and [Dopamine](https://github.com/google/dopamine)

Please use this bibtex if you want to cite this repository in your publications:

Expand Down
18 changes: 10 additions & 8 deletions docs/source/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,17 @@ Networks

.. currentmodule:: lagom.core.networks

.. autofunction:: ortho_init

.. autoclass:: BaseNetwork
:members:

.. autoclass:: BaseMLP
:members:
.. autoclass:: BaseCNN
:members:

.. autofunction:: make_fc

.. autofunction:: make_cnn

.. autofunction:: make_transposed_cnn
.. autoclass:: BaseMDN
:members:

Expand Down Expand Up @@ -113,10 +115,10 @@ Policies
.. autoclass:: RandomPolicy
:members:

.. autoclass:: BaseCategoricalPolicy
.. autoclass:: CategoricalPolicy
:members:

.. autoclass:: BaseGaussianPolicy
.. autoclass:: GaussianPolicy
:members:

Transformations
Expand Down
3,247 changes: 2,098 additions & 1,149 deletions examples/policy_gradient/main.ipynb

Large diffs are not rendered by default.

2 changes: 0 additions & 2 deletions lagom/base_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,5 @@ def __call__(self, config):
-------
result : object
result of the execution. If no need to return anything, then an ``None`` should be returned.
result2: int
test
"""
raise NotImplementedError
14 changes: 9 additions & 5 deletions lagom/core/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from lagom.core.networks.base_network import BaseNetwork
from lagom.core.networks.base_mlp import BaseMLP
from lagom.core.networks.base_cnn import BaseCNN
from lagom.core.networks.base_vae import BaseVAE
from lagom.core.networks.base_mdn import BaseMDN
from .base_network import BaseNetwork
from .base_vae import BaseVAE
from .base_mdn import BaseMDN

from .init import ortho_init

from .make_blocks import make_fc
from .make_blocks import make_cnn
from .make_blocks import make_transposed_cnn
46 changes: 0 additions & 46 deletions lagom/core/networks/base_cnn.py

This file was deleted.

5 changes: 5 additions & 0 deletions lagom/core/networks/base_mdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@

from .base_network import BaseNetwork

####################
# TODO: update
# - do not overwrite __init__
# - put everything in make_params
# - in example, use make_fc/make_cnn

class BaseMDN(BaseNetwork):
"""
Expand Down
39 changes: 0 additions & 39 deletions lagom/core/networks/base_mlp.py

This file was deleted.

109 changes: 71 additions & 38 deletions lagom/core/networks/base_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,51 @@


class BaseNetwork(nn.Module):
"""
Base class for neural networks.
r"""Base class for all neural networks.
Any neural network should subclass this class.
Depending on the type of neural networks (e.g. policy network, Q-network), it is recommended
to override the constructor __init__ to provide essential items for the neural network.
The subclass should implement at least the following:
Note that if subclass overrides __init__, remember to provide
keywords aguments, i.e. **kwargs passing to super().__init__.
- :meth:`make_params`
- :meth:`init_params`
- :meth:`forward`
All inherited subclasses should at least implement the following functions:
1. make_params(self, config)
2. init_params(self, config)
Example::
import torch.nn as nn
import torch.nn.functional as F
from lagom.core.networks import BaseNetwork
class Network(BaseNetwork):
def make_params(self, config):
self.fc1 = nn.Linear(3, 2)
self.fc2 = nn.Linear(2, 1)
def init_params(self, config):
gain = nn.init.calculate_gain('relu')
nn.init.orthogonal_(self.fc1.weight, gain=gain)
nn.init.constant_(self.fc1.bias, 0.0)
nn.init.orthogonal_(self.fc2.weight, gain=gain)
nn.init.constant_(self.fc2.bias, 0.0)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
"""
def __init__(self, config=None, **kwargs):
r"""Initialize the neural network.
Args:
config (dict): A dictionary of configurations.
**kwargs: keyword arguments to specify the network.
"""
super().__init__()

self.config = config
Expand All @@ -26,71 +57,73 @@ def __init__(self, config=None, **kwargs):
for key, val in kwargs.items():
self.__setattr__(key, val)

# User-defined function to create all trainable parameters (layers)
# User-defined function to create all trainable parameters/layers
self.make_params(self.config)

# User-defined function to initialize all created parameters
# User-defined function to initialize all created parameters/layers
self.init_params(self.config)

def make_params(self, config):
"""
User-defined function to create all trainable parameters (layers)
r"""User-defined function to create all trainable parameters/layers for the neural network
according to a given configuration.
Args:
config (Config): configurations
Examples:
Refer to each inherited subclass with individual documentation.
config (dict): A dictionary of configurations.
"""
raise NotImplementedError

def init_params(self, config):
"""
User-defined function to initialize all created parameters
r"""User-defined function to initialize all created parameters in :meth:`make_params` according
to a given configuration.
Args:
config (Config): configurations
config (dict): A dictionary of configurations.
"""
raise NotImplementedError

@property
def num_params(self):
"""
Returns the number of trainable parameters.
"""
r"""Returns the total number of trainable parameters in the neural network."""
return sum(param.numel() for param in self.parameters() if param.requires_grad)

def save(self, f):
"""
Save the model parameters. It is saved by using recommended way from PyTorch documentation.
https://pytorch.org/docs/master/notes/serialization.html#best-practices
r"""Save the network parameters to a file.
It complies with the `recommended approach for saving a model in PyTorch documentation`_.
.. note::
It uses the highest pickle protocol to serialize the network parameters.
Args:
f (str): saving path
f (str): file path.
.. _recommended approach for saving a model in PyTorch documentation:
https://pytorch.org/docs/master/notes/serialization.html#best-practices
"""
torch.save(self.state_dict(), f)
import pickle
torch.save(self.state_dict(), f, pickle.HIGHEST_PROTOCOL)

def load(self, f):
"""
Load the model parameters. It is loaded by using recommended way from PyTorch documentation.
https://pytorch.org/docs/master/notes/serialization.html#best-practices
r"""Load the network parameters from a file.
It complies with the `recommended approach for saving a model in PyTorch documentation`_.
Args:
f (str): loading path
f (str): file path.
.. _recommended approach for saving a model in PyTorch documentation:
https://pytorch.org/docs/master/notes/serialization.html#best-practices
"""
self.load_state_dict(torch.load(f))

def to_vec(self):
"""
Flatten the network parameters into a single big vector.
"""
r"""Returns the network parameters as a single flattened vector. """
return parameters_to_vector(parameters=self.parameters())

def from_vec(self, x):
"""
Unflatten the given vector as the network parameters.
r"""Set the network parameters from a single flattened vector.
Args:
x (Tensor): flattened single vector with size consistent of the number of network paramters.
x (Tensor): A single flattened vector of the network parameters with consistent size.
"""
vector_to_parameters(vec=x, parameters=self.parameters())
6 changes: 6 additions & 0 deletions lagom/core/networks/base_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@

from .base_network import BaseNetwork

####################
# TODO: update
# - do not overwrite __init__
# - put everything in make_params
# - in example, use make_fc/make_cnn


class BaseVAE(BaseNetwork):
"""
Expand Down
44 changes: 44 additions & 0 deletions lagom/core/networks/init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch.nn as nn


def ortho_init(module, nonlinearity=None, weight_scale=1.0, constant_bias=0.0):
r"""Applies orthogonal initialization for the parameters of a given module.
Args:
module (nn.Module): A module to apply orthogonal initialization over its parameters.
nonlinearity (str, optional): Nonlinearity followed by forward pass of the module. When nonlinearity
is not ``None``, the gain will be calculated and :attr:`weight_scale` will be ignored.
Default: ``None``
weight_scale (float, optional): Scaling factor to initialize the weight. Ignored when
:attr:`nonlinearity` is not ``None``. Default: 1.0
constant_bias (float, optional): Constant value to initialize the bias. Default: 0.0
.. note::
Currently, the only supported :attr:`module` are elementary neural network layers, e.g.
nn.Linear, nn.Conv2d, nn.LSTM. The submodules are not supported.
Example::
>> a = nn.Linear(2, 3)
>> ortho_init(a)
"""
# Get the gain (scaling factor)
if nonlinearity is not None: # based on nonlinearity
gain = nn.init.calculate_gain(nonlinearity)
else: # user provided
gain = weight_scale

# Initialization
if isinstance(module, (nn.RNNBase, nn.RNNCellBase)): # RNN
# Iterate over named parameters
for name, param in module.named_parameters():
if 'weight' in name: # Weight
nn.init.orthogonal_(param, gain=gain)
elif 'bias' in name: # Bias
nn.init.constant_(param, constant_bias)
else: # other modules with single .weight and .bias
# Weight
nn.init.orthogonal_(module.weight, gain=gain)
# Bias
nn.init.constant_(module.bias, constant_bias)

0 comments on commit 32006ea

Please sign in to comment.