In [None]:
import re

from datasets import load_dataset, concatenate_datasets
import datasets

import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
librispeech = load_dataset("librispeech_asr", "all", split="train.clean.100[:10]+train.clean.360[:10]+train.other.500[:10]", use_auth_token=True)

common_voice_9_0 = load_dataset("mozilla-foundation/common_voice_9_0", "en", split="train[:10]", use_auth_token=True)

voxpopuli = load_dataset("google/xtreme_s", "voxpopuli.en", split="train[:10]", use_auth_token=True)

tedlium = load_dataset("LIUM/tedlium", "release3", split="train[:10]", use_auth_token=True)

gigaspeech = load_dataset("speechcolab/gigaspeech", "l", split="train[:10]", use_auth_token=True)

earnings22 = load_dataset("sanchit-gandhi/earnings22_robust_split", split="train[:10]", use_auth_token=True)

spgispeech = load_dataset("kensho/spgispeech", "L", split="train", use_auth_token=True, revision="f4d7d3b3f9b66414a09532ec937e285197afeaf6")

switchboard = load_dataset("ldc/switchboard", "switchboard", split="train[:10]", use_auth_token=True)

train_datasets = [librispeech, common_voice_9_0, voxpopuli, tedlium, gigaspeech, earnings22, spgispeech, switchboard]
ds_name = ["librispeech", "common_voice_9_0", "voxpopuli", "tedlium", "gigaspeech", "earnings22", "spgispeech", "switchboard"]

transcript_column_names = ['text', 'sentence', 'transcription', 'text', 'text', 'sentence', 'transcript', 'text']
id_column_names = ['id', 'client_id', 'id', 'id', 'segment_id', 'source_id', 'wav_filename', 'id']
do_lower_cases = [True, False, True, True, True, False, False, True]


tedlium_contractions = [" 's", " 't", " 're", " 've", " 'm", " 'll", " 'd", " 'clock", " 'all"]
gigaspeech_punctuation = {" <comma>": ",", " <period>": ".", " <questionmark>": "?", " <exclamationpoint>": "!"}
gigaspeech_disfluencies = ["<other>", "<sil>"]
swb_disfluencies = ["[noise]", "[laughter]", "[silence]", "<a_aside>", "<b_aside>", "<e_aside>", "[laughter-",
                    "[vocalized-noise]", "_1"]
swb_punctuations = ["{", "}", "[", "]-", "]"]
earnings_disfluencies = ["<crosstalk>", "<affirmative>", "<inaudible>", "inaudible", "<laugh>"]
ignore_segments = ["ignore_time_segment_in_scoring", "<noise>", "<music>", "[noise]", "[laughter]", "[silence]",
                   "[vocalized-noise]", "<crosstalk>", "<affirmative>", "<inaudible>", "<laugh>", "<other>",
                   "<sil>", ""]

Reusing dataset librispeech_asr (/home/sanchit_huggingface_co/.cache/huggingface/datasets/librispeech_asr/all/2.1.0/14c8bffddb861b4b3a4fcdff648a56980dbb808f3fc56f5a3d56b18ee88458eb)
Reusing dataset common_voice_9_0 (/home/sanchit_huggingface_co/.cache/huggingface/datasets/mozilla-foundation___common_voice_9_0/en/9.0.0/c8491634a4579fef5745ab949ee9aa4265b7203d7e2ecf44f45879a6419cd40d)
Reusing dataset xtreme_s (/home/sanchit_huggingface_co/.cache/huggingface/datasets/google___xtreme_s/voxpopuli.en/2.0.0/1384f19b49cc1beade2a9bf2ca44abe870cd95f85819a16f6f44671d4fdad7e2)
Reusing dataset tedlium (/home/sanchit_huggingface_co/.cache/huggingface/datasets/LIUM___tedlium/release3/1.0.1/3534cf671f9fe252aa91994765f9fbe95f9a077a67d56255dcd6645776ab997d)
Reusing dataset gigaspeech (/home/sanchit_huggingface_co/.cache/huggingface/datasets/speechcolab___gigaspeech/l/0.0.0/0db31224ad43470c71b459deb2f2b40956b3a4edfde5fb313aaec69ec7b50d3c)
Using custom data configuration sanchit-gandhi--earnings22_robust_

In [None]:
def error_correction(datasets, ds_name, transcript_column_names, id_column_names, do_lower_cases):
    for i, ds in enumerate(datasets):
        dataset_name = ds_name[i]
        text_column_name = transcript_column_names[i]
        id_column_name = id_column_names[i]
        do_lower_case = do_lower_cases[i]

        if text_column_name != "text":
            ds = ds.rename_column(text_column_name, "text")
        if id_column_name != "id":
            ds = ds.rename_column(id_column_name, "id")

        def is_target_labels(input_str):
            return input_str.lower() not in ignore_segments

        ds = ds.filter(is_target_labels, input_columns=["text"], desc="filtering text...")

        def prepare_dataset(batch):
            # Pre-process audio
            try:
                sample = batch["audio"]
            except ValueError:
                # E22: some samples are empty (no audio). Reading the empty audio array will trigger
                # a soundfile ValueError. For now, we'll manually set these arrays to a zero array.
                # They will be filtered in the subsequent filtering stage and so are
                # explicitly ignored during training.
                sample = {"array": np.array([0.]), "sampling_rate": 16000}

            # time in s
            batch["input_length"] = len(sample["array"]) / sample["sampling_rate"]

            # 'Error correction' of targets
            input_str = batch["text"].lower() if do_lower_case else batch["text"]
            # LibriSpeech ASR
            if "librispeech" in dataset_name:
                pass  # no error correction necessary
            # VoxPopuli
            if "voxpopuli" in dataset_name:
                pass  # no error correction necessary
            # Common Voice 9
            if "common_voice_9_0" in dataset_name:
                if input_str.startswith('"') and input_str.endswith('"'):
                    # we can remove trailing quotation marks as they do not affect the transcription
                    input_str = input_str[1:-1]
                # replace double quotation marks with single
                input_str = input_str.replace('""', '"')
            # TED-LIUM (Release 3)
            if "tedlium" in dataset_name:
                # delete the <unk> token from the text
                input_str = input_str.replace("<unk>", "")
                # replace spaced apostrophes with un-spaced (it 's -> it's)
                for contraction in tedlium_contractions:
                    input_str = input_str.replace(contraction, contraction[1:])
            # GigaSpeech
            if "gigaspeech" in dataset_name:
                for disfluency in gigaspeech_disfluencies:
                    input_str = input_str.replace(disfluency, "")
                # convert spelled out punctuation to symbolic form
                for punctuation, replacement in gigaspeech_punctuation.items():
                    input_str = input_str.replace(punctuation, replacement)
            # SWB: hide the path to the private HF dataset
            if "switchboard" in dataset_name:
                for disfluency in swb_disfluencies:
                    input_str = input_str.replace(disfluency, "")
                # remove parenthesised text (test data only)
                input_str = re.sub("[\(].*?[\)]", "", input_str)
                for punctuation in swb_punctuations:
                    input_str = input_str.replace(punctuation, "")
                # replace anomalous words with their correct transcriptions
                split_str = input_str.split("/")
                if len(split_str) > 1:
                    input_str = " ".join(
                        [" ".join([" ".join(i.split(" ")[:-1]) for i in split_str])] + [split_str[-1].split(" ")[-1]])
            # Earnings 22: still figuring out best segmenting method. Thus, dataset name subject to change
            if "earnings22" in dataset_name:
                for disfluency in earnings_disfluencies:
                    input_str = input_str.replace(disfluency, "")
            # SPGISpeech
            if "spgispeech" in dataset_name:
                pass  # no error correction necessary
            # JIWER compliance (for WER/CER calc.)
            # remove multiple spaces
            input_str = re.sub(r"\s\s+", " ", input_str)
            # strip trailing spaces
            input_str = input_str.strip()
            batch["text"] = input_str
            batch["words_length"] = len(input_str.split(" "))
            return batch

        ds = ds.map(prepare_dataset, desc=f"pre-processing...", num_proc=1)

        def is_audio_not_none(audio_length):
            return audio_length > 1

        ds = ds.filter(is_audio_not_none, input_columns=["input_length"], desc="filtering audio...")

        def is_text_not_none(words_length):
            return words_length > 0

        ds = ds.filter(is_text_not_none, input_columns=["words_length"], desc="filtering text...")

        datasets[i] = ds

        print(100*"=")
        print(dataset_name)
        print("Num samples: ", len(ds))
        print("Total audio length: ", np.sum(ds["input_length"]) / 60 ** 2, "hours")

    return datasets

In [None]:
train_datasets = error_correction(train_datasets, ds_name, transcript_column_names, id_column_names, do_lower_cases)

In [None]:
librispeech_dev_clean = load_dataset("librispeech_asr", "all", split="validation.clean[:10]", use_auth_token=True)
librispeech_dev_other = load_dataset("librispeech_asr", "all", split="validation.other[:10]", use_auth_token=True)
librispeech_test_clean = load_dataset("librispeech_asr", "all", split="test.clean[:10]", use_auth_token=True)
librispeech_test_other = load_dataset("librispeech_asr", "all", split="test.other[:10]", use_auth_token=True)

common_voice_9_0_dev = load_dataset("mozilla-foundation/common_voice_9_0", "en", split="validation[:10]", use_auth_token=True)
common_voice_9_0_test = load_dataset("mozilla-foundation/common_voice_9_0", "en", split="test[:10]", use_auth_token=True)

voxpopuli_dev = load_dataset("google/xtreme_s", "voxpopuli.en", split="validation[:10]", use_auth_token=True)
voxpopuli_test = load_dataset("google/xtreme_s", "voxpopuli.en", split="test[:10]", use_auth_token=True)

tedlium_dev = load_dataset("LIUM/tedlium", "release3", split="validation[:10]", use_auth_token=True)
tedlium_test = load_dataset("LIUM/tedlium", "release3", split="test[:10]", use_auth_token=True)

gigaspeech_dev = load_dataset("speechcolab/gigaspeech", "l", split="validation[:10]", use_auth_token=True)
gigaspeech_test = load_dataset("speechcolab/gigaspeech", "l", split="test[:10]", use_auth_token=True)

earnings22_dev = load_dataset("sanchit-gandhi/earnings22_robust_split", split="validation[:10]", use_auth_token=True)
earnings22_test = load_dataset("sanchit-gandhi/earnings22_robust_split", split="test[:10]", use_auth_token=True)

spgispeech_dev = load_dataset("kensho/spgispeech", "L", split="validation", use_auth_token=True, revision="f4d7d3b3f9b66414a09532ec937e285197afeaf6")
spgispeech_test = load_dataset("kensho/spgispeech", "L", split="test", use_auth_token=True, revision="f4d7d3b3f9b66414a09532ec937e285197afeaf6")

switchboard_test = load_dataset("ldc/switchboard", "switchboard", split="test.switchboard[:10]", use_auth_token=True)
callhome_test = load_dataset("ldc/switchboard", "switchboard", split="test.callhome[:10]", use_auth_token=True)

dev_ds = [librispeech_dev_clean, librispeech_dev_other, common_voice_9_0_dev, voxpopuli_dev, tedlium_dev, gigaspeech_dev, earnings22_dev, spgispeech_dev, switchboard_test]
dev_name = ["librispeech_asr/validation.clean", "librispeech_asr/validation.other", "common_voice_9_0/validation", "voxpopuli/validation", "tedlium/validation", "gigaspeech/validation", "earnings22/validation", "spgispeech/validation", "switchboard/test"]

test_ds = [librispeech_test_clean, librispeech_test_other, common_voice_9_0_test, voxpopuli_test, tedlium_test, gigaspeech_test, earnings22_test, spgispeech_test, callhome_test]
test_name = ["librispeech_asr/test.clean", "librispeech_asr/test.other", "common_voice_9_0/test", "voxpopuli/test", "tedlium/test", "gigaspeech/test", "earnings22/test", "spgispeech/test", "switchboard/callhome"]

dev_transcript_column_names = [transcript_column_names[0], *transcript_column_names]
dev_id_column_names = [id_column_names[0], *id_column_names]
dev_do_lower_cases = [do_lower_cases[0], *do_lower_cases]

In [None]:
dev_ds = error_correction(dev_ds, dev_name, dev_transcript_column_names, dev_id_column_names, dev_do_lower_cases)
test_ds = error_correction(test_ds, test_name, dev_transcript_column_names, dev_id_column_names, dev_do_lower_cases)

In [None]:
# combine datasets for accumulated statistics (train-dev-test)
librispeech_all = concatenate_datasets([train_datasets[0], dev_ds[0], dev_ds[1], test_ds[0], test_ds[1]])
all_datasets = [concatenate_datasets([train_datasets[i-1], dev_ds[i], test_ds[i]]) for i in range(2, len(dev_name))]
all_datasets = [librispeech_all, * all_datasets]

In [None]:
for i in range(len(all_datasets)):
    print(ds_name[i], all_datasets[i]) 

In [None]:
for i in range(len(all_datasets)):
    print(ds_name[i])
    ds = all_datasets[i]
    print("Mean sample duration: ", np.mean(ds["input_length"]), "+-", np.std(ds["input_length"]), "s")
    print("Mean transcript length: ", np.mean(ds["words_length"]), "+-", np.std(ds["words_length"]), "words")