diff --git a/test/test_cuda.py b/test/test_cuda.py index 3d52c99df856..ed1d82a91112 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -3113,6 +3113,16 @@ def test_batch_norm_gather_stats(self): self.assertEqual(mean, torch.ones(3, device='cuda')) self.assertEqual(invstd, torch.ones(3, device='cuda')) + @unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs") + def test_cuda_device_memory_allocated(self): + from torch.cuda import memory_allocated + device_count = torch.cuda.device_count() + current_alloc = [memory_allocated(idx) for idx in range(device_count)] + x = torch.ones(10, device="cuda:0") + self.assertTrue(memory_allocated(0) > current_alloc[0]) + self.assertTrue(all(memory_allocated(torch.cuda.device(idx)) == current_alloc[idx] for idx in range(1, device_count))) + + class TestCudaComm(TestCase): def _test_broadcast(self, input): if not TEST_MULTIGPU: diff --git a/torch/cuda/_utils.py b/torch/cuda/_utils.py index 8f4105623a98..7f6b70a037bb 100644 --- a/torch/cuda/_utils.py +++ b/torch/cuda/_utils.py @@ -29,6 +29,8 @@ def _get_device_index(device: Union[Device, str, int, None], optional: bool = Fa raise ValueError('Expected a cuda or cpu device, but got: {}'.format(device)) elif device.type != 'cuda': raise ValueError('Expected a cuda device, but got: {}'.format(device)) + if isinstance(device, torch.cuda.device): + return device.idx return _torch_get_device_index(device, optional, allow_cpu) diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index c0bde95de741..102ca1cb2e9f 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -193,6 +193,8 @@ def _recurse_add_to_result(prefix, obj): def memory_stats_as_nested_dict(device: Union[Device, int] = None) -> Dict[str, Any]: r"""Returns the result of :func:`~torch.cuda.memory_stats` as a nested dictionary.""" + if not is_initialized(): + return {} device = _get_device_index(device, optional=True) return torch._C._cuda_memoryStats(device) @@ -303,7 +305,7 @@ def memory_allocated(device: Union[Device, int] = None) -> int: needs to be created on GPU. See :ref:`cuda-memory-management` for more details about GPU memory management. """ - return memory_stats(device=device)["allocated_bytes.all.current"] + return memory_stats(device=device).get("allocated_bytes.all.current", 0) def max_memory_allocated(device: Union[Device, int] = None) -> int: @@ -325,7 +327,7 @@ def max_memory_allocated(device: Union[Device, int] = None) -> int: See :ref:`cuda-memory-management` for more details about GPU memory management. """ - return memory_stats(device=device)["allocated_bytes.all.peak"] + return memory_stats(device=device).get("allocated_bytes.all.peak", 0) def memory_reserved(device: Union[Device, int] = None) -> int: @@ -341,7 +343,7 @@ def memory_reserved(device: Union[Device, int] = None) -> int: See :ref:`cuda-memory-management` for more details about GPU memory management. """ - return memory_stats(device=device)["reserved_bytes.all.current"] + return memory_stats(device=device).get("reserved_bytes.all.current", 0) def max_memory_reserved(device: Union[Device, int] = None) -> int: @@ -363,7 +365,7 @@ def max_memory_reserved(device: Union[Device, int] = None) -> int: See :ref:`cuda-memory-management` for more details about GPU memory management. """ - return memory_stats(device=device)["reserved_bytes.all.peak"] + return memory_stats(device=device).get("reserved_bytes.all.peak", 0) def memory_cached(device: Union[Device, int] = None) -> int: