From 4254590a2c08386c0a7ee350cc5bc4458ad35841 Mon Sep 17 00:00:00 2001 From: patrick-wilken Date: Tue, 12 Nov 2019 09:44:39 -0500 Subject: [PATCH] ChoiceLayer: prefix decoding --- returnn/tf/layers/rec.py | 69 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 66 insertions(+), 3 deletions(-) diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index 3bd9df6c6e..9d767d5648 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -4504,7 +4504,7 @@ def __init__(self, beam_size, keep_beams=False, length_normalization=True, length_normalization_exponent=1.0, custom_score_combine=None, - source_beam_sizes=None, scheduled_sampling=False, cheating=False, + source_beam_sizes=None, scheduled_sampling=False, cheating=False, prefix_target=None, explicit_search_sources=None, **kwargs): """ @@ -4546,6 +4546,7 @@ def __init__(self, beam_size, keep_beams=False, self.search_scores_base = None self.search_scores_combined = None # We assume log-softmax here, inside the rec layer. + self.prefix_target = prefix_target if self.search_flag: if cheating: @@ -4699,6 +4700,13 @@ def __init__(self, beam_size, keep_beams=False, cheating_exclusive=cheating_exclusive) self.search_choices.set_src_beams(src_beams) # (batch, beam) -> beam_in idx labels = tf.reshape(labels, [net_batch_dim * beam_size]) # (batch * beam) + + if self.prefix_target: + assert len(self.sources) == 1, "Prefix decoding not yet implemented for multiple sources." + labels, scores = self._enforce_prefixes( + top_k_labels=labels, all_scores=scores_comb, top_k_scores=scores, batch_dim=net_batch_dim, + beam_size=beam_size) + labels = tf.cast(labels, self.output.dtype) if len(self.sources) > 1: @@ -5008,6 +5016,59 @@ def _get_cheating_targets_and_src_beam_idxs(self, scores): src_beams = src_beams[:, 0] # (batch,) return cheating_gold_targets, src_beams + def _enforce_prefixes(self, top_k_labels, all_scores, top_k_scores, batch_dim, beam_size): + """ + This function replaces the target labels from beam search by the ones predefined by the target prefixes as long + as search is still at a position within the prefix. We also replace the scores such that they correspond to a + prediction of the prefixes. + + :param tf.Tensor top_k_labels: target labels from beam seach, shape (batch * beam,) + :param tf.Tensor all_scores: scores before beam pruning, used to lookup prefix scores, shape (batch, beam, dim) + :param tf.Tensor top_k_scores: scores after beam pruning, shape (batch, beam) + :param tf.Tensor|int batch_dim: number of sequences in batch (without beam) + :param int beam_size: outgoing beam size of this layer + :return: labels (batch * beam,) and scores (batch, beam) of self.prefix_target as long as within prefix, after + that top_k_labels and top_k_scores from beam search + :rtype: (tf.Tensor, tf.Tensor) + """ + assert self.prefix_target + + # Get the labels of the prefixes which should be enforced. They are padded with zeros, therefore we will + # get zeros for those sequences where the current timestep is beyond the length of the prefix. + target_prefix_labels = self._get_target_value( + target=self.prefix_target).get_placeholder_as_batch_major() # (batch * beam,), int32 + + # Get prefixes that have already ended (i.e. have a smaller length than the current time step). + target_prefix_ended = tf.equal(target_prefix_labels, 0) + + # Select between the prefixes and the labels from free decoding, depending on whether the prefix + # has still got to be enforced. + labels = tf.where(target_prefix_ended, top_k_labels, target_prefix_labels) + + # Get rid of the redundant beam, all entries are the same, only keep first entry. + target_prefix_labels = tf.reshape(target_prefix_labels, [batch_dim, beam_size])[:, 0] # (batch,) + + # Now also get the scores for the prefixes. First, select only the first entry of the incoming beam as all entries + # are the same if we are still within the prefix. + all_scores = all_scores[:, 0, :] # (batch, dim) + + # Gather scores for the prefix labels. + from TFUtil import nd_indices + target_prefix_nd_indices = nd_indices(target_prefix_labels) + prefix_scores = tf.expand_dims(tf.gather_nd(all_scores, target_prefix_nd_indices), axis=-1) # (batch, 1) + + # Create an artificial beam, where all but the first scores are infinite. Tiling the one entry we have would + # lead to a beam of all equal hypotheses for the rest of the search. + # Conceptually similar to TFUtil.filter_ended_scores(). + prefix_scores_padding = tf.fill([batch_dim, beam_size - 1], -1.e30) + prefix_scores = tf.concat([prefix_scores, prefix_scores_padding], axis=1) + + # Use prefix scores for sequences where the prefix has not ended yet. + target_prefix_ended = tf.reshape(target_prefix_ended, [batch_dim, beam_size]) + scores = tf.where(target_prefix_ended, top_k_scores, prefix_scores) # (batch, beam) + + return labels, scores + @classmethod def transform_config_dict(cls, d, network, get_layer): """ @@ -5063,8 +5124,8 @@ def _create_search_beam(cls, name, beam_size, sources, network): name="%s%s" % (network.get_absolute_name_prefix(), name)) @classmethod - def get_out_data_from_opts(cls, name, sources, target, network, - beam_size, search=NotSpecified, scheduled_sampling=False, cheating=False, **kwargs): + def get_out_data_from_opts(cls, name, sources, target, network, beam_size, search=NotSpecified, + scheduled_sampling=False, cheating=False, prefix_target=None, **kwargs): """ :param str name: :param list[LayerBase] sources: @@ -5099,6 +5160,8 @@ def get_out_data_from_opts(cls, name, sources, target, network, out_data.batch = out_data.batch.copy_set_beam(out_data.beam) if cheating or scheduled_sampling or not search: cls._static_get_target_value(target=target, network=network, mark_data_key_as_used=True) # mark as used + if search and prefix_target: + cls._static_get_target_value(target=prefix_target, network=network, mark_data_key_as_used=True) # mark as used return out_data def get_sub_layer(self, layer_name):