Skip to content

Commit

Permalink
ctc cleanup beam search
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed May 10, 2024
1 parent 65284f7 commit 782fba6
Showing 1 changed file with 11 additions and 21 deletions.
32 changes: 11 additions & 21 deletions users/zeyer/experiments/exp2024_04_23_baselines/ctc.py
Expand Up @@ -307,37 +307,27 @@ def model_recog(
# Initial state.
beam_dim = Dim(1, name="initial-beam")
batch_dims_ = [beam_dim] + batch_dims
ended = rf.constant(False, dims=batch_dims_)
out_seq_len = rf.constant(0, dims=batch_dims_)
seq_log_prob = rf.constant(0.0, dims=batch_dims_)

label_log_prob = rf.log_softmax(logits, axis=model.wb_target_dim)
label_log_prob = TensorArray.unstack(label_log_prob, axis=enc_spatial_dim)
label_log_prob = rf.log_softmax(logits, axis=model.wb_target_dim) # Batch, Spatial, Vocab
label_log_prob = rf.where(
enc_spatial_dim.get_mask(),
label_log_prob,
rf.sparse_to_dense(model.blank_idx, axis=model.wb_target_dim, label_value=0.0, other_value=-1.0e30),
)
label_log_prob = TensorArray.unstack(label_log_prob, axis=enc_spatial_dim) # t -> Batch, Vocab

t = 0
max_seq_len = int(enc_spatial_dim.get_dim_value())
seq_targets = []
seq_backrefs = []
while True:
for t in range(max_seq_len):
# Filter out finished beams
label_log_prob_ = rf.where(
ended,
rf.sparse_to_dense(model.blank_idx, axis=model.wb_target_dim, label_value=0.0, other_value=-1.0e30),
label_log_prob[t],
)
seq_log_prob = seq_log_prob + label_log_prob_ # Batch, InBeam, Vocab
seq_log_prob = seq_log_prob + label_log_prob[t] # Batch, InBeam, Vocab
seq_log_prob, (backrefs, target), beam_dim = rf.top_k(
seq_log_prob, k_dim=Dim(beam_size, name=f"dec-step{t}-beam"), axis=[beam_dim, model.wb_target_dim]
) # seq_log_prob, backrefs, target: Batch, Beam
seq_targets.append(target)
seq_backrefs.append(backrefs)
ended = rf.gather(ended, indices=backrefs)
out_seq_len = rf.gather(out_seq_len, indices=backrefs)
t += 1

ended = rf.logical_or(ended, rf.copy_to_device(t >= enc_spatial_dim.get_size_tensor()))
if bool(rf.reduce_all(ended, axis=ended.dims).raw_tensor):
break
out_seq_len = out_seq_len + rf.where(ended, 0, 1)

# Backtrack via backrefs, resolve beams.
seq_targets_ = []
Expand All @@ -351,7 +341,7 @@ def model_recog(
seq_targets__ = TensorArray(seq_targets_[0])
for target in seq_targets_:
seq_targets__ = seq_targets__.push_back(target)
out_spatial_dim = Dim(out_seq_len, name="out-spatial")
out_spatial_dim = enc_spatial_dim
seq_targets = seq_targets__.stack(axis=out_spatial_dim)

return seq_targets, seq_log_prob, out_spatial_dim, beam_dim
Expand Down

0 comments on commit 782fba6

Please sign in to comment.