Skip to content

Commit dc20b2d

Browse files
DLPack: add test using PyTorch DLPack functions. (#8294)
Co-authored-by: iefgnoix <isaacwxf23@gmail.com>
1 parent 8177447 commit dc20b2d

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

.torch_pin

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
#138470

test/test_operations.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2912,6 +2912,17 @@ def test_dlpack_xla_to_pytorch_cuda(self):
29122912
cuda_t1[0] = cuda_t1[0] + 20
29132913
self.assertTrue(torch.allclose(xla_t1.cpu(), cuda_t1.cpu()))
29142914

2915+
@onlyIfTorchSupportsCUDA
2916+
@onlyIfPJRTDeviceIsCUDA
2917+
def test_dlpack_xla_to_pytorch_cuda_protocol_conversion(self):
2918+
xla_t1 = torch.arange(5).to(xm.xla_device())
2919+
caps_t1 = torch.utils.dlpack.to_dlpack(xla_t1)
2920+
cuda_t1 = torch.utils.dlpack.from_dlpack(caps_t1)
2921+
self.assertEqual(cuda_t1.device.type, 'cuda')
2922+
self.assertEqual(cuda_t1.device.index, xla_t1.device.index)
2923+
cuda_t1[0] = cuda_t1[0] + 20
2924+
self.assertTrue(torch.allclose(xla_t1.cpu(), cuda_t1.cpu()))
2925+
29152926
@onlyIfTorchSupportsCUDA
29162927
@onlyIfPJRTDeviceIsCUDA
29172928
def test_dlpack_non_default_layout(self):

0 commit comments

Comments
 (0)