Skip to content

Commit

Permalink
Merge branch 'master' into feature/20180401-file-format-converter
Browse files Browse the repository at this point in the history
  • Loading branch information
YukioOobuchi committed Jun 19, 2018
2 parents 1b28d8b + 6884a4a commit 6e02015
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 53 deletions.
5 changes: 3 additions & 2 deletions python/src/nnabla/ext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def get_extension_context(ext_name, **kw):
"""
if ext_name == 'cuda.cudnn':
from nnabla import logger
logger.warn('Deprecated extension name "cuda.cudnn" passed.')
extensin_name = 'cudnn'
logger.warn(
'Deprecated extension name "cuda.cudnn" passed. Use "cudnn" instead.')
ext_name = 'cudnn'
mod = import_extension_module(ext_name)
return mod.context(**kw)
22 changes: 17 additions & 5 deletions python/src/nnabla/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def set_parameter(key, param):
current_scope[names[0]] = param


def get_parameter_or_create(name, shape, initializer=None, need_grad=True):
def get_parameter_or_create(name, shape=None, initializer=None, need_grad=True):
"""
Returns an existing parameter variable with the provided name.
If a variable with the provided name does not exist,
Expand All @@ -162,8 +162,8 @@ def get_parameter_or_create(name, shape, initializer=None, need_grad=True):
name(str): The name under the current scope. If it already exists, the name is queried from the
parameter manager.
shape (:obj:`tuple` of :obj:`int`): Shape of created parameter. The shape of the specified
parameter must match with this shape.
initializer (~nnabla.initializer.BaseInitializer): An initialization function to be applied to the parameter.
parameter must match with this shape. The default is None which is only valid if initializer is given as an :obj:`numpy.ndarray`.
initializer (:obj:`nnabla.initializer.BaseInitializer` or :obj:`numpy.ndarray`): An initialization function to be applied to the parameter. :obj:`numpy.ndarray` can also be given to initialize parameters from numpy array data.
need_grad (bool): The value for `need_grad` .
The default is True.
Expand All @@ -178,9 +178,21 @@ class VariableInfo:
pass
info = VariableInfo()
info.initializer = initializer
param = nn.Variable(shape, need_grad=need_grad)

if initializer is not None:
param.d = initializer(shape=param.shape)
if isinstance(initializer, numpy.ndarray): # numpy init
param = nn.Variable(initializer.shape, need_grad=need_grad)
param.d = initializer
elif isinstance(initializer, nn.initializer.BaseInitializer): # initializer init
assert shape is not None
param = nn.Variable(shape, need_grad=need_grad)
param.d = initializer(shape=param.shape)
else:
raise ValueError(
"`initializer` must be either the :obj:`numpy.ndarray` or an instance inherited from `nnabla.initializer.BaseInitializer`.")
else: # default init
assert shape is not None
param = nn.Variable(shape, need_grad=need_grad)
set_parameter(name, param)
else:
assert param.shape == tuple(shape)
Expand Down

0 comments on commit 6e02015

Please sign in to comment.