Skip to content

Commit

Permalink
Normalize DLPack stride to 1 where shape < 2 (#83158) (#83158)
Browse files Browse the repository at this point in the history
Summary:
Fixes #83069. Also move all the dlpack tests to a new file., `test_dlpack.py`.

The fix involves always allocating a "strides" int array when converting to dlPack and deleting the strides when the capsule descructor is called. Then the strides are copied from the tensor, and `strides[i]` is set to `1` where `shape[i] < 2`.

Pull Request resolved: #83158
Approved by: https://github.com/ezyang

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/4dfa6d28a139e8325fe9b255af86e8d1360ae7ee

Reviewed By: weiwangmeta

Differential Revision: D38947365

fbshipit-source-id: 3b621efe3fc8f1798f2d285cfa818548d3bfca7e
  • Loading branch information
mattip authored and facebook-github-bot committed Aug 24, 2022
1 parent 3ae8459 commit ec42ee0
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 168 deletions.
19 changes: 15 additions & 4 deletions aten/src/ATen/DLConvertor.cpp
Expand Up @@ -215,11 +215,22 @@ void deleter(DLManagedTensor* arg) {
// This function returns a shared_ptr to memory managed DLpack tensor
// constructed out of ATen tensor
DLManagedTensor* toDLPack(const Tensor& src) {
// create a new tensor with possibly normalized strides
// gh-83069
auto shape = src.sizes();
auto strides = src.strides().vec();
for (int i=0; i<src.dim(); i++) {
if (shape[i] < 2) {
strides[i] = 1;
}
}

auto view = src.as_strided(shape, strides, src.storage_offset());
ATenDLMTensor* atDLMTensor(new ATenDLMTensor);
atDLMTensor->handle = src;
atDLMTensor->handle = view;
atDLMTensor->tensor.manager_ctx = atDLMTensor;
atDLMTensor->tensor.deleter = &deleter;
atDLMTensor->tensor.dl_tensor.data = src.data_ptr();
atDLMTensor->tensor.dl_tensor.data = view.data_ptr();
int64_t device_id = 0;
if (src.is_cuda()) {
device_id = src.get_device();
Expand All @@ -229,10 +240,10 @@ DLManagedTensor* toDLPack(const Tensor& src) {
atDLMTensor->tensor.dl_tensor.dtype = getDLDataType(src);
atDLMTensor->tensor.dl_tensor.shape =
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<int64_t*>(src.sizes().data());
const_cast<int64_t*>(view.sizes().data());
atDLMTensor->tensor.dl_tensor.strides =
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
const_cast<int64_t*>(src.strides().data());
const_cast<int64_t*>(view.strides().data());
atDLMTensor->tensor.dl_tensor.byte_offset = 0;
return &(atDLMTensor->tensor);
}
Expand Down
193 changes: 193 additions & 0 deletions test/test_dlpack.py
@@ -0,0 +1,193 @@
# -*- coding: utf-8 -*-
# Owner(s): ["module: tests"]

import torch
from torch.testing import make_tensor
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, onlyCUDA, dtypes, skipMeta,
onlyNativeDeviceTypes)
from torch.testing._internal.common_dtype import all_types_and_complex_and
from torch.utils.dlpack import from_dlpack, to_dlpack


class TestTorchDlPack(TestCase):
exact_dtype = True

@skipMeta
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
def test_dlpack_capsule_conversion(self, device, dtype):
# DLpack does not explicitly support bool (xref dmlc/dlpack#75)
x = make_tensor((5,), dtype=dtype, device=device)
z = from_dlpack(to_dlpack(x))
self.assertEqual(z, x)

@skipMeta
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
def test_dlpack_protocol_conversion(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
z = from_dlpack(x)
self.assertEqual(z, x)

@skipMeta
@onlyNativeDeviceTypes
def test_dlpack_shared_storage(self, device):
x = make_tensor((5,), dtype=torch.float64, device=device)
z = from_dlpack(to_dlpack(x))
z[0] = z[0] + 20.0
self.assertEqual(z, x)

@skipMeta
@onlyCUDA
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
def test_dlpack_conversion_with_streams(self, device, dtype):
# Create a stream where the tensor will reside
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
# Do an operation in the actual stream
x = make_tensor((5,), dtype=dtype, device=device) + 1
# DLPack protocol helps establish a correct stream order
# (hence data dependency) at the exchange boundary.
# DLPack manages this synchronization for us, so we don't need to
# explicitly wait until x is populated
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
z = from_dlpack(x)
stream.synchronize()
self.assertEqual(z, x)

@skipMeta
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
def test_from_dlpack(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
y = torch.from_dlpack(x)
self.assertEqual(x, y)

@skipMeta
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
def test_from_dlpack_noncontinguous(self, device, dtype):
x = make_tensor((25,), dtype=dtype, device=device).reshape(5, 5)

y1 = x[0]
y1_dl = torch.from_dlpack(y1)
self.assertEqual(y1, y1_dl)

y2 = x[:, 0]
y2_dl = torch.from_dlpack(y2)
self.assertEqual(y2, y2_dl)

y3 = x[1, :]
y3_dl = torch.from_dlpack(y3)
self.assertEqual(y3, y3_dl)

y4 = x[1]
y4_dl = torch.from_dlpack(y4)
self.assertEqual(y4, y4_dl)

y5 = x.t()
y5_dl = torch.from_dlpack(y5)
self.assertEqual(y5, y5_dl)

@skipMeta
@onlyCUDA
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
def test_dlpack_conversion_with_diff_streams(self, device, dtype):
stream_a = torch.cuda.Stream()
stream_b = torch.cuda.Stream()
# DLPack protocol helps establish a correct stream order
# (hence data dependency) at the exchange boundary.
# the `tensor.__dlpack__` method will insert a synchronization event
# in the current stream to make sure that it was correctly populated.
with torch.cuda.stream(stream_a):
x = make_tensor((5,), dtype=dtype, device=device) + 1
z = torch.from_dlpack(x.__dlpack__(stream_b.cuda_stream))
stream_a.synchronize()
stream_b.synchronize()
self.assertEqual(z, x)

@skipMeta
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
def test_from_dlpack_dtype(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
y = torch.from_dlpack(x)
assert x.dtype == y.dtype

@skipMeta
@onlyCUDA
def test_dlpack_default_stream(self, device):
class DLPackTensor:
def __init__(self, tensor):
self.tensor = tensor

def __dlpack_device__(self):
return self.tensor.__dlpack_device__()

def __dlpack__(self, stream=None):
if torch.version.hip is None:
assert stream == 1
else:
assert stream == 0
capsule = self.tensor.__dlpack__(stream)
return capsule

# CUDA-based tests runs on non-default streams
with torch.cuda.stream(torch.cuda.default_stream()):
x = DLPackTensor(make_tensor((5,), dtype=torch.float32, device=device))
from_dlpack(x)

@skipMeta
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16))
def test_dlpack_tensor_invalid_stream(self, device, dtype):
with self.assertRaises(TypeError):
x = make_tensor((5,), dtype=dtype, device=device)
x.__dlpack__(stream=object())

@skipMeta
def test_dlpack_error_on_bool_tensor(self):
x = torch.tensor([True], dtype=torch.bool)
with self.assertRaises(RuntimeError):
to_dlpack(x)

# TODO: add interchange tests once NumPy 1.22 (dlpack support) is required
@skipMeta
def test_dlpack_export_requires_grad(self):
x = torch.zeros(10, dtype=torch.float32, requires_grad=True)
with self.assertRaisesRegex(RuntimeError, r"require gradient"):
x.__dlpack__()

@skipMeta
def test_dlpack_export_is_conj(self):
x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
y = torch.conj(x)
with self.assertRaisesRegex(RuntimeError, r"conjugate bit"):
y.__dlpack__()

@skipMeta
def test_dlpack_export_non_strided(self):
x = torch.sparse_coo_tensor([[0]], [1], size=(1,))
y = torch.conj(x)
with self.assertRaisesRegex(RuntimeError, r"strided"):
y.__dlpack__()

@skipMeta
def test_dlpack_normalize_strides(self):
x = torch.rand(16)
y = x[::3][:1]
self.assertEqual(y.shape, (1,))
self.assertEqual(y.stride(), (3,))
z = from_dlpack(y)
self.assertEqual(z.shape, (1,))
# gh-83069, make sure __dlpack__ normalizes strides
self.assertEqual(z.stride(), (1,))


instantiate_device_type_tests(TestTorchDlPack, globals())

if __name__ == '__main__':
run_tests()

0 comments on commit ec42ee0

Please sign in to comment.