Skip to content

Conversation

zdevito
Copy link
Contributor

@zdevito zdevito commented Apr 21, 2023

Stack from ghstack (oldest at bottom):

This PR adds calls to nvml during an OOM to find out the total memory
in use by the process and any other CUDA processes on the device.

This makes it easier to identify cases where non-PyTorch libraries have
allocated memory or another process (such as a data loader) has also
allocated memory on the device.

This also rewords the other parts of the error message to make the meaning
of the memory statistics more clear with this new information:

"""
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 138.00 MiB.
GPU 0 has a total capacty of 15.90 GiB of which 8.44 MiB is free.
Process 1246069 has 577.00 MiB memory in use. Including non-PyTorch memory,
this process has 15.32 GiB memory in use. Of the allocated memory
14.12 GiB is allocated by PyTorch, and 410.41 MiB is reserved
by PyTorch but unallocated. If reserved but unallocated memory is large
try setting max_split_size_mb to avoid fragmentation. See documentation
for Memory Management and PYTORCH_CUDA_ALLOC_CONF
"""

This PR adds calls to nvml during an OOM to find out the total memory
in use by the process and any other CUDA processes on the device.

This makes it easier to identify cases where non-PyTorch libraries have
allocated memory or another process (such as a data loader) has also
allocated memory on the device.

This also rewords the other parts of the error message to make the meaning
of the memory statistics more clear with this new information:

"""
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 138.00 MiB.
GPU 0 has a total capacty of 15.90 GiB of which 8.44 MiB is free.
Process 1246069 has 577.00 MiB memory in use. Including non-PyTorch memory,
this process has 15.32 GiB memory in use. Of the allocated memory
14.12 GiB is allocated by PyTorch, and 410.41 MiB is reserved
by PyTorch but unallocated. If reserved but unallocated memory is large
try setting max_split_size_mb to avoid fragmentation.  See documentation
 for Memory Management and PYTORCH_CUDA_ALLOC_CONF
"""

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Apr 21, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/99699

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 Failures

As of commit 8b56875:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

zdevito added a commit that referenced this pull request Apr 21, 2023
This PR adds calls to nvml during an OOM to find out the total memory
in use by the process and any other CUDA processes on the device.

This makes it easier to identify cases where non-PyTorch libraries have
allocated memory or another process (such as a data loader) has also
allocated memory on the device.

This also rewords the other parts of the error message to make the meaning
of the memory statistics more clear with this new information:

"""
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 138.00 MiB.
GPU 0 has a total capacty of 15.90 GiB of which 8.44 MiB is free.
Process 1246069 has 577.00 MiB memory in use. Including non-PyTorch memory,
this process has 15.32 GiB memory in use. Of the allocated memory
14.12 GiB is allocated by PyTorch, and 410.41 MiB is reserved
by PyTorch but unallocated. If reserved but unallocated memory is large
try setting max_split_size_mb to avoid fragmentation.  See documentation
 for Memory Management and PYTORCH_CUDA_ALLOC_CONF
"""

ghstack-source-id: df6cd1e
Pull Request resolved: #99699
@zdevito zdevito added the topic: not user facing topic category label Apr 21, 2023
@zdevito zdevito requested a review from ngimel April 21, 2023 05:45
@zdevito zdevito added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 21, 2023
#include <c10/util/llvmMathExtras.h>

#if !defined(USE_ROCM) && defined(PYTORCH_EXPANDABLE_SEGMENTS_SUPPORTED)
#define PYTORCH_C10_DRIVER_API_SUPPORTED
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be passed as compiler arg?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, accidentally left in

nvmlDevice_t nvml_device;
TORCH_INTERNAL_ASSERT(
NVML_SUCCESS ==
DriverAPI::get()->nvmlDeviceGetHandleByIndex_v2_(device, &nvml_device));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvml index and runtime index are not necessarily the same (e.g. CUDA_VISIBLE_DEVICES changes runtime indices), python nvml calls go through annoying process of mapping one index to another. Should we just have a callback on python side when throwing OOMError?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to PCI bus id so that we don't have to parse CUDA_VISIBLE_DEVICES and friends. Using PCI bus rather than uuid because the docs say uuid might need to initialize other GPUs to check their uuid, but pic bus id only has to initialize the device it is specifying.

This PR adds calls to nvml during an OOM to find out the total memory
in use by the process and any other CUDA processes on the device.

This makes it easier to identify cases where non-PyTorch libraries have
allocated memory or another process (such as a data loader) has also
allocated memory on the device.

This also rewords the other parts of the error message to make the meaning
of the memory statistics more clear with this new information:

"""
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 138.00 MiB.
GPU 0 has a total capacty of 15.90 GiB of which 8.44 MiB is free.
Process 1246069 has 577.00 MiB memory in use. Including non-PyTorch memory,
this process has 15.32 GiB memory in use. Of the allocated memory
14.12 GiB is allocated by PyTorch, and 410.41 MiB is reserved
by PyTorch but unallocated. If reserved but unallocated memory is large
try setting max_split_size_mb to avoid fragmentation.  See documentation
 for Memory Management and PYTORCH_CUDA_ALLOC_CONF
"""

[ghstack-poisoned]
zdevito added a commit that referenced this pull request Apr 21, 2023
This PR adds calls to nvml during an OOM to find out the total memory
in use by the process and any other CUDA processes on the device.

This makes it easier to identify cases where non-PyTorch libraries have
allocated memory or another process (such as a data loader) has also
allocated memory on the device.

This also rewords the other parts of the error message to make the meaning
of the memory statistics more clear with this new information:

"""
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 138.00 MiB.
GPU 0 has a total capacty of 15.90 GiB of which 8.44 MiB is free.
Process 1246069 has 577.00 MiB memory in use. Including non-PyTorch memory,
this process has 15.32 GiB memory in use. Of the allocated memory
14.12 GiB is allocated by PyTorch, and 410.41 MiB is reserved
by PyTorch but unallocated. If reserved but unallocated memory is large
try setting max_split_size_mb to avoid fragmentation.  See documentation
 for Memory Management and PYTORCH_CUDA_ALLOC_CONF
"""

ghstack-source-id: 7234c36
Pull Request resolved: #99699
This PR adds calls to nvml during an OOM to find out the total memory
in use by the process and any other CUDA processes on the device.

This makes it easier to identify cases where non-PyTorch libraries have
allocated memory or another process (such as a data loader) has also
allocated memory on the device.

This also rewords the other parts of the error message to make the meaning
of the memory statistics more clear with this new information:

"""
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 138.00 MiB.
GPU 0 has a total capacty of 15.90 GiB of which 8.44 MiB is free.
Process 1246069 has 577.00 MiB memory in use. Including non-PyTorch memory,
this process has 15.32 GiB memory in use. Of the allocated memory
14.12 GiB is allocated by PyTorch, and 410.41 MiB is reserved
by PyTorch but unallocated. If reserved but unallocated memory is large
try setting max_split_size_mb to avoid fragmentation.  See documentation
 for Memory Management and PYTORCH_CUDA_ALLOC_CONF
"""

[ghstack-poisoned]
This PR adds calls to nvml during an OOM to find out the total memory
in use by the process and any other CUDA processes on the device.

This makes it easier to identify cases where non-PyTorch libraries have
allocated memory or another process (such as a data loader) has also
allocated memory on the device.

This also rewords the other parts of the error message to make the meaning
of the memory statistics more clear with this new information:

"""
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 138.00 MiB.
GPU 0 has a total capacty of 15.90 GiB of which 8.44 MiB is free.
Process 1246069 has 577.00 MiB memory in use. Including non-PyTorch memory,
this process has 15.32 GiB memory in use. Of the allocated memory
14.12 GiB is allocated by PyTorch, and 410.41 MiB is reserved
by PyTorch but unallocated. If reserved but unallocated memory is large
try setting max_split_size_mb to avoid fragmentation.  See documentation
 for Memory Management and PYTORCH_CUDA_ALLOC_CONF
"""

[ghstack-poisoned]
zdevito added a commit that referenced this pull request Apr 21, 2023
This PR adds calls to nvml during an OOM to find out the total memory
in use by the process and any other CUDA processes on the device.

This makes it easier to identify cases where non-PyTorch libraries have
allocated memory or another process (such as a data loader) has also
allocated memory on the device.

This also rewords the other parts of the error message to make the meaning
of the memory statistics more clear with this new information:

"""
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 138.00 MiB.
GPU 0 has a total capacty of 15.90 GiB of which 8.44 MiB is free.
Process 1246069 has 577.00 MiB memory in use. Including non-PyTorch memory,
this process has 15.32 GiB memory in use. Of the allocated memory
14.12 GiB is allocated by PyTorch, and 410.41 MiB is reserved
by PyTorch but unallocated. If reserved but unallocated memory is large
try setting max_split_size_mb to avoid fragmentation.  See documentation
 for Memory Management and PYTORCH_CUDA_ALLOC_CONF
"""

ghstack-source-id: 9d4716a
Pull Request resolved: #99699
@zdevito zdevito requested a review from ngimel April 21, 2023 18:11
@zdevito
Copy link
Contributor Author

zdevito commented Apr 21, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/zdevito/237/head branch June 8, 2023 19:28
pytorchmergebot pushed a commit that referenced this pull request Nov 2, 2023
Since #99699 introduced a dependency on nvml for oom reporting in `c10/cuda/driver_api.h`, `c10/cuda/driver_api.cpp`, and `reportProcessMemoryInfo` from `c10/cuda/CUDACachingAllocator.cpp`, we've seen failures regarding cuda expandable segments and oom reporting in NVIDIA's internal CI, specifically on Jetson devices which don't have nvml support as it is incompatible with Jetson. Example failures using the latest upstream on Orin AGX node:

`python test/test_cuda.py -k test_notifies_oom` generates

```
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/pytorch/pytorch/test/test_cuda.py", line 1643, in _worker
    results[t] = torch.nn.functional.conv2d(results[t], weight, padding=0)
RuntimeError: CUDA driver error: out of memory
```

`python test/test_cuda_expandable_segments.py` generates

```
Traceback (most recent call last):
  File "/opt/pytorch/pytorch/test/test_cuda_expandable_segments.py", line 12, in <module>
    exec(compile(open(filepath).read(), filepath, mode='exec'))
  File "/opt/pytorch/pytorch/test/test_cuda.py", line 66, in <module>
    class TestCuda(TestCase):
  File "/opt/pytorch/pytorch/test/test_cuda.py", line 1609, in TestCuda
    @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
  File "/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_utils.py", line 4628, in wrapped
    self._value = self._cb()
  File "/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_cuda.py", line 20, in <lambda>
    TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
RuntimeError: handle_0 INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch/c10/cuda/driver_api.cpp":15, please report a bug to PyTorch.
```

This PR intends to fix this issue by adding various dlopen checks to make sure nvml actually exists, and safely fall back to using the older libcuda based features of cuda expandable segments and oom reporting if nvml is not found.

Pull Request resolved: #112121
Approved by: https://github.com/eqy, https://github.com/ngimel, https://github.com/albanD
xuhancn pushed a commit to xuhancn/pytorch that referenced this pull request Nov 7, 2023
Since pytorch#99699 introduced a dependency on nvml for oom reporting in `c10/cuda/driver_api.h`, `c10/cuda/driver_api.cpp`, and `reportProcessMemoryInfo` from `c10/cuda/CUDACachingAllocator.cpp`, we've seen failures regarding cuda expandable segments and oom reporting in NVIDIA's internal CI, specifically on Jetson devices which don't have nvml support as it is incompatible with Jetson. Example failures using the latest upstream on Orin AGX node:

`python test/test_cuda.py -k test_notifies_oom` generates

```
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/pytorch/pytorch/test/test_cuda.py", line 1643, in _worker
    results[t] = torch.nn.functional.conv2d(results[t], weight, padding=0)
RuntimeError: CUDA driver error: out of memory
```

`python test/test_cuda_expandable_segments.py` generates

```
Traceback (most recent call last):
  File "/opt/pytorch/pytorch/test/test_cuda_expandable_segments.py", line 12, in <module>
    exec(compile(open(filepath).read(), filepath, mode='exec'))
  File "/opt/pytorch/pytorch/test/test_cuda.py", line 66, in <module>
    class TestCuda(TestCase):
  File "/opt/pytorch/pytorch/test/test_cuda.py", line 1609, in TestCuda
    @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
  File "/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_utils.py", line 4628, in wrapped
    self._value = self._cb()
  File "/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_cuda.py", line 20, in <lambda>
    TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
RuntimeError: handle_0 INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch/c10/cuda/driver_api.cpp":15, please report a bug to PyTorch.
```

This PR intends to fix this issue by adding various dlopen checks to make sure nvml actually exists, and safely fall back to using the older libcuda based features of cuda expandable segments and oom reporting if nvml is not found.

Pull Request resolved: pytorch#112121
Approved by: https://github.com/eqy, https://github.com/ngimel, https://github.com/albanD
Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Nov 14, 2023
Since pytorch#99699 introduced a dependency on nvml for oom reporting in `c10/cuda/driver_api.h`, `c10/cuda/driver_api.cpp`, and `reportProcessMemoryInfo` from `c10/cuda/CUDACachingAllocator.cpp`, we've seen failures regarding cuda expandable segments and oom reporting in NVIDIA's internal CI, specifically on Jetson devices which don't have nvml support as it is incompatible with Jetson. Example failures using the latest upstream on Orin AGX node:

`python test/test_cuda.py -k test_notifies_oom` generates

```
Traceback (most recent call last):
  File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/opt/pytorch/pytorch/test/test_cuda.py", line 1643, in _worker
    results[t] = torch.nn.functional.conv2d(results[t], weight, padding=0)
RuntimeError: CUDA driver error: out of memory
```

`python test/test_cuda_expandable_segments.py` generates

```
Traceback (most recent call last):
  File "/opt/pytorch/pytorch/test/test_cuda_expandable_segments.py", line 12, in <module>
    exec(compile(open(filepath).read(), filepath, mode='exec'))
  File "/opt/pytorch/pytorch/test/test_cuda.py", line 66, in <module>
    class TestCuda(TestCase):
  File "/opt/pytorch/pytorch/test/test_cuda.py", line 1609, in TestCuda
    @unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
  File "/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_utils.py", line 4628, in wrapped
    self._value = self._cb()
  File "/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_cuda.py", line 20, in <lambda>
    TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
RuntimeError: handle_0 INTERNAL ASSERT FAILED at "/opt/pytorch/pytorch/c10/cuda/driver_api.cpp":15, please report a bug to PyTorch.
```

This PR intends to fix this issue by adding various dlopen checks to make sure nvml actually exists, and safely fall back to using the older libcuda based features of cuda expandable segments and oom reporting if nvml is not found.

Pull Request resolved: pytorch#112121
Approved by: https://github.com/eqy, https://github.com/ngimel, https://github.com/albanD
pytorchmergebot pushed a commit that referenced this pull request Jun 30, 2024
Seems to be removed following #99699?
Pull Request resolved: #129546
Approved by: https://github.com/Skylion007
pytorchmergebot pushed a commit to khushi-411/pytorch that referenced this pull request Jul 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged merging topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants