Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Calling variable.assign() too many times crashes on memory allocation. #2311

Closed
jhollowayj opened this issue May 10, 2016 · 6 comments
Closed
Assignees

Comments

@jhollowayj
Copy link

jhollowayj commented May 10, 2016

Background: I'm working on a set of networks that only share some layers, so I have a parameter server that sends new weights for the different clients to use. These clients accept the new weights and bias for the layers they are using and assign the values to the TF.Variables via sess.run(self.w1.assign(new_weights)). However, when I start it up and let it run, it crashes saying

W tensorflow/core/common_runtime/bfc_allocator.cc:271] Ran out of memory trying to allocate 16B.  See logs for memory state.

(Sometimes it's allocating 16B, other times its 3.9KiB)

To give you an idea of the size of the weights, I have three layers of:
Layer 1(W,b): (2, 1000), (1000, )
Layer 2(W,b): (1000, 1000), (1000, )
Layer 3(W,b): (1000, 4), (4, )
I'm running on a Titan X with 12G memory.

With per_process_gpu_memory_fraction = 0.01, the program dies at ~190 assign commands.
With per_process_gpu_memory_fraction = 0.02, the program dies at ~384 assign commands.
With per_process_gpu_memory_fraction = 0.03, the program dies at ~780 assign commands.
With per_process_gpu_memory_fraction = 0.04, the program dies at ~784 assign commands.
With per_process_gpu_memory_fraction = 0.05, the program dies at ~1582 assign commands.
With per_process_gpu_memory_fraction = 0.06, the program dies at ~1586 assign commands.

I've tried to set allow_growth=True, and deferred_deletion_bytes=1 in the session's GPUOptions after reading issue #1578, but that didn't get me much further. (I have no idea what deferred_deletion_bytes does...) Looking at the numbers just above (GPU%vsAssignmentCommands), it seems to be fairly linear, so it seems to me that the assign operation takes some of the GPU ram and it's never freed up. Is there any sense of GC on the GPU memory allocated durring the var.assign() op?

It seems that I could delete and create a new session, but that sounds expensive to me, and I'd have to maintain the weights outside of session to be able to restore them correctly. The second idea I had would to use placeholders and ship the weights in every time with the feed_dict, but again, that seems less that ideal and I think it would struggle in the optimizer on knowing what to optimize if they are just placeholders.

Let me know if you would like any other logs or reports from me. I figure this is the first time someone has tried to use assign operations like this, so I want to be helpful in fixing it if it's a bug.

Thanks

Environment info

Operating System: Ubuntu 16.04
Installed version of CUDA and cuDNN:
/usr/local/cuda/lib/libcudart.so -> libcudart.so.7.0
/usr/local/cuda/lib/libcudart.so.7.0 -> libcudart.so.7.0.28
/usr/local/cuda/lib/libcudart.so.7.0.28
/usr/local/cuda/lib/libcudart_static.a
Built from source. Commit hash: 35cd6a3

@jhollowayj
Copy link
Author

jhollowayj commented May 12, 2016

Here's a quick script that should break when you run it (if it helps...). Mine dies on iteration 31.
crash_tf_assign_op.py.txt

@mrry
Copy link
Contributor

mrry commented May 16, 2016

The assign op is not consuming memory, but the problem is caused by the fact that each instance of new_weights is converted to a constant op, and added to the graph. Each constant op owns a buffer containing the value that it produces, and a constant op on the GPU device will allocate that buffer in GPU memory.

The fix is to rewrite your program somewhat. Instead of doing:

for i in range(3000):
    print "Assigning i:{}".format(i)
    sess.run(w1.assign(new_value_array))

... you should declare the assign op and a placeholder before the loop, and feed different values to the placeholder in each iteration:

assign_placeholder = tf.placeholder(tf.float32, shape=[1000, 1000])
assign_op = w1.assign(assign_placeholder)

for i in range(3000):
    print "Assigning i:{}".format(i)
    sess.run(assign_op, feed_dict={assign_placeholder: new_value_array})

@mrry mrry closed this as completed May 16, 2016
@jhollowayj
Copy link
Author

That totally makes sense now. I never would have guessed to do that though. Thanks so much.

@mrry
Copy link
Contributor

mrry commented May 16, 2016

Indeed - it's a difficult error to disallow, because there are many totally valid patterns that involve adding nodes to the graph. One tip is to try calling tf.get_default_graph().finalize() before your training loop, so that an error will be thrown if you accidentally add a node. (However, we can't do that automatically - e.g. on the first run() call - because it would break a huge number of people :( ...)

@jhollowayj
Copy link
Author

Thanks @mrry. Everything is running great again.

@mrry
Copy link
Contributor

mrry commented May 17, 2016

Glad to hear it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants