In [1]:
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset, Audio
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")

dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
dataset = dataset.cast_column("audio", Audio(16_000))

In [3]:
# load and pre-process an audio input
sample = dataset[0]["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features

In [4]:
generate_outputs = model.generate(input_features, return_token_timestamps=True, output_scores=True)

# decode text without timestamps
pred_text = processor.batch_decode(generate_outputs.sequences)
print(pred_text)

['<|startoftranscript|><|notimestamps|> Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.<|endoftext|>']


In [5]:
def get_token_probabilities(generate_outputs):
    # Strip off the BOS token; we have no scores for this token
    predicted_ids = generate_outputs.sequences[:, 1:]

    # Get the probability for each predicted token
    scores = torch.cat([x.unsqueeze(0) for x in generate_outputs.scores], dim=0)
    scores = scores.permute([1, 0, 2])
    probabilities = scores.softmax(dim=-1)
    token_probs = torch.gather(probabilities, 2, predicted_ids.unsqueeze(2)).squeeze(2)

    # There is no score for the first token, so set this to 1.0
    ones = torch.ones((predicted_ids.shape[0], 1))
    token_probs = torch.cat([ones, token_probs], dim=-1)
    return token_probs

In [6]:
token_probabilities = get_token_probabilities(generate_outputs)
token_probabilities

tensor([[1.0000, 1.0000, 0.9269, 0.9681, 0.7533, 0.9181, 0.9928, 0.9911, 0.9942,
         0.8014, 0.9976, 0.9960, 0.7689, 0.9324, 0.5493, 0.9934, 0.9979, 0.7633,
         0.9978, 0.9965, 0.9989, 0.9734, 0.6287, 0.8823, 0.9955]])

In [7]:
def combine_results(outputs, skip_special_tokens=False):
    combined = []
    for batch_idx in range(len(outputs["sequences"])):
        combined.append([
            (
                word,
                token_id, 
                timestamp.item(),
                probability.item(),
            )
            for (word, token_id, timestamp, probability) in zip(
                processor.batch_decode(outputs["sequences"][batch_idx].squeeze().tolist()), 
                outputs["sequences"][batch_idx].tolist(), 
                outputs["token_timestamps"][batch_idx], 
                get_token_probabilities(outputs)[batch_idx],
            )
            if not skip_special_tokens or token_id < model.config.eos_token_id
        ])
    return combined

In [8]:
results = combine_results(generate_outputs, skip_special_tokens=True)

In [9]:
# str, token, start, prob
results

[[(' Mr', 1770, 0.0, 0.9268808364868164),
  ('.', 13, 0.8600000143051147, 0.9680864214897156),
  (' Qu', 2264, 1.0199999809265137, 0.7532929182052612),
  ('il', 346, 1.0199999809265137, 0.9181373715400696),
  ('ter', 353, 1.0800000429153442, 0.9928463101387024),
  (' is', 318, 1.2400000095367432, 0.991145670413971),
  (' the', 262, 1.4800000190734863, 0.9941967725753784),
  (' apostle', 46329, 1.6799999475479126, 0.8014128804206848),
  (' of', 286, 2.0799999237060547, 0.9975734353065491),
  (' the', 262, 2.359999895095825, 0.9960063099861145),
  (' middle', 3504, 2.5, 0.7688519954681396),
  (' classes', 6097, 2.700000047683716, 0.9324130415916443),
  (',', 11, 3.200000047683716, 0.5493204593658447),
  (' and', 290, 3.4000000953674316, 0.9934113621711731),
  (' we', 356, 3.559999942779541, 0.9978528022766113),
  (' are', 389, 3.700000047683716, 0.7633094191551208),
  (' glad', 9675, 3.819999933242798, 0.9978243112564087),
  (' to', 284, 4.099999904632568, 0.9964914917945862),
  (' welco