Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Beam Search 4 Translate.py #654

Closed
marcellofederico opened this issue Dec 30, 2015 · 54 comments
Closed

Beam Search 4 Translate.py #654

marcellofederico opened this issue Dec 30, 2015 · 54 comments
Labels
stat:contribution welcome Status - Contributions welcome

Comments

@marcellofederico
Copy link

Hi,
I'm wondering which steps are necessary to move from the greedy decoder currently implemented to an actual beam search decoder. Is this enhancement already in someone's roadmap? If not, could anyone tell me which is the right point in the code where to add this functionality?
Thanks a lot!
Marcello

@marcotrombetti
Copy link

I would like to contribute too.

@vrv @lukaszkaiser knowing where the proper place to add it will be very useful.

@lukaszkaiser
Copy link
Contributor

When writing the seq2seq module, the idea was that the loop_function argument (e.g. of attention_decoder) could be used to provide various forms of decoding. It works for the greedy case, but we have not implemented a beam-search yet. I think it should be possible using loop_function and the top-k op from tensorflow. But we'll see - if it's too hard to do it inside the graph, then we can change the design and go with a python-side decoder. Having a decoder in the graph has advantages though, esp. when building more complex models, so I'd like to try that first. All ideas, comments, remarks and code are welcome of course!

@PrajitR
Copy link

PrajitR commented Dec 31, 2015

@lukaszkaiser
Here's a self contained example demonstrating a possible beam search implementation:

from __future__ import division
import tensorflow as tf

with tf.Graph().as_default():
    beam_size = 3 # Number of hypotheses in beam.
    num_symbols = 5 # Output vocabulary size.
    embedding_size = 10
    num_steps = 3
    embedding = tf.zeros([num_symbols, embedding_size])
    output_projection = None

    # log_beam_probs: list of [beam_size, 1] Tensors
    #  Ordered log probabilities of the `beam_size` best hypotheses
    #  found in each beam step (highest probability first).
    # beam_symbols: list of [beam_size] Tensors 
    #  The ordered `beam_size` words / symbols extracted by the beam
    #  step, which will be appended to their corresponding hypotheses
    #  (corresponding hypotheses found in `beam_path`).
    # beam_path: list of [beam_size] Tensor
    #  The ordered `beam_size` parent indices. Their values range
    #  from [0, `beam_size`), and they denote which previous
    #  hypothesis each word should be appended to.
    log_beam_probs, beam_symbols, beam_path  = [], [], []
    def beam_search(prev, i):
        if output_projection is not None:
            prev = tf.nn.xw_plus_b(
                prev, output_projection[0], output_projection[1])

        # Compute 
        #  log P(next_word, hypothesis) = 
        #  log P(next_word | hypothesis)*P(hypothesis) =
        #  log P(next_word | hypothesis) + log P(hypothesis)
        # for each hypothesis separately, then join them together 
        # on the same tensor dimension to form the example's 
        # beam probability distribution:
        # [P(word1, hypothesis1), P(word2, hypothesis1), ...,
        #  P(word1, hypothesis2), P(word2, hypothesis2), ...]

        # If TF had a log_sum_exp operator, then it would be 
        # more numerically stable to use: 
        #   probs = prev - tf.log_sum_exp(prev, reduction_dims=[1])
        probs = tf.log(tf.nn.softmax(prev))
        # i == 1 corresponds to the input being "<GO>", with
        # uniform prior probability and only the empty hypothesis
        # (each row is a separate example).
        if i > 1:
            probs = tf.reshape(probs + log_beam_probs[-1], 
                               [-1, beam_size * num_symbols])

        # Get the top `beam_size` candidates and reshape them such
        # that the number of rows = batch_size * beam_size, which
        # allows us to process each hypothesis independently.
        best_probs, indices = tf.nn.top_k(probs, beam_size)
        indices = tf.stop_gradient(tf.squeeze(tf.reshape(indices, [-1, 1])))
        best_probs = tf.stop_gradient(tf.reshape(best_probs, [-1, 1]))

        symbols = indices % num_symbols # Which word in vocabulary.
        beam_parent = indices // num_symbols # Which hypothesis it came from.

        beam_symbols.append(symbols)
        beam_path.append(beam_parent)
        log_beam_probs.append(best_probs)
        return tf.nn.embedding_lookup(embedding, symbols)

    # Setting up graph.
    inputs = [tf.placeholder(tf.float32, shape=[None, num_symbols])
              for i in range(num_steps)]
    for i in range(num_steps):
        beam_search(inputs[i], i + 1)

    # Running the graph.
    input_vals = [0, 0, 0]
    l = np.log
    eps = -10 # exp(-10) ~= 0

    # These values mimic the distribution of vocabulary words
    # from each hypothesis independently (in log scale since
    # they will be put through exp() in softmax).
    input_vals[0] = np.array([[0, eps, l(2), eps, l(3)]])
    # Step 1 beam hypotheses =
    # (1) Path: [4], prob = log(1 / 2)
    # (2) Path: [2], prob = log(1 / 3)
    # (3) Path: [0], prob = log(1 / 6)

    input_vals[1] = np.array([[l(1.2), 0, 0, l(1.1), 0], # Path [4] 
                              [0,   eps, eps, eps, eps], # Path [2]
                              [0,  0,   0,   0,   0]])   # Path [0]
    # Step 2 beam hypotheses =
    # (1) Path: [2, 0], prob = log(1 / 3) + log(1)
    # (2) Path: [4, 0], prob = log(1 / 2) + log(1.2 / 5.3)
    # (3) Path: [4, 3], prob = log(1 / 2) + log(1.1 / 5.3)

    input_vals[2] = np.array([[0,  l(1.1), 0,   0,   0], # Path [2, 0]
                              [eps, 0,   eps, eps, eps], # Path [4, 0]
                              [eps, eps, eps, eps, 0]])  # Path [4, 3]
    # Step 3 beam hypotheses =
    # (1) Path: [4, 0, 1], prob = log(1 / 2) + log(1.2 / 5.3) + log(1)
    # (2) Path: [4, 3, 4], prob = log(1 / 2) + log(1.1 / 5.3) + log(1)
    # (3) Path: [2, 0, 1], prob = log(1 / 3) + log(1) + log(1.1 / 5.1)

    input_feed = {inputs[i]: input_vals[i][:beam_size, :] 
                  for i in xrange(num_steps)} 
    output_feed = beam_symbols + beam_path + log_beam_probs
    session = tf.InteractiveSession()
    outputs = session.run(output_feed, feed_dict=input_feed)

    expected_beam_symbols = [[4, 2, 0],
                             [0, 0, 3],
                             [1, 4, 1]]
    expected_beam_path = [[0, 0, 0],
                          [1, 0, 0],
                          [1, 2, 0]]

    print("predicted beam_symbols vs. expected beam_symbols")
    for ind, predicted in enumerate(outputs[:num_steps]):
        print(list(predicted), expected_beam_symbols[ind])
    print("\npredicted beam_path vs. expected beam_path")
    for ind, predicted in enumerate(outputs[num_steps:num_steps * 2]):
        print(list(predicted), expected_beam_path[ind])
    print("\nlog beam probs")
    for log_probs in outputs[2 * num_steps:]:
        print(log_probs)

Output:


predicted beam_symbols vs. expected beam_symbols
([4, 2, 0], [4, 2, 0])
([0, 0, 3], [0, 0, 3])
([1, 4, 1], [1, 4, 1])

predicted beam_path vs. expected beam_path
([0, 0, 0], [0, 0, 0])
([1, 0, 0], [1, 0, 0])
([1, 2, 0], [1, 2, 0])

log beam probs
[[-0.6931622 ]
 [-1.09862733]
 [-1.79177451]]
[[-1.098809  ]
 [-2.17854738]
 [-2.26555872]]
[[-2.17872906]
 [-2.26574016]
 [-2.63273931]]

A simple function very similar to the decoding step of Viterbi is needed to extract the best hypothesis. Furthermore, there will have to be logic outside the function that extracts the correct recurrent state for each hypothesis. This can be done with an embedding lookup of prev_state using beam_parent multiplied by each example's index in batch * beam_size.

One downside to this implementation is that it does not pull off hypotheses that have reached the token from the beam. It might be possible to do that, though I think it would require significantly more complicated logic. I'm interested in hearing your thoughts!

@wchan
Copy link

wchan commented Jan 2, 2016

FYI, I have a variation of a beam search for TF here:
https://github.com/wchan/tensorflow/blob/master/speech4/models/las_decoder.py

@NickShahML
Copy link

I wouldn't mind helping out either. I currently use sampling with temperature to generate different outputs given the same input.

@lukaszkaiser
Copy link
Contributor

I like the first code a lot, I think it's advantegous when the beam-search is done in the graph, so we can just feed the graph once and get the whole sequence. I think the only thing missing was pulling out hypotheses, right? But we have the top-k op in TensorFlow, woudn't that suffice?

@marcellofederico
Copy link
Author

Hi,
let me first say that my knowledge of tensorflow is still quite limited, hence forgive me if I'll write something wrong. From what I read I understood that (1) it is not good getting back and forth between the computations on the graph (on GPU) and computations in python (on CPU), and (2) it is advisable
to also exploit as much as possibly parallel computations on the graph (on GPU), which means in our case computing expansions of alternative hypotheses in parallel.

I have put down the following figure to explain how the beam search could work.

nmt-search

The figure focuses on the decoding step and shows on the top outputs of decoding steps that can be computed in parallel, and the best K=2 output words for each step. As no recombination is possible with RNN, because each output word depends on its whole history, we have to select the output words (with possible repetitions) that result in the top B=3 cumulative scores. (Notice that K should ideally be equal to B to have beams search with beam B ) These B words become the input for the following step. To allow backtracking of the best translation we only need to store the provenance of each input with respect to the input of the previous step (see dotted arrows). A numeric index corresponding to a position index should work, too. Once we have finished with all computations, we can start backtracking the
input trellis (on the bottom) starting from the input word </s> with the best global score. This could be performed with a linear pass over all columns, from right to left.

@martinwicke martinwicke added the stat:contribution welcome Status - Contributions welcome label Jan 5, 2016
@PrajitR
Copy link

PrajitR commented Jan 5, 2016

@lukaszkaiser In the Seq2Seq paper they say, "As soon as the <EOS>
symbol is appended to a hypothesis, it is removed from the beam and is added to the set of complete
hypotheses". The problem with the implementation I did above is that once an EOS token is appended, the hypothesis remains on the beam. This means that the effective beam size is reduced by one. The easy way to avoid this is to extract 2k hypothesis, which guarantees that we will follow the Seq2Seq approach. Of course, it has the downside of doubling computation.

@mfederico In TF, Python is just a front end language -- no operations are actually run in Python. Whether an operation runs on GPU depends on whether the operation has a GPU kernel (e.g. .cu file) and other scheduling heuristics. Computing hypotheses in parallel is equivalent to decreasing the batch size for each model replica (actually faster because data transfer between GPUs doesn't have to happen).

I like the idea of computing top k on each hypothesis, then computing top k on the combination of remaining words. I think this would be faster if the top k operation is computed in parallel for each row (O(n + k^2) vs. O(nk), where n = vocab size, k = beam size, k << n). @lukaszkaiser would this be faster?

@wchan
Copy link

wchan commented Jan 5, 2016

@PrajitR, i really like the approach as well ... question, how would you stop the graph early? i.e., if u know all future partial hypothesis will be worse than the best completed partial hypothesis.

@NickShahML
Copy link

I like the idea of computing top k on each hypothesis, then computing top k on the combination of remaining words.

I like this idea too, but isn't computing top k on the combination of remaining words expensive?

@marcellofederico
Copy link
Author

If the k top-k lists are sorted, they will remain sorted if we add the cumulative score of the input hypotheses they were generated from. The finding the global top-k takes O(k log k) in the worst case with k space (see algorithm http://stackoverflow.com/a/21051271).

@giancds
Copy link

giancds commented Feb 17, 2016

@PrajitR, I'm trying to implement your suggestion into my experiments, using the Seq2seq interfaces. I think your code is very clever and make lots of sense. Nevertheless, I'm still confused by one thing you mentioned:

"(...) extracts the correct recurrent state for each hypothesis. This can be done with an embedding lookup of prev_state using beam_parent multiplied by each example's index in batch * beam_size."

If I'm using a batch of 1 (i.e., one sentence at the time - am I right with this assertion?), wouldn't just the lookup of prev_state suffice? If not, I think I didn't get that.

In addition, I think that if we combine the predicted symbols with their parents right after producing them, we could keep a list of complete hypotheses to add those which reach the EOS symbol and we could try to remove them from the beam_path. This would not stop the graph earlier but I think would solve the problem of reducing the effectiveness of the beam size.

@lukaszkaiser
Copy link
Contributor

Just as a comment: there is some experimental support for session.partial_run() in TensorFlow in the 0.7 release. Since partial_run does not deallocate the tensors, it should make step-by-step decoding from seq2seq models much easier. And since we can have a few of them in parallel, it could also greatly simplify beam_search. But it's experimental for now, so beware - I'm just testing it.

@amirj
Copy link

amirj commented Mar 14, 2016

Finally, I'm interested to know if there is a version of translate.py which support beam search in decoding?

@JinXinDeep
Copy link

I learning translation currently, and found many papers that describe beam seach for Statistical Machine Translation. The main idea is that for the translations that cover the same number of source sentence words, select at most k best candidates, which may combined with Estimating Future Cost method. Are there any papers describe beam search for neural machine translation (NMT)? At each step, is selecting k best candidates that have same length is enough? thanks!

@tilneyyang
Copy link

@giancds Is it possible to get the code of your implementation?

@giancds
Copy link

giancds commented Apr 18, 2016

Hi tilneyyang,

I have declined of implementing the solution suggested in here. I am currently trying some new options on the decoder and have found out that this particular solution does not give me the flexibility I need.

However, I am still doing the beam search but running one step at the time, so part of the beam search is done outside the graph (at the cost of decoding speed to get that freedom).

My code is largely based on the code of @kyunghyuncho (you can find his code here), but he uses Theano in the implementation.

Mine you can find here under the nmt_models.py file.

@nikitakit
Copy link

nikitakit commented Jun 21, 2016

I'd like to share my in-graph beam search implementation that uses the loop_function approach

https://gist.github.com/nikitakit/6ab61a73b86c50ad88d409bac3c3d09f

I believe it correctly implements length-bucketed beam search. So far I've only tested that the outputs look reasonable (as opposed to comparing with a known-good beam search implementation).

The loop_function API is really nice, but I'm not sure if bending over backwards to avoid writing a custom op was the best decision here.

EDIT 2: After further thought, I've discovered a serious flaw in the original beam search implementation I posted. The gist is now updated with a fix. In the process, I had to abandon the loop_function API, since it was not sufficient to correctly implement beam search.

@avati
Copy link

avati commented Jun 27, 2016

For those interested in a dynamic_rnn() based beam search, here is an implementation that is working well for us - stanfordmlgroup/nlc@b5088d1

@nikitakit
Copy link

@avati It appears that your implementation only works on a single example at a time batch_size=1. Is this by accident or by design?

For example, compare your code with mine. Note that your call to top_k is on a flat vector (so if there is a batch size >1, the examples compete with each other), while mine is on a 2D matrix. My code has hacks for allowing batch_size>1, though I have now discovered that it's flawed in a different way.

@avati
Copy link

avati commented Jun 28, 2016

As of now batch_size=1 is by design. I suppose it is possible to support batch_size>1 with dynamic_rnn(), but the code got unreadable quickly the first time i tried. dynamic_rnn() based code is in general less readable already compared to the time-step unfolded code (which enjoys use of native python loops etc). I plan to revisit batch_size>1 support sometime in the future.

@nikitakit
Copy link

FYI my gist is now updated with fixes and a new API.

I'd love to get feedback from others on whether this API would suit their needs!

I think I've mostly converged on something that works well with typical uses of tensorflow. I'm curious if there is interest in getting something like this added to tf.contrib (maybe @lukaszkaiser can comment?) If that is the case, I'm willing to add in some more optimizations (e.g. dynamic early-stopping like the code by @avati) and do general profiling/cleanup on the code.

@pbhatia243
Copy link

Here is the link for extension of tensorflow seq2seq model for conversation models and has the option of beam search and setting beam size .

https://github.com/pbhatia243/Neural_Conversation_Models

@DogNick
Copy link

DogNick commented Aug 25, 2016

@pbhatia243 Hi
I have read your code https://github.com/pbhatia243/Neural_Conversation_Models/blob/master/my_seq2seq.py#L729 for beam_search, and I got several questions maybe you can shed some light on that
1)line 119
emb_prev = tf.reshape(emb_prev,[beam_size,embedding_size]), it seams that emb_prev has a shape of [batch_size * beam_size, embedding_size] before reshape, which i thought might lead to a reshape failure. But if it works , would you explain that for me ?

  1. line 100
    probs = tf.reshape(probs + log_beam_probs[-1], [-1, beam_size * num_symbols])

As i considered, only when the probs matrix is in beam-major order could this reshape work well,
like:
batch0_beam0_symbol_vec
batch0_beam1_symbol_vec
batch0_beam2_symbol_vec
....
But i found that in line 668 - 670

for kk in range(beam_size):
      states.append(state)
state = tf.reshape(tf.concat(0, states), [-1, state_size])

the state is constructed in a seemly batch-major manner, which can't be align to the probs (which come from last cell output that takes the concatenation of state and x as input)
is there anything i missed?

@jihunchoi
Copy link
Contributor

jihunchoi commented Jan 13, 2017

Can't we use the newly added dynamic_rnn_decoder?
According to the documentation, we can use context_state parameter of decoder_fn to implement beam search, however I am not sure how to use that.
I think I could see if we can implement beam search using the function in a couple of weeks. Is there anyone interested on this?

@nikitakit
Copy link

@jihunchoi Thanks for bringing this to my attention. As far as I can tell dynamic_rnn_decoder is just a thin wrapper around raw_rnn. That is, the raw_rnn based implementation I have would be functionally identical to a dynamic_rnn_decoder one, so no point in modifying something that already works.

@alrojo
Copy link
Contributor

alrojo commented Jan 13, 2017

The dynamic_rnn_decoder was built to, eventually, support beam search.

@amirj
Copy link

amirj commented Feb 2, 2017

In the original Google's SmartReply paper:

Our search is conducted as follows. First, the elements of R (the set of all possible responses) are organized into a trie. Then, we conduct a left-to-right beam search, but only retain hypotheses that appear in the trie. This search process has complexity O(bl) for beam size b and maximum response length l. Both b and l are typically in the range of 10-30, so this method dramatically reduces the time to find the top responses and is a critical element of making this system deployable.

I want to implement something similar. Some questions:

  1. How to represent the set of all possible responses (R) in a trie data structure in TF?
  2. Does the current implementation of beam search support retaining only hypotheses that appear in the trie?

@Syndrome777
Copy link
Contributor

Syndrome777 commented Feb 10, 2017

Hi, @pbhatia243

Thank you for your repo. It's a good extension.
But as @nikitakit says, it seems that your code in Neural Conversion Models repo is not the complete Beam Search. It may be a solution between the argmax greedy method and Beam Search. And worse still, it will always get wrong responses.

    symbols = indices % num_symbols # Which word in vocabulary.
    beam_parent = indices // num_symbols # Which hypothesis it came from.

When the beam_parent is changed with the previous time stamp, we must reset the state of RNN in this beam path. Then using this state and this input to inference the output.
If we don't reset the state, the decoding processing will be in the different beam paths. Do you think so?

As the normal Beam Search method, the state of the decoder is pruned step by step. It need to reset the state of RNN during decoding.

Could you add a step to reset the state by the beam_parent in your code?

@pbhatia243
Copy link

@Syndrome777 : I am about to give a second pass on beam search. Could you elaborate on reseting the state. Do you mean reorder ?

@pbhatia243
Copy link

@nikitakit : How has the performance been on your large dataset. Using my approach, even though its an approximation, I do get realistic answers and lower perplexity. I am about to give a second pass on beam search. Could you elaborate on reseting the state. Do you mean reorder ?

@junskang
Copy link

@pbhatia243 I think you must have figured this out already, but, I am leaving this comment, hoping to help people reading this thread. I believe what Syndrome777 and nikitakit meant to point out was "reordering" the states after you reorder the outputs sequences.

If you do not re-order the states along with the outputs(beams), then you could end up trying to generate next token of sequence A with a state used to generate sequences B in the previous step, which most likely would not know of the context of sequence A.

@chenghuige
Copy link

chenghuige commented Mar 19, 2017

Agree with junskang, I have used nikitakit's doing in graph beam search and face this problem. I you use attention you also need to re-order attention. Also for someone trying to use beam search I suggest refering to https://github.com/google/seq2seq/tree/master/seq2seq which has a dynamic implementaion of ingraph beam search. For out graph beam search you can refer to models/im2txt

@Kyubyong
Copy link

@chenghuige Thanks! Would give us some example of the google implementation? Does it cover mini-batches? If we can only process a sequence at a time, does time matter?

@lukaszkaiser
Copy link
Contributor

I'd recommend using the google/seq2seq github repo. It's a very good, complete and maintained seq2seq implementation. Beam search is here: https://github.com/google/seq2seq/blob/master/seq2seq/inference/beam_search.py

@gidim
Copy link

gidim commented Mar 27, 2017

@lukaszkaiser any plans on adding beam search to tensorflow from google/seq2seq?
The google/seq2seq repo seems very good for experimentation but it does a lot of things different than TF (e.g., conf files to define classes).

@lukaszkaiser
Copy link
Contributor

I think we always welcome contributions and stable things moving to tf.contrib, but I'll not be working on it personally.

@gidim
Copy link

gidim commented Mar 27, 2017

thanks @lukaszkaiser. If I end up implementing it i'll send a pull request.

@lukaszkaiser
Copy link
Contributor

Great! I think tensorflow/contrib/seq2seq/ would be a good place for it.

@somah1411
Copy link

Hi All,
I have problem with NMT all out put is unknown. my parallel data is small. is this the problem and how can i solve it

@classicCoder16
Copy link

Hello, I am currently trying to use nikitakit's beam search code in conjunction with an attention mechanism. @nikitakit mentioned in a previous post that the code now works with attention but later @chenghuige said that it doesn't reorder attention -- does this mean it is not possible to currently use nikitakit's code with an attention mechanism? Thanks!

@nikitakit
Copy link

My code works fine with attention. All of my own models that i used it for relied on attention. However, you do need to make sure that cell_transform is set correctly (instructions in the docstring).

That said, the code still uses the tensor flow 0.12 API and I haven't updated it for tensorflow 1.0

@classicCoder16
Copy link

Great, thanks! Out of curiosity, when you've integrated attention have you been using AttentionCellWrapper, or some other means? Also, does the code in general support custom attention cell wrappers?

@nikitakit
Copy link

@classicCoder16 I used custom cells, but AttentionCellWrapper will work too. Anything that implements the RNNCell interface is compatible -- the interface specifies how the cells are supposed to store their state, and the beam search has access to that for state reordering.

@lmkostas
Copy link

Hi, @nikitakit , I am currently working on an application in which I would like to incorporate attention with your implementation of beam search. So far I have been struggling to implement my own custom attention wrapper and I wanted to ask if you would be willing to share your code for the custom attention cells which you referenced in your most recent post? Thanks for the help in advance.

tarasglek pushed a commit to tarasglek/tensorflow that referenced this issue Jun 20, 2017
Update compression model README with results for comparison.
@yangshao
Copy link

Hi, @nikitakit, Does your implementation works in tensorflow version >=1.0? For example, if I use TrainingHelper during the training? how should I use your implementation during the testing? Thanks.

@eduOS
Copy link

eduOS commented May 16, 2018

@PrajitR I have extended your version and tested it in my seq2seq model. Hope it helps.

https://stackoverflow.com/a/50304227/3552975

Any suggestions are highly appreciated. Thanks.

darkbuck pushed a commit to darkbuck/tensorflow that referenced this issue Jan 23, 2020
…pstream-rm-rocm-profiler

Remove rocm-profiler dependency
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:contribution welcome Status - Contributions welcome
Projects
None yet
Development

No branches or pull requests