Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Memory leak on Convnet on CPU #5285

Closed
EKami opened this issue Feb 17, 2018 · 21 comments

Comments

Projects
None yet
8 participants
@EKami
Copy link

commented Feb 17, 2018

  • OS: Ubuntu 16.04
  • PyTorch version: 0.3.1b0+2b47480
  • How you installed PyTorch (conda, pip, source): source
  • Python version: 3.6.4
  • CUDA/cuDNN version: CUDA 9.0/CuDnn 7
  • GPU models and configuration: GTX 1080Ti (as well as GTX 1070)
  • GCC version (if compiling from source): gcc (Ubuntu 5.4.0-6ubuntu1~16.04.6) 5.4.0 20160609

Hello,
While implementing the SRGan paper I ran into a memory leak issue while doing the inference on my CPU.
This issue should be easily reproductible from your side as my implementation is here. All you have to do is to clone the repository with git clone -b showcase/memory-leak git@github.com:EKami/Torchlite.git, cd into the examples folder then run the script with python srgan.py eval --on_cpu to run the inference on the cpu. Then you should get a memory leak with the following message:

RuntimeError: $ Torch: not enough memory: you tried to allocate 116GB. Buy new RAM!

If you run the same script but with python srgan.py eval (defaults on cuda) then the memory leak vanishes. The exact line which cause the memory leak is this one. Remove that line to get:

        block_x2 = self.block_x2(block1 + block_x1)  # ElementWise sum

        # TODO causes a memory leak on CPU
        # block_x3 = self.block_x3(block_x2)

        return (F.tanh(block_x2) + 1) / 2

and execute the script on the cpu again with python srgan.py eval --on_cpu and poof the memory leak vanishes.
I tried on 2 different computer each with the same software installed but for the hardware one has:

  • AMD FX 8350
  • GTX 1070
    The second:
  • Intel i7 7700k
  • GTX 1080Ti

And I get the same memory leak on both machines.

@EKami EKami changed the title [Bug] Memory leak on inference on Convnet on CPU [Bug] Memory leak on Convnet on CPU Feb 17, 2018

@goldsborough

This comment has been minimized.

Copy link
Contributor

commented Feb 20, 2018

Can you run valgrind on your code and see if it reports anything suspicious?
@ezyang maybe has thoughts

@zou3519

This comment has been minimized.

Copy link
Contributor

commented Feb 20, 2018

I'm suspecting an integer overflow somewhere. @EKami, when I try to run your code, I get the following:

Traceback (most recent call last):
  File "srgan.py", line 18, in <module>
    import torchlight.data.fetcher as fetcher
  File "/private/home/rzou/Torchlight/examples/torchlight/data/fetcher.py", line 3, in <module>
    from kaggle_data.downloader import KaggleDataDownloader
ModuleNotFoundError: No module named 'kaggle_data'

How do I install kaggle_data?

@EKami

This comment has been minimized.

Copy link
Author

commented Feb 20, 2018

Forgot about this one, sorry @zou3519 . Here is the dependency. Or just pip install -U git+https://github.com/EKami/kaggle-data-downloader.git

@zou3519

This comment has been minimized.

Copy link
Contributor

commented Feb 20, 2018

Hitting the following now:

Traceback (most recent call last):
  File "srgan.py", line 22, in <module>
    from torchlight.nn.train_callbacks import ModelSaverCallback, ReduceLROnPlateau, TensorboardVisualizerCallback
  File "/private/home/rzou/Torchlight/examples/torchlight/nn/train_callbacks.py", line 9, in <module>
    from tensorboardX import SummaryWriter
  File "/private/home/rzou/local/miniconda3/lib/python3.6/site-packages/tensorboardX/__init__.py", line 4, in <module>
    from .writer import FileWriter, SummaryWriter
  File "/private/home/rzou/local/miniconda3/lib/python3.6/site-packages/tensorboardX/writer.py", line 25, in <module>
    from .src import summary_pb2
  File "/private/home/rzou/local/miniconda3/lib/python3.6/site-packages/tensorboardX/src/summary_pb2.py", line 25, in <module>
    dependencies=[tensorboard_dot_src_dot_tensor__pb2.DESCRIPTOR,])
  File "/private/home/rzou/local/miniconda3/lib/python3.6/site-packages/google/protobuf/descriptor.py", line 829, in __new__
    return _message.default_pool.AddSerializedFile(serialized_pb)
TypeError: Couldn't build proto file into descriptor pool!
Invalid proto descriptor for file "tensorboard/src/summary.proto":
  tensorboard/src/summary.proto: A file with this name is already in the pool.

In your code,

# TODO causes a memory leak on CPU
# block_x3 = self.block_x3(block_x2)

What is the output of block_x2.size()? I'm thinking of just making a random input and sending it into a convolution and seeing if that causes an OOM.

@EKami

This comment has been minimized.

Copy link
Author

commented Feb 20, 2018

@zou3519 You can either remove TensorboardVisualizerCallback from the import and on this line to fix your issue or the size of block_x2 is [16, 64, 224, 224]

@zou3519

This comment has been minimized.

Copy link
Contributor

commented Feb 20, 2018

@EKami
Does the following code OOM for you?

import torch
import torch.nn as nn
from torch.autograd import Variable

model = nn.Conv2d(64, 3, kernel_size=9, padding=4)
x = Variable(torch.randn([16, 64, 224, 224]))
model(x)

I removed the two pieces you suggested (while trying to run your code) and am still getting a similar error. I've never worked with tensorboard before so I'm not sure what the error is saying.

@EKami

This comment has been minimized.

Copy link
Author

commented Feb 20, 2018

@zou3519 The code run from start to end without crashing but it takes like 15gb of RAM to run which isn't normal imo.

@zou3519

This comment has been minimized.

Copy link
Contributor

commented Feb 27, 2018

I agree, I've seen around the same memory usage. Am looking into it.

@fmassa

This comment has been minimized.

Copy link
Member

commented Feb 27, 2018

The 15GB of RAM for this input size wouldn't surprise me, because we use the (memory consuming) unfolding of the image to perform the convolution.
In your example, the unfolded image has size of roughly 16 * 64 * 9 * 9 * 224 * 224 * 4 which is roughly 15GB (exact code here).

The dimensions of the input image are too big for the kernel size.

@EKami

This comment has been minimized.

Copy link
Author

commented Feb 28, 2018

@fmassa Ok so I ran a little benchmark on both the GPU and the CPU of the following script:

import signal
import torch
import torch.nn as nn
from torch.autograd import Variable

model = nn.Conv2d(64, 3, kernel_size=9, padding=4).cuda() # cuda removed for CPU
x = Variable(torch.randn([16, 64, 224, 224])).cuda() # cuda removed for CPU
model(x)
print("Done")
signal.pause()

Here are the results:
On GPU:

+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
|=============================================================================|
|    0      1403      G   /usr/lib/xorg/Xorg                            90MiB |
|    0      7061      C   python                                      1341MiB |
+-----------------------------------------------------------------------------+

On CPU:

$ps -eo size,pid,user,command --sort -size | awk '{ hr=$1/1024 ; printf("%13.2f Mb ",hr) } { for ( x=4 ; x<=NF ; x++ ) { printf("%s ",$x) } print "" }' |cut -d "" -f2 | cut -d "-" -f1 > output

         0.00 Mb COMMAND
     16607.09 Mb python test.py
      1172.36 Mb node current/index.js
       696.78 Mb /usr/lib/x86_64
       641.43 Mb /usr/bin/dockerd
       506.49 Mb /usr/lib/x86_64
       438.69 Mb /usr/sbin/unity
       427.19 Mb /usr/lib/snapd/snapd
       ....

But once the program reaches the signal.pause() the memory get freed up after a while.
Do you really think this is still normal?
Even if it's not a memory leak is there an alternative for me to run the code on CPU without it taking 17gb but 1.5gb instead as on the GPU? Thanks

@fmassa

This comment has been minimized.

Copy link
Member

commented Feb 28, 2018

The convolutions on the GPU uses cudnn, which does not use the same unfold technique so uses much less memory.

For the moment, I'd say that the only way of reducing memory usage would be to either go through the NNPack binding, which in the master branch is enabled in the following cases, or reducing the batch size / image size that you feed to your model.

I've mentioned in the past about the large memory requirements of convolutions on CPU, but we didn't reach an agreement

@EKami

This comment has been minimized.

Copy link
Author

commented Feb 28, 2018

Thanks a lot for these informations @fmassa !
Well it's very limited =/ . At a hight level that means I only have to use the .cpu() directive if I want to use nnpack? (considering my neural network is compatible). That's sad because I wanted to run inference on the cpu on aws lambda for my algorithm (with lambda being limited to 3gb of RAM usage) but now that this issue exist I'll have to look at another framework like TF :( .

@fmassa

This comment has been minimized.

Copy link
Member

commented Feb 28, 2018

There is some work being done to use MKLDNN in PyTorch #4186 , but I'm not sure about the status of it. They seem to have a branch linked in that thread. Another possibility would be to make the im2col only operate in single elements of the batch (instead of the whole batch), that would reduce the memory requirements by a factor of 16 in your case, but would make things a bit slower.

One question: it seems that you want to generate an image that is of size 1800x1800 (8x upsample of 224x224), is that right?

@EKami

This comment has been minimized.

Copy link
Author

commented Feb 28, 2018

One question: it seems that you want to generate an image that is of size 1800x1800 (8x upsample of 224x224), is that right?

Yes, when I run inference I only process 1 batch at a time but my resulting image can be even bigger than 1800x1800 (depending on the input image size). Here is the code which does that. It seems that there isn't much hope for CPU inference for me for now considering the upscaling factor will take even more memory I believe.

@mingfeima

This comment has been minimized.

Copy link
Contributor

commented Mar 15, 2018

@EKami i ran your benchmark on our repo with mkldnn integrated. The memory consumption is 3.1G (bench.py, the rest items belong to other users)

3157.19 Mb python bench.py
      2532.11 Mb python
      2181.91 Mb ./LSA start
      1265.99 Mb /usr/libexec/mysqld
      579.24 Mb vim applications/convergence.py
      441.38 Mb /usr/lib/polkit
      432.87 Mb /usr/libexec/xdg

i also tried to run python srgan.py eval --on_cpu but got the error below, because i didn't compile pytorch with CUDA. But why do i need cuda if running a CPU inference, am i missing anything?

Traceback (most recent call last):
  File "srgan.py", line 161, in <module>
    main()
  File "srgan.py", line 155, in main
    evaluate(args)
  File "srgan.py", line 76, in evaluate
    ModelSaverCallback.restore_model([netG], saved_model_dir.absolute())
  File "/home/mingfeim/pytorch/Torchlite/examples/torchlight/nn/train_callbacks.py", line 269, in restore_model
    model.load_state_dict(torch.load(file))
  File "/home/mingfeim/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 267, in load
    return _load(f, map_location, pickle_module)
  File "/home/mingfeim/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 432, in _load
    result = unpickler.load()
  File "/home/mingfeim/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 401, in persistent_load
    data_type(size), location)
  File "/home/mingfeim/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 87, in default_restore_location
    result = fn(storage, location)
  File "/home/mingfeim/anaconda3/lib/python3.6/site-packages/torch/serialization.py", line 69, in _cuda_deserialize
    return obj.cuda(device)
  File "/home/mingfeim/anaconda3/lib/python3.6/site-packages/torch/_utils.py", line 68, in _cuda
    with torch.cuda.device(device):
  File "/home/mingfeim/anaconda3/lib/python3.6/site-packages/torch/cuda/__init__.py", line 225, in __enter__
    self.prev_idx = torch._C._cuda_getDevice()
AttributeError: module 'torch._C' has no attribute '_cuda_getDevice'
@EKami

This comment has been minimized.

Copy link
Author

commented Mar 15, 2018

Hey @mingfeima that's awesome! Thanks a lot! Do you plan to merge this into the official pytorch repo?
For the cuda error with python srgan.py eval --on_cpu I just pushed a fix if you want to test again with this command from the showcase/memory-leak branch

@mingfeima

This comment has been minimized.

Copy link
Contributor

commented Mar 16, 2018

@EKami first of all, sry to say that i didn't solve your problem from the bottom.
i believe you also need to update here to enable CPU inference.

    #ModelSaverCallback.restore_model([netG], saved_model_dir.absolute())
      load_with_cpu = False if args.cuda else True
      ModelSaverCallback.restore_models([netG], saved_model_dir.absolute(), load_with_cpu)

Anyway, i ran python srgan.py eval --on_cpu and still the code hang inside mkldnn because of not enough memory. The reason is that mkldnn has several paths for convolution computation depending on input_channel and output_channel. Only output_channel equals to multiple of 16 will go for direct (small memory footprint) and self.block_x3 = nn.Conv2d(64, 3, kernel_size=9, padding=4) has a output channel of 3 which will go for im2col (very large footprint). And your input for block_x3 is [1, 64, 808, 1152] will take too much memory. We need to solve this from the root, probably going to take some time.

@EKami

This comment has been minimized.

Copy link
Author

commented Mar 16, 2018

@mingfeima Oh right excuse me I pushed a fix but I didn't test before fixing for CPU inference as this branch of my code is now very far behind the one on master.

Oh ok thanks for your solution anyway. I believe we won't see a memory footprint improvement for CPU inference anytime soon as people mainly use pytorch for rapid prototyping on cuda and CPU is a bit forgotten...

@apaszke

This comment has been minimized.

Copy link
Member

commented Mar 16, 2018

@EKami it's true that it's not as good as our CUDA backend, but we are working on it, so it should get better soon!

@Katrien-Declercq

This comment has been minimized.

Copy link

commented Feb 15, 2019

Hello,
What is the status of this memory leak issue in CPU mode?
I'm running into the same problem and have no access to GPU.

@soumith

This comment has been minimized.

Copy link
Member

commented Feb 15, 2019

this original issue has been fixed, as PyTorch now ships by default with MKL-DNN.

@soumith soumith closed this Feb 15, 2019

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.