Skip to content

Commit

Permalink
Fix handling shape(arg of reshape).
Browse files Browse the repository at this point in the history
  • Loading branch information
YukioOobuchi committed Jun 1, 2018
1 parent 458c2ee commit 679abe4
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 4 deletions.
8 changes: 5 additions & 3 deletions python/src/nnabla/utils/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,12 @@ def _create_function(ctx, network, f, variable_index):
outputs = [network.variables[v_name] for v_name in output_variable_names]

if f.type == "Reshape":
reshape_shape = (network.batch_size,) + \
tuple(f.reshape_param.shape.dim)
shape = tuple(
[d if d >= 0 else network.batch_size for d in f.reshape_param.shape.dim])
if numpy.prod(shape) != numpy.prod(inputs[0].shape):
shape = (network.batch_size,) + shape
function_instance = F.Reshape(ctx,
shape=reshape_shape)
shape=shape)
elif f.type == "RepeatStart":
function_instance = F.Identity(ctx)
elif f.type == "RepeatEnd":
Expand Down
4 changes: 3 additions & 1 deletion python/src/nnabla/utils/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def collect_info(func):
else:
v.type = 'Buffer'
# TODO: The first dimension is always considered as batch size.
# No problem?
if len(shape) > 0:
shape[0] = -1
v.shape.dim.extend(shape)
Expand All @@ -241,6 +240,9 @@ def collect_info(func):

for name, function in functions.items():
f = n.function.add()
if function['type'] == 'Reshape':
# TODO: The first dimension is always considered as batch size.
function['args']['shape'][0] = -1
_create_function_nntxt(f, name, function)

return n
Expand Down
21 changes: 21 additions & 0 deletions src/nbla_utils/nnp_impl_create_function.cpp.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,27 @@ shared_ptr<nbla::CgFunction> NetworkImpl::create_cgfunction(const ::Function& fu
${name}Parameter param = func.${func['snake_name']}_param();
% for argname, arg in func['arguments'].items():
${proto_to_ctype(argname, arg)};
% if name == 'Reshape' and argname == 'shape':

int prod_shape = 1;
for( auto &v : arg_shape ) {
prod_shape *= v;
if( v < 0 ) {
v = batch_size();
}
}
int prod_input = 1;
for( auto inp : func.input() ) {
auto var_it = variable_protos_.find(inp);
const ::Variable *var = var_it->second;
for ( auto d : var->shape().dim() ) {
prod_input *= d;
}
}
if( prod_input != prod_shape ) {
arg_shape.insert(arg_shape.begin(), batch_size());
}
% endif
% endfor
% endif
nbla::FunctionPtr fp = create_${name}(ctx_${''.join([', arg_' + argname for argname in func.get('arguments', {}).keys()])});
Expand Down

0 comments on commit 679abe4

Please sign in to comment.