Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request - Seq2Seq Inference Helper w/o Embeddings #12065

Closed
RylanSchaeffer opened this issue Aug 6, 2017 · 44 comments
Closed

Feature Request - Seq2Seq Inference Helper w/o Embeddings #12065

RylanSchaeffer opened this issue Aug 6, 2017 · 44 comments
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower

Comments

@RylanSchaeffer
Copy link

RylanSchaeffer commented Aug 6, 2017

tf.contrib.seq2seq has two Helper classes to use during inference, SampleEmbeddingHelper and GreedyEmbeddingHelper. However, both make use of embeddings, which is unhelpful when building sequence-to-sequence models that operate on non-embedded target sequences (my target sequence already consists of meaningful vectors).

I'd like a new Helper class that pipes the output of the decoder RNN at one time step into the decoder RNN at the following time step. It should permit the start_tokens to be vectors (tensors?) and the end_token to be a vector (tensor?) as well. Right now, I'm attempting to use ScheduledOutputTrainingHelper with sampling_probability set equal to 1.0, but I'm struggling to get it to work. Something like a simple OutputInferenceHelper would be very nice :)

If there already exists an easy way to do what I'm suggesting, please let me know!

@RylanSchaeffer
Copy link
Author

RylanSchaeffer commented Aug 8, 2017

I'm trying to implement my suggestion, but I'm running into an error that I'm unsure of how to debug.

My code:

class OutputInferenceHelper(Helper):

    def __init__(self, start_tensors, end_tensor, name=None):
        """Initializer.
        
        Args:
          start_tensors: `float32` tensor shaped `[batch_size, ...]`, the start
           tensors.
          end_tensor: `float32` tensor shaped `[...]`, the tensor that marks 
           end of decoding.
        """
        with ops.name_scope(name, "OutputInferenceHelper"):
            self._start_tensors = ops.convert_to_tensor(
                start_tensors, dtype=dtypes.float32, name="start_tensors")
            self._end_tensor = ops.convert_to_tensor(
                end_tensor, dtype=dtypes.float32, name="end_tensor")
            self._batch_size = array_ops.shape(start_tensors)[0]

    @property
    def batch_size(self):
        return self._batch_size

    def initialize(self, name=None):
        with ops.name_scope(name, "OutputInferenceHelperInitialize"):
            finished = array_ops.tile([False], [self._batch_size])
            return (finished, self._start_tensors)

    def sample(self, time, outputs, state, name=None):
        with ops.name_scope(name, "OutputInferenceHelperSample"):
            del time, state
            if not isinstance(outputs, ops.Tensor):
                raise TypeError("Expected outputs to be a single Tensor, got: %s" %
                                type(outputs))
            return outputs

    def next_inputs(self, time, outputs, state, sample_ids, name=None):
        with ops.name_scope(name, "OutputInferenceHelperNextInputs"):
            del time, sample_ids
            finished = math_ops.equal(outputs, self._end_tensor)
            all_finished = math_ops.reduce_all(finished)
            next_inputs = control_flow_ops.cond(
                all_finished,
                # If we're finished, the next_inputs value doesn't matter
                lambda: self._start_tensors,
                lambda: outputs)

            return (finished, next_inputs, state)

I'm receiving the following error: tensorflow.python.framework.errors_impl.InvalidArgumentError: Shapes must be equal rank, but are 1 and 2 for 'add_inference/add_decoder/decoder/while/Select' (op: 'Select') with input shapes: [128,4], [?], [?].

My formal parameter for start_tensors when instantiating a OutputInferenceHelper object has shape [128, 4], but beyond that, I'm lost.

My Traceback:

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/common_shapes.py", line 671, in _call_cpp_shape_fn_impl
    input_tensors_as_shapes, status)
  File "/usr/local/Cellar/python3/3.6.1/Frameworks/Python.framework/Versions/3.6/lib/python3.6/contextlib.py", line 89, in __exit__
    next(self.gen)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
    pywrap_tensorflow.TF_GetCode(status))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  <my function calls>
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 286, in dynamic_decode
    swap_memory=swap_memory)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2770, in while_loop
    result = context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2599, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2549, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/contrib/seq2seq/python/ops/decoder.py", line 242, in body
    sequence_lengths)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 2328, in where
    return gen_math_ops._select(condition=condition, t=x, e=y, name=name)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/ops/gen_math_ops.py", line 2145, in _select
    name=name)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
    op_def=op_def)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2508, in create_op
    set_shapes_for_outputs(ret)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1873, in set_shapes_for_outputs
    shapes = shape_func(op)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1823, in call_with_requiring
    return call_cpp_shape_fn(op, require_shape_fn=True)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/common_shapes.py", line 610, in call_cpp_shape_fn
    debug_python_shape_fn, require_shape_fn)
  File "/usr/local/lib/python3.6/site-packages/tensorflow/python/framework/common_shapes.py", line 676, in _call_cpp_shape_fn_impl
    raise ValueError(err.message)
ValueError: Shapes must be equal rank, but are 1 and 2 for 'add_inference/add_decoder/decoder/while/Select' (op: 'Select') with input shapes: [128,4], [?], [?].

@tatatodd
Copy link
Contributor

tatatodd commented Aug 9, 2017

@ebrevdo @lukaszkaiser Can you comment on this? Thanks!

@tatatodd tatatodd added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Aug 9, 2017
@lukaszkaiser
Copy link
Contributor

I'm not too familiar with this code, will wait for @ebrevdo or @lmthang to comment.

@RylanSchaeffer
Copy link
Author

One mistake I'm making (I think) is that finished = math_ops.equal(outputs, self._end_tensor) operates element-wise, which means finished will be a tensor of shape [batch size, ...] and should instead be a tensor of shape [batch size]. I believe a statement is finished = math_ops.reduce_all(math_ops.equal(outputs, self._end_tensor), axis=1), but I'm not sure if this generalizes to tensors of higher dimensions.

@ebrevdo
Copy link
Contributor

ebrevdo commented Aug 10, 2017 via email

@adarob
Copy link
Member

adarob commented Aug 10, 2017

This is actually a more specific case of what I am building, which is a general sampling helper that takes a layer like the ScheduledOutputTrainingHelper. The problem with using the training helper is that it doesn't output the actual samples so it can't be used during inference.

Rylan, your solution solution will almost work, but you are going to hit the issue that sample_ids is currently required to be a scalar integer, something which I am also solving with my change. If you can wait until next week, I can have something that will work for you. In the meantime, I'd recommend hacking up the CustomHelper to work for your use case. You'll also need to adjust the BasicDecoder output_size and output_dtype to match the shape of your sample tensor.

@RylanSchaeffer
Copy link
Author

RylanSchaeffer commented Aug 10, 2017

@adarob , thanks for responding! I need to meet a deadline this Sunday for my internship project, but I have another month after that to continue working on the project. Could I email you directly for help with hacking a solution in the short term? (I'm happy to talk here, but I know TensorFlow developers like to keep conversations focused on the original issue).

@adarob
Copy link
Member

adarob commented Aug 10, 2017

Sure @RylanSchaeffer

@RylanSchaeffer
Copy link
Author

RylanSchaeffer commented Aug 10, 2017

@adarob , thanks! Sent!

@ppyht2
Copy link

ppyht2 commented Aug 13, 2017

@adarob @RylanSchaeffer Hi guys, I've also been looking for an output helper for this purpose. Would it be possible to share your solution with me?

Thanks in advance.

@RylanSchaeffer
Copy link
Author

RylanSchaeffer commented Aug 13, 2017

@ppyht2 Adam's solution for my problem might not work generally. In my case, my decoder's output is a (4,) tensor of binary values, so he suggested using a CustomHelper that maps the decoder's outputs to sample_ids by treating the output as an integer represented in binary i.e. [0, 0, 0, 0,]'s sample_id is 0, [0, 1, 0, 0]'s sample_id is 4, etc.

Let me know if that makes sense.

@adarob
Copy link
Member

adarob commented Aug 13, 2017 via email

@RylanSchaeffer
Copy link
Author

@adarob , would it be possible to also get a corresponding modified version of sequence_loss that doesn't require targets to be shaped [batch_size x sequence_length]?

@ppyht2
Copy link

ppyht2 commented Aug 14, 2017

@RylanSchaeffer I'm currently working on a problem where my decoder's output has float values, in which case the solution will no longer work, is that correct?

@adarob will the InferenceHelper resolve this issue?

Thanks for you help guys :)

@RylanSchaeffer
Copy link
Author

RylanSchaeffer commented Aug 14, 2017

@ppyht2 That's correct. However, I think there's an easier solution if all you want is the decoder's output to be passed as input at the next time step and you don't care about sample_ids.

If you look at dynamic_decode, you'll see that the function calls its decoder's step method. Assuming you're using a BasicDecoder, the step method does three things:

  1. Runs the decoder's cell for 1 step (139)
  2. Feeds the cell's output to the Helper's sample_fn to generate sample_ids (142)
  3. Feeds the cell's output and the sample_ids (and a few other parameters) to the Helper's next_input_fn (144)

If all you want is the cell's outputs to be the decoder's inputs at the next time step, you can write a CustomHelper with a sample_fn that returns arbitrary integers (to pass a type and shape later check) and then write a next_inputs_fn that returns outputs as next_inputs if the decoder isn't done.

Hope that makes sense!

Disclaimer: My CustomHelper isn't working, so I'm not sure if this is a viable solution.

@ppyht2
Copy link

ppyht2 commented Aug 17, 2017

@RylanSchaeffer This solution sounds like it could work, I will give it a crack. Thanks for your help.

Did you had any explanation as to why the CustomHelper is not working?

@RylanSchaeffer
Copy link
Author

@ppyht2 I suspect that I'm incorrectly using the TrainingHelper. I've been receiving help from people on another issue at the NMT tutorial (tensorflow/nmt#3), but I haven't been able to fix my issue yet.

@adarob
Copy link
Member

adarob commented Aug 17, 2017

The new InferenceHelper added in e9a8d75 resolves this issue.

@ebrevdo
Copy link
Contributor

ebrevdo commented Aug 17, 2017

@RylanSchaeffer does @adarob's PR solve your issue?

@RylanSchaeffer
Copy link
Author

@ebrevdo Let's presume yes, and I'll reopen the issue with additional details if it doesn't.

@RylanSchaeffer
Copy link
Author

@adarob , thank you!

@BalthazarFitzpatrick
Copy link

@RylanSchaeffer I am facing a similar issue, just not working on vectors but on regular good old floating point numbers for time series forecasting. Would you mind me contacting you via mail with a few short questions regarding your InferenceHelper implementation?

@RylanSchaeffer
Copy link
Author

@fritzfitzpatrick , I'd rather help you here in case anyone else runs into a similar issue.

Here's the code I used. In my case, I had a (4,)-shaped tensor of zeros and ones as outputs, hence 16 possible outcomes, hence my sampling function. However, the sampling function was just for helping me debug whereas you can get by with a function that does nothing.

inference

        elif self.mode == 'inference':

            def initialize_fn():
                finished = tf.tile([False], [FLAGS.batch_size])
                start_inputs = tf.fill([FLAGS.batch_size, 4], -1.)
                return (finished, start_inputs)

            def sample_fn(time, outputs, state):
                del time, state
                outputs = tf.cast(tf.round(tf.nn.sigmoid(outputs)),
                                  dtype=tf.int32)
                sample_ids = outputs[:, 0] + 2 * outputs[:, 1] + \
                             4 * outputs[:, 2] + 8 * outputs[:, 3]
                return sample_ids

            def next_inputs_fn(time, outputs, state, sample_ids):
                del time, sample_ids
                squashed_logits = tf.nn.sigmoid(outputs)
                binary_decoder_outputs = tf.round(squashed_logits)
                finished = tf.equal(
                    0.,
                    tf.reduce_sum(binary_decoder_outputs, axis=1))
                all_finished = tf.reduce_all(finished)
                next_inputs = tf.cond(
                    all_finished,
                    # If we're finished, the next_inputs value doesn't matter
                    lambda: tf.zeros_like(outputs),
                    lambda: squashed_logits)
                return (finished, next_inputs, state)

            helper = CustomHelper(initialize_fn=initialize_fn,
                                  sample_fn=sample_fn,
                                  next_inputs_fn=next_inputs_fn)

@RylanSchaeffer
Copy link
Author

@adarob @ebrevdo Just my personal opinion, but even though the issue is technically solved by giving the programmer the ability to implement their own helper through InferenceHelper, I don't feel like this practically solved the issue since (in my opinion) one should be able to pipe the output at one time step to the next time step without needing to write custom code.

@adarob
Copy link
Member

adarob commented Jan 25, 2018

@RylanSchaeffer are you suggesting having a CategoricalInferenceHelper that provides categorical sampling, as is most often used? I think it's reasonable to add that.

@RylanSchaeffer
Copy link
Author

RylanSchaeffer commented Jan 25, 2018

@adarob that would be helpful, but I was referring to something slightly different. I haven't looked at this in 5+ months, so maybe the module has changed, but my understanding is that the sample_fn is used to collapse the rnn cell's output tensor into a single number (either deterministically or stochastically), which is then reconstituted as an input tensor for the next step. This is useful if I want the input tensor to represent a discrete input e.g. a word, but in the context that I was using the library, I wanted the output tensor to be passed to the next step unchanged.

For concreteness, my understanding of the current implementation is like this:
output at time T is [0.5, 0.1, 0.1, 0.1] => 0th element is sampled => [1., 0., 0., 0.] is passed to time step T+1

My desired behavior:
output at time T is [0.5, 0.1, 0.1, 0.1] => [0.5, 0.1, 0.1, 0.1] is passed to time step T+1

@BalthazarFitzpatrick
Copy link

@RylanSchaeffer Thanks for sharing. My code is a bit more bare bones, but I should be able to wire something together based on your snippet.

@adarob Passing the decoder output at time step T as the decoder input at time step T+1 is exactly what I am after, and I have been struggling a bit with implementing this through other helpers (mostly because of my lack of understanding re: sampling in TF and/or how to deal with function arguments like time, that I now see Rylan just deletes from the function scope).

Overall I am really happy with TensorFlow and it's a great entry point for ML beginners like me, so thanks a lot!

@adarob
Copy link
Member

adarob commented Jan 26, 2018

What you're requesting is possible by setting sample_fn=tf.identity.
However it is redundant to have the unmodified outputs be the inputs in an RNN since this is already passed through as the state.

@BalthazarFitzpatrick
Copy link

@adarob Thanks, I'll try tf.identity.

I will pass the modified decoder output (through a dense layer) of time step T as inputs in time step T+1 to the decoder, I was unclear in my above post, sorry for that.

@nishaskinner
Copy link

I'm facing a similar issue. I've described what I'm trying to do at stackoverflow.
https://stackoverflow.com/questions/48216786/seq2seq-in-tensorflow-without-embeddings

Has anyone managed to use a CustomHelper for this?

@BalthazarFitzpatrick
Copy link

BalthazarFitzpatrick commented Feb 9, 2018

@nishaskinner I think Rylan managed to get it to work, but I am getting stuck.

@RylanSchaeffer can I pick your brain about this in a mail? I want to understand and share my knowledge, as apparently there's more people out there that need help implementing the CustomHelper outside the nmt domain.

I am creating a generic example for regression using a sequence to sequence architecture that will work with any input sequence length and number of features, as well as any output sequence length and number of features.

The work in progress notebook can be found here, and I want to turn this into a beginner friendly blog post as tutorials on non-language applications are few and far between.

Thanks a lot in advance!

@RylanSchaeffer
Copy link
Author

@adarob , I agree that while you're correct, the lack of documentation and examples (and a general explanation of what a Helper even does) makes writing a CustomHelper difficult for people who aren't familiar with the library, as these comments demonstrate.

@nishaskinner Yes, my CustomHelper worked as I intended. I posted the code above.

@fritzfitzpatrick , what exactly would you like to talk about? My email address is rylanschaeffer@gmail.com, but like I said above, I'd prefer to keep my conversations public in case others have similar questions.

@ebrevdo
Copy link
Contributor

ebrevdo commented Feb 10, 2018 via email

@BalthazarFitzpatrick
Copy link

BalthazarFitzpatrick commented Feb 16, 2018

@RylanSchaeffer I have tried using a go token as well as the last value of the encoder input sequence as my start_inputs in the initializer_fn, and I have tried using the projection layer outputs as well as the original signal as the next_inputs in the next_inputs_fn.

I am training my model on a linear time series, it is seeing a sequence [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6] and should learn to predict [0.7, 0.8]. During training and validation the l2 regulated loss goes down to 0.010 and the validation prediction using the training helper looks somewhat like [0.69, 0.79], so close enough.

When I save that model, restore the variables to an identical inference model and run it, the output looks more like this [-0.14, 0.19] etc. I have no idea why it is off like that. Attached is the model in its current form, but as I said above, I have been through a few configurations. I can't seem to find the reason as to why the inference prediction is so vastly off the mark.
`

# define training encoder
inf_enc_out, inf_enc_state = tf.nn.dynamic_rnn(
    enc_cell, 
    inf_enc_inp,
    dtype = tf.float32,
    sequence_length = seq_length_inp,
    time_major = time_major)

with tf.variable_scope('projection_layer', reuse = tf.AUTO_REUSE):
    from tensorflow.python.layers.core import Dense
    projection_layer = Dense(features_dec_exp_out)

# define inference custom helper
def initialize_fn():
    finished = tf.tile([False], [batch_size])
    enc_inp_end = inf_enc_inp[0, observation_length - 1, 0]
    start_inputs = tf.reshape(enc_inp_end, shape=[1, 1]) 
    return (finished, start_inputs)

def sample_fn(time, outputs, state):
    return tf.constant([0])

def next_inputs_fn(time, outputs, state, sample_ids):
    finished = time >= prediction_length
    next_inputs = outputs
    return (finished, next_inputs, state)

inf_custom_helper = tf.contrib.seq2seq.CustomHelper(
    initialize_fn = initialize_fn,
    sample_fn = sample_fn,                      
    next_inputs_fn = next_inputs_fn)

# create inference decoder
inf_decoder = tf.contrib.seq2seq.BasicDecoder(
    dec_cell, 
    inf_custom_helper, 
    inf_enc_state,
    projection_layer)

# create inference dynamic decoding
inf_dec_out, inf_dec_state, inf_dec_out_seq_length = tf.contrib.seq2seq.dynamic_decode(
    inf_decoder, 
    output_time_major = time_major)

# extract prediction from decoder output
inf_output_dense = inf_dec_out.rnn_output`

@MrfksIv
Copy link

MrfksIv commented Jun 26, 2018

@fritzfitzpatrick did you have any luck with this? I have been struggling with this for days and training is fine with the traininghelper, however I am still completely lost with what should happen during inference.

@VinoJose
Copy link

@fritzfitzpatrick I'm also struggling with same issue. Could you please share the solution or guide me, if you have managed to fix this?

@BalthazarFitzpatrick
Copy link

@MrfksIv @VinoJose

I indeed have made some progress using a seq2seq architecture, but it is not very accurate and tends to generalise towards a perfectly linear horizontal line after just a few iterations. I have however tested another architecture that works quite nicely. I am currently pretty tapped out with business travels, but I will try to upload a python notebook for you to peruse.

If anyone else finally got multistep numerical time signal predictions going with the inference or custom helper using a seq2seq architecture, please let us know. Can't be that hard, can it.

@MrfksIv
Copy link

MrfksIv commented Jun 28, 2018

@VinoJose @fritzfitzpatrick
By following this code I have successfully trained the network with a RMSE of 0.02.
The problem with the code is that the training-helper expects and requires the true +1 timesteps which are of course unknown during inference. My naive solution to this was to give to the network the predicted sequence as the correct one.
The predict_sequence_end below is test data neither trained, nor validated on.

for i in range(steps_ahead): predict_sequence_end = sess.run(h, feed_dict={enc_inp: predict_sequence_end.reshape((1,batch_steps,1)), expect:predict_sequence_end.reshape((1,batch_steps,1)), expect_length: [n_steps]* predict_sequence_end.shape[0], keep_prob: keepprob})

From my limited understanding, this bypasses the problem of the training helper, does it not? By not running the train_op, I assume that the network weights remain constant. This gives quite good results, although I am not sure if I am 'cheating' in any way. Do you have any ideas on this?
image

@nimroha
Copy link

nimroha commented Jul 16, 2018

lacking a CategoricalInferenceHelper I used the GreedyEmbeddingHelper with
embedding = tf.eye(VOCAB_SIZE, dtype=tf.float32, name='embedding')

Hope this helps anyone who reaches this thread

@Andreea-G
Copy link

I got it to work for no embedding in a much simpler way, using a very rudimentary InferenceHelper:

inference_helper = tf.contrib.seq2seq.InferenceHelper(
            sample_fn=lambda outputs: outputs,
            sample_shape=[dim],
            sample_dtype=dtypes.float32,
            start_inputs=start_tokens,
            end_fn=lambda sample_ids: False)

My inputs are floats with the shape [batch_size, time, dim]. For the example above with @MrfksIv 's plot, dim would be 1, but this can easily be extended to more dimensions. Here's the relevant chunk of the code:

# Dense layer to translate the decoder's output at each time
# step.
projection_layer = tf.layers.Dense(
    units=1,  # = dim
    kernel_initializer=tf.truncated_normal_initializer(
        mean=0.0, stddev=0.1))

# Training Decoder
training_decoder_output = None
with tf.variable_scope("decode"):
    # output_data doesn't exist during prediction phase.
    if output_data is not None:
        # Prepend the "go" token
        go_tokens = tf.constant(go_token, shape=[batch_size, 1, 1])
        dec_input = tf.concat([go_tokens, target_data], axis=1)

        # Helper for the training process.
        training_helper = tf.contrib.seq2seq.TrainingHelper(
            inputs=dec_input,
            sequence_length=[output_size] * batch_size)

        # Basic decoder
        training_decoder = tf.contrib.seq2seq.BasicDecoder(
            dec_cell,  training_helper, enc_state, projection_layer)

        # Perform dynamic decoding using the decoder
        training_decoder_output = tf.contrib.seq2seq.dynamic_decode(
            training_decoder, impute_finished=True,
            maximum_iterations=output_size)[0]

# Inference Decoder
# Reuses the same parameters trained by the training process.
with tf.variable_scope("decode", reuse=tf.AUTO_REUSE):
    start_tokens = tf.constant(
        go_token, shape=[batch_size, 1])

    # The sample_ids are the actual output in this case (not dealing with any logits here).
    # My end_fn is always False because I'm working with a generator that will stop giving 
    # more data. You may extend the end_fn as you wish. E.g. you can append end_tokens 
    # and make end_fn be true when the sample_id is the end token.
    inference_helper = tf.contrib.seq2seq.InferenceHelper(
        sample_fn=lambda outputs: outputs,
        sample_shape=[1],  # again because dim=1
        sample_dtype=dtypes.float32,
        start_inputs=start_tokens,
        end_fn=lambda sample_ids: False)

    # Basic decoder
    inference_decoder = tf.contrib.seq2seq.BasicDecoder(dec_cell,
        inference_helper, enc_state, projection_layer)

    # Perform dynamic decoding using the decoder
    inference_decoder_output = tf.contrib.seq2seq.dynamic_decode(
        inference_decoder, impute_finished=True,
        maximum_iterations=output_size)[0]

@muggin
Copy link

muggin commented Aug 27, 2018

@Andreea-G could you please elaborate on this flag end_fn=lambda sample_ids: False in your InferenceHelper? Wouldn't this be equal to setting impute_finished=False?

I use the following lambda function end_fn=lambda outputs: tf.greater_equal(tf.shape(outputs)[1], self.outputs_len) where self.outputs_len are the true lengths of the target sequences. In this setting, during training, the validation loss doesn't decrease. If I set impute_finished=False, validation loss does decrease. I'm having difficulties in understanding this behavior.

@tungk
Copy link

tungk commented Nov 12, 2018

@Andreea-G Could you elaborate the CustomHelper for ConvLSTMCell. That means input is of shape [batch_size, time_step, row_number, column_number, channel_number]?

@shivam13juna
Copy link

Sorry to respond late. Is anyone here interested in adding a comprehensive docstring to the module explaining helpers and how to create your own? You could base it on existing unit tests.

On Sat, Feb 10, 2018, 9:28 AM Rylan Schaeffer @.> wrote: @adarob https://github.com/adarob , I agree that while you're correct, the lack of documentation and examples (and a general explanation of what a Helper even does) makes writing a CustomHelper difficult for people who aren't familiar with the library, as these comments demonstrate. @nishaskinner https://github.com/nishaskinner Yes, my CustomHelper worked as I intended. I posted the code above. @fritzfitzpatrick https://github.com/fritzfitzpatrick , what exactly would you like to talk about? My email address is @., but like I said above, I'd prefer to keep my conversations public in case others have similar questions. — You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub <#12065 (comment)>, or mute the thread https://github.com/notifications/unsubscribe-auth/ABtim1ApLs7vSeSXl6VHjH2PNOuUO3v0ks5tTdHAgaJpZM4Ou28p .

@ebrevdo I would love to do that, if it's still required?

@ebrevdo
Copy link
Contributor

ebrevdo commented Dec 31, 2018 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:awaiting tensorflower Status - Awaiting response from tensorflower
Projects
None yet
Development

No branches or pull requests