diff --git a/python/src/nnabla/utils/nnp_graph.py b/python/src/nnabla/utils/nnp_graph.py index 61132e1b7..86ff012ce 100644 --- a/python/src/nnabla/utils/nnp_graph.py +++ b/python/src/nnabla/utils/nnp_graph.py @@ -267,40 +267,42 @@ def _get_variable_or_create(self, v, callback, current_scope): assert np.all(np.array(shape) > 0), "Shape must be positive. Given {}.".format(shape) - # The variable is a parameter, then get from parameter registry. - if pvar.type == 'Parameter': - try: - param = get_parameter(name) - if param is None: - logger.info('Paramter `{}` is not found. Initializing.'.format( - name)) - tmp = _create_variable(pvar, name, shape, self.rng) - param = tmp.variable_instance - set_parameter(name, param) - # Always copy param to current scope even if it already exists. - with nn.parameter_scope('', current_scope): - set_parameter(name, param) - except: - import sys - import traceback - raise ValueError( - 'An error occurs during creation of a variable `{}` as a' - ' parameter variable. The error was:\n----\n{}\n----\n' - 'The parameters registered was {}'.format( - name, traceback.format_exc(), - '\n'.join( - list(nn.get_parameters(grad_only=False).keys())))) - assert shape == param.shape - param = param.get_unlinked_variable(need_grad=v.need_grad) - v.variable = param - param.name = name - return param - - # Create a new one and returns. - var = nn.Variable(shape) - v.variable = var - var.name = name - return var + if pvar.type != 'Parameter': + # Create a new variable and returns. + var = nn.Variable(shape) + v.variable = var + var.name = name + return var + + # Trying to load the parameter from .nnp file. + callback.verbose( + 'Loading parameter `{}` from .nnp.'.format(name)) + try: + param = get_parameter(name) + if param is None: + logger.info( + 'Parameter `{}` is not found. Initializing.'.format(name)) + tmp = _create_variable(pvar, name, shape, self.rng) + param = tmp.variable_instance + set_parameter(name, param) + # Always copy param to current scope even if it already exists. + with nn.parameter_scope('', current_scope): + set_parameter(name, param) + except: + import sys + import traceback + raise ValueError( + 'An error occurs during creation of a variable `{}` as a' + ' parameter variable. The error was:\n----\n{}\n----\n' + 'The parameters registered was {}'.format( + name, traceback.format_exc(), + '\n'.join( + list(nn.get_parameters(grad_only=False).keys())))) + assert shape == param.shape + param = param.get_unlinked_variable(need_grad=v.need_grad) + v.variable = param + param.name = name + return param def _create_inputs(self, inputs, callback, current_scope): input_vars = []