Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a couple of selection reduce function autograd bugs #1702

Merged
merged 3 commits into from
Jun 3, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,16 +1368,19 @@ class dont_convert(tuple):
(Unfold, (), ((S, S, S), 1, 3, 1)),
(Unfold, (), ((S, S, S), 2, 3, 2), 'lastdim'),
(Min, (), ((S, S, S),),),
(Max, (1,), ((S, S, S),), 'dim', [0]),
(Min, (1,), ((S, S, S),), 'dim', [0]),
(Max, (1, False), ((S, S, S),), 'keepdim_false_dim', [0]),
(Min, (1, False), ((S, S, S),), 'keepdim_false_dim', [0]),
(Mode, (1,), ((S, S, S),), 'dim', [0]),
(Mode, (1, False,), ((S, S, S),), 'keepdim_false_dim', [0]),
(Kthvalue, (2, 0), ((S, S, S),),),
(Kthvalue, (2, 0, False), ((S, S, S),), "keepdim_false"),
(Median, (0,), ((S, S, S),),),
(Median, (0, False, ), ((S, S, S),), "keepdim_false"),
(Max, (), ((S, S, S), 1), 'dim', [0]),
(Min, (), ((S, S, S), 1), 'dim', [0]),
(Max, (), ((S, S, S), 1, False), 'keepdim_false_dim', [0]),
(Min, (), ((S, S, S), 1, False), 'keepdim_false_dim', [0]),
(Mode, (), ((S, S, S),),),
(Mode, (), ((S, S, S), 1), 'dim', [0]),
(Mode, (), ((S, S, S), 1, False), 'keepdim_false_dim', [0]),
(Kthvalue, (), ((S, S, S), 2),),
(Kthvalue, (), ((S, S, S), 2, 0), 'dim0'),
(Kthvalue, (), ((S, S, S), 2, 0, False), "keepdim_false"),
(Median, (), ((S, S, S),),),
(Median, (), ((S, S, S), 0), 'dim0'),
(Median, (), ((S, S, S), 0, False), "keepdim_false"),
(Norm, (1.5,), (torch.rand(S, S, S),), '1_5'),
(Norm, (), ((S, S, S),), '2'),
(Norm, (3,), ((S, S, S),), '3'),
Expand Down Expand Up @@ -1491,8 +1494,13 @@ class dont_convert(tuple):
('mean', (S, S, S), ()),
('mean', (S, S, S), (1,), 'dim', [0]),
('mean', (S, S, S), (1, False,), 'keepdim_false_dim', [0]),
('kthvalue', (S, S, S), (2,)),
('kthvalue', (S, S, S), (2, 1,), 'dim', [1]),
('kthvalue', (S, S, S), (2, 1, False,), 'keepdim_false_dim', [1]),
('median', (S, S, S), ()),
('median', (S, S, S), (1,), 'dim', [0]),
('median', (S, S, S), (1, False,), 'keepdim_false_dim', [0]),
('mode', (S, S, S), ()),
('mode', (S, S, S), (1,), 'dim', [0]),
('mode', (S, S, S), (1, False,), 'keepdim_false_dim', [0]),
('sum', (S, S, S), ()),
Expand Down
65 changes: 32 additions & 33 deletions torch/autograd/_functions/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,51 +130,50 @@ class _SelectionFunction(Function):
# additional_args is prepended before dim when calling the tensor
# function. It's a no-op for subclasses other than kthvalue.
# kthvalue not only requires us to pass a dim, but also preceed it with k.
additional_args = tuple()

def __init__(self, dim=None, keepdim=True):
super(_SelectionFunction, self).__init__()
self.dim = dim
self.keepdim = keepdim

def forward(self, input):
fn = getattr(input, type(self).__name__.lower())
self.input_size = input.size()
if self.dim is None and self.has_all_reduce:
value = fn(*self.additional_args)
self.indices = tuple(input.eq(value).nonzero()[0])

@classmethod
def forward(cls, ctx, input, dim=None, keepdim=True, additional_args=tuple()):
fn = getattr(input, cls.__name__.lower())
ctx.dim = dim
ctx.keepdim = keepdim
ctx.additional_args = additional_args
ctx.input_size = input.size()
if ctx.dim is None and cls.has_all_reduce:
value = fn(*additional_args)
ctx.indices_tuple = tuple(input.eq(value).nonzero()[0])
return input.new((value,))
else:
if self.dim is None:
if ctx.dim is None:
dim = input.dim() - 1
else:
dim = self.dim
args = (dim, self.keepdim)
if self.additional_args:
args = self.additional_args + args
dim = ctx.dim
args = (dim, keepdim)
if additional_args:
args = additional_args + args
output, indices = fn(*args)
self.save_for_backward(indices)
self.mark_non_differentiable(indices)
ctx.save_for_backward(indices)
ctx.mark_non_differentiable(indices)
return output, indices

def backward(self, grad_output, grad_indices=None):
grad_input = grad_output.new(*self.input_size).zero_()
if self.dim is None and self.has_all_reduce:
grad_input[self.indices] = grad_output[0]
@classmethod
def backward(cls, ctx, grad_output, grad_indices=None):
grad_input = Variable(grad_output.data.new(*ctx.input_size).zero_())
if ctx.dim is None and cls.has_all_reduce:
grad_input[ctx.indices_tuple] = grad_output.data[0]

This comment was marked as off-topic.

else:
if self.dim is None:
dim = input.dim() - 1
if ctx.dim is None:
dim = len(ctx.input_size) - 1
else:
dim = self.dim
dim = ctx.dim

indices, = self.saved_tensors
if self.keepdim is False:
indices, = ctx.saved_variables
if ctx.keepdim is False:
grad_output = grad_output.unsqueeze(dim)
grad_indices = grad_indices.unsqueeze(dim)
indices = indices.unsqueeze(dim)

grad_input.scatter_(dim, indices, grad_output)
return grad_input
return grad_input, None, None, None


class Max(_SelectionFunction):
Expand All @@ -196,9 +195,9 @@ class Median(_SelectionFunction):
class Kthvalue(_SelectionFunction):
has_all_reduce = False

def __init__(self, k, dim=None, keepdim=True):
super(Kthvalue, self).__init__(dim, keepdim)
self.additional_args = (k,)
@classmethod
def forward(cls, ctx, input, k, dim=None, keepdim=True):
return super(Kthvalue, cls).forward(ctx, input, dim, keepdim, (k,))


class Norm(Function):
Expand Down
16 changes: 8 additions & 8 deletions torch/autograd/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,21 +451,21 @@ def mean(self, dim=None, keepdim=True):
def max(self, dim=None, keepdim=True):
if isinstance(dim, Variable):
return Cmax.apply(self, dim)
return Max(dim, keepdim)(self)
return Max.apply(self, dim, keepdim)

def min(self, dim=None, keepdim=True):
if isinstance(dim, Variable):
return Cmin.apply(self, dim)
return Min(dim, keepdim)(self)
return Min.apply(self, dim, keepdim)

def mode(self, dim, keepdim=True):
return Mode(dim, keepdim)(self)
def mode(self, dim=None, keepdim=True):
return Mode.apply(self, dim, keepdim)

def median(self, dim, keepdim=True):
return Median(dim, keepdim)(self)
def median(self, dim=None, keepdim=True):
return Median.apply(self, dim, keepdim)

def kthvalue(self, dim, keepdim=True):
return Kthvalue(dim, keepdim)(self)
def kthvalue(self, k, dim=None, keepdim=True):
return Kthvalue.apply(self, k, dim, keepdim)

def sort(self, dim=None, descending=False):
return Sort.apply(self, dim, descending, True)
Expand Down