Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request: multi-epoch alternative to tf.QueueBase.close() #2514

Closed
markpwoodward opened this issue May 26, 2016 · 62 comments
Closed

Feature Request: multi-epoch alternative to tf.QueueBase.close() #2514

markpwoodward opened this issue May 26, 2016 · 62 comments
Assignees
Labels
stat:contribution welcome Status - Contributions welcome type:feature Feature requests

Comments

@markpwoodward
Copy link

Examples on the web demonstrate signaling the end of queue data by calling queue.close() in the data producer and then catching the tf.errors.OutOfRangeError exception in the data consumer.

This works fine for a single epoch, but I do multiple epochs, alternating between training data and testing data, and I can't reuse the queue after calling queue.close().

The two solutions that I have thought of using the existing code are:

  1. enqueue() some sentinel at the end of an epoch in the data producer and then tf.Assert() against the sentinel and catch the tf.errors.InvalidArgumentError in the data consumer.
  2. know the number of enqueue's for the epoch and only dequeue that number.
    Both seem a little hacky.

Multi-epoch use of queues might be simplified by adding one of the following:

  1. A queue.reset(), that throws one tf.errors.OutOfRangeError on dequeue() or some other exception.
  2. A queue.close(reset=True), that only throws one tf.errors.OutOfRangeError on dequeue() or some other exception.

example usage of 1):

q = tf.FIFOQueue(...)
placeholder = ...
enqueue_op = q.enqueue(placeholder)
....

def producer(data_dir, sess, q, enqueue_op, placeholder):
  for ...:
    sess.run(enqueue_op, {placeholder:...})
  sess.run(q.reset())

def do_epoch(data_dir, learn):
  threading.Thread(target=producer, args=(data_dir, sess, q, enqueue_op, placeholder)).start()
  while True:
    try:
      sess.run(...)
    exception tf.errors.OutOfRangeError:
      break

for epoch in range(NUM_EPOCHS):
  ... = do_epoch(TRAIN_DIR, learn=True)
  ... = do_epoch(TEST_DIR, learn=False)
@vrv
Copy link

vrv commented May 26, 2016

cc @ebrevdo @mrry @josh11b

@ebrevdo
Copy link
Contributor

ebrevdo commented May 26, 2016

Why not build a separate graph with its own queue for eval? You'll also get the flexibility of being able to add additional flexibility to append the eval graph for eval purposes.

@markpwoodward
Copy link
Author

Thank you for the suggestion. I guess I stayed away from two graphs to avoid code duplication between train and eval graphs and the accompanying maintenance issues, and because it would require updating the eval model weights with the trained model weights. I could define them both within one graph and share the weights, but that feels messy and still has the duplication.

With a single model, I can just run sess.run([loss, train_op],...) for training, or leave off train_op for eval sess.run([loss],...). It seems clean; only one place to modify the model and weights are ready from the previous training invocation.

@markpwoodward
Copy link
Author

If this isn't how others are structuring their code, then there is no urgency/necessity. I have been pre-computing the number of enqueue()'s that will be performed in an epoch, and then running the consumer the appropriate number of times, with asserts that the producer is not alive and the queue is empty before starting the next epoch.

@ebrevdo
Copy link
Contributor

ebrevdo commented May 27, 2016

People usually use one method to create the core of the code, and a
separate method to take its output and generate the loss and gradients. A
third to generate the eval losses. This way you can share code and
variables between the two graphs. See tf.contrib.learn.
On May 26, 2016 3:05 PM, "Mark Woodward" notifications@github.com wrote:

If this isn't how others are structuring their code, then there is no
urgency/necessity. I have been pre-computing the number of enqueue()'s that
will be performed in an epoch, and then running the consumer the
appropriate number of times, with asserts that the producer is not alive
and the queue is empty before starting the next epoch.


You are receiving this because you were mentioned.
Reply to this email directly or view it on GitHub
#2514 (comment)

@markpwoodward
Copy link
Author

I think I see what you are saying. Just to be certain, do you by chance have a specific example? I looked through tensorflow/contrib/learn/python/learn but did not find an example of this on my first scan. No worries if not, thank you.

@22csnyder
Copy link

22csnyder commented Jun 2, 2016

I have exactly the issue that markpwoodward does. If you have tensorflow v0.8 you might be able to use this workaround where by default you stream directly from the train queue except sometimes you pass test data in through a feed_dict. To be fair, I haven't tried it, and I came here looking for a better solution. I imagine it would look like this:

##Within Model 
input= tf.place_holder_with_default( training_dequeue, shape)

###Do training until end of epoch
sess.run( model.update)

###Do test/validate epoch
test_batch=sess.run( test_dequeue)
sess.run( model.accuracy, feed_dict={input: test_batch} )


@aselle aselle added enhancement stat:contribution welcome Status - Contributions welcome labels Jun 6, 2016
@osdf
Copy link

osdf commented Jul 18, 2016

Picking up this old thread, a question: why don't you use make_template and have three queues feeding into three automatically shared model graphs? True, you still need to keep track when an epoch ends and switch over to the e.g. validation queue, but this can be done rather simple by either counting or have the queues return an additional element (additional wrt inputs and targets) that indicates one full pass through the data (training or validation data) is done. Or is this approach considered harmful in some way?

@markpwoodward
Copy link
Author

I ended up doing the switching between validation and train in a python
producer thread, and counting the number of elements processed. But I have
been thinking about this a bit, and I am leaning towards a similar (maybe
same) idea as yours, where there are two queues (train and validation),
with a tf.cond switching between queues, controlled by a placeholder. This
would allow for a single model graph (no need for multiple instantiations
with shared weights). Funny enough, Rohit Girdhar just emailed the list
with a question, and it seems that this is the setup he is using. I copied
the relevant portion below.

"I'm training a deep network with two data input pipelines, one for
training and one for validation. They use shuffle_batch_join and
batch_join respectively for parallel data reading. The data stream that
is used in the network is decided using a tf.cond operation on top of
these two pipelines, which is controlled by a is_training placeholder
that is set to true for a training iteration and false when doing
validation. I have 4 threads for reading training data and 1 thread for
validation..."

On Mon, Jul 18, 2016 at 12:21 PM, Christian notifications@github.com
wrote:

Picking up this old thread, a question: why don't you use make_template
and have three queues feeding into three automatically shared model graphs?
True, you still need to keep track when an epoch ends and switch over to
the e.g. validation queue, but this can be done rather simple by either
counting or have the queues return a third element that indicates one full
pass through the data (training or validation data) is done. Or is this
approach considered harmful in some way?


You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
#2514 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/AGgTpS5Jyv_NEidwew9thKg9x5KTJAiRks5qW9IhgaJpZM4Inwby
.

@markpwoodward
Copy link
Author

Oops, tf.cond may be the wrong choice, as ebrevdo mentions in the parallel
thread which I will stop referencing now. From the documentation, it seems
that tf.cond executes both paths up to, but not including, the final
operation, so it wouldn't work as I had hoped.

Perhaps it could still work if what was passed to tf.cond() was the
queue.dequeue(), or queue.dequeue_many(), operations for the train_queue
and validation_queue. But those may themselves be the end of other tf
operations, so it might still not work.

On Mon, Jul 18, 2016 at 5:59 PM, Mark Woodward mwoodward@cs.stanford.edu
wrote:

I ended up doing the switching between validation and train in a python
producer thread, and counting the number of elements processed. But I have
been thinking about this a bit, and I am leaning towards a similar (maybe
same) idea as yours, where there are two queues (train and validation),
with a tf.cond switching between queues, controlled by a placeholder. This
would allow for a single model graph (no need for multiple instantiations
with shared weights). Funny enough, Rohit Girdhar just emailed the list
with a question, and it seems that this is the setup he is using. I copied
the relevant portion below.

"I'm training a deep network with two data input pipelines, one for
training and one for validation. They use shuffle_batch_join and
batch_join respectively for parallel data reading. The data stream that
is used in the network is decided using a tf.cond operation on top of
these two pipelines, which is controlled by a is_training placeholder
that is set to true for a training iteration and false when doing
validation. I have 4 threads for reading training data and 1 thread for
validation..."

On Mon, Jul 18, 2016 at 12:21 PM, Christian notifications@github.com
wrote:

Picking up this old thread, a question: why don't you use make_template
and have three queues feeding into three automatically shared model graphs?
True, you still need to keep track when an epoch ends and switch over to
the e.g. validation queue, but this can be done rather simple by either
counting or have the queues return a third element that indicates one full
pass through the data (training or validation data) is done. Or is this
approach considered harmful in some way?


You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
#2514 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/AGgTpS5Jyv_NEidwew9thKg9x5KTJAiRks5qW9IhgaJpZM4Inwby
.

@ebrevdo
Copy link
Contributor

ebrevdo commented Jul 19, 2016

What can work is a cond(filter_predicate, lambda: queue.enqueue(..),
tf.no_op). This allows controlling what goes into a queue. Most other
ways of combining queues and tf.cond usually don't have the behavior you
would expect.

On Jul 18, 2016 8:20 PM, "Mark Woodward" notifications@github.com wrote:

Oops, tf.cond may be the wrong choice, as ebrevdo mentions in the parallel
thread which I will stop referencing now. From the documentation, it seems
that tf.cond executes both paths up to, but not including, the final
operation, so it wouldn't work as I had hoped.

Perhaps it could still work if what was passed to tf.cond() was the
queue.dequeue(), or queue.dequeue_many(), operations for the train_queue
and validation_queue. But those may themselves be the end of other tf
operations, so it might still not work.

On Mon, Jul 18, 2016 at 5:59 PM, Mark Woodward mwoodward@cs.stanford.edu
wrote:

I ended up doing the switching between validation and train in a python
producer thread, and counting the number of elements processed. But I
have
been thinking about this a bit, and I am leaning towards a similar (maybe
same) idea as yours, where there are two queues (train and validation),
with a tf.cond switching between queues, controlled by a placeholder.
This
would allow for a single model graph (no need for multiple instantiations
with shared weights). Funny enough, Rohit Girdhar just emailed the list
with a question, and it seems that this is the setup he is using. I
copied
the relevant portion below.

"I'm training a deep network with two data input pipelines, one for
training and one for validation. They use shuffle_batch_join and
batch_join respectively for parallel data reading. The data stream that
is used in the network is decided using a tf.cond operation on top of
these two pipelines, which is controlled by a is_training placeholder
that is set to true for a training iteration and false when doing
validation. I have 4 threads for reading training data and 1 thread for
validation..."

On Mon, Jul 18, 2016 at 12:21 PM, Christian notifications@github.com
wrote:

Picking up this old thread, a question: why don't you use make_template
and have three queues feeding into three automatically shared model
graphs?
True, you still need to keep track when an epoch ends and switch over to
the e.g. validation queue, but this can be done rather simple by either
counting or have the queues return a third element that indicates one
full
pass through the data (training or validation data) is done. Or is this
approach considered harmful in some way?


You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<
#2514 (comment)
,
or mute the thread
<
https://github.com/notifications/unsubscribe-auth/AGgTpS5Jyv_NEidwew9thKg9x5KTJAiRks5qW9IhgaJpZM4Inwby

.


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#2514 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/ABtim7VModT2soQ10NwRw9Zf7-umSE5Rks5qXEKFgaJpZM4Inwby
.

@osdf
Copy link

osdf commented Jul 19, 2016

With make_template there won't be multiple copies of the graph. Tensorboard nicely shows that the two queues (one for training data, the other for validation data) share one graph with the weights (True, you need to have two different smaller pieces of loss expressions that are associated with the queue and are placed on top of the one graph that has the trainable parameters). Deciding which queue is used is done on the outer training iteration loop, no tf.cond or any other expression-based logic necessary, you just sess.run the respective loss expression and the underlying queue is polled.

@markpwoodward
Copy link
Author

Christian, just a clarification, doesn't calling make_template twice create
duplicate operation paths in the graph, one for training and one for
validation? That duplication is probably harmless, it just seems messy when
all we want to do is switch between two input sources (although I don't
have an alternative to propose).

On Tue, Jul 19, 2016 at 12:05 AM, Christian notifications@github.com
wrote:

With make_template there won't be multiple copies of the graph.
Tensorboard nicely shows that the two queues (one for training data, the
other for validation data) share one graph with the weights (True, you need
to have two different smaller pieces of loss expressions that are
associated with the queue and are placed on top of the one graph that has
the trainable parameters). Deciding which queue is used is done on the
outer training iteration loop, no tf.cond or any other expression-based
logic necessary, you just sess.run the respective loss expression and the
underlying queue is polled.


You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
#2514 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/AGgTpRXoUlDAdmhT6zbDrHRd5_ReufOaks5qXHdXgaJpZM4Inwby
.

@osdf
Copy link

osdf commented Jul 19, 2016

@markpwoodward hmm, i don't know what an operation path is. Say you call it twice, with two different input tensors (e.g. coming from a deque op on your training/validation set respectively). What is happening is that in the first call, the graph with the parameters is constructed. In the second call, this graph is simply reused. If you open up tensorboard, you see exactly this view, sort of a bottlenecked picture (if your shared graph is followed by some additional, loss-connected ops, which you need to have duplicated, that's true): At the bottom the two queues and their nodes, both feeding into the network expression that is made up of shared parameters, then going splitting up again to two loss-connected paths. With respect to messiness, I thought it the least messy solution, as I only had to add a for loop over the queues and kept the rest the same, make_template taking care of the rest. But messiness is probably quite subjective :-).

@markpwoodward
Copy link
Author

@osdf, thank you for the response. Would you mind include the picture from
tensorboard? I certainly may be missing something about how tensorflow
execution happens. Also, maybe brief pseudo code for your usage of
make_template(); where the goal of the pseudo code is just inference (no
loss or training), but on two different queues.

On Tue, Jul 19, 2016 at 11:16 AM, Christian notifications@github.com
wrote:

@markpwoodward https://github.com/markpwoodward hmm, i don't know what
an operation path is. Say you call it twice, with two different input
tensors (e.g. coming from a deque op on your training/validation set
respectively). What is happening is that in the first call, the graph with
the parameters is constructed. In the second call, this graph is simply
reused. If you open up tensorboard, you see exactly this view, sort of a
bottlenecked picture (if your shared graph is followed by some additional,
loss-connected ops, which you need to have duplicated, that's true): At the
bottom the two queues and their nodes, both feeding into the network
expression that is made up of shared parameters, then going splitting up
again to two loss-connected paths. With respect to messiness, I thought it
the least messy solution, as I only had to add a for loop over the queues
and kept the rest the same, make_template taking care of the rest. But
messiness is probably quite subjective :-).


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#2514 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/AGgTpRXQZW9sXwpodVkpt3Joe28JgAnrks5qXRRygaJpZM4Inwby
.

@josh11b
Copy link
Contributor

josh11b commented Jul 27, 2016

You may be able to use QueueBase.from_list to dynamically select which queue to dequeue from, see:
https://www.tensorflow.org/versions/r0.9/api_docs/python/io_ops.html#QueueBase.from_list

@markpwoodward
Copy link
Author

I don't know how I missed that! Thank you. I just tested it. It works great; two FIFOQueues, one placeholder to select the queue.

This feature request was side tracked a bit, I will leave it open as my original request of a way to signal the last dequeue of an epoch, without needing to count dequeue's, still stands. Low priority, since it is easy enough to count dequeue's and this feature is likely less relevant for larger datasets, where people don't usually do things on epoch boundaries.

@aselle aselle removed the triaged label Jul 28, 2016
@zaheersm
Copy link

zaheersm commented Aug 6, 2016

@markpwoodward Can you explain how to use the feature pointed out by @josh11b? A code snippet would be great.

@danijar
Copy link
Contributor

danijar commented Sep 2, 2016

Hi, I am using the same solution that was mentioned here, having a boolean placeholder is_training, and combining the training and evaluation queues using batch = tf.cond(is_training, lambda: training_batch, lambda: testing_batch). Why would this not work? It seems to work for me, but I didn't check if it always dequeues both batches.

@mrry
Copy link
Contributor

mrry commented Sep 2, 2016

@danijar Assuming training_batch and testing_batch are the results of calls to tf.train.batch() that occur before the tf.cond(), it will always dequeue both batches. See here for an explanation (and substitute "dequeue operation" for "assignment").

@danijar
Copy link
Contributor

danijar commented Sep 6, 2016

Thanks. So I basically need to pass functions into tf.cond that do a tf.identity and perform the desired calls within the tf.control_dependencies of the identity, correct? What about something like this instead:

queue = tf.cond(is_training, train_queue, test_queue)
batch = queue.dequeue_many(batch_size)

@cancan101
Copy link
Contributor

The issues that I have with the two original proposals are:
The two solutions that I have thought of using the existing code are:

  1. enqueue() some sentinel at the end of an epoch in the data producer and then tf.Assert() against the sentinel and catch the tf.errors.InvalidArgumentError in the data consumer.

What to do with batches? ie the single rows are handled by a buffering batch creator

  1. know the number of enqueue's for the epoch and only dequeue that number.
    Both seem a little hacky.

Again with batching, If the epoch size is not a multiple of the batch size, then part of the next epoch ends up on the last batch.

Perhaps an alternative solution to the OP would be a queues that raises OutOfRangeError but stays open. That way the local variable for epochs can just be reset.

@cancan101
Copy link
Contributor

cancan101 commented Sep 9, 2016

One solution that I have found is to create a new input_producer each epoch and then to initialize_variables only the new local variable (the epoch count). This works but yields warnings for the old queues:

[[Node: input_producer_18/fraction_of_32_full_Dequeue = QueueDequeue[_class=["loc:@input_producer_18/input_producer/fraction_of_32_full/fraction_of_32_full"], component_types=[DT_INT32], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](input_producer_18/input_producer/fraction_of_32_full/fraction_of_32_full)]]
W tensorflow/core/framework/op_kernel.cc:940] Out of range: FIFOQueue '_37_input_producer_18/input_producer/fraction_of_32_full/fraction_of_32_full' is closed and has insufficient elements (requested 1, current size 0)

@markpwoodward
Copy link
Author

markpwoodward commented Sep 9, 2016

Regarding the side track, not the original request.
@MuhammadZaheer, here is an example usage of QueueBase.from_list() that you asked for.
@danijar, I would recommend using QueueBase.from_list() over tf.cond().

import tensorflow as tf

q1 = tf.FIFOQueue(capacity=100, dtypes=[tf.int32])
input1 = tf.placeholder(tf.int32, [])
enq1 = q1.enqueue(input1)

q2 = tf.FIFOQueue(capacity=100, dtypes=[tf.int32])
input2 = tf.placeholder(tf.int32, [])
enq2 = q2.enqueue(input2)

select_q = tf.placeholder(tf.int32, [])
q = tf.QueueBase.from_list(select_q, [q1, q2])
data = q.dequeue()

with tf.Session() as sess:
  # enqueue values (these would typically be in their own thread(s))
  q1_vals = [1,2,3]
  print "q1_vals = " + str(q1_vals)
  for v in q1_vals:
    sess.run(enq1, { input1: v })

  q2_vals = [4,5,6]
  print "q2_vals = " + str(q2_vals)
  for v in q2_vals:
    sess.run(enq2, { input2: v })

  # run an op that pulls from the queue, specifying which queue
  for i in range(3):
    print "q1.dequeue = " + str(sess.run(data, {select_q: 0}))
    print "q2.dequeue = " + str(sess.run(data, {select_q: 1}))

outputs

q1_vals = [1, 2, 3]
q2_vals = [4, 5, 6]
q1.dequeue = 1
q2.dequeue = 4
q1.dequeue = 2
q2.dequeue = 5
q1.dequeue = 3
q2.dequeue = 6

@ebrevdo
Copy link
Contributor

ebrevdo commented Sep 10, 2016

QueueBase.from_list may be deprecated soon. +@mrry

On Sep 9, 2016 1:14 PM, "Mark Woodward" notifications@github.com wrote:

Regarding the side track, not the original request.
@MuhammadZaheer https://github.com/muhammadzaheer, here is an example
usage of QueueBase.from_list() that you asked for.
@danijar https://github.com/danijar, I would recommend using
QueueBase.from_list() over tf.cond().

import tensorflow as tf

q1 = tf.FIFOQueue(capacity=100, dtypes=[tf.int32])
input1 = tf.placeholder(tf.int32, [])
enq1 = q1.enqueue(input1)

q2 = tf.FIFOQueue(capacity=100, dtypes=[tf.int32])
input2 = tf.placeholder(tf.int32, [])
enq2 = q2.enqueue(input2)

select_q = tf.placeholder(tf.int32, [])
q = tf.QueueBase.from_list(select_q, [q1, q2])
data = q.dequeue()

with tf.Session() as sess:

enqueue values (these would typically be in their own thread(s))

q1_vals = [1,2,3]
print "q1_vals = " + str(q1_vals)
for v in q1_vals:
sess.run(enq1, { input1: int(v) })

q2_vals = [4,5,6]
print "q2_vals = " + str(q2_vals)
for v in q2_vals:
sess.run(enq2, { input2: int(v) })

run an op that pulls from the queue, specifying which queue

for batch in range(3):
print "q1.dequeue = " + str(sess.run(data, {select_q: 0}))
print "q2.dequeue = " + str(sess.run(data, {select_q: 1}))

outputs

q1_vals = [1, 2, 3]
q2_vals = [4, 5, 6]
q1.dequeue = 1
q2.dequeue = 4
q1.dequeue = 2
q2.dequeue = 5
q1.dequeue = 3
q2.dequeue = 6


You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
#2514 (comment),
or mute the thread
https://github.com/notifications/unsubscribe-auth/ABtim4OS5APEuGuTpj9-uSnIFfE0qZ4Lks5qob4vgaJpZM4Inwby
.

@rsethur
Copy link

rsethur commented Jan 27, 2017

Many thanks to @yaroslavvb : The time taken to add the operators was just 90 seconds in my case - so make_template is a reasonable approach and works great for me.

@TimZaman : Thanks a LOT for taking the time to share your approach using scopes instead of make_template - this works fine as well in my case. The latter throws an error incase of tf.Variable is used - that seems to be an advantage.

@shaayaansayed
Copy link

@yaroslavvb, I'd also greatly appreciate some sample code on how to use tf.train.maybe_batch. I imagine that keep_input needs to be a placeholder passed in during graph eval.

@TimZaman, I also implemented the train and val with different feeds, with two different models in different name_scopes but still reusing variables across the graphs obviously. I found that if I created the train graph first and then the val graph, evaluating the val graph still caused the train input pipeline to dequeue, implying that the graphs were not separated well enough. I see you mentioned that issue earlier in this feed. How do you ensure graph separation using different input feeds?

I created a google groups discussion to share different tensorflow workflows. I think it would be very helpful to me and other more inexperienced users if some of you guys could maybe share your code designs. Thanks!

@aselle aselle added type:feature Feature requests and removed enhancement labels Feb 9, 2017
@lballes
Copy link

lballes commented Feb 28, 2017

Does anyone by now have a minimum working example using maybe_batch to switch between train and test queues? I'd greatly appreciate it!

@mrry
Copy link
Contributor

mrry commented Feb 28, 2017

We're planning to move away from queues and provide first-class support for multi-epoch processing in the redesigned input pipeline API. Please feel free to comment on #7951 if there are particular features that you'd particularly like to see in the new API!

@hughperkins
Copy link

in the meantime, what is the standard approach for running multiple training epochs using a queue? I've searched around for a while, and the closest thing I could find was http://stackoverflow.com/a/39209186/212731 "http://stackoverflow.com/a/39209186/212731", for which @mrry provides a workaround. Is this the standard technique we should follow for now? or ... ?

(basically, I have a bunch of examples, which I'm happy to store in a file as tfrecords, but I need to run indefinitely; specifying the number of epochs at the start is not really ideal for me. Having to guess how many steps per epoch is also not ideal).

@danijar
Copy link
Contributor

danijar commented Apr 14, 2017

How would I use tf.train.maybe_batch() to select between training and testing batches?

@LucasMahieu
Copy link

does someone used tf.train.maybe_batch() and wants to share the way he used it ? It would be awersone

@danijar
Copy link
Contributor

danijar commented May 24, 2017

@mrry?

@b3nk4n
Copy link

b3nk4n commented May 30, 2017

In my opinion, tf.train.maybe_batch() is not working properly yet. By looking at the code, it is exactly the same method as tf.train.batch(). The latter simply uses keep_input=True.

According to the API documentation, keep_input is a bool Tensor, so it should accept a tf.placeholder(tf.bool) as well, right? But when I use a placeholder, and feed it with the value True, the queue is simply blocking and nothing is happening.

I tried something like that:

is_training = tf.placeholder(tf.bool, shape=[])

...

image_batch, label_batch = tf.cond(is_training,
                                   true_fn=lambda: tf.train.maybe_batch([train_image, train_label],
                                                                        keep_input=is_training,
                                                                        batch_size=BATCH_SIZE),
                                   false_fn=lambda: tf.train.maybe_batch([test_image, test_label],
                                                                         keep_input=tf.logical_not(is_training),
                                                                         batch_size=BATCH_SIZE))
with tf.Session() as sess:
    # initialize the variables
    sess.run(tf.local_variables_initializer())
    sess.run(tf.global_variables_initializer())

    # initialize the queue threads to start to shovel data
    threads = tf.train.start_queue_runners(coord=tf.train.Coordinator())

    for i in range(10):
        print sess.run(label_batch, feed_dict={is_training: True})
...

I thought of tf.train.maybe_batch() to be similar to the approach to use tf.cond() to switch between input-queue (for training) and feeding (for validation), but probably without having the downside described here. But it's also possible that I got this wrong...

@tpatel0409
Copy link

Hi,
So any1 found a way to use tf.train.maybe_batch() appropriately ?
I am not finding out ways to switch between train/test data queues while training or eval phase.

Thanks

@danijar
Copy link
Contributor

danijar commented Jun 12, 2017

Here is a hacky solution using tf.FIFOQueue.from_list():

def select_batch(batches, index):
  """
  Select a batch based on the current value of the index. Only the active batch
  will be consumed. Each batch can be an arbitrarily nested tuple or list.
  """

  def _get_dtypes(tensors):
    if isinstance(tensors, (list, tuple)):
      return type(tensors)(_get_dtypes(tensor) for tensor in tensors)
    return tensors.dtype

  def _get_shapes(tensors):
    if isinstance(tensors, (list, tuple)):
      return type(tensors)(_get_shapes(tensor) for tensor in tensors)
    return tensors.shape

  def _flatten(collection):
    if isinstance(collection, (list, tuple)):
      return sum([_flatten(element) for element in collection], [])
    return [collection]

  def _unflatten(iterator, shapes):
    if isinstance(shapes, (list, tuple)):
      return type(shapes)(_unflatten(iterator, shape) for shape in shapes)
    return next(iterator)

  queues = []
  for batch in batches:
    dtypes, shapes = _get_dtypes(batch), _get_shapes(batch)
    queue = tf.FIFOQueue(10, _flatten(dtypes), _flatten(shapes))
    runner = tf.train.QueueRunner(queue, (queue.enqueue(_flatten(batch)),))
    tf.train.add_queue_runner(runner)
    queues.append(queue)
  batch = tf.FIFOQueue.from_list(index, queues).dequeue()
  return _unflatten(iter(batch), shapes)

@JulienSiems
Copy link

@tpatel0409 @LucasMahieu @danijar Have you found an example of how to use tf.train.maybe_batch in the mean time? Or did you end up switching to the new dataset api?

@LucasMahieu
Copy link

LucasMahieu commented Jun 29, 2017 via email

@JulienSiems
Copy link

@LucasMahieu Thank you for your quick response. Does the tf slim api provide any way of solving the problem? I am using tf slim as well but just for using predefined layers. I am currently looking into tf.train.maybe_batch and will post here if I come up with any example.

@JulienSiems
Copy link

Maybe @ebrevdo ? Since you suggested it. It would be great if you could post a snippet

@alexwal
Copy link

alexwal commented Jun 30, 2017

Also curious and will be reading through TF Slim data API today. I will post if I come across a solution.

zhuangh pushed a commit to zhuangh/tensorflow that referenced this issue Jul 14, 2017
…' guide.

This is a potential solution to issue tensorflow#2514.

PiperOrigin-RevId: 161732107
@JulienSiems
Copy link

JulienSiems commented Aug 3, 2017

https://www.tensorflow.org/versions/r1.3/programmers_guide/datasets

Looks like the dataset api is the recommended way to do it (See for example tfrecords reader). Since I am using the input pipeline I ended up evaluating batches of testdata and then feeding them back into the training queue.
As suggested by: #2514 (comment)

tillahoffmann referenced this issue in caisq/tensorflow Aug 25, 2017
With this change, it becomes possible to use a Python generator as the source
dataset for a `tf.contrib.data` input pipeline. This enables easier integration
with non-TensorFlow data sources. The generator can yield a nested structure of
NumPy arrays, or values convertible to NumPy arrays.

This addresses a concern raised in issue tensorflow#7951.

PiperOrigin-RevId: 165663857
@panicooo
Copy link

panicooo commented Nov 16, 2017

How about if the labels for train and test are different? For example, If I have two lables for train, but only one label for test.

@markpwoodward
Copy link
Author

If you are using tf.estimator.Estimator, which I think is the current best way to go (https://stackoverflow.com/questions/46925196/does-tf-estimator-estimator-train-maintain-input-fn-state), you can pass a different input_fn for train and evaluate, and you can use the mode passed to model_fn to create the appropriate graph.

@markpwoodward
Copy link
Author

I just closed this. Feel free to re-open it. In my opinion, tf.estimator and tf.dataset are the way to go, and solve my original issue.

If you need to do things at the end of each epoch then just run estimator.train() with a Dataset that does not repeat. Alternate this with estimator.evaluate() or whatever you need.

Often times the dataset is too big to wait for epoch boundaries, so use a Dataset that repeats and pass a listener to estimator.train() that runs estimator.eval() whenever a checkpoint is saved. See the link in my last comment above.

I initially had a problem with this approach because of the perceived overhead of creating the graph on each call to estimator.train() or estimator.evaluate(). But, for me, the overhead has been negligible, and having checkpoints be the transfer of information between train() and evaluate() feels like the right approach.

@panicooo
Copy link

@TimZaman Is there a example(make_template or share weights by sope) of try do so such things(diffrent feed for train and test). I try to use share weights method but end result that the testing data is always using the wights initialized. But the wights for traing have been updated.
Here is my code:

with tf.name_scope('train') as scope:
    dist_train = network_out(network, x_train, keep_prob= FLAGS.keep_prob, phase_train=True, batch_size=FLAGS.batch_size)
 with tf.name_scope('eval') as scope:
    with tf.variable_scope(tf.get_variable_scope(), reuse=True) as vscope:
        dist_test = network_out(network, x_test, keep_prob= FLAGS.keep_prob, phase_train=False, batch_size=FLAGS.eval_batch_size)

Any one can help me?

@rsethur
Copy link

rsethur commented Dec 29, 2017

@markpwoodward Many thanks for your update and stackoverflow link - I would like to migrate to estimator API as well. I still have a challenge - can you share your thoughts please?

I would like to run evaluation before my epoch ends (as stated in your post as well because dataset is huge). Earlier I used to run evaluation and then save a checkpoint if score is better than the old one. Is this logical?

From your post i understand that evaluation (by a listener) is run whenever a checkpoint is saved.
What triggers a checkpoint save? How to implement the above logic?
Thank you!

@markpwoodward
Copy link
Author

@rsethur Saving only the best checkpoint would be efficient, but I am not sure how to do that with this setup. Also, evaluation loss may not be the thing you want to optimize, you might want to review a number of evaluation metrics and pick the checkpoint that looks best in a general sense. I just keep all checkpoints, and visually inspect if I need to train more.

As for setting the frequency that estimator.train() saves checkpoints, there are probably other ways to do this, but I do it in a RunConfig object passed to Estimator's constructor.

estimator = tf.estimator.Estimator(
  model_fn=...,
  config=tf.estimator.RunConfig().replace( # I'm not sure why I use replace here and not the constructor
    save_checkpoints_steps=1000, # or whatever you want
    keep_checkpoint_max=None, # defaults to last 5, but I keep all
  ),
  ...
)

Also, take a look at the new Estimator.train_and_evaluate(), you may prefer it. I still prefer my proposed approach as I actually run evaluation on a fixed subset of my training data in addition to running evaluation on my validation set. I haven't been able to get train_and_evaluate() to support this (e.g. multiple EvalSpec's)

@rsethur
Copy link

rsethur commented Dec 31, 2017

@markpwoodward Many thanks for your detailed response - much appreciated!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:contribution welcome Status - Contributions welcome type:feature Feature requests
Projects
None yet
Development

No branches or pull requests