Skip to content

Commit

Permalink
Clarify, make consistent, and test the behavior of logspace when dtyp…
Browse files Browse the repository at this point in the history
…e is integral (#47647)

Summary:
torch.logspace doesn't seem to have explained how integers are handled.
Add some clarification and some test when dtype is integral.

The CUDA implementation is also updated to be consistent with CPU implementation.

Pull Request resolved: #47647

Reviewed By: gchanan

Differential Revision: D25843351

Pulled By: walterddr

fbshipit-source-id: 45237574d04c56992c18766667ff1ed71be77ac3
  • Loading branch information
xuhdev authored and facebook-github-bot committed Jan 15, 2021
1 parent 8e74024 commit 0ae0fac
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 7 deletions.
5 changes: 3 additions & 2 deletions aten/src/ATen/native/cuda/RangeFactories.cu
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,11 @@ Tensor& logspace_cuda_out(Tensor& result, Scalar start, Scalar end, c10::optiona
r.fill_(std::pow(base, start.to<double>()));
} else if (isIntegralType(r.scalar_type(), 0)) {
AT_DISPATCH_INTEGRAL_TYPES(r.scalar_type(), "logspace_cuda", [&]() {
float scalar_base = static_cast<float>(base); // Use float to avoid promotion to double
// We use double here to be consistent with CPU implementation
double scalar_base = static_cast<double>(base);
scalar_t scalar_start = start.to<scalar_t>();
scalar_t scalar_end = end.to<scalar_t>();
float step = static_cast<float>(scalar_end - scalar_start) / (steps - 1);
double step = static_cast<double>(scalar_end - scalar_start) / (steps - 1);
const int64_t halfway = steps / 2;
gpu_kernel_with_index(r, [scalar_start, scalar_end, scalar_base, steps, step, halfway]GPU_LAMBDA(int64_t ind) -> scalar_t {
if (ind < halfway) {
Expand Down
17 changes: 17 additions & 0 deletions test/test_tensor_creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2808,6 +2808,23 @@ def test_logspace(self, device, dtype):
y = torch.logspace(0, 3, 4, base=2, device=device, dtype=dtype, out=x.narrow(1, 1, 2))
self.assertEqual(x, torch.tensor(((0, 1, 2), (0, 4, 8)), device=device, dtype=dtype), atol=0, rtol=0)

@dtypes(torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
def test_logspace_integral(self, device, dtype):
"Check logspace with integer."
for from_, to in ((1, 5),
(1.2, 2),
(1.7, 4),
(2, 2.5)):
res1 = torch.logspace(from_, to, steps=10, device=device, dtype=dtype)
res2 = torch.logspace(int(from_), int(to), steps=10,
device=device, dtype=torch.double).floor().type(dtype)
self.assertEqual(res1, res2)
if not device.startswith('cpu'):
# Compare with CPU output
res2_cpu = torch.logspace(int(from_), int(to), steps=10,
device='cpu', dtype=torch.double).floor().type(dtype)
self.assertEqual(res1, res2_cpu)

@onlyOnCPUAndCUDA
@dtypes(torch.half, torch.float, torch.double)
def test_full_inference(self, device, dtype):
Expand Down
17 changes: 12 additions & 5 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4665,10 +4665,10 @@ def merge_dicts(*dicts):
out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
""" + r"""
Creates a one-dimensional tensor of size :attr:`steps` whose values are evenly
Creates a one-dimensional tensor of size ``steps`` whose values are evenly
spaced from :math:`{{\text{{base}}}}^{{\text{{start}}}}` to
:math:`{{\text{{base}}}}^{{\text{{end}}}}`, inclusive, on a logarithmic scale
with base :attr:`base`. That is, the values are:
with base ``base``. That is, the values are:
.. math::
(\text{base}^{\text{start}},
Expand All @@ -4678,12 +4678,15 @@ def merge_dicts(*dicts):
\text{base}^{\text{end}})
""" + """
If ``dtype`` is an integral type, ``start`` and ``end`` are cast as integers first and the returned tensor is the floor
of the returned tensor as if ``dtype`` were a floating-point type.
.. warning::
Not providing a value for :attr:`steps` is deprecated. For backwards
compatibility, not providing a value for :attr:`steps` will create a tensor
Not providing a value for ``steps`` is deprecated. For backwards
compatibility, not providing a value for ``steps`` will create a tensor
with 100 elements. Note that this behavior is not reflected in the
documented function signature and should not be relied on. In a future
PyTorch release, failing to provide a value for :attr:`steps` will throw a
PyTorch release, failing to provide a value for ``steps`` will throw a
runtime error.
Args:
Expand All @@ -4705,6 +4708,10 @@ def merge_dicts(*dicts):
tensor([ 1.0000e-10, 1.0000e-05, 1.0000e+00, 1.0000e+05, 1.0000e+10])
>>> torch.logspace(start=0.1, end=1.0, steps=5)
tensor([ 1.2589, 2.1135, 3.5481, 5.9566, 10.0000])
>>> torch.logspace(start=0.0, end=1.0, steps=5)
tensor([ 1.0000, 1.7783, 3.1623, 5.6234, 10.0000])
>>> torch.logspace(start=0.0, end=1.0, steps=5, dtype=torch.int)
tensor([1, 1, 3, 5, 10])
>>> torch.logspace(start=0.1, end=1.0, steps=1)
tensor([1.2589])
>>> torch.logspace(start=2, end=2, steps=1, base=2)
Expand Down

0 comments on commit 0ae0fac

Please sign in to comment.