Skip to content

Commit

Permalink
asarray: take the default device into consideration.
Browse files Browse the repository at this point in the history
Fix: #106773

This PR makes it so `asarray` takes the default device into consideration when called with
a Python sequence as the data.

ghstack-source-id: 0552a22c3828d69fcc3ce93c5fe44d6326f7de50
Pull Request resolved: #106779
  • Loading branch information
ysiraichi committed Aug 9, 2023
1 parent 03c9321 commit 421ee0d
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 2 deletions.
22 changes: 22 additions & 0 deletions test/test_tensor_creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import unittest
from itertools import product, combinations, combinations_with_replacement, permutations
import random
from typing import Any, Dict, List, Tuple

from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
Expand Down Expand Up @@ -3966,6 +3967,27 @@ def test_numpy_scalars(self, device):
self.assertEqual(tensor.item(), zerodim_arr.item())
self.assertEqual(tensor.dtype, torch.int32)

def test_default_device(self, device):
original = torch.arange(5)

examples: List[Tuple[Any, Dict]] = [
(3, {}),
(original, {}),
(to_numpy(original), {}),
(to_memview(original), {"dtype": original.dtype}),
]

for data, kwargs in examples:
with torch.device(device):
tensor = torch.asarray(data, **kwargs)
self.assertEqual(tensor.device, torch.device(device))

# Check the contents of the tensor.
if isinstance(data, int):
self.assertEqual(data, tensor.item())
else:
self.assertEqual(data, tensor)


instantiate_device_type_tests(TestTensorCreation, globals())
instantiate_device_type_tests(TestRandomTensorCreation, globals())
Expand Down
5 changes: 3 additions & 2 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,7 @@ def merge_dicts(*dicts):
When :attr:`obj` is none of the above but a scalar, or a sequence of scalars then the
returned tensor will, by default, infer its datatype from the scalar values, be on the
CPU device, and not share its memory.
current default device, and not share its memory.
.. seealso::
Expand All @@ -1282,7 +1282,8 @@ def merge_dicts(*dicts):
If ``False`` then the returned tensor shares its memory with :attr:`obj` and an
error is thrown if it cannot.
device (:class:`torch.device`, optional): the device of the returned tensor.
Default: ``None``, which causes the device of :attr:`obj` to be used.
Default: ``None``, which causes the device of :attr:`obj` to be used. Or, if
:attr:`obj` is a Python sequence, the current default device will be used.
requires_grad (bool, optional): whether the returned tensor requires grad.
Default: ``False``, which causes the returned tensor not to require a gradient.
If ``True``, then the returned tensor will require a gradient, and if :attr:`obj`
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/autograd/python_torch_functions_manual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,11 @@ static PyObject* THPVariable_asarray(
ParsedArgs<5> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);

if (r.has_torch_function()) {
return handle_torch_function(
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
}

if (r.idx == 0) {
auto obj = r.pyobject(0);
auto dtype = r.scalartypeOptional(1);
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/utils/tensor_new.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1633,6 +1633,9 @@ Tensor asarray(
bool force_alias = !copy.value_or(true);
bool should_warn_numpy_not_writable = false;

// Used when:
// 1. 'obj' implements the buffer protocol and no type is given.
// 2. creating a new tensor from a Python sequence.
auto dtype_unwrapped =
dtype.value_or(torch::tensors::get_default_scalar_type());

Expand Down
1 change: 1 addition & 0 deletions torch/utils/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def _device_constructors():
torch.tensor,
torch.as_tensor,
torch.scalar_tensor,
torch.asarray,
}

# NB: This is directly called from C++ in torch/csrc/Device.cpp
Expand Down

0 comments on commit 421ee0d

Please sign in to comment.