Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KWS with CTCloss training and CTC prefix beam search detection. #135

Merged
merged 14 commits into from
Aug 16, 2023
Merged
57 changes: 57 additions & 0 deletions examples/hi_xiaowen/s0/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
Comparison among different backbones,
all models use Max-Pooling loss.
FRRs with FAR fixed at once per hour:

| model | params(K) | epoch | hi_xiaowen | nihao_wenwen |
Expand All @@ -8,3 +10,58 @@ FRRs with FAR fixed at once per hour:
| DS_TCN(spec_aug) | 287 | 80(avg30) | 0.008176 | 0.005075 |
| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 |
| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 |

Next, we use CTC loss to train the model, with DS_TCN and FSMN backbones.
and we use CTC prefix beam search to decode and detect keywords,
the detection is either in non-streaming or streaming fashion.

Since the FAR is pretty low when using CTC loss,
the follow results are FRRs with FAR fixed at once per 12 hours:

Comparison between Max-pooling and CTC loss.
The CTC model is fine-tuned with base model pretrained on WenetSpeech(23 epoch, not converged).
FRRs with FAR fixed at once per 12 hours

| model | loss | hi_xiaowen | nihao_wenwen | model ckpt |
|-----------------------|-------------|------------|--------------|------------|
| DS_TCN(spec_aug) | Max-pooling | 0.051217 | 0.021896 | [dstcn-maxpooling](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn/files) |
| DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 | [dstcn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn_ctc/files) |


Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch, not converged)
and FSMN(Pretained with modelscope released xiaoyunxiaoyun model, fully converged).
FRRs with FAR fixed at once per 12 hours:

| model | params(K) | hi_xiaowen | nihao_wenwen | model ckpt |
|-----------------------|-------------|------------|--------------|-------------------------------------------------------------------------------|
| DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 | [dstcn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn_ctc/files) |
| FSMN(spec_aug) | 756 | 0.031012 | 0.022460 | [fsmn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_fsmn_ctc/files) |

Now, the DSTCN model with CTC loss may not get the best performance, because the
pretraining phase is not sufficiently converged. We recommend you use pretrained
FSMN model as initial checkpoint to train your own model.

Comparison Between stream_score_ctc and score_ctc.
FRRs with FAR fixed at once per 12 hours:

| model | stream | hi_xiaowen | nihao_wenwen |
|-----------------------|-------------|------------|--------------|
| DS_TCN(spec_aug) | no | 0.056574 | 0.056856 |
| DS_TCN(spec_aug) | yes | 0.132694 | 0.057044 |
| FSMN(spec_aug) | no | 0.031012 | 0.022460 |
| FSMN(spec_aug) | yes | 0.115215 | 0.020205 |

Note: when using CTC prefix beam search to detect keywords in streaming case(detect in each frame),
we record the probability of a keyword in a decoding path once the keyword appears in this path.
Actually the probability will increase through the time, so we record a lower value of probability,
which result in a higher False Rejection Rate in Detection Error Tradeoff result.
The actual FRR will be lower than the DET curve gives in a given threshold.

On some small data KWS tasks, we believe the FSMN-CTC model is more robust
compared with the classification model using CE/Max-pooling loss.
For more infomation and results of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary).

For realtime CTC-KWS, we should process wave input on streaming-fashion,
include feature extraction, keyword decoding and detection and some postprocessing.
Here is a [demo](https://modelscope.cn/studios/thuduj12/KWS_Nihao_Xiaojing/summary) in python,
the core code is in wekws/bin/stream_kws_ctc.py, you can refer it to implement the runtime code.
50 changes: 50 additions & 0 deletions examples/hi_xiaowen/s0/conf/ds_tcn_ctc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 40
frame_shift: 10
frame_length: 25
dither: 1.0
spec_aug: true
spec_aug_conf:
num_t_mask: 1
num_f_mask: 1
max_t: 20
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 256

model:
hidden_dim: 256
preprocessing:
type: linear
backbone:
type: tcn
ds: true
num_layers: 4
kernel_size: 8
dropout: 0.1
activation:
type: identity


optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.0001

training_config:
grad_clip: 5
max_epoch: 80
log_interval: 10
criterion: ctc

50 changes: 50 additions & 0 deletions examples/hi_xiaowen/s0/conf/ds_tcn_ctc_base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 40
frame_shift: 10
frame_length: 25
dither: 1.0
spec_aug: true
spec_aug_conf:
num_t_mask: 1
num_f_mask: 1
max_t: 20
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 200

model:
hidden_dim: 256
preprocessing:
type: linear
backbone:
type: tcn
ds: true
num_layers: 4
kernel_size: 8
dropout: 0.1
activation:
type: identity


optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.0001

training_config:
grad_clip: 5
max_epoch: 50
log_interval: 100
criterion: ctc

64 changes: 64 additions & 0 deletions examples/hi_xiaowen/s0/conf/fsmn_ctc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.
context_expansion: true
context_expansion_conf:
left: 2
right: 2
frame_skip: 3
spec_aug: true
spec_aug_conf:
num_t_mask: 1
num_f_mask: 1
max_t: 20
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 256

model:
input_dim: 400
preprocessing:
type: none
hidden_dim: 128
backbone:
type: fsmn
input_affine_dim: 140
num_layers: 4
linear_dim: 250
proj_dim: 128
left_order: 10
right_order: 2
left_stride: 1
right_stride: 1
output_affine_dim: 140
classifier:
type: identity
dropout: 0.1
activation:
type: identity


optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.0001

training_config:
grad_clip: 5
max_epoch: 80
log_interval: 10
criterion: ctc

11 changes: 9 additions & 2 deletions examples/hi_xiaowen/s0/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

. ./path.sh

stage=0
stop_stage=4
stage=$1
stop_stage=$2
num_keywords=2

config=conf/ds_tcn.yaml
Expand Down Expand Up @@ -98,6 +98,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
python wekws/bin/score.py \
--config $dir/config.yaml \
--test_data data/test/data.list \
--gpu 0 \
--batch_size 256 \
--checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \
Expand All @@ -111,6 +112,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--score_file $result_dir/score.txt \
--stats_file $result_dir/stats.${keyword}.txt
done

# plot det curve
python wekws/bin/plot_det_curve.py \
--keywords_dict dict/words.txt \
--stats_dir $result_dir \
--figure_file $result_dir/det.png
fi


Expand Down
Loading