Skip to content

Snowdar/asv-subtools

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

ASV-Subtools: An Open Source Tools for Speaker Recognition

ASV-Subtools is developed based on Pytorch and Kaldi for the task of speaker recognition, language identification, etc.
The 'sub' of 'subtools' means that there are many modular tools and the parts constitute the whole.

Copyright: [TalentedSoft-XMU Speech Lab] XMU Speech Lab (Xiamen University, China) TalentedSoft (TalentedSoft, China) Apache 2.0

Author : Miao Zhao (Email: snowdar@stu.xmu.edu.cn), Jianfeng Zhou, Zheng Li, Hao Lu, Fuchuan Tong, Dexin Liao, Tao Jiang
Current Maintainer: Tao Jiang (Email: sssyousen@163.com)
Co-author: Lin Li, Qingyang Hong

Citation:

@inproceedings{tong2021asv,
  title={{ASV-Subtools}: {Open} Source Toolkit for Automatic Speaker Verification},
  author={Tong, Fuchuan and Zhao, Miao and Zhou, Jianfeng and Lu, Hao and Li, Zheng and Li, Lin and Hong, Qingyang},
  booktitle={ICASSP 2021-2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
  pages={6184--6188},
  year={2021},
  organization={IEEE}
}


Introduction

In ASV-Subtools, Kaldi is used to extract acoustic features and scoring in the back-end and Pytorch is used to build a model freely and train it with a custom style.

The project structure, training framework and data pipeline shown as follows could help you to have some insights into ASV-Subtools.

By the way, if you can not see the pictures in Github, maybe you should try to check the DNS of your network or use a VPN agent. If you are a student of XMU, then the VPN of campus network could be very helpful for these types of problems (see https://vpn.xmu.edu.cn for a configuration). Of course, at least the last way is to clone ASV-Subtools to your local notebook.

Project Structure

ASV-Subtools contains three main branches:

  • Basic Shell Scripts: data processing, back-end scoring (most are based on Kaldi)
  • Kaldi: training of basic model (i-vector, TDNN, F-TDNN and multi-task learning x-vector)
  • Pytorch: training of custom model (less limitation)


For pytorch branch, there are two important concepts:

  • Model Blueprint: the path of your_model.py
  • Model Creation : the code to init a model class, such as resnet(40, 1211, loss="AM")

In ASV-Subtools, the model is individual, which means that we should know the path of model.py and how to initialize this model class at least when using this model in training or testing module. This structure is designed to avoid modifying codes of static modules frequently. For example, if the embedding extractor is wrote down as a called program and we use an inline method from my_model_py import my_model to import a fixed model from a fixed model.py , then it will be not free for model_2.py, model_3.py and so on.

Note that, all models (torch.nn.Module) shoud inherit libs.nnet.framework.TopVirtualNnet class to get some default functions, such as auto-saving model creation and blueprint, extracting emdedding of whole utterance, step-training, computing accuracy, etc.. It is easy to transform the original model of Pytorch to ASV-Subtools model by inheriting. Just modify your model.py w.r.t this x-vector example.

Training Framework

The basic training framework is provided here and the relations between every module are very clear. So it will be not complex if you want to change anything when you want to have a custom ASV-Subtools.

Note that, libs/support/utils.py has many common functions, so it is imported in most of *.py.



Data Pipeline

Here, a data pipeline is given to show the relation between Kaldi and Pytorch. There are only two interfaces, reading acoustic features and writing x-vectors, and both of them are implemented by kaldi_io.

Of course, this data pipeline could be also followed to know the basic principle of xvector-based speaker recognition.



Update Pipeline

  • 20221113
  • 20220707
    • Online Datasets is implemented (Including online feature extracting, online VAD, online augmentation and online x-vector extracting)
    • Supporting mixed precision training.
    • Runtime module for exporting jit model.
    • Updating some models.
    • Feature Decomposition and Cosine Similar Adversarial Learning (FD-AL)

Support List

Ready to Start

1. Install Kaldi

Pytorch-training is not much related to Kaldi, but we have not provided other interfaces to concatenate acoustic feature and training module now. So if you don't want to use Kaldi, you could change the libs.egs.egs.ChunkEgs class where the features are given to Pytorch only by torch.utils.data.Dataset. Besides, you should also change the interface of extracting x-vector after training. Note that, most of scripts which require Kaldi could be not available in this case, such as subtools/makeFeatures.sh and subtools/augmentDataByNoise.sh.

If you prefer to use Kaldi, then install Kaldi firstly w.r.t http://www.kaldi-asr.org/doc/install.html.

Here are conclusive stages:

# Download Kaldi
git clone https://github.com/kaldi-asr/kaldi.git kaldi --origin upstream
cd kaldi

# You could check the INSTALL file of current directory for more details of installation
cat INSTALL

# Compile tools firstly
cd tools
bash extras/check_dependencies.sh
make -j 4

# Config src before compiling
cd ../src
./configure --shared

# Check depend and compile
make depend -j 4
make -j 4
cd ..

2. Create Project

Create your project with 4-level name relative to Kaldi root directory (1-level), such as kaldi/egs/xmuspeech/sre. It is important for the project environment. For more details, see subtools/path.sh.

# Suppose current directory is kaldi root directory
mkdir -p kaldi/egs/xmuspeech/sre

3. Clone ASV-Subtools

ASV-Subtools could be seen as a set of tools like 'utils' or 'steps' of Kaldi, so there are only two extra stages to complete the final installation:

  • Clone ASV-Subtools to your project.
  • Install the requirements of python (Python3 is recommended).

Here is the method cloning ASV-Subtools from Github:

# Clone asv-subtools from github
cd kaldi/egs/xmuspeech/sre
git clone https://github.com/Snowdar/asv-subtools.git subtools

4. Install Python Requirements

  • Pytorch>=1.10:
    conda create -n subtools python=3.8
    conda activate subtools
    conda install pytorch=1.10.0 torchaudio=0.10.0 cudatoolkit=11.1 -c pytorch -c conda-forge
  • Other requirements: numpy, thop, pandas, progressbar2, matplotlib, scipy (option), sklearn (option)
    # progressbar2 needs to install progeressbar first  
    pip3 install progressbar
    pip3 install progressbar2
    pip3 install -r subtools/requirements.txt

5. Support Multi-GPU Training

ASV-Subtools provide both DDP (recommended) and Horovod solutions to support multi-GPU training.

Some answers about how to use multi-GPU training, see subtools/pytorch/launcher/runSnowdarXvector.py. It is very convenient and easy now.

Requirements List:

  • DDP: Pytorch, NCCL
  • Horovod: Pytorch, NCCL, Openmpi, Horovod

An Example of Installing NCCL Based on Linux-Centos-7 and CUDA-10.2
Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-install-guide/index.html.

# For a simple way, there are only three stages.
# [1] Download rpm package of nvidia
wget https://developer.download.nvidia.com/compute/machine-learning/repos/rhel7/x86_64/nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm

# [2] Add nvidia repo to yum (NOKEY could be ignored)
sudo rpm -i nvidia-machine-learning-repo-rhel7-1.0.0-1.x86_64.rpm

# [3] Install NCCL by yum
sudo yum install libnccl-2.6.4-1+cuda10.2 libnccl-devel-2.6.4-1+cuda10.2 libnccl-static-2.6.4-1+cuda10.2

These yum-clean commands could be very useful when you get some troubles when using yum.

# Install yum-utils firstly
yum -y install yum-utils

# Stop unfinished transactions
yum-complete-transaction --cleanup-only

# Clean duplicate and conflict
package-cleanup --cleandupes

# Clean cached headers and packages
yum clean all

If you want to install Openmpi and Horovod, see https://github.com/horovod/horovod for more details.

6. Extra Installation (Option)

There are some extra installations for some special applications.

Train A Multi-Task Learning Model Based on Kaldi

Use subtools/kaldi/runMultiTaskXvector.sh to train a model with multi-task learning, but it requires some extra codes.

# Enter your project, such as kaldi/egs/xmuspeech/sre and make sure ASV-Subtools is cloned here
# Just run this patch to compile some extra C++ commands with Kaldi's format
cd kaldi/egs/xmuspeech/sre
bash subtools/kaldi/patch/runPatch-multitask.sh

Accelerate X-vector Extractor of Kaldi

It will spend so much time to compile nnet3 models for the utterances with different frames when extracting x-vectors based on Kaldi. To optimize this problem, ASV-Subtools provides an offine modification (MOD) in subtools/kaldi/sid/nnet3/xvector/extract_xvectors.sh to accelerate extracting. This MOD requires two extra commands, nnet3-compile-xvector-net and nnet3-offline-xvector-compute. When extracting x-vectors, all models with different input chunk-size will be compiled firstly. Then the utterances which have the same frames could share a compiled nnet3 network. It saves much time by avoiding a lot of duplicate dynamic compilations.

Besides, the scp spliting type w.r.t length of utterances (subtools/splitDataByLength.sh) is adopted to balance the frames of different nj when multi-processes is used.

# Enter your project, such as kaldi/egs/xmuspeech/sre and make sure ASV-Subtools is cloned here
# Just run this patch to compile some extra C++ commands with Kaldi's format

# Target *.cc:
#     src/nnet3bin/nnet3-compile-xvector-net.cc
#     src/nnet3bin/nnet3-offline-xvector-compute.cc

cd kaldi/egs/xmuspeech/sre
bash subtools/kaldi/patch/runPatch-base-command.sh

Add A MMI-GMM Classifier for The Back-End

If you have run subtools/kaldi/patch/runPatch-base-command.sh, then it dosen't need to run again.

# Enter your project, such as kaldi/egs/xmuspeech/sre and make sure ASV-Subtools is cloned here
# Just run this patch to compile some extra C++ commands with Kaldi's format

# Target *.cc:
#    src/gmmbin/gmm-global-init-from-feats-mmi.cc
#    src/gmmbin/gmm-global-est-gaussians-ebw.cc
#    src/gmmbin/gmm-global-est-map.cc
#    src/gmmbin/gmm-global-est-weights-ebw.cc

cd kaldi/egs/xmuspeech/sre
bash subtools/kaldi/patch/runPatch-base-command.sh

Training Model

If you have completed the Ready to Start stage, then you could try to train a model with ASV-Subtools.

For kaldi training, some launcher scripts named run*.sh could be found in subtoos/Kaldi/.

For pytorch training, some launcher scripts named run*.py could be found in subtools/pytorch/launcher/. And some models named *.py could be found in subtools/pytorch/model/. Note that, model will be called in launcher.py.

Here is a pytorch training example, but you should follow a pipeline of recipe to prepare your data and features before training. The part of data preprocessing is not complex and it is the same as Kaldi.

# Suppose you have followed the recipe and prepare your data and faetures, then the training could be run by follows.
# Enter your project, such as kaldi/egs/xmuspeech/sre and make sure ASV-Subtools is cloned here

# Firsty, copy a launcher to your project
cp subtools/pytorch/launcher/runSnowdarXvector.py ./

# Modify this launcher and run
# In most of time, there are only two files, model.py and launcher.py, will be changed.
subtools/runLauncher.sh runSnowdarXvector.py --gpu-id=0,1,2,3 --stage=0

Recipe

[1] Voxceleb Recipe [Speaker Recognition]

Voxceleb is a popular dataset for the task of speaker recognition. It has two parts now, Voxceleb1 and Voxceleb2.

There are two recipes for Voxceleb:

i. Test Voxceleb1-O only

It means the trainset could be sampled from both Voxceleb1.dev and Voxceleb2 with a fixed training condition. The training script is available in subtools/recipe/voxceleb/runVoxceleb.sh.

The voxceleb1 recipe with mfcc23&pitch features is available:
Link: https://pan.baidu.com/s/1nMXaAXiOnFGRhahzVyrQmg
Password: 24sg

# Download this recipe to kaldi/egs/xmuspeech directory
cd kaldi/egs/xmuspeech
tar xzf voxceleb1_recipe.tar.gz
cd voxceleb1

# Clone ASV-Subtools (Suppose the configuration of related environment has been done)
git clone https://github.com/Snowdar/asv-subtools.git subtools

# Train an extended x-vector model (Do not use multi-GPU training for it is not stable for specaugment.)
subtools/runPytorchLauncher.sh runSnowdarXvector-extended-spec-am.py --stage=0

# Score (EER = 2.444% for voxceleb1.test)
subtools/recipe/voxceleb/gather_results_from_epochs.sh --vectordir exp/extended_spec_am --epochs 21 --score plda

Results of Voxceleb1-O with Voxceleb1.dev.aug1:1 Training only

results-1.png

Results of Voxceleb1-O with Voxceleb1&2.dev.aug1:1 Training

results-2.png

Note, 2000 utterances are selected from no-aug-trainset as the cohort set of AS-Norm, the same below.


ii. Test Voxceleb1-O/E/H

It means the trainset could only be sampled from Voxceleb2 with a fixed training condition.

Old Results of Voxceleb1-O/E/H with Voxceleb2.dev.aug1:4 Training (EER%)

results-3.png

These models are trained by adam + warmRestarts and they are old (so related scripts was removed). Note, Voxceleb1.dev is used as the trainset of back-end for the Voxceleb1-O* task and Voxceleb2.dev for others.

These basic models performs good but the results are not the state-of-the-art yet. I found that training strategies could have an important influence on the final performance, such as the number of epoch, the value of weight decay, the selection of optimizer, and so on. Unfortunately, I have not enough time and GPU to fine-tune so many models, especially training model with a large dataset like Voxceleb2 whose duration is more than 2300h (In this case, it will spend 1~2 days to train one fbank81-based Resnet2d model for 6 epochs with 4 V100 GPUs).

--#--Snowdar--2020-06-02--#--

New Results of Voxceleb1-O/E/H with Voxceleb2.dev.aug1:4 Training (EER%)

Here, this is a resnet34 benchmark model. And the training script is available in subtools/recipe/voxcelebSRC/runVoxcelebSRC.sh. For more details, see it also. (by Snowdar)

EER% vox1-O vox1-O-clean vox1-E vox1-E-clean vox1-H vox1-H-clean
Baseline 1.304 1.159 1.35 1.223 2.357 2.238
Submean 1.262 1.096 1.338 1.206 2.355 2.223
AS-Norm 1.161 1.026 - - - -

New Results of Voxceleb1-O/E/H with Voxceleb2.dev.aug.speed1:4:2 Training (EER%) Here, this is an ECAPA benchmark model. And the training script is available in subtools/pytorch/launcher/runEcapaXvector.py. For more details, see it also. (by Fuchuan Tong) ==new==

EER% vox1-O vox1-O-clean vox1-E vox1-E-clean vox1-H vox1-H-clean
Baseline 1.506 1.393 1.583 1.462 2.811 2.683
Submean 1.225 1.112 1.515 1.394 2.781 2.652
AS-Norm 1.140 0.963 - - - -

New Results of Voxceleb1-O/E/H with original Voxceleb2.dev (without data augmentation) Training (EER%) Here, this is an statistical pooling and Xi-vector embedding benchmark model (implement on TDNN). And the training script is available in subtools/pytorch/launcher/runSnowdar_Xivector.py. We would like to thank Dr. Kong Aik Lee for providing codes and useful discussion. (experiments conducted by Fuchuan Tong) ==2021-10-30==

EER% vox1-O vox1-E vox1-H
Statistics Pooling 1.85 2.01 3.57
Multi-head 1.76 2.00 3.54
Xi-Vector(โˆ…,๐œŽ) 1.59 1.90 3.38

New Results of Voxceleb1-O/E/H with Voxceleb2.dev (online random augmentation) Training(EER%) Here, this is a resnet34 benchmark model. And the training script is available in subtools/pytorch/launcher/runResnetXvector_online.py. For more details, see it also. (experiments conducted by Dexin Liao) ==2022-07-07==

EER% vox1-O vox1-O-clean vox1-E vox1-E-clean vox1-H vox1-H-clean
Submean 1.071 0.920 1.257 1.135 2.205 2.072
AS-Norm 0.970 0.819 - - - -

Here, this is a ECAPA benchmark model. And the training script is available in subtools/pytorch/launcher/runEcapaXvector_online.py. For more details, see it also. (experiments conducted by Dexin Liao) ==2022-07-07==

EER% vox1-O vox1-O-clean vox1-E vox1-E-clean vox1-H vox1-H-clean
Submean 1.045 0.904 1.330 1.211 2.430 2.303
AS-Norm 0.991 0.856 - - - -

New Results of Voxceleb1-O/E/H with Voxceleb2.dev (online random augmentation) Training(EER%) Here, this is a Conformer benchmark model. And the training script is available in subtools/pytorch/launcher/runTransformerXvector.py. For more details, see it also. (experiments conducted by Dexin Liao) ==2022-11-15==

  • Egs = Voxceleb2_dev(online random aug) + random chunk(3s)
  • Optimization = [adamW (lr = 1e-6 - 1e-3) + 1cycle] x 4 GPUs (total batch-size=512)
  • Conformer + FC-Swish-LN + ASP + FC-LN + AAM-Softmax (margin = 0.2))
  • Back-end = near + Cosine
  • LM: Large-Margin Fine-tune (margin: 0.2 --> 0.5, chunk: 6s)
Config EER% vox1-O vox1-O-clean vox1-E vox1-E-clean vox1-H vox1-H-clean
6L-256-4H-4Sub Submean 1.204 1.074 1.386 1.267 2.416 2.294
AS-Norm 1.029 0.952 - - - -
+SAM training cosine 1.103 0.984 1.350 1.234 2.380 2.257
LM 1.034 0.899 1.181 1.060 2.079 1.953
AS-Norm 0.943 0.792 - - - -
6L-256D-4H-2Sub cosine 1.066 0.915 1.298 1.177 2.167 2.034
LM 1.029 0.888 1.160 1.043 1.923 1.792
AS-Norm 0.949 0.792 - - - -

Results of RTF

  • RTF is evaluated on LibTorch-based runtime, see subtools/runtime
  • One thread is used for CPU threading and TorchScript inference.
  • CPU: Intel(R) Xeon(R) Gold 5218R CPU @ 2.10GHz.
Model Config Params RTF
ResNet34 base32 6.80M 0.090
ECAPA C1024 16.0M 0.071
C512 6.53M 0.030
Conformer 6L-256D-4H-4Sub 18.8M 0.025
6L-256D-4H-2Sub 22.5M 0.070

[2] OLR Challenge 2020 Baseline Recipe [Language Identification]

OLR Challenge 2020 is closed now.

Baseline: subtools/recipe/ap-olr2020-baseline.

The top training script of baseline is available in subtools/recipe/ap-olr2020-baseline/run.sh. And the baseline results could be seen in subtools/recipe/ap-olr2020-baseline/results.txt.

Plan: Zheng Li, Miao Zhao, Qingyang Hong, Lin Li, Zhiyuan Tang, Dong Wang, Liming Song and Cheng Yang: AP20-OLR Challenge: Three Tasks and Their Baselines, submitted to APSIPA ASC 2020.

[3] OLR Challenge 2021 Baseline Recipe [Language Identification]

Baseline: subtools/recipe/olr2021-baseline.

The top training script of baseline is available in subtools/recipe/olr2021-baseline/run.sh.

Plan: Binling Wang, Wenxuan Hu, Jing Li, Yiming Zhi, Zheng Li, Qingyang Hong, Lin Li, Dong Wang, Liming Song and Cheng Yang: OLR 2021 Challenge: Datasets, Rules and Baselines, submitted to APSIPA ASC 2021.

For previous challenges (2016-2020), see http://olr.cslt.org.

[4] CNSRC 2022 Baseline Recipe [Speaker Recognition]

Baseline: subtools/recipe/cnsrc.

The top training script of baseline is available in subtools/recipe/cnsrc/sv/run-cnsrc_sv.sh and subtools/recipe/cnsrc/sr/run-cnsrc_sr.sh.

Plan: Dong Wang, Qingyang Hong, Liantian Li, Hui Bu: CNSRC 2022 Evaluation Plan.

For more informations, see http://cnceleb.org. For any Challenge questions please contact lilt@cslt.org and for any baseline questions contact sssyousen@163.com.

Tasks Trainging Evaluation Metrics
Task1 SV CN-Celeb.T CN-Celeb.E minDCF:0.463 EER:9.141%
Task2 SR CN-Celeb.T SR.eval mAP:0.242

New Results of CN-Celeb.E with CN-Celeb.T (online random augmentation) Training(EER%)

config pretrain ASR EER% minDCF
6L-256D-4H - 8.39% 0.4748
Multi-CN 7.95% 0.4534
WenetSpeech 7.42% 0.4427

Feedback

  • If you find bugs or have some questions, please create a github issue in this repository to let everyone knows it, so that a good solution could be contributed.
  • If you want to ask some questions, just send e-mail to sssyousen@163.com (Tao Jiang) or snowdar@stu.xmu.edu.cn (Snowdar) for SRE answers and wangbling1207@stu.xmu.edu.cn for LID answers. In general, we will reply you in our free time.
  • If you want to join the WeChat group of asv-subtools, you can scan the QR code on the left to follow XMUSPEECH and reply "join group" + your institution/university + your name. In addtion, you can also scan the QR code on the right and the guy will invite you to the chat group.

Acknowledgement