diff --git a/examples/asr/conf/ssl/conformer/conformer_ssl.yaml b/examples/asr/conf/ssl/conformer/conformer_ssl.yaml index e3776535ac25..505e3ae528c6 100644 --- a/examples/asr/conf/ssl/conformer/conformer_ssl.yaml +++ b/examples/asr/conf/ssl/conformer/conformer_ssl.yaml @@ -27,8 +27,10 @@ model: train_ds: manifest_filepath: ??? + manifest_speaker_verification_fp: ??? + manifest_content_fp: ??? sample_rate: ${model.sample_rate} - batch_size: 16 # you may increase batch_size if your memory allows + batch_size: 32 # you may increase batch_size if your memory allows shuffle: true num_workers: 8 pin_memory: false diff --git a/examples/tts/conf/ssl_tts.yaml b/examples/tts/conf/ssl_tts.yaml new file mode 100644 index 000000000000..1accbfa1e731 --- /dev/null +++ b/examples/tts/conf/ssl_tts.yaml @@ -0,0 +1,198 @@ +# This config contains the default values for self-supervised pre-training of a Conformer ASR model, large size (~120M). + +# Architecture and training config: +# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective +# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches. +# Here are the recommended configs for different variants of Conformer-CTC, other parameters are the same as in this config file. +# One extra layer (compared to original paper) is added to the medium and large variants to compensate for replacing the LSTM decoder with a linear one. +# +# +-------------+---------+---------+----------+------------+-----+ +# | Model | d_model | n_heads | n_layers | time_masks | lr | +# +=============+=========+========+===========+============+=====+ +# | Small (13M)| 176 | 4 | 16 | 5 | 5.0 | +# +-------------+---------+--------+-----------+------------+-----+ +# | Medium (30M)| 256 | 4 | 18 | 5 | 5.0 | +# +-------------+---------+--------+-----------+------------+-----+ +# | Large (121M)| 512 | 8 | 18 | 10 | 2.0 | +# +---------------------------------------------------------------+ +# +# If you do not want to train with AMP, you may use weight decay of 0.0 or reduce the number of time maskings to 2 +# with time_width=100. It may help when you want to train for fewer epochs and need faster convergence. +# With weight_decay=0.0, learning rate may need to get reduced to 2.0. + +name: "Conformer-SSL" +phoneme_dict_path: "scripts/tts_dataset_files/cmudict-0.7b_nv22.01" +heteronyms_path: "scripts/tts_dataset_files/heteronyms-030921" +init_from_pretrained_model: "ssl_en_conformer_large" + +model: + sample_rate: 16000 + combined_loss: true + train_ds: + manifest_filepath: "/home/shehzeenh/datasets/speaker_verification_full/vox1/train_formatted.json" + manifest_speaker_verification_fp: "/home/shehzeenh/datasets/speaker_verification_full/vox1/train_formatted.json" + manifest_content_fp: "/home/shehzeenh/datasets/All_LibriSpeech/train_clean_100.json" + sample_rate: ${model.sample_rate} + batch_size: 16 # you may increase batch_size if your memory allows + shuffle: true + num_workers: 8 + pin_memory: false + use_start_end_token: true + trim_silence: false + max_duration: 16.7 + min_duration: 8.0 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "synced_randomized" + bucketing_batch_size: null + + validation_ds: + manifest_filepath: "/home/shehzeenh/datasets/speaker_verification_full/vox1/dev_formatted.json" + manifest_speaker_verification_fp: "/home/shehzeenh/datasets/speaker_verification_full/vox1/dev_formatted.json" + manifest_content_fp: "/home/shehzeenh/datasets/All_LibriSpeech/dev_clean.json" + sample_rate: ${model.sample_rate} + batch_size: 16 # you may increase batch_size if your memory allows + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: false + min_duration: 8.0 + + # text_tokenizer: + # _target_: nemo.collections.tts.torch.tts_tokenizers.EnglishPhonemesTokenizer + # punct: true + # stresses: true + # chars: true + # apostrophe: true + # pad_with_space: true + # g2p: + # _target_: nemo.collections.tts.torch.g2ps.EnglishG2p + # phoneme_dict: ${phoneme_dict_path} + # heteronyms: ${heteronyms_path} + # phoneme_probability: 0.5 + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 0.00001 + pad_to: 16 + pad_value: 0.0 + + spec_augment: + _target_: nemo.collections.asr.modules.MaskedPatchAugmentation + freq_masks: 3 + freq_width: 20 + patch_size: 48 + mask_patches: 0.5 + + downstream_heads: + task_names: ['speaker_verification', 'content'] + speaker_embed_size: 256 + num_speakers: 1211 + content_embed_size: 256 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 18 + d_model: 512 + + # Sub-sampling params + subsampling: striding # vggnet or striding, vggnet may give better results but needs more memory + subsampling_factor: 4 # must be power of 2 + subsampling_conv_channels: -1 # -1 sets it to d_model + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 31 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + decoder_out: 128 + + + optim_backbone: + _target_: torch.optim.Adam + lr: 1e-5 + # optimizer arguments + # betas: [0.9, 0.98] + sched: + min_lr: 1e-6 + warmup_steps: 1000 + + optim_downstream: + _target_: torch.optim.Adam + lr: 1e-4 + sched: + min_lr: 1e-6 + warmup_steps: 1000 + + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: 1000 + max_steps: null # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + # gradient_clip_val: 1.0 + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 10 # Interval of logging. + progress_bar_refresh_rate: 10 + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_loss" + mode: "min" + save_top_k: 5 + + # you need to set these two to True to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/tts/conf/ssl_tts_ngc.yaml b/examples/tts/conf/ssl_tts_ngc.yaml new file mode 100644 index 000000000000..c8b279a1b646 --- /dev/null +++ b/examples/tts/conf/ssl_tts_ngc.yaml @@ -0,0 +1,188 @@ +# This config contains the default values for self-supervised pre-training of a Conformer ASR model, large size (~120M). + +# Architecture and training config: +# Default learning parameters in this config are set for effective batch size of 2K. To train it with smaller effective +# batch sizes, you may need to re-tune the learning parameters or use higher accumulate_grad_batches. +# Here are the recommended configs for different variants of Conformer-CTC, other parameters are the same as in this config file. +# One extra layer (compared to original paper) is added to the medium and large variants to compensate for replacing the LSTM decoder with a linear one. +# +# +-------------+---------+---------+----------+------------+-----+ +# | Model | d_model | n_heads | n_layers | time_masks | lr | +# +=============+=========+========+===========+============+=====+ +# | Small (13M)| 176 | 4 | 16 | 5 | 5.0 | +# +-------------+---------+--------+-----------+------------+-----+ +# | Medium (30M)| 256 | 4 | 18 | 5 | 5.0 | +# +-------------+---------+--------+-----------+------------+-----+ +# | Large (121M)| 512 | 8 | 18 | 10 | 2.0 | +# +---------------------------------------------------------------+ +# +# If you do not want to train with AMP, you may use weight decay of 0.0 or reduce the number of time maskings to 2 +# with time_width=100. It may help when you want to train for fewer epochs and need faster convergence. +# With weight_decay=0.0, learning rate may need to get reduced to 2.0. + +name: "Conformer-SSL" +phoneme_dict_path: "scripts/tts_dataset_files/cmudict-0.7b_nv22.01" +heteronyms_path: "scripts/tts_dataset_files/heteronyms-030921" +init_from_pretrained_model: "ssl_en_conformer_large" + +model: + sample_rate: 16000 + + train_ds: + manifest_filepath: "/raid/ssl_tts_tmp/train_vox_ngc.json" + manifest_speaker_verification_fp: "/raid/ssl_tts_tmp/train_vox_ngc.json" + manifest_content_fp: "/raid/ssl_tts_tmp/train_libri_ngc.json" + sample_rate: ${model.sample_rate} + batch_size: 12 # you may increase batch_size if your memory allows + shuffle: true + num_workers: 8 + pin_memory: false + use_start_end_token: true + trim_silence: false + max_duration: 16.7 + min_duration: 8.0 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "synced_randomized" + bucketing_batch_size: null + + validation_ds: + manifest_filepath: "/raid/ssl_tts_tmp/dev_vox_ngc.json" + manifest_speaker_verification_fp: "/raid/ssl_tts_tmp/dev_vox_ngc.json" + manifest_content_fp: "/raid/ssl_tts_tmp/dev_libri_ngc.json" + sample_rate: ${model.sample_rate} + batch_size: 12 # you may increase batch_size if your memory allows + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: false + min_duration: 8.0 + + # text_tokenizer: + # _target_: nemo.collections.tts.torch.tts_tokenizers.EnglishPhonemesTokenizer + # punct: true + # stresses: true + # chars: true + # apostrophe: true + # pad_with_space: true + # g2p: + # _target_: nemo.collections.tts.torch.g2ps.EnglishG2p + # phoneme_dict: ${phoneme_dict_path} + # heteronyms: ${heteronyms_path} + # phoneme_probability: 0.5 + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 0.00001 + pad_to: 16 + pad_value: 0.0 + + spec_augment: + _target_: nemo.collections.asr.modules.MaskedPatchAugmentation + freq_masks: 3 + freq_width: 20 + patch_size: 48 + mask_patches: 0.5 + + downstream_heads: + task_names: ['speaker_verification', 'content'] + speaker_embed_size: 256 + num_speakers: 1211 + content_embed_size: 256 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 18 + d_model: 512 + + # Sub-sampling params + subsampling: striding # vggnet or striding, vggnet may give better results but needs more memory + subsampling_factor: 4 # must be power of 2 + subsampling_conv_channels: -1 # -1 sets it to d_model + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 31 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + decoder_out: 128 + + + optim: + name: adam + lr: 1e-5 + # optimizer arguments + # betas: [0.9, 0.98] + + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: 1000 + max_steps: null # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + # gradient_clip_val: 1.0 + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 10 # Interval of logging. + progress_bar_refresh_rate: 10 + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_loss" + mode: "min" + save_top_k: 5 + + # you need to set these two to True to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/tts/sleep_command.py b/examples/tts/sleep_command.py new file mode 100644 index 000000000000..1e7e05d05a2e --- /dev/null +++ b/examples/tts/sleep_command.py @@ -0,0 +1,3 @@ +import time +time.sleep(120*60*60) +print("done sleeping") \ No newline at end of file diff --git a/examples/tts/ssl_tts.py b/examples/tts/ssl_tts.py new file mode 100644 index 000000000000..a38e49a9c322 --- /dev/null +++ b/examples/tts/ssl_tts.py @@ -0,0 +1,36 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl + +from nemo.collections.common.callbacks import LogEpochTimeCallback +from nemo.collections.tts.models import ssl_tts +from nemo.core.config import hydra_runner +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="conf", config_name="ssl_tts") +def main(cfg): + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + model = ssl_tts.SSLDisentangler(cfg=cfg.model, trainer=trainer) + model.maybe_init_from_pretrained_checkpoint(cfg=cfg) + lr_logger = pl.callbacks.LearningRateMonitor() + epoch_time_logger = LogEpochTimeCallback() + trainer.callbacks.extend([lr_logger, epoch_time_logger]) + trainer.fit(model) + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/tts/ssl_tts_both.py b/examples/tts/ssl_tts_both.py new file mode 100644 index 000000000000..8fe9e0e79d39 --- /dev/null +++ b/examples/tts/ssl_tts_both.py @@ -0,0 +1,36 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytorch_lightning as pl + +from nemo.collections.common.callbacks import LogEpochTimeCallback +from nemo.collections.tts.models import ssl_tts_both +from nemo.core.config import hydra_runner +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="conf", config_name="ssl_tts") +def main(cfg): + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + model = ssl_tts_both.SSLDisentangler(cfg=cfg.model, trainer=trainer) + model.maybe_init_from_pretrained_checkpoint(cfg=cfg) + lr_logger = pl.callbacks.LearningRateMonitor() + epoch_time_logger = LogEpochTimeCallback() + trainer.callbacks.extend([lr_logger, epoch_time_logger]) + trainer.fit(model) + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/nemo/collections/tts/models/ssl_tts.py b/nemo/collections/tts/models/ssl_tts.py new file mode 100644 index 000000000000..e68d2eeaae7b --- /dev/null +++ b/nemo/collections/tts/models/ssl_tts.py @@ -0,0 +1,336 @@ +from typing import Dict, Optional, Union +import itertools +import torch +import torch.nn as nn +from pytorch_lightning import Trainer +from pytorch_lightning.trainer.supporters import CombinedLoader +from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger +from nemo.core.classes import ModelPT +from nemo.collections.asr.models import ssl_models +from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss +import nemo.collections.tts.torch.data as TTSData +from nemo.core.classes.common import PretrainedModelInfo, typecheck +from hydra.utils import instantiate +from dataclasses import dataclass +from nemo.collections.tts.torch.tts_tokenizers import BaseTokenizer, EnglishCharsTokenizer, EnglishPhonemesTokenizer +from omegaconf import DictConfig, OmegaConf, open_dict +from nemo.core.optim.lr_scheduler import WarmupPolicy + +def decode(tokenizer, token_list): + return tokenizer.sep.join(tokenizer._id2token[t] for t in token_list) + + +class SSLDisentangler(ModelPT): + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + super().__init__(cfg=cfg, trainer=trainer) + self.preprocessor = SSLDisentangler.from_config_dict(self._cfg.preprocessor) + self.encoder = SSLDisentangler.from_config_dict(self._cfg.encoder) + self._tb_logger = None + + self.downstream_nets = nn.ModuleDict() + for task in self._cfg.downstream_heads.task_names: + + if task == 'speaker_verification': + in_dim = self._cfg.encoder.d_model + out_dim = self._cfg.downstream_heads.speaker_embed_size + num_speakers = self._cfg.downstream_heads.num_speakers + self.downstream_nets[task] = nn.Linear(in_dim,out_dim) + self.sv_linear = nn.Linear(out_dim,num_speakers) + self.sv_loss = AngularSoftmaxLoss(scale=30, margin=0.4) + # self.sv_loss = nn.CrossEntropyLoss() + + elif task == 'content': + in_dim = self._cfg.encoder.d_model + out_dim = self._cfg.downstream_heads.content_embed_size + num_chars = len(self._text_tokenizer.tokens) #list of english tokens + self.downstream_nets[task] = nn.Linear(in_dim,out_dim) + self.content_linear = nn.Linear(out_dim,num_chars) + self.ctc_loss = nn.CTCLoss(blank=self._text_tokenizer.blank) + + self.automatic_optimization = False + + + @classmethod + def list_available_models(cls) -> Optional[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + results = [] + + model = PretrainedModelInfo( + pretrained_model_name="ssl_en_conformer_large", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:ssl_en_conformer_large", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/ssl_en_conformer_large/versions/1.10.1/files/ssl_en_conformer_large.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="ssl_en_conformer_xlarge", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:ssl_en_conformer_xlarge", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/ssl_en_conformer_xlarge/versions/1.10.0/files/ssl_en_conformer_xlarge.nemo", + ) + results.append(model) + + return results + + @property + def tb_logger(self): + if self._tb_logger is None: + if self.logger is None and self.logger.experiment is None: + return None + tb_logger = self.logger.experiment + if isinstance(self.logger, LoggerCollection): + for logger in self.logger: + if isinstance(logger, TensorBoardLogger): + tb_logger = logger.experiment + break + self._tb_logger = tb_logger + return self._tb_logger + + def __setup_dataloader_from_config(self, data_config): + # _text_tokenizer = instantiate(self._cfg.text_tokenizer) + _text_tokenizer = self._text_tokenizer = EnglishCharsTokenizer(add_blank_at="last") + for task in self._cfg.downstream_heads.task_names: + if task == 'speaker_verification': + sv_dataset = TTSData.TTSDataset( + manifest_filepath=data_config['manifest_speaker_verification_fp'], + sample_rate=16000, + text_tokenizer=_text_tokenizer, + sup_data_types=['speaker_id']) + sv_loader = torch.utils.data.DataLoader( + sv_dataset, + batch_size=data_config['batch_size'], + collate_fn=sv_dataset.general_collate_fn, + shuffle=data_config['shuffle'], + num_workers=data_config.get('num_workers', 0), + pin_memory=data_config.get('pin_memory', False)) + + if task == 'content': + content_dataset = TTSData.TTSDataset( + manifest_filepath=data_config['manifest_content_fp'], + sample_rate=16000, + text_tokenizer=_text_tokenizer, + max_duration=16.7 + ) + content_loader = torch.utils.data.DataLoader( + content_dataset, + batch_size=data_config['batch_size'], + collate_fn=content_dataset.general_collate_fn, + shuffle=data_config['shuffle'], + num_workers=data_config.get('num_workers', 0), + pin_memory=data_config.get('pin_memory', False)) + + loaders = {"sv": sv_loader, "content": content_loader} + return loaders + + + def setup_training_data(self, cfg): + self._train_dl = self.__setup_dataloader_from_config(self._cfg.train_ds) + + def setup_validation_data(self, cfg): + self._validation_dl = CombinedLoader(self.__setup_dataloader_from_config(self._cfg.validation_ds)) + + def configure_optimizers(self): + optim_backbone_config = self._cfg.optim_backbone.copy() + optim_downstream_config = self._cfg.optim_downstream.copy() + + OmegaConf.set_struct(optim_backbone_config, False) + sched_backbone_config = optim_backbone_config.pop("sched", None) + OmegaConf.set_struct(optim_backbone_config, True) + + OmegaConf.set_struct(optim_downstream_config, False) + sched_downstream_config = optim_downstream_config.pop("sched", None) + OmegaConf.set_struct(optim_downstream_config, True) + + optim_backbone = instantiate(optim_backbone_config, params=self.encoder.parameters(),) + optim_downstream = instantiate(optim_downstream_config, params=itertools.chain(self.downstream_nets.parameters(), self.sv_linear.parameters(), self.content_linear.parameters(), self.sv_loss.parameters() ),) + + if sched_backbone_config is not None and sched_downstream_config is not None: + + scheduler_backbone = WarmupPolicy( + optimizer=optim_backbone, max_steps=None, min_lr=sched_backbone_config.min_lr, warmup_steps=sched_backbone_config.warmup_steps, + ) # Use warmup to delay start + sch1_dict = { + 'scheduler': scheduler_backbone, + 'interval': 'step', + } + + scheduler_downstream = WarmupPolicy( + optimizer=optim_downstream, max_steps=None, min_lr=sched_downstream_config.min_lr, warmup_steps=sched_downstream_config.warmup_steps, + ) + sch2_dict = { + 'scheduler': scheduler_downstream, + 'interval': 'step', + } + + return [optim_backbone, optim_downstream], [sch1_dict, sch2_dict] + else: + return [optim_backbone, optim_downstream] + + + def forward(self, input_signal=None, input_signal_length=None): + + processed_signal, processed_signal_length = self.preprocessor( + input_signal=input_signal, length=input_signal_length, + ) + + encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) #b,c,t + + for task in self._cfg.downstream_heads.task_names: + if task == "speaker_verification": + speaker_embedding = self.downstream_nets['speaker_verification'](encoded[:,:,0]) + l2_norm = torch.norm(speaker_embedding, p=2,dim=-1, keepdim=True) + speaker_embedding_normalized = speaker_embedding/l2_norm + speaker_logits = self.sv_linear(speaker_embedding_normalized) + + elif task == "content": + encoded_btc = encoded.permute(0, 2, 1) + content_embedding = self.downstream_nets['content'](encoded_btc) + content_logits = self.content_linear(content_embedding) + content_log_probs = content_logits.log_softmax(dim=2) + content_log_probs = content_log_probs.permute(1, 0, 2) #t,b,c for ctc + + + return speaker_logits, speaker_embedding_normalized, content_embedding, content_log_probs, encoded_len + + + def training_step(self, batch, batch_idx): + loss = 0.0 + optim_backbone, optim_downstream = self.optimizers() + schedulers = self.lr_schedulers() + + for key in batch.keys(): + if key == 'sv': + signal = batch[key]['audio'] + signal_len = batch[key]['audio_lens'] + speaker_id = batch[key]['speaker_id'] + + sv_logits, sv_emb, _, _, _ = self.forward( + input_signal=signal, input_signal_length=signal_len + ) + + pred_speaker = torch.argmax(sv_logits, dim=1) + + sv_loss = self.sv_loss(logits=sv_logits, labels=speaker_id) + loss += sv_loss + if not self._cfg.combined_loss: + optim_backbone.zero_grad() + optim_downstream.zero_grad() + self.manual_backward(sv_loss) + optim_backbone.step() + optim_downstream.step() + + correct = pred_speaker.eq(speaker_id.data.view_as(pred_speaker)).sum().item() + acc = (correct/len(speaker_id))*100 + + self.log("t_sv_loss", sv_loss.item()) + self.log("t_sv_accuracy", acc) + + elif key == "content": + signal = batch[key]['audio'] + signal_len = batch[key]['audio_lens'] + target = batch[key]['text'] # (B, T) + target_len = batch[key]['text_lens'] + + _, _, content_embedding, content_log_probs, encoded_len = self.forward( + input_signal=signal, input_signal_length=signal_len + ) + + ctc_loss = self.ctc_loss(content_log_probs, target, encoded_len, target_len) + loss += ctc_loss + + if not self._cfg.combined_loss: + optim_backbone.zero_grad() + optim_downstream.zero_grad() + self.manual_backward(ctc_loss) + optim_backbone.step() + optim_downstream.step() + + self.log("t_content_loss", ctc_loss.item()) + + + if self._cfg.combined_loss: + optim_backbone.zero_grad() + optim_downstream.zero_grad() + self.manual_backward(loss) + optim_backbone.step() + optim_downstream.step() + + if schedulers is not None: + sch1, sch2 = schedulers + sch1.step() + sch2.step() + + if self.trainer.global_step % 10 == 0: + # self.log("lr backbone", optim_backbone.param_groups[0]['lr'] ) + # self.log("lr downstream", optim_downstream.param_groups[0]['lr'] ) + # self.log("t_loss", loss) + print ("Loss", loss.item()) + # print ("lr backbone", optim_backbone.param_groups[0]['lr']) + # print ("lr down", optim_downstream.param_groups[0]['lr']) + + # return {'loss': loss} + + def validation_step(self, batch, batch_idx): + + loss_total = 0 + for key in batch.keys(): + if key == 'sv': + signal = batch[key]['audio'] + signal_len = batch[key]['audio_lens'] + speaker_id = batch[key]['speaker_id'] + sv_logits, sv_emb, _, _, _ = self.forward( + input_signal=signal, input_signal_length=signal_len + ) + + pred_speaker = torch.argmax(sv_logits, dim=1) + sv_loss = self.sv_loss(logits=sv_logits, labels=speaker_id) + loss_total += sv_loss + + correct = pred_speaker.eq(speaker_id.data.view_as(pred_speaker)).sum().item() + acc = (correct/len(speaker_id))*100 + acc_val = torch.as_tensor(acc) + + if key == 'content': + signal = batch[key]['audio'] + signal_len = batch[key]['audio_lens'] + target = batch[key]['text'] # (B, T) + target_len = batch[key]['text_lens'] + + _, _, content_embedding, content_log_probs, encoded_len = self.forward( + input_signal=signal, input_signal_length=signal_len + ) + + ctc_loss = self.ctc_loss(content_log_probs, target, encoded_len, target_len) + loss_total += ctc_loss + pred_char_batch = torch.argmax(content_log_probs, dim=2) + pred_char_batch = pred_char_batch.permute(1,0) + pred_char = decode(self._text_tokenizer, pred_char_batch[0].tolist() ) + target_char = decode(self._text_tokenizer, target[0].tolist() ) + + + return { + 'val_loss': loss_total.cpu(), + 'sv_loss' : sv_loss.cpu(), + 'ctc_loss' : ctc_loss.cpu(), + 'accuracy_sv': acc_val.cpu() + } + + def validation_epoch_end(self, outputs): + collect = lambda key: torch.stack([x[key] for x in outputs]).mean() + val_loss = collect("val_loss") + val_sv_loss = collect("sv_loss") + val_ctc_loss = collect("ctc_loss") + accuracy_sv = collect("accuracy_sv") + self.log("val_loss", val_loss) + self.log("sv_loss", val_sv_loss) + self.log("ctc_loss", val_ctc_loss) + self.log("accuracy_sv", accuracy_sv) + + + + diff --git a/nemo/collections/tts/models/ssl_tts_both.py b/nemo/collections/tts/models/ssl_tts_both.py new file mode 100644 index 000000000000..04fa020475a2 --- /dev/null +++ b/nemo/collections/tts/models/ssl_tts_both.py @@ -0,0 +1,262 @@ +from typing import Dict, Optional, Union + +import torch +import torch.nn as nn +from omegaconf import DictConfig +from pytorch_lightning import Trainer +from pytorch_lightning.trainer.supporters import CombinedLoader +from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger +from nemo.core.classes import ModelPT +from nemo.collections.asr.models import ssl_models +from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss +import nemo.collections.tts.torch.data as TTSData +from nemo.core.classes.common import PretrainedModelInfo, typecheck +from hydra.utils import instantiate +from dataclasses import dataclass +from nemo.collections.tts.torch.tts_tokenizers import BaseTokenizer, EnglishCharsTokenizer, EnglishPhonemesTokenizer + + +def decode(tokenizer, token_list): + return tokenizer.sep.join(tokenizer._id2token[t] for t in token_list) + + +class SSLDisentangler(ModelPT): + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + super().__init__(cfg=cfg, trainer=trainer) + self.preprocessor = SSLDisentangler.from_config_dict(self._cfg.preprocessor) + self.encoder = SSLDisentangler.from_config_dict(self._cfg.encoder) + self._tb_logger = None + + self.downstream_nets = nn.ModuleDict() + for task in self._cfg.downstream_heads.task_names: + + if task == 'speaker_verification': + in_dim = self._cfg.encoder.d_model + out_dim = self._cfg.downstream_heads.speaker_embed_size + num_speakers = self._cfg.downstream_heads.num_speakers + self.downstream_nets[task] = nn.Linear(in_dim,out_dim) + self.sv_linear = nn.Linear(out_dim,num_speakers) + self.sv_loss = AngularSoftmaxLoss(scale=30, margin=0.4) + # self.sv_loss = nn.CrossEntropyLoss() + + elif task == 'content': + in_dim = self._cfg.encoder.d_model + out_dim = self._cfg.downstream_heads.content_embed_size + num_chars = len(self._text_tokenizer.tokens) #list of english tokens + self.downstream_nets[task] = nn.Linear(in_dim,out_dim) + self.content_linear = nn.Linear(out_dim,num_chars) + self.ctc_loss = nn.CTCLoss(blank=self._text_tokenizer.blank) + + + @classmethod + def list_available_models(cls) -> Optional[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + results = [] + + model = PretrainedModelInfo( + pretrained_model_name="ssl_en_conformer_large", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:ssl_en_conformer_large", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/ssl_en_conformer_large/versions/1.10.1/files/ssl_en_conformer_large.nemo", + ) + results.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="ssl_en_conformer_xlarge", + description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:ssl_en_conformer_xlarge", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/ssl_en_conformer_xlarge/versions/1.10.0/files/ssl_en_conformer_xlarge.nemo", + ) + results.append(model) + + return results + + @property + def tb_logger(self): + if self._tb_logger is None: + if self.logger is None and self.logger.experiment is None: + return None + tb_logger = self.logger.experiment + if isinstance(self.logger, LoggerCollection): + for logger in self.logger: + if isinstance(logger, TensorBoardLogger): + tb_logger = logger.experiment + break + self._tb_logger = tb_logger + return self._tb_logger + + def __setup_dataloader_from_config(self, data_config): + # _text_tokenizer = instantiate(self._cfg.text_tokenizer) + _text_tokenizer = self._text_tokenizer = EnglishCharsTokenizer(add_blank_at="last") + for task in self._cfg.downstream_heads.task_names: + if task == 'speaker_verification': + sv_dataset = TTSData.TTSDataset( + manifest_filepath=data_config['manifest_speaker_verification_fp'], + sample_rate=16000, + text_tokenizer=_text_tokenizer, + max_duration=16.7, + sup_data_types=['speaker_id']) + sv_loader = torch.utils.data.DataLoader( + sv_dataset, + batch_size=data_config['batch_size'], + collate_fn=sv_dataset.general_collate_fn, + shuffle=data_config['shuffle'], + num_workers=data_config.get('num_workers', 0), + pin_memory=data_config.get('pin_memory', False)) + + if task == 'content': + content_dataset = TTSData.TTSDataset( + manifest_filepath=data_config['manifest_content_fp'], + sample_rate=16000, + text_tokenizer=_text_tokenizer, + max_duration=16.7) + content_loader = torch.utils.data.DataLoader( + content_dataset, + batch_size=data_config['batch_size'], + collate_fn=content_dataset.general_collate_fn, + shuffle=data_config['shuffle'], + num_workers=data_config.get('num_workers', 0), + pin_memory=data_config.get('pin_memory', False)) + + loaders = {"sv": sv_loader, "content": content_loader} + return loaders + + + def setup_training_data(self, cfg): + self._train_dl = self.__setup_dataloader_from_config(self._cfg.train_ds) + + def setup_validation_data(self, cfg): + self._validation_dl = CombinedLoader(self.__setup_dataloader_from_config(self._cfg.validation_ds)) + + def forward(self, input_signal=None, input_signal_length=None): + + processed_signal, processed_signal_length = self.preprocessor( + input_signal=input_signal, length=input_signal_length, + ) + + encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) #b,c,t + + for task in self._cfg.downstream_heads.task_names: + if task == "speaker_verification": + speaker_embedding = self.downstream_nets['speaker_verification'](encoded[:,:,0]) + l2_norm = torch.norm(speaker_embedding, p=2,dim=-1, keepdim=True) + speaker_embedding_normalized = speaker_embedding/l2_norm + speaker_logits = self.sv_linear(speaker_embedding_normalized) + + elif task == "content": + encoded_btc = encoded.permute(0, 2, 1) + content_embedding = self.downstream_nets['content'](encoded_btc) + content_logits = self.content_linear(content_embedding) + content_log_probs = content_logits.log_softmax(dim=2) + content_log_probs = content_log_probs.permute(1, 0, 2) #t,b,c for ctc + + + return speaker_logits, speaker_embedding_normalized, content_embedding, content_log_probs, encoded_len + + + def training_step(self, batch, batch_idx): + + loss = 0.0 + for key in batch.keys(): + if key == 'sv': + signal = batch[key]['audio'] + signal_len = batch[key]['audio_lens'] + speaker_id = batch[key]['speaker_id'] + sv_logits, sv_emb, _, _, _ = self.forward( + input_signal=signal, input_signal_length=signal_len + ) + + pred_speaker = torch.argmax(sv_logits, dim=1) + + sv_loss = self.sv_loss(logits=sv_logits, labels=speaker_id) + loss += sv_loss + + correct = pred_speaker.eq(speaker_id.data.view_as(pred_speaker)).sum().item() + acc = (correct/len(speaker_id))*100 + + self.log("t_sv_loss", sv_loss) + self.log("t_sv_accuracy", acc) + + elif key == "content": + signal = batch[key]['audio'] + signal_len = batch[key]['audio_lens'] + target = batch[key]['text'] # (B, T) + target_len = batch[key]['text_lens'] + + _, _, content_embedding, content_log_probs, encoded_len = self.forward( + input_signal=signal, input_signal_length=signal_len + ) + + ctc_loss = self.ctc_loss(content_log_probs, target, encoded_len, target_len) + loss += ctc_loss + + self.log("t_content_loss", ctc_loss) + + + self.log("t_loss", loss) + + return {'loss': loss} + + def validation_step(self, batch, batch_idx): + + loss_total = 0 + for key in batch.keys(): + if key == 'sv': + signal = batch[key]['audio'] + signal_len = batch[key]['audio_lens'] + speaker_id = batch[key]['speaker_id'] + sv_logits, sv_emb, _, _, _ = self.forward( + input_signal=signal, input_signal_length=signal_len + ) + + pred_speaker = torch.argmax(sv_logits, dim=1) + sv_loss = self.sv_loss(logits=sv_logits, labels=speaker_id) + loss_total += sv_loss + + correct = pred_speaker.eq(speaker_id.data.view_as(pred_speaker)).sum().item() + acc = (correct/len(speaker_id))*100 + acc_val = torch.as_tensor(acc) + + if key == 'content': + signal = batch[key]['audio'] + signal_len = batch[key]['audio_lens'] + target = batch[key]['text'] # (B, T) + target_len = batch[key]['text_lens'] + + _, _, content_embedding, content_log_probs, encoded_len = self.forward( + input_signal=signal, input_signal_length=signal_len + ) + + ctc_loss = self.ctc_loss(content_log_probs, target, encoded_len, target_len) + loss_total += ctc_loss + pred_char_batch = torch.argmax(content_log_probs, dim=2) + pred_char_batch = pred_char_batch.permute(1,0) + pred_char = decode(self._text_tokenizer, pred_char_batch[0].tolist() ) + target_char = decode(self._text_tokenizer, target[0].tolist() ) + + + return { + 'val_loss': loss_total, + 'sv_loss' : sv_loss, + 'ctc_loss' : ctc_loss, + 'accuracy_sv': acc_val + } + + def validation_epoch_end(self, outputs): + collect = lambda key: torch.stack([x[key] for x in outputs]).mean() + val_loss = collect("val_loss") + val_sv_loss = collect("sv_loss") + val_ctc_loss = collect("ctc_loss") + accuracy_sv = collect("accuracy_sv") + self.log("val_loss", val_loss) + self.log("sv_loss", val_sv_loss) + self.log("ctc_loss", val_ctc_loss) + self.log("accuracy_sv", accuracy_sv) + + + + diff --git a/nemo/core/optim/lr_scheduler.py b/nemo/core/optim/lr_scheduler.py index 92da9938703f..a3fe96ca8e3a 100644 --- a/nemo/core/optim/lr_scheduler.py +++ b/nemo/core/optim/lr_scheduler.py @@ -71,7 +71,7 @@ def get_lr(self): if step <= self.warmup_steps and self.warmup_steps > 0: return self._get_warmup_lr(step) - if step > self.max_steps: + if self.max_steps is not None and step > self.max_steps: return [self.min_lr for _ in self.base_lrs] return self._get_lr(step) @@ -789,19 +789,24 @@ def prepare_lr_scheduler( num_workers = scheduler_config.get('t_num_workers') # Compute effective num max_steps - num_samples = len(train_dataloader.dataset) + if isinstance(train_dataloader, dict): + _train_dataloader = train_dataloader[list(train_dataloader.keys())[0]] + else: + _train_dataloader = train_dataloader + + num_samples = len(_train_dataloader.dataset) # TODO: not sure if this will be the correct LR schedule for Megatron # we may need to override ModelPT setup_optimization - if train_dataloader.batch_size is not None: - batch_size = train_dataloader.batch_size - elif hasattr(train_dataloader, 'batch_sampler') and train_dataloader.batch_sampler is not None: - if train_dataloader.batch_sampler.micro_batch_size is not None: - batch_size = train_dataloader.batch_sampler.micro_batch_size + if _train_dataloader.batch_size is not None: + batch_size = _train_dataloader.batch_size + elif hasattr(_train_dataloader, 'batch_sampler') and _train_dataloader.batch_sampler is not None: + if _train_dataloader.batch_sampler.micro_batch_size is not None: + batch_size = _train_dataloader.batch_sampler.micro_batch_size else: - raise ValueError(f'Could not find batch_size from batch_sampler: {train_dataloader.batch_sampler}') + raise ValueError(f'Could not find batch_size from batch_sampler: {_train_dataloader.batch_sampler}') else: - raise ValueError(f'Could not find batch_size from train_dataloader: {train_dataloader}') - drop_last = train_dataloader.drop_last + raise ValueError(f'Could not find batch_size from train_dataloader: {_train_dataloader}') + drop_last = _train_dataloader.drop_last max_steps = compute_max_steps( max_epochs=max_epochs, diff --git a/scripts/speaker_tasks/speaker_json_formatting.py b/scripts/speaker_tasks/speaker_json_formatting.py new file mode 100644 index 000000000000..c89ce535853e --- /dev/null +++ b/scripts/speaker_tasks/speaker_json_formatting.py @@ -0,0 +1,36 @@ +import json + +def format_json(source_json_fp, out_json_fp, speaker_mapping=None): + if speaker_mapping is None: + speaker_mapping = {} + spk_idx = 0 + + all_records = [] + with open(source_json_fp, 'r') as f: + lines = f.readlines() + for line in lines: + record = json.loads(line) + if record['label'] in speaker_mapping: + record['speaker'] = speaker_mapping[record['label']] + else: + record['speaker'] = spk_idx + speaker_mapping[record['label']] = spk_idx + spk_idx += 1 + record['text'] = "dummy" + all_records.append(record) + + with open(out_json_fp, 'w') as f: + out_str = "" + for record in all_records: + out_str += json.dumps(record) + '\n' + out_str = out_str[:-1] + + f.write(out_str) + + return speaker_mapping +if __name__ == '__main__': + speaker_mapping = format_json("/raid/datasets/train.json", "/raid/datasets/train_formatted.json") + format_json("/raid/datasets/dev.json", "/raid/datasets/dev_formatted.json", speaker_mapping) + + # speaker_mapping = format_json("/home/shehzeenh/train.json", "/home/shehzeenh/train_formatted.json") + # format_json("/home/shehzeenh/dev.json", "/home/shehzeenh/dev_formatted.json", speaker_mapping) \ No newline at end of file diff --git a/scripts/ssl_tts/eer_eval.py b/scripts/ssl_tts/eer_eval.py new file mode 100644 index 000000000000..bbbe658f9df8 --- /dev/null +++ b/scripts/ssl_tts/eer_eval.py @@ -0,0 +1,89 @@ +from nemo.collections.tts.models import ssl_tts +import omegaconf +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +import os +import json +from sklearn.metrics import roc_curve, auc +import numpy as np + + +def sv_emb_from_audio(audio_path,ssl_model): + wav = wav_featurizer.process(audio_path) + audio_signal = wav[None].cuda() + audio_signal_length = torch.tensor( [ wav.shape[0] ]).cuda() + ssl_model_cuda = ssl_model.cuda() + emb = ssl_model_cuda(audio_signal,audio_signal_length) + + return emb[1] + +def get_similarity(audio_path_1, audio_path_2,ssl_model): +# audio_path_1 = "/home/shehzeenh/datasets/speaker_verification_full/vox1/segments/id10986/KH-yJAsKo1Q/00019_0_4.wav" +# audio_path_2 = "/home/shehzeenh/datasets/speaker_verification_full/vox1/segments/id10986/KH-yJAsKo1Q/00037_0_4.wav" + sv_emb1 = sv_emb_from_audio(audio_path_1,ssl_model) + sv_emb2 = sv_emb_from_audio(audio_path_2,ssl_model) + similarity = F.cosine_similarity(sv_emb1,sv_emb2) + similarity = similarity.item() + # print(similarity) + return similarity + +def get_checkpoint(folder_path): + + ckpt_path = None + for filename in os.listdir(folder_path): + if filename.endswith('last.ckpt'): + ckpt_path = os.path.join(folder_path, filename) + + return ckpt_path + +path = "/home/shehzeenh/nemo_local/NeMo/examples/tts/conf/ssl_tts.yaml" +cfg = omegaconf.OmegaConf.load(path) +cfg.model.train_ds.manifest_filepath = "dummy" +cfg.model.validation_ds.manifest_filepath = "dummy" +ssl_model = ssl_tts.SSLDisentangler(cfg=cfg.model) +cfg.pop('init_from_pretrained_model') + +# ckpt_path = '/home/shehzeenh/nemo_local/NeMo/examples/tts/nemo_experiments/Conformer-SSL/2022-07-13_10-56-22/checkpoints/Conformer-SSL--val_loss=1.1792-epoch=11-last.ckpt' +ckpt_path = get_checkpoint('/home/shehzeenh/nemo_local/NeMo/examples/tts/nemo_experiments/Conformer-SSL/2022-07-13_10-56-22/checkpoints/') +print ("CKPT PATH", ckpt_path) +cfg.init_from_ptl_ckpt = ckpt_path +ssl_model.maybe_init_from_pretrained_checkpoint(cfg=cfg) +wav_featurizer = WaveformFeaturizer(sample_rate=16000, int_values=False, augmentor=None) + +y_score = [] +y_true = [] + +# with open('/home/shehzeenh/datasets/speaker_verification_full/vox1_test/validation_test_pairs.txt') as f: +with open('/home/shehzeenh/datasets/speaker_verification_full/vox_o_trial.txt') as f: + lines = f.readlines() # list containing lines of file + + for line in lines: + line = line.strip() # remove leading/trailing white spaces +# print(line) + label, wav1, wav2 = line.split(' ') + # print(label, wav1, wav2) + + wav1_path = '/home/shehzeenh/datasets/speaker_verification_full/test_wav/'+ str(wav1) + wav2_path = '/home/shehzeenh/datasets/speaker_verification_full/test_wav/'+ str(wav2) + + ssl_model.eval() + with torch.no_grad(): + sim_score = get_similarity(wav1_path,wav2_path,ssl_model) + + y_score.append(sim_score) + y_true.append(int(label)) + + + +fpr, tpr, thresholds = roc_curve(y_true, y_score) +_auc = auc(fpr, tpr) +fnr = 1 - tpr +eer_threshold = thresholds[np.nanargmin(np.absolute((fnr - fpr)))] +eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] +eer_verify = fnr[np.nanargmin(np.absolute((fnr - fpr)))] + +assert abs(eer - eer_verify) < 1.0 +print ("eer", eer) +print ("auc", _auc) \ No newline at end of file