Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
Summary: Fixes #83069. Also move all the dlpack tests to a new file., `test_dlpack.py`. The fix involves always allocating a "strides" int array when converting to dlPack and deleting the strides when the capsule descructor is called. Then the strides are copied from the tensor, and `strides[i]` is set to `1` where `shape[i] < 2`. Pull Request resolved: #83158 Approved by: https://github.com/ezyang Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/4dfa6d28a139e8325fe9b255af86e8d1360ae7ee Reviewed By: weiwangmeta Differential Revision: D38947365 fbshipit-source-id: 3b621efe3fc8f1798f2d285cfa818548d3bfca7e
- Loading branch information
1 parent
3ae8459
commit ec42ee0
Showing
3 changed files
with
208 additions
and
168 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
# -*- coding: utf-8 -*- | ||
# Owner(s): ["module: tests"] | ||
|
||
import torch | ||
from torch.testing import make_tensor | ||
from torch.testing._internal.common_utils import TestCase, run_tests | ||
from torch.testing._internal.common_device_type import ( | ||
instantiate_device_type_tests, onlyCUDA, dtypes, skipMeta, | ||
onlyNativeDeviceTypes) | ||
from torch.testing._internal.common_dtype import all_types_and_complex_and | ||
from torch.utils.dlpack import from_dlpack, to_dlpack | ||
|
||
|
||
class TestTorchDlPack(TestCase): | ||
exact_dtype = True | ||
|
||
@skipMeta | ||
@onlyNativeDeviceTypes | ||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) | ||
def test_dlpack_capsule_conversion(self, device, dtype): | ||
# DLpack does not explicitly support bool (xref dmlc/dlpack#75) | ||
x = make_tensor((5,), dtype=dtype, device=device) | ||
z = from_dlpack(to_dlpack(x)) | ||
self.assertEqual(z, x) | ||
|
||
@skipMeta | ||
@onlyNativeDeviceTypes | ||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) | ||
def test_dlpack_protocol_conversion(self, device, dtype): | ||
x = make_tensor((5,), dtype=dtype, device=device) | ||
z = from_dlpack(x) | ||
self.assertEqual(z, x) | ||
|
||
@skipMeta | ||
@onlyNativeDeviceTypes | ||
def test_dlpack_shared_storage(self, device): | ||
x = make_tensor((5,), dtype=torch.float64, device=device) | ||
z = from_dlpack(to_dlpack(x)) | ||
z[0] = z[0] + 20.0 | ||
self.assertEqual(z, x) | ||
|
||
@skipMeta | ||
@onlyCUDA | ||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) | ||
def test_dlpack_conversion_with_streams(self, device, dtype): | ||
# Create a stream where the tensor will reside | ||
stream = torch.cuda.Stream() | ||
with torch.cuda.stream(stream): | ||
# Do an operation in the actual stream | ||
x = make_tensor((5,), dtype=dtype, device=device) + 1 | ||
# DLPack protocol helps establish a correct stream order | ||
# (hence data dependency) at the exchange boundary. | ||
# DLPack manages this synchronization for us, so we don't need to | ||
# explicitly wait until x is populated | ||
stream = torch.cuda.Stream() | ||
with torch.cuda.stream(stream): | ||
z = from_dlpack(x) | ||
stream.synchronize() | ||
self.assertEqual(z, x) | ||
|
||
@skipMeta | ||
@onlyNativeDeviceTypes | ||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) | ||
def test_from_dlpack(self, device, dtype): | ||
x = make_tensor((5,), dtype=dtype, device=device) | ||
y = torch.from_dlpack(x) | ||
self.assertEqual(x, y) | ||
|
||
@skipMeta | ||
@onlyNativeDeviceTypes | ||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) | ||
def test_from_dlpack_noncontinguous(self, device, dtype): | ||
x = make_tensor((25,), dtype=dtype, device=device).reshape(5, 5) | ||
|
||
y1 = x[0] | ||
y1_dl = torch.from_dlpack(y1) | ||
self.assertEqual(y1, y1_dl) | ||
|
||
y2 = x[:, 0] | ||
y2_dl = torch.from_dlpack(y2) | ||
self.assertEqual(y2, y2_dl) | ||
|
||
y3 = x[1, :] | ||
y3_dl = torch.from_dlpack(y3) | ||
self.assertEqual(y3, y3_dl) | ||
|
||
y4 = x[1] | ||
y4_dl = torch.from_dlpack(y4) | ||
self.assertEqual(y4, y4_dl) | ||
|
||
y5 = x.t() | ||
y5_dl = torch.from_dlpack(y5) | ||
self.assertEqual(y5, y5_dl) | ||
|
||
@skipMeta | ||
@onlyCUDA | ||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) | ||
def test_dlpack_conversion_with_diff_streams(self, device, dtype): | ||
stream_a = torch.cuda.Stream() | ||
stream_b = torch.cuda.Stream() | ||
# DLPack protocol helps establish a correct stream order | ||
# (hence data dependency) at the exchange boundary. | ||
# the `tensor.__dlpack__` method will insert a synchronization event | ||
# in the current stream to make sure that it was correctly populated. | ||
with torch.cuda.stream(stream_a): | ||
x = make_tensor((5,), dtype=dtype, device=device) + 1 | ||
z = torch.from_dlpack(x.__dlpack__(stream_b.cuda_stream)) | ||
stream_a.synchronize() | ||
stream_b.synchronize() | ||
self.assertEqual(z, x) | ||
|
||
@skipMeta | ||
@onlyNativeDeviceTypes | ||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) | ||
def test_from_dlpack_dtype(self, device, dtype): | ||
x = make_tensor((5,), dtype=dtype, device=device) | ||
y = torch.from_dlpack(x) | ||
assert x.dtype == y.dtype | ||
|
||
@skipMeta | ||
@onlyCUDA | ||
def test_dlpack_default_stream(self, device): | ||
class DLPackTensor: | ||
def __init__(self, tensor): | ||
self.tensor = tensor | ||
|
||
def __dlpack_device__(self): | ||
return self.tensor.__dlpack_device__() | ||
|
||
def __dlpack__(self, stream=None): | ||
if torch.version.hip is None: | ||
assert stream == 1 | ||
else: | ||
assert stream == 0 | ||
capsule = self.tensor.__dlpack__(stream) | ||
return capsule | ||
|
||
# CUDA-based tests runs on non-default streams | ||
with torch.cuda.stream(torch.cuda.default_stream()): | ||
x = DLPackTensor(make_tensor((5,), dtype=torch.float32, device=device)) | ||
from_dlpack(x) | ||
|
||
@skipMeta | ||
@onlyNativeDeviceTypes | ||
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16)) | ||
def test_dlpack_tensor_invalid_stream(self, device, dtype): | ||
with self.assertRaises(TypeError): | ||
x = make_tensor((5,), dtype=dtype, device=device) | ||
x.__dlpack__(stream=object()) | ||
|
||
@skipMeta | ||
def test_dlpack_error_on_bool_tensor(self): | ||
x = torch.tensor([True], dtype=torch.bool) | ||
with self.assertRaises(RuntimeError): | ||
to_dlpack(x) | ||
|
||
# TODO: add interchange tests once NumPy 1.22 (dlpack support) is required | ||
@skipMeta | ||
def test_dlpack_export_requires_grad(self): | ||
x = torch.zeros(10, dtype=torch.float32, requires_grad=True) | ||
with self.assertRaisesRegex(RuntimeError, r"require gradient"): | ||
x.__dlpack__() | ||
|
||
@skipMeta | ||
def test_dlpack_export_is_conj(self): | ||
x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j]) | ||
y = torch.conj(x) | ||
with self.assertRaisesRegex(RuntimeError, r"conjugate bit"): | ||
y.__dlpack__() | ||
|
||
@skipMeta | ||
def test_dlpack_export_non_strided(self): | ||
x = torch.sparse_coo_tensor([[0]], [1], size=(1,)) | ||
y = torch.conj(x) | ||
with self.assertRaisesRegex(RuntimeError, r"strided"): | ||
y.__dlpack__() | ||
|
||
@skipMeta | ||
def test_dlpack_normalize_strides(self): | ||
x = torch.rand(16) | ||
y = x[::3][:1] | ||
self.assertEqual(y.shape, (1,)) | ||
self.assertEqual(y.stride(), (3,)) | ||
z = from_dlpack(y) | ||
self.assertEqual(z.shape, (1,)) | ||
# gh-83069, make sure __dlpack__ normalizes strides | ||
self.assertEqual(z.stride(), (1,)) | ||
|
||
|
||
instantiate_device_type_tests(TestTorchDlPack, globals()) | ||
|
||
if __name__ == '__main__': | ||
run_tests() |
Oops, something went wrong.