Skip to content

Commit

Permalink
Auto-format
Browse files Browse the repository at this point in the history
  • Loading branch information
YukioOobuchi committed Apr 20, 2018
1 parent d63b8a2 commit 261dae2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
34 changes: 19 additions & 15 deletions python/src/nnabla/utils/converter/nnablart/nnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from .utils import create_nnabart_info


class NnbExporter:
def _align(self, size):
return int(math.ceil(size / 4) * 4)
Expand Down Expand Up @@ -233,16 +234,18 @@ def __save_variable_buffer(self):
buf_var_lives = self.__make_buf_var_lives(info)
actual_buf_sizes = self.__compute_actual_buf_sizes(info, buf_var_lives)
buf_var_refs = self.__make_buf_var_refs(info, buf_var_lives)
vidx_to_abidx = self.__assign_actual_buf_to_variable(info, actual_buf_sizes, buf_var_refs)
vidx_to_abidx = self.__assign_actual_buf_to_variable(
info, actual_buf_sizes, buf_var_refs)
return (list(actual_buf_sizes), vidx_to_abidx)

def __make_buf_var_lives(self, info):
# buf_var_lives is to remember from when and until when each
# Buffer Variables must be alive
buf_var_num = len(info._variable_buffer_index)
buf_var_num = len(info._variable_buffer_index)
buf_var_lives = [_LifeSpan() for _ in range(buf_var_num)]
name_to_vidx = {v.name: i for i, v in enumerate(info._network.variable)}
name_to_var = {v.name: v for v in info._network.variable}
name_to_vidx = {v.name: i for i,
v in enumerate(info._network.variable)}
name_to_var = {v.name: v for v in info._network.variable}

# set _LifeSpan.begin_func_idx and .end_func_idx along info._network
for func_idx, func in enumerate(info._network.function):
Expand All @@ -258,7 +261,7 @@ def __make_buf_var_lives(self, info):
pass
buf_var_life.end_func_idx = func_idx
else:
pass # ignore 'Parameter'
pass # ignore 'Parameter'

return buf_var_lives

Expand Down Expand Up @@ -287,7 +290,7 @@ def __make_buf_var_refs(self, info, buf_var_lives):
buf_var_refs[func_idx][crsr] = buf_idx
crsr += 1
else:
pass # only focus on buffers used in this func
pass # only focus on buffers used in this func

return buf_var_refs

Expand All @@ -306,7 +309,7 @@ def __compute_actual_buf_sizes(self, info, buf_var_lives):
tmp_size_array[crsr] = info._variable_buffer_size[buf_idx]
crsr += 1
else:
pass # only focus on buffers used in this func
pass # only focus on buffers used in this func

# update sizes of actual buffers
tmp_size_array = np.sort(tmp_size_array)
Expand Down Expand Up @@ -339,7 +342,7 @@ def __assign_actual_buf_to_variable(self, info, actual_buf_sizes, buf_var_refs):
abidx = vidx_to_abidx[vidx]
actual_assigned_flags[abidx] = True
else:
pass # determine assigment for this vidx in the follwoing for loop
pass # determine assigment for this vidx in the follwoing for loop

# determine new assigments of actual buffers to Variables
for ref_crsr in range(actual_buf_num):
Expand All @@ -357,11 +360,11 @@ def __assign_actual_buf_to_variable(self, info, actual_buf_sizes, buf_var_refs):
needed_size = info._variable_buffer_size[buf_idx]
abidx = 0
while abidx != actual_buf_num:
cond = not actual_assigned_flags[abidx]
cond = not actual_assigned_flags[abidx]
cond &= needed_size <= actual_buf_sizes[abidx]
if cond:
actual_assigned_flags[abidx] = True
vidx_to_abidx[vidx] = abidx
vidx_to_abidx[vidx] = abidx
break
else:
abidx += 1
Expand All @@ -370,19 +373,20 @@ def __assign_actual_buf_to_variable(self, info, actual_buf_sizes, buf_var_refs):
if abidx == actual_buf_num:
for abidx in range(actual_buf_num):
if not actual_assigned_flags[abidx]:
actual_buf_sizes[abidx] = needed_size
actual_buf_sizes[abidx] = needed_size
actual_assigned_flags[abidx] = True
vidx_to_abidx[vidx] = abidx
vidx_to_abidx[vidx] = abidx
break

return vidx_to_abidx


class _LifeSpan:
def __init__(self):
self.begin_func_idx = -1
self.end_func_idx = -1
self.end_func_idx = -1

def needed_at(self, func_idx):
needed = self.begin_func_idx <= func_idx
needed &= self.end_func_idx >= func_idx
needed = self.begin_func_idx <= func_idx
needed &= self.end_func_idx >= func_idx
return needed
2 changes: 0 additions & 2 deletions src/nbla/function/global_average_pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ namespace nbla {
// float
template class GlobalAveragePooling<float>;


// half
template class GlobalAveragePooling<Half>;

}

0 comments on commit 261dae2

Please sign in to comment.