-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug] MaxPool3d causes GPU memory leaking #6222
Comments
Thanks for the report, @xichangzun. I can reproduce this and am looking into it. |
zou3519
added a commit
to zou3519/pytorch
that referenced
this issue
Apr 3, 2018
Fixes pytorch#6222 We don't need to make sure gradInput is contiguous because it's always passed in as an empty tensor (see CUDAFloatType.cpp after it gets codegen-ed). This was increasing the reference on gradInput and leaking it. I'm not sure if there's a good way to test this. I put together a script that 1) Prints out when a tensor is allocated and deallocated 2) Checks allocations vs deallocations after running a python script And verified that each allocation matches each deallocation.
soumith
pushed a commit
that referenced
this issue
Apr 3, 2018
Fixes #6222 We don't need to make sure gradInput is contiguous because it's always passed in as an empty tensor (see CUDAFloatType.cpp after it gets codegen-ed). This was increasing the reference on gradInput and leaking it. I'm not sure if there's a good way to test this. I put together a script that 1) Prints out when a tensor is allocated and deallocated 2) Checks allocations vs deallocations after running a python script And verified that each allocation matches each deallocation.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
When I try to train my model which contains MaxPool3d , It always end up with 'out of memory' error.
my environment info is here:
I can reproduce this bug by the following script:
The text was updated successfully, but these errors were encountered: