Skip to content

Commit

Permalink
more efficient ctc beam search
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed May 10, 2024
1 parent f34f70d commit a6d950e
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions users/zeyer/experiments/exp2024_04_23_baselines/ctc.py
Expand Up @@ -307,25 +307,32 @@ def model_recog(
# Initial state.
beam_dim = Dim(1, name="initial-beam")
batch_dims_ = [beam_dim] + batch_dims
seq_log_prob = rf.constant(0.0, dims=batch_dims_)
seq_log_prob = rf.constant(0.0, dims=batch_dims_) # Batch, Beam

label_log_prob = rf.log_softmax(logits, axis=model.wb_target_dim) # Batch, Spatial, Vocab
label_log_prob = rf.where(
enc_spatial_dim.get_mask(),
label_log_prob,
rf.sparse_to_dense(model.blank_idx, axis=model.wb_target_dim, label_value=0.0, other_value=-1.0e30),
)
label_log_prob = TensorArray.unstack(label_log_prob, axis=enc_spatial_dim) # t -> Batch, Vocab
label_log_prob_pre_filter, (backrefs_pre_filter,), pre_filter_beam_dim = rf.top_k(
label_log_prob, k_dim=Dim(beam_size, name=f"pre-filter-beam"), axis=[model.wb_target_dim]
) # seq_log_prob, backrefs_global: Batch, Spatial, PreFilterBeam. backrefs_pre_filter -> Vocab
label_log_prob_pre_filter_ta = TensorArray.unstack(
label_log_prob_pre_filter, axis=enc_spatial_dim
) # t -> Batch, PreFilterBeam
backrefs_pre_filter_ta = TensorArray.unstack(backrefs_pre_filter, axis=enc_spatial_dim) # t -> Batch, PreFilterBeam

max_seq_len = int(enc_spatial_dim.get_dim_value())
seq_targets = []
seq_backrefs = []
for t in range(max_seq_len):
# Filter out finished beams
seq_log_prob = seq_log_prob + label_log_prob[t] # Batch, InBeam, Vocab
seq_log_prob = seq_log_prob + label_log_prob_pre_filter_ta[t] # Batch, InBeam, PreFilterBeam
seq_log_prob, (backrefs, target), beam_dim = rf.top_k(
seq_log_prob, k_dim=Dim(beam_size, name=f"dec-step{t}-beam"), axis=[beam_dim, model.wb_target_dim]
) # seq_log_prob, backrefs, target: Batch, Beam
seq_log_prob, k_dim=Dim(beam_size, name=f"dec-step{t}-beam"), axis=[beam_dim, pre_filter_beam_dim]
) # seq_log_prob, backrefs, target: Batch, Beam. backrefs -> InBeam. target -> PreFilterBeam.
target = rf.gather(backrefs_pre_filter_ta[t], indices=target) # Batch, Beam -> Vocab
seq_targets.append(target)
seq_backrefs.append(backrefs)

Expand Down

0 comments on commit a6d950e

Please sign in to comment.