diff --git a/recipes/LJSpeech/evaluation/README.md b/recipes/LJSpeech/evaluation/README.md new file mode 100644 index 0000000000..2dd974892e --- /dev/null +++ b/recipes/LJSpeech/evaluation/README.md @@ -0,0 +1,46 @@ +# Text-to-Speech (with LJSpeech) +This folder contains the recipes for evaluation of existing pretrained text-to-speech systems using ASR-based evaluators and MOS estimation + +By default, MOS evaluation is performed using a pretrained Transformer model, as defined in `recipes/SOMOS/ttseval/hparams/train.yaml` and available in pre-trained form on HuggingFace in +https://huggingface.co/flexthink/ttseval-wavlm-transformer + +ASR evaluation is performed using the bundled Transformer ASR : https://huggingface.co/speechbrain/asr-transformer-transformerlm-librispeech + +# Tacotron 2 +The recipe contains hyperparameters for the evaluation of Tacotron2 in `hparams/tacotron2.yaml` + +To perform evaluation, run the following script +``` +python evaluate.py --data_folder=/your_folder/LJSpeech-1.1 hparams/tacotron.yaml +``` + + +# FastSpeech2 +The recipe contains hyperparameters for the evaluation of FastSpeech2 in `hparams/fastspeech2.yaml` + +``` +python train.py --data_folder=/your_folder/LJSpeech-1.1 hparams/fastspeech2.yaml +``` + + +# **About SpeechBrain** +- Website: https://speechbrain.github.io/ +- Code: https://github.com/speechbrain/speechbrain/ +- HuggingFace: https://huggingface.co/speechbrain/ + + +# **Citing SpeechBrain** +Please, cite SpeechBrain if you use it for your research or business. + +```bibtex +@misc{speechbrain, + title={{SpeechBrain}: A General-Purpose Speech Toolkit}, + author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio}, + year={2021}, + eprint={2106.04624}, + archivePrefix={arXiv}, + primaryClass={eess.AS}, + note={arXiv:2106.04624} +} +``` + diff --git a/recipes/LJSpeech/evaluation/adapters.py b/recipes/LJSpeech/evaluation/adapters.py new file mode 100644 index 0000000000..7dda465f90 --- /dev/null +++ b/recipes/LJSpeech/evaluation/adapters.py @@ -0,0 +1,81 @@ +"""Adapters for specific TTS system + +Authors +* Artem Ploujnikov, 2024 +""" + +from torch import nn + + +class MelAdapter(nn.Module): + """An adapter for TTSes that output a MEL spectrogram + and require a vocoder to synthesize an + audio wave + + Arguments + --------- + vocoder : torch.nn.Module | speechbrain.inference.Pretrained + the vocoder to be used + vocoder_run_opts : dict + run options for the vocoder + """ + + def __init__(self, vocoder, vocoder_run_opts=None): + super().__init__() + self.vocoder_fn = vocoder + self.vocoder_run_opts = vocoder_run_opts or {} + self.vocoder = None + self.device = None + + def _get_vocoder(self): + """Instantiates the vocoder, if not already instantiated""" + if self.vocoder is None: + run_opts = dict(self.vocoder_run_opts) + if self.device is not None: + run_opts["device"] = self.device + self.vocoder = self.vocoder_fn(run_opts=run_opts) + return self.vocoder + + def forward(self, tts_out): + """Applies a vocoder to the waveform + + Arguments + --------- + tts_out : tuple + a (tensor, tensor) tuple with a MEL spectrogram + of shape (batch x mel x length) + and absolute lengths (as in the output of Tacotron2 + or similar models) + + Returns + ------- + wav : torch.Tensor + The waveform + lengths : torch.Tensor + The lengths + """ + mel_outputs, mel_lengths = tts_out[:2] + vocoder = self._get_vocoder() + max_len = mel_lengths.max() + mel_outputs = mel_outputs[:, :, :max_len] + wav = vocoder(mel_outputs) + lengths = mel_lengths / max_len + return wav, lengths + + def to(self, device): + """Transfers the adapter (and the underlying model) to the + specified device + + Arguments + --------- + device : str | torch.Device + The device + + + Returns + ------- + result : MelAdapter + the adapter (i.e. returns itself) + """ + self.device = device + return super().to(device) diff --git a/recipes/LJSpeech/evaluation/evaluate.py b/recipes/LJSpeech/evaluation/evaluate.py new file mode 100644 index 0000000000..272e8c9077 --- /dev/null +++ b/recipes/LJSpeech/evaluation/evaluate.py @@ -0,0 +1,401 @@ +"""Recipe for evaluating a speech synthesis model using one or more of the evaluators provided + +Authors +* Artem Ploujnikov, 2024 +""" + +import csv +import json +import logging +import re +import speechbrain as sb +import string +import sys +import torch +from collections import OrderedDict +from speechbrain.dataio.dataloader import make_dataloader +from speechbrain.inference.eval import itemize +from pathlib import Path +from torch import nn +from types import SimpleNamespace +from hyperpyyaml import load_hyperpyyaml +from tqdm.auto import tqdm + + +logger = logging.getLogger(__name__) + + +class Evaluator: + """Encapsulates the evaluation loop for a TTS evaluation + model + + Arguments + --------- + hparams : dict + Raw hyperparameters + run_opts : dict + The run options + """ + + def __init__( + self, hparams, run_opts=None, + ): + self.hparams = SimpleNamespace(**hparams) + self.run_opts = run_opts or {} + self.device = run_opts.get("device", "cpu") + modules = hparams.get("modules") + self.tts = self.hparams.tts(run_opts={"device": self.device}) + self.modules = ( + nn.ModuleDict(self.hparams.modules).to(self.device) + if modules + else {} + ) + self.modules.tts2wav.to(self.device) + self.enabled_evaluators = set(self.hparams.evaluations.split(",")) + self.evaluators = { + evaluator_key: evaluator_fn(run_opts={"device": self.device}) + for evaluator_key, evaluator_fn in self.hparams.evaluators.items() + if evaluator_key in self.enabled_evaluators + } + + def evaluate(self, dataset): + """Runs the evaluation loop on the specified dataset + + Arguments + --------- + dataset : speechbrain.dataio.dataset.DynamicItemDataset + the dataset + """ + + self.on_evaluate_start() + dataloader = make_dataloader( + dataset, batch_size=self.hparams.batch_size + ) + for batch in tqdm(dataloader, desc="Evaluation"): + self.evaluate_batch(batch) + self.on_evaluate_end() + + def on_evaluate_start(self): + """Invoked at the beginning of evaluation""" + self.evaluators = {} + self.output_files = {} + self.output_writers = {} + self.details = {} + for key, evaluator_fn in self.hparams.evaluators.items(): + self.evaluators[key] = evaluator_fn( + run_opts={"device": self.device} + ) + self.init_evaluator_result(key) + + def init_evaluator_result(self, evaluator_key): + """Opens the CSV file to which evaluation results will be written + and outputs the header + + Arguments + --------- + evaluator_key : str + The evaluator key + """ + file_name = self.hparams.output_files[evaluator_key] + output_path = Path(file_name).parent + output_path.mkdir(parents=True, exist_ok=True) + self.output_file = open(file_name, "w") + columns = self.get_report_columns(evaluator_key) + self.output_writers[evaluator_key] = csv.DictWriter( + self.output_file, columns + ) + self.output_writers[evaluator_key].writeheader() + self.details[evaluator_key] = [] + + def on_evaluate_end(self): + """Invoked at the end of evaluation""" + self.flush() + self.close() + self.write_summary() + + def flush(self): + """Flushes all output files to disk""" + for output_file in self.output_files.values(): + output_file.flush() + + def close(self): + """Closes all output files""" + for output_file in self.output_files.values(): + output_file.close() + + def evaluate_batch(self, batch): + """Runs the evaluaion on a single batch + + Arguments + --------- + batch : PaddedBatch + A single item (wrapped in a batch) + """ + batch = batch.to(self.device) + wav, length = self.synthesize(batch.label_norm) + for evaluator_key, evaluator in self.evaluators.items(): + result = evaluator.evaluate( + wav, + length, + text=batch.label_norm_eval, + wavs_ref=batch.sig.data, + length_ref=batch.sig.lengths, + ) + result_items = itemize(result) + self.write_result(evaluator_key, batch.uttid, result_items) + self.details[evaluator_key].extend(result_items) + self.flush() + + def synthesize(self, text): + """Calls the TTS system to synthesize audio from text + + Arguments + --------- + text : str + The text to be synthesized + + Returns + ------- + wav : torch.Tensor + The waveform + length : torch.Tensor + The lengths + """ + tts_out = self.tts(text) + wav, length = self.modules.tts2wav(tts_out) + if wav.dim() > 2: + wav = wav.squeeze(1) + return wav, length + + def write_result(self, key, item_ids, result): + """Outputs a speech evaluation result to the target file + + Arguments + --------- + key : str + The evaluator key + item_id : list + A list of IDs + result : list + speechbrain.inference.eval.SpeechEvaluationResult + The evaluation result from a single evaluator""" + writer = self.output_writers[key] + for item_id, item_result in zip(item_ids, result): + row = { + "id": item_id, + "score": item_result.score, + **item_result.details, + } + writer.writerow(flatten(row)) + + def get_report_columns(self, evaluator_key): + """Returns the columns for the specified evaluator + + Arguments + --------- + evaluator_key : str + the identifier of the evaluator + + Returns + ------- + columns : list[str] + a list of column headers + """ + bogus_wavs = torch.randn(2, 10000, device=self.device) + bogus_length = torch.tensor([1.0, 1.0], device=self.device) + evaluator = self.evaluators[evaluator_key] + result = evaluator.evaluate( + wavs=bogus_wavs, + length=bogus_length, + text=["BOGUS"] * len(bogus_wavs), + wavs_ref=bogus_wavs, + length_ref=bogus_length, + ) + return list( + OrderedDict.fromkeys(["id", "score"] + list(result.details.keys())) + ) + + def compute_summary(self): + """Computes the summarized statistics""" + return { + f"{evaluator_key}_{stat_key}": value + for evaluator_key in self.enabled_evaluators + if evaluator_key in self.details + for metric_key in self.hparams.eval_summary[evaluator_key][ + "descriptive" + ] + for stat_key, value in descriptive_statistics( + items=self.details[evaluator_key], key=metric_key, + ).items() + } + + def write_summary(self): + """Outputs summarized statistics""" + summary = self.compute_summary() + file_name = Path(self.hparams.output_files["summary"]) + file_name.parent.mkdir(parents=True, exist_ok=True) + with open(file_name, "w") as output_file: + json.dump(summary, output_file, indent=4) + + +def dataio_prepare(hparams): + """Prepares the dataset + + Arguments + --------- + hparams : dict + Raw hyperparameters""" + + data_folder = hparams["data_folder"] + eval_dataset = hparams["eval_dataset"] + json_path = hparams[f"{eval_dataset}_json"] + + dataset = sb.dataio.dataset.DynamicItemDataset.from_json( + json_path=json_path, + replacements={"data_root": data_folder}, + output_keys=["uttid", "label"], + ) + dataset.add_dynamic_item(label_norm_pipeline) + dataset.add_dynamic_item(audio_ref_pipeline) + dataset.set_output_keys( + ["uttid", "label_norm_eval", "label_norm", "label_norm_length", "sig"] + ) + + if hparams["sorting"] == "ascending": + dataset = dataset.filtered_sorted(sort_key="label_norm_length") + elif hparams["sorting"] == "descending": + dataset = dataset.filtered_sorted( + sort_key="label_norm_length", reverse=True + ) + return dataset + + +def flatten(value): + """Converts tensors to scalars and lists of strings to strings + + Arguments + --------- + value : dict + the dictionary to flatten + + Returns + ------- + result : dict + a flattened dictionary + """ + return { + key: item_value.item() if torch.is_tensor(item_value) else item_value + for key, item_value in value.items() + } + + +def descriptive_statistics(items, key): + """Computes descriptive statistics for the summary + + Arguments + --------- + items : list + a list of dictionaries with metric values for each item + key : str + """ + values = torch.tensor([item.details[key] for item in items]) + quantiles = torch.tensor([0.25, 0.5, 0.75]) + q1, median, q3 = values.quantile(quantiles) + stats = { + "mean": values.mean(), + "std": values.std(), + "min": values.min(), + "max": values.max(), + "median": median, + "q1": q1, + "q3": q3, + "iqr": q3 - q1, + } + return { + f"{key}_{stat_key}": value.item() for stat_key, value in stats.items() + } + + +RE_PUNCTUATION = re.compile( + "|".join(re.escape(char) for char in string.punctuation) +) + + +@sb.utils.data_pipeline.takes("label") +@sb.utils.data_pipeline.provides( + "label_norm", "label_norm_length", "label_norm_eval" +) +def label_norm_pipeline(label): + """Normalizes labels for ASR comparison, converting to uppercase and removing + punctuation + + Arguments + --------- + label : str + The unnormalized label + + Returns + ------- + result : str + The normalized label + """ + label_norm = label.upper() + yield label_norm + yield len(label_norm) + label_norm_eval = RE_PUNCTUATION.sub("", label_norm) + yield label_norm_eval + + +@sb.utils.data_pipeline.takes("wav") +@sb.utils.data_pipeline.provides("sig") +def audio_ref_pipeline(wav): + """The audio loading pipeline for references + + Arguments + --------- + wav : str + The file path + + Returns + ------- + sig : torch.Tensor + The waveform + """ + sig = sb.dataio.dataio.read_audio(wav) + return sig + + +if __name__ == "__main__": + # Load hyperparameters file with command-line overrides + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + from ljspeech_prepare import prepare_ljspeech + + sb.utils.distributed.run_on_main( + prepare_ljspeech, + kwargs={ + "data_folder": hparams["data_folder"], + "save_folder": hparams["prepare_save_folder"], + "splits": hparams["splits"], + "split_ratio": hparams["split_ratio"], + "seed": hparams["seed"], + "skip_prep": hparams["skip_prep"], + "skip_ignore_folders": hparams["skip_ignore_folders"], + "frozen_split_path": hparams["frozen_split_path"], + }, + ) + + dataset = dataio_prepare(hparams) + + evaluator = Evaluator(hparams, run_opts) + evaluator.evaluate(dataset) diff --git a/recipes/LJSpeech/evaluation/hparams/fastspeech2.yaml b/recipes/LJSpeech/evaluation/hparams/fastspeech2.yaml new file mode 100644 index 0000000000..32855df5f6 --- /dev/null +++ b/recipes/LJSpeech/evaluation/hparams/fastspeech2.yaml @@ -0,0 +1,70 @@ +############################################################################ +# Model: FastSpeech2 +# Evaluation recipe +# Tokens: Raw characters (English text) +# Training: LJSpeech +# Authors: Artem Ploujnikov +# ############################################################################ + +seed: 1234 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref ./results/fastspeech2/ + +data_folder: !PLACEHOLDER +prepare_save_folder: !ref /prepared +pretrained_model_save_folder: !ref +evaluations: asr,mos +eval_dataset: valid +sorting: descending +splits: ["train", "valid", "test"] +split_ratio: [90, 5, 5] +skip_prep: False +skip_ignore_folders: False +frozen_split_path: null +batch_size: 8 + +train_json: !ref /train.json +valid_json: !ref /valid.json +test_json: !ref /test.json + + +asr_source: speechbrain/asr-transformer-transformerlm-librispeech +tts_source: speechbrain/tts-fastspeech2-ljspeech +mos_source: flexthink/ttseval-wavlm-transformer +vocoder_model: speechbrain/tts-hifigan-libritts-16kHz + +vocoder: !name:speechbrain.inference.vocoders.HIFIGAN.from_hparams + source: !ref + savedir: !ref /vocoder + +tts: !name:speechbrain.inference.TTS.FastSpeech2.from_hparams + source: !ref + savedir: !ref /tts-fastspeech2 + return_lengths: True + +tts2wav: !new:adapters.MelAdapter + vocoder: !ref + +evaluators: + mos: !name:speechbrain.inference.eval.RegressionModelSpeechEvaluator + source: !ref + savedir: !ref /mos + asr: !name:speechbrain.inference.eval.EncoderDecoderASRSpeechEvaluator + source: !ref + savedir: !ref /asr + overrides: + lm_weight: 0.0 + +output_files: + mos: !ref /mos.csv + asr: !ref /asr.csv + summary: !ref /summary.json + +modules: + tts2wav: !ref + +eval_summary: + mos: + descriptive: ["score"] + asr: + descriptive: ["wer", "cer", "wer_ref", "cer_ref", "dwer", "dcer"] diff --git a/recipes/LJSpeech/evaluation/hparams/tacotron2.yaml b/recipes/LJSpeech/evaluation/hparams/tacotron2.yaml new file mode 100644 index 0000000000..b54529c6eb --- /dev/null +++ b/recipes/LJSpeech/evaluation/hparams/tacotron2.yaml @@ -0,0 +1,70 @@ +############################################################################ +# Model: Tacotron2 +# Evaluation recipe +# Tokens: Raw characters (English text) +# losses: Transducer +# Training: LJSpeech +# Authors: Artem Ploujnikov +# ############################################################################ +seed: 1234 +__set_seed: !apply:torch.manual_seed [!ref ] +output_folder: !ref ./results/tacotron2/ + +data_folder: !PLACEHOLDER +prepare_save_folder: !ref /prepared +pretrained_model_save_folder: !ref +evaluations: asr,mos +eval_dataset: valid +sorting: descending +splits: ["train", "valid", "test"] +split_ratio: [90, 5, 5] +skip_prep: False +skip_ignore_folders: False +frozen_split_path: null +batch_size: 8 + +train_json: !ref /train.json +valid_json: !ref /valid.json +test_json: !ref /test.json + + +asr_source: speechbrain/asr-transformer-transformerlm-librispeech +tts_source: speechbrain/tts-tacotron2-ljspeech +mos_source: flexthink/ttseval-wavlm-transformer +vocoder_model: speechbrain/tts-hifigan-libritts-16kHz + +vocoder: !name:speechbrain.inference.vocoders.HIFIGAN.from_hparams + source: !ref + savedir: !ref /vocoder + +tts: !name:speechbrain.inference.TTS.Tacotron2.from_hparams + source: !ref + savedir: !ref /tts-tacotron + +tts2wav: !new:adapters.MelAdapter + vocoder: !ref + +evaluators: + mos: !name:speechbrain.inference.eval.RegressionModelSpeechEvaluator + source: !ref + savedir: !ref /mos + asr: !name:speechbrain.inference.eval.EncoderDecoderASRSpeechEvaluator + source: !ref + savedir: !ref /asr + overrides: + lm_weight: 0.0 + + +output_files: + mos: !ref /mos.csv + asr: !ref /asr.csv + summary: !ref /summary.json + +modules: + tts2wav: !ref + +eval_summary: + mos: + descriptive: ["score"] + asr: + descriptive: ["wer", "cer", "wer_ref", "cer_ref", "dwer", "dcer"] diff --git a/recipes/LJSpeech/evaluation/ljspeech_prepare.py b/recipes/LJSpeech/evaluation/ljspeech_prepare.py new file mode 120000 index 0000000000..2de5a21a8d --- /dev/null +++ b/recipes/LJSpeech/evaluation/ljspeech_prepare.py @@ -0,0 +1 @@ +../ljspeech_prepare.py \ No newline at end of file diff --git a/recipes/LJSpeech/ljspeech_prepare.py b/recipes/LJSpeech/ljspeech_prepare.py index 99d10abd96..44de9aec81 100644 --- a/recipes/LJSpeech/ljspeech_prepare.py +++ b/recipes/LJSpeech/ljspeech_prepare.py @@ -16,6 +16,7 @@ import torch import torchaudio import numpy as np +from pathlib import Path from tqdm import tqdm from speechbrain.utils.data_utils import download_file from speechbrain.dataio.dataio import load_pkl, save_pkl @@ -51,6 +52,8 @@ def prepare_ljspeech( pitch_min_f0=65, pitch_max_f0=400, skip_prep=False, + skip_ignore_folders=False, + frozen_split_path=None, use_custom_cleaner=False, device="cpu", ): @@ -83,6 +86,11 @@ def prepare_ljspeech( If True, skip preparation use_custom_cleaner : bool If True, uses custom cleaner defined for this recipe + skip_ignore_folders : bool + Whether to ignore differences in data and save folders when + checking if the dataset has already been prepared. This is + useful on high-performance compute clusters where such + folders are not permanent device : str Device for to be used for computation (used as required) @@ -154,7 +162,7 @@ def prepare_ljspeech( os.makedirs(pitch_folder) # Check if this phase is already done (if so, skip it) - if skip(splits, save_folder, conf): + if skip(splits, save_folder, conf, ignore_foders=skip_ignore_folders): logger.info("Skipping preparation, completed in previous run.") return @@ -165,7 +173,9 @@ def prepare_ljspeech( # Prepare data splits msg = "Creating json file for ljspeech Dataset.." logger.info(msg) - data_split, meta_csv = split_sets(data_folder, splits, split_ratio) + data_split, meta_csv = split_sets( + data_folder, splits, split_ratio, frozen_split_path + ) if "train" in splits: prepare_json( @@ -221,11 +231,23 @@ def prepare_ljspeech( save_pkl(conf, save_opt) -def skip(splits, save_folder, conf): +def skip(splits, save_folder, conf, ignore_foders=False): """ Detects if the ljspeech data_preparation has been already done. If the preparation has been done, we can skip it. + Arguments + --------- + splits : list + The list of split identifiers + save_folder : str | path-like + The folder where the prepared dataset is saved + conf : dict + The dataset preparation configuration + ignore_folders : bool + Whether differences in folder parameters are to be + ignored + Returns ------- bool @@ -250,6 +272,10 @@ def skip(splits, save_folder, conf): if skip is True: if os.path.isfile(save_opt): opts_old = load_pkl(save_opt) + if ignore_foders: + opts_old = remove_folder_opts(opts_old) + conf = remove_folder_opts(opts_old) + if opts_old == conf: skip = True else: @@ -259,7 +285,23 @@ def skip(splits, save_folder, conf): return skip -def split_sets(data_folder, splits, split_ratio): +def remove_folder_opts(conf): + """Removes all folder options from the configuration dict + + Arguments + --------- + conf : dict + The configuration dictionary + + Returns + ------- + conf : dict + The resulting configuration + """ + return {k: v for k, v in conf.items() if not k.endswith("_folder")} + + +def split_sets(data_folder, splits, split_ratio, frozen_split_path): """Randomly splits the wav list into training, validation, and test lists. Note that a better approach is to make sure that all the classes have the same proportion of samples for each session. @@ -284,6 +326,17 @@ def split_sets(data_folder, splits, split_ratio): ) meta_csv = list(csv_reader) + if frozen_split_path is not None: + frozen_split_path = Path(frozen_split_path) + if frozen_split_path.exists(): + logger.info("Found frozen splits in %s", frozen_split_path) + with open(frozen_split_path, "r") as frozen_split_file: + data_split = json.load(frozen_split_file) + return data_split, meta_csv + else: + logger.info( + "Frozen split %s does not exst, splliting", frozen_split_path + ) index_for_sessions = [] session_id_start = "LJ001" @@ -323,6 +376,11 @@ def split_sets(data_folder, splits, split_ratio): if split == "test": data_split[split].extend(index_for_sessions[j]) + if frozen_split_path is not None: + logger.info("Saving frozen splits in %s", frozen_split_path) + with open(frozen_split_path, "w") as frozen_split_file: + json.dump(data_split, frozen_split_file, indent=0) + return data_split, meta_csv diff --git a/recipes/SOMOS/somos_prepare.py b/recipes/SOMOS/somos_prepare.py new file mode 100644 index 0000000000..ba3fb036e4 --- /dev/null +++ b/recipes/SOMOS/somos_prepare.py @@ -0,0 +1,295 @@ +""" +SOMOS data preparation + +Download: https://zenodo.org/records/7378801 +Paper: https://paperswithcode.com/paper/somos-the-samsung-open-mos-dataset-for-the + +Authors + * Artem Ploujnikov 2023 +""" +from pathlib import Path +from zipfile import ZipFile +from speechbrain.dataio.dataio import merge_csvs +import re +import csv +import logging + +logger = logging.getLogger(__name__) + +FILE_AUDIO_ZIP = "audios.zip" +FILE_DATA = "data.csv" +PATH_AUDIOS = "audios" +PATH_METADATA = "training_files/split1/{subset}/{split}_mos_list.txt" +RE_EXT_WAV = re.compile(".wav$") +COLUMN_ID = "ID" +COLUMN_WAV = "wav" +COLUMN_CHAR = "char" +COLUMN_SCORE = "score" +COLUMN_SYSTEM = "system" +TOOLS_PATH = Path(__file__).parent.parent.parent.parent / "tools" +TOOLS_PATH_VOICEMOS = TOOLS_PATH / "voicemos" +VOICEMOS_NORM_SCRIPT = TOOLS_PATH_VOICEMOS / "sub_normRMSE.sh" + + +def prepare_somos( + data_folder, + save_folder, + splits=["train", "valid", "test"], + subset="full", + use_transcripts=False, + char_list_file=None, +): + """Prepares the csv files for the Somos dataset + + Arguments + --------- + data_folder : str | path-like + Path to the folder where the original LJspeech dataset is stored + save_folder : str | path-like + The directory where to store the csv/json files + splits : list + List of dataset splits to prepare + subset : str + the subset to use: + "full" for the full dataset + "clean" for clean data only + transcripts : bool + Whether to include transcripts (requires a version of SOMOS where + transcript/gather_transcripts.py has been run) + char_list_file : str|path-like + The list of characters + """ + data_folder = Path(data_folder) + save_folder = Path(save_folder) + + if not data_folder.exists(): + raise ValueError(f"{data_folder} does not exist") + save_folder.mkdir(parents=True, exist_ok=True) + extract_audio(data_folder, save_folder) + # Note: This can be overridden from the command line + if isinstance(splits, str): + splits = splits.split(",") + transcripts = None + char_set = None + if use_transcripts: + if char_list_file is not None: + char_list = read_list_file(char_list_file) + char_set = set(char_list) + transcripts_file_name = data_folder / "transcript/all_transcripts.txt" + if not transcripts_file_name.exists(): + raise ValueError( + f"{transcripts_file_name} does not exist, please run " + "gather_transcripts.py in {data_folder}/transcript" + ) + transcripts = read_transcripts(transcripts_file_name, char_set) + + for split in splits: + process_split(data_folder, save_folder, split, subset, transcripts) + merge_splits(save_folder, splits) + + +def read_list_file(file_name): + """Reads a file with a simple list of items - used for + filtering characters in transcripts + + Arguments + --------- + file_name : str|path-like + The path to the file + + Returns + ------- + items : list + The lines from the file + """ + with open(file_name) as list_file: + return [line.replace("\r", "").replace("\n", "") for line in list_file] + + +def extract_audio(data_folder, save_folder): + """Extracts audio files + + Arguments + --------- + data_folder : str + Path to the folder where the original LJspeech dataset is stored + save_folder : str + The directory where to store the csv/json files + """ + audios_path = Path(data_folder) / PATH_AUDIOS + if audios_path.exists(): + logging.info( + "Skipping audio extraction - %s already exists", audios_path + ) + else: + audio_archive_path = Path(data_folder) / FILE_AUDIO_ZIP + logger.info("Extracting audio to %s", save_folder) + with ZipFile(audio_archive_path) as audio_archive: + audio_archive.extractall(save_folder) + + +def get_metadata_columns(use_transcripts=False): + """Gets the list of columns to be included + + Arguments + --------- + use_transcripts : bool + Whether to include transcripts (requires a version of SOMOS where + transcript/gather_transcripts.py has been run) + + Returns + ------- + columns : list + A list of column names + """ + columns = [COLUMN_ID, COLUMN_WAV, COLUMN_SYSTEM] + if use_transcripts: + columns.append(COLUMN_CHAR) + columns.append(COLUMN_SCORE) + return columns + + +def read_transcripts(file_name, char_set): + """Reads a transcript file + + Arguments + --------- + file_name : str|path-like + The path to the file containing transcripts + + Returns + ------- + result : dict + The transcript dictionary + char_set : set + The whitelist of allwoed characters + """ + + with open(file_name) as transcript_file: + records = ( + parse_transcript_line(line.strip(), char_set) + for line in transcript_file + ) + return {item_id: transcript for item_id, transcript in records} + + +def parse_transcript_line(line, char_set): + """Parses a single line of the transcript + + Arguments + --------- + line : str + A raw line from the file + char_set : set + The whitelist of allwoed characters + + Results + ------- + item_id : str + The item ID + + transcript : str + The normalized transcript""" + item_id, transcript = line.split("\t") + transcript = transcript.upper() + if char_set is not None: + transcript = "".join(char for char in transcript if char in char_set) + return item_id, transcript + + +def process_split(data_folder, save_folder, split, subset, transcripts=None): + """Processes metadata for the specified split + + Arguments + --------- + data_folder : str + Path to the folder where the original LJspeech dataset is stored + save_folder : str + The directory where to store the csv/json files + split : str + the split identifier ("train", "valid" or "test") + subset : str + the subset to use: + "full" for the full dataset + "clean" for clean data only + transcripts : dict, optional + The parsed transcripts + """ + src_metadata_file_path = data_folder / PATH_METADATA.format( + split=split, subset=subset + ) + tgt_metadata_file_path = save_folder / f"{split}.csv" + logger.info( + "Processing %s - from %s to %s", + split, + src_metadata_file_path, + tgt_metadata_file_path, + ) + + if not src_metadata_file_path.exists(): + raise ValueError(f"{src_metadata_file_path} does not exist") + + metadata_columns = get_metadata_columns(transcripts is not None) + + with open(src_metadata_file_path) as src_file: + with open(tgt_metadata_file_path, "w") as tgt_file: + reader = csv.DictReader(src_file) + writer = csv.DictWriter(tgt_file, metadata_columns) + writer.writeheader() + for src_item in reader: + src_audio_path = ( + Path(data_folder) / PATH_AUDIOS / src_item["utteranceId"] + ) + if src_audio_path.exists(): + tgt_item = process_item(src_item, transcripts) + writer.writerow(tgt_item) + else: + logger.warn("%s does not exist", src_audio_path) + + +def process_item(item, transcripts): + """Converts a single metadata record to the SpeechBrain + convention + + Arguments + --------- + item : dict + a single record from the source file + transcripts : dict + The parsed transcripts + + Returns + ------- + result: dict + the processed item""" + src_utterance_id = item["utteranceId"] + tgt_utterance_id = RE_EXT_WAV.sub("", src_utterance_id) + system_id = tgt_utterance_id.split("_")[-1] + wav_path = Path("$data_root") / PATH_AUDIOS / src_utterance_id + result = { + "ID": tgt_utterance_id, + "wav": wav_path, + "system": system_id, + "score": item["mean"], + } + if transcripts is not None: + result["char"] = transcripts[tgt_utterance_id] + + return result + + +def merge_splits( + save_folder, splits, +): + """Merges data files into a single file + + Arguments + --------- + save_folder : str | path-like + The directory where to store the csv/json files + splits : list + List of dataset splits to prepare + """ + tgt_file_path = save_folder / FILE_DATA + csvs = [save_folder / f"{split}.csv" for split in splits] + merge_csvs(save_folder, csvs, tgt_file_path) diff --git a/recipes/SOMOS/ttseval/README.md b/recipes/SOMOS/ttseval/README.md new file mode 100644 index 0000000000..d914bf7e9f --- /dev/null +++ b/recipes/SOMOS/ttseval/README.md @@ -0,0 +1,59 @@ +# MOS Estimation (with SOMOS) +This folder contains the recipes for training TTS evaluation systems trained on LJSpeech using the Samsung Open MOS Dataset (SOMOS) + +# Dataset +The dataset can be downloaded from here: +https://zenodo.org/records/7378801 + +# Installing Extra Dependencies + +Before proceeding, ensure you have installed the necessary additional dependencies. To do this, simply run the following command in your terminal: + +``` +pip install -r extra_requirements.txt +``` +# TTS Evaluation Model + +The model is based loosely on the baseline model described in the VoiceMOS 2022 challenge described in the paper below: +https://arxiv.org/pdf/2203.11389.pdf + +It is based on the same principle: the weights of the base self-supervised representation model are updated in order +to fine-tune it to the quality assessment task. Linear regression between human and model ratings is used for +assessment. + +Additional enhancements have been added. The updated model featured a shallow encoder-only Tranformer before the pooling +layer in order to introduce an attention mechanism. + +To run this recipe, run the following command + +``` +python train.py --device=cuda:0 --data_folder=/your_folder/SOMOS hparams/train_ssl_wavlm_xformer.yaml +``` + +# Training Results +| Release | Model | hyperparams file | Val R | Test R | HuggingFace Link | Model Link | GPUs | +| ----------- |:-----------------:| ----------------:|:--------------:|:-------------------------------------------------------------------:|:------------------------------------:|:-----------:| +| 2024-02-26 | WavLM Transformer | train.yaml | TBD | TBD | [model](https://huggingface.co/flexthink/ttseval-wavlm-transformer) | [model](https://www.dropbox.com/tbd) | 1xV100 32GB | + + +# **About SpeechBrain** +- Website: https://speechbrain.github.io/ +- Code: https://github.com/speechbrain/speechbrain/ +- HuggingFace: https://huggingface.co/speechbrain/ + + +# **Citing SpeechBrain** +Please, cite SpeechBrain if you use it for your research or business. + +```bibtex +@misc{speechbrain, + title={{SpeechBrain}: A General-Purpose Speech Toolkit}, + author={Mirco Ravanelli and Titouan Parcollet and Peter Plantinga and Aku Rouhe and Samuele Cornell and Loren Lugosch and Cem Subakan and Nauman Dawalatabad and Abdelwahab Heba and Jianyuan Zhong and Ju-Chieh Chou and Sung-Lin Yeh and Szu-Wei Fu and Chien-Feng Liao and Elena Rastorgueva and François Grondin and William Aris and Hwidong Na and Yan Gao and Renato De Mori and Yoshua Bengio}, + year={2021}, + eprint={2106.04624}, + archivePrefix={arXiv}, + primaryClass={eess.AS}, + note={arXiv:2106.04624} +} +``` + diff --git a/recipes/SOMOS/ttseval/contrastive_sampling.py b/recipes/SOMOS/ttseval/contrastive_sampling.py new file mode 100644 index 0000000000..f0f8902a31 --- /dev/null +++ b/recipes/SOMOS/ttseval/contrastive_sampling.py @@ -0,0 +1,199 @@ +""" +Contrastive sampling utilities + +Authors + * Artem Ploujnikov 2023 +""" + +import torch +from functools import partial +from speechbrain.utils.data_pipeline import takes, provides +from speechbrain.utils import checkpoints + + +@checkpoints.register_checkpoint_hooks +class RegressionContrastiveEnhancement: + """An dataset enhancement for contrastive learning on simple regression + tasks with a single metric, useful for cases where consistently estimating + differences is easier than evaluating the metric for an individual sample + and where differences that are too small are not considered useful. Originally, + this was developed for MOS estimation + + For any given sample, a paired sample will be selected by first finding + all "allowed" pairings (by excluding those less than min_delta units away), + sorting them by distance - and then sampling from the uniform distribution + (of indices, not of distances). + + Arguments + --------- + metric_key : str + The key in the original dataset corresponding to the metric + + min_delta : float + The minimum metric distance (absolute value) for a sample to be pairable + with any given sample + + Examples + -------- + >>> data = { + ... 1: {"score": 3.5}, + ... 2: {"score": 2.7}, + ... 3: {"score": 5.0}, + ... 4: {"score": 1.2}, + ... 5: {"score": 2.5}, + ... 6: {"score": 3.2}, + ... 7: {"score": 3.8}, + ... 8: {"score": 1.7}, + ... 9: {"score": 1.2}, + ... 10: {"score": 4.2}, + ... } + >>> from speechbrain.dataio.dataset import DynamicItemDataset + >>> dataset = DynamicItemDataset(data) + >>> dataset.set_output_keys(["id", "score"]) + >>> from contrastive_sampling import RegressionContrastiveEnhancement + >>> sampling = RegressionContrastiveEnhancement( + ... metric_key="score", + ... min_delta=0.5, + ... seed=42 + ... ) + >>> sampling.bind(dataset) + >>> from speechbrain.dataio.dataloader import make_dataloader + >>> loader = make_dataloader(dataset) + >>> loader_it = iter(loader) + >>> batch = next(loader_it) + >>> batch.score.item() + 3.5 + >>> batch.contrast_score.item() + 2.5 + """ + + def __init__(self, metric_key, min_delta, seed=None): + self.metric_key = metric_key + self.min_delta = min_delta + self.generator = torch.Generator() + if seed is not None: + self.generator.manual_seed(seed) + + def bind(self, dataset): + """Binds the enhancement to a dataset, adding the contrastive pairings + to its pipeline + + Arguments + --------- + dataset : speechbrain.dataio.dataset.DynamicItemDataset + the target dataset + """ + metric_values = torch.tensor( + [item[self.metric_key] for item in dataset] + ) + size = len(dataset) + self.indexes = torch.arange(size) + metric_values_sorted, self.indexes_sorted = metric_values.sort() + metric_diff_abs = ( + metric_values_sorted[None, :] - metric_values_sorted[:, None] + ).abs() + selection_blocked = metric_diff_abs < self.min_delta + min_shift_right = selection_blocked.triu().sum(-1) + self.min_shift_left = selection_blocked.tril().sum(-1) + self.indexes_sorted_mirror = torch.cat( + [self.indexes_sorted.flip(0)[1:], self.indexes_sorted] + ) + self.indexes_sorted_mirror = self.indexes_sorted_mirror.unsqueeze( + 0 + ).expand(len(dataset), self.indexes_sorted_mirror.size(-1)) + for idx in range(len(dataset)): + self.indexes_sorted_mirror[idx] = self.indexes_sorted_mirror[ + idx + ].roll(-idx) + + self.shift_max = (2 * size - 1) - self.min_shift_left - min_shift_right + keys = list(dataset.pipeline.output_mapping.keys()) + self.data_ids = dataset.data_ids + pairings_map = self._get_pairings_map() + if not keys: + raise ValueError("Output keys must be set before binding") + contrastive_keys = [f"contrast_{key}" for key in keys] + self.pipeline = ContrastivePairingPipeline(keys, pairings_map) + pipeline_element = partial(self.pipeline, dataset) + pipeline_element = takes("id")(pipeline_element) + pipeline_element = provides(*contrastive_keys)(pipeline_element) + dataset.add_dynamic_item(pipeline_element) + dataset.set_output_keys(keys + contrastive_keys) + + def _get_pairings_map(self): + """Builds a returns a dictionary of item pairings""" + shift_rel = torch.rand(len(self.indexes), generator=self.generator) + shift_abs = ( + self.min_shift_left + (self.shift_max * shift_rel).floor().int() + ) + indexes_selected = self.indexes_sorted_mirror[ + torch.arange(len(shift_abs)), shift_abs + ] + pairings = torch.zeros_like(indexes_selected) + pairings[self.indexes_sorted] = indexes_selected + return { + data_id: pairing_idx + for data_id, pairing_idx in zip(self.data_ids, pairings) + } + + def shuffle(self): + """Re-samples the pairings""" + pairings_map = self._get_pairings_map() + self.pipeline.pairings = pairings_map + + @checkpoints.mark_as_saver + def save(self, path): + """Saves the current metrics on the specified path.""" + data = { + "generator_state": self.generator.get_state(), + } + torch.save(data, path) + + @checkpoints.mark_as_loader + def load(self, path, end_of_epoch=False, device=None): + """Loads the needed information.""" + del end_of_epoch + del device + data = torch.load(path) + self.generator.set_state(data["generator_state"]) + + +class ContrastivePairingPipeline: + """A helper callable that adds keys from the paired samples + to dataset elements. Instances of this class are intended to + be used as dynamic items in a DynamicItemDataset. + + Arguments + --------- + keys : list + a list of keys in the original dataset to be enhanced + + parirings : dict + a dictionary indicating how IDs are paired - with keys + corresponding to anchor items and values to the paired items + """ + + def __init__(self, keys, pairings): + self.keys = keys + self.pairings = pairings + + def __call__(self, dataset, data_id): + """Provides the data keys from the paired data sample + + Arguments + --------- + dataset : speechbrain.dataio.dataio.DynamicItemDataset + a dataset + data_id : object + the ID of the item + + Returns + ------- + result: generator + the values corresponding to the specified keys from + the paired item""" + pairing_id = self.pairings[data_id] + with dataset.output_keys_as(self.keys): + pairing = dataset[pairing_id] + for key in self.keys: + yield pairing[key] diff --git a/recipes/SOMOS/ttseval/extra_requirements.txt b/recipes/SOMOS/ttseval/extra_requirements.txt new file mode 100644 index 0000000000..45efa8f3f8 --- /dev/null +++ b/recipes/SOMOS/ttseval/extra_requirements.txt @@ -0,0 +1 @@ +seaborn>=0.13.0 diff --git a/recipes/SOMOS/ttseval/hparams/train.yaml b/recipes/SOMOS/ttseval/hparams/train.yaml new file mode 100644 index 0000000000..5977afe645 --- /dev/null +++ b/recipes/SOMOS/ttseval/hparams/train.yaml @@ -0,0 +1,112 @@ +# ############################################################################ +# Model: SSL with Wav2Vec (training from scratch) +# Authors: Artem Ploujnikov, Yingzhi Wang +# # ############################################################################ + + +# Seed needs to be set at top of yaml, before objects with parameters are instantiated +seed: 42 +__set_seed: !apply:torch.manual_seed [!ref ] + +data_folder: !PLACEHOLDER +output_folder: !ref results/ssl_wavlm/ +save_folder: !ref /save +details_folder: !ref /details +train_log: !ref /train_log.txt +train_regression_metric: True +batch_size: 4 +num_workers: 4 +use_transcripts: False +src_sample_rate: 24000 +tgt_sample_rate: 16000 +contrastive: False +contrastive_loss_weight: 0.8 +contrastive_min_delta: 0.5 +classification_epochs: 3 +classification_threshold: 3.5 + +lr: 0.00001 + +dataloader_options: + batch_size: !ref + num_workers: !ref + shuffle: True + +number_of_epochs: 10 +ckpt_interval_minutes: 15 + +wavlm_source: microsoft/wavlm-large +wavlm_folder: !ref /wavlm_checkpoint + +#freeze all wavlm +wavlm_freeze: False +#set to true to freeze the CONV part of the wavlm model +wavlm_freeze_feature_extractor: True +activation: !name:torch.nn.LeakyReLU + +d_model: 512 +d_ffn: 2048 +num_layers: 3 +nhead: 4 +dropout: 0.5 + +splits: ["train", "valid", "test"] +subset: "full" +skip_prep: False + +train_annotation: !ref /train.csv +valid_annotation: !ref /valid.csv +test_annotation: !ref /test.csv + +train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger + save_file: !ref + +compute_cost: !name:speechbrain.nnet.losses.l1_loss +compute_cost_contrastive: !name:speechbrain.nnet.losses.mse_loss +compute_cost_classification: !name:speechbrain.nnet.losses.bce_loss + + +wavlm: !new:speechbrain.lobes.models.huggingface_transformers.wavlm.WavLM + source: !ref + output_norm: True + freeze: !ref + freeze_feature_extractor: !ref + save_path: !ref + +model: !new:speechbrain.lobes.models.eval.ssl.TransformerRegression + base_model: !ref + d_model: !ref + d_ffn: !ref + num_layers: !ref + nhead: !ref + dropout: !ref + activation: !ref + +classifier: !new:speechbrain.lobes.models.eval.ssl.BaselineSSLFinetune + base_model: !ref + +modules: + model: !ref + classifier: !ref + +epoch_counter: !new:speechbrain.utils.epoch_loop.EpochCounter + limit: !ref + +opt_class: !name:torch.optim.Adam + lr: !ref + +checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer + checkpoints_dir: !ref + recoverables: + model: !ref + classifier: !ref + lr_annealing_output: !ref + counter: !ref + allow_partial_load: True + + +lr_annealing: !new:speechbrain.nnet.schedulers.NewBobScheduler + initial_value: !ref + improvement_threshold: 0.0025 + annealing_factor: 0.9 + patient: 0 diff --git a/recipes/SOMOS/ttseval/somos_prepare.py b/recipes/SOMOS/ttseval/somos_prepare.py new file mode 120000 index 0000000000..5e65865a54 --- /dev/null +++ b/recipes/SOMOS/ttseval/somos_prepare.py @@ -0,0 +1 @@ +../somos_prepare.py \ No newline at end of file diff --git a/recipes/SOMOS/ttseval/train.py b/recipes/SOMOS/ttseval/train.py new file mode 100644 index 0000000000..e8b1bd950f --- /dev/null +++ b/recipes/SOMOS/ttseval/train.py @@ -0,0 +1,585 @@ +#!/usr/bin/env python3 +"""Recipe for training a TTS evaluation system + +Authors + * Artem Ploujnikov 2024 + * Yingzi Wang 2024 +""" +import sys +import speechbrain as sb +import torchaudio +import logging +from enum import Enum +from hyperpyyaml import load_hyperpyyaml +from pathlib import Path +from contrastive_sampling import RegressionContrastiveEnhancement + + +logger = logging.getLogger(__name__) + +LABEL_MODEL_SCORE = "Model Score" +LABEL_HUMAN_SCORE = "Human Score" +KEY_MODEL_SCORE = "model_score" +KEY_HUMAN_SCORE = "human_score" + + +class TTSEvalTrainMode(Enum): + CLASSIFICATION = "classification" + CONTRASTIVE = "contrastive" + REGRESSION = "regression" + + +# Brain class for TTS evaluation training +class TTSEvalBrain(sb.Brain): + """Class that manages the training loop for TTS evaluation. + See speechbrain.core.Brain.""" + + def compute_forward(self, batch, stage): + """Computes the forward pass + + Arguments + --------- + batch : PaddedBatch + This batch object contains all the relevant tensors for computation. + stage : sb.Stage + One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST. + + Returns + ------- + predictions : Tensor | tuple + predictions + """ + if self.mode == TTSEvalTrainMode.CONTRASTIVE: + return self.compute_forward_contrastive(batch) + + # We first move the batch to the appropriate device. + batch = batch.to(self.device) + + # Compute predictions + model = ( + self.modules.classifier + if self.mode == TTSEvalTrainMode.CLASSIFICATION + else self.modules.model + ) + predictions = model(batch.sig.data, batch.sig.lengths) + + return predictions + + def compute_forward_contrastive(self, batch): + """Computes the forward pass (contrastive mode) + + Arguments + --------- + batch : PaddedBatch + This batch object contains all the relevant tensors for computation. + stage : sb.Stage + One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST. + + Returns + ------- + predictions : Tensor + Tensor that contains the posterior probabilities over the N classes. + """ + batch = batch.to(self.device) + + # Compute predictions + predictions_anchor = self.modules.model( + batch.sig.data, batch.sig.lengths + ) + predictions_contrast = self.modules.model( + batch.contrast_sig.data, batch.contrast_sig.lengths + ) + + return predictions_anchor, predictions_contrast + + def compute_objectives(self, predictions, batch, stage): + """Computes the loss given the predicted and targeted outputs. + + Arguments + --------- + predictions : tensor + The output tensor from `compute_forward`. + batch : PaddedBatch + This batch object contains all the relevant tensors for computation. + stage : sb.Stage + One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST. + + Returns + ------- + loss : torch.Tensor + A one-element tensor used for backpropagating the gradient. + """ + if self.mode == TTSEvalTrainMode.CONTRASTIVE: + loss = self.compute_objectives_contrastive( + predictions, batch, stage + ) + elif self.mode == TTSEvalTrainMode.CLASSIFICATION: + loss = self.compute_objectives_classification( + predictions, batch, stage + ) + else: + loss = self.compute_objectives_regression(predictions, batch, stage) + return loss + + def compute_objectives_regression(self, predictions, batch, stage): + """Computes the classification loss + + Arguments + --------- + predictions : tensor + The output tensor from `compute_forward`. + batch : PaddedBatch + This batch object contains all the relevant tensors for computation. + stage : sb.Stage + One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST. + + Returns + ------- + loss : torch.Tensor + A one-element tensor used for backpropagating the gradient. + """ + scores = batch.score_num[:, None, None] + loss = self.hparams.compute_cost(predictions, scores) + + # Append this batch of losses to the loss metric for easy + self.loss_metric.append( + batch.id, predictions, scores, batch.sig.lengths, reduction="batch" + ) + if self.reg_metric is not None: + self.reg_metric.append(batch.id, predictions, scores) + self.reg_system_metric.append( + batch.id, predictions, scores, groups=batch.system + ) + + return loss + + def compute_objectives_classification(self, predictions, batch, stage): + """Computes the classification loss + + Arguments + --------- + predictions : tensor + The output tensor from `compute_forward`. + batch : PaddedBatch + This batch object contains all the relevant tensors for computation. + stage : sb.Stage + One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST. + + Returns + ------- + loss : torch.Tensor + A one-element tensor used for backpropagating the gradient. + """ + scores = batch.score_num[:, None] + targets = (scores > self.hparams.classification_threshold).float()[ + :, None + ] + loss = self.hparams.compute_cost_classification(predictions, targets) + self.loss_metric_classification.append( + batch.id, predictions, targets, reduction="batch", + ) + return loss + + def compute_objectives_contrastive(self, predictions, batch, stage): + """Computes the loss given the predicted and targeted outputs. + + Arguments + --------- + predictions : tensor + The output tensor from `compute_forward`. + batch : PaddedBatch + This batch object contains all the relevant tensors for computation. + stage : sb.Stage + One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST. + + Returns + ------- + loss : torch.Tensor + A one-element tensor used for backpropagating the gradient. + """ + predictions_anchor, predictions_contrast = predictions + diff_predictions = predictions_anchor - predictions_contrast + scores_anchor = batch.score_num[:, None, None].to( + predictions_anchor.dtype + ) + scores_contrast = batch.contrast_score_num[:, None, None].to( + predictions_contrast.dtype + ) + diff_targets = scores_anchor - scores_contrast + + loss_predictive = 0.5 * ( + self.hparams.compute_cost(scores_anchor, predictions_anchor) + + self.hparams.compute_cost(scores_contrast, predictions_contrast) + ) + loss_contrastive = self.hparams.compute_cost_contrastive( + diff_predictions, diff_targets + ) + + predictive_loss_weight = 1.0 - self.hparams.contrastive_loss_weight + loss = ( + loss_predictive * predictive_loss_weight + + loss_contrastive * self.hparams.contrastive_loss_weight + ) + + # Append this batch of losses to the loss metric for easy + self.loss_metric.append( + batch.id, + diff_predictions, + diff_targets, + batch.sig.lengths, + reduction="batch", + ) + self.loss_metric_contrastive.append( + batch.id, + diff_predictions, + diff_targets, + batch.sig.lengths, + reduction="batch", + ) + if self.reg_metric is not None: + self.reg_metric.append(batch.id, predictions_anchor, scores_anchor) + self.reg_metric.append( + batch.contrast_id, predictions_contrast, scores_contrast + ) + + return loss + + def on_stage_start(self, stage, epoch=None): + """Gets called at the beginning of each epoch. + + Arguments + --------- + stage : sb.Stage + One of sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST. + epoch : int + The currently-starting epoch. This is passed + `None` during the test stage. + """ + + # Set up statistics trackers for this stage + if epoch is not None and epoch <= self.hparams.classification_epochs: + logger.info("Classification pretraining mode") + self.mode = TTSEvalTrainMode.CLASSIFICATION + elif self.hparams.contrastive: + logger.info("Contrastive training mode") + self.mode = TTSEvalTrainMode.CONTRASTIVE + else: + logger.info("Regular regression training mode") + self.mode = TTSEvalTrainMode.REGRESSION + + self.loss_metric = sb.utils.metric_stats.MetricStats( + metric=self.hparams.compute_cost + ) + self.loss_metric_contrastive = sb.utils.metric_stats.MetricStats( + metric=self.hparams.compute_cost_contrastive + ) + self.loss_metric_classification = sb.utils.metric_stats.MetricStats( + metric=self.hparams.compute_cost_classification + ) + + if ( + stage != sb.Stage.TRAIN or self.hparams.train_regression_metric + ) and self.mode != TTSEvalTrainMode.CLASSIFICATION: + self.reg_metric = sb.utils.metric_stats.LinearRegressionStats( + scores_label=LABEL_MODEL_SCORE, + targets_label=LABEL_HUMAN_SCORE, + scores_key=KEY_MODEL_SCORE, + targets_key=KEY_HUMAN_SCORE, + ) + self.reg_system_metric = sb.utils.metric_stats.LinearRegressionStats( + scores_label=LABEL_MODEL_SCORE, + targets_label=LABEL_HUMAN_SCORE, + scores_key=KEY_MODEL_SCORE, + targets_key=KEY_HUMAN_SCORE, + grouped=True, + ) + else: + self.reg_metric = None + self.reg_system_metric = None + + def get_stats(self, stage_loss): + """Retrieves statistics for the current stage + + Arguments + --------- + stage_loss : float + The average loss for all of the data processed in this stage. + """ + stats = { + "loss": stage_loss, + "mode": self.mode.value, + } + if self.mode != TTSEvalTrainMode.CLASSIFICATION: + stats["predictive_loss"] = self.loss_metric.summarize("average") + + if self.reg_metric is not None: + stats.update(self.get_prefixed_metric_stats(self.reg_metric, "utt")) + stats.update( + self.get_prefixed_metric_stats(self.reg_system_metric, "sys") + ) + else: + stats["utt_pearson_r"] = 0.0 + + return stats + + def on_stage_end(self, stage, stage_loss, epoch=None): + """Gets called at the end of an epoch. + + Arguments + --------- + stage : sb.Stage + One of sb.Stage.TRAIN, sb.Stage.VALID, sb.Stage.TEST + stage_loss : float + The average loss for all of the data processed in this stage. + epoch : int + The currently-starting epoch. This is passed + `None` during the test stage. + """ + + # Store the train loss until the validation stage. + if stage == sb.Stage.TRAIN: + self.train_stats = self.get_stats(stage_loss) + + # Summarize the statistics from the stage for record-keeping. + else: + stats = self.get_stats(stage_loss) + + # At the end of validation... + if stage == sb.Stage.VALID: + + old_lr, new_lr = self.hparams.lr_annealing(epoch) + sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr) + + # The train_logger writes a summary to stdout and to the logfile. + self.hparams.train_logger.log_stats( + {"Epoch": epoch, "lr": old_lr}, + train_stats=self.train_stats, + valid_stats=stats, + ) + + # Save the current checkpoint and delete previous checkpoints, + self.checkpointer.save_and_keep_only( + meta=stats, max_keys=["utt_pearson_r"] + ) + + # We also write statistics about test data to stdout and to the logfile. + if stage == sb.Stage.TEST: + self.hparams.train_logger.log_stats( + {"Epoch loaded": self.hparams.epoch_counter.current}, + test_stats=stats, + ) + + self.output_details(stage, epoch) + if self.hparams.contrastive: + self.shuffle(stage) + + def get_prefixed_metric_stats(self, metric, prefix): + """Gets statistics from a MetricStats instance and applies + a prfix to them + + Arguments + --------- + metric : speechbrain.utils.metric_stats.MetricStats + A metric instance + prefix : str + The prefix to use + + Returns + ------- + stats : dict + prefixed statistics + """ + stats = metric.summarize() + return {f"{prefix}_{key}": value for key, value in stats.items()} + + def shuffle(self, stage): + """Shuffles contrastive pairings + + Arguments + --------- + stage : speechbrain.Stage + The experiment stage""" + stage_key = stage.name.lower() + self.contrastive_enhancements[stage_key].shuffle() + + def output_details(self, stage, epoch=None): + """Outputs raw CSV stats and regression plots + + Arguments + --------- + stage : speechbrain.Stage + The experiment stage + epoch : int, optional + The epoch number""" + if self.reg_metric is None: + return None + target_path = Path(self.hparams.details_folder) + if epoch is not None: + suffix = str(epoch) + else: + suffix = "test" + target_path = target_path / suffix + target_path.mkdir(exist_ok=True, parents=True) + stage_label = str(stage.name).lower() + csv_file_name = f"raw_{stage_label}.csv" + self.reg_metric.write_csv(target_path / csv_file_name) + try: + plot_file_name = f"regression_{stage_label}.png" + self.reg_metric.plot(target_path / plot_file_name) + plot_file_name = f"regression_{stage_label}_system.png" + self.reg_system_metric.plot(target_path / plot_file_name) + except ImportError: + logger.warn("Unable to output plots, seaborn is not installed") + + +def dataio_prepare(hparams): + """Prepares the dataset + + Arguments + --------- + hparams : dict + Raw hyperparameters""" + + @sb.utils.data_pipeline.takes("score") + @sb.utils.data_pipeline.provides("score_num") + def score_pipeline(score): + return float(score) + + @sb.utils.data_pipeline.takes("wav") + @sb.utils.data_pipeline.provides("sig") + def audio_pipeline(wav): + """Load the signal, and pass it and its length to the corruption class. + This is done on the CPU in the `collate_fn`.""" + sig = sb.dataio.dataio.read_audio(wav) + sig = torchaudio.functional.resample( + sig, + orig_freq=hparams["src_sample_rate"], + new_freq=hparams["tgt_sample_rate"], + ) + return sig + + datasets = {} + for key in ["train", "valid", "test"]: + dataset = sb.dataio.dataset.DynamicItemDataset.from_csv( + hparams[f"{key}_annotation"], + dynamic_items=[score_pipeline, audio_pipeline], + replacements={ + "data_root": hparams["data_folder"], + "processed_folder": str( + Path(hparams["data_folder"]) / "processed" + ), + }, + ) + output_keys = ["id", "sig", "score_num", "system"] + if hparams.get("use_transcripts", False): + output_keys.append("char") + dataset.set_output_keys(output_keys) + datasets[key] = dataset + return datasets + + +def add_contrastive(datasets, hparams): + """Adds contrastive enhancement to the dataset + + Arguments + --------- + datasets : dict + a dictionary of datasets with "train", "valid" + and "test" keys + + hparams : dict + Hyperparameters + + Returns + ------- + contrastive_enhancements : dict + A dictionary with the same keys as the dataset + and corresponding RegressionContrastiveEnhancement + objects as values + """ + contrastive_enhancements = {} + for key, dataset in datasets.items(): + contrastive_enhancement = RegressionContrastiveEnhancement( + metric_key="score_num", + min_delta=hparams["contrastive_min_delta"], + seed=hparams["seed"], + ) + contrastive_enhancement.bind(dataset) + contrastive_enhancements[key] = contrastive_enhancement + return contrastive_enhancements + + +# Recipe begins! +if __name__ == "__main__": + + # Reading command line arguments. + hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:]) + + # Initialize ddp (useful only for multi-GPU DDP training). + sb.utils.distributed.ddp_init_group(run_opts) + + # Load hyperparameters file with command-line overrides. + with open(hparams_file) as fin: + hparams = load_hyperpyyaml(fin, overrides) + + # Create experiment directory + sb.create_experiment_directory( + experiment_directory=hparams["output_folder"], + hyperparams_to_save=hparams_file, + overrides=overrides, + ) + + # Data preparation, to be run on only one process. + from somos_prepare import prepare_somos + + if not hparams["skip_prep"]: + + sb.utils.distributed.run_on_main( + prepare_somos, + kwargs={ + "data_folder": hparams["data_folder"], + "save_folder": hparams["data_folder"], + "splits": hparams["splits"], + "subset": hparams["subset"], + "use_transcripts": hparams.get("use_transcripts", False), + "char_list_file": hparams.get("char_list_file"), + }, + ) + + # Create dataset objects "train", "valid", and "test". + datasets = dataio_prepare(hparams) + + # Initialize the Brain object to prepare for mask training. + ttseval_brain = TTSEvalBrain( + modules=hparams["modules"], + opt_class=hparams["opt_class"], + hparams=hparams, + run_opts=run_opts, + checkpointer=hparams["checkpointer"], + ) + + if hparams["contrastive"]: + contrastive_enhancements = add_contrastive(datasets, hparams) + checkpointer = hparams["checkpointer"] + for key, enhancement in contrastive_enhancements.items(): + checkpointer.add_recoverable(f"contrastive_{key}", enhancement) + ttseval_brain.contrastive_enhancements = contrastive_enhancements + + # The `fit()` method iterates the training loop, calling the methods + # necessary to update the parameters of the model. Since all objects + # with changing state are managed by the Checkpointer, training can be + # stopped at any point, and will be resumed on next call. + ttseval_brain.fit( + epoch_counter=ttseval_brain.hparams.epoch_counter, + train_set=datasets["train"], + valid_set=datasets["valid"], + train_loader_kwargs=hparams["dataloader_options"], + valid_loader_kwargs=hparams["dataloader_options"], + ) + + # Load the best checkpoint for evaluation + test_stats = ttseval_brain.evaluate( + test_set=datasets["test"], + min_key="error", + test_loader_kwargs=hparams["dataloader_options"], + ) diff --git a/speechbrain/inference/TTS.py b/speechbrain/inference/TTS.py index 40c7cfab14..e5648f7744 100644 --- a/speechbrain/inference/TTS.py +++ b/speechbrain/inference/TTS.py @@ -375,6 +375,7 @@ class FastSpeech2(Pretrained): HPARAMS_NEEDED = ["spn_predictor", "model", "input_encoder"] def __init__(self, *args, **kwargs): + self.return_lengths = kwargs.pop("return_lengths", False) super().__init__(*args, **kwargs) lexicon = self.hparams.lexicon lexicon = ["@@"] + lexicon @@ -567,7 +568,7 @@ def encode_batch( _, energy, _, - _, + mel_lens, ) = self.hparams.model( tokens_padded, pace=pace, @@ -578,7 +579,12 @@ def encode_batch( # Transposes to make in compliant with HiFI GAN expected format post_mel_outputs = post_mel_outputs.transpose(-1, 1) - return post_mel_outputs, durations, pitch, energy + if self.return_lengths: + mel_lens = mel_lens.to(post_mel_outputs.device) + result = post_mel_outputs, mel_lens, durations, pitch, energy + else: + result = post_mel_outputs, durations, pitch, energy + return result def forward(self, text, pace=1.0, pitch_rate=1.0, energy_rate=1.0): """Batch inference for a tensor of phoneme sequences @@ -593,8 +599,10 @@ def forward(self, text, pace=1.0, pitch_rate=1.0, energy_rate=1.0): energy_rate : float scaling factor for phoneme energies """ + if isinstance(text, str): + text = [text] return self.encode_text( - [text], pace=pace, pitch_rate=pitch_rate, energy_rate=energy_rate + text, pace=pace, pitch_rate=pitch_rate, energy_rate=energy_rate ) @@ -630,6 +638,7 @@ class FastSpeech2InternalAlignment(Pretrained): HPARAMS_NEEDED = ["model", "input_encoder"] def __init__(self, *args, **kwargs): + self.return_lengths = kwargs.pop("return_lengths", False) super().__init__(*args, **kwargs) lexicon = self.hparams.lexicon lexicon = ["@@"] + lexicon @@ -816,7 +825,7 @@ def encode_batch( _, _, _, - _, + mel_lens, _, ) = self.hparams.model( tokens_padded, @@ -828,7 +837,11 @@ def encode_batch( # Transposes to make in compliant with HiFI GAN expected format post_mel_outputs = post_mel_outputs.transpose(-1, 1) - return post_mel_outputs, durations, pitch, energy + if self.return_lengths: + result = post_mel_outputs, mel_lens, durations, pitch, energy + else: + result = post_mel_outputs, durations, pitch, energy + return result def forward(self, text, pace=1.0, pitch_rate=1.0, energy_rate=1.0): """Batch inference for a tensor of phoneme sequences @@ -843,6 +856,8 @@ def forward(self, text, pace=1.0, pitch_rate=1.0, energy_rate=1.0): energy_rate : float scaling factor for phoneme energies """ + if isinstance(text, str): + text = [text] return self.encode_text( - [text], pace=pace, pitch_rate=pitch_rate, energy_rate=energy_rate + text, pace=pace, pitch_rate=pitch_rate, energy_rate=energy_rate ) diff --git a/speechbrain/inference/eval.py b/speechbrain/inference/eval.py new file mode 100644 index 0000000000..f84828b5ee --- /dev/null +++ b/speechbrain/inference/eval.py @@ -0,0 +1,821 @@ +""" Specifies the inference interfaces for speech quality +evaluation, used to assess the quality/intelligibility of +text-to-speech systems + +Authors: +* Artem Ploujnikov 2024 +""" + +from speechbrain.inference.interfaces import Pretrained +from speechbrain.inference.ASR import EncoderDecoderASR +from speechbrain.lobes.models.huggingface_transformers import Whisper +from speechbrain.decoders.seq2seq import S2SWhisperGreedySearch +from speechbrain.dataio.batch import PaddedBatch +from speechbrain.utils.metric_stats import ErrorRateStats +from collections import namedtuple +from pathlib import Path +import os +import torch +import torchaudio +import re +import string +import logging +import shutil +import subprocess + +logger = logging.getLogger(__name__) + +RE_PUNCTUATION = re.compile( + "|".join(re.escape(char) for char in string.punctuation) +) + + +SpeechEvaluationResult = namedtuple( + "SpeechEvaluationResult", ["score", "details"] +) + + +class SpeechEvaluator: + """A base class for speech evaluators + + Arguments + --------- + sample_rate : int + The audio sample rate this evaluator expects + """ + + def __init__(self, sample_rate=16000): + self.sample_rate = sample_rate + + def evaluate_file(self, file_name, text=None): + """Evaluates a single file + + Arguments + --------- + file_name : str|pathlib.Path + The file name to evaluate + text : str + The ground truth text, if applicable + + Returns + ------- + result: SpeechEvaluationResult + the evaluation result + """ + wav = self.read_audio(str(file_name)).to(self.device) + result = self.evaluate( + wavs=wav.unsqueeze(0), + length=torch.ones(1).to(self.device), + text=[text], + ) + return SpeechEvaluationResult( + score=result.score.item(), + details={ + key: _unbatchify(value) for key, value in result.details.items() + }, + ) + + def evaluate_files(self, file_names, text=None): + """Evaluates multiple files + + Arguments + --------- + file_names : list + A list of files + + text : list + File transcripts (not required for all evaluators) + + Returns + ------- + result : list + a list of SpeechEvaluationResult instances + """ + if text is None: + text = [None] * len(file_names) + items = [ + {"wav": self.read_audio(str(file_name)), "text": item_text} + for file_name, item_text in zip(file_names, text) + ] + batch = PaddedBatch(items) + return self.evaluate( + wavs=batch.wav.data.to(self.device), + length=batch.wav.lengths.to(self.device), + text=batch.text, + ) + + def read_audio(self, file_name): + """Reads an audio file, resampling if necessary + + Arguments + --------- + file_name : str | path-like + The file path + + Returns + ------- + audio : torch.Tensor + the audio + """ + audio, audio_sample_rate = torchaudio.load(str(file_name)) + return self.resample(audio, audio_sample_rate) + + def evaluate( + self, + wavs, + length, + text=None, + wavs_ref=None, + wavs_length_ref=None, + sample_rate=None, + ): + """Evaluates samples + + Arguments + --------- + wavs : torch.Tensor + the waveforms to evaluate + + length : torch.Tensor + relative lengths (a 1-D tensor) + + text : list + Evaluator-specific metadata + + wavs_ref : torch.Tensor + the reference waveforms + + wavs_length_ref + the reference waveform lengths + + sample_rate: int, optional + The sample rate of the audio. If not provided, + the audio is assumed to be at the same sample + rate as the model + + Returns + ------- + result : list + A list of SpeechEvaluationResult objects, + one for each sample""" + raise NotImplementedError() + + def resample(self, audio, sample_rate=None): + """Resamples the audio, if necessary + + Arguments + --------- + audio : torch.Tensor + the audio to be resampled + sample_rate : int + the sample rate of the audio + + Returns + ------- + audio : torch.Tensor + the target audio, resampled if necessary + """ + if sample_rate is not None and sample_rate != self.sample_rate: + audio = torchaudio.functional.resample( + audio, orig_freq=sample_rate, new_freq=self.sample_rate + ) + return audio + + +def _unbatchify(value): + """Removes the batch dimension from the tensor. If a single + number is returned in any shape, the function converts + the result to a numeric value. Values that are not tensors + are returned unmodified + + Arguments + --------- + value : object + the value + + Returns + ------- + value : object + the value with the batch dimension removed, if applicable + """ + if torch.is_tensor(value): + if value.dim() == 0 or not any(dim > 1 for dim in value.shape): + value = value.item() + else: + value = value.squeeze(0) + return value + + +class SpeechEvaluationRegressionModel(Pretrained): + """A pretrained wrapper for regression-based evaluaton + models""" + + def __call__(self, wavs, length): + return self.mods.model(wavs, length) + + +class RegressionModelSpeechEvaluator(SpeechEvaluator): + """A speech evaluator that uses a regression model + that produces a quality score (e.g. SSL fine-tuning) + for a sample of speech + + Arguments + --------- + source : str + The source model path or HuggingFace hub name + sample_rate : int + The audio sample rate this evaluator expects + """ + + def __init__(self, source, sample_rate=None, *args, **kwargs): + super().__init__(sample_rate=sample_rate) + self.model = SpeechEvaluationRegressionModel.from_hparams( + source, *args, **kwargs + ) + + def evaluate( + self, + wavs, + length, + text=None, + wavs_ref=None, + length_ref=None, + sample_rate=None, + sample_rate_ref=None, + ): + """Evaluates a batch of waveforms + + Arguments + --------- + Arguments + --------- + wavs: torch.Tensor + the waveforms to evaluate + + length: torch.Tensor + relative lengths (a 1-D tensor) + + text : list, optional + Ground truth text + + wavs_ref : torch.Tensor + the reference waveforms + + length_ref : torch.Tensor + the reference waveform lengths + + sample_rate : int, optional + The sample rate of the audio. If not provided, + the audio is assumed to be at the same sample + rate as the model + + sample_rate_ref : int, optional + The sample rate of the reference samples + + Returns + ------- + result : SpeechEvaluationResult + an aggregated speech evaluation result with a score + for each item + """ + wavs = self.resample(wavs, sample_rate) + scores = self.model(wavs, length) + while scores.dim() > 1 and scores.size(-1) == 1: + scores = scores.squeeze(-1) + return SpeechEvaluationResult(score=scores, details={"score": scores}) + + +class ASRSpeechEvaluator(SpeechEvaluator): + """A superclass for ASR-based speech evaluators""" + + def evaluate( + self, + wavs, + length, + text=None, + wavs_ref=None, + length_ref=None, + sample_rate=None, + sample_rate_ref=None, + ): + """Evaluates samples + + Arguments + --------- + wavs: torch.Tensor + the waveforms to evaluate + + length: torch.Tensor + relative lengths (a 1-D tensor) + + text : list, optional + Ground truth text + + wavs_ref : torch.Tensor + the reference waveforms + + length_ref : torch.Tensor + the reference waveform lengths + + + sample_rate : int, optional + The sample rate of the audio. If not provided, + the audio is assumed to be at the same sample + rate as the model + + sample_rate_ref : int, optional + The sample rate of the reference samples + + Returns + ------- + result : SpeechEvaluationResult + an aggregated speech evaluation result with a score + for each item + """ + details = self.evaluate_samples( + wavs=wavs, length=length, text=text, sample_rate=sample_rate + ) + if wavs_ref is not None: + details_ref = self.evaluate_samples( + wavs=wavs_ref, + length=length_ref, + text=text, + sample_rate=sample_rate_ref, + ) + details.update( + {f"{key}_ref": value for key, value in details_ref.items()} + ) + # Redundant: it is the same + del details["target_ref"] + details.update(self.compute_diff_rate(details, device=wavs.device)) + + return SpeechEvaluationResult(score=details["wer"], details=details,) + + def compute_diff_rate(self, details, device): + """Computes diferrential scores (dWER and dCER) + + Arguments + --------- + details : dict + A details dictory, containing keys 'pred' (the ASR prediction of TTS output) + and 'pred_ref' (the ASR prediction of the ground truth) + device : str | torch.device + the device + + Returns + ------- + result : dict + A dictionary with two keys + 'dwer' : the differential Word Error Rate + 'dcer': the differential Character Error Rate + """ + ids = range(1, len(details["pred"]) + 1) + wer_metric, cer_metric = init_asr_metrics() + pred = self._replace_blanks(details["pred"]) + pred_ref = self._replace_blanks(details["pred_ref"]) + wer_metric.append(ids, pred, pred_ref) + cer_metric.append(ids, pred, pred_ref) + dwer = torch.tensor( + [score["WER"] for score in wer_metric.scores], device=device + ) + dcer = torch.tensor( + [score["WER"] for score in cer_metric.scores], device=device + ) + return {"dwer": dwer, "dcer": dcer} + + def _replace_blanks(self, preds): + return [" " if item == "" else item for item in preds] + + +class EncoderDecoderASRSpeechEvaluator(ASRSpeechEvaluator): + """A speech evaluator implementation based on ASR. + Computes the Word Error Rate (WER), Character Error Rate (CER) + and a few other metrics + + Arguments + --------- + sample_rate : int + The audio sample rate this evaluator expects + """ + + def __init__(self, source, sample_rate=None, *args, **kwargs): + super().__init__(sample_rate=sample_rate) + self.asr = EncoderDecoderASR.from_hparams(source, *args, **kwargs) + self.device = next(self.asr.mods.parameters()).device + + def evaluate_samples(self, wavs, length, text=None, sample_rate=None): + """Evaluates a batch of samples + + Arguments + --------- + wav : torch.Tensor + A batch of waveforms + lengths : torch.Tensor + Relative lengths + text : list + A list of ground truth texts, one per sample + sample_rate : int + The sample of the waveforms. If not provided, + it will be assumed to match the underlying + model + + Returns + ------- + results : dict + A results dictionary with the following keys + wer : the word error rates (tensor) + cer : the character error rate (tensor) + pred : text predictions (list of strings) + target : the ground truth (list of strings) + """ + + wavs = self.resample(wavs, sample_rate) + if text is None: + raise ValueError("This evaluator requires ground-truth text") + predicted_words, scores, log_probs = self.transcribe_batch_with_details( + wavs, length + ) + ids = range(1, len(wavs) + 1) + wer_metric, cer_metric = init_asr_metrics() + wer_metric.append(ids, predicted_words, text) + cer_metric.append(ids, predicted_words, text) + wer = torch.tensor( + [score["WER"] for score in wer_metric.scores], device=wavs.device + ) + cer = torch.tensor( + [score["WER"] for score in cer_metric.scores], device=wavs.device + ) + prob_mean = log_probs.exp().mean(dim=-1) + return { + "wer": wer, + "cer": cer, + "beam_score": scores, + "prob_mean": prob_mean, + "pred": predicted_words, + "target": text, + } + + def transcribe_batch_with_details(self, wavs, wav_lens): + """Transcribes the input audio into a sequence of words + + The waveforms should already be in the model's desired format. + You can call: + ``normalized = EncoderDecoderASR.normalizer(signal, sample_rate)`` + to get a correctly converted signal in most cases. + + Arguments + --------- + predicted_words : list + The raw ASR predictions, fully decoded + best_scores : list + The best scores (from beam search) + best_log_probs : list + The best predicted log-probabilities (from beam search) + + + Returns + ------- + predicted_words : list + The predictions + + best_scores : torch.Tensor + The best scores (from beam search) + + best_log_probs : torch.Tensor + The best log-probabilities + + """ + with torch.no_grad(): + wav_lens = wav_lens.to(self.device) + encoder_out = self.asr.encode_batch(wavs, wav_lens) + ( + hyps, + best_lens, + best_scores, + best_log_probs, + ) = self.asr.mods.decoder(encoder_out, wav_lens) + predicted_words = [ + self.asr.tokenizer.decode_ids(token_seq) for token_seq in hyps + ] + return predicted_words, best_scores, best_log_probs + + def to(self, device): + """Transfers this module to the spcieifed device + + Arguments + --------- + device : str | torch.Device + the target device + """ + self.asr = self.asr.to(device) + return self + + +class WhisperASRSpeechEvaluator(ASRSpeechEvaluator): + """Similar to EncoderDecoderASRSpeechEvaluator, but for the + Whisper-based ASR + + Arguments + --------- + source : str + The model source (path or HF hub) + savedir : str + The save directory + sample_rate : int + The model sample rate + bos_index : int + The index of the BOS token + eos_index : int + The index of the EOS token + min_decode_ratio : int + The minimum decode ratio + max_decode_ratio : int + The maximum decode ratio + run_opts : dict + The run options + """ + + def __init__( + self, + source, + savedir=None, + sample_rate=22050, + bos_index=50363, + eos_index=50257, + min_decode_ratio=0.0, + max_decode_ratio=1.0, + run_opts=None, + ): + if run_opts is None: + run_opts = {} + super().__init__(sample_rate=sample_rate) + if savedir is None: + savedir = "." + self.model = Whisper( + source, savedir, sample_rate, freeze=True, freeze_encoder=True, + ) + self.model.tokenizer.set_prefix_tokens("english", "transcribe", False) + self.searcher = S2SWhisperGreedySearch( + self.model, + bos_index=bos_index, + eos_index=eos_index, + min_decode_ratio=min_decode_ratio, + max_decode_ratio=max_decode_ratio, + ) + self.searcher.set_decoder_input_tokens( + self.model.tokenizer.prefix_tokens + ) + device = run_opts.get("device", next(self.model.parameters()).device) + self.to(device) + + def evaluate_samples(self, wavs, length, text=None, sample_rate=None): + """Evaluates a batch of samples + + Arguments + --------- + wav : torch.Tensor + A batch of waveforms + lengths : torch.Tensor + Relative lengths + text : list + A list of ground truth texts, one per sample + sample_rate : int + The sample of the waveforms. If not provided, + it will be assumed to match the underlying + model + + Returns + ------- + results : dict + A results dictionary with the following keys + wer : the word error rates (tensor) + cer : the character error rate (tensor) + pred : text predictions (list of strings) + target : the ground truth (list of strings) + """ + if text is None: + raise ValueError("This evaluator requires ground-truth text") + wavs = self.resample(wavs, sample_rate) + enc_out = self.model.forward_encoder(wavs) + predicted_words, _, _, _ = self.searcher(enc_out, length) + predicted_words = self.model.tokenizer.batch_decode( + predicted_words, skip_special_tokens=True + ) + predicted_words = [self.normalize(text) for text in predicted_words] + ids = range(1, len(wavs) + 1) + wer_metric, cer_metric = init_asr_metrics() + wer_metric.append(ids, predicted_words, text) + cer_metric.append(ids, predicted_words, text) + wer = torch.tensor( + [score["WER"] for score in wer_metric.scores], device=wavs.device + ) + cer = torch.tensor( + [score["WER"] for score in cer_metric.scores], device=wavs.device + ) + return { + "wer": wer, + "cer": cer, + "pred": predicted_words, + "target": text, + } + + def normalize(self, text): + """Normalizes the prediction by converting to uppercase, + and removing leading/trailing spaces and punctuation + + Arguments + --------- + text : str + Unnormalized text + + Returns + ------- + result : str + Normalized text + """ + text = text.upper() + text = text.strip() + text = RE_PUNCTUATION.sub("", text) + return text + + def to(self, device): + """Transfers this module to the spcieifed device + + Arguments + --------- + device : str | torch.Device + the target device + """ + self.model = self.model.to(device) + return self + + +def itemize(result): + """Converts a single batch result into per-item results + + Arguments + --------- + result: SpeechEvaluationResult + a single batch result + + Returns + ------- + results: list + a list of individual SpeechEvaluationResult instances""" + + return [ + SpeechEvaluationResult( + score=result.score[idx], + details={key: value[idx] for key, value in result.details.items()}, + ) + for idx in range(len(result.score)) + ] + + +def init_asr_metrics(): + """Initializes the WER and CER metrics + + Returns + ------- + wer_metric : ErrorRateStats + the Word Error Rate (WER) metric + cer_metric : ErrorRateStats + the Character Error Rate (CER) metric""" + wer_metric = ErrorRateStats() + cer_metric = ErrorRateStats(split_tokens=True) + return wer_metric, cer_metric + + +class BulkSpeechEvaluator: + """A superclass for speech evaluations that can only + evaluate lists of files - useful for wrappers around + external tools""" + + def evaluate_files(self, file_names, text=None, file_names_ref=None): + """Evaluates a collection of files + + Arguments + --------- + file_names : str + A list of file names + text : str + The reference text for the files, if applicable + file_names_ref + The file names of applicable ground truth, if applicable + + Returns + ------- + result: SpeechEvaluationResult + the evaluation result + """ + raise NotImplementedError() + + +class UTMOSSpeechEvaluator(BulkSpeechEvaluator): + """A speech evaluator that uses a pretrained UTMOS from HuggingFace + + Paper: https://arxiv.org/abs/2204.02152 + Implementation: https://huggingface.co/spaces/sarulab-speech/UTMOS-demo + + Arguments + --------- + model_path : str + The path to the pretrained model (i.e. a clone of HuggingFace code) + output_folder : str + The output folder + ckpt_path : str + The path to the checkpoint + script : str + The name of the script to be called + python : str + The python interpreter to be used + use_python : bool, optional + Whether to invoke using the script python interpreter. True by default + Set this by default if a custom script is needed (e.g. to set up the + environment) + batch_size : int + The batch size + + """ + + def __init__( + self, + model_path, + output_folder, + ckpt_path, + script="predict.py", + python="python", + use_python=True, + batch_size=8, + ): + self.output_folder = Path(output_folder) + rand = torch.randint(1, 999999999, (1,)).item() + self.eval_path = (self.output_folder / f"eval_{rand}").absolute() + self.model_path = Path(model_path).absolute() + script = self.model_path / script + self.script = script + self.ckpt_path = Path(ckpt_path).absolute() + self.batch_size = batch_size + self.python = python + self.use_python = use_python + + def evaluate_files(self, file_names, text, file_names_ref=None): + """Evaluates a collection of files + + Arguments + --------- + file_names : str + A list of file names + text : str + The reference text for the files. Ignored for UTMOS. + file_names_ref + The file names of applicable ground truth. + Ignored for UTMOS + + Returns + ------- + result: SpeechEvaluationResult + the evaluation result + """ + current_path = os.getcwd() + try: + self.eval_path.mkdir(parents=True, exist_ok=True) + logger.info("Copying the files to '%s'", self.eval_path) + for file_name in file_names: + target_file_name = self.eval_path / Path(file_name).name + shutil.copy(file_name, target_file_name) + + logger.info("Running evaluation") + result_path = self.eval_path / "result.txt" + os.chdir(self.model_path) + cmd = [ + str(self.script), + "--mode", + "predict_dir", + "--bs", + str(self.batch_size), + "--inp_dir", + str(self.eval_path), + "--out_path", + result_path, + "--ckpt_path", + str(self.ckpt_path), + ] + if self.use_python: + cmd = [self.python] + cmd + + output = subprocess.check_output(cmd) + logger.info("Evaluation finished, output: %s", output) + file_names = [path.name for path in self.eval_path.glob("*.wav")] + with open(result_path) as result_path: + scores = [float(line.strip()) for line in result_path] + score_map = dict(zip(file_names, scores)) + scores_ordered = [ + score_map[Path(file_name).name] for file_name in file_names + ] + return SpeechEvaluationResult( + scores_ordered, {"utmos": scores_ordered} + ) + finally: + os.chdir(current_path) + shutil.rmtree(self.eval_path) diff --git a/speechbrain/lobes/models/eval/__init__.py b/speechbrain/lobes/models/eval/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/speechbrain/lobes/models/eval/ssl.py b/speechbrain/lobes/models/eval/ssl.py new file mode 100644 index 0000000000..ad11581483 --- /dev/null +++ b/speechbrain/lobes/models/eval/ssl.py @@ -0,0 +1,224 @@ +""" +Speech quality assessment models based on self-supervised learning (SSL) +model finetuning + +Authors + * Artem Ploujnikov 2023 + * Yingzi Wang 2024 +""" + +import torch +from torch import nn +from speechbrain.nnet.linear import Linear +from speechbrain.lobes.models.transformer.Transformer import ( + TransformerEncoder, + PositionalEncoding, +) +from speechbrain.dataio.dataio import length_to_mask +from speechbrain.nnet.normalization import BatchNorm1d +from speechbrain.nnet.pooling import StatisticsPooling + + +class BaselineSSLFinetune(nn.Module): + """A baseline self-supervised learning representation fine-tuning + model, inspired by the following: + + https://github.com/nii-yamagishilab/mos-finetune-ssl + + Arguments + --------- + base_model : torch.nn.Module + The base model to be used + + feats_dim : int, optional + The feature dimension. If omitted, it will be computed automatically + + Example + ------- + >>> from speechbrain.lobes.models.eval.ssl import BaselineSSLFinetune + >>> from speechbrain.nnet.linear import Linear + >>> from torch import nn + >>> import torch + >>> class FakeBaseModel(nn.Module): + ... def __init__(self, output_size): + ... super().__init__() + ... self.lin = Linear( + ... input_size=1, + ... n_neurons=output_size + ... ) + ... def forward(self, x, length): + ... return self.lin(x.unsqueeze(-1)) + >>> fake_base_model = FakeBaseModel(128) + >>> model = BaselineSSLFinetune( + ... base_model=fake_base_model + ... ) + >>> x = torch.randn(4, 100) + >>> length = torch.ones(4) + >>> scores = model(x, length) + >>> scores.shape + torch.Size([4, 1, 1]) + """ + + def __init__(self, base_model, feats_dim=None): + super().__init__() + self.base_model = base_model + if feats_dim is None: + feats_dim = compute_feats_dim(base_model) + self.feats_dim = feats_dim + self.pool = StatisticsPooling(return_std=False) + self.out = Linear(n_neurons=1, input_size=feats_dim) + + def forward(self, wav, length): + """Computes the forward pass + + Arguments + --------- + wav : torch.Tensor + The waveform (in the format understood by the base model) + Typically (Batch x Time) or (Batch x Channel x Time) + length : torch.Tensor + A 1-D tensor of relative lengths + + Returns + ------- + result : torch.Tensor + a 1-D tensor with an estimated speech quality rating + (the scale used depends on the training data) + """ + x = self.base_model(wav, length) + x = self.pool(x, length) + x = self.out(x) + return x + + +class TransformerRegression(nn.Module): + """A simple extension of the SSL fine-tuning model that adds a non-autoregressive + transformer layer on top of SSL representation followed by average pooling. The + idea is to train a new model for the evaluation task instead of - or in addition to + - attempting to update the weights of the base model + + Arguments + --------- + base_model : torch.nn.Module + The base model converting an audio/speech signal to a latent representation + feats_dim : int, optional + The feature dimension. If omitted, it will be computed automatically + d_model : int, optional + The transformer model dimension + d_ffn : int, optional + The transformer feed-forward network dimension + num_layers : int, optional + The number of transformer layers + nhead : int, optional + The number of transformer heads + activation : torch.nn.Module, optional + The type of activation to use (defaults to LeakyRELU) + dropout : float, optional + The dropout probability + max_len : int + The maximum sequence length + + Example + ------- + >>> from speechbrain.lobes.models.eval.ssl import TransformerRegression + >>> from speechbrain.nnet.linear import Linear + >>> from torch import nn + >>> import torch + >>> class FakeBaseModel(nn.Module): + ... def __init__(self, output_size): + ... super().__init__() + ... self.lin = Linear( + ... input_size=1, + ... n_neurons=output_size + ... ) + ... def forward(self, x, length): + ... return self.lin(x.unsqueeze(-1)) + >>> fake_base_model = FakeBaseModel(128) + >>> model = TransformerRegression( + ... base_model=fake_base_model + ... ) + >>> x = torch.randn(4, 100) + >>> length = torch.ones(4) + >>> scores = model(x, length) + >>> scores.shape + torch.Size([4, 1, 1]) + """ + + def __init__( + self, + base_model, + feats_dim=None, + d_model=512, + d_ffn=2048, + num_layers=3, + nhead=4, + activation=None, + dropout=0.2, + max_len=2500, + ): + super().__init__() + self.base_model = base_model + + if activation is None: + activation = nn.LeakyReLU + + if feats_dim is None: + feats_dim = compute_feats_dim(base_model) + self.feats_norm = BatchNorm1d(input_size=feats_dim) + self.feat_proj = Linear(n_neurons=d_model, input_size=feats_dim) + self.pos_emb = PositionalEncoding(input_size=d_model, max_len=max_len) + + self.transformer = TransformerEncoder( + num_layers=num_layers, + nhead=nhead, + d_model=d_model, + d_ffn=d_ffn, + activation=nn.LeakyReLU, + dropout=dropout, + normalize_before=True, + ) + self.pool = StatisticsPooling(return_std=False) + self.out_proj = Linear(n_neurons=1, input_size=d_model) + + def forward(self, wav, length): + """Computes the forward pass + + Arguments + --------- + wav : torch.Tensor + The waveform (in the format understood by the base model) + Typically (Batch x Time) or (Batch x Channel x Time) + length : torch.Tensor + A 1-D tensor of relative lengths + + Returns + ------- + result : torch.Tensor + a 1-D tensor with an estimated speech quality rating + (the scale used depends on the training data) + """ + x = self.base_model(wav, length) + x = self.feats_norm(x) + pos_emb = self.pos_emb(x) + x = self.feat_proj(x) + pos_emb + abs_len = torch.round(length * x.shape[1]) + src_key_padding_mask = ~length_to_mask(abs_len).bool() + x, _ = self.transformer(x, src_key_padding_mask=src_key_padding_mask) + x = self.pool(x) + x = self.out_proj(x) + return x + + +def compute_feats_dim(model): + """Computes the feature dimension by feeding a fake tensor to the model + + Arguments + --------- + model : torch.nn.Module + A model that takes audio input + """ + device = next(model.parameters()).device + wav = torch.randn(1, 1000, device=device) + length = torch.tensor([1.0], device=device) + out = model(wav, length) + return out.size(-1) diff --git a/speechbrain/utils/metric_stats.py b/speechbrain/utils/metric_stats.py index c79efb25ee..1bb0568d6b 100644 --- a/speechbrain/utils/metric_stats.py +++ b/speechbrain/utils/metric_stats.py @@ -9,6 +9,9 @@ """ import torch +import numpy as np +import csv +import logging from joblib import Parallel, delayed from speechbrain.utils.data_utils import undo_padding from speechbrain.utils.edit_distance import wer_summary, wer_details_for_batch @@ -19,6 +22,8 @@ ) from speechbrain.dataio.wer import print_wer_summary, print_alignments +logger = logging.getLogger(__name__) + class MetricStats: """A default class for storing and summarizing arbitrary metrics. @@ -1106,3 +1111,300 @@ def wrapper(*args, **kwargs): return result._asdict() if has_asdict else result return wrapper + + +KEY_ID = "id" +KEY_DIFF = "diff" +KEY_DIFF_SQ = "diff_sq" + + +class LinearRegressionStats(MetricStats): + """Computes a simple linear correlation between two metrics - useful + for regression tasks, such as quality assessment. It provides an optional + grouping option, in which case the correlation is computed between means + of groups rather than individual samples. The original use case for grouping + is producing system-level correlation for the MOS estimation task + (as opposed to utterance-level). + + Arguments + --------- + grouped: bool, Optional + If set to true, statistics will be grouped + scores_label : str, optional + The user-facing label for scores, to be shown on plots + targets_label : str, optional + The user-facing label for targets, to be shown on plots + scores_key : str, optional + The field name for scores, to be used for raw data output + targets_key : str, optional + The field name for targets, to be used for raw data output + plot_pad_left : float + The amount of padding on the left + plot_pad_bottom : float + The amount of padding on the bottom + + Example + ------- + >>> import torch + >>> from speechbrain.utils.metric_stats import LinearRegressionStats + >>> reg_stats = LinearRegressionStats() + >>> reg_stats.append( + ... ids=["ID1", "ID2"], + ... predict=torch.tensor([1.25, 2.75]), + ... target=torch.tensor([1.00, 3.00]), + ... ) + >>> reg_stats.append( + ... ids=["ID3", "ID4"], + ... predict=torch.tensor([5.5, 3.5]), + ... target=torch.tensor([5.0, 3.0]), + ... ) + >>> summary = reg_stats.summarize() + >>> summary = {key: round(value, 2) for key, value in summary.items()} + >>> summary["scores_mean"] + 3.25 + >>> summary["scores_std"] + 1.77 + >>> summary["targets_mean"] + 3.0 + >>> summary["targets_std"] + 1.63 + >>> summary["slope"] + 0.91 + >>> summary["intercept"] + 0.05 + >>> summary["pearson_r"] + 0.98 + >>> reg_stats = LinearRegressionStats(grouped=True) + >>> reg_stats.append( + ... ids=["ID1", "ID2", "ID3", "ID4"], + ... predict=torch.tensor([1.25, 2.75]), + ... target=torch.tensor([1.00, 3.00]), + ... groups=["G1", "G2", "G3", "G2"], + ... ) + >>> reg_stats.append( + ... ids=["ID5", "ID6", "ID7", "ID8"], + ... predict=torch.tensor([5.5, 3.5, 2.2, 1.0]), + ... target=torch.tensor([5.0, 3.0, 2.0, 1.2]), + ... groups=["G1", "G2", "G3", "G1"], + ... ) + >>> summary = reg_stats.summarize() + >>> summary = {key: round(value, 2) for key, value in summary.items()} + >>> summary["scores_mean"] + 3.21 + >>> summary["scores_std"] + 2.01 + >>> summary["targets_mean"] + 2.97 + >>> summary["targets_std"] + 1.82 + >>> summary["slope"] + 0.9 + >>> summary["intercept"] + 0.07 + >>> summary["pearson_r"] + 1.0 + """ + + def __init__( + self, + grouped=False, + scores_label="y", + targets_label="x", + scores_key="y", + targets_key="x", + plot_pad_left=0.2, + plot_pad_bottom=0.1, + ): + self.clear() + self.targets = [] + self.groups = [] + self.grouped = grouped + self.scores_label = scores_label + self.targets_label = targets_label + self.scores_key = scores_key + self.targets_key = targets_key + self.plot_pad_left = plot_pad_left + self.plot_pad_bottom = plot_pad_bottom + + def append( + self, ids, predict, target, groups=None, + ): + """Appends a measurement + + Arguments + --------- + ids : list + a list of item IDs + predict : torch.Tensor + the prediction tensor + target : torch.Tensor + the target tensor + groups : list, optional + the group indicator for each item, ignored + if grouped is set to false + """ + self.ids.extend(ids) + self.scores.extend(_flatten(predict)) + self.targets.extend(_flatten(target)) + if self.grouped: + self.groups.extend(groups) + + def group_data(self): + """Returns the group means of scores and targets""" + grouped_scores = _group(self.scores, self.groups) + grouped_targets = _group(self.targets, self.groups) + groups = sorted(grouped_scores.keys()) + scores = np.array([grouped_scores[group].mean() for group in groups]) + targets = np.array([grouped_targets[group].mean() for group in groups]) + return scores, targets + + def get_regression_data(self): + """Prepares data for regression. If grouping is disabled, collected + scores and targets are converted to arrays. If it is enabled, grouped + data will be aggregated first + + Returns + ------- + scores : numpy.array + Estimated scores / metric values + targets : numpy.array + Ground truths""" + if self.grouped: + scores, targets = self.group_data() + else: + scores = np.array(self.scores) + targets = np.array(self.targets) + return scores, targets + + def summarize(self, field=None): + """Summarizes linear regression statistics + + Full set of fields: + - scores_mean - the mean of scores + - scores_std - the standard deviation of scores + - targets_mean - the mean of targets + - targets_std - the standard deviation of targets + - slope - the slope of the regression line + - intercept - the intercept of the regression line + - pearson_r - the Pearson correlation coefficient + """ + scores, targets = self.get_regression_data() + has_data = len(scores) > 0 + if has_data: + x = np.stack([scores, np.ones_like(scores)], axis=1) + solution, _, _, _ = np.linalg.lstsq(x, targets, rcond=None) + slope, intercept = solution.squeeze() + corr_mat = np.corrcoef(scores, targets) + pearson_r = corr_mat[0][1] + self.summary = { + "scores_mean": scores.mean() if has_data else 0.0, + "scores_std": scores.std(ddof=1) if has_data else 0.0, + "targets_mean": targets.mean() if has_data else 0.0, + "targets_std": targets.std(ddof=1) if has_data else 0.0, + "slope": slope if has_data else 0.0, + "intercept": intercept if has_data else 0.0, + "pearson_r": pearson_r if has_data else 0.0, + } + if field: + return self.summary[field] + else: + return self.summary + + def plot(self, output=None): + """Outputs a regression plot, optionally saving it to a file or a + stream, returning a Matplotlib figure. Requires Seaborn. + + Arguments + --------- + output : str | path-like | BytesIO + The path to which the diagram will be saved + + Returns + ------- + fig : figure + A Matplotlib figure""" + try: + import seaborn as sns + import matplotlib + except ImportError: + raise ImportError("Regression plots require Seaborn") + matplotlib.use("Agg") + if self.summary is None: + self.summarize() + scores, targets = self.get_regression_data() + if len(scores) == 0: + logger.warning("Cannot produce a plot - no data found") + return None + h = sns.jointplot(x=targets, y=scores, kind="reg") + + r = self.summary["pearson_r"] + h.figure.suptitle(f"r = {r:.3f}", x=self.plot_pad_left) + + h.ax_joint.set_xlabel(self.targets_label) + h.ax_joint.set_ylabel(self.scores_label) + h.figure.subplots_adjust( + left=self.plot_pad_left, bottom=self.plot_pad_bottom + ) + if output is not None: + h.figure.savefig(output) + return h.figure + + def write_csv(self, file_name): + """Outputs raw data, as CSV + + Arguments + --------- + file_name : str, path-like + The path to which raw statistics will be written""" + scores = np.array(self.scores) + targets = np.array(self.targets) + diff = scores - targets + diff_sq = diff ** 2 + with open(file_name, "w") as csv_file: + writer = csv.writer(csv_file) + header = [ + KEY_ID, + self.scores_key, + self.targets_key, + KEY_DIFF, + KEY_DIFF_SQ, + ] + writer.writerow(header) + rows = zip(self.ids, scores, targets, diff, diff_sq) + writer.writerows(rows) + + +def _flatten(x): + """Removes size-1 dimensions from the end but does not remove the + batch dimension""" + while x.dim() > 1 and x.size(-1) == 1: + x = x.squeeze(-1) + return x.tolist() + + +def _group(data, groups): + """Collects raw data into groups (naive implementation using a simple in-memory + dictionary) + + Arguments + --------- + data : list + a list of numeric data + groups : list + a list of group indicators + + Returns + ------- + results : dict + a dictionary with group labels as keys and the corresponding + data as values""" + grouped_data = {} + for item, group in zip(data, groups): + if group not in grouped_data: + grouped_data[group] = [] + grouped_data[group].append(item) + + return { + group: np.array(group_data) + for group, group_data in grouped_data.items() + } diff --git a/tests/recipes/LJSpeech.csv b/tests/recipes/LJSpeech.csv index 2aeabc8f70..15110734a4 100644 --- a/tests/recipes/LJSpeech.csv +++ b/tests/recipes/LJSpeech.csv @@ -8,3 +8,5 @@ TTS,LJSpeech,recipes/LJSpeech/TTS/vocoder/hifi_gan_unit/train.py,recipes/LJSpeec quantization,LJSpeech,recipes/LJSpeech/quantization/train.py,recipes/LJSpeech/quantization/hparams/train_with_hubert.yaml,recipes/LJSpeech/quantization/ljspeech_prepare.py,recipes/LJSpeech/quantization/README.md,https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0,,--data_folder=tests/samples/TTS --train_json=tests/samples/annotation/TTS_train.json --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]" quantization,LJSpeech,recipes/LJSpeech/quantization/train.py,recipes/LJSpeech/quantization/hparams/train_with_wav2vec.yaml,recipes/LJSpeech/quantization/ljspeech_prepare.py,recipes/LJSpeech/quantization/README.md,https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0,,--data_folder=tests/samples/TTS --train_json=tests/samples/annotation/TTS_train.json --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]" quantization,LJSpeech,recipes/LJSpeech/quantization/train.py,recipes/LJSpeech/quantization/hparams/train_with_wavlm.yaml,recipes/LJSpeech/quantization/ljspeech_prepare.py,recipes/LJSpeech/quantization/README.md,https://www.dropbox.com/sh/bk5qz0u1ppx15jk/AAAj23FI3AVKtfRKGvyHJYHza?dl=0,,--data_folder=tests/samples/TTS --train_json=tests/samples/annotation/TTS_train.json --skip_prep=True,"file_exists=[log.txt,train.py,env.log,hyperparams.yaml]" +evaluation,LJSpeech,recipes/LJSpeech/evaluation/evaluate.py,recipes/LJSpeech/evaluation/hparams/tacotron2.yaml,recipes/LJSpeech/evaluation/ljspeech_prepare.py,recipes/LJSpeech/evaluation/README.md,,,--data_folder=tests/samples/TTS --train_json=tests/samples/annotation/TTS_train.json --skip_prep=True --eval_dataset train,"file_exists=[mos.csv,asr.csv,log.txt,summary.json]" +evaluation,LJSpeech,recipes/LJSpeech/evaluation/evaluate.py,recipes/LJSpeech/evaluation/hparams/fastspeech2.yaml,recipes/LJSpeech/evaluation/ljspeech_prepare.py,recipes/LJSpeech/evaluation/README.md,,,--data_folder=tests/samples/TTS --train_json=tests/samples/annotation/TTS_train.json --skip_prep=True --eval_dataset train,"file_exists=[mos.csv,asr.csv,log.txt,summary.json]" \ No newline at end of file diff --git a/tests/recipes/SOMOS.csv b/tests/recipes/SOMOS.csv new file mode 100644 index 0000000000..bf6d44187c --- /dev/null +++ b/tests/recipes/SOMOS.csv @@ -0,0 +1,2 @@ +Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks +ttseval,SOMOS,recipes/SOMOS/ttseval/train.py,recipes/SOMOS/ttseval/hparams/train.yaml,recipes/SOMOS/somos_prepare.py,recipes/SOMOS/ttseval/README.md,https://www.dropbox.com/tbd,https://huggingface.co/flexthink/ttseval-wavlm-transformer,--batch_size=2 --number_of_epochs=2 --data_folder=tests/samples/TTS --train_annotation=tests/samples/annotation/TTS_eval_train.csv --valid_annotation=tests/samples/annotation/TTS_eval_train.csv --test_annotation=tests/samples/annotation/TTS_eval_train.csv --skip_prep=True --num_workers 0,"file_exists=[train_log.txt,log.txt,env.log,train.py,hyperparams.yaml]" diff --git a/tests/samples/annotation/TTS_eval_train.csv b/tests/samples/annotation/TTS_eval_train.csv new file mode 100644 index 0000000000..99aee5e74f --- /dev/null +++ b/tests/samples/annotation/TTS_eval_train.csv @@ -0,0 +1,2 @@ +ID,wav,system,score +LJ050-0131,$data_root/LJ050-0131.wav,0005,3.95 \ No newline at end of file diff --git a/tests/unittests/test_metrics.py b/tests/unittests/test_metrics.py index 62a7970cd3..7d3f338cd3 100755 --- a/tests/unittests/test_metrics.py +++ b/tests/unittests/test_metrics.py @@ -198,3 +198,27 @@ def test_classification_stats_report(): -> B: 1 / 1 (100.00%) """ assert report == ref_report + + +def test_linear_regression_stats(): + from speechbrain.utils.metric_stats import LinearRegressionStats + + reg_stats = LinearRegressionStats() + reg_stats.append( + ids=["ID1", "ID2"], + predict=torch.tensor([1.25, 2.75]), + target=torch.tensor([1.00, 3.00]), + ) + reg_stats.append( + ids=["ID3", "ID4"], + predict=torch.tensor([5.5, 3.5]), + target=torch.tensor([5.0, 3.0]), + ) + summary = reg_stats.summarize() + assert math.isclose(3.25, summary["scores_mean"], rel_tol=0.01) + assert math.isclose(1.7678, summary["scores_std"], rel_tol=0.01) + assert math.isclose(3.0, summary["targets_mean"], rel_tol=0.01) + assert math.isclose(1.633, summary["targets_std"], rel_tol=0.01) + assert math.isclose(0.9067, summary["slope"], rel_tol=0.01) + assert math.isclose(0.0533, summary["intercept"], rel_tol=0.01) + assert math.isclose(0.9814, summary["pearson_r"], rel_tol=0.01)