Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clear grad after propagation #852

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion include/nbla/computation_graph/variable.hpp
Expand Up @@ -263,6 +263,9 @@ class CgVariable {
@param communicator_callbacks The callback functions invoked when 1)
backward computation of each function is finished and
2) all backward computation is finished.
@param clear_initial_grad If true, the input parameter, grad, will be
cleared during backward propagation. This flag is only
activated when grad is set.

@seealso set_persistent() to prevent a specific variable to be cleared
during forward propagation.
Expand All @@ -271,7 +274,8 @@ class CgVariable {
backward(NdArrayPtr grad = nullptr, bool clear_buffer = false,
vector<CommunicatorBackwardCallbackPtr> communicator_callbacks = {},
function_hook_type pre_callback = nullptr,
function_hook_type post_callback = nullptr);
function_hook_type post_callback = nullptr,
const bool clear_initial_grad = false);

/**
*/
Expand Down
2 changes: 1 addition & 1 deletion python/src/nnabla/_variable.pxd
Expand Up @@ -79,7 +79,7 @@ cdef extern from "nbla/computation_graph/variable.hpp" namespace "nbla":
int rank() const
void set_rank(int rank) except+
void forward(cpp_bool clear_buffer, cpp_bool clear_no_need_grad, unordered_set[CgFunctionPtr] *fclosed, function_hook_type function_pre_hook, function_hook_type function_post_hook) nogil except+
void backward(NdArrayPtr grad, cpp_bool clear_buffer, vector[CommunicatorBackwardCallbackPtr] communicator_callbacks, function_hook_type function_pre_hook, function_hook_type function_post_hook) nogil except+
void backward(NdArrayPtr grad, cpp_bool clear_buffer, vector[CommunicatorBackwardCallbackPtr] communicator_callbacks, function_hook_type function_pre_hook, function_hook_type function_post_hook, cpp_bool clear_initial_grad) nogil except+
void set_persistent(cpp_bool b)
cpp_bool persistent()
string name() except +
Expand Down
40 changes: 23 additions & 17 deletions python/src/nnabla/_variable.pyx
Expand Up @@ -749,23 +749,29 @@ cdef class Variable:

"""
cdef NdArrayPtr p
if grad is None:
pass
elif np.isscalar(grad):
arr = NdArray(self.shape)
arr.fill(grad)
p = ( < NdArray > arr).arr
elif isinstance(grad, NdArray):
p = ( < NdArray > grad).arr
elif isinstance(grad, np.ndarray):
arr = NdArray(grad.shape)
arr.data = grad
p = ( < NdArray > arr).arr
cdef cpp_bool clear_initial_grad = False
if isinstance(grad, NdArray):
# Share a user-refered NdArray as a initial grad
clear_initial_grad = False
p = ( < NdArray > grad).arr
else:
# Try to interpret as scalar value
arr = NdArray()
arr.data = grad
p = ( < NdArray > arr).arr
# Use a temporary NdArray as a initial grad
clear_initial_grad = True
if grad is None:
pass
elif np.isscalar(grad):
arr = NdArray(self.shape)
arr.fill(grad)
p = ( < NdArray > arr).arr
elif isinstance(grad, np.ndarray):
arr = NdArray(grad.shape)
arr.data = grad
p = ( < NdArray > arr).arr
else:
# Try to interpret as scalar value
arr = NdArray()
arr.data = grad
p = ( < NdArray > arr).arr

cdef vector[CommunicatorBackwardCallbackPtr] callback_list
if type(communicator_callbacks) == list:
Expand All @@ -783,7 +789,7 @@ cdef class Variable:
function_post_hook_c = create_function_hook_with_object(function_post_hook)

with nogil:
self.varp.backward(p, clear_buffer, callback_list, function_pre_hook_c, function_post_hook_c)
self.varp.backward(p, clear_buffer, callback_list, function_pre_hook_c, function_post_hook_c, clear_initial_grad)

def unlinked(self, need_grad=None):
"""
Expand Down
129 changes: 129 additions & 0 deletions python/test/test_graph.py
Expand Up @@ -620,3 +620,132 @@ def test_clear_input_if_no_need_grad_batch_normalization(self, batch_stat):

y.forward(clear_no_need_grad=True)
self.check_input_data_clear_called_flags(answer)


class TestClearOutputGrad():

def check_grad_cleared_flags(self, answer):
result = clear_called_flag_recorder.get_output_clear_called_flags()
assert len(result) == len(answer)
for i, flags in enumerate(answer):
assert len(result[i]) == len(flags)
for j, flag in enumerate(flags):
assert flag == result[i][j][1]

def setup_method(self):
clear_called_flag_recorder.activate_clear_called_flag_recorder()

def teardown_method(self):
clear_called_flag_recorder.deactivate_clear_called_flag_recorder()

# Test for the type of grad given to backward.
@pytest.mark.parametrize("grad", [1, None, np.ndarray([1]), nn.NdArray([1])])
def test_clear_output_grad_argument(self, grad):
x1 = nn.Variable([1], need_grad=True)

xx1 = F.identity(x1)
y1 = F.add_scalar(xx1)

answer_grad = []
if grad is None or isinstance(grad, nn.NdArray):
answer_grad.append([False]) # y1
else:
answer_grad.append([True]) # y1
answer_grad.append([True]) # xx1

y1.forward(clear_no_need_grad=True)
clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
clear_called_flag_recorder.activate_clear_called_flag_recorder()
y1.backward(clear_buffer=True, grad=grad)

self.check_grad_cleared_flags(answer_grad)
assert y1.grad.clear_called == False

# Test for an inplaced variable.
def test_clear_output_grad_inplace(self):
x1 = nn.Variable([1], need_grad=True)

xx1 = F.identity(x1)
y1 = F.add_scalar(xx1, inplace=True)
y2 = F.add_scalar(y1)

answer_grad = []
answer_grad.append([True])
answer_grad.append([True])
answer_grad.append([True])

y2.forward(clear_no_need_grad=True)
clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
clear_called_flag_recorder.activate_clear_called_flag_recorder()
y2.backward(clear_buffer=True)

self.check_grad_cleared_flags(answer_grad)

# Test for a variable shared with two layer functions.
def test_clear_output_grad_shared_variable(self):
x1 = nn.Variable([1], need_grad=True)

xx1 = F.identity(x1)
y1 = F.add_scalar(xx1)
y2 = F.add_scalar(xx1)
y3 = F.add2(y1, y2)

answer_grad = []
answer_grad.append([True])
answer_grad.append([True])
answer_grad.append([True])
answer_grad.append([True])

y3.forward(clear_no_need_grad=True)
clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
clear_called_flag_recorder.activate_clear_called_flag_recorder()
y3.backward(clear_buffer=True)

self.check_grad_cleared_flags(answer_grad)

# Test for a persistent variable.
def test_clear_output_grad_persistent(self):
x1 = nn.Variable([1], need_grad=True)

xx1 = F.identity(x1)
y1 = F.add_scalar(xx1)
y2 = F.add_scalar(y1)

xx1.persistent = True
y2.persistent = True

answer_grad = []
answer_grad.append([False]) # y2
answer_grad.append([True]) # y1
answer_grad.append([False]) # xx1

y2.forward(clear_no_need_grad=True)
clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
clear_called_flag_recorder.activate_clear_called_flag_recorder()
y2.backward(clear_buffer=True)

self.check_grad_cleared_flags(answer_grad)

# Test for the input variables of sink.
# In the case where Function::prohibit_clear_input_buffers returns true,
# these inputs must not be cleared from any function.
def test_clear_output_grad_prohibit_clear_input(self):
x1 = nn.Variable([1], need_grad=True)

xx1 = F.identity(x1)
y1 = F.add_scalar(xx1)
y2 = F.add_scalar(xx1)
y3 = F.sink(y1, y2)

answer_grad = []
answer_grad.append([True]) # y3
answer_grad.append([False]) # y2
answer_grad.append([False]) # y1
answer_grad.append([True]) # xx1

y3.forward(clear_no_need_grad=True)
clear_called_flag_recorder.deactivate_clear_called_flag_recorder()
clear_called_flag_recorder.activate_clear_called_flag_recorder()
y3.backward(clear_buffer=True)

self.check_grad_cleared_flags(answer_grad)