Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
2c2f400
cumsum backward for complex types
nikitaved Nov 17, 2020
48652db
enable complex backward for cumprod. It does depend on prod though
nikitaved Nov 18, 2020
b0179ea
modify tests for forward prod
nikitaved Nov 18, 2020
eb178c7
enable complex grads for index_select/index_add but ignore testing fo…
nikitaved Nov 18, 2020
8371d82
prod complex backward support. Some tests are disabled
nikitaved Nov 18, 2020
443a698
enable all test for cumprod
nikitaved Nov 18, 2020
d0671a1
enable complex dispatch in index_add on CUDA
nikitaved Nov 18, 2020
7219437
unlock more tests for prod
nikitaved Nov 18, 2020
638f1ca
fix that cumprod test that caused JIT failures
nikitaved Nov 18, 2020
24baf84
minor
nikitaved Nov 18, 2020
ddcf407
merge master
nikitaved Mar 7, 2021
120aa1f
make sure test_prod pass. TODO: do OpInfo
nikitaved Mar 7, 2021
2847c2d
make sure cumsum tests pass. TODO: add OpInfo
nikitaved Mar 7, 2021
56095df
make cumprod pass, remove cast test for now. TODO: add OpInfo
nikitaved Mar 7, 2021
c61df34
minor. Run on CIs
nikitaved Mar 7, 2021
81a306b
merge master
nikitaved Mar 11, 2021
7f1ad40
add OpInfo for cumprod
nikitaved Mar 11, 2021
941db80
add Op tests for prod
nikitaved Mar 11, 2021
ecf539a
minor
nikitaved Mar 11, 2021
e962d3f
merge master
nikitaved Mar 11, 2021
298aebd
Merge branch 'master' of https://github.com/pytorch/pytorch into nikv…
nikitaved Mar 11, 2021
6e6061f
minor
nikitaved Mar 11, 2021
ca7d279
add a comment on skipping out= tests
nikitaved Mar 12, 2021
d4787c2
Merge branch 'master' of https://github.com/pytorch/pytorch into nikv…
nikitaved Mar 12, 2021
8f08e16
merge master
nikitaved Mar 16, 2021
b521637
merge master
nikitaved Mar 19, 2021
91201d1
make sure cumprod supports complex
nikitaved Mar 19, 2021
100280b
Merge branch 'master' of https://github.com/pytorch/pytorch into nikv…
nikitaved Mar 19, 2021
49a5b9a
remove prod_single_zero for common_utils.py
nikitaved Mar 19, 2021
7977e70
add a comment that input_conj <=> input
nikitaved Mar 19, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,16 @@ Tensor cumprod_backward(const Tensor& grad, const Tensor& input, int64_t dim, co
return grad;
}

const auto w = output * grad;
// To enable complex support.
// From this line on `input_conj` and output_conj`
// are interchangeable with `input` and `output`.
auto input_conj = input.conj();
auto output_conj = output.conj();

const auto w = output_conj * grad;
const auto is_zero = input == 0;
if (!(is_zero.any().item<uint8_t>())) {
return reversed_cumsum(w, dim).div(input);
return reversed_cumsum(w, dim).div(input_conj);
}

// If we are not computing a second order gradient, we can use an
Expand Down Expand Up @@ -309,7 +315,7 @@ Tensor cumprod_backward(const Tensor& grad, const Tensor& input, int64_t dim, co
auto mask = cumsum == 0;
// equiv to grad_input[mask] = deriv[grad]
grad_input.masked_scatter_(mask,
reversed_cumsum(w.masked_fill(~mask, 0.), dim).div_(input).masked_select(mask));
reversed_cumsum(w.masked_fill(~mask, 0.), dim).div_(input_conj).masked_select(mask));
// select everything from the first zero to the second zero [z1, z2)
mask = cumsum == 1;

Expand All @@ -332,10 +338,10 @@ Tensor cumprod_backward(const Tensor& grad, const Tensor& input, int64_t dim, co
// relu_() necessary as gather does not support negative indices
// finally, we do grad_input[z1] = dy_j / dx_z1
grad_input.masked_scatter_(first_zero_mask,
input.masked_fill(~mask, 1.).cumprod(dim)
input_conj.masked_fill(~mask, 1.).cumprod(dim)
.mul_(grad.masked_fill(cumsum != 1, 0.))
.sum(dim, /*keepdim*/true)
.mul_(at::gather(output, dim, (first_zero_index - 1).relu_())
.mul_(at::gather(output_conj, dim, (first_zero_index - 1).relu_())
.masked_fill_(first_zero_index == 0, 1.))
.masked_select(first_zero_mask));
} else { // GradMode::enabled()
Expand Down Expand Up @@ -367,14 +373,14 @@ Tensor cumprod_backward(const Tensor& grad, const Tensor& input, int64_t dim, co
Tensor omitted_products;
for (int k = 0; k < dim_size; ++k) {
if (k == 0) {
prods_from_k_plus_1 = at::cumprod(input.slice(dim, k + 1), dim);
prods_from_k_plus_1 = at::cumprod(input_conj.slice(dim, k + 1), dim);
omitted_products = at::cat({ones, prods_from_k_plus_1}, dim);
} else if (k == dim_size - 1) {
const Tensor prods_until_k = at::prod(input.slice(dim, 0, k), dim, true);
const Tensor prods_until_k = at::prod(input_conj.slice(dim, 0, k), dim, true);
omitted_products = prods_until_k;
} else {
const Tensor prods_until_k = at::prod(input.slice(dim, 0, k), dim, true);
prods_from_k_plus_1 = at::cumprod(input.slice(dim, k+1), dim);
const Tensor prods_until_k = at::prod(input_conj.slice(dim, 0, k), dim, true);
prods_from_k_plus_1 = at::cumprod(input_conj.slice(dim, k+1), dim);
omitted_products = prods_until_k.expand_as(prods_from_k_plus_1) * prods_from_k_plus_1;
omitted_products = at::cat({prods_until_k, omitted_products}, dim);
}
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Indexing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ namespace {

template <typename mask_t>
void masked_fill_kernel(TensorIterator& iter, const Scalar& value) {
AT_DISPATCH_ALL_TYPES_AND3(
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kBool, kHalf, kBFloat16, iter.common_dtype(), "masked_fill_", [&]() {
const auto value_ = value.to<scalar_t>();
gpu_kernel(
Expand Down
5 changes: 5 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,11 @@ def test_prod_large(self):
x = torch.ones(240000, device='cuda', dtype=torch.float32)
self.assertEqual(x.prod(), 1)

# test for complex types. Note 240k is divisible by 4
for dtype in [torch.cfloat, torch.cdouble]:
x = torch.ones(240000, device='cuda', dtype=dtype) * (0 + 1j)
self.assertEqual(x.prod(), 1)

def test_multinomial_ext(self):
# Test two corner cases from older PyTorch (Issue #4858)
freqs = torch.cuda.FloatTensor([
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
'replication_pad1d', 'replication_pad2d', 'replication_pad3d',
'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward',
'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum',
'eig', 'lerp', 'linalg_vector_norm'
'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod'
}

# Some operators invalidate the grad_accumulator. Let's reset it.
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ Tensor prod_safe_zeros_backward(const Tensor &grad, const Tensor& inp, int64_t d
Tensor exclusive_reverse_nocp = at::cat({ones, narrow_reverse}, dim);
Tensor exclusive_reverse = reverse_dim(exclusive_reverse_nocp.cumprod(dim), dim);

return grad * (exclusive_normal * exclusive_reverse);
return grad * (exclusive_normal * exclusive_reverse).conj();
}

// note that the gradient for prod is equivalent to:
Expand All @@ -482,7 +482,7 @@ Tensor prod_backward(const Tensor& grad, const Tensor& input, const Tensor& resu
}
Tensor zero_idx = (input == 0).nonzero();
if (zero_idx.numel() == 0) {
return (grad * result) / input;
return grad * (result / input).conj();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the precision implication of flipping this order? I have #43414 in mind

Copy link
Collaborator Author

@nikitaved nikitaved Nov 18, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure they are comparable. The reason being one order is better when input[i] is greater than 1, the other when input[i] is less than 1. And it is only entry-wise dependent... Hmm, I do not know tbh...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The different in the other one was that we were squaring the denominator before doing the division. BUt since there is no squaring it should be fine here right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if you feel more comfortable with the original order, I could make it like
grad * result.conj() / input.conj()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the new one will work. It just that it reminded me of that issue and I didn't want to introduce any regression.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you are right. I will restore the previous order so that the change does not become bc-breaking.

} else if (zero_idx.size(0) > 1) {
return at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
} else {
Expand All @@ -504,7 +504,7 @@ Tensor prod_backward(Tensor grad, const Tensor& input, Tensor result, int64_t di
Tensor slice_zero_count = zero_mask.sum(dim, true);
int64_t total_zeros = slice_zero_count.sum().item<int64_t>();
if (total_zeros == 0) {
return (grad * result) / input;
return grad * (result / input).conj();
} else {
return prod_safe_zeros_backward(grad, input, dim);
}
Expand Down
127 changes: 99 additions & 28 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
skipCUDAIfRocm, expectedAlertNondeterministic, precisionOverride,)
from torch.testing._internal.common_cuda import CUDA11OrLater
from torch.testing._internal.common_utils import \
(prod_single_zero, random_square_matrix_of_rank,
(random_square_matrix_of_rank,
random_symmetric_matrix, random_symmetric_psd_matrix,
random_symmetric_pd_matrix, make_nonzero_det,
random_fullrank_matrix_distinct_singular_value, set_rng_seed, SEED,
Expand Down Expand Up @@ -1437,6 +1437,69 @@ def sample_inputs_clamp(op_info, device, dtype, requires_grad):
output += [SampleInput(empty_tensor, args=(0.0, 1.0)), ]
return output

def sample_inputs_cumprod(op_info, device, dtype, requires_grad):
def make_arg(shape):
# shrink values to be in the interval [-1, +1] for better precision in gradgradcheck
return make_tensor(shape, device, dtype, low=-1, high=+1, requires_grad=requires_grad)

def prod_zeros(dim_select):
assert len(dim_select) == 2
result = make_arg(3 * (S,))
with torch.no_grad():
result.narrow(dim_select[0], 0, 1).narrow(dim_select[1], 1, 1).zero_()
result.narrow(dim_select[0], 2, 1).narrow(dim_select[1], 3, 1).zero_()
result.narrow(dim_select[0], 4, 1).narrow(dim_select[1], 3, 1).zero_()
return result

# will not be needed once OpInfo tests suport Iterables
def sample_generator():
for dim in range(3):
yield SampleInput((make_arg((S, S, S)), dim))
# Scalar tensors and empty tensor
for size in [(), (1,), (0,)]:
yield SampleInput((make_arg(size), 0))

yield SampleInput((prod_zeros([0, 1]), 1))
yield SampleInput((prod_zeros([0, 2]), 1))
yield SampleInput((prod_zeros([1, 2]), 1))

# test dtype kwarg
yield SampleInput((prod_zeros([1, 2]), 1), kwargs={'dtype': dtype})

return list(sample_generator())

def sample_inputs_prod(op_info, device, dtype, requires_grad):
def make_arg(shape):
# shrink values to be in the interval [-1, +1] for better precision in gradgradcheck
return make_tensor(shape, device, dtype, low=-1, high=+1, requires_grad=requires_grad)

def prod_single_zero():
result = make_arg(2 * (S,))
with torch.no_grad():
result[0, 1] = 0
return result

# will not be needed once OpInfo tests support Iterables
def sample_generator():
for sample in sample_inputs_cumprod(op_info, device, dtype, requires_grad):
yield SampleInput(sample.input[0]) # only Tensor, ignore other inputs
yield sample
sample.kwargs['keepdim'] = True
yield sample
yield SampleInput(prod_single_zero())
yield SampleInput((make_arg((3, 3, 3)), 1))
yield SampleInput((make_arg((3, 3, 3)), 1), kwargs={'keepdim': True})

# test zero scalar tensor
zero = make_arg(())
with torch.no_grad():
zero.zero_()
yield SampleInput(zero)
yield SampleInput((zero, 0))
yield SampleInput((zero, 0), kwargs={'keepdim': True})

return list(sample_generator())

def sample_inputs_diag(op_info, device, dtype, requires_grad):
vec_sample = SampleInput(make_tensor((M, ), device, dtype, low=None, high=None, requires_grad=requires_grad))

Expand Down Expand Up @@ -1998,6 +2061,29 @@ def _make_tensor_helper(shape, low=None, high=None):
SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),
),
sample_inputs_func=sample_inputs_cumsum),
OpInfo('cumprod',
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16),
test_inplace_grad=False,
skips=(
# Reference: https://github.com/pytorch/pytorch/issues/53360
# For integer inputs,
# inplace variant preserves dtype of `self` while method variant
# always promotes it to torch.long.
# >>> t = torch.randint(2, 10, (3, 2), dtype=torch.int8)
# >>> t.cumprod(0).dtype
# torch.int64
# >>> t.cumprod_(0).dtype
# torch.int8
SkipInfo('TestCommon', 'test_variant_consistency_eager',
dtypes=[torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32]),
SkipInfo('TestCommon', 'test_variant_consistency_jit',
dtypes=[torch.bool, torch.float16]),
# cumprod does not correctly warn when resizing out= inputs
SkipInfo('TestCommon', 'test_out',
dtypes=[torch.float32]),
),
sample_inputs_func=sample_inputs_cumprod),
UnaryUfuncInfo('deg2rad',
ref=np.radians,
decorators=(precisionOverride({torch.bfloat16: 7e-1,
Expand Down Expand Up @@ -2533,6 +2619,18 @@ def _make_tensor_helper(shape, low=None, high=None):
dtypesIfCPU=all_types_and_complex_and(torch.half, torch.bfloat16),
dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16),
assert_autodiffed=True,),
OpInfo('prod',
dtypes=all_types_and_complex_and(torch.bool),
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
test_inplace_grad=False,
skips=(
SkipInfo('TestCommon', 'test_variant_consistency_jit',
dtypes=[torch.float16, torch.bfloat16]),
# prod does not correctly warn when resizing out= inputs
SkipInfo('TestCommon', 'test_out',
dtypes=[torch.float32]),
),
sample_inputs_func=sample_inputs_prod),
OpInfo('qr',
op=torch.qr,
dtypes=floating_and_complex_types(),
Expand Down Expand Up @@ -3600,25 +3698,6 @@ def method_tests():
('nansum', (), (0, True,), 'scalar_keepdim_dim', (), [0]),
('nansum', (S, S, S), ([1, 2],), 'multi_dim'),
('nansum', (S, S, S), ([1, 2], True,), 'multi_dim_keepdim'),
('prod', (S, S, S), NO_ARGS),
('prod', (S, S, S), (1,), 'dim', (), [0]),
('prod', (S, S, S), (1, True,), 'keepdim_dim', (), [0]),
('prod', (), NO_ARGS, 'scalar'),
('prod', (), (0,), 'scalar_dim', (), [0]),
('prod', (), (0, True,), 'scalar_keepdim_dim', (), [0]),
('prod', prod_zeros(S, [0, 1]), NO_ARGS, 'zerodims2'),
('prod', prod_zeros(S, [0, 2]), NO_ARGS, 'zerodims1'),
('prod', prod_zeros(S, [1, 2]), NO_ARGS, 'zerodims0'),
('prod', prod_zeros(S, [0, 1]), (1,), 'zeros_dims2', (), [0]),
('prod', prod_zeros(S, [0, 2]), (1,), 'zeros_dims1', (), [0]),
('prod', prod_zeros(S, [1, 2]), (1,), 'zeros_dims0', (), [0]),
('prod', prod_zeros(S, [0, 1]), (1, True), 'keepdim_zeros_dims2', (), [0]),
('prod', prod_zeros(S, [0, 2]), (1, True), 'keepdim_zeros_dims1', (), [0]),
('prod', prod_zeros(S, [1, 2]), (1, True), 'keepdim_zeros_dims0', (), [0]),
('prod', prod_single_zero(S), NO_ARGS, 'single_zero'),
('prod', (torch.tensor(0., requires_grad=True)), NO_ARGS, 'scalar_zero'),
('prod', (torch.tensor(0., requires_grad=True)), (0,), 'scalar_dim_zero', (), [0]),
('prod', (torch.tensor(0., requires_grad=True)), (0, True,), 'scalar_keepdim_dim_zero', (), [0]),
('var_mean', (S, S, S), NO_ARGS, ''),
('var_mean', (S, S, S), (1,), 'dim', [0]),
('var_mean', (S, S, S), (1, True, True), 'keepdim_dim', [0]),
Expand All @@ -3642,14 +3721,6 @@ def method_tests():
('cummin', (S, S, S), (1,), 'dim1', (), [0]),
('cummin', (), (0,), 'dim0_scalar', (), [0]),
('cumsum', (S, S, S), (1,), 'dim1_cast', (), [0], (), ident, {'dtype': torch.float64}),
('cumprod', (S, S, S), (0,)),
('cumprod', (S, S, S), (1,), 'dim1', (), [0]),
('cumprod', (), (0,), 'scalar'),
('cumprod', (torch.tensor(0., requires_grad=True)), (0,), 'scalar_zeros'),
('cumprod', prod_zeros(S, [0, 1]), (1,), 'zeros_dim2', (), [0]),
('cumprod', prod_zeros(S, [0, 2]), (1,), 'zeros_dim1', (), [0]),
('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0', (), [0]),
('cumprod', prod_zeros(S, [1, 2]), (1,), 'zeros_dim0_cast', (), [0], (), ident, {'dtype': torch.float64}),
('log_softmax', (S, S, S), (1, torch.float64,), 'kwarg_dtype_would_break_jit_loader', (True,)),
('unfold', (), (0, 1, 1), 'scalar', (), [0]),
('unfold', (S, S, S, S), (0, 3, 1), '4d_dim0_step1', (), [0]),
Expand Down
6 changes: 0 additions & 6 deletions torch/testing/_internal/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1644,12 +1644,6 @@ def make_tensor(size, device: torch.device, dtype: torch.dtype, *, low=None, hig

return result

def prod_single_zero(dim_size):
result = torch.randn(dim_size, dim_size)
result[0, 1] = 0
return result


def random_square_matrix_of_rank(l, rank, dtype=torch.double, device='cpu'):
assert rank <= l
A = torch.randn(l, l, dtype=dtype, device=device)
Expand Down