Skip to content

Commit

Permalink
Merge branch 'feature/20180529-support_fixed_point/rev-4' into 'featu…
Browse files Browse the repository at this point in the history
…re/20180401-file-format-converter'

reimplement the fixed-point value support in NNB (SSPSWALL-856)

See merge request nnabla/nnabla!188
  • Loading branch information
YukioOobuchi committed May 30, 2018
2 parents 02a56bb + 0f75167 commit 8dba9ff
Showing 1 changed file with 58 additions and 30 deletions.
88 changes: 58 additions & 30 deletions python/src/nnabla/utils/converter/nnablart/nnb.py
Expand Up @@ -30,6 +30,12 @@ class Nnb:
'''
NN_DATA_TYPE_FLOAT, NN_DATA_TYPE_INT16, NN_DATA_TYPE_INT8, NN_DATA_TYPE_SIGN = range(
4)
from_type_name = {
'FLOAT32': NN_DATA_TYPE_FLOAT,
'FIXED16': NN_DATA_TYPE_INT16,
'FIXED8': NN_DATA_TYPE_INT8
}
fp_pos_max = {NN_DATA_TYPE_INT16: 15, NN_DATA_TYPE_INT8: 7}


class NnbExporter:
Expand Down Expand Up @@ -63,6 +69,21 @@ def __init__(self, nnp, batch_size):
arg['type'])
self._argument_formats[fn] = argfmt

@staticmethod
def __compute_int_bit_num(param_array):
abs_array = np.abs(param_array)
max_abs = abs_array.max()
if max_abs >= 1:
max_idx = abs_array.argmax()
max_log2 = np.log2(max_abs)
if max_log2.is_integer() and param_array[max_idx] > 0:
int_bit_num = int(max_log2) + 2 # almost impossible
else:
int_bit_num = int(np.ceil(max_log2)) + 1
else:
int_bit_num = 1 # 1 is needed to represent sign
return int_bit_num

def export(self, nnb_output_filename, settings_template_filename, settings_filename, default_type):
settings = collections.OrderedDict()
if settings_filename is not None and len(settings_filename) == 1:
Expand Down Expand Up @@ -120,49 +141,48 @@ def export(self, nnb_output_filename, settings_template_filename, settings_filen
var = self._Variable
var.id = n

# set var.shape and store into NNB
shape = [
x if x >= 0 else self._info._batch_size for x in v.shape.dim]
index, pointer = self._alloc(
data=struct.pack('{}I'.format(len(shape)), *shape))
var.shape = self._List(len(shape), index)

var.type = 0 # NN_DATA_TYPE_FLOAT
var.fp_pos = 0

# parse a type option in YAML given via -settings
if v.name not in settings['variables']:
settings['variables'][v.name] = default_type[0]
type_option = settings['variables'][v.name]
opt_list = type_option.split('_')
type_name = opt_list[0]
fp_pos = int(opt_list[1]) if len(opt_list) == 2 else None

# set var.type, var.data_index, and var.fp_pos in this paragraph
var.type = Nnb.from_type_name[type_name]
if v.type == 'Parameter':
param = self._info._parameters[v.name]
param_data = list(param.data)
v_name = settings['variables'][v.name]
if v_name == 'FLOAT32':
data = struct.pack('{}f'.format(
len(param_data)), *param_data)
var.type = Nnb.NN_DATA_TYPE_FLOAT
elif v_name.startswith('FIXED16'):
fixed16_desc = v_name.split('_')
if (len(fixed16_desc) == 2) and int(fixed16_desc[1]) <= 15:
var.fp_pos = int(fixed16_desc[1])
scale = 1 << var.fp_pos
fixed16_n_data = [int(round(x * scale))
for x in param_data]
data = struct.pack('{}h'.format(
len(fixed16_n_data)), *fixed16_n_data)
var.type = Nnb.NN_DATA_TYPE_INT16
elif v_name.startswith('FIXED8'):
fixed8_desc = v_name.split('_')
if (len(fixed8_desc) == 2) and int(fixed8_desc[1]) <= 7:
var.fp_pos = int(fixed8_desc[1])
scale = 1 << var.fp_pos
fixed8_n_data = [int(round(x * scale)) for x in param_data]
data = struct.pack('{}b'.format(
len(fixed8_n_data)), *fixed8_n_data)
var.type = Nnb.NN_DATA_TYPE_INT8

# store parameter into NNB
array = np.array(self._info._parameters[v.name].data)
if type_name == 'FLOAT32':
fmt_base = '{}f'
else: # type_name == 'FIXED16' or type_name == 'FIXED8'
fmt_base = '{}h' if type_name == 'FIXED16' else '{}b'
# if fp_pos is not specified, compute it looking at its distribution
if fp_pos is None:
int_bit_num = NnbExporter.__compute_int_bit_num(array)
fp_pos = (Nnb.fp_pos_max[var.type] + 1) - int_bit_num
else:
pass # do nothing
# convert float to fixed point values
scale = 1 << fp_pos
array = np.round(array * scale).astype(int)
fmt = fmt_base.format(len(array))
data = struct.pack(fmt, *array)
index, pointer = self._alloc(data=data)
var.data_index = index
elif v.type == 'Buffer':
# check fp_pos
if var.type != Nnb.NN_DATA_TYPE_FLOAT and fp_pos is None:
msg = 'fp_pos must be specified for Buffer Variable'
raise ValueError(msg)
# FIXME: remove the following workaround
if n in vidx_to_abidx:
# n which is NOT in vidx_to_abidx can appear
Expand All @@ -172,6 +192,14 @@ def export(self, nnb_output_filename, settings_template_filename, settings_filen
# this var doesn't make sense, but add it
# so that nn_network_t::variables::size is conserved
var.data_index = -1
# check fp_pos and set var.fp_pos
if var.type == Nnb.NN_DATA_TYPE_INT16 or var.type == Nnb.NN_DATA_TYPE_INT8:
if 0 <= fp_pos or fp_pos <= Nnb.fp_pos_max[var.type]:
var.fp_pos = fp_pos
else:
raise ValueError('invalid fp_pos was given')
else:
var.fp_pos = 0

variable = struct.pack('IiIBi',
var.id,
Expand Down

0 comments on commit 8dba9ff

Please sign in to comment.