Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[numpy] torch.lgamma: promote integer inputs to float #50140

Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 3 additions & 32 deletions aten/src/ATen/native/UnaryOps.cpp
Expand Up @@ -657,38 +657,9 @@ Tensor& mvlgamma_(Tensor& self, int64_t p) {
return self.copy_(args.lgamma_().sum(-1).add_(p * (p - 1) * std::log(M_PI) / 4.));
}

// NB: If you use this macro, you may also need to add a CUDA forwarding
// stub in CUDAUnaryOps

#define IMPLEMENT_UNARY_OP_CORE(op) \
Tensor op(const Tensor& self) { \
Tensor result = at::empty({0}, self.options()); \
at::op##_out(result, self); \
return result; \
}

#define IMPLEMENT_UNARY_OP_OUT_INPLACE(op, prefix, device) \
Tensor& _##op##__##prefix(Tensor& self) { \
return at::op##_out(self, self); \
} \
Tensor& _##op##_out_##prefix(Tensor& result, const Tensor& self) { \
checkDeviceType(#op, result, DeviceType::device); \
checkLayout(#op, result, Layout::Strided); \
auto iter = TensorIterator::unary_op(result, self); \
op##_stub(iter.device_type(), iter); \
return result; \
}

#define IMPLEMENT_UNARY_OP_VEC(op) \
IMPLEMENT_UNARY_OP_CORE(op) \
IMPLEMENT_UNARY_OP_OUT_INPLACE(op, cpu, CPU)

#define IMPLEMENT_UNARY_OP_VEC_CUDA(op) \
IMPLEMENT_UNARY_OP_CORE(op) \
IMPLEMENT_UNARY_OP_OUT_INPLACE(op, cpu, CPU) \
IMPLEMENT_UNARY_OP_OUT_INPLACE(op, cuda, CUDA)

IMPLEMENT_UNARY_OP_VEC_CUDA(lgamma)
Tensor& lgamma_out(Tensor& result, const Tensor& self) { return unary_op_impl_float_out(result, self, lgamma_stub); }
Tensor lgamma(const Tensor& self) { return unary_op_impl_float(self, lgamma_stub); }
Tensor& lgamma_(Tensor& self) { return unary_op_impl_(self, at::lgamma_out); }

DEFINE_DISPATCH(abs_stub);
DEFINE_DISPATCH(angle_stub);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/UnaryGammaKernels.cu
Expand Up @@ -41,7 +41,7 @@ void polygamma_kernel_cuda(TensorIterator& iter, int64_t n) {
}

void lgamma_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "lgamma_cuda", [&]() {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.common_dtype(), "lgamma_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return ::lgamma(a);
});
Expand Down
14 changes: 6 additions & 8 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -5165,12 +5165,6 @@
dispatch:
CPU, CUDA: __irshift__

- func: lgamma_(Tensor(a!) self) -> Tensor(a!)
variants: method
dispatch:
CPU: _lgamma__cpu
CUDA: _lgamma__cuda

- func: atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!)
variants: method
dispatch:
Expand Down Expand Up @@ -6019,8 +6013,12 @@
- func: lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
dispatch:
CPU: _lgamma_out_cpu
CUDA: _lgamma_out_cuda
CPU, CUDA: lgamma_out

- func: lgamma_(Tensor(a!) self) -> Tensor(a!)
variants: method
dispatch:
CPU, CUDA: lgamma_

- func: lgamma(Tensor self) -> Tensor
variants: method, function
Expand Down
1 change: 0 additions & 1 deletion test/test_torch.py
Expand Up @@ -6932,7 +6932,6 @@ def inner(self, device, dtype):
('round', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
('trunc', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
('ceil', '', _small_3d, lambda t, d: [], 1e-5, 1e-2, 1e-5, _float_types, [torch.bfloat16]),
('lgamma', '', _small_3d, lambda t, d: [], 1e-2, 1e-1, 1e-5, _float_types_no_half, [torch.bfloat16]),
]

# Creates and decorates a generic test and adds it to the class.
Expand Down
2 changes: 0 additions & 2 deletions test/test_unary_ufuncs.py
Expand Up @@ -1709,8 +1709,6 @@ def _medium_2d(dtype, device):
_TorchMathTestMeta('frac', reffn='fmod', refargs=lambda x: (x.numpy(), 1)),
_TorchMathTestMeta('trunc'),
_TorchMathTestMeta('round'),
# FIXME lgamma produces different result compared to scipy at -inf
_TorchMathTestMeta('lgamma', reffn='gammaln', ref_backend='scipy', replace_inf_with_nan=True),
_TorchMathTestMeta('polygamma', args=[0], substr='_0', reffn='polygamma',
refargs=lambda x: (0, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False],
ref_backend='scipy'),
Expand Down
5 changes: 3 additions & 2 deletions torch/csrc/jit/tensorexpr/kernel.cpp
Expand Up @@ -1350,8 +1350,9 @@ Tensor* TensorExprKernel::computeValue(const torch::jit::Value* v) {
} break;

case aten::lgamma: {
return computeOneOperand(
"aten_lgamma", v, [](const ExprHandle& a) { return lgamma(a); });
return computeOneOperand("aten_lgamma", v, [](const ExprHandle& a) {
return lgamma(promoteIntegerToDefaultType(a));
});
} break;

case prim::ConstantChunk: {
Expand Down
44 changes: 44 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -1414,6 +1414,29 @@ def reference_sigmoid(x):
return (1 / (1 + np.exp(-x)))
return scipy.special.expit(x)

def reference_lgamma(x):
# scipy.special.gammaln returns `-inf` when input is `-inf`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug in SciPy.

scipy.special.gammaln is documented as equivalent to math.lgamma, but math.lgamma returns inf when given -inf.

Would you file an issue with SciPy, @kshitij12345?

cc @rgommers

# While Pytorch, C and C++, all return `inf` when input is `-inf`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good comment. Would you extend it to include Python's math.lgamma, too?

# Reference:
# https://en.cppreference.com/w/cpp/numeric/math/lgamma
# https://en.cppreference.com/w/c/numeric/math/lgamma

# To handle the above discrepancy,
# we replace -inf with inf so values
# that were originally -inf map to inf as expected
if x.dtype.kind == 'f':
x = np.where(x == float('-inf'), np.array(float('inf'), dtype=x.dtype), x)

out = scipy.special.gammaln(x)

if x.dtype == np.float16:
# `scipy.special.gammaln` returns output of float32 when input is float16,
# while `torch.lgamma` preserves `float16`. But due to smaller range of float16,
# Pytorch version outputs `inf` while SciPy returns finite values.
out = out.astype(np.float16)

return out

op_db_scipy_reference: List[OpInfo] = [
UnaryUfuncInfo('sigmoid',
ref=reference_sigmoid,
Expand Down Expand Up @@ -1491,6 +1514,27 @@ def reference_sigmoid(x):
dtypes=[torch.bfloat16]),
)
),
UnaryUfuncInfo('lgamma',
ref=reference_lgamma,
decorators=(precisionOverride({torch.float16: 7e-1}),),
dtypes=all_types_and(torch.bool),
dtypesIfCPU=all_types_and(torch.bool, torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.bool, torch.half),
skips=(
# Reference: https://github.com/pytorch/pytorch/pull/50140#discussion_r552615345
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
dtypes=[torch.bfloat16]),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skipping Bfloat16

>>> import torch
>>> torch.tensor(500.0, dtype=torch.bfloat16)
tensor(500., dtype=torch.bfloat16)
>>> t = torch.tensor(500.0, dtype=torch.bfloat16)
>>> torch.lgamma(t)
tensor(2608., dtype=torch.bfloat16)
>>> torch.lgamma(t.to(torch.float32))
tensor(2605.1160)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good skip. This again highlights that we should break-up test_reference_numerics into different value ranges (like small values, large values, extremal values, for example).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful if ranges can be specified per dtype or for group of dtype (maybe something similar to precision-override), because we don't want to decrease the range of all dtypes just to be able to support one dtype with smaller range.

# Reference: https://github.com/pytorch/pytorch/pull/50140#issuecomment-756150214
SkipInfo('TestUnaryUfuncs', 'test_reference_numerics',
dtypes=[torch.float32, torch.float64], active_if=IS_WINDOWS),
# Backward of `lgamma` uses `digamma` but `digamma`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting. From the perspective of this PR this skip is correct, but I thought test_variant_consistency_jit wrapped the backward call in a try/except block to catch these errors?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have verified. I don't see any try/execpt block in the failing test.

Stack Trace
======================================================================
ERROR: test_variant_consistency_jit_lgamma_cpu_bfloat16 (__main__.TestCommonCPU)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/home/kshiteej/.conda/envs/PytorchENV/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 284, in instantiated_test
    raise rte
  File "/home/kshiteej/.conda/envs/PytorchENV/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 279, in instantiated_test
    result = test_fn(self, *args)
  File "/home/kshiteej/.conda/envs/PytorchENV/lib/python3.7/site-packages/torch/testing/_internal/common_device_type.py", line 253, in test_wrapper
    return test(*args, **kwargs)
  File "test/test_ops.py", line 290, in test_variant_consistency_jit
    no_grad=(dtype not in dtypes_to_grad_check))
  File "/home/kshiteej/.conda/envs/PytorchENV/lib/python3.7/site-packages/torch/testing/_internal/common_jit.py", line 77, in check_against_reference
    allow_unused=allow_unused)
  File "/home/kshiteej/.conda/envs/PytorchENV/lib/python3.7/site-packages/torch/autograd/__init__.py", line 225, in grad
    inputs, allow_unused, accumulate_grad=False)
RuntimeError: "digamma" not implemented for 'BFloat16'

----------------------------------------------------------------------
Ran 94 tests in 5.282s

FAILED (errors=1, skipped=8)

Code fails in check_against_reference

pytorch/test/test_ops.py

Lines 272 to 300 in 5546a12

with disable_autodiff_subgraph_inlining():
def fn(*inputs, **kwargs):
output = func(*inputs, **kwargs)
return op.output_func(output)
# bfloat16 grad doesn't work for some operators
dtypes_to_grad_check = floating_and_complex_types_and(torch.half) \
if op.skip_bfloat16_grad else floating_and_complex_types_and(torch.half, torch.bfloat16)
# Check scripted forward, grad, and grad grad
script_fn = create_script_fn(self, name, func_type, op.output_func)
check_against_reference(self,
script_fn,
fn,
(*sample.input,) + sample.args,
sample.kwargs,
no_grad=(dtype not in dtypes_to_grad_check))
# Check traced forward, grad, and grad grad
traced_fn = create_traced_fn(self, variant)
check_against_reference(self,
traced_fn,
fn,
(*sample.input,) + sample.args,
sample.kwargs,
no_grad=(dtype not in dtypes_to_grad_check))
# Check alias annotation schema for correctness (make

The code for check_against_reference also doesn't have any try/except

def check_against_reference(self, func, reference_func, args, kwargs=None,
allow_unused=True, check_types=True, no_grad=False):
kwargs = kwargs if kwargs else {}
def allSum(vs):
if isinstance(vs, torch.Tensor):
vs = (vs,)
return sum((i + 1) * v.sum()
for i, v in enumerate(vs)
if v is not None and v.dtype in floating_and_complex_types_and(torch.half, torch.bfloat16))
def clone_inputs(requires_grad):
inputs = [
arg.detach().clone().requires_grad_(requires_grad and arg.requires_grad)
if isinstance(arg, torch.Tensor) else arg for arg in args
]
return inputs, [input for input in inputs if isinstance(input, torch.Tensor) and input.requires_grad]
nograd_inputs, nograd_tensors = clone_inputs(False)
recording_inputs, recording_tensors = clone_inputs(True)
# test no gradients case
outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs)
with enable_profiling_mode_for_profiling_tests():
outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
self.assertEqual(outputs, outputs_test)
if check_types:
check_output_types(self, func, outputs_test, nograd_inputs, kwargs)
if no_grad:
# skip grad tests
return
with enable_profiling_mode_for_profiling_tests():
# test single grad case
outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs)
grads = torch.autograd.grad(allSum(outputs), recording_tensors,
allow_unused=allow_unused)
outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors,
allow_unused=allow_unused)
self.assertEqual(outputs, outputs_test)
self.assertEqual(grads, grads_test)
# test the grad grad case
if self._testMethodName in nn_functional_single_grad:
return
outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs)
l1 = allSum(outputs)
grads = torch.autograd.grad(l1, recording_tensors, create_graph=True,
allow_unused=allow_unused)
l2 = (allSum(grads) * l1)
grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused)
recording_inputs, recording_tensors = clone_inputs(True)
outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
l1_test = allSum(outputs_test)
grads_test = torch.autograd.grad(
l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused)
l2_test = (allSum(grads_test) * l1_test)
grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused)
self.assertEqual(outputs, outputs_test)
self.assertEqual(grads, grads_test)
for g2, g2_test in zip(grads2, grads2_test):
if g2 is None and g2_test is None:
continue
self.assertTrue(torch.allclose(g2, g2_test, atol=5e-4, rtol=1e-4))

Note:
The test_variant_consistency_eager test has a try/except block which checks if the eager mode raises or not.

pytorch/test/test_ops.py

Lines 183 to 237 in 5546a12

# Tests that the forward and backward passes of operations produce the
# same values for the cross-product of op variants (method, inplace)
# against eager's gold standard op function variant
@ops(op_db)
def test_variant_consistency_eager(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=True)
if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")
for sample in samples:
# Acquires variants to test
method = op.get_method()
inplace = op.get_inplace()
variants = (v for v in (method, inplace) if v is not None)
# Computes expected forward
# below calls op's function variant
expected_forward = op(*sample.input, *sample.args, **sample.kwargs)
# Computes expected backward
# NOTE: backward may fail for some dtypes
exception_during_backwards = False
expected_grad = None
try:
expected_forward.sum().backward()
expected_grad = sample.input.grad
sample.input.grad = None
except Exception as e:
exception_during_backwards = True
# Test eager consistency
for variant in variants:
# Verifies that inplace operations that promote int->float fail
# on tensors with integer dtypes.
if (variant is inplace and op.promotes_integers_to_float and
dtype in (torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)):
try:
variant_forward = variant(*(clone_input_helper(input) for input in sample.input),
*sample.args,
**sample.kwargs)
except Exception as e:
continue
self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!")
# Compares variant's forward
# Note: copy the tensor-type inputs when testing inplace operation
variant_forward = variant(*(clone_input_helper(input) if variant is inplace else input
for input in sample.input),
*sample.args,
**sample.kwargs)
self.assertEqual(variant_forward, expected_forward)
# Compares variant's backward
if variant is not inplace or op.test_inplace_grad:
self.check_variant_backward(sample.input, variant_forward,
expected_grad, exception_during_backwards)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha! That's why I was confused. Thank you for verifying. I'll update the test suite tracker.

# is not implemented for `BFloat16`
# Error Raised:
# RuntimeError: "digamma" not implemented for 'BFloat16'
SkipInfo('TestCommon', 'test_variant_consistency_jit',
dtypes=[torch.bfloat16]),
),
promotes_integers_to_float=True),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR has to be rebased because this will cause a logical merge conflict with @peterbell10's

0436ea1.

I think it just needs to be changed to safe_casts_outputs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Updated.

OpInfo('xlogy',
dtypes=all_types_and(torch.bool),
dtypesIfCPU=all_types_and(torch.bool, torch.half, torch.bfloat16),
Expand Down