In [1]:
from datasets import Audio, load_dataset
from transformers import pipeline
from evaluate import load
import torch

## Single dataset

In [2]:
dataset = load_dataset("LIUM/tedlium", "release3", split="validation", streaming=True)
dataset = dataset.take(32)

whisper_asr = pipeline(
    "automatic-speech-recognition", model="openai/whisper-tiny.en", device=0
)

whisper_asr.model.config.suppress_tokens.remove(6)
whisper_asr.model.config.suppress_tokens.remove(12)

wer_metric = load("wer")

In [3]:
dataset_cast = dataset.cast_column("audio", Audio(16000))

In [4]:
# helper function: get the column names for the datasets
def get_text(sample):
    if "text" in sample:
        return sample["text"]
    elif "sentence" in sample:
        return sample["sentence"]
    elif "normalized_text" in sample:
        return sample["normalized_text"]
    elif "transcript" in sample:
        return sample["transcript"]
    else:
        raise ValueError(f"Sample: {sample.keys()} has no transcript.")

In [5]:
def normalise(batch):
    batch["norm_text"] = whisper_asr.tokenizer._normalize(get_text(batch))
    return batch

In [6]:
dataset_norm = dataset_cast.map(normalise)

In [7]:
def is_target_text_in_range(ref):
    if ref.strip() == "ignore time segment in scoring":
        return False
    else:
        return ref.strip() != ""

In [8]:
dataset_filter = dataset_norm.filter(is_target_text_in_range, input_columns=["norm_text"])

In [9]:
def data(my_dataset):
    for i, sample in enumerate(my_dataset):
        yield sample["audio"]

In [12]:
%%time
predictions = []
references = []

for out in whisper_asr(data(dataset_filter), batch_size=8):
    predictions.append(whisper_asr.tokenizer._normalize((out["text"])))
    
dataset_text = dataset_filter.remove_columns("audio")

for i, sample in enumerate(dataset_text):
    references.append(sample["norm_text"])

CPU times: user 16.1 s, sys: 309 ms, total: 16.4 s
Wall time: 3.52 s


In [13]:
100 * wer_metric.compute(references=references, predictions=predictions)

3.9458850056369785

## Multi dataset

In [25]:
librispeech_clean = load_dataset("librispeech_asr", "all", split="test.clean", streaming=True)
librispeech_other = load_dataset("librispeech_asr", "all", split="test.other", streaming=True)

common_voice = load_dataset("mozilla-foundation/common_voice_11_0", "en", revision="streaming", split="test", streaming=True, use_auth_token=True)

voxpopuli = load_dataset("facebook/voxpopuli", "en", split="test", streaming=True)

tedlium = load_dataset("LIUM/tedlium", "release3", split="test", streaming=True)

gigaspeech = load_dataset("speechcolab/gigaspeech", "xs", split="test", streaming=True, use_auth_token=True)

spgispeech = load_dataset("kensho/spgispeech", "S", split="test", streaming=True, use_auth_token=True)

#earnings22 = load_dataset("anton-l/earnings22_baseline_5_gram", split="test", streaming=True)

ami = load_dataset("edinburghcstr/ami", "ihm", split="test", streaming=True)

In [26]:
esb_datasets = {"LibriSpeech Clean": librispeech_clean,
                "LibriSpeech Other": librispeech_other,
                "Common Voice": common_voice,
                "VoxPopuli": voxpopuli,
                "TEDLIUM": tedlium,
                "GigaSpeech": gigaspeech,
                "SPGISpeech": spgispeech,
                #"Earnings-22": earnings22,
                "AMI": ami}

In [27]:
%%time
# batch size for extracting references and predictions
batch_size = 4

wer_results = []

# loop over all the datasets in the ESB benchmark
for dataset_name, dataset in esb_datasets.items():    
    # first 32 samples
    dataset = dataset.take(8)

    # resample to 16kHz
    dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))

    # normalise references
    dataset = dataset.map(normalise)

    # remove any empty references
    dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"])

    # run streamed inference
    predictions = []
    references = []

    for out in whisper_asr(data(dataset), batch_size=batch_size):
        predictions.append(whisper_asr.tokenizer._normalize((out["text"])))

    dataset = dataset.remove_columns("audio")

    for i, sample in enumerate(dataset):
        references.append(sample["norm_text"])

    # compute the WER
    wer = wer_metric.compute(references=references, predictions=predictions)
    wer = round(100 * wer, 2)

    wer_results.append(wer)

Reading metadata...: 16354it [00:00, 51617.76it/s]
Reading metadata...: 16354it [00:00, 74394.16it/s]


CPU times: user 47.6 s, sys: 1.16 s, total: 48.8 s
Wall time: 17.4 s
