Skip to content

Commit

Permalink
Merge pull request #852 from sony/feature/20210304-clear-grad-after-p…
Browse files Browse the repository at this point in the history
…ropagation

Clear grad after propagation
  • Loading branch information
TakuyaNarihira committed Apr 7, 2021
2 parents c294f2f + f660473 commit f7d1607
Show file tree
Hide file tree
Showing 5 changed files with 271 additions and 70 deletions.
6 changes: 5 additions & 1 deletion include/nbla/computation_graph/variable.hpp
Expand Up @@ -263,6 +263,9 @@ class CgVariable {
@param communicator_callbacks The callback functions invoked when 1)
backward computation of each function is finished and
2) all backward computation is finished.
@param clear_initial_grad If true, the input parameter, grad, will be
cleared during backward propagation. This flag is only
activated when grad is set.
@seealso set_persistent() to prevent a specific variable to be cleared
during forward propagation.
Expand All @@ -271,7 +274,8 @@ class CgVariable {
backward(NdArrayPtr grad = nullptr, bool clear_buffer = false,
vector<CommunicatorBackwardCallbackPtr> communicator_callbacks = {},
function_hook_type pre_callback = nullptr,
function_hook_type post_callback = nullptr);
function_hook_type post_callback = nullptr,
const bool clear_initial_grad = false);

/**
*/
Expand Down
2 changes: 1 addition & 1 deletion python/src/nnabla/_variable.pxd
Expand Up @@ -79,7 +79,7 @@ cdef extern from "nbla/computation_graph/variable.hpp" namespace "nbla":
int rank() const
void set_rank(int rank) except+
void forward(cpp_bool clear_buffer, cpp_bool clear_no_need_grad, unordered_set[CgFunctionPtr] *fclosed, function_hook_type function_pre_hook, function_hook_type function_post_hook) nogil except+
void backward(NdArrayPtr grad, cpp_bool clear_buffer, vector[CommunicatorBackwardCallbackPtr] communicator_callbacks, function_hook_type function_pre_hook, function_hook_type function_post_hook) nogil except+
void backward(NdArrayPtr grad, cpp_bool clear_buffer, vector[CommunicatorBackwardCallbackPtr] communicator_callbacks, function_hook_type function_pre_hook, function_hook_type function_post_hook, cpp_bool clear_initial_grad) nogil except+
void set_persistent(cpp_bool b)
cpp_bool persistent()
string name() except +
Expand Down
40 changes: 23 additions & 17 deletions python/src/nnabla/_variable.pyx
Expand Up @@ -749,23 +749,29 @@ cdef class Variable:
"""
cdef NdArrayPtr p
if grad is None:
pass
elif np.isscalar(grad):
arr = NdArray(self.shape)
arr.fill(grad)
p = ( < NdArray > arr).arr
elif isinstance(grad, NdArray):
p = ( < NdArray > grad).arr
elif isinstance(grad, np.ndarray):
arr = NdArray(grad.shape)
arr.data = grad
p = ( < NdArray > arr).arr
cdef cpp_bool clear_initial_grad = False
if isinstance(grad, NdArray):
# Share a user-refered NdArray as a initial grad
clear_initial_grad = False
p = ( < NdArray > grad).arr
else:
# Try to interpret as scalar value
arr = NdArray()
arr.data = grad
p = ( < NdArray > arr).arr
# Use a temporary NdArray as a initial grad
clear_initial_grad = True
if grad is None:
pass
elif np.isscalar(grad):
arr = NdArray(self.shape)
arr.fill(grad)
p = ( < NdArray > arr).arr
elif isinstance(grad, np.ndarray):
arr = NdArray(grad.shape)
arr.data = grad
p = ( < NdArray > arr).arr
else:
# Try to interpret as scalar value
arr = NdArray()
arr.data = grad
p = ( < NdArray > arr).arr

cdef vector[CommunicatorBackwardCallbackPtr] callback_list
if type(communicator_callbacks) == list:
Expand All @@ -783,7 +789,7 @@ cdef class Variable:
function_post_hook_c = create_function_hook_with_object(function_post_hook)

with nogil:
self.varp.backward(p, clear_buffer, callback_list, function_pre_hook_c, function_post_hook_c)
self.varp.backward(p, clear_buffer, callback_list, function_pre_hook_c, function_post_hook_c, clear_initial_grad)

def unlinked(self, need_grad=None):
"""
Expand Down
129 changes: 129 additions & 0 deletions python/test/test_graph.py
Expand Up @@ -620,3 +620,132 @@ def test_clear_input_if_no_need_grad_batch_normalization(self, batch_stat):

y.forward(clear_no_need_grad=True)
self.check_input_data_clear_called_flags(answer)


class TestClearOutputGrad():

def check_grad_cleared_flags(self, answer):
result = clear_called_flag_recorder.get_output_clear_called_flags()
assert len(result) == len(answer)
for i, flags in enumerate(answer):
assert len(result[i]) == len(flags)
for j, flag in enumerate(flags):
assert flag == result[i][j][1]

def setup_method(self):
clear_called_flag_recorder.activate_clear_called_flag_recorder()

def teardown_method(self):
clear_called_flag_recorder.deactivate_clear_called_flag_recorder()

# Test for the type of grad given to backward.
@pytest.mark.parametrize("grad", [1, None, np.ndarray([1]), nn.NdArray([1])])
def test_clear_output_grad_argument(self, grad):
x1 = nn.Variable([1], need_grad=True)

xx1 = F.identity(x1)
y1 = F.add_scalar(xx1)

answer_grad = []
if grad is None or isinstance(grad, nn.NdArray):
answer_grad.append([False]) # y1
else:
answer_grad.append([True]) # y1
answer_grad.append([True]) # xx1

y1.forward(clear_no_need_grad=True)
clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
clear_called_flag_recorder.activate_clear_called_flag_recorder()
y1.backward(clear_buffer=True, grad=grad)

self.check_grad_cleared_flags(answer_grad)
assert y1.grad.clear_called == False

# Test for an inplaced variable.
def test_clear_output_grad_inplace(self):
x1 = nn.Variable([1], need_grad=True)

xx1 = F.identity(x1)
y1 = F.add_scalar(xx1, inplace=True)
y2 = F.add_scalar(y1)

answer_grad = []
answer_grad.append([True])
answer_grad.append([True])
answer_grad.append([True])

y2.forward(clear_no_need_grad=True)
clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
clear_called_flag_recorder.activate_clear_called_flag_recorder()
y2.backward(clear_buffer=True)

self.check_grad_cleared_flags(answer_grad)

# Test for a variable shared with two layer functions.
def test_clear_output_grad_shared_variable(self):
x1 = nn.Variable([1], need_grad=True)

xx1 = F.identity(x1)
y1 = F.add_scalar(xx1)
y2 = F.add_scalar(xx1)
y3 = F.add2(y1, y2)

answer_grad = []
answer_grad.append([True])
answer_grad.append([True])
answer_grad.append([True])
answer_grad.append([True])

y3.forward(clear_no_need_grad=True)
clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
clear_called_flag_recorder.activate_clear_called_flag_recorder()
y3.backward(clear_buffer=True)

self.check_grad_cleared_flags(answer_grad)

# Test for a persistent variable.
def test_clear_output_grad_persistent(self):
x1 = nn.Variable([1], need_grad=True)

xx1 = F.identity(x1)
y1 = F.add_scalar(xx1)
y2 = F.add_scalar(y1)

xx1.persistent = True
y2.persistent = True

answer_grad = []
answer_grad.append([False]) # y2
answer_grad.append([True]) # y1
answer_grad.append([False]) # xx1

y2.forward(clear_no_need_grad=True)
clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
clear_called_flag_recorder.activate_clear_called_flag_recorder()
y2.backward(clear_buffer=True)

self.check_grad_cleared_flags(answer_grad)

# Test for the input variables of sink.
# In the case where Function::prohibit_clear_input_buffers returns true,
# these inputs must not be cleared from any function.
def test_clear_output_grad_prohibit_clear_input(self):
x1 = nn.Variable([1], need_grad=True)

xx1 = F.identity(x1)
y1 = F.add_scalar(xx1)
y2 = F.add_scalar(xx1)
y3 = F.sink(y1, y2)

answer_grad = []
answer_grad.append([True]) # y3
answer_grad.append([False]) # y2
answer_grad.append([False]) # y1
answer_grad.append([True]) # xx1

y3.forward(clear_no_need_grad=True)
clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
clear_called_flag_recorder.activate_clear_called_flag_recorder()
y3.backward(clear_buffer=True)

self.check_grad_cleared_flags(answer_grad)

0 comments on commit f7d1607

Please sign in to comment.