New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autograd refactor #1016

Merged
merged 8 commits into from May 1, 2017
Next

Refactor attribute names in autograd

  • Loading branch information...
apaszke committed Mar 15, 2017
commit 56f294b8e599a5ce5727e73e2b1767b629677ea9
View
@@ -272,14 +272,16 @@ def run(self):
"torch/csrc/autograd/engine.cpp",
"torch/csrc/autograd/function.cpp",
"torch/csrc/autograd/variable.cpp",
"torch/csrc/autograd/grad_buffer.cpp",
"torch/csrc/autograd/input_buffer.cpp",
"torch/csrc/autograd/python_function.cpp",
"torch/csrc/autograd/python_cpp_function.cpp",
"torch/csrc/autograd/python_variable.cpp",
"torch/csrc/autograd/python_engine.cpp",
"torch/csrc/autograd/python_hook.cpp",
"torch/csrc/autograd/functions/batch_normalization.cpp",
"torch/csrc/autograd/functions/convolution.cpp",
"torch/csrc/autograd/functions/basic_ops.cpp",
"torch/csrc/autograd/functions/utils.cpp",
"torch/csrc/autograd/functions/init.cpp",
"torch/csrc/nn/THNN_generic.cpp",
]
View
@@ -77,7 +77,7 @@ def to_gpu(obj, type_map={}):
elif torch.is_storage(obj):
return obj.new().resize_(obj.size()).copy_(obj)
elif isinstance(obj, Variable):
assert obj.creator is None
assert obj.is_leaf
t = type_map.get(type(obj.data), get_gpu_type(type(obj.data)))
return Variable(obj.data.clone().type(t), requires_grad=obj.requires_grad)
elif isinstance(obj, list):
View
@@ -270,15 +270,15 @@ def test_volatile(self):
z = x ** 2
self.assertFalse(z.volatile)
self.assertTrue(z.requires_grad)
self.assertIsNotNone(z.creator)
self.assertIsNotNone(z.grad_fn)
z.backward(torch.ones(5, 5))
self.assertEqual(x.grad.data, torch.ones(5, 5) * 2)
w = z + y
self.assertTrue(w.volatile)
self.assertFalse(w.requires_grad)
self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
self.assertIsNone(w.creator)
self.assertIsNone(w.grad_fn)
def test_indexing(self):
x = torch.arange(1, 17).resize_(4, 4)
@@ -376,23 +376,23 @@ def test_backward_no_grad(self):
with self.assertRaises(RuntimeError):
torch.autograd.backward([b], [None])
def test_previous_functions(self):
def test_next_functions(self):
x = Variable(torch.randn(5, 5), requires_grad=True)
y = Variable(torch.randn(5, 5), requires_grad=True)
a = x + y
self.assertIsNotNone(a.creator)
previous_functions = a.creator.previous_functions
self.assertEqual(len(previous_functions), 2)
self.assertIs(previous_functions[0][0], x)
self.assertEqual(previous_functions[0][1], 0)
self.assertIs(previous_functions[1][0], y)
self.assertEqual(previous_functions[1][1], 0)
self.assertIsNotNone(a.grad_fn)
next_functions = a.grad_fn.next_functions
self.assertEqual(len(next_functions), 2)
self.assertIs(next_functions[0][0], x)
self.assertEqual(next_functions[0][1], 0)
self.assertIs(next_functions[1][0], y)
self.assertEqual(next_functions[1][1], 0)
b = a + 5
previous_functions = b.creator.previous_functions
self.assertEqual(len(previous_functions), 1)
self.assertIs(previous_functions[0][0], a.creator)
next_functions = b.grad_fn.next_functions
self.assertEqual(len(next_functions), 1)
self.assertIs(next_functions[0][0], a.grad_fn)
def test_inplace(self):
x = Variable(torch.ones(5, 5), requires_grad=True)
@@ -543,7 +543,7 @@ def __del__(self):
gc.collect()
for i in range(10):
Variable(torch.randn(10, 10), creator=CollectOnDelete())
Variable(torch.randn(10, 10), grad_fn=CollectOnDelete())
@unittest.skipIf(not torch.cuda.is_available() or torch.cuda.device_count() < 2,
"CUDA not available or <2 GPUs detected")
@@ -567,7 +567,7 @@ def test_detach(self):
y = x * 2
y = y.detach()
self.assertFalse(y.requires_grad)
self.assertIsNone(y.creator)
self.assertIsNone(y.grad_fn)
z = x + y
z.sum().backward()
# This is an incorrect gradient, but we assume that's what the user
@@ -669,7 +669,7 @@ def backward(self, grad_a, grad_b):
fn = Inplace(True)
q, p = fn(x, y)
self.assertIs(q, x)
self.assertIs(q.creator, fn)
self.assertIs(q.grad_fn, fn)
self.assertTrue(q.requires_grad)
q.sum().backward()
self.assertEqual(y.grad.data, torch.ones(5, 5))
@@ -682,7 +682,7 @@ def test_leaf_assignment(self):
x[0] = y
x[1] = 2 * z
self.assertTrue(x.requires_grad)
self.assertIsNot(x.creator, None)
self.assertIsNot(x.grad_fn, None)
x.sum().backward()
self.assertEqual(y.grad.data, torch.ones(5))
self.assertEqual(z.grad.data, torch.ones(5) * 2)
@@ -1293,7 +1293,7 @@ def unpack_variables(args):
def do_test(self, cls=cls, constructor_args=new_constructor_args,
call_args=call_args, test_name=test_name):
input = create_input(call_args)
self.assertEqual(gradcheck(cls(*constructor_args), input, eps=1e-6, atol=PRECISION), True)
self.assertTrue(gradcheck(lambda *input: cls(*constructor_args)(*input), input, eps=1e-6, atol=PRECISION))
if test_name not in ignore_inplace and issubclass(cls, InplaceFunction):
output = cls(*constructor_args)(*input)
@@ -38,7 +38,10 @@ def backward(variables, grad_variables, retain_variables=False):
specify ``True`` if you want to differentiate some subgraph multiple
times.
"""
grad_variables = tuple(var if isinstance(var, Variable) or var is None
else Variable(var, volatile=True)
for var in grad_variables)
Variable._execution_engine.run_backward(
tuple(variables), tuple(grad_variables), retain_variables)
tuple(variables), grad_variables, retain_variables)
assert torch._C._autograd_init()
View

This file was deleted.

Oops, something went wrong.
@@ -36,10 +36,6 @@ class Function(_C._FunctionBase):
num_outputs: Number of tensors returned by :func:`forward`.
requires_grad: Boolean indicating whether the :func:`backward` will
ever need to be called.
previous_functions: Tuple of (int, Function) pairs of length
:attr:`num_inputs`. Each entry contains a reference to a
:class:`Function` that created corresponding input, and an index
of the previous function output that's been used.
"""
__call__ = _C._FunctionBase._do_forward
@@ -140,8 +140,8 @@ def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3):
def fn(input):
return _as_tuple(func(*input))[i].data
numerical = get_numerical_jacobian(fn, inputs, inputs, eps)
analytical = get_analytical_jacobian(_as_tuple(inputs), o)
numerical = get_numerical_jacobian(fn, inputs, inputs, eps)
for a, n in zip(analytical, numerical):
if not ((a - n).abs() <= (atol + rtol * n.abs())).all():
View
@@ -13,7 +13,7 @@ class Variable(_C._VariableBase):
Variable is a thin wrapper around a Tensor object, that also holds
the gradient w.r.t. to it, and a reference to a function that created it.
This reference allows retracing the whole chain of operations that
created the data. If the Variable has been created by the user, its creator
created the data. If the Variable has been created by the user, its grad_fn
will be ``None`` and we call such objects *leaf* Variables.
Since autograd only supports scalar valued function differentiation, grad
@@ -33,8 +33,9 @@ class Variable(_C._VariableBase):
inference mode, i.e. don't save the history. See
:ref:`excluding-subgraphs` for more details.
Can be changed only on leaf Variables.
creator: Function of which the variable was an output. For leaf
(user created) variables it's ``None``. Read-only attribute.
is_leaf: Boolean indicating if the Variable is a graph leaf (i.e
if it was created by the user).
grad_fn: Gradient function graph trace.
Parameters:
data (any tensor class): Tensor to wrap.
@@ -82,7 +83,7 @@ def __setitem__(self, key, value):
return SetItem(key, value)(self)
def __deepcopy__(self, memo):
if self.creator is not None:
if not self.is_leaf:
raise RuntimeError("Only Variables created explicitly by the user "
"(graph leaves) support the deepcopy protocol at the moment")
result = type(self)(self.data.clone())
@@ -106,7 +107,7 @@ def __setstate__(self, state):
# legacy serialization of Variable
self.data = state[0]
state = (state[3], state[4], state[2])
if self.creator is not None:
if not self.is_leaf:
raise RuntimeError('__setstate__ can be only called on leaf variables')
self.requires_grad, self.volatile, self._backward_hooks = state
@@ -143,6 +144,10 @@ def backward(self, gradient=None, retain_variables=False):
'backward should be called only on a scalar (i.e. 1-element tensor) '
'or with gradient w.r.t. the variable')
gradient = self.data.new().resize_as_(self.data).fill_(1)
if not isinstance(gradient, Variable):
if gradient is not None and not torch.is_tensor(gradient):
raise TypeError("gradient has to be a Tensor, Variable or None")
gradient = Variable(gradient, volatile=True)
self._execution_engine.run_backward((self,), (gradient,), retain_variables)
def register_hook(self, hook):
@@ -177,8 +182,8 @@ def register_hook(self, hook):
"doesn't require gradient")
if self._backward_hooks is None:
self._backward_hooks = OrderedDict()
if self.creator is not None:
self.creator._register_hook_dict(self)
if self.grad_fn is not None:
self.grad_fn._register_hook_dict(self)
handle = hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[handle.id] = hook
return handle
@@ -194,10 +199,10 @@ def reinforce(self, reward):
reward(Tensor): Tensor with per-element rewards. It has to match
the device location and shape of Variable's data.
"""
if not isinstance(self.creator, StochasticFunction):
if not isinstance(self.grad_fn, StochasticFunction):
raise RuntimeError("reinforce() can be only called on outputs "
"of stochastic functions")
self.creator._reinforce(reward)
self.grad_fn._reinforce(reward)
def detach(self):
"""Returns a new Variable, detached from the current graph.
@@ -212,12 +217,12 @@ def detach(self):
errors in correctness checks.
"""
result = NoGrad()(self) # this is needed, because it merges version counters
result._creator = None
result._grad_fn = None
return result
def detach_(self):
"""Detaches the Variable from the graph that created it, making it a leaf."""
self._creator = None
self._grad_fn = None
self.requires_grad = False
def contiguous(self):
@@ -895,5 +900,5 @@ def addr(cls, *args):
setattr(Variable._torch, method, as_static)
from .engine import ImperativeEngine
from torch._C import _ImperativeEngine as ImperativeEngine
Variable._execution_engine = ImperativeEngine()
Oops, something went wrong.
ProTip! Use n and p to navigate between commits in a pull request.