Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SummaryWriter writes constants with the graph definition #1444

Closed
vgatto opened this issue Mar 9, 2016 · 8 comments
Closed

SummaryWriter writes constants with the graph definition #1444

vgatto opened this issue Mar 9, 2016 · 8 comments

Comments

@vgatto
Copy link

vgatto commented Mar 9, 2016

It looks like SummaryWriter includes the values of constants when writing the graph definition. I'm currently loading word embeddings using python and using the numpy array to initialize a variable. This results in the ~300M of embeddings to be dumped to disk during summarization. I understand why a constant would be part of the model graph if the goal was to save and restore the graph, but I don't believe this is the goal of the SummaryWriter. Is it possible to strip these out to save disk space?

Here's a trivial example. If the declaration of random_b is removed, the resulting summary is ~4.6K, but with it, it's about 7.6M.

import numpy as np
import os
import tensorflow as tf

session = tf.Session()

logdir = "/tmp/tflogs"

random_a = tf.Variable(tf.random_normal([1000000]))
random_b = tf.Variable(np.random.rand(1000000))
tf.histogram_summary("random_var", random_a)
os.makedirs(logdir)
writer = tf.train.SummaryWriter(logdir, session.graph_def)
init = tf.initialize_all_variables()
merged_summary_op = tf.merge_all_summaries()
session.run(init)

summary = session.run(merged_summary_op)
writer.add_summary(summary, 0)
@vgatto
Copy link
Author

vgatto commented Mar 9, 2016

I'd also like to point out that this makes it impossible to visualize the graph in tensorboard, since it just hangs parsing the constants in the graph def.

@girving
Copy link
Contributor

girving commented Mar 10, 2016

Constants are part of the GraphDef, so unfortunately this is intended behavior. To keep them out of the GraphDef, change from

random_a = tf.Variable(tf.random_normal([1000000]))

to

random_a = tf.Variable(shape=[1000000])
random_a.assign(np.random.randn(*random_a.shape)).eval()

Ug, actually that'll implicitly create a constant too. @petewarden: What's the right solution here?

@mrry
Copy link
Contributor

mrry commented Mar 10, 2016

The same question came up on StackOverflow today: http://stackoverflow.com/a/35904439/3574081

(FWIW: random_a = tf.Variable(tf.random_normal([1000000])) won't add a large constant to the graph, but random_a = tf.Variable(np.random.normal(size=[1000000])) will.)

@petewarden
Copy link
Contributor

To answer what I think is the original question, it's definitely possible to strip out Constant ops from a graph def. We do the opposite in the freeze_graph script, so we could create a 'strip_graph' script or similar to slim down files by replacing large constants with something smaller (though we'd have to do some resize tricks to do it properly).

I think the bigger question is 'Why are there these big weights in my graph def file?'. We're working on answering that here by looking at different ways that we can save external weights, without using the Variable/Restore checkpoints since those are fairly specialized for training. We don't have a good design for that yet though.

Does that help at all?

@girving
Copy link
Contributor

girving commented Mar 10, 2016

@mrry: That solves the code shown, but I believe the original code loads a big numpy array from a file.

@vgatto
Copy link
Author

vgatto commented Mar 10, 2016

@mrry's suggestion works perfectly as a work-around for my large numpy array (I'm responsible for the stackoverflow question as well) and I'm able to use Tensorboard's graph visualization again.

@petewarden - FWIW, I never use the Saver to persist the entire session. I'm always cherrypicking variables so that I can deploy just the model to other machines for inference or train on a different dataset. So any help with managing training state vs. learned model state would be great.

@thomasnedelec
Copy link

@vgatto I'd also like to point out that this makes it impossible to visualize the graph in tensorboard, since it just hangs parsing the constants in the graph def. Can you explain why and how can I display the graph?

@girving
Copy link
Contributor

girving commented Jun 6, 2016

Closing this as intended behavior. @toma5692 and others: If you want to keep graphs small, use tf.Variable instead of tf.constant.

@girving girving closed this as completed Jun 6, 2016
fsx950223 pushed a commit to fsx950223/tensorflow that referenced this issue Nov 28, 2023
Switching Dockerfile.rocm to use python3.9
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants