From 014924b116bf00006ee5389754fa618a03a91678 Mon Sep 17 00:00:00 2001 From: Yukio Oobuchi Date: Tue, 11 Dec 2018 17:22:38 +0900 Subject: [PATCH] Support multiple dataset in .proto --- doc/requirements.txt | 2 +- .../src/nnabla/utils/cli/compare_with_cpu.py | 12 ++- python/src/nnabla/utils/cli/forward.py | 58 +++++++------ python/src/nnabla/utils/cli/profile.py | 82 +++++++++++-------- python/src/nnabla/utils/cli/train.py | 16 ++-- python/src/nnabla/utils/cli/uploader.py | 2 +- python/src/nnabla/utils/cli/utility.py | 18 ++-- python/src/nnabla/utils/data_source.py | 50 ++++++----- python/src/nnabla/utils/load.py | 20 ++++- python/src/nnabla/utils/network.py | 6 +- src/nbla/proto/nnabla.proto.tmpl | 4 +- src/nbla_utils/nnp_impl.cpp | 10 ++- src/nbla_utils/nnp_impl_monitor.cpp | 5 +- src/nbla_utils/nnp_impl_optimizer.cpp | 5 +- 14 files changed, 183 insertions(+), 107 deletions(-) diff --git a/doc/requirements.txt b/doc/requirements.txt index 0fd0acfd1..9fbd9d4c8 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1,4 +1,4 @@ -nnabla==1.0.10 +nnabla==1.0.11 sphinx-rtd-theme sphinxcontrib-actdiag sphinxcontrib-blockdiag diff --git a/python/src/nnabla/utils/cli/compare_with_cpu.py b/python/src/nnabla/utils/cli/compare_with_cpu.py index 3beb2dec5..23394a24d 100644 --- a/python/src/nnabla/utils/cli/compare_with_cpu.py +++ b/python/src/nnabla/utils/cli/compare_with_cpu.py @@ -80,10 +80,12 @@ def compare_optimizer(config, parameters, config_cpu, parameters_cpu, result_arr for v, d in o.dataset_assign.items(): let_data_to_variable(v.variable_instance, data[ - di.variables.index(d)]) + di.variables.index(d)], + data_name=d, variable_name=v.name) for v, d in o_cpu.dataset_assign.items(): let_data_to_variable(v.variable_instance, data[ - di.variables.index(d)]) + di.variables.index(d)], + data_name=d, variable_name=v.name) # Generate data generated = {} @@ -92,12 +94,14 @@ def compare_optimizer(config, parameters, config_cpu, parameters_cpu, result_arr dest_context = config.global_config.default_context if not o.forward_sequence or v not in o.forward_sequence[ 0].inputs else None let_data_to_variable(v.variable_instance, - data=generated[v.name], ctx=dest_context) + data=generated[v.name], ctx=dest_context, + variable_name=v.name) for v, generator in o_cpu.generator_assign.items(): dest_context = config.global_config.default_context if not o.forward_sequence or v not in o.forward_sequence[ 0].inputs else None let_data_to_variable(v.variable_instance, - data=generated[v.name], ctx=dest_context) + data=generated[v.name], ctx=dest_context, + variable_name=v.name) last_max_diff = 1e-5 diff --git a/python/src/nnabla/utils/cli/forward.py b/python/src/nnabla/utils/cli/forward.py index 02a5431ec..a229aba18 100644 --- a/python/src/nnabla/utils/cli/forward.py +++ b/python/src/nnabla/utils/cli/forward.py @@ -97,7 +97,7 @@ def _update_result(args, index, result, values, output_index, type_end_names, ou # CSV type with open(full_path, 'w') as f: writer = csv.writer(f, lineterminator='\n') - x = np.array(d, dtype=np.float32) + x = np.array(d) writer.writerows(x) outputs[data_index].append(os.path.join('.', file_name)) output_index += 1 @@ -120,17 +120,22 @@ class ForwardResult: vind = variables.index(d) if v.variable_instance.d.shape != data[vind].shape: let_data_to_variable(v.variable_instance, - np.reshape(data[vind], v.variable_instance.d.shape)) + np.reshape( + data[vind], v.variable_instance.d.shape), + data_name=d, variable_name=v.name) else: let_data_to_variable(v.variable_instance, - data[vind].astype(v.variable_instance.d.dtype)) + data[vind].astype( + v.variable_instance.d.dtype), + data_name=d, variable_name=v.name) # Generate data for v, generator in e.generator_assign.items(): v.variable_instance.d = generator(v.shape) # Forward recursive - sum = [np.zeros(o.shape) for o in e.output_assign.keys()] + sum = [np.zeros(o.shape, dtype=o.variable_instance.d.dtype) + for o in e.output_assign.keys()] for i in range(e.num_evaluations): e.network.forward(e.forward_sequence) if e.need_back_propagation: @@ -195,6 +200,7 @@ class ForwardConfig: batch_size=config.networks[0].batch_size, shuffle=False, normalize=normalize, + with_memory_cache=False, with_file_cache=False)) # load dataset as csv @@ -207,29 +213,31 @@ class ForwardConfig: rows = list(map(lambda row: list(map(lambda x: x if is_float( x) else compute_full_path(root_path, x), row)), rows)) - with data_iterator() as di: - index = 0 - while index < di.size: - data = di.next() - result, outputs = _forward(args, index, config, data, di.variables) - if index == 0: - for name, dim in zip(result.names, result.dims): - if dim == 1: - row0.append(name) - else: - for d in range(dim): - row0.append(name + '__' + str(d)) - for i, output in enumerate(outputs): - if index + i < len(rows): - rows[index + i].extend(output) - index += len(outputs) - logger.log( - 99, 'data {} / {}'.format(min([index, len(rows)]), len(rows))) - with open(os.path.join(args.outdir, 'output_result.csv'), 'w') as f: writer = csv.writer(f, lineterminator='\n') - writer.writerow(row0) - writer.writerows(rows) + with data_iterator() as di: + index = 0 + while index < di.size: + data = di.next() + result, outputs = _forward( + args, index, config, data, di.variables) + if index == 0: + for name, dim in zip(result.names, result.dims): + if dim == 1: + row0.append(name) + else: + for d in range(dim): + row0.append(name + '__' + str(d)) + writer.writerow(row0) + for i, output in enumerate(outputs): + if index + i < len(rows): + import copy + row = copy.deepcopy(rows[index + i]) + row.extend(output) + writer.writerow(row) + index += len(outputs) + logger.log( + 99, 'data {} / {}'.format(min([index, len(rows)]), len(rows))) logger.log(99, 'Forward Completed.') progress(None) diff --git a/python/src/nnabla/utils/cli/profile.py b/python/src/nnabla/utils/cli/profile.py index 4926c6a2b..91792d16e 100644 --- a/python/src/nnabla/utils/cli/profile.py +++ b/python/src/nnabla/utils/cli/profile.py @@ -23,6 +23,7 @@ import nnabla as nn import nnabla.function as F +from nnabla.ext_utils import import_extension_module from nnabla.logger import logger from nnabla.parameter import save_parameters from nnabla.utils.progress import configure_progress, progress @@ -30,28 +31,25 @@ import nnabla.utils.load as load -def profile(config, name, func, result_dict): - # for sync CPU/GPU - identity = F.Identity(config.global_config.default_context) - tmp_in = nn.Variable((1,)) - tmp_out = nn.Variable((1,)) - identity.setup([tmp_in], [tmp_out]) - - tmp_in.d = [0.] - identity.forward([tmp_in], [tmp_out]) +def profile(config, name, func, result_dict, synchromize): + # Warm-up + func() + synchromize() # Profile - start = time.time() + start_0 = time.time() + result = 0 count = 0 - while time.time() < start + 1.0 or count < 100: + while time.time() < start_0 + 1.0 or count < 100: + start = time.time() func() + synchromize() + stop = time.time() + result += stop - start count += 1 - # sync CPU/GPU - identity.forward([tmp_in], [tmp_out]) - data = tmp_out.d + t = result * 1000 / count - t = (time.time() - start) * 1000 / count logger.log(99, '%s %f(ms)' % (name, t)) result_dict[name] = t return result_dict @@ -74,7 +72,7 @@ def add_result(title, result_dict, result_array): return result_array -def profile_optimizer(config, result_array): +def profile_optimizer(config, result_array, synchronize): # Profile Training for opt in config.optimizers.values(): o = opt.optimizer @@ -83,6 +81,10 @@ def profile_optimizer(config, result_array): result_dict = OrderedDict() logger.log(99, 'Profiling ' + result_name + ' ...') + # Clear weight + for name, p in o.parameters.items(): + if name[-2:] in ('/W', '/b'): + p.data.zero() # Load dataset def load_dataset(): @@ -90,7 +92,7 @@ def load_dataset(): di = opt.data_iterator loaded_data[di] = di.next() return loaded_data - profile(config, 'load_dataset', load_dataset, result_dict) + profile(config, 'load_dataset', load_dataset, result_dict, synchronize) # Let data loaded_data = load_dataset() @@ -103,24 +105,32 @@ def let_data(): print(opt.data_iterator.variables) raise ValueError( 'Data "' + d + '" is not found in dataset.') - let_data_to_variable(v.variable_instance, data=data) + let_data_to_variable(v.variable_instance, data=data, + data_name=d, variable_name=v.name) profile(config, 'let_data (%s to %s)' % - (d, v.name), let_data, result_dict) + (d, v.name), let_data, result_dict, synchronize) # Generate data for v, generator in o.generator_assign.items(): def generate_data(): let_data_to_variable(v.variable_instance, - data=generator(v.shape)) + data=generator(v.shape), + variable_name=v.name) profile(config, 'generate_data (%s)' % - v.name, generate_data, result_dict) + v.name, generate_data, result_dict, synchronize) + ''' # Setup (detail) for func in o.forward_sequence: def setup(): o.network.setup_function(func) profile(config, 'setup_function (%s : %s)' % ( - func.name, func.function_instance.name), setup, result_dict) + func.name, func.function_instance.name), setup, result_dict, synchronize) + ''' + # Warm-up + o.network.forward(o.forward_sequence) + o.network.prepare_backward(o.backward_sequence) + o.network.backward(o.backward_sequence) # Forward (detail) for func in o.forward_sequence: @@ -129,12 +139,13 @@ def forward(): in_place_str = ' : in_place' if func.function_instance.inplace_data( 0) > 0 else '' profile(config, 'forward_function (%s : %s%s)' % ( - func.name, func.function_instance.name, in_place_str), forward, result_dict) + func.name, func.function_instance.name, in_place_str), forward, result_dict, synchronize) # Backward (detail) def prepare_backward(): o.network.prepare_backward(o.backward_sequence) - profile(config, 'prepare_backward', prepare_backward, result_dict) + profile(config, 'prepare_backward', + prepare_backward, result_dict, synchronize) for seq in o.backward_sequence.sequence: o.network.prepare_backward(o.backward_sequence) @@ -143,41 +154,42 @@ def backward(): in_place_str = ' : in_place' if seq.func.function_instance.inplace_grad( 0) > 0 else '' profile(config, 'backward_function (%s : %s%s)' % ( - seq.func.name, seq.func.function_instance.name, in_place_str), backward, result_dict) + seq.func.name, seq.func.function_instance.name, in_place_str), backward, result_dict, synchronize) # Forward (all) def forward_all(): o.network.forward(o.forward_sequence) - profile(config, 'forward_all', forward_all, result_dict) + profile(config, 'forward_all', forward_all, result_dict, synchronize) # Backward (all) def backward_all(): o.network.backward(o.backward_sequence) - profile(config, 'backward_all', backward_all, result_dict) + profile(config, 'backward_all', backward_all, result_dict, synchronize) # Backward (all) def backward_all_wo_zero_grad(): o.network.backward(o.backward_sequence, parameter_zero_grad=False) profile(config, 'backward_all(wo param zero_grad)', - backward_all_wo_zero_grad, result_dict) + backward_all_wo_zero_grad, result_dict, synchronize) # Update (weight decay) if o.weight_decay > 0: def weight_decay(): o.solver.weight_decay(o.weight_decay) profile(config, 'weight_decay (%s)' % - o.solver.name, weight_decay, result_dict) + o.solver.name, weight_decay, result_dict, synchronize) # Update def update(): o.solver.update() - profile(config, 'update (%s)' % o.solver.name, update, result_dict) + profile(config, 'update (%s)' % + o.solver.name, update, result_dict, synchronize) # Monitor loss def monitor_loss(): for l in o.loss_variables: np.mean(l.variable_instance.d) - profile(config, 'monitor_loss', monitor_loss, result_dict) + profile(config, 'monitor_loss', monitor_loss, result_dict, synchronize) result_array = add_result(result_name, result_dict, result_array) @@ -215,6 +227,12 @@ class MonConfig: m.data_iterator = None config.monitors[name] = m + ext_module = import_extension_module( + config.global_config.default_context.backend[0].split(':')[0]) + + def synchronize(): return ext_module.synchronize( + device_id=config.global_config.default_context.device_id) + result_array = [['time in ms']] # Profile Optimizer @@ -222,7 +240,7 @@ class MonConfig: for name, o in config.optimizers.items(): o.data_iterator = stack.enter_context( o.optimizer.data_iterator()) - result_array = profile_optimizer(config, result_array) + result_array = profile_optimizer(config, result_array, synchronize) # Write profiling result import csv diff --git a/python/src/nnabla/utils/cli/train.py b/python/src/nnabla/utils/cli/train.py index 974804e27..e5bfa051f 100644 --- a/python/src/nnabla/utils/cli/train.py +++ b/python/src/nnabla/utils/cli/train.py @@ -78,9 +78,7 @@ def _save_parameters(args, suffix, epoch, force=False): if suffix == 'best': base = os.path.join(args.outdir, 'results') filename = base + '.nnp' - - if not os.path.exists(filename) and \ - (force or timediff > 180.0 or epochdiff > 10): + if force or timediff > 180.0 or epochdiff > 10: version_filename = base + '_version.txt' @@ -134,14 +132,16 @@ def _sum_cost(): dest_context = config.global_config.default_context if not o.forward_sequence or v not in o.forward_sequence[ 0].inputs else None let_data_to_variable(v.variable_instance, data[ - di.variables.index(d)], ctx=dest_context) + di.variables.index(d)], ctx=dest_context, + data_name=d, variable_name=v.name) # Generate data for v, generator in o.generator_assign.items(): dest_context = config.global_config.default_context if not o.forward_sequence or v not in o.forward_sequence[ 0].inputs else None let_data_to_variable(v.variable_instance, - data=generator(v.shape), ctx=dest_context) + data=generator(v.shape), ctx=dest_context, + variable_name=v.name) # Monitor loss before forward to prepare input data while processing on # GPU @@ -227,14 +227,16 @@ def _sum_error(sum, error): dest_context = config.global_config.default_context if not m.forward_sequence or v not in m.forward_sequence[ 0].inputs else None let_data_to_variable(v.variable_instance, data[ - di.variables.index(d)], ctx=dest_context) + di.variables.index(d)], ctx=dest_context, + data_name=d, variable_name=v.name) # Generate data for v, generator in m.generator_assign.items(): dest_context = config.global_config.default_context if not m.forward_sequence or v not in m.forward_sequence[ 0].inputs else None let_data_to_variable(v.variable_instance, - data=generator(v.shape), ctx=dest_context) + data=generator(v.shape), ctx=dest_context, + variable_name=v.name) # Sum error before forward to prepare input data while processing # on GPU diff --git a/python/src/nnabla/utils/cli/uploader.py b/python/src/nnabla/utils/cli/uploader.py index ea4426bb4..922bb8fcb 100644 --- a/python/src/nnabla/utils/cli/uploader.py +++ b/python/src/nnabla/utils/cli/uploader.py @@ -73,7 +73,7 @@ def createTemporaryTar(self, name, csv_data, data_files, tmpdir): self._log('Create index.csv') self._progress.init(len(csv_data), 'Create index.csv') - with open(indexcsvfilename, 'w') as f: + with open(indexcsvfilename, 'w', newline='') as f: csvwriter = csv.writer(f) for row in csv_data: csvwriter.writerow(row) diff --git a/python/src/nnabla/utils/cli/utility.py b/python/src/nnabla/utils/cli/utility.py index 16405bef3..f7f217033 100644 --- a/python/src/nnabla/utils/cli/utility.py +++ b/python/src/nnabla/utils/cli/utility.py @@ -33,11 +33,19 @@ def compute_full_path(root_path, file_path): return full_path -def let_data_to_variable(variable, data, ctx=None): - if data.dtype <= np.float64: - variable.data.cast(data.dtype)[...] = data - else: - variable.d = data +def let_data_to_variable(variable, data, ctx=None, data_name=None, variable_name=None): + try: + if data.dtype <= np.float64: + variable.data.cast(data.dtype)[...] = data + else: + variable.d = data + except: + if variable.shape != data.shape: + logger.critical('Shape does not match between data{} and variable{} ({} != {}).'.format( + ' "' + data_name + '"' if data_name else '', + ' "' + variable_name + '"' if variable_name else '', + data.shape, variable.shape)) + raise variable.need_grad = False # Copy to device diff --git a/python/src/nnabla/utils/data_source.py b/python/src/nnabla/utils/data_source.py index f431adea7..2a5401e3c 100644 --- a/python/src/nnabla/utils/data_source.py +++ b/python/src/nnabla/utils/data_source.py @@ -229,26 +229,37 @@ def get_data(args): data[n].append(d) logger.info('Creating cache file {}'.format(cache_filename)) - if self._cache_file_format == ".h5": - h5 = h5py.File(cache_filename, 'w') + try: + if self._cache_file_format == ".h5": + h5 = h5py.File(cache_filename, 'w') + for k, v in data.items(): + h5.create_dataset(k, data=v) + h5.close() + else: + retry_count = 1 + is_create_cache_imcomplete = True + while is_create_cache_imcomplete: + try: + with open(cache_filename, 'wb') as f: + for v in data.values(): + numpy.save(f, v) + is_create_cache_imcomplete = False + except OSError: + retry_count += 1 + if retry_count > 10: + raise + logger.info( + 'Creating cache retry {}/10'.format(retry_count)) + except: + logger.critical( + 'An error occurred while creating cache file from dataset.') for k, v in data.items(): - h5.create_dataset(k, data=v) - h5.close() - else: - retry_count = 1 - is_create_cache_imcomplete = True - while is_create_cache_imcomplete: - try: - with open(cache_filename, 'wb') as f: - for v in data.values(): - numpy.save(f, v) - is_create_cache_imcomplete = False - except OSError: - retry_count += 1 - if retry_count > 10: - raise - logger.info( - 'Creating cache retry {}/10'.format(retry_count)) + size = v[0].shape + for d in v: + if size != d.shape: + logger.critical('The sizes of data "{}" are not the same. ({} != {})'.format( + k, size, d.shape)) + raise self._cache_file_names.append(cache_filename) self._cache_file_order.append(len(self._cache_file_order)) @@ -469,6 +480,7 @@ class DataSourceWithMemoryCache(DataSource): ''' def _get_data_func(self, position): + # return self._data_source._get_data(position) return [numpy.array(x, dtype=numpy.float32) for x in self._data_source._get_data(position)] def _get_data(self, position): diff --git a/python/src/nnabla/utils/load.py b/python/src/nnabla/utils/load.py index cdf5578e4..957256256 100644 --- a/python/src/nnabla/utils/load.py +++ b/python/src/nnabla/utils/load.py @@ -179,7 +179,7 @@ def _create_function(ctx, network, f, variable_index): if f.type == "Reshape": shape = resolve_reshape_params(inputs, f, network.batch_size) function_instance = F.Reshape( - ctx, shape=shape, inplace=f.reshape_param.inplace) + ctx, shape=shape, inplace=True) elif f.type == "RepeatStart": function_instance = F.Identity(ctx) elif f.type == "RepeatEnd": @@ -363,7 +363,10 @@ class Optimizer: optimizer.order = o.order optimizer.update_interval = o.update_interval if o.update_interval > 0 else 1 optimizer.network = networks[o.network_name] - optimizer.data_iterator = datasets[o.dataset_name].data_iterator + optimizer.data_iterator = OrderedDict() + for d in o.dataset_name: + optimizer.data_iterator[d] = datasets[d].data_iterator + optimizer.data_iterator = list(optimizer.data_iterator.values())[0] # Todo optimizer.dataset_assign = OrderedDict() for d in o.data_variable: @@ -617,7 +620,10 @@ class Monitor: monitor = Monitor() monitor.network = networks[m.network_name] - monitor.data_iterator = datasets[m.dataset_name].data_iterator + monitor.data_iterator = OrderedDict() + for d in m.dataset_name: + monitor.data_iterator[d] = datasets[d].data_iterator + monitor.data_iterator = list(monitor.data_iterator.values())[0] # Todo monitor.dataset_assign = OrderedDict() for d in m.data_variable: @@ -732,7 +738,13 @@ class Info: if ext in ['.nntxt', '.prototxt']: if not parameter_only: with open(filename, 'rt') as f: - text_format.Merge(f.read(), proto) + try: + text_format.Merge(f.read(), proto) + except: + logger.critical('Failed to read {}.'.format(filename)) + logger.critical( + '2 byte characters may be used for file name or folder name.') + raise if len(proto.parameter) > 0: if not exclude_parameter: nn.load_parameters(filename) diff --git a/python/src/nnabla/utils/network.py b/python/src/nnabla/utils/network.py index fd2c64ae2..f70d9cf7f 100644 --- a/python/src/nnabla/utils/network.py +++ b/python/src/nnabla/utils/network.py @@ -70,7 +70,7 @@ def forward(self, forward_sequence): except: index = forward_sequence.index(func) print_network_traceback( - forward_sequence[min(0, index - 4):index + 1]) + forward_sequence[max(0, index - 4):index + 1]) raise def forward_function(self, func): @@ -178,7 +178,7 @@ def backward(self, backward_sequence, parameter_zero_grad=True): except: index = backward_sequence.sequence.index(seq) print_network_traceback( - [seq.func for seq in backward_sequence.sequence[min(0, index - 4):index + 1]]) + [seq.func for seq in backward_sequence.sequence[max(0, index - 4):index + 1]]) raise def backward_function(self, seq): @@ -229,7 +229,7 @@ def setup(self, optimize=False): self.setup_function(func) except: print_network_traceback(list(self.functions.values())[ - min(0, i - 4):i + 1]) + max(0, i - 4):i + 1]) raise # set link structure to each layer diff --git a/src/nbla/proto/nnabla.proto.tmpl b/src/nbla/proto/nnabla.proto.tmpl index 381296ecf..08ea55a0f 100644 --- a/src/nbla/proto/nnabla.proto.tmpl +++ b/src/nbla/proto/nnabla.proto.tmpl @@ -129,7 +129,7 @@ message Optimizer { int64 order = 3; string network_name = 10; - string dataset_name = 20; + repeated string dataset_name = 20; Solver solver = 30; int64 update_interval = 40; @@ -223,7 +223,7 @@ message Monitor { string name = 1; string network_name = 10; - string dataset_name = 20; + repeated string dataset_name = 20; repeated DataVariable data_variable = 50; repeated GeneratorVariable generator_variable = 60; diff --git a/src/nbla_utils/nnp_impl.cpp b/src/nbla_utils/nnp_impl.cpp index ba170f578..11fda3480 100644 --- a/src/nbla_utils/nnp_impl.cpp +++ b/src/nbla_utils/nnp_impl.cpp @@ -780,9 +780,12 @@ shared_ptr NnpImpl::get_optimizer(const string &name) { if (it->name() != name) { continue; } + if (it->dataset_name_size() != 1) { + NBLA_ERROR(error_code::value, "Currently only one dataset supported."); + } return shared_ptr(new Optimizer( new OptimizerImpl(ctx_, *it, get_network(it->network_name()), - get_dataset(it->dataset_name())))); + get_dataset(it->dataset_name()[0])))); } NBLA_ERROR(error_code::value, "Optimizer `%s` not found", name.c_str()); } @@ -822,9 +825,12 @@ shared_ptr NnpImpl::get_monitor(const string &name) { if (it->name() != name) { continue; } + if (it->dataset_name_size() != 1) { + NBLA_ERROR(error_code::value, "Currently only one dataset supported."); + } return shared_ptr( new Monitor(new MonitorImpl(ctx_, *it, get_network(it->network_name()), - get_dataset(it->dataset_name())))); + get_dataset(it->dataset_name()[0])))); } NBLA_ERROR(error_code::value, "Monitor `%s` not found", name.c_str()); } diff --git a/src/nbla_utils/nnp_impl_monitor.cpp b/src/nbla_utils/nnp_impl_monitor.cpp index 1723e1f8a..231e24225 100644 --- a/src/nbla_utils/nnp_impl_monitor.cpp +++ b/src/nbla_utils/nnp_impl_monitor.cpp @@ -48,7 +48,10 @@ string MonitorImpl::network_name() const { return monitor_proto_.network_name(); } string MonitorImpl::dataset_name() const { - return monitor_proto_.dataset_name(); + if (monitor_proto_.dataset_name_size() != 1) { + NBLA_ERROR(error_code::value, "Currently only one dataset supported."); + } + return monitor_proto_.dataset_name()[0]; } vector MonitorImpl::get_data_variables() { diff --git a/src/nbla_utils/nnp_impl_optimizer.cpp b/src/nbla_utils/nnp_impl_optimizer.cpp index ee0c4f77c..7859363cd 100644 --- a/src/nbla_utils/nnp_impl_optimizer.cpp +++ b/src/nbla_utils/nnp_impl_optimizer.cpp @@ -71,7 +71,10 @@ string OptimizerImpl::network_name() const { return optimizer_proto_.network_name(); } string OptimizerImpl::dataset_name() const { - return optimizer_proto_.dataset_name(); + if (optimizer_proto_.dataset_name_size() != 1) { + NBLA_ERROR(error_code::value, "Currently only one dataset supported."); + } + return optimizer_proto_.dataset_name()[0]; } const int OptimizerImpl::update_interval() const { return optimizer_proto_.update_interval();