Skip to content

Commit

Permalink
Merge branch 'feature/20180516-support-fixed16-nnb' into 'feature/201…
Browse files Browse the repository at this point in the history
…80401-file-format-converter'

add fixed data type for nnb

See merge request nnabla/nnabla!173
  • Loading branch information
YukioOobuchi committed May 16, 2018
2 parents 8186afb + cbc4707 commit 13c20b5
Showing 1 changed file with 42 additions and 16 deletions.
58 changes: 42 additions & 16 deletions python/src/nnabla/utils/converter/nnablart/nnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@

from .utils import create_nnabart_info

class Nnb:
'''
Nnb is only used as namespace
'''
NN_DATA_TYPE_FLOAT, NN_DATA_TYPE_INT16, NN_DATA_TYPE_INT8, NN_DATA_TYPE_SIGN = range(4)


class NnbExporter:
def _align(self, size):
Expand Down Expand Up @@ -56,7 +62,6 @@ def __init__(self, nnp, batch_size):
self._argument_formats[fn] = argfmt

def export(self, nnb_output_filename, settings_template_filename, settings_filename, default_type):

settings = collections.OrderedDict()
if settings_filename is not None and len(settings_filename) == 1:
settings = nnabla.utils.converter.load_yaml_ordered(
Expand Down Expand Up @@ -122,11 +127,34 @@ def export(self, nnb_output_filename, settings_template_filename, settings_filen
var.type = 0 # NN_DATA_TYPE_FLOAT
var.fp_pos = 0

if v.name not in settings['variables']:
settings['variables'][v.name] = default_type[0]

if v.type == 'Parameter':
param = self._info._parameters[v.name]
param_data = list(param.data)
data = struct.pack('{}f'.format(
len(param_data)), *param_data)
v_name = settings['variables'][v.name]
if v_name == 'FLOAT32':
data = struct.pack('{}f'.format(
len(param_data)), *param_data)
var.type = Nnb.NN_DATA_TYPE_FLOAT
elif v_name.startswith('FIXED16'):
fixed16_desc = v_name.split('_')
if (len(fixed16_desc) == 2) and int(fixed16_desc[1]) <= 15:
var.fp_pos = int(fixed16_desc[1])
scale = 1 << var.fp_pos
fixed16_n_data = [int(round(x * scale)) for x in param_data]
data = struct.pack('{}h'.format(len(fixed16_n_data)), *fixed16_n_data)
var.type = Nnb.NN_DATA_TYPE_INT16
elif v_name.startswith('FIXED8'):
fixed8_desc = v_name.split('_')
if (len(fixed8_desc) == 2) and int(fixed8_desc[1]) <= 7:
var.fp_pos = int(fixed8_desc[1])
scale = 1 << var.fp_pos
fixed8_n_data = [int(round(x * scale)) for x in param_data]
data = struct.pack('{}b'.format(len(fixed8_n_data)), *fixed8_n_data)
var.type = Nnb.NN_DATA_TYPE_INT8

index, pointer = self._alloc(data=data)
var.data_index = index
elif v.type == 'Buffer':
Expand All @@ -140,22 +168,20 @@ def export(self, nnb_output_filename, settings_template_filename, settings_filen
# so that nn_network_t::variables::size is conserved
var.data_index = -1

if v.name in settings['variables']:
if settings['variables'][v.name] != default_type[0]:
if v.type == 'Parameter':
# TODO convert parameter here.
print('Convert {} to {}.'.format(
v.name, settings['variables'][v.name]))
pass
# TODO set var.type here
else:
settings['variables'][v.name] = default_type[0]

# if v.name in settings['variables']:
# if settings['variables'][v.name] != default_type[0]:
# if v.type == 'Parameter':
# # TODO convert parameter here.
# print('Convert {} to {}.'.format(
# v.name, settings['variables'][v.name]))
# pass
# # TODO set var.type here
# else:
# settings['variables'][v.name] = default_type[0]
variable = struct.pack('IiIBi',
var.id,
var.shape.size, var.shape.list_index,
(var.type & 0xf << 4) | (
var.fp_pos & 0xf),
((var.fp_pos & 0xf) << 4 | (var.type & 0xf)),
var.data_index)
index, pointer = self._alloc(data=variable)
vindexes.append(index)
Expand Down

0 comments on commit 13c20b5

Please sign in to comment.