Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda out of memory error when GPU0 memory is fully utilized #3477

Closed
lucylw opened this issue Nov 3, 2017 · 16 comments
Closed

cuda out of memory error when GPU0 memory is fully utilized #3477

lucylw opened this issue Nov 3, 2017 · 16 comments
Labels
module: cuda Related to torch.cuda, and CUDA support in general triage review

Comments

@lucylw
Copy link

lucylw commented Nov 3, 2017

I have been experiencing an error on systems with multiple GPUs. When GPU0 is fully utilized by another process, I get RuntimeError: cuda runtime error (2) : out of memory.

It seems that torch.nn.Module.cuda() transfers data not only to my specified GPU, but also GPU0, whose memory is already being used.

I can reproduce this error using the below code (memory on GPU0 should be fully utilized by another process):

import torch

import torch.nn as nn
from torch.autograd import Variable

seq_len = 10
features = 50
hidden_size = 50
batch_size = 32

model = nn.Module()
model.rnn = nn.RNN(input_size=features, hidden_size=hidden_size, num_layers=2)
model.cuda(5)

X_train = torch.randn(seq_len, batch_size, features)
y_train = torch.randn(batch_size)
X_train, y_train = Variable(X_train).cuda(), Variable(y_train).cuda()

After model.cuda(5), my nvidia-smi output shows:
+----------------------------------------------------------+
| Processes: GPU Memory |
| GPU PID Type Process name Usage |
|===========================================|
| 0 32317 C python 7773MiB |
| 0 34873 C python 217MiB |
| 1 41080 C python 7775MiB |
| 5 34873 C python 289MiB |
+----------------------------------------------------------+

I am using GPU5, but the same process 34873 is also using memory on GPU0.

It looks like in the device class of torch/cuda/init.py, the prev_idx is being reset to 0 and then torch._C._cuda_setDevice is setting the device number to 0 upon exit.

torch/cuda/init.py:110

class device(object):
    """Context-manager that changes the selected device.

    Arguments:
        idx (int): device index to select. It's a no-op if this argument
            is negative.
    """

    def __init__(self, idx):
        self.idx = idx
        self.prev_idx = -1

    def __enter__(self):
        if self.idx is -1:
            return
        _lazy_init()
        self.prev_idx = torch._C._cuda_getDevice()
        if self.prev_idx != self.idx:
            torch._C._cuda_setDevice(self.idx)

    def __exit__(self, *args):
        if self.prev_idx != self.idx:
            torch._C._cuda_setDevice(self.prev_idx)
        return False

cc @ngimel

@ngimel
Copy link
Collaborator

ngimel commented Nov 3, 2017

That's because the default device is 0, so pytorch is trying to create context on it. You can control the devices you are using either by CUDA_VISIBLE_DEVICES environment variable, or guarding you computations like this

with torch.cuda.device(5):
      model.cuda()

      X_train = torch.randn(seq_len, batch_size, features)
      y_train = torch.randn(batch_size)
      X_train, y_train = Variable(X_train).cuda(), Variable(y_train).cuda()

@lucylw
Copy link
Author

lucylw commented Nov 3, 2017

Even when I specify the cuda device for all my transfers, there is still memory being used on GPU0, e.g.

model = nn.Module()
model.rnn = nn.RNN(input_size=features, hidden_size=hidden_size, num_layers=2)
model.cuda(5)

X_train = torch.randn(seq_len, batch_size, features)
y_train = torch.randn(batch_size)
X_train, y_train = Variable(X_train).cuda(5), Variable(y_train).cuda(5)

Is this the right behavior?

More specifically, model.cuda(5) transfers data onto GPU0 (the default), but the variable transfers do not, the variables seem to only go onto GPU5 when I specify the device number.

@ssnl
Copy link
Collaborator

ssnl commented Nov 4, 2017

@lucylw Yes this is the intended behavior. Since PyTorch still sees your GPU 0 as first in CUDA_VISIBLE_DEVICES, it will create some context on it. If you want your script to completely ignore GPU 0, you need to set that environment variable. e.g., for it to only use GPU 5, do CUDA_VISIBLE_DEVICES=5 python my_script.py. However, be noted that in the script GPU 5 is really referred to as device 0.

@apaszke
Copy link
Contributor

apaszke commented Nov 4, 2017

Still, I thought we're only initializing contexts on devices that are actually getting used (in the last snippet). I don't think we should create one on GPU0 in this case. It's worth looking into that.

@ssnl
Copy link
Collaborator

ssnl commented Nov 5, 2017

@apaszke From my experience with 0.2.0, it always creates ~250MB context on first visible GPU, no matter if that GPU is used or not.

@apaszke
Copy link
Contributor

apaszke commented Nov 5, 2017

That's weird, I don't think it should be like this

@lucylw
Copy link
Author

lucylw commented Nov 6, 2017

@ssnl, @apaszke

It looks like in the context-manager in torch/cuda/__init__.py, the prev_idx gets reset in __enter__ to the default device index (which is the first visible GPU), and then it gets set to that upon __exit__ instead of to -1. So the context first gets created on the specified GPU (i.e. GPU5), then some more context gets created on GPU0, and then all the variable transfers go back to GPU5.

@csarofeen
Copy link
Contributor

Shouldn't enter/exit only get called with a 'with' statement?
Does this still happen if your first call (before using any cuda calls) torch.cuda.set_device(devID)?

@jekbradbury
Copy link
Contributor

jekbradbury commented Nov 6, 2017

There are a bunch of things in PyTorch that can currently lead to initialization of a context on the first visible GPU; things like CPU-GPU copies and .tolist() all need to be inside with torch.cuda.device_of(tensor): blocks. None of this is a problem if you use torch.cuda.set_device but the devs typically recommend CUDA_VISIBLE_DEVICES instead.

@csarofeen
Copy link
Contributor

csarofeen commented Nov 6, 2017

CUDA_VISIBLE_DEVICES may not be the best thing to always rely on "robust applications should use the CUDA API to enumerate and select devices with appropriate capabilities at run time. To learn how, read the section on Device Enumeration in the CUDA Programming Guide. But the CUDA_VISIBLE_DEVICES environment variable is handy for restricting execution to a specific device or set of devices for debugging and testing." https://devblogs.nvidia.com/parallelforall/cuda-pro-tip-control-gpu-visibility-cuda_visible_devices/

It has known conflicts with multi-process NCCL (as in it doesn't work)

@TomHeaven
Copy link
Contributor

TomHeaven commented Mar 27, 2019

I met the same problem using Pytorch 0.4.1. I have two GTX 1080Ti GPUs (11GB RAM for each one). When I run a training code on GPU0, it's OK. When I run the training code on GPU1 by setting CUDA_VISIBIE_DEVICES, the program reported CUDA OUT OF MEMORY error.

It's weird since GPU0 actually has less free memory since it's connected to the monitor.

Free GPU memory before running the training code:

./cuda-semi 
Device 0 [PCIe 0:1:0.0]: GeForce GTX 1080 Ti (CC 6.1): 9247.5 of 11264 MB (i.e. 82.1%) Free
Device 1 [PCIe 0:2:0.0]: GeForce GTX 1080 Ti (CC 6.1): 11090 of 11264 MB (i.e. 98.5%) Free

After running the training code on GPU0:

Device 0 [PCIe 0:1:0.0]: GeForce GTX 1080 Ti (CC 6.1): 342.02 of 11264 MB (i.e. 3.04%) Free
Device 1 [PCIe 0:2:0.0]: GeForce GTX 1080 Ti (CC 6.1): 11090 of 11264 MB (i.e. 98.5%) Free

Error message when running on GPU1:

Traceback (most recent call last):
  File "train.py", line 93, in <module>
    out = net(images)
  File "/Library/Python/2.7/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/Volumes/Data/比赛/提交代码M2Det/code/m2det.py", line 133, in forward
    conf.append(c(x).permute(0, 2, 3, 1).contiguous())
  File "/Library/Python/2.7/site-packages/torch/nn/modules/module.py", line 477, in __call__
    result = self.forward(*input, **kwargs)
  File "/Library/Python/2.7/site-packages/torch/nn/modules/conv.py", line 301, in forward
    self.padding, self.dilation, self.groups)
RuntimeError: CUDA error: out of memory

@soumith
Copy link
Member

soumith commented Mar 27, 2019

@TomHeaven did you set CUDA_VISIBLE_DEVICES outside the python process? if that's the case pytorch should not even have driver-level access to your GPU0.

Ideally:

CUDA_VISIBLE_DEVICES=1 python foo.py

@TomHeaven
Copy link
Contributor

Yes, I did exactly the same way. And if I restart the computer, GPU 1 sometimes can run the code without problem.

@saurabheights
Copy link

saurabheights commented Dec 6, 2019

Just tried this and the behaviour is really perplexing. I am working on two gpus system shared with my colleagues. Cuda:1 is in use by my colleague and though cuda:0 is mostly empty(8300MBfree), I get OOM error.

Screenshot from 2019-12-06 01-16-14

Let me know if you need more information.

@gchanan gchanan added module: cuda Related to torch.cuda, and CUDA support in general triage review labels Dec 6, 2019
@saurabheights
Copy link

saurabheights commented Dec 6, 2019

@gchanan Please ignore my comment earlier. I have been informed just now by the owner of the system, that due to some glitch the gpus id have been reversed. So, RTX is actually gpu id 1, and not 0. and nvidia-smi gives wrong info. I am really sorry for wasting your time.

Note: I have just copied over information from supervisor to you directly. Just to not look into my issue and waste time, but I am yet to confirm the supervisor comments myself yet. I will confirm this tonight as soon as I have system access.

UPDATE
Studying https://stackoverflow.com/a/13785789/1874627 and https://discuss.pytorch.org/t/how-to-specify-gpu-usage/945/7, I checked using deviceQuery(Run locate deviceQuery on ubuntu/mac for anyone going through this issue, to locate its executable, needs cuda):-

deviceQuery, CUDA Driver = CUDART, CUDA Driver Version = 10.0, CUDA Runtime Version = 10.0, NumDevs = 2, Device0 = TITAN V, Device1 = GeForce RTX 2080 Ti

Turns out the issue is discrepancy between how nvidia-smi and rest of nvidia driver works.

P.S. Setting CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=0 python might work but I cannot test this. [Both gpus currently in use.]

@gchanan
Copy link
Contributor

gchanan commented Dec 9, 2019

I think we've tackled a number of these issues. I'm going to close this now since we haven't seen a reproduction in quite awhile. Please reopen if you see this again.

@gchanan gchanan closed this as completed Dec 9, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general triage review
Projects
Issue Categories
distributed/multiGPU
Development

No branches or pull requests

10 participants