Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fixes #852

Merged
merged 8 commits into from Feb 26, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/source/nn.rst
Expand Up @@ -28,6 +28,18 @@ Containers
.. autoclass:: Sequential
:members:

:hidden:`ModuleList`
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ModuleList
:members:

:hidden:`ParameterList`
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: ParameterList
:members:

Convolution Layers
----------------------------------

Expand Down
75 changes: 57 additions & 18 deletions test/test_autograd.py
Expand Up @@ -225,14 +225,49 @@ def test_volatile(self):

def test_indexing(self):
x = torch.range(1, 16).resize_(4, 4)
y = Variable(x)
self.assertEqual(x[1], y[1].data)
self.assertEqual(x[1, 1], y[1, 1].data[0])
self.assertEqual(x[1:], y[1:].data)
self.assertEqual(x[:2], y[:2].data)
self.assertEqual(x[:2, 2], y[:2, 2].data)
self.assertEqual(x[1:2, 2], y[1:2, 2].data)
self.assertEqual(x[1, 2:], y[1, 2:].data)
y = Variable(x, requires_grad=True)

def check_index(idx):
y.grad.data.zero_()
indexed_tensor = x[idx]
indexed_var = y[idx]

indexed_var_t = indexed_var.data
if not torch.is_tensor(indexed_tensor):
indexed_var_t = indexed_var_t[0]
self.assertEqual(indexed_tensor, indexed_var)

indexed_var.sum().backward()
expected_grad = torch.zeros(4, 4)
expected_grad[idx] = 1
self.assertEqual(y.grad.data, expected_grad)

check_index(1)
check_index((1, 1))
check_index(slice(1, None))
check_index(slice(None, 2))
check_index((slice(None, 2), 2))
check_index((slice(1, 2), 2))
check_index((1, slice(2, None)))
check_index((slice(None, None), slice(2, None)))
check_index(torch.LongTensor([0, 2]))
check_index(torch.rand(4, 4).bernoulli().byte())
check_index((Ellipsis, slice(2, None)))

def test_basic_op_grad(self):
"""Grad output might need to be reshaped to match the second argument."""
x = Variable(torch.randn(4, 6), requires_grad=True)
b = Variable(torch.rand(12, 1) + 1e-2, requires_grad=True)

def y():
# .mm() depends on the grad_output being of correct size
return b.mm(Variable(torch.rand(1, 2) + 1e-2))

(x + y()).sum().backward()
(x - y()).sum().backward()
(x * y()).sum().backward()
(x / y()).sum().backward()
(x.abs() ** y()).sum().backward()

def test_requires_grad(self):
x = Variable(torch.randn(5, 5))
Expand Down Expand Up @@ -452,7 +487,6 @@ def test_detach(self):
self.assertEqual(y.grad.data, torch.ones(10, 10) * 2)

def test_type_conversions(self):
import torch.cuda
x = Variable(torch.randn(5, 5))
self.assertIs(type(x.float().data), torch.FloatTensor)
self.assertIs(type(x.int().data), torch.IntTensor)
Expand Down Expand Up @@ -836,7 +870,10 @@ def gather_variable(shape, index_dim, max_indices):
(Index, (slice(0, 3),), (torch.rand(S, S, S),), 'slice'),
(Index, ((slice(0, 3), 1),), (torch.rand(S, S, S),), 'slice_index'),
(View, (S * S, S), (torch.rand(S, S, S),)),
(Expand, ((S, 5, S, 5),), ((S, 1, S, 1),)),
(Expand, ((5, S, 5, S, 5),), ((1, S, 1, S, 1),)),
(Expand, ((S, S, S),), ((S, 1),), 'new_dim'),
(Expand, ((S, S, S),), ((1, S),), 'new_dim_front'),
(Expand, ((S, S, S),), ((1,),), 'scalar'),
(Exp, (), (torch.rand(S, S, S),)),
(Log, (), (torch.rand(S, S, S) + 1e-2,)),
(Log1p, (), (torch.rand(S, S, S),)),
Expand Down Expand Up @@ -886,7 +923,7 @@ def gather_variable(shape, index_dim, max_indices):
(Addr, (0.1, 0.4), ((S, M), (S,), (M,)), 'coef'),
(Dot, (), ((L,), (L,)),),
(Max, (), ((S, S, S),),),
(Repeat, (torch.Size([2, 3, 1, 4]),), ((S, S, S, S),)),
(Repeat, (torch.Size([2, 3, 1, 2]),), ((S, S, S, S),)),
(Min, (), ((S, S, S),),),
(Max, (0,), ((S, S, S),), 'dim'),
(Min, (0,), ((S, S, S),), 'dim'),
Expand Down Expand Up @@ -952,8 +989,10 @@ def gather_variable(shape, index_dim, max_indices):
('t', (1, 2), ()),
('view', (S, S, S), (S * S, S),),
('view_as', (S, S, S), ((S * S, S),)),
('expand', (S, 1, S), (S, S, S)),
('expand', (S, 1, 1), (S, S, S)),
('expand', (torch.Size([S, 1, S]),), (S, S, S), 'size'),
('expand', (S, 1), (S, S, S), 'new_dim'),
('expand', (1,), (S, S, S), 'scalar'),
('exp', (S, S, S), ()),
('log', (S, S, S), ()),
('log1p', (S, S, S), ()),
Expand Down Expand Up @@ -1055,18 +1094,18 @@ def gather_variable(shape, index_dim, max_indices):
# TODO: clamp with min/max


def create_input(call_args):
def create_input(call_args, requires_grad=True):
if not isinstance(call_args, tuple):
call_args = (call_args,)

def map_arg(arg):
if isinstance(arg, tuple) and not isinstance(arg[0], Variable):
return Variable(torch.randn(*arg).double(), requires_grad=True)
return Variable(torch.randn(*arg).double(), requires_grad=requires_grad)
elif torch.is_tensor(arg):
if isinstance(arg, torch.FloatTensor):
return Variable(arg.double(), requires_grad=True)
return Variable(arg.double(), requires_grad=requires_grad)
else:
return Variable(arg, requires_grad=True)
return Variable(arg, requires_grad=requires_grad)
else:
return arg
return tuple(map_arg(arg) for arg in call_args)
Expand Down Expand Up @@ -1150,8 +1189,8 @@ def fn(input):

def do_test(self, name=name, self_size=self_size, args=args, test_name=test_name):
def check(name):
self_variable = create_input((self_size,))[0]
args_variable = create_input(args)
self_variable = create_input((self_size,), requires_grad=False)[0]
args_variable = create_input(args, requires_grad=False)
self_tensor = deepcopy(self_variable.data)
args_tensor = deepcopy(unpack_variables(args_variable))
output_variable = getattr(self_variable, name)(*args_variable)
Expand Down
2 changes: 1 addition & 1 deletion test/test_nn.py
Expand Up @@ -1088,7 +1088,7 @@ def forward(self, input):
model_cp = deepcopy(model)
self.assertEqual(model(input).data, model_cp(input).data)

model_cp.linear.weight[:] = 2
model_cp.linear.weight.data[:] = 2
self.assertNotEqual(model(input).data, model_cp(input).data)

def test_RNN_cell(self):
Expand Down
39 changes: 28 additions & 11 deletions test/test_torch.py
Expand Up @@ -1892,8 +1892,9 @@ def test_index(self):
reference = self._consecutive((5, 5, 5))
idx = torch.LongTensor([2, 4])
self.assertEqual(reference[idx], torch.stack([reference[2], reference[4]]))
self.assertEqual(reference[2, idx], torch.stack([reference[2, 2], reference[2, 4]]))
self.assertEqual(reference[3, idx, 1], torch.stack([reference[3, 2], reference[3, 4]])[:, 1])
# TODO: enable one indexing is implemented like in numpy
# self.assertEqual(reference[2, idx], torch.stack([reference[2, 2], reference[2, 4]]))
# self.assertEqual(reference[3, idx, 1], torch.stack([reference[3, 2], reference[3, 4]])[:, 1])

# None indexing
self.assertEqual(reference[2, None], reference[2].unsqueeze(0))
Expand Down Expand Up @@ -1944,6 +1945,7 @@ def checkPartialAssign(index):
checkPartialAssign((0, 1))
checkPartialAssign((1, 2))
checkPartialAssign((0, 2))
checkPartialAssign(torch.LongTensor((0, 2)))

with self.assertRaises(IndexError):
reference[1, 1, 1, 1] = 1
Expand All @@ -1964,10 +1966,8 @@ def checkPartialAssign(index):
with self.assertRaises(TypeError):
reference[0.0, :, 0.0] = 1

# LongTensor assignments are not supported yet
with self.assertRaises(RuntimeError):
reference[torch.LongTensor([2, 4])] = 1
with self.assertRaises(RuntimeError):
# LongTensor assignments are not fully supported yet
with self.assertRaises(TypeError):
reference[0, torch.LongTensor([2, 4])] = 1

def test_index_copy(self):
Expand Down Expand Up @@ -2181,13 +2181,30 @@ def test_view(self):
self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))

def test_expand(self):
result = torch.Tensor()
tensor = torch.rand(8, 1)
template = torch.rand(8, 5)
tensor = torch.rand(1, 8, 1)
tensor2 = torch.rand(5)
template = torch.rand(4, 8, 5)
target = template.size()
self.assertEqual(tensor.expand_as(template).size(), target)
self.assertEqual(tensor.expand(8, 5).size(), target)
self.assertEqual(tensor.expand(torch.Size([8, 5])).size(), target)
self.assertEqual(tensor.expand(4, 8, 5).size(), target)
self.assertEqual(tensor.expand(target).size(), target)
self.assertEqual(tensor2.expand_as(template).size(), target)
self.assertEqual(tensor2.expand(4, 8, 5).size(), target)
self.assertEqual(tensor2.expand(target).size(), target)

# test double expand
self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1))

# test non-contiguous
noncontig = torch.randn(5, 2, 1, 3)[:, 0]
assert not noncontig.is_contiguous()
self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1))

# make sure it's compatible with unsqueeze
expanded = tensor2.expand(1, 1, 5)
unsqueezed = tensor2.unsqueeze(0).unsqueeze(1)
self.assertEqual(expanded, unsqueezed)
self.assertEqual(expanded.stride(), unsqueezed.stride())

def test_repeat(self):
result = torch.Tensor()
Expand Down
21 changes: 16 additions & 5 deletions torch/autograd/_functions/basic_ops.py
Expand Up @@ -3,63 +3,74 @@
import math


def maybe_view(tensor, size):
if tensor.size() == size:
return tensor
return tensor.contiguous().view(size)


class Add(InplaceFunction):

def forward(self, a, b):
self.b_size = b.size()
if self.inplace:
self.mark_dirty(a)
return a.add_(b)
else:
return a.add(b)

def backward(self, grad_output):
return grad_output, grad_output
return grad_output, maybe_view(grad_output, self.b_size)


class Sub(InplaceFunction):

def forward(self, a, b):
self.b_size = b.size()
if self.inplace:
self.mark_dirty(a)
return a.sub_(b)
else:
return a.sub(b)

def backward(self, grad_output):
return grad_output, grad_output.neg()
return grad_output, maybe_view(grad_output.neg(), self.b_size)


class Mul(Function):

def forward(self, a, b):
self.b_size = b.size()
self.save_for_backward(a, b)
return a.mul(b)

def backward(self, grad_output):
a, b = self.saved_tensors
return grad_output.mul(b), grad_output.mul(a)
return grad_output.mul(b), maybe_view(grad_output.mul(a), self.b_size)


class Div(Function):

def forward(self, a, b):
self.b_size = b.size()
self.save_for_backward(a, b)
return a.div(b)

def backward(self, grad_output):
a, b = self.saved_tensors
return grad_output.div(b), grad_output.neg().mul(a).div_(b).div_(b)
return grad_output.div(b), maybe_view(grad_output.neg().mul(a).div_(b).div_(b), self.b_size)


class Pow(Function):

def forward(self, a, b):
self.b_size = b.size()
self.save_for_backward(a, b)
return a.pow(b)

def backward(self, grad_output):
a, b = self.saved_tensors
return grad_output.mul(b).mul_(a.pow(b - 1)), grad_output.mul(a.pow(b)).mul_(a.log())
return grad_output.mul(b).mul_(a.pow(b - 1)), maybe_view(grad_output.mul(a.pow(b)).mul_(a.log()), self.b_size)


class AddConstant(InplaceFunction):
Expand Down
8 changes: 4 additions & 4 deletions torch/autograd/_functions/tensor.py
Expand Up @@ -18,9 +18,8 @@ def forward(self, i):
return result

def backward(self, grad_output):
# TODO: this won't have to be zeroed
grad_input = grad_output.new(self.input_size).zero_()
grad_input.index(self.index).copy_(grad_output)
grad_input._set_index(self.index, grad_output)

This comment was marked as off-topic.

This comment was marked as off-topic.

return grad_input


Expand Down Expand Up @@ -110,10 +109,11 @@ def __init__(self, sizes):
self.expanded_dims = []

def forward(self, i):
result = i.expand(*self.sizes)
unsqueezed = (1,) * (len(self.sizes) - len(i.size()))
self.expanded_dims = [dim for dim, (expanded, original)
in enumerate(zip(self.sizes, i.size()))
in enumerate(zip(self.sizes, unsqueezed + i.size()))
if expanded != original]
result = i.expand(*self.sizes)
self.mark_shared_storage((i, result))
return result

Expand Down
22 changes: 13 additions & 9 deletions torch/csrc/autograd/python_function.cpp
Expand Up @@ -220,32 +220,36 @@ static void _wrap_outputs(THPFunction *self, t2var_type &t2var,
output_var = input_var;
input_var_.creator = THPFunction_asFunction(self);
} else {
// If the Variable has been changed, we have to move it after the
// If the leaf Variable has been returned, we have to move it after the
// current function to ensure the gradient is computed correctly.
// There are two cases now:
// 1. If it requires grad, it is an error, and this will be caught
// when its _do_backward is called, because it won't be a leaf anymore.
// Also we'll change its version.
// 2. If it doesn't require grad, we can safely move it in the graph,
// because its _do_backward will never be called.
// 1. It has been modified in-place. If it didn't require_grad it's ok,
// but if it does, then it's a clear error.
// 2. It hasn't been modified. This means that it must have been
// returned unchanged, and we can simply return a new Variable
// referencing the same storage.
if (dirty_inputs.count(output) > 0) {
Py_INCREF(input_var);
output_var = input_var;
auto& output_var_ = *output_var->cdata;
output_var_.creator = THPFunction_asFunction(self);
if (!output_var_.requires_grad && self->cdata.requires_grad) {
if (!output_var_.requires_grad) {
// Now, there's another subtlety. We move the input in the graph
// and we change its requires_grad to True. However, remember
// and possibly change its requires_grad to True. However, remember
// that we're still holding a reference to is as a previous
// function. Backward engine will think that it was really a
// leaf that initialy did require grad and call its _do_backward
// and that will throw. Because of this, we need to allocate
// a dummy leaf that doesn't require grad and put it as our
// previous function.
output_var_.requires_grad = true;
// Even if the function doesn't require grad, creating a dummy leaf
// prevents the creation of reference cycles.
output_var_.requires_grad = self->cdata.requires_grad;
auto dummy_prev_fn = std::make_shared<Variable>(
std::unique_ptr<Tensor>(output_var_.data->clone_shallow()), false, false);
self->cdata.previous_functions[i] = std::make_pair<>(dummy_prev_fn, 0);
} else { // output_var_.requires_grad
throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation.");
}
} else {
// An input has been returned, but it wasn't modified. It's better
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/cuda/Module.cpp
Expand Up @@ -60,6 +60,7 @@ static bool THCPModule_assignStateless()
PyObject *stateless;
INIT_STATELESS(Double);
INIT_STATELESS_DETAIL(Float, Cuda);
INIT_STATELESS(Half);
INIT_STATELESS(Long);
INIT_STATELESS(Int);
INIT_STATELESS(Short);
Expand Down