Skip to content

Commit

Permalink
BUG: fix torch._numpy.arange(5, dtype="float32") (#110005)
Browse files Browse the repository at this point in the history
Make `np.arange` respect an explicitly provided dtype.

Also remove duplicated tests:
- torch_np/test_function_base.py::TestArange is a dupe of
- torch_np/numpy_tests/core/test_multiarray.py::TestArange

Fixes #109975

Pull Request resolved: #110005
Approved by: https://github.com/lezcano
  • Loading branch information
ev-br authored and pytorchmergebot committed Sep 28, 2023
1 parent 5f7eff0 commit 3603f64
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 95 deletions.
8 changes: 8 additions & 0 deletions test/torch_np/numpy_tests/core/test_multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6822,7 +6822,11 @@ def test_arange_booleans(self):
def test_error_paths_and_promotion(self, which):
args = [0, 1, 2] # start, stop, and step
args[which] = np.float64(2.0) # should ensure float64 output
assert np.arange(*args).dtype == np.float64

# repeat with non-empty ranges
args = [0, 8, 2]
args[which] = np.float64(2.0)
assert np.arange(*args).dtype == np.float64

# Cover stranger error path, test only to achieve code coverage!
Expand All @@ -6831,6 +6835,10 @@ def test_error_paths_and_promotion(self, which):
# Fails discovering start dtype
np.arange(*args)

@parametrize("dt", [np.float32, np.uint8, complex])
def test_explicit_dtype(self, dt):
assert np.arange(5.0, dtype=dt).dtype == dt


class TestRichcompareScalar(TestCase):
@xfail # (reason="comparison: builtin.bools or...?")
Expand Down
84 changes: 2 additions & 82 deletions test/torch_np/test_function_base.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,12 @@
# Owner(s): ["module: dynamo"]

from unittest import expectedFailure as xfail

import pytest

import torch._numpy as np
from pytest import raises as assert_raises
from torch._numpy.testing import assert_array_equal, assert_equal
from torch._numpy.testing import assert_equal

from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)


@instantiate_parametrized_tests
class TestArange(TestCase):
def test_infinite(self):
assert_raises(
(RuntimeError, ValueError), np.arange, 0, np.inf
) # "size exceeded",

def test_nan_step(self):
assert_raises(
(RuntimeError, ValueError), np.arange, 0, 1, np.nan
) # "cannot compute length",

def test_zero_step(self):
assert_raises(ZeroDivisionError, np.arange, 0, 10, 0)
assert_raises(ZeroDivisionError, np.arange, 0.0, 10.0, 0.0)

# empty range
assert_raises(ZeroDivisionError, np.arange, 0, 0, 0)
assert_raises(ZeroDivisionError, np.arange, 0.0, 0.0, 0.0)

def test_require_range(self):
assert_raises(TypeError, np.arange)
assert_raises(TypeError, np.arange, step=3)
assert_raises(TypeError, np.arange, dtype="int64")

@xfail # (reason="XXX: arange(start=0, stop, step=1)")
def test_require_range_2(self):
assert_raises(TypeError, np.arange, start=4)

def test_start_stop_kwarg(self):
keyword_stop = np.arange(stop=3)
keyword_zerotostop = np.arange(start=0, stop=3)
keyword_start_stop = np.arange(start=3, stop=9)

assert len(keyword_stop) == 3
assert len(keyword_zerotostop) == 3
assert len(keyword_start_stop) == 6
assert_array_equal(keyword_stop, keyword_zerotostop)

@xfail # (reason="XXX: arange(..., dtype=bool)")
def test_arange_booleans(self):
# Arange makes some sense for booleans and works up to length 2.
# But it is weird since `arange(2, 4, dtype=bool)` works.
# Arguably, much or all of this could be deprecated/removed.
res = np.arange(False, dtype=bool)
assert_array_equal(res, np.array([], dtype="bool"))

res = np.arange(True, dtype="bool")
assert_array_equal(res, [False])

res = np.arange(2, dtype="bool")
assert_array_equal(res, [False, True])

# This case is especially weird, but drops out without special case:
res = np.arange(6, 8, dtype="bool")
assert_array_equal(res, [True, True])

with pytest.raises(TypeError):
np.arange(3, dtype="bool")

@parametrize("which", [0, 1, 2])
def test_error_paths_and_promotion(self, which):
args = [0, 10, 2] # start, stop, and step
args[which] = np.float64(2.0) # should ensure float64 output
assert np.arange(*args).dtype == np.float64

# Cover stranger error path, test only to achieve code coverage!
args[which] = [None, []]
with pytest.raises((ValueError, RuntimeError)):
# Fails discovering start dtype
np.arange(*args)
from torch.testing._internal.common_utils import run_tests, TestCase


class TestAppend(TestCase):
Expand Down
12 changes: 12 additions & 0 deletions torch/_numpy/_dtypes_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,18 @@ def _dtype_for_scalar(py_type):
}[py_type]


def _dtype_for_scalar_or_tensor(x):
return x.dtype if isinstance(x, torch.Tensor) else _dtype_for_scalar(type(x))


def is_float_or_fp_tensor(x):
return _dtype_for_scalar_or_tensor(x).is_floating_point


def is_complex_or_complex_tensor(x):
return _dtype_for_scalar_or_tensor(x).is_complex


def _category(dtype):
return {
torch.bool: 0,
Expand Down
27 changes: 14 additions & 13 deletions torch/_numpy/_funcs_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,9 @@ def logspace(


def arange(
start: Optional[ArrayLike] = None,
stop: Optional[ArrayLike] = None,
step: Optional[ArrayLike] = 1,
start: Optional[ArrayLikeOrScalar] = None,
stop: Optional[ArrayLikeOrScalar] = None,
step: Optional[ArrayLikeOrScalar] = 1,
dtype: Optional[DTypeLike] = None,
*,
like: NotImplementedType = None,
Expand All @@ -359,22 +359,23 @@ def arange(

# the dtype of the result
if dtype is None:
dtype = _dtypes_impl.default_dtypes().int_dtype
# XXX: default values do not get normalized
start, stop, step = (_util._coerce_to_tensor(x) for x in (start, stop, step))

dummy = torch.empty(1, dtype=dtype)
target_dtype = _dtypes_impl.result_type_impl(start, stop, step, dummy)
dtype = (
_dtypes_impl.default_dtypes().float_dtype
if any(_dtypes_impl.is_float_or_fp_tensor(x) for x in (start, stop, step))
else _dtypes_impl.default_dtypes().int_dtype
)
work_dtype = torch.float64 if dtype.is_complex else dtype

# work around RuntimeError: "arange_cpu" not implemented for 'ComplexFloat'
work_dtype = torch.float64 if target_dtype.is_complex else target_dtype
# RuntimeError: "lt_cpu" not implemented for 'ComplexFloat'. Fall back to eager.
if any(_dtypes_impl.is_complex_or_complex_tensor(x) for x in (start, stop, step)):
raise NotImplementedError

if (step > 0 and start > stop) or (step < 0 and start < stop):
# empty range
return torch.empty(0, dtype=target_dtype)
return torch.empty(0, dtype=dtype)

result = torch.arange(start, stop, step, dtype=work_dtype)
result = _util.cast_if_needed(result, target_dtype)
result = _util.cast_if_needed(result, dtype)
return result


Expand Down

0 comments on commit 3603f64

Please sign in to comment.