From d1219f2ff03e9db26af95fd0f9e94533b875ca73 Mon Sep 17 00:00:00 2001 From: Tanuj Rai Date: Fri, 13 Jun 2025 14:36:06 +0530 Subject: [PATCH 1/5] Add kwargs support in WhisperForConditionalGeneration --- src/transformers/models/whisper/modeling_whisper.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 7bb07a6c1c6a..52fd7925223a 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1116,6 +1116,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: r""" input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`): @@ -1286,6 +1287,8 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + loss_function: Optional[Callable] = None, + **kwargs, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: r""" input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`): @@ -1382,10 +1385,10 @@ def forward( loss = None if labels is not None: - loss_fct = CrossEntropyLoss() - # move labels to correct device to enable PP labels = labels.to(lm_logits.device) - loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1)) + loss_function = loss_function or CrossEntropyLoss() + loss = loss_function( + lm_logits.view(-1, self.config.vocab_size), labels.view(-1),**kwargs) if not return_dict: output = (lm_logits,) + outputs[1:] From 3c062dd85fb1000d4060231957652efd7f995343 Mon Sep 17 00:00:00 2001 From: Tanuj Rai Date: Tue, 24 Jun 2025 20:42:46 +0530 Subject: [PATCH 2/5] Update modeling_whisper.py --- src/transformers/models/whisper/modeling_whisper.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 4c6db26d4034..e2b78c8164a2 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1091,7 +1091,6 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - whisper-loss-function-support **kwargs, ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: @@ -1265,7 +1264,6 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - whisper-loss-function-support loss_function: Optional[Callable] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: From f6f5c828c4d6801388055bbbd0cc6abcb7eace26 Mon Sep 17 00:00:00 2001 From: Tanuj Rai Date: Wed, 25 Jun 2025 11:13:52 +0530 Subject: [PATCH 3/5] Update modeling_whisper.py --- src/transformers/models/whisper/modeling_whisper.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index e2b78c8164a2..661cf934706c 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1093,8 +1093,6 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: - - main r""" input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`): Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by @@ -1267,8 +1265,6 @@ def forward( loss_function: Optional[Callable] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: - - main r""" input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, sequence_length)`): Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by From 2bbd2d56bf1672c70fe1e94b93497c588df2d69e Mon Sep 17 00:00:00 2001 From: Tanuj Rai Date: Thu, 26 Jun 2025 12:10:57 +0530 Subject: [PATCH 4/5] Update modeling_whisper.py --- src/transformers/models/whisper/modeling_whisper.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 661cf934706c..13b22291edd2 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1262,7 +1262,6 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - loss_function: Optional[Callable] = None, **kwargs, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: r""" From 4ce31353f5835e02cc44ea9c5412129f1aa7fdda Mon Sep 17 00:00:00 2001 From: Tanuj Rai Date: Thu, 26 Jun 2025 12:20:27 +0530 Subject: [PATCH 5/5] Update modeling_whisper.py --- src/transformers/models/whisper/modeling_whisper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 13b22291edd2..75be87919fbc 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1360,8 +1360,8 @@ def forward( loss = None if labels is not None: labels = labels.to(lm_logits.device) - loss_function = loss_function or CrossEntropyLoss() - loss = loss_function( + loss_fct = loss_fct or CrossEntropyLoss() + loss = loss_fct( lm_logits.view(-1, self.config.vocab_size), labels.view(-1),**kwargs) if not return_dict: