-
Notifications
You must be signed in to change notification settings - Fork 3k
[magpietts] added an argument 'binarize_atten_prior' to trigger whether apply prior binarization. #14166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: magpietts_2503
Are you sure you want to change the base?
[magpietts] added an argument 'binarize_atten_prior' to trigger whether apply prior binarization. #14166
Conversation
apply prior binarization or not during training or evaluating. also make prior past and future context window configurable during inference. hardcoded text_lens - 2 to decide if text sentence is finished or not. Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com>
Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds control over attention prior binarization and makes the prior context window and decay fully configurable for both training and inference, updates a hardcoded finish threshold, and fixes minor typos.
- Introduces
binarize_attn_prior
flag to toggle binarization of attention priors. - Adds
inference_prior_{future,past}_{context,decay}
andinference_prior_current_value
parameters for inference. - Changes finish condition threshold from
text_lens - 5
totext_lens - 2
and corrects typos.
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 1 comment.
Show a summary per file
File | Description |
---|---|
nemo/collections/tts/models/magpietts.py | Added binarize_attn_prior flag, configurable inference prior parameters, refactored attention prior construction, updated finish threshold, and typo fixes. |
examples/tts/conf/magpietts/magpietts_multilingual_v1.yaml | Added binarize_attn_prior default to multilingual config. |
examples/tts/conf/magpietts/magpietts_lhotse_dc_en.yaml | Added binarize_attn_prior default to lhotse DC English config. |
examples/tts/conf/magpietts/magpietts_inference_multilingual_v1.yaml | Added binarize_attn_prior default to inference multilingual config. |
examples/tts/conf/magpietts/magpietts_inference_en.yaml | Added binarize_attn_prior default to inference English config. |
examples/tts/conf/magpietts/magpietts_en.yaml | Added binarize_attn_prior default to English config. |
examples/tts/conf/magpietts/magpietts_dc_en.yaml | Added binarize_attn_prior default to DC English config. |
Comments suppressed due to low confidence (5)
nemo/collections/tts/models/magpietts.py:286
- The newly added inference_prior_* parameters lack documentation. Please update the class or method docstring to explain their purpose and valid value ranges.
# Inference prior configuration
nemo/collections/tts/models/magpietts.py:1222
- Since behavior now changes based on binarize_attn_prior, add unit tests covering both True and False cases to ensure the logic branches produce the expected prior matrices.
if self.binarize_attn_prior:
nemo/collections/tts/models/magpietts.py:289
- The attribute prior_future_decay is not defined before use. Consider obtaining it from cfg or defining a default value earlier to avoid an AttributeError.
self.inference_prior_future_decay = self.cfg.get('inference_prior_future_decay', self.prior_future_decay)
nemo/collections/tts/models/magpietts.py:290
- The attribute prior_past_decay is not defined before use. Consider obtaining it from cfg or defining a default value earlier to avoid an AttributeError.
self.inference_prior_past_decay = self.cfg.get('inference_prior_past_decay', self.prior_past_decay)
nemo/collections/tts/models/magpietts.py:1222
- When binarize_attn_prior is False, aligner_attn_hard is never assigned but is likely used later. Add an else branch or default assignment to prevent an UnboundLocalError.
if self.binarize_attn_prior:
@@ -1501,14 +1528,14 @@ def construct_inference_prior(self, prior_epsilon, cross_attention_scores, | |||
if bidx not in end_indices: | |||
unfinished_texts[bidx] = True | |||
|
|||
if text_time_step_attended[bidx] >= text_lens[bidx] - 5 or bidx in end_indices: | |||
if text_time_step_attended[bidx] >= text_lens[bidx] - 2 or bidx in end_indices: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] This uses a magic number (-2) to detect sentence completion. Consider extracting it into a named constant or making it configurable for clarity and easier tuning.
if text_time_step_attended[bidx] >= text_lens[bidx] - 2 or bidx in end_indices: | |
if text_time_step_attended[bidx] >= text_lens[bidx] - self.EOS_WINDOW_OFFSET or bidx in end_indices: |
Copilot uses AI. Check for mistakes.
binarize_atten_prior
to trigger whether apply prior binarization or not during training or evaluating. Previously, the training always applied binary prior. Now it is configurable.text_lens - 2
to decide if text sentence is finished or not.