Skip to content

Commit

Permalink
torch.cuda.memory_allocated to return {} if not initialized (#51179)
Browse files Browse the repository at this point in the history
Summary:
Fixes #49952

Pull Request resolved: #51179

Reviewed By: ngimel

Differential Revision: D26094932

Pulled By: malfet

fbshipit-source-id: 0ec28ef9b0604245753d3f2b0e3536286700668d
  • Loading branch information
malfet authored and facebook-github-bot committed Jan 29, 2021
1 parent 916af89 commit 43f0ccd
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
10 changes: 10 additions & 0 deletions test/test_cuda.py
Expand Up @@ -3113,6 +3113,16 @@ def test_batch_norm_gather_stats(self):
self.assertEqual(mean, torch.ones(3, device='cuda'))
self.assertEqual(invstd, torch.ones(3, device='cuda'))

@unittest.skipIf(not TEST_MULTIGPU, "Test needs multiple GPUs")
def test_cuda_device_memory_allocated(self):
from torch.cuda import memory_allocated
device_count = torch.cuda.device_count()
current_alloc = [memory_allocated(idx) for idx in range(device_count)]
x = torch.ones(10, device="cuda:0")
self.assertTrue(memory_allocated(0) > current_alloc[0])
self.assertTrue(all(memory_allocated(torch.cuda.device(idx)) == current_alloc[idx] for idx in range(1, device_count)))


class TestCudaComm(TestCase):
def _test_broadcast(self, input):
if not TEST_MULTIGPU:
Expand Down
2 changes: 2 additions & 0 deletions torch/cuda/_utils.py
Expand Up @@ -29,6 +29,8 @@ def _get_device_index(device: Union[Device, str, int, None], optional: bool = Fa
raise ValueError('Expected a cuda or cpu device, but got: {}'.format(device))
elif device.type != 'cuda':
raise ValueError('Expected a cuda device, but got: {}'.format(device))
if isinstance(device, torch.cuda.device):
return device.idx
return _torch_get_device_index(device, optional, allow_cpu)


Expand Down
10 changes: 6 additions & 4 deletions torch/cuda/memory.py
Expand Up @@ -193,6 +193,8 @@ def _recurse_add_to_result(prefix, obj):

def memory_stats_as_nested_dict(device: Union[Device, int] = None) -> Dict[str, Any]:
r"""Returns the result of :func:`~torch.cuda.memory_stats` as a nested dictionary."""
if not is_initialized():
return {}
device = _get_device_index(device, optional=True)
return torch._C._cuda_memoryStats(device)

Expand Down Expand Up @@ -303,7 +305,7 @@ def memory_allocated(device: Union[Device, int] = None) -> int:
needs to be created on GPU. See :ref:`cuda-memory-management` for more
details about GPU memory management.
"""
return memory_stats(device=device)["allocated_bytes.all.current"]
return memory_stats(device=device).get("allocated_bytes.all.current", 0)


def max_memory_allocated(device: Union[Device, int] = None) -> int:
Expand All @@ -325,7 +327,7 @@ def max_memory_allocated(device: Union[Device, int] = None) -> int:
See :ref:`cuda-memory-management` for more details about GPU memory
management.
"""
return memory_stats(device=device)["allocated_bytes.all.peak"]
return memory_stats(device=device).get("allocated_bytes.all.peak", 0)


def memory_reserved(device: Union[Device, int] = None) -> int:
Expand All @@ -341,7 +343,7 @@ def memory_reserved(device: Union[Device, int] = None) -> int:
See :ref:`cuda-memory-management` for more details about GPU memory
management.
"""
return memory_stats(device=device)["reserved_bytes.all.current"]
return memory_stats(device=device).get("reserved_bytes.all.current", 0)


def max_memory_reserved(device: Union[Device, int] = None) -> int:
Expand All @@ -363,7 +365,7 @@ def max_memory_reserved(device: Union[Device, int] = None) -> int:
See :ref:`cuda-memory-management` for more details about GPU memory
management.
"""
return memory_stats(device=device)["reserved_bytes.all.peak"]
return memory_stats(device=device).get("reserved_bytes.all.peak", 0)


def memory_cached(device: Union[Device, int] = None) -> int:
Expand Down

0 comments on commit 43f0ccd

Please sign in to comment.