## Imports

In [1]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

# load model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.config.forced_decoder_ids = None

# load dummy dataset and read audio files
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = ds[0]["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features 

Found cached dataset librispeech_asr_dummy (/Users/Tony/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)


## Vanilla N-beam search with `generate`

In [2]:
# generate token ids
predicted_ids = model.generate(input_features, num_beams=3)
# decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
predicted_ids, transcription



(tensor([[50257, 50362,  1770,    13,  2264,   346,   353,   318,   262, 46329,
            286,   262,  3504,  6097,    11,   290,   356,   389,  9675,   284,
           7062,   465, 21443,    13, 50256]]),
 [' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'])

## Get top-N sequences from beam search

In [3]:
# generate token ids
predicted_ids = model.generate(input_features, num_beams=3, num_return_sequences=3)
# decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
predicted_ids, transcription

(tensor([[50257, 50362,  1770,    13,  2264,   346,   353,   318,   262, 46329,
            286,   262,  3504,  6097,    11,   290,   356,   389,  9675,   284,
           7062,   465, 21443,    13, 50256],
         [50257, 50362,  1770,    13,  2264,   346,   353,   318,   262, 46329,
            286,   262,  3504,  6097,   290,   356,   389,  9675,   284,  7062,
            465, 21443,    13, 50256, 50256],
         [50257, 50362,  1770,    13,  2264,   346,   353,   318,   262, 46329,
            286,   262,  3504,  6097,    11,   290,   356,   389,  9675,   284,
           7062,   465, 23244,    13, 50256]]),
 [' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.',
  ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.',
  ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his Gospel.'])

## Retrieve scores from `generate`

In [11]:
outputs = model.generate(input_features, num_beams=3, num_return_sequences=3, output_scores=True, return_dict_in_generate=True)
outputs.keys()

odict_keys(['sequences', 'sequences_scores', 'scores', 'beam_indices'])

### `outputs.scores`

In [12]:
len(outputs.scores) # corresponds to the length of the output sequence

25

In [13]:
# Note: outputs.scores[0] is not interesting because it would try
# to predict the special token "<|notimestamps|>".

outputs.scores[0].shape, outputs.scores[0]

(torch.Size([3, 51864]),
 tensor([[-inf, -inf, -inf,  ..., -inf, -inf, -inf],
         [-inf, -inf, -inf,  ..., -inf, -inf, -inf],
         [-inf, -inf, -inf,  ..., -inf, -inf, -inf]]))

In [14]:
outputs.scores[1].shape, outputs.scores[1]

(torch.Size([3, 51864]),
 tensor([[-11.7886,     -inf,     -inf,  ..., -15.2796, -16.0519, -18.6459],
         [-11.7886,     -inf,     -inf,  ..., -15.2796, -16.0519, -18.6459],
         [-11.7886,     -inf,     -inf,  ..., -15.2796, -16.0519, -18.6459]]))

In [15]:
outputs.sequences.shape, outputs.sequences

(torch.Size([3, 25]),
 tensor([[50257, 50362,  1770,    13,  2264,   346,   353,   318,   262, 46329,
            286,   262,  3504,  6097,    11,   290,   356,   389,  9675,   284,
           7062,   465, 21443,    13, 50256],
         [50257, 50362,  1770,    13,  2264,   346,   353,   318,   262, 46329,
            286,   262,  3504,  6097,   290,   356,   389,  9675,   284,  7062,
            465, 21443,    13, 50256, 50256],
         [50257, 50362,  1770,    13,  2264,   346,   353,   318,   262, 46329,
            286,   262,  3504,  6097,    11,   290,   356,   389,  9675,   284,
           7062,   465, 23244,    13, 50256]]))

### `outputs.sequences_scores`

In [16]:
outputs.sequences_scores  # score (log-probability) of the whole sentence (best scores from the K beams)

tensor([-0.1096, -0.1257, -0.1332])

We are sometimes interested in the transition probability at EACH generation step.
To get these with the `generate` method, one can use the `model.compute_transition_scores` method.

### `model.compute_transition_scores`

In [17]:
outputs = model.generate(input_features, num_beams=3, output_scores=True, return_dict_in_generate=True)
outputs.keys()

odict_keys(['sequences', 'sequences_scores', 'scores', 'beam_indices'])

In [18]:
# To get only the scores of the words of interest, use the following:
transition_scores = model.compute_transition_scores(
    sequences=outputs.sequences,
    scores=outputs.scores
)

transition_scores.shape

torch.Size([3, 25])

According to the Huggingface [documentation](https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.compute_transition_scores), `transition_scores` is:

> A torch.Tensor of shape (batch_size*num_return_sequences, sequence_length) containing the transition scores (logits)

In [19]:
transition_scores

tensor([[    -inf, -11.7545, -13.6545, -11.9025, -15.5699, -15.5868, -15.5456,
          -9.8802, -11.2776, -16.8050,  -9.5086, -10.5978, -16.2510, -17.7871,
         -11.9698, -12.2878,  -9.6142, -15.9646, -15.0512, -11.4573, -13.6455,
         -15.4601, -15.6462, -12.2083,  -4.7140],
        [    -inf, -11.7545, -11.4174, -19.3821, -16.5451, -16.7485, -15.1044,
          -9.5242, -11.3064, -17.6027,  -9.5018, -10.5148, -13.7400, -17.9222,
          -7.8407, -13.8596, -12.2393, -13.2752, -17.6306,  -9.5788, -17.5460,
         -13.4245, -17.2971, -12.1118,  -0.4220],
        [    -inf, -11.7545, -11.3382, -18.9005, -10.6259,  -2.9965, -16.9288,
          -9.8166, -11.2027, -16.4502,  -9.4199, -10.7644, -13.8387, -18.0040,
         -11.9421, -12.3366, -12.6366, -15.1623, -15.3701, -11.4702, -13.8237,
         -15.5517, -16.2082, -10.6575,  -4.6480]])