File tree Expand file tree Collapse file tree 2 files changed +12
-0
lines changed Expand file tree Collapse file tree 2 files changed +12
-0
lines changed Original file line number Diff line number Diff line change 1+ #138470
Original file line number Diff line number Diff line change @@ -2912,6 +2912,17 @@ def test_dlpack_xla_to_pytorch_cuda(self):
29122912 cuda_t1 [0 ] = cuda_t1 [0 ] + 20
29132913 self .assertTrue (torch .allclose (xla_t1 .cpu (), cuda_t1 .cpu ()))
29142914
2915+ @onlyIfTorchSupportsCUDA
2916+ @onlyIfPJRTDeviceIsCUDA
2917+ def test_dlpack_xla_to_pytorch_cuda_protocol_conversion (self ):
2918+ xla_t1 = torch .arange (5 ).to (xm .xla_device ())
2919+ caps_t1 = torch .utils .dlpack .to_dlpack (xla_t1 )
2920+ cuda_t1 = torch .utils .dlpack .from_dlpack (caps_t1 )
2921+ self .assertEqual (cuda_t1 .device .type , 'cuda' )
2922+ self .assertEqual (cuda_t1 .device .index , xla_t1 .device .index )
2923+ cuda_t1 [0 ] = cuda_t1 [0 ] + 20
2924+ self .assertTrue (torch .allclose (xla_t1 .cpu (), cuda_t1 .cpu ()))
2925+
29152926 @onlyIfTorchSupportsCUDA
29162927 @onlyIfPJRTDeviceIsCUDA
29172928 def test_dlpack_non_default_layout (self ):
You can’t perform that action at this time.
0 commit comments