You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In the example translate seq2seq model, mentioned in the tutorial and in the code here (reproduced below):
# Get a 1-element batch to feed the sentence to the model.
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
{bucket_id: [(token_ids, [])]}, bucket_id)
# Get output logits for the sentence.
_, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
target_weights, bucket_id, True)
# This is a greedy decoder - outputs are just argmaxes of output_logits.
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
It uses only a greedy decoder, and uses argmax to find the best match. I'm wondering if there's a way to get the top N results instead of just doing it greedily. I've tried argsort, but everything apart from the 0th index is just garbage results. I've also looked into tf.nn.top_k(), but because this is batched, I get the error "List of Tensors when single Tensor expected" and I'm not sure how to unroll the list within TF.
The text was updated successfully, but these errors were encountered:
To get top-N results in a sequential network you need to run a beam search. While it's not implemented in the tutorial, there were already some suggestions and code on github -- please see #654 .
In the example translate seq2seq model, mentioned in the tutorial and in the code here (reproduced below):
It uses only a greedy decoder, and uses argmax to find the best match. I'm wondering if there's a way to get the top N results instead of just doing it greedily. I've tried argsort, but everything apart from the 0th index is just garbage results. I've also looked into tf.nn.top_k(), but because this is batched, I get the error "List of Tensors when single Tensor expected" and I'm not sure how to unroll the list within TF.
The text was updated successfully, but these errors were encountered: