Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of leaf Variables in autograd #391

Merged
merged 3 commits into from Jan 2, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
95 changes: 86 additions & 9 deletions test/test_autograd.py
Expand Up @@ -360,23 +360,23 @@ def test_unused_output_gpu(self):
y.sum().backward()
self.assertEqual(x.grad, torch.ones(5, 5) * 2)

def test_no_grad(self):
def test_detach(self):
x = Variable(torch.randn(10, 10), requires_grad=True)
y = x + 2
y = y.no_grad()
y = y.detach()

This comment was marked as off-topic.

This comment was marked as off-topic.

z = y * 4 + 2
self.assertFalse(y.requires_grad)
self.assertFalse(z.requires_grad)

x = Variable(torch.randn(10, 10), requires_grad=True)
y = x * 2
y = y.no_grad()
y = y.detach()
self.assertFalse(y.requires_grad)
self.assertFalse(y.creator.requires_grad)
z = x + y
z.sum().backward()
# This is an incorrect gradient, but we assume that's what the user
# wanted. no_grad() is an advanced option.
# wanted. detach() is an advanced option.
self.assertEqual(x.grad, torch.ones(10, 10))

def test_type_conversions(self):
Expand All @@ -399,6 +399,67 @@ def test_type_conversions(self):
self.assertIs(type(x2.data), torch.cuda.FloatTensor)
self.assertIs(x2.get_device(), 1)

def test_return_leaf(self):
class Identity(Function):
def forward(self, a, b):
return a, a + b

def backward(self, grad_a, grad_b):
return grad_a + grad_b, grad_b

class Inplace(InplaceFunction):
def forward(self, a, b):
self.mark_dirty(a)
return a.add_(b), b + 2

def backward(self, grad_a, grad_b):
return grad_a, grad_a + grad_b

x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5), requires_grad=True)

q, p = Identity()(x, y)
# Make sure hooks only receive grad from usage of q, not x.
q.register_hook(
'test', lambda grad: self.assertEqual(grad, torch.ones(5, 5)))
(q + p + x).sum().backward()
self.assertEqual(x.grad, torch.ones(5, 5) * 3)
self.assertEqual(y.grad, torch.ones(5, 5))
del q, p # these need to be freed, or next part will raise an error

def test_return_leaf_inplace(self):
class Inplace(InplaceFunction):
def forward(self, a, b):
self.mark_dirty(a)
return a.add_(b), b + 2

def backward(self, grad_a, grad_b):
return grad_a, grad_a + grad_b

x = Variable(torch.randn(5, 5))
y = Variable(torch.randn(5, 5), requires_grad=True)

fn = Inplace(True)
q, p = fn(x, y)
self.assertIs(q, x)
self.assertIs(q.creator, fn)
self.assertTrue(q.requires_grad)
q.sum().backward()
self.assertEqual(y.grad, torch.ones(5, 5))

def test_leaf_assignment(self):
x = Variable(torch.randn(5, 5))
y = Variable(torch.randn(5), requires_grad=True)
z = Variable(torch.randn(5), requires_grad=True)

x[0] = y
x[1] = 2 * z
self.assertTrue(x.requires_grad)
self.assertIsNot(x.creator, None)
x.sum().backward()
self.assertEqual(y.grad, torch.ones(5))
self.assertEqual(z.grad, torch.ones(5) * 2)

def test_backward_copy(self):
# This tests checks backward engine for a very subtle bug that appreared
# in one of the initial versions of autograd. Gradients tensors were
Expand Down Expand Up @@ -480,18 +541,18 @@ def test_save_none_for_backward(self):
class MyFn(Function):
def forward(self, input):
self.save_for_backward(None, input, None)
return input
return input * input

def backward(self, grad_output):
n1, input, n2 = self.saved_tensors
test_case.assertIsNone(n1)
test_case.assertIsNone(n2)
return input * grad_output
return 2 * input * grad_output

x = Variable(torch.randn(5, 5), requires_grad=True)
y = MyFn()(x)
y.sum().backward()
self.assertEqual(x.grad, x.data)
self.assertEqual(x.grad, 2 * x.data)

def test_too_many_grads(self):
class MyFn(Function):
Expand Down Expand Up @@ -582,8 +643,20 @@ def assert_strict_equal(var1, var2):
assert_strict_equal(zc, z)


def index_variable(num_indices, max_indices):
index = torch.randperm(max_indices)[:num_indices].long()
def index_variable(shape, max_indices):
if not isinstance(shape, tuple):
shape = (shape,)
index = torch.rand(*shape).mul_(max_indices).floor_().long()
return Variable(index, requires_grad=False)

def gather_variable(shape, index_dim, max_indices):
assert len(shape) == 2
assert index_dim < 2
batch_dim = 1 - index_dim
index = torch.LongTensor(*shape)
for i in range(shape[index_dim]):
index.select(index_dim, i).copy_(
torch.randperm(max_indices)[:shape[batch_dim]])
return Variable(index, requires_grad=False)


Expand Down Expand Up @@ -679,6 +752,10 @@ def index_variable(num_indices, max_indices):
(IndexCopy, (0,), ((S, S), index_variable(2, S), (2, S)) ),
(IndexFill, (0, 2), ((S, S), index_variable(2, S)) ),
(IndexSelect, (0,), ((S, S), index_variable(2, S)) ),
(Gather, (0,), ((M, S), gather_variable((S, S), 1, M)) ),
(Gather, (1,), ((M, S), gather_variable((M, S//2), 0, S)), 'dim1'),
(Scatter, (0,), ((M, S), gather_variable((S, S), 1, M), (S, S))),
(Scatter, (1,), ((M, S), gather_variable((M, S//2), 0, S), (M, S//2)), 'dim1'),
(Concat, (0,), ((1, S, S), (2, S, S), (3, S, S)) ),
(Resize, (S*S, S), ((S, S, S),) ),
(Diag, (), ((S, S),), '2d' ),
Expand Down
47 changes: 46 additions & 1 deletion torch/autograd/_functions/tensor.py
Expand Up @@ -521,7 +521,52 @@ def backward(self, *grad_output):
return grad_input


# TODO: gather
class Gather(Function):

def __init__(self, dim):
super(Gather, self).__init__()
self.dim = dim

def forward(self, input, index):
assert not self.needs_input_grad[1], "Gather can't differentiate " \
"the index"
self.input_size = input.size()
self.save_for_backward(index)
return input.gather(self.dim, index)

def backward(self, grad_output):
index, = self.saved_tensors
grad_input = grad_output.new(self.input_size).zero_()
return grad_input.scatter_(self.dim, index, grad_output), None


class Scatter(InplaceFunction):

def __init__(self, dim, inplace=False):
super(Scatter, self).__init__(inplace)
self.dim = dim

def forward(self, input, index, source):
assert not self.needs_input_grad[1], "Scatter can't differentiate " \
"the index"
if self.inplace:
self.mark_dirty(input)
else:
input = input.clone()
self.save_for_backward(index)
return input.scatter_(self.dim, index, source)

def backward(self, grad_output):
index, = self.saved_tensors
grad_input = grad_source = None
if self.needs_input_grad[0]:
grad_input = grad_output.clone()
grad_input.scatter_(self.dim, index, 0)
if self.needs_input_grad[2]:
grad_source = grad_output.gather(self.dim, index)
return grad_input, None, grad_source


# TODO: kthvalue
# TODO: repeat
# TODO: sort
Expand Down
13 changes: 11 additions & 2 deletions torch/autograd/variable.py
Expand Up @@ -99,7 +99,7 @@ def requires_grad(self, value):
if value is False:
hint = (" If you want to use a computed variable in a subgraph "
"that doesn't require differentiation use "
"var_no_grad = var.no_grad().")
"var_no_grad = var.detach().")
else:
hint = ''
raise RuntimeError("you can only change requires_grad flags of "
Expand Down Expand Up @@ -259,7 +259,7 @@ def reinforce(self, reward):
"of stochastic functions")
self.creator._reinforce(reward)

def no_grad(self):
def detach(self):
"""Detaches the Variable from the graph that created it."""
return NoGrad()(self)

Expand Down Expand Up @@ -628,6 +628,15 @@ def index_fill_(self, dim, index, value):
def index_select(self, dim, index):
return IndexSelect(dim)(self, index)

def gather(self, dim, index):
return Gather(dim)(self, index)

def scatter(self, dim, index, source):
return Scatter(dim)(self, index, source)

def scatter_(self, dim, index, source):
return Scatter(dim, True)(self, index, source)

def masked_copy(self, mask, variable):
return MaskedCopy()(self, mask, variable)

Expand Down
66 changes: 52 additions & 14 deletions torch/csrc/autograd/function.cpp
Expand Up @@ -2,6 +2,7 @@
#include <structmember.h>

#include <unordered_map>
#include <unordered_set>
#include <exception>

#include "THP.h"
Expand Down Expand Up @@ -101,7 +102,8 @@ PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)

using t2var_type = std::unordered_map<PyObject *, THPVariable *>;

static void _mark_dirty(THPFunction *self, t2var_type &t2var)
static void _mark_dirty(THPFunction *self, t2var_type &t2var,
std::unordered_set<PyObject *> &dirty_inputs)
{
// Increase versions of modified tensors
if (!self->dirty_tensors) return;
Expand All @@ -112,6 +114,7 @@ static void _mark_dirty(THPFunction *self, t2var_type &t2var)
Py_ssize_t num_dirty = PyTuple_GET_SIZE(self->dirty_tensors);
for (int i = 0; i < num_dirty; i++) {
PyObject *tensor = PyTuple_GET_ITEM(self->dirty_tensors, i);
dirty_inputs.insert(tensor);
THPVariable *variable;
try {
variable = t2var.at(tensor);
Expand All @@ -135,7 +138,8 @@ static void _mark_dirty(THPFunction *self, t2var_type &t2var)
}

static void _wrap_outputs(THPFunction *self, t2var_type &t2var,
PyObject *raw_output, PyObject *outputs)
std::unordered_set<PyObject *> &dirty_inputs, PyObject *raw_output,
PyObject *outputs)
{
// Wrap outputs in Variables
Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output);
Expand All @@ -161,16 +165,49 @@ static void _wrap_outputs(THPFunction *self, t2var_type &t2var,
Py_INCREF(self);
input_var->creator = (PyObject*)self;
} else {
// If it's a leaf it's not as simple. Leaves will raise an error in
// backward if they've been changed, or they're no longer leaves. In
// some cases (e.g. broadcast) it's perfectly valid to return the same
// tensor untouched, so instead of moving it we're going to create a
// copy and join their version counters. This works for broadcast,
// and if the use wasn't valid we'll still detect an error, because
// the leaf will have a version != 0.
output_var = (THPVariable*)THPVariable_New(output, (PyObject*)self, self->requires_grad);
if (!output_var) throw python_error();
output_var->version_counter->join_with(*input_var->version_counter);
// If the Variable has been changed, we have to move it after the
// current function to ensure the gradient is computed correctly.
// There are two cases now:
// 1. If it requires grad, it is an error, and this will be caught
// when its _do_backward is called, because it won't be a leaf anymore.
// Also we'll change its version.
// 2. If it doesn't require grad, we can safely move it in the graph,
// because its _do_backward will never be called.
if (dirty_inputs.count(output) > 0) {
Py_INCREF(input_var);
output_var = input_var;
Py_INCREF(self);
output_var->creator = (PyObject*)self;
if (!output_var->requires_grad && self->requires_grad) {
// Now, there's another subtlety. We move the input in the graph
// and we change its requires_grad to True. However, remember
// that we're still holding a reference to is as a previous
// function. Backward engine will think that it was really a
// leaf that initialy did require grad and call its _do_backward
// and that will throw. Because of this, we need to allocate
// a dummy leaf that doesn't require grad and put it as our
// previous function.
output_var->requires_grad = self->requires_grad;
PyObject* dummy_prev_fn = THPVariable_New(output, NULL, false);
if (!dummy_prev_fn) throw python_error();
self->previous_functions[i] = THPFunctionPtr(dummy_prev_fn, 0);
}
} else {
// An input has been returned, but it wasn't modified. It's better
// not to move the Variable, because there are some legitimate cases
// where making it non-leaf would break stuff (e.g. broadcast). Also,
// returning the input Variable is not a good option either,
// because if someone registers hooks on it, they will fire with grads
// from all usages, not only from usages of this output. This is why
// we'll return a copy and join their version counters. This has
// a side-effect of making in-place ops on any of these Variables an
// immediate error, but it would be raised anyway once someone
// calls backward.
output_var = (THPVariable*)THPVariable_New(output, (PyObject*)self,
self->requires_grad);
if (!output_var) throw python_error();
output_var->version_counter->join_with(*input_var->version_counter);
}
}
}
if (!output_var) throw python_error();
Expand Down Expand Up @@ -390,8 +427,9 @@ PyObject *THPFunction_do_forward(THPFunction *self, PyObject *inputs)
self->previous_functions[i] = THPFunctionPtr(prev_fn, input_var->output_nr);
}

_mark_dirty(self, t2var);
_wrap_outputs(self, t2var, raw_output, outputs);
std::unordered_set<PyObject *> dirty_inputs;
_mark_dirty(self, t2var, dirty_inputs);
_wrap_outputs(self, t2var, dirty_inputs, raw_output, outputs);
_join_version_counters(self, t2var);
if (self->requires_grad ||
PyObject_IsInstance((PyObject*)self, THPStochasticFunctionClass)) {
Expand Down
8 changes: 7 additions & 1 deletion torch/csrc/autograd/variable.cpp
Expand Up @@ -42,13 +42,19 @@ PyObject * THPVariable_NewVolatile(PyObject *data)
}

// This function DOES NOT steal a reference to data and creator
// To create a leaf Variable pass NULL as creator.
PyObject * THPVariable_New(PyObject *data, PyObject *creator, char requires_grad)
{
if (num_cached > 0) {
Py_INCREF(data);
Py_INCREF(creator);
Py_XINCREF(creator);
return (PyObject*)pop_cache(data, creator, requires_grad);
}
// We can't pass a NULL creator to this Python call, because Py_BuildValue
// will raise an error (it tries to be overly smart by setting its own error
// if there's no flag set at the moment and we're giving NULL to some
// function).
creator = creator ? creator : Py_None;
return PyObject_CallFunction(THPVariableClass, "OObb", data, creator, (char)0, requires_grad);
}

Expand Down