Skip to content

Commit

Permalink
Separate save_variable_buffer prepare to implement for CSRC.
Browse files Browse the repository at this point in the history
  • Loading branch information
YukioOobuchi committed Apr 27, 2018
1 parent 73f6b01 commit af17de8
Show file tree
Hide file tree
Showing 2 changed files with 223 additions and 195 deletions.
229 changes: 34 additions & 195 deletions python/src/nnabla/utils/converter/nnablart/nnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,37 @@ def export(self, nnb_output_filename, settings_template_filename, settings_filen
if 'variables' not in settings:
settings['variables'] = collections.OrderedDict()

####################################################################
# make 2 data to save Variable Buffers in inference
actual_buf_sizes, vidx_to_abidx = self.__save_variable_buffer()

####################################################################
# Version
version = nnabla.utils.converter.get_category_info_version()

####################################################################
# Varibles name index
vindexes_by_name = {}
for n, v in enumerate(self._info._network.variable):
vindexes_by_name[v.name] = n

####################################################################
# Inputs
input_list = [vindexes_by_name[i]
for i in self._info._input_variables]
index, pointer = self._alloc(data=struct.pack(
'{}I'.format(len(input_list)), *input_list))
inputs = self._List(len(input_list), index)

####################################################################
# Outputs
output_list = [vindexes_by_name[i]
for i in self._info._output_variables]
index, pointer = self._alloc(data=struct.pack(
'{}I'.format(len(output_list)), *output_list))
outputs = self._List(len(output_list), index)

####################################################################
# make 2 data to save Variable Buffers in inference
from .save_variable_buffer import save_variable_buffer
actual_buf_sizes, vidx_to_abidx = save_variable_buffer(self._info)

####################################################################
# Varible buffers
blist = actual_buf_sizes
Expand All @@ -85,12 +108,11 @@ def export(self, nnb_output_filename, settings_template_filename, settings_filen
# Varibles
self._Variable = collections.namedtuple(
'Variable', ('id', 'shape', 'type', 'fp_pos', 'data_index'))
vindexes_by_name = {}
vindexes = []
for n, v in enumerate(self._info._network.variable):
var = self._Variable
var.id = n
vindexes_by_name[v.name] = n

shape = [
x if x >= 0 else self._info._batch_size for x in v.shape.dim]
index, pointer = self._alloc(
Expand Down Expand Up @@ -160,15 +182,15 @@ def export(self, nnb_output_filename, settings_template_filename, settings_filen

finfo = self._info._function_info[f.type]

inputs = [vindexes_by_name[i] for i in f.input]
finputs = [vindexes_by_name[i] for i in f.input]
index, pointer = self._alloc(data=struct.pack(
'{}I'.format(len(inputs)), *inputs))
function_data += struct.pack('iI', len(inputs), index)
'{}I'.format(len(finputs)), *finputs))
function_data += struct.pack('iI', len(finputs), index)

outputs = [vindexes_by_name[o] for o in f.output]
foutputs = [vindexes_by_name[o] for o in f.output]
index, pointer = self._alloc(data=struct.pack(
'{}I'.format(len(outputs)), *outputs))
function_data += struct.pack('iI', len(outputs), index)
'{}I'.format(len(foutputs)), *foutputs))
function_data += struct.pack('iI', len(foutputs), index)

if 'arguments' in finfo and len(finfo['arguments']) > 0:
argfmt = ''
Expand Down Expand Up @@ -207,22 +229,6 @@ def export(self, nnb_output_filename, settings_template_filename, settings_filen
'{}I'.format(len(findexes)), *findexes))
functions = self._List(len(findexes), index)

####################################################################
# Inputs
input_list = [vindexes_by_name[i]
for i in self._info._input_variables]
index, pointer = self._alloc(data=struct.pack(
'{}I'.format(len(input_list)), *input_list))
inputs = self._List(len(input_list), index)

####################################################################
# Outputs
output_list = [vindexes_by_name[i]
for i in self._info._output_variables]
index, pointer = self._alloc(data=struct.pack(
'{}I'.format(len(output_list)), *output_list))
outputs = self._List(len(output_list), index)

network = struct.pack('IiIiIiIiIiIII',
version,
buffers.size,
Expand All @@ -246,170 +252,3 @@ def export(self, nnb_output_filename, settings_template_filename, settings_filen

with open(nnb_output_filename, 'wb') as f:
f.write(network + memory)

def __save_variable_buffer(self):
# make the followings to save memory usage for Variable Buffer:
# - actual_buf_sizes(list): sizes of actual buffers, which lie unfer Variable Buffer.
# indices in this list are hereinafter called 'actual buffer index'
# - vidx_to_abidx(dict): assignment of actual buffers to Variable Buffer.
# the key and the value are Variable index and actual buffer index, respectively
info = self._info
buf_var_lives = self.__make_buf_var_lives(info)
actual_buf_sizes = self.__compute_actual_buf_sizes(info, buf_var_lives)
buf_var_refs = self.__make_buf_var_refs(info, buf_var_lives)
vidx_to_abidx = self.__assign_actual_buf_to_variable(
info, actual_buf_sizes, buf_var_refs)
return (list(actual_buf_sizes), vidx_to_abidx)

def __make_buf_var_lives(self, info):
# buf_var_lives is to remember from when and until when each
# Buffer Variables must be alive
buf_var_num = len(info._variable_buffer_index)
buf_var_lives = [_LifeSpan() for _ in range(buf_var_num)]
name_to_vidx = {v.name: i for i,
v in enumerate(info._network.variable)}
name_to_var = {v.name: v for v in info._network.variable}

# set _LifeSpan.begin_func_idx and .end_func_idx along info._network
for func_idx, func in enumerate(info._network.function):
for var_name in list(func.input) + list(func.output):
if name_to_var[var_name].type == 'Buffer':
var_idx = name_to_vidx[var_name]
buf_idx = info._buffer_ids[var_idx]
buf_var_life = buf_var_lives[buf_idx]
if buf_var_life.begin_func_idx < 0:
buf_var_life.begin_func_idx = func_idx
else:
# only identify a Function which first refers to the Variable
pass
buf_var_life.end_func_idx = func_idx
else:
pass # ignore 'Parameter'

return buf_var_lives

def __count_actual_buf(self, info, buf_var_lives):
# count how many buffers are required at maximum based on buf_var_lives
actual_buf_num = 0
for func_idx, _ in enumerate(info._network.function):
buf_num = 0
for buf_idx, buf_var_life in enumerate(buf_var_lives):
buf_num += int(buf_var_life.needed_at(func_idx))
actual_buf_num = max(actual_buf_num, buf_num)
return actual_buf_num

def __make_buf_var_refs(self, info, buf_var_lives):
# buf_var_refs is to store buffer indices of buffers required in each Function
actual_buf_num = self.__count_actual_buf(info, buf_var_lives)
shape = (len(info._network.function), actual_buf_num)
buf_var_refs = np.empty(shape, dtype=np.int32)
buf_var_refs[:] = -1

# fill buf_var_refs based on buf_var_lives
for func_idx, _ in enumerate(info._network.function):
crsr = 0
for buf_idx, buf_var_life in enumerate(buf_var_lives):
if buf_var_life.needed_at(func_idx):
buf_var_refs[func_idx][crsr] = buf_idx
crsr += 1
else:
pass # only focus on buffers used in this func

return buf_var_refs

def __compute_actual_buf_sizes(self, info, buf_var_lives):
# buf_size_array is to store size values of each actual buffer
actual_buf_num = self.__count_actual_buf(info, buf_var_lives)
buf_size_array = np.zeros(actual_buf_num, dtype=np.int32)

# tmp_size_array is size values when only focusing on a single Function
tmp_size_array = np.empty_like(buf_size_array, dtype=np.int32)
for func_idx, _ in enumerate(info._network.function):
tmp_size_array[:] = -1
crsr = 0
for buf_idx, buf_var_life in enumerate(buf_var_lives):
if buf_var_life.needed_at(func_idx):
tmp_size_array[crsr] = info._variable_buffer_size[buf_idx]
crsr += 1
else:
pass # only focus on buffers used in this func

# update sizes of actual buffers
tmp_size_array = np.sort(tmp_size_array)
for i in range(actual_buf_num):
buf_size_array[i] = max(buf_size_array[i], tmp_size_array[i])

return buf_size_array

def __assign_actual_buf_to_variable(self, info, actual_buf_sizes, buf_var_refs):
# create a dictionary to store assiginment of actual buffers to Variables

# vidx_to_abidx is short for variable index to actual buffer index
vidx_to_abidx = {}

# actual_assigned_flags is to remember if actual buffers are assigned or not
actual_buf_num = len(actual_buf_sizes)
actual_assigned_flags = np.empty(actual_buf_num, dtype=np.bool)

for func_idx, _ in enumerate(info._network.function):
actual_assigned_flags[:] = False
for ref_crsr in range(actual_buf_num):
# minus buf_idx means the corresponding buffer is not needed
buf_idx = buf_var_refs[func_idx][ref_crsr]
if buf_idx < 0:
continue

# restore assignment determined in the previous func_idx
vidx = info._variable_buffer_index[buf_idx][0]
if vidx in vidx_to_abidx:
abidx = vidx_to_abidx[vidx]
actual_assigned_flags[abidx] = True
else:
pass # determine assigment for this vidx in the follwoing for loop

# determine new assigments of actual buffers to Variables
for ref_crsr in range(actual_buf_num):
# minus buf_idx means the corresponding buffer is not needed
buf_idx = buf_var_refs[func_idx][ref_crsr]
if buf_idx < 0:
continue

# skip Variables to which an actual buffer is already assigned
vidx = info._variable_buffer_index[buf_idx][0]
if vidx in vidx_to_abidx:
continue

# search for actual buffers vacant and large enough
needed_size = info._variable_buffer_size[buf_idx]
abidx = 0
while abidx != actual_buf_num:
cond = not actual_assigned_flags[abidx]
cond &= needed_size <= actual_buf_sizes[abidx]
if cond:
actual_assigned_flags[abidx] = True
vidx_to_abidx[vidx] = abidx
break
else:
abidx += 1

# increase size if buffers large enough was NOT found
if abidx == actual_buf_num:
for abidx in range(actual_buf_num):
if not actual_assigned_flags[abidx]:
actual_buf_sizes[abidx] = needed_size
actual_assigned_flags[abidx] = True
vidx_to_abidx[vidx] = abidx
break

return vidx_to_abidx


class _LifeSpan:
def __init__(self):
self.begin_func_idx = -1
self.end_func_idx = -1

def needed_at(self, func_idx):
needed = self.begin_func_idx <= func_idx
needed &= self.end_func_idx >= func_idx
return needed

0 comments on commit af17de8

Please sign in to comment.