Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve cuda OOM message #32101

Open
vadimkantorov opened this issue Jan 13, 2020 · 11 comments
Open

Improve cuda OOM message #32101

vadimkantorov opened this issue Jan 13, 2020 · 11 comments
Labels
module: bootcamp We plan to do a full writeup on the issue, and then get someone to do it for onboarding module: cuda Related to torch.cuda, and CUDA support in general module: docs Related to our documentation, both in docs/ and docblocks triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jan 13, 2020

This is based on #31497 (after the last messages with @albanD and @ezyang)

Condensed, I had a script there that printed:

print('memory_allocated', torch.cuda.memory_allocated() / 1e9, 'memory_cached', torch.cuda.memory_cached() / 1e9)

in a model eval loop. And at some point it got an OOM:

memory_allocated 9.7478528 memory_cached 22.013804544
memory_allocated 11.03991552 memory_cached 24.29550592

RuntimeError: CUDA out of memory. Tried to allocate 152.00 MiB (GPU 0; 31.72 GiB total capacity; 24.89 GiB already allocated; 6.12 MiB free; 30.70 GiB reserved in total by PyTorch)

Some things:

  1. torch.cuda.memory_allocated() of 11Gb is not reported in OOM message

  2. Terminology discrepancy: torch.cuda.memory_cached seems to be equivalent to "already allocated" from OOM message. In presence of also existing torch.cuda.memory_allocated this is confusing. Probably the OOM message should also say cached. Otherwise, it's pretty easy to mix up allocated, cached, reserved.

  3. It would be nice for the OOM message to include by default a small glossary/explainer explaining all these various memory counters.

Related: my previous issue about feature request of adding by default some measure of fragmentation: #29554, my old issue about allocator stats #1529.

Recently there were a few reports of default allocator causing problems with very varying batch sizes (myself included). To confirm the guess, it would be nice to have an easily interpretable allocator state visualization / dump (super cool would be to have a way to dump an HTML vis). Currently there exists torch.cuda.memory_stats, torch.cuda.memory_usage and torch.cuda.memory_snapshot. It would be nice to have some default advice on what to save / use when debugging for suspected fragmentation.

In addition, they are not searchable for whatever reason: https://pytorch.org/docs/master/search.html?q=memory_usage&check_keywords=yes&area=default#

cc @ngimel @jlin27 @mruberry

@ezyang
Copy link
Contributor

ezyang commented Jan 13, 2020

This seems quite doable to hack something up to look at this, but the question is what exactly to hack up...

@vadimkantorov
Copy link
Contributor Author

@ezyang assuming you're talking about the vis: A view over all device (despite other processes can use it) memory with all allocated tensors + storages. One way it could be some long ribbon, possibly wrapped over several lines, allocated tensors could use odd/even highlighting to discriminate between different tensors allocated next to each other (+ maybe another color for unused but allocated tensor storage). Then even a coarse view would give an idea of fragmentation.

@vadimkantorov
Copy link
Contributor Author

A measure of fragmentation could be some relation like amount of uncommitted memory versus maximum number bytes than can be allocated contiguously

@zhangguanheng66 zhangguanheng66 added module: cuda Related to torch.cuda, and CUDA support in general module: docs Related to our documentation, both in docs/ and docblocks triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 14, 2020
@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Jun 6, 2020

One problem of the default OOM message is that it doesn't report how much GPU memory is taken by other processes (or coversely, what's the capacity given all other processes).

I recently had again RuntimeError: CUDA out of memory. Tried to allocate 78.00 MiB (GPU 0; 31.75 GiB total capacity; 604.93 MiB already allocated; 11.69 MiB free; 626.00 MiB reserved in total by PyTorch), the problem was simply other processes, but it's not clear from these numbers, the total capacity is being deceptive here.

@ezyang
Copy link
Contributor

ezyang commented Jun 7, 2020

I suppose we could shell out to nvidia-smi in this situation :>

@vadimkantorov
Copy link
Contributor Author

Is nvidia-smi guaranteed to be available in PATH? I think it happened to me where nvidia-smi was available, and nvcc was not. Well, if it is, this info would be useful.

Anyway, global memory situation or info on presence of other processes using the GPU is useful.

@ngimel
Copy link
Collaborator

ngimel commented Jun 7, 2020

nvidia-smi is not guaranteed to be available in the path, and in case of isolation (e.g. with docker) it might not report other processes running on the same gpu. In cases where it is available, and where all processes are visible it'll be useful, but it's not a guarantee.

@vadimkantorov
Copy link
Contributor Author

vadimkantorov commented Jun 7, 2020

It also occurred to me that it did not report other processes even when they were inside the same docker container (just some background processes), but at least memory counter showed that they do exist - even that memory counter would be useful for the OOM msg

@ezyang ezyang added the module: bootcamp We plan to do a full writeup on the issue, and then get someone to do it for onboarding label Nov 6, 2020
@CorentinJ
Copy link

Noting that you can get all tensors allocated on CUDA in the program using python's garbage collector:

def _active_cuda_tensors():
    """
    Returns all tensors initialized on cuda devices
    """
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) and obj.device.type == "cuda":
                yield obj
        except:
            pass

And compute the size used in byte by each tensor with tensor.element_size() * tensor.nelement()

@vadimkantorov
Copy link
Contributor Author

Now more and more tensors are allocated from C++ (including output volumes from ops), gradients etc. They still add memory pressure on the caching allocator, but cannot be inspected like this :(

@CorentinJ
Copy link

I don't know about that, I wrote a function to compare the size of tensors in memory w.r.t. what pytorch shows is allocated, and it matches:

VRAM usage summary:
	All tensors above:       2.39 GiB
	Currently in use:        2.39 GiB
	Peak usage so far:       3.96 GiB
	Reserved by torch:       4.09 GiB
	Max available on device: 8.00 GiB

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: bootcamp We plan to do a full writeup on the issue, and then get someone to do it for onboarding module: cuda Related to torch.cuda, and CUDA support in general module: docs Related to our documentation, both in docs/ and docblocks triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants