Skip to content

Commit

Permalink
Fix a couple of selection reduce function autograd bugs (#1702)
Browse files Browse the repository at this point in the history
* Fix Median/Mode autograd functions.

* Fix kthvalue autograd function.

* Double backward for selection reduce functions.
  • Loading branch information
gchanan authored and soumith committed Jun 3, 2017
1 parent eba3dc8 commit ac1c674
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 51 deletions.
28 changes: 18 additions & 10 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,16 +1368,19 @@ class dont_convert(tuple):
(Unfold, (), ((S, S, S), 1, 3, 1)),
(Unfold, (), ((S, S, S), 2, 3, 2), 'lastdim'),
(Min, (), ((S, S, S),),),
(Max, (1,), ((S, S, S),), 'dim', [0]),
(Min, (1,), ((S, S, S),), 'dim', [0]),
(Max, (1, False), ((S, S, S),), 'keepdim_false_dim', [0]),
(Min, (1, False), ((S, S, S),), 'keepdim_false_dim', [0]),
(Mode, (1,), ((S, S, S),), 'dim', [0]),
(Mode, (1, False,), ((S, S, S),), 'keepdim_false_dim', [0]),
(Kthvalue, (2, 0), ((S, S, S),),),
(Kthvalue, (2, 0, False), ((S, S, S),), "keepdim_false"),
(Median, (0,), ((S, S, S),),),
(Median, (0, False, ), ((S, S, S),), "keepdim_false"),
(Max, (), ((S, S, S), 1), 'dim', [0]),
(Min, (), ((S, S, S), 1), 'dim', [0]),
(Max, (), ((S, S, S), 1, False), 'keepdim_false_dim', [0]),
(Min, (), ((S, S, S), 1, False), 'keepdim_false_dim', [0]),
(Mode, (), ((S, S, S),),),
(Mode, (), ((S, S, S), 1), 'dim', [0]),
(Mode, (), ((S, S, S), 1, False), 'keepdim_false_dim', [0]),
(Kthvalue, (), ((S, S, S), 2),),
(Kthvalue, (), ((S, S, S), 2, 0), 'dim0'),
(Kthvalue, (), ((S, S, S), 2, 0, False), "keepdim_false"),
(Median, (), ((S, S, S),),),
(Median, (), ((S, S, S), 0), 'dim0'),
(Median, (), ((S, S, S), 0, False), "keepdim_false"),
(Norm, (1.5,), (torch.rand(S, S, S),), '1_5'),
(Norm, (), ((S, S, S),), '2'),
(Norm, (3,), ((S, S, S),), '3'),
Expand Down Expand Up @@ -1492,8 +1495,13 @@ class dont_convert(tuple):
('mean', (S, S, S), ()),
('mean', (S, S, S), (1,), 'dim', [0]),
('mean', (S, S, S), (1, False,), 'keepdim_false_dim', [0]),
('kthvalue', (S, S, S), (2,)),
('kthvalue', (S, S, S), (2, 1,), 'dim', [1]),
('kthvalue', (S, S, S), (2, 1, False,), 'keepdim_false_dim', [1]),
('median', (S, S, S), ()),
('median', (S, S, S), (1,), 'dim', [0]),
('median', (S, S, S), (1, False,), 'keepdim_false_dim', [0]),
('mode', (S, S, S), ()),
('mode', (S, S, S), (1,), 'dim', [0]),
('mode', (S, S, S), (1, False,), 'keepdim_false_dim', [0]),
('sum', (S, S, S), ()),
Expand Down
65 changes: 32 additions & 33 deletions torch/autograd/_functions/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,51 +130,50 @@ class _SelectionFunction(Function):
# additional_args is prepended before dim when calling the tensor
# function. It's a no-op for subclasses other than kthvalue.
# kthvalue not only requires us to pass a dim, but also preceed it with k.
additional_args = tuple()

def __init__(self, dim=None, keepdim=True):
super(_SelectionFunction, self).__init__()
self.dim = dim
self.keepdim = keepdim

def forward(self, input):
fn = getattr(input, type(self).__name__.lower())
self.input_size = input.size()
if self.dim is None and self.has_all_reduce:
value = fn(*self.additional_args)
self.indices = tuple(input.eq(value).nonzero()[0])

@classmethod
def forward(cls, ctx, input, dim=None, keepdim=True, additional_args=tuple()):
fn = getattr(input, cls.__name__.lower())
ctx.dim = dim
ctx.keepdim = keepdim
ctx.additional_args = additional_args
ctx.input_size = input.size()
if ctx.dim is None and cls.has_all_reduce:
value = fn(*additional_args)
ctx.indices_tuple = tuple(input.eq(value).nonzero()[0])
return input.new((value,))
else:
if self.dim is None:
if ctx.dim is None:
dim = input.dim() - 1
else:
dim = self.dim
args = (dim, self.keepdim)
if self.additional_args:
args = self.additional_args + args
dim = ctx.dim
args = (dim, keepdim)
if additional_args:
args = additional_args + args
output, indices = fn(*args)
self.save_for_backward(indices)
self.mark_non_differentiable(indices)
ctx.save_for_backward(indices)
ctx.mark_non_differentiable(indices)
return output, indices

def backward(self, grad_output, grad_indices=None):
grad_input = grad_output.new(*self.input_size).zero_()
if self.dim is None and self.has_all_reduce:
grad_input[self.indices] = grad_output[0]
@classmethod
def backward(cls, ctx, grad_output, grad_indices=None):
grad_input = Variable(grad_output.data.new(*ctx.input_size).zero_())
if ctx.dim is None and cls.has_all_reduce:
grad_input[ctx.indices_tuple] = grad_output.data[0]
else:
if self.dim is None:
dim = input.dim() - 1
if ctx.dim is None:
dim = len(ctx.input_size) - 1
else:
dim = self.dim
dim = ctx.dim

indices, = self.saved_tensors
if self.keepdim is False:
indices, = ctx.saved_variables
if ctx.keepdim is False:
grad_output = grad_output.unsqueeze(dim)
grad_indices = grad_indices.unsqueeze(dim)
indices = indices.unsqueeze(dim)

grad_input.scatter_(dim, indices, grad_output)
return grad_input
return grad_input, None, None, None


class Max(_SelectionFunction):
Expand All @@ -196,9 +195,9 @@ class Median(_SelectionFunction):
class Kthvalue(_SelectionFunction):
has_all_reduce = False

def __init__(self, k, dim=None, keepdim=True):
super(Kthvalue, self).__init__(dim, keepdim)
self.additional_args = (k,)
@classmethod
def forward(cls, ctx, input, k, dim=None, keepdim=True):
return super(Kthvalue, cls).forward(ctx, input, dim, keepdim, (k,))


class Norm(Function):
Expand Down
16 changes: 8 additions & 8 deletions torch/autograd/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,21 +451,21 @@ def mean(self, dim=None, keepdim=True):
def max(self, dim=None, keepdim=True):
if isinstance(dim, Variable):
return Cmax.apply(self, dim)
return Max(dim, keepdim)(self)
return Max.apply(self, dim, keepdim)

def min(self, dim=None, keepdim=True):
if isinstance(dim, Variable):
return Cmin.apply(self, dim)
return Min(dim, keepdim)(self)
return Min.apply(self, dim, keepdim)

def mode(self, dim, keepdim=True):
return Mode(dim, keepdim)(self)
def mode(self, dim=None, keepdim=True):
return Mode.apply(self, dim, keepdim)

def median(self, dim, keepdim=True):
return Median(dim, keepdim)(self)
def median(self, dim=None, keepdim=True):
return Median.apply(self, dim, keepdim)

def kthvalue(self, dim, keepdim=True):
return Kthvalue(dim, keepdim)(self)
def kthvalue(self, k, dim=None, keepdim=True):
return Kthvalue.apply(self, k, dim, keepdim)

def sort(self, dim=None, descending=False):
return Sort.apply(self, dim, descending, True)
Expand Down

0 comments on commit ac1c674

Please sign in to comment.