Skip to content

Commit

Permalink
Merge branch 'fix/20180530-reshape-problem' into feature/20180401-fil…
Browse files Browse the repository at this point in the history
…e-format-converter
  • Loading branch information
YukioOobuchi committed Jun 1, 2018
2 parents dfe65d2 + 679abe4 commit 7508066
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 30 deletions.
37 changes: 13 additions & 24 deletions python/src/nnabla/utils/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,12 @@ def _create_function(ctx, network, f, variable_index):
outputs = [network.variables[v_name] for v_name in output_variable_names]

if f.type == "Reshape":
batch_size = network.batch_size
if network.batch_size < 1:
batch_size = 1
reshape_shape = (batch_size,) + \
tuple(f.reshape_param.shape.dim)
shape = tuple(
[d if d >= 0 else network.batch_size for d in f.reshape_param.shape.dim])
if numpy.prod(shape) != numpy.prod(inputs[0].shape):
shape = (network.batch_size,) + shape
function_instance = F.Reshape(ctx,
shape=reshape_shape)
shape=shape)
elif f.type == "RepeatStart":
function_instance = F.Identity(ctx)
elif f.type == "RepeatEnd":
Expand Down Expand Up @@ -361,20 +360,14 @@ def _context(proto):
logger.warn('Old-style context. Updating to new format.')
# Update from old Context
if proto.backend == 'cpu|cuda':
try:
if 'cudnn' in proto.compute_backend:
import nnabla_ext.cudnn
ctx = nnabla_ext.cudnn.context(device_id=proto.device_id)
elif 'default' in proto.compute_backend:
import nnabla_ext.cuda
ctx = nnabla_ext.cuda.context(device_id=proto.device_id)
else:
raise ValueError('Invalid context {}'.format(proto))
except ImportError:
logger.log(
99, 'Could not import extension. Fallback into CPU context.')
import nnabla_ext.cpu
ctx = nnabla_ext.cpu.context()
if 'cudnn' in proto.compute_backend:
import nnabla_ext.cudnn
ctx = nnabla_ext.cudnn.context(device_id=proto.device_id)
elif 'default' in proto.compute_backend:
import nnabla_ext.cuda
ctx = nnabla_ext.cuda.context(device_id=proto.device_id)
else:
raise ValueError('Invalid context {}'.format(proto))
elif proto.backend == 'cpu':
import nnabla_ext.cpu
ctx = nnabla_ext.cpu.context()
Expand Down Expand Up @@ -600,8 +593,6 @@ class Info:
if ext in ['.nntxt', '.prototxt']:
with open(filename, 'rt') as f:
text_format.Merge(f.read(), proto)
if len(proto.parameter) > 0:
nn.load_parameters(filename)
elif ext in ['.protobuf', '.h5']:
nn.load_parameters(filename)

Expand All @@ -619,8 +610,6 @@ class Info:
nnp.extract(name, tmpdir)
with open(os.path.join(tmpdir, name), 'rt') as f:
text_format.Merge(f.read(), proto)
if len(proto.parameter) > 0:
nn.load_parameters(os.path.join(tmpdir, name))
elif ext in ['.protobuf', '.h5']:
nnp.extract(name, tmpdir)
nn.load_parameters(os.path.join(tmpdir, name))
Expand Down
4 changes: 3 additions & 1 deletion python/src/nnabla/utils/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,6 @@ def collect_info(func):
else:
v.type = 'Buffer'
# TODO: The first dimension is always considered as batch size.
# No problem?
if len(shape) > 0:
shape[0] = -1
v.shape.dim.extend(shape)
Expand All @@ -241,6 +240,9 @@ def collect_info(func):

for name, function in functions.items():
f = n.function.add()
if function['type'] == 'Reshape':
# TODO: The first dimension is always considered as batch size.
function['args']['shape'][0] = -1
_create_function_nntxt(f, name, function)

return n
Expand Down
23 changes: 18 additions & 5 deletions src/nbla_utils/nnp_impl_create_function.cpp.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,25 @@ shared_ptr<nbla::CgFunction> NetworkImpl::create_cgfunction(const ::Function& fu
% for argname, arg in func['arguments'].items():
${proto_to_ctype(argname, arg)};
% if name == 'Reshape' and argname == 'shape':
// Add batch_size.
int bs = batch_size();
if( bs < 1 ) {
bs = 1;

int prod_shape = 1;
for( auto &v : arg_shape ) {
prod_shape *= v;
if( v < 0 ) {
v = batch_size();
}
}
int prod_input = 1;
for( auto inp : func.input() ) {
auto var_it = variable_protos_.find(inp);
const ::Variable *var = var_it->second;
for ( auto d : var->shape().dim() ) {
prod_input *= d;
}
}
if( prod_input != prod_shape ) {
arg_shape.insert(arg_shape.begin(), batch_size());
}
arg_shape.insert(arg_shape.begin(), bs);
% endif
% endfor
% endif
Expand Down

0 comments on commit 7508066

Please sign in to comment.