Skip to content

Commit

Permalink
Merge pull request #480 from sony/feature/20190701-model-auto-forward…
Browse files Browse the repository at this point in the history
…-support

[Models] auto-forward support and modified parameter loading method.
  • Loading branch information
TakuyaNarihira committed Jul 3, 2019
2 parents cb17b15 + f34b56a commit ee6307b
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions python/src/nnabla/utils/nnp_graph.py
Expand Up @@ -274,20 +274,34 @@ def _get_variable_or_create(self, v, callback, current_scope):
var.name = name
return var

# Trying to load the parameter from .nnp file.
callback.verbose(
'Loading parameter `{}` from .nnp.'.format(name))
# Trying to load the parameter from the global scope.
try:
with nn.parameter_scope('', current_scope):
param = get_parameter(name)

if param is not None:
assert shape == param.shape
param = param.get_unlinked_variable(need_grad=v.need_grad)
v.variable = param
param.name = name
return param

# Parameter does not exist in the global scope.
# Then try to load the parameter from .nnp file.
callback.verbose(
'Loading parameter `{}` from .nnp.'.format(name))
param = get_parameter(name)

if param is None:
logger.info(
'Parameter `{}` is not found. Initializing.'.format(name))
tmp = _create_variable(pvar, name, shape, self.rng)
param = tmp.variable_instance
set_parameter(name, param)
# Always copy param to current scope even if it already exists.

# Register the parameter to the current (global) scope.
with nn.parameter_scope('', current_scope):
set_parameter(name, param)

except:
import sys
import traceback
Expand All @@ -298,6 +312,7 @@ def _get_variable_or_create(self, v, callback, current_scope):
name, traceback.format_exc(),
'\n'.join(
list(nn.get_parameters(grad_only=False).keys()))))

assert shape == param.shape
param = param.get_unlinked_variable(need_grad=v.need_grad)
v.variable = param
Expand All @@ -321,7 +336,7 @@ def _create_function(self, f, callback, current_scope):
function_instance = _create_function(inputs, f.proto, self.batch_size)

outputs = function_instance(
*inputs, n_outputs=len(f.outputs), auto_forward=False)
*inputs, n_outputs=len(f.outputs), auto_forward=nn.get_auto_forward())
if not isinstance(outputs, tuple):
outputs = (outputs,)

Expand Down

0 comments on commit ee6307b

Please sign in to comment.