-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Rewrite Networks: well-documented, add initialization method, add network block making function i.e. MLP or CNN * Update Policies: Rewrite policies, more general API for Categorical and Gaussian policies, user only needs to define feature layer network, and output layer handled automatically by the policy * Add Test: networks and policies
- Loading branch information
1 parent
941f228
commit 32006ea
Showing
21 changed files
with
3,131 additions
and
1,640 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,9 @@ | ||
from lagom.core.networks.base_network import BaseNetwork | ||
from lagom.core.networks.base_mlp import BaseMLP | ||
from lagom.core.networks.base_cnn import BaseCNN | ||
from lagom.core.networks.base_vae import BaseVAE | ||
from lagom.core.networks.base_mdn import BaseMDN | ||
from .base_network import BaseNetwork | ||
from .base_vae import BaseVAE | ||
from .base_mdn import BaseMDN | ||
|
||
from .init import ortho_init | ||
|
||
from .make_blocks import make_fc | ||
from .make_blocks import make_cnn | ||
from .make_blocks import make_transposed_cnn |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch.nn as nn | ||
|
||
|
||
def ortho_init(module, nonlinearity=None, weight_scale=1.0, constant_bias=0.0): | ||
r"""Applies orthogonal initialization for the parameters of a given module. | ||
Args: | ||
module (nn.Module): A module to apply orthogonal initialization over its parameters. | ||
nonlinearity (str, optional): Nonlinearity followed by forward pass of the module. When nonlinearity | ||
is not ``None``, the gain will be calculated and :attr:`weight_scale` will be ignored. | ||
Default: ``None`` | ||
weight_scale (float, optional): Scaling factor to initialize the weight. Ignored when | ||
:attr:`nonlinearity` is not ``None``. Default: 1.0 | ||
constant_bias (float, optional): Constant value to initialize the bias. Default: 0.0 | ||
.. note:: | ||
Currently, the only supported :attr:`module` are elementary neural network layers, e.g. | ||
nn.Linear, nn.Conv2d, nn.LSTM. The submodules are not supported. | ||
Example:: | ||
>> a = nn.Linear(2, 3) | ||
>> ortho_init(a) | ||
""" | ||
# Get the gain (scaling factor) | ||
if nonlinearity is not None: # based on nonlinearity | ||
gain = nn.init.calculate_gain(nonlinearity) | ||
else: # user provided | ||
gain = weight_scale | ||
|
||
# Initialization | ||
if isinstance(module, (nn.RNNBase, nn.RNNCellBase)): # RNN | ||
# Iterate over named parameters | ||
for name, param in module.named_parameters(): | ||
if 'weight' in name: # Weight | ||
nn.init.orthogonal_(param, gain=gain) | ||
elif 'bias' in name: # Bias | ||
nn.init.constant_(param, constant_bias) | ||
else: # other modules with single .weight and .bias | ||
# Weight | ||
nn.init.orthogonal_(module.weight, gain=gain) | ||
# Bias | ||
nn.init.constant_(module.bias, constant_bias) |
Oops, something went wrong.