Skip to content

Commit

Permalink
Merge pull request #449 from sony/feature/20190520-modify-nnpgraph
Browse files Browse the repository at this point in the history
nnpgraph refactoring
  • Loading branch information
TakuyaNarihira committed May 31, 2019
2 parents f856034 + f965db9 commit ed7aea5
Showing 1 changed file with 36 additions and 34 deletions.
70 changes: 36 additions & 34 deletions python/src/nnabla/utils/nnp_graph.py
Expand Up @@ -267,40 +267,42 @@ def _get_variable_or_create(self, v, callback, current_scope):
assert np.all(np.array(shape) >
0), "Shape must be positive. Given {}.".format(shape)

# The variable is a parameter, then get from parameter registry.
if pvar.type == 'Parameter':
try:
param = get_parameter(name)
if param is None:
logger.info('Paramter `{}` is not found. Initializing.'.format(
name))
tmp = _create_variable(pvar, name, shape, self.rng)
param = tmp.variable_instance
set_parameter(name, param)
# Always copy param to current scope even if it already exists.
with nn.parameter_scope('', current_scope):
set_parameter(name, param)
except:
import sys
import traceback
raise ValueError(
'An error occurs during creation of a variable `{}` as a'
' parameter variable. The error was:\n----\n{}\n----\n'
'The parameters registered was {}'.format(
name, traceback.format_exc(),
'\n'.join(
list(nn.get_parameters(grad_only=False).keys()))))
assert shape == param.shape
param = param.get_unlinked_variable(need_grad=v.need_grad)
v.variable = param
param.name = name
return param

# Create a new one and returns.
var = nn.Variable(shape)
v.variable = var
var.name = name
return var
if pvar.type != 'Parameter':
# Create a new variable and returns.
var = nn.Variable(shape)
v.variable = var
var.name = name
return var

# Trying to load the parameter from .nnp file.
callback.verbose(
'Loading parameter `{}` from .nnp.'.format(name))
try:
param = get_parameter(name)
if param is None:
logger.info(
'Parameter `{}` is not found. Initializing.'.format(name))
tmp = _create_variable(pvar, name, shape, self.rng)
param = tmp.variable_instance
set_parameter(name, param)
# Always copy param to current scope even if it already exists.
with nn.parameter_scope('', current_scope):
set_parameter(name, param)
except:
import sys
import traceback
raise ValueError(
'An error occurs during creation of a variable `{}` as a'
' parameter variable. The error was:\n----\n{}\n----\n'
'The parameters registered was {}'.format(
name, traceback.format_exc(),
'\n'.join(
list(nn.get_parameters(grad_only=False).keys()))))
assert shape == param.shape
param = param.get_unlinked_variable(need_grad=v.need_grad)
v.variable = param
param.name = name
return param

def _create_inputs(self, inputs, callback, current_scope):
input_vars = []
Expand Down

0 comments on commit ed7aea5

Please sign in to comment.