In [2]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

# load model and processor
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.config.forced_decoder_ids = None

# load dummy dataset and read audio files
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
sample = ds[0]["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features 

Found cached dataset librispeech_asr_dummy (/Users/Tony/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_dummy/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)


## Vanilla N-beam search with `generate`

In [8]:
# generate token ids
predicted_ids = model.generate(input_features, num_beams=3)
# decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
predicted_ids, transcription

(tensor([[50257, 50362,  1770,    13,  2264,   346,   353,   318,   262, 46329,
            286,   262,  3504,  6097,    11,   290,   356,   389,  9675,   284,
           7062,   465, 21443,    13, 50256]]),
 [' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'])

## Get top-N sequences from beam search

In [9]:
# generate token ids
predicted_ids = model.generate(input_features, num_beams=3, num_return_sequences=3)
# decode token ids to text
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
predicted_ids, transcription

(tensor([[50257, 50362,  1770,    13,  2264,   346,   353,   318,   262, 46329,
            286,   262,  3504,  6097,    11,   290,   356,   389,  9675,   284,
           7062,   465, 21443,    13, 50256],
         [50257, 50362,  1770,    13,  2264,   346,   353,   318,   262, 46329,
            286,   262,  3504,  6097,   290,   356,   389,  9675,   284,  7062,
            465, 21443,    13, 50256, 50256],
         [50257, 50362,  1770,    13,  2264,   346,   353,   318,   262, 46329,
            286,   262,  3504,  6097,    11,   290,   356,   389,  9675,   284,
           7062,   465, 23244,    13, 50256]]),
 [' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.',
  ' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.',
  ' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his Gospel.'])

## Retrieve scores from `generate`

In [10]:
outputs = model.generate(input_features, num_beams=3, output_scores=True, return_dict_in_generate=True)
outputs.keys()

odict_keys(['sequences', 'sequences_scores', 'scores', 'beam_indices'])

In [15]:
outputs.sequences, outputs.sequences_scores

(tensor([[50257, 50362,  1770,    13,  2264,   346,   353,   318,   262, 46329,
            286,   262,  3504,  6097,    11,   290,   356,   389,  9675,   284,
           7062,   465, 21443,    13, 50256]]),
 tensor([-0.1096]))

In [20]:
# Returns all scores for all N-beams:
len(outputs), outputs.scores[1]

(4,
 tensor([[-11.7886,     -inf,     -inf,  ..., -15.2796, -16.0519, -18.6459],
         [-11.7886,     -inf,     -inf,  ..., -15.2796, -16.0519, -18.6459],
         [-11.7886,     -inf,     -inf,  ..., -15.2796, -16.0519, -18.6459]]))

In [21]:
# To get only the scores of the words of interest, use the following:
transition_scores = model.compute_transition_scores(
    outputs.sequences, outputs.scores, normalize_logits=True
)

transition_scores

tensor([[    -inf, -11.7038, -13.6538, -11.9018, -15.5696, -15.5865, -15.5451,
          -9.8790, -11.2768, -16.8047,  -9.5083, -10.5973, -16.2507, -17.7821,
         -11.9694, -12.2877,  -9.6139, -15.9645, -15.0510, -11.4573, -13.6452,
         -15.4599, -15.6448, -12.2026,  -4.7137],
        [    -inf, -11.7038, -11.4163, -19.3818, -16.5448, -16.7483, -15.1039,
          -9.5228, -11.3053, -17.6024,  -9.5014, -10.5143, -13.7391, -17.9175,
          -7.8405, -13.8592, -12.2392, -13.2750, -17.6305,  -9.5785, -17.5459,
         -13.4226, -17.2916, -12.1076,  -0.4173],
        [    -inf, -11.7038, -11.3371, -18.9002, -10.6250,  -2.9960, -16.9284,
          -9.8154, -11.2018, -16.4499,  -9.4196, -10.7639, -13.8377, -17.9994,
         -11.9418, -12.3365, -12.6365, -15.1623, -15.3699, -11.4702, -13.8234,
         -15.5515, -16.2069, -10.6529,  -4.6477]])