Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Horrible tensorboard graph: #19

Closed
MadcowD opened this issue Feb 24, 2017 · 3 comments
Closed

Horrible tensorboard graph: #19

MadcowD opened this issue Feb 24, 2017 · 3 comments

Comments

@MadcowD
Copy link

MadcowD commented Feb 24, 2017

Hi,

I'm trying to scope certain components of my tensorflow fold model, and I after rending it in tensorboard, it seems that its ignoring the scope in which I construct different blocks. See this http://i.imgur.com/Vtf9OhP.png

The associated source code which generates this graph is:

with tf.variable_scope("context_free_discriminator"):
	with tf.variable_scope("embedding"):
	    pre_process =  (td.InputTransform(lambda tokens: [tokeToVec(s, word2vec, indim) for s in tokens]) 
	                    >> td.Map(td.Vector(indim)))
	    
	    word_cell = td.ScopedLayer(tf.contrib.rnn.BasicLSTMCell(num_units=300), 'word_cell')
	    stacked_cell =  td.ScopedLayer(tf.contrib.rnn.BasicLSTMCell(num_units=100), 'stacked_cell')

	    # word lstm convers word sequence into an lstm.
	    word_lstm = (pre_process >> td.RNN(word_cell))
	    hierarchical_lstm = (word_lstm 
	                         >> td.GetItem(0) #hidden states
	                         >> td.RNN(stacked_cell))
	    embedding =  (hierarchical_lstm  
	                                      >> td.GetItem(1) # Final state (out, hidden)
	                                      >> td.GetItem(0) # Final output state.
	                                      >>  td.FC(outdim))

	with tf.variable_scope("model"):
	    logit_output_block = (embedding_block 
	              >> td.FC(428)
	              >> td.FC(328)
	              >> td.FC(num_mocs, activation=None))
	    compiler = td.Compiler.create(logit_output_block)
	    logit_tensor = compiler.output_tensors

	with tf.variable_scope("training"):
	    _desired = tf.placeholder(tf.int32)
	    _desired_onehot = tf.one_hot(_desired, num_mocs)
	    loss = cross_entropy = tf.reduce_mean(
	        tf.nn.softmax_cross_entropy_with_logits(labels=_desired, logits=logit_tensor))

	    train_op = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)
@fabioasdias
Copy link

My understanding is that since td is build over tf, not everything will be accessible as would be the case with "pure" tensorflow, meaning specifically that scopes are probably used by td to deal with the dynamic part of building the graph. I'm not entirely sure that the tf.variable_scope commands are compatible with the whole td.compiler thing. If they were, you wouldn't need a scopedlayer, IMHO.

A good example of that is the results of the 'word_cell' scope, that is "internally" translated to several blocks.

Of course, I'm a random internet dude, not a contributor to the project, so everything I said might be horribly wrong.

@delesley
Copy link
Contributor

delesley commented Feb 27, 2017 via email

@moshelooks
Copy link

Yeah, TensorBoard graph view is not very helpful with Fold, unfortunately. Regarding scoping, if you know what you're doing you can pass in a scope as an argument to td.ScopedLayer and it will get used exactly as provided. If you pass in a string we go through some gymnastics to ensure that every ScopedLayer really does correspond to a unique scope. We do this to avoid some breakage and gotchas in the TF variable scoping mechanism. What Delesley said is a very good suggestion for making sure the right variables are being created. If you are observing undersharing (e.g. a single scoped layer that encapsulates variables that are being created more than once) that is bug, please do report it. But the view of the graph in tensorboard is unfortunately not sufficient to figure out what's going on.

Regarding the while loop, I'm confused; Fold certainly only creates one while loop per compiler; maybe you ran Block.eval() somewhere? (eval() is only for messing around at the repl and unit tests; it shouldn't be used in "real" models)

This issue was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants