Skip to content

Commit

Permalink
Merge pull request #1751 from Hguimaraes/whisper_hidden_states
Browse files Browse the repository at this point in the history
Enabling the retrieval of whisper's hidden states
  • Loading branch information
Adel-Moumen committed Dec 28, 2022
2 parents 8235fd1 + c988618 commit ff4366a
Showing 1 changed file with 41 additions and 9 deletions.
50 changes: 41 additions & 9 deletions speechbrain/lobes/models/huggingface_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ class HuggingFaceWhisper(nn.Module):
HuggingFace hub name: e.g "openai/whisper-tiny"
save_path : str
Path (dir) of the downloaded model.
output_all_hiddens: bool (default: False)
If True, the forward function outputs the hidden states from all transformer layers of the encoder.
For example whisper-base has 6 transformer layers and the output is of shape (7, B, T, C),
where the output of the CNN output is added to the beginning.
If False, the forward function outputs the hidden states only from the last transformer layer of the encoder.
Example
-------
>>> model_hub = "openai/whisper-tiny"
Expand All @@ -65,13 +70,15 @@ def __init__(
freeze=False,
freeze_encoder=False,
output_attentions=True,
output_all_hiddens=False,
):
super().__init__()
self.sampling_rate = sampling_rate
self.encoder_only = encoder_only
self.freeze = freeze
self.freeze_encoder = freeze_encoder
self.output_attentions = output_attentions
self.output_all_hiddens = output_all_hiddens

self.tokenizer = None
# Download the tokenizer only if we are going to use the Decoder.
Expand Down Expand Up @@ -131,18 +138,29 @@ def forward(self, wav, decoder_input_ids=None):
out_encoder = self.forward_encoder(wav)
if self.encoder_only:
return out_encoder
logits, attn = self.forward_decoder(
out_encoder, decoder_input_ids
)

if self.output_all_hiddens:
logits, attn = self.forward_decoder(
out_encoder[-1], decoder_input_ids
)
else:
logits, attn = self.forward_decoder(
out_encoder, decoder_input_ids
)
return out_encoder, logits, attn
else:
if self.encoder_only:
return self.forward_encoder(wav)
else:
out_encoder = self.forward_encoder(wav)
logits, attn = self.forward_decoder(
out_encoder, decoder_input_ids
)
if self.output_all_hiddens:
logits, attn = self.forward_decoder(
out_encoder[-1], decoder_input_ids
)
else:
logits, attn = self.forward_decoder(
out_encoder, decoder_input_ids
)
return out_encoder, logits, attn

def forward_encoder(self, wav):
Expand All @@ -155,10 +173,24 @@ def forward_encoder(self, wav):

if self.freeze_encoder:
with torch.no_grad():
mel = self._get_mel(wav)
return self.model.encoder(mel).last_hidden_state
return self._get_encoder_states(wav)
else:
return self._get_encoder_states(wav)

def _get_encoder_states(self, wav):
"""Takes an input waveform and return its corresponding encoder states.
Returns the last hidden state of the encoder or all hidden states if
output_all_hiddens is True.
Arguments
---------
wav : torch.Tensor (signal)
A batch of audio signals to transform to features.
"""
mel = self._get_mel(wav)
if self.output_all_hiddens:
states = self.model.encoder(mel, output_hidden_states=True)
return torch.stack(states.hidden_states)
else:
mel = self._get_mel(wav)
return self.model.encoder(mel).last_hidden_state

def _get_mel(self, wav):
Expand Down

0 comments on commit ff4366a

Please sign in to comment.