Skip to content

Add kwargs support in WhisperForConditionalGeneration #38810

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
12 changes: 7 additions & 5 deletions src/transformers/models/whisper/modeling_whisper.py
Original file line number Diff line number Diff line change
@@ -1091,7 +1091,8 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[tuple[torch.Tensor], Seq2SeqModelOutput]:
**kwargs,
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
@@ -1261,7 +1262,8 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[tuple[torch.Tensor], Seq2SeqLMOutput]:
**kwargs,
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
r"""
input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
@@ -1357,10 +1359,10 @@ def forward(

loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
loss_fct = loss_fct or CrossEntropyLoss()
loss = loss_fct(
lm_logits.view(-1, self.config.vocab_size), labels.view(-1),**kwargs)

if not return_dict:
output = (lm_logits,) + outputs[1:]
Loading
Oops, something went wrong.