Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory 'leak' issue with MapTable and clearState #1141

Closed
achalddave opened this issue Feb 20, 2017 · 1 comment
Closed

Memory 'leak' issue with MapTable and clearState #1141

achalddave opened this issue Feb 20, 2017 · 1 comment

Comments

@achalddave
Copy link
Contributor

achalddave commented Feb 20, 2017

Calling clearState() on MapTable seems to lead to a significant increase in memory usage for future iterations. I was unable to find the source, but the following script demonstrates the bug. (Of course, the exact memory amounts may not match across systems, but the relative amounts should be correct.)

--[[ Displays memory 'leak' with nn.MapTable after clearState() ]]--

local nn = require 'nn'
local cunn = require 'cunn'

model = nn.MapTable():add(nn.SpatialConvolution(3, 256, 3, 3, 1, 1, 1, 1, 1))
                     :cuda()
i = {torch.rand(30, 3, 224, 224):cuda(), torch.rand(30, 3, 224, 224):cuda()}

function check_mem() os.execute('nvidia-smi | grep luajit') end
-- Train two iterations without clear state:
print('Before training 1'); check_mem() -- 277 MiB
o = model:forward(i)
print('After forward 1'); check_mem() -- 3254 MiB

-- Train another iteration:
print('Before training 2'); check_mem() -- 3254 MiB
o = model:forward(i)
print('After forward 2'); check_mem() -- 3254 MiB

-- Clear state:
model:clearState()
collectgarbage()
collectgarbage()

-- Train a final iteration before clearState. This final forward call causes an
-- increase in memory usage for the rest of the program!
print('Before training 3 (after clearState)'); check_mem() -- 3254 MiB
o = model:forward(i)
print('After forward 3 (after clearState)'); check_mem() -- 4724 MiB!

-- Garbage collection doesn't fix it.
collectgarbage()
collectgarbage()
print('After collectgarbage()'); check_mem() -- 4724 MiB!

The script is at: https://gist.github.com/achalddave/6ac8390e06a23ecc6d67e3fa22ef0f04

A few notes:

  • The issue does not seem to occur if I attempt to only forward an input table with only one element.
  • The issue does not occur if nn.MapTable is removed (and replaced with just a single SpatialConvolution operating on a single input tensor).
@achalddave
Copy link
Contributor Author

Update: This issue is fixed if the modules in MapTable have clearState called before they are removed from MapTable, here:

function MapTable:clearState()

achalddave added a commit to achalddave/nn that referenced this issue Feb 20, 2017
@soumith soumith closed this as completed Feb 21, 2017
achalddave added a commit to achalddave/predictive-corrective that referenced this issue Jul 20, 2017
This should fix the same issue as in
torch/nn#1141
achalddave added a commit to achalddave/predictive-corrective that referenced this issue Jul 20, 2017
Calling clearState() seems to cause issues that, after 4-5 days of
debugging, I haven't been able to fix. See, for example:

torch/nn#1141
torch/cunn#441

Further, it's unclear to me if `getParameters` and memory management in
general works well when a call to `clearState` can destroy modules (and
therefore weight tensors). The easiest solution to all of this is simply
to never call clearState on the model while it is training.

When saving the model, we create a copy of it on the CPU, and call
clearState on this CPU copy, which we then save to disk.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants