Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

asarray: take the default device into consideration. #106779

Closed
wants to merge 7 commits into from
23 changes: 23 additions & 0 deletions test/test_tensor_creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3966,6 +3966,29 @@ def test_numpy_scalars(self, device):
self.assertEqual(tensor.item(), zerodim_arr.item())
self.assertEqual(tensor.dtype, torch.int32)

def test_default_device(self, device):
import array

examples = [
3,
np.arange(5),
torch.arange(5),
array.array("f", [1, 2, 3, 4]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very minor: "f" means this constructs a floating-point number. which is tested for exact equality below; using "i" for integers may be a little more robust (however, tests seem happy, so feel free to ignore this comment)

Copy link
Collaborator Author

@ysiraichi ysiraichi Aug 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you are right. I used "f" here because I could just call asarray without any arguments (otherwise I had to pass the dtype). But, I guess that's more robust. Will change.

]

for data in examples:
with torch.device(device):
tensor = torch.asarray(data)
self.assertEqual(tensor.device, torch.device(device))

# Check the contents of the tensor.
if isinstance(data, int):
self.assertEqual(data, tensor.item())
else:
self.assertEqual(len(data), len(tensor))
for i in range(len(data)):
self.assertEqual(data[i], tensor[i])


instantiate_device_type_tests(TestTensorCreation, globals())
instantiate_device_type_tests(TestRandomTensorCreation, globals())
Expand Down
5 changes: 3 additions & 2 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,7 @@ def merge_dicts(*dicts):

When :attr:`obj` is none of the above but a scalar, or a sequence of scalars then the
returned tensor will, by default, infer its datatype from the scalar values, be on the
CPU device, and not share its memory.
current default device, and not share its memory.

.. seealso::

Expand All @@ -1282,7 +1282,8 @@ def merge_dicts(*dicts):
If ``False`` then the returned tensor shares its memory with :attr:`obj` and an
error is thrown if it cannot.
device (:class:`torch.device`, optional): the device of the returned tensor.
Default: ``None``, which causes the device of :attr:`obj` to be used.
Default: ``None``, which causes the device of :attr:`obj` to be used. Or, if
:attr:`obj` is a Python sequence, the current default device will be used.
requires_grad (bool, optional): whether the returned tensor requires grad.
Default: ``False``, which causes the returned tensor not to require a gradient.
If ``True``, then the returned tensor will require a gradient, and if :attr:`obj`
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/autograd/python_torch_functions_manual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,11 @@ static PyObject* THPVariable_asarray(
ParsedArgs<5> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);

if (r.has_torch_function()) {
return handle_torch_function(
r, nullptr, args, kwargs, THPVariableFunctionsModule, "torch");
}

if (r.idx == 0) {
auto obj = r.pyobject(0);
auto dtype = r.scalartypeOptional(1);
Expand Down
3 changes: 3 additions & 0 deletions torch/csrc/utils/tensor_new.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1633,6 +1633,9 @@ Tensor asarray(
bool force_alias = !copy.value_or(true);
bool should_warn_numpy_not_writable = false;

// Used when:
// 1. 'obj' implements the buffer protocol and no type is given.
// 2. creating a new tensor from a Python sequence.
auto dtype_unwrapped =
dtype.value_or(torch::tensors::get_default_scalar_type());

Expand Down
1 change: 1 addition & 0 deletions torch/utils/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def _device_constructors():
torch.tensor,
torch.as_tensor,
torch.scalar_tensor,
torch.asarray,
}

# NB: This is directly called from C++ in torch/csrc/Device.cpp
Expand Down