Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory not being deallocated in backward() #18643

Open
mdlockyer opened this issue Mar 30, 2019 · 17 comments
Open

Memory not being deallocated in backward() #18643

mdlockyer opened this issue Mar 30, 2019 · 17 comments
Labels
module: autograd Related to torch.autograd, and the autograd engine in general module: memory usage PyTorch is using more memory than it should, or it is leaking memory quansight-nack High-prio issues that have been reviewed by Quansight and are judged to be not actionable. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mdlockyer
Copy link

mdlockyer commented Mar 30, 2019

馃悰 Bug

I've recently discovered an issue with memory not being freed after the first iteration of training. It's not a leak, as memory usage stays consistent after the second pass through the loop. It appears on both CPU and GPU, however it is much more significant when running on CPU.

The issue seems to come from the either backward or optimizer.step(), as removing their calls provides stable memory usage.

I ran into this while attempting to train a rather large model that uses pretty much all of my available GPU memory. It will complete the first iteration successfully, then OOM during the second.

To Reproduce

Steps to reproduce the behavior:

I have compiled a minimal CPU and GPU gist that should reproduce this issue:

CPU
GPU

The CPU gist uses the memory-profile package, so that will need to be installed with pip

Expected behavior

The memory usage should be relatively the same in the first pass through the training loop, and all following loops.

Environment

PyTorch version: 1.0.1.post2
Is debug build: No
CUDA used to build PyTorch: None

OS: Mac OSX 10.13.6
GCC version: Could not collect
CMake version: version 3.9.4

Python version: 3.6
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip3] numpy==1.16.2
[pip3] torch==1.0.1.post2
[pip3] torchvision==0.2.2.post3
[conda] torch 1.0.1.post2
[conda] torchsummary 1.5.1
[conda] torchvision 0.2.1

Additional context

I ran some profiles on the CPU memory usage that highlight the issue:

With backward pass and update:

with_backward_4

Iteration 1

Line #    Mem usage    Increment   Line Contents
================================================
    23    360.0 MiB    360.0 MiB   @profile
    24                             def train(model, criterion, optim):
    25    360.1 MiB      0.0 MiB       x = torch.rand(1, 3, 8, 8)
    26    360.1 MiB      0.0 MiB       y = torch.ones(1, 1, 8, 8)
    27                             
    28    402.7 MiB     42.6 MiB       out = model(x)
    29    402.7 MiB      0.1 MiB       loss = criterion(out, y)
    30                             
    31    402.7 MiB      0.0 MiB       optim.zero_grad()
    32    663.8 MiB    261.1 MiB       loss.backward()
    33    664.0 MiB      0.1 MiB       optim.step()
    34    664.0 MiB      0.0 MiB       optim.zero_grad()
    35    664.0 MiB      0.0 MiB       del x, y, out, loss
    36    664.0 MiB      0.0 MiB       gc.collect()

Iteration 2

Line #    Mem usage    Increment   Line Contents
================================================
    23    664.0 MiB    664.0 MiB   @profile
    24                             def train(model, criterion, optim):
    25    664.0 MiB      0.0 MiB       x = torch.rand(1, 3, 8, 8)
    26    664.0 MiB      0.0 MiB       y = torch.ones(1, 1, 8, 8)
    27                             
    28    701.7 MiB     37.7 MiB       out = model(x)
    29    701.7 MiB      0.0 MiB       loss = criterion(out, y)
    30                             
    31    701.7 MiB      0.0 MiB       optim.zero_grad()
    32    671.7 MiB      0.0 MiB       loss.backward()
    33    671.7 MiB      0.0 MiB       optim.step()
    34    671.7 MiB      0.0 MiB       optim.zero_grad()
    35    671.7 MiB      0.0 MiB       del x, y, out, loss
    36    671.7 MiB      0.0 MiB       gc.collect()

Without backward pass and update:

without_backward_2

Iteration 1

Line #    Mem usage    Increment   Line Contents
================================================
    23    351.2 MiB    351.2 MiB   @profile
    24                             def train(model, criterion, optim):
    25    351.2 MiB      0.0 MiB       x = torch.rand(1, 3, 8, 8)
    26    351.3 MiB      0.0 MiB       y = torch.ones(1, 1, 8, 8)
    27                             
    28    392.4 MiB     41.1 MiB       out = model(x)
    29    392.5 MiB      0.1 MiB       loss = criterion(out, y)
    30                             
    31    392.5 MiB      0.0 MiB       optim.zero_grad()
    32                                 #loss.backward()
    33                                 #optim.step()
    34    392.5 MiB      0.0 MiB       optim.zero_grad()
    35    361.7 MiB      0.0 MiB       del x, y, out, loss
    36    361.7 MiB      0.0 MiB       gc.collect()

Iteration 2

Line #    Mem usage    Increment   Line Contents
================================================
    23    361.7 MiB    361.7 MiB   @profile
    24                             def train(model, criterion, optim):
    25    361.7 MiB      0.0 MiB       x = torch.rand(1, 3, 8, 8)
    26    361.7 MiB      0.0 MiB       y = torch.ones(1, 1, 8, 8)
    27                             
    28    392.0 MiB     30.3 MiB       out = model(x)
    29    392.0 MiB      0.0 MiB       loss = criterion(out, y)
    30                             
    31    392.0 MiB      0.0 MiB       optim.zero_grad()
    32                                 #loss.backward()
    33                                 #optim.step()
    34    392.0 MiB      0.0 MiB       optim.zero_grad()
    35    361.7 MiB      0.0 MiB       del x, y, out, loss
    36    361.7 MiB      0.0 MiB       gc.collect()

cc @ezyang @gchanan @zou3519 @ssnl @albanD @gqchen

@ssnl
Copy link
Collaborator

ssnl commented Mar 30, 2019

after first backward, the parameters' grad buffers are created so the model takes 2x memory, as expected. in first optim.step, if the optimizer maintains some buffer (e.g., Adam or SGD with momentum`), buffers will be created, as expected.

@mdlockyer
Copy link
Author

mdlockyer commented Mar 30, 2019

That definitely explains the bulk of the memory usage, however it doesn鈥檛 explain the increase in memory usage. This is more at the root of the issue, and I may have chosen a bad title. If you look at the peak usage, it is higher by about 40MB in the second pass. In the model I was training when I discovered this it was more exaggerated, being almost 1GB higher. I鈥檝e checked and double checked that no tensors are staying referenced accidentally and not being garbage collected. In fact, I tried the same pattern as the reproduction gists where I don鈥檛 return anything at all, and use del on all tensors. Still runs about 1GB higher in the second iteration onward for that architecture. Nothing is being stored internally within the model鈥檚 child Modules so I can鈥檛 explain it.

Also should note that no momentum was being used in any of the models I鈥檝e tested. All have been SGD with momentum=0 so it shouldn鈥檛 be storing any momentum data after the first pass.

@ezyang ezyang added module: memory usage PyTorch is using more memory than it should, or it is leaking memory triage review module: autograd Related to torch.autograd, and the autograd engine in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module high priority and removed triage review labels Apr 2, 2019
@soumith
Copy link
Member

soumith commented May 27, 2019

@ezyang @gchanan could you make sure someone looks into this

@ezyang
Copy link
Contributor

ezyang commented Jun 11, 2019

cc @malvika2147

@malvika2147
Copy link

@mdlockyer Can you please add the gists again, the links are currently broken.

@mdlockyer
Copy link
Author

@malvika2147 They should be back up. Sorry about that. Finally got around to changing my old username which breaks all those links

@KaiQiao1992
Copy link

I encounter the same problem, and memory is about 8GB higher when executing the second loss.backward(). I do not know why.

@mdlockyer
Copy link
Author

mdlockyer commented Jun 26, 2019

@KaiQiao1992 it may be worth noting in this discussion that if you are using adaptive optimizers like Adam, there are a lot of buffers being created under the hood. They are very memory hungry. Adam creates two buffers that are of equal size to the weights being optimized(so in memory, it鈥檚 model size x3) and if amsgrad=True it will add a third. Not sure if that is relevant to your case, but I thought I鈥檇 put it out there. And as @ssnl mentioned, those buffers are created on the first call to step(). So they will only appear after the first iteration.

@mdlockyer
Copy link
Author

@KaiQiao1992 8GB sounds steep for optimizer buffers though. That is a significant amount. Hopefully you鈥檙e able to figure it out. This may help. It鈥檚 a memory profiler for PyTorch. I haven鈥檛 tested it out, but it could be of use.

@KaiQiao1992
Copy link

@mdlockyer I indeed used the Adam optimizer. Strangely, after rebooting the machine, i dno not encounter the "out of memory" again, though using the same Adam. because my fc layer is the size of 800000*1000, the consume of memory is big.

@mdlockyer
Copy link
Author

@KaiQiao1992 that鈥檚 huge!! 800M parameters! My biggest model was 45M and I thought that was gigantic. Glad you鈥檙e not getting the OOM errors now though.

@prasunanand
Copy link
Contributor

prasunanand commented Jul 30, 2019

When I add time.sleep(20) to the reproducer code, there is no such issue. I belive the gc kicks in during the sleep().

https://gist.github.com/prasunanand/0926fe1ea453a785c967d2c444a22402

@ezyang
Copy link
Contributor

ezyang commented Jul 30, 2019

If that's true, swapping time.sleep(20) with gc.collect() ought to work too. Sounds like a reference cycle, in that case?

@mdlockyer
Copy link
Author

@prasunanand that's interesting. I'll test your gist on my end when I get a chance. @ezyang In my reproduction, I have a call to collect() after each iteration already. Not sure why the sleep works but not explicit garbage collection.

@TaehwanKwon
Copy link

I had same issue on python3.6.9 + torch1.3.0 but it works fine on python 3.7.5 + torch 1.3.0.

@peterbell10 peterbell10 self-assigned this Nov 11, 2019
@peterbell10
Copy link
Collaborator

I've been unable to reproduce the sharp memory spikes from the issue, only a very slight rise in memory usage. The profile looks roughly the same for all of the python and pytorch versions I tried.

plot_py36

@TaehwanKwon would you mind posting the memory profile that you see running the cpu script on python 3.6? Also, are you using macOS like @mdlockyer?

@peterbell10 peterbell10 removed their assignment Nov 21, 2019
@rgommers rgommers added the quansight-nack High-prio issues that have been reviewed by Quansight and are judged to be not actionable. label Jan 26, 2020
@ezyang
Copy link
Contributor

ezyang commented Jan 27, 2020

Given that this reproduces inconsistently / can be fixed by upgrading (either PyTorch or Python), I'm downgrading the priority of this issue. If someone can come up with a clear configuration on the newest Python/PyTorch which exactly causes the problem please let us know.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: autograd Related to torch.autograd, and the autograd engine in general module: memory usage PyTorch is using more memory than it should, or it is leaking memory quansight-nack High-prio issues that have been reviewed by Quansight and are judged to be not actionable. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

10 participants