From 7079b8587f0e0342c5d82d63041400160049e81d Mon Sep 17 00:00:00 2001 From: Gregory Chanan Date: Wed, 12 Sep 2018 12:49:08 -0700 Subject: [PATCH 1/4] Only involve tensor device in CUDA -> CPU copy, not current device. This also unifies the device usage between the async and sync case. Fixes https://github.com/pytorch/pytorch/issues/10832. --- aten/src/THC/generic/THCTensorCopy.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/aten/src/THC/generic/THCTensorCopy.cpp b/aten/src/THC/generic/THCTensorCopy.cpp index 96ab307182639..0c20edfbd9fd3 100644 --- a/aten/src/THC/generic/THCTensorCopy.cpp +++ b/aten/src/THC/generic/THCTensorCopy.cpp @@ -58,6 +58,13 @@ void THTensor_(copyCuda)(THCState *state, THTensor *self, struct THCTensor *src) { THTensor *selfc = THTensor_(newContiguous)(self); + int tensorDevice = THCTensor_(getDevice)(state, src); + int currentDevice; + THCudaCheck(cudaGetDevice(¤tDevice)); + + if (currentDevice != tensorDevice) { + THCudaCheck(cudaSetDevice(tensorDevice)); + } src = THCTensor_(newContiguous)(state, src); cudaStream_t stream = THCState_getCurrentStream(state); @@ -68,6 +75,10 @@ void THTensor_(copyCuda)(THCState *state, THTensor *self, struct THCTensor *src) stream)); THCudaCheck(cudaStreamSynchronize(stream)); + if (currentDevice != tensorDevice) { + THCudaCheck(cudaSetDevice(currentDevice)); + } + THCTensor_(free)(state, src); THTensor_(freeCopyTo)(selfc, self); } From 635b6cdd3bf6985c4c75a3d4d35b35d3395781f4 Mon Sep 17 00:00:00 2001 From: Gregory Chanan Date: Wed, 12 Sep 2018 14:10:14 -0700 Subject: [PATCH 2/4] Add test. --- test/run_test.py | 1 + test/test_cuda_primary_ctx.py | 56 +++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+) create mode 100644 test/test_cuda_primary_ctx.py diff --git a/test/run_test.py b/test/run_test.py index d7af8e47ab876..aa72d412cf3ac 100644 --- a/test/run_test.py +++ b/test/run_test.py @@ -21,6 +21,7 @@ 'cpp_extensions', 'c10d', 'cuda', + 'cuda_primary_ctx.py' 'dataloader', 'distributed', 'distributions', diff --git a/test/test_cuda_primary_ctx.py b/test/test_cuda_primary_ctx.py new file mode 100644 index 0000000000000..4e788fbcacd3b --- /dev/null +++ b/test/test_cuda_primary_ctx.py @@ -0,0 +1,56 @@ +import ctypes +import torch +from common import TestCase, run_tests, skipIfRocm +import unittest + +# NOTE: this needs to be run in a brand new process + +# We cannot import TEST_CUDA and TEST_MULTIGPU from common_cuda here, +# because if we do that, the TEST_CUDNN line from common_cuda will be executed +# multiple times as well during the execution of this test suite, and it will +# cause CUDA OOM error on Windows. +TEST_CUDA = torch.cuda.is_available() +TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2 + +if not TEST_CUDA: + print('CUDA not available, skipping tests') + TestCase = object # noqa: F811 + +def get_is_primary_context_created(device): + flags = ctypes.cast( (ctypes.c_uint*1)(), ctypes.POINTER(ctypes.c_uint) ) + active = ctypes.cast( (ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int) ) + result = torch.cuda.cudart().cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active) + assert result == 0, 'cuDevicePrimaryCtxGetState failed' + return bool(active[0]) + +class TestCudaPrimaryCtx(TestCase): + @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") + @skipIfRocm + def test_cuda_primary_ctx(self): + # Ensure context has not been created beforehand + self.assertFalse(get_is_primary_context_created(0)) + self.assertFalse(get_is_primary_context_created(1)) + + x = torch.randn(1, device='cuda:1') + + # We should have only created context on 'cuda:1' + self.assertFalse(get_is_primary_context_created(0)) + self.assertTrue(get_is_primary_context_created(1)) + + print(x) + + # We should still have only created context on 'cuda:1' + self.assertFalse(get_is_primary_context_created(0)) + self.assertTrue(get_is_primary_context_created(1)) + + y = torch.randn(1, device='cpu') + y.copy_(x) + + # We should still have only created context on 'cuda:1' + self.assertFalse(get_is_primary_context_created(0)) + self.assertTrue(get_is_primary_context_created(1)) + + # DO NOT ADD ANY OTHER TESTS HERE! ABOVE TEST REQUIRES FRESH PROCESS + +if __name__ == '__main__': + run_tests() From 9d2e8b4d5cd3a3fd70905f969b810ff9da454533 Mon Sep 17 00:00:00 2001 From: Gregory Chanan Date: Thu, 13 Sep 2018 10:34:38 -0700 Subject: [PATCH 3/4] Correct run_test. --- test/run_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/run_test.py b/test/run_test.py index aa72d412cf3ac..1e3c2f60e1df3 100644 --- a/test/run_test.py +++ b/test/run_test.py @@ -21,7 +21,7 @@ 'cpp_extensions', 'c10d', 'cuda', - 'cuda_primary_ctx.py' + 'cuda_primary_ctx', 'dataloader', 'distributed', 'distributions', From e87748b1d3d4301e8aea7495e42c51a5996c9aff Mon Sep 17 00:00:00 2001 From: Gregory Chanan Date: Thu, 13 Sep 2018 11:46:45 -0700 Subject: [PATCH 4/4] Fix flake8. --- test/test_cuda_primary_ctx.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_cuda_primary_ctx.py b/test/test_cuda_primary_ctx.py index 4e788fbcacd3b..2006b340aa22d 100644 --- a/test/test_cuda_primary_ctx.py +++ b/test/test_cuda_primary_ctx.py @@ -16,13 +16,15 @@ print('CUDA not available, skipping tests') TestCase = object # noqa: F811 + def get_is_primary_context_created(device): - flags = ctypes.cast( (ctypes.c_uint*1)(), ctypes.POINTER(ctypes.c_uint) ) - active = ctypes.cast( (ctypes.c_int*1)(), ctypes.POINTER(ctypes.c_int) ) + flags = ctypes.cast((ctypes.c_uint * 1)(), ctypes.POINTER(ctypes.c_uint)) + active = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int)) result = torch.cuda.cudart().cuDevicePrimaryCtxGetState(ctypes.c_int(device), flags, active) assert result == 0, 'cuDevicePrimaryCtxGetState failed' return bool(active[0]) + class TestCudaPrimaryCtx(TestCase): @unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected") @skipIfRocm