Skip to content

torch.cuda.set_per_process_memory_fraction() does not perform VRAM isolation #69688

@ORippler

Description

@ORippler

🐛 Describe the bug

Hi,

torch.cuda.set_per_process_memory_fraction() is intended to be used to allow for GPU sharing across models and was introduced by #48172. However, what the docs and the PR currently fail to mention is that this limitation is a check done purely after successful allocating the memory via CUDA, and actually does not enforce proper isolation.

I just had the same code failing for me on a GTX 3090 limited to 5936 MiB that passed just fine on a GTX 1660 (which has 5936 MiB natively).

Inspecting outputs of torch.cuda.max_memory_allocated() yields

  • 3682 MiB for the GTX 1660, and
  • 9638 MiB for the RTX 3090
    when omitting the limitation via torch.cuda.set_per_process_memory_fraction(). For context: I trained a simple FCN-32s on PascalVoc (and can provide training scripts as well as a docker image if required).

Note that nvidia-smi shows similar VRAM usages for both settings, so I guess that allocating one big temporary chunk followed by GC is more efficient in the case of the RTX 3090 and done irrespective of the set limitation.

Suggestions

  1. Incorporate fraction imposed by torch.cuda.set_per_process_memory_fraction() also into the memory allocation procedure, providing proper isolation.
  2. Should 1 not be feasible: Update the docs to reflect that edge cases might exist where temporary allocation exceeds the imposed limitations, causing unexpected OOM errors.
  3. A little bit off-note: Update the docs to reflect that torch.cuda.set_per_process_memory_fraction() does not consider overheads by e.g. CUDA context (refer [Feature] Allow user to specify a fraction of the GPU memory.  #48172 and set_per_process_memory_fraction() does not ensure max used GPU memory below fraction #58466)

Versions

Collecting environment information...
PyTorch version: 1.8.2+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.2 LTS (x86_64)
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.31

Python version: 3.9.6 | packaged by conda-forge | (default, Jul 11 2021, 03:39:48)  [GCC 9.3.0] (64-bit runtime)
Python platform: Linux-5.13.0-22-lowlatency-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: NVIDIA GeForce GTX 1660
Nvidia driver version: 495.44
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.21.2
[pip3] torch==1.8.2+cu111
[pip3] torchaudio==0.8.2
[pip3] torchvision==0.9.2+cu111
[conda] numpy                     1.21.2           py39hdbf815f_0    conda-forge
[conda] torch                     1.8.2+cu111              pypi_0    pypi
[conda] torchaudio                0.8.2                    pypi_0    pypi
[conda] torchvision               0.9.2+cu111              pypi_0    pypi

cc @brianjo @mruberry @ngimel

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generalmodule: docsRelated to our documentation, both in docs/ and docblocksmodule: memory usagePyTorch is using more memory than it should, or it is leaking memorytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions