Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KWS with CTCloss training and CTC prefix beam search detection. #135

Merged
merged 14 commits into from
Aug 16, 2023
Merged
57 changes: 57 additions & 0 deletions examples/hi_xiaowen/s0/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
Comparison among different backbones,
all models use Max-Pooling loss.
FRRs with FAR fixed at once per hour:

| model | params(K) | epoch | hi_xiaowen | nihao_wenwen |
@@ -8,3 +10,58 @@ FRRs with FAR fixed at once per hour:
| DS_TCN(spec_aug) | 287 | 80(avg30) | 0.008176 | 0.005075 |
| MDTC | 156 | 80(avg10) | 0.007142 | 0.005920 |
| MDTC_Small | 31 | 80(avg10) | 0.005357 | 0.005920 |

Next, we use CTC loss to train the model, with DS_TCN and FSMN backbones.
and we use CTC prefix beam search to decode and detect keywords,
the detection is either in non-streaming or streaming fashion.

Since the FAR is pretty low when using CTC loss,
the follow results are FRRs with FAR fixed at once per 12 hours:

Comparison between Max-pooling and CTC loss.
The CTC model is fine-tuned with base model pretrained on WenetSpeech(23 epoch, not converged).
FRRs with FAR fixed at once per 12 hours

| model | loss | hi_xiaowen | nihao_wenwen | model ckpt |
|-----------------------|-------------|------------|--------------|------------|
| DS_TCN(spec_aug) | Max-pooling | 0.051217 | 0.021896 | [dstcn-maxpooling](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn/files) |
| DS_TCN(spec_aug) | CTC | 0.056574 | 0.056856 | [dstcn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn_ctc/files) |


Comparison between DS_TCN(Pretrained with Wenetspeech, 23 epoch, not converged)
and FSMN(Pretained with modelscope released xiaoyunxiaoyun model, fully converged).
FRRs with FAR fixed at once per 12 hours:

| model | params(K) | hi_xiaowen | nihao_wenwen | model ckpt |
|-----------------------|-------------|------------|--------------|-------------------------------------------------------------------------------|
| DS_TCN(spec_aug) | 955 | 0.056574 | 0.056856 | [dstcn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_dstcn_ctc/files) |
| FSMN(spec_aug) | 756 | 0.031012 | 0.022460 | [fsmn-ctc](https://modelscope.cn/models/thuduj12/kws_wenwen_fsmn_ctc/files) |

Now, the DSTCN model with CTC loss may not get the best performance, because the
pretraining phase is not sufficiently converged. We recommend you use pretrained
FSMN model as initial checkpoint to train your own model.

Comparison Between stream_score_ctc and score_ctc.
FRRs with FAR fixed at once per 12 hours:

| model | stream | hi_xiaowen | nihao_wenwen |
|-----------------------|-------------|------------|--------------|
| DS_TCN(spec_aug) | no | 0.056574 | 0.056856 |
| DS_TCN(spec_aug) | yes | 0.132694 | 0.057044 |
| FSMN(spec_aug) | no | 0.031012 | 0.022460 |
| FSMN(spec_aug) | yes | 0.115215 | 0.020205 |

Note: when using CTC prefix beam search to detect keywords in streaming case(detect in each frame),
we record the probability of a keyword in a decoding path once the keyword appears in this path.
Actually the probability will increase through the time, so we record a lower value of probability,
which result in a higher False Rejection Rate in Detection Error Tradeoff result.
The actual FRR will be lower than the DET curve gives in a given threshold.

On some small data KWS tasks, we believe the FSMN-CTC model is more robust
compared with the classification model using CE/Max-pooling loss.
For more infomation and results of FSMN-CTC KWS model, you can click [modelscope](https://modelscope.cn/models/damo/speech_charctc_kws_phone-wenwen/summary).

For realtime CTC-KWS, we should process wave input on streaming-fashion,
include feature extraction, keyword decoding and detection and some postprocessing.
Here is a [demo](https://modelscope.cn/studios/thuduj12/KWS_Nihao_Xiaojing/summary) in python,
the core code is in wekws/bin/stream_kws_ctc.py, you can refer it to implement the runtime code.
50 changes: 50 additions & 0 deletions examples/hi_xiaowen/s0/conf/ds_tcn_ctc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 40
frame_shift: 10
frame_length: 25
dither: 1.0
spec_aug: true
spec_aug_conf:
num_t_mask: 1
num_f_mask: 1
max_t: 20
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 256

model:
hidden_dim: 256
preprocessing:
type: linear
backbone:
type: tcn
ds: true
num_layers: 4
kernel_size: 8
dropout: 0.1
activation:
type: identity


optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.0001

training_config:
grad_clip: 5
max_epoch: 80
log_interval: 10
criterion: ctc

50 changes: 50 additions & 0 deletions examples/hi_xiaowen/s0/conf/ds_tcn_ctc_base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 40
frame_shift: 10
frame_length: 25
dither: 1.0
spec_aug: true
spec_aug_conf:
num_t_mask: 1
num_f_mask: 1
max_t: 20
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 200

model:
hidden_dim: 256
preprocessing:
type: linear
backbone:
type: tcn
ds: true
num_layers: 4
kernel_size: 8
dropout: 0.1
activation:
type: identity


optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.0001

training_config:
grad_clip: 5
max_epoch: 50
log_interval: 100
criterion: ctc

64 changes: 64 additions & 0 deletions examples/hi_xiaowen/s0/conf/fsmn_ctc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
dataset_conf:
filter_conf:
max_length: 2048
min_length: 0
resample_conf:
resample_rate: 16000
speed_perturb: false
feature_extraction_conf:
feature_type: 'fbank'
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 1.
context_expansion: true
context_expansion_conf:
left: 2
right: 2
frame_skip: 3
spec_aug: true
spec_aug_conf:
num_t_mask: 1
num_f_mask: 1
max_t: 20
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
batch_conf:
batch_size: 256

model:
input_dim: 400
preprocessing:
type: none
hidden_dim: 128
backbone:
type: fsmn
input_affine_dim: 140
num_layers: 4
linear_dim: 250
proj_dim: 128
left_order: 10
right_order: 2
left_stride: 1
right_stride: 1
output_affine_dim: 140
classifier:
type: identity
dropout: 0.1
activation:
type: identity


optim: adam
optim_conf:
lr: 0.001
weight_decay: 0.0001

training_config:
grad_clip: 5
max_epoch: 80
log_interval: 10
criterion: ctc

11 changes: 9 additions & 2 deletions examples/hi_xiaowen/s0/run.sh
Original file line number Diff line number Diff line change
@@ -3,8 +3,8 @@

. ./path.sh

stage=0
stop_stage=4
stage=$1
stop_stage=$2
num_keywords=2

config=conf/ds_tcn.yaml
@@ -98,6 +98,7 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
python wekws/bin/score.py \
--config $dir/config.yaml \
--test_data data/test/data.list \
--gpu 0 \
--batch_size 256 \
--checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \
@@ -111,6 +112,12 @@ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
--score_file $result_dir/score.txt \
--stats_file $result_dir/stats.${keyword}.txt
done

# plot det curve
python wekws/bin/plot_det_curve.py \
--keywords_dict dict/words.txt \
--stats_dir $result_dir \
--figure_file $result_dir/det.png
fi


223 changes: 223 additions & 0 deletions examples/hi_xiaowen/s0/run_ctc.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
#!/bin/bash
# Copyright 2021 Binbin Zhang(binbzha@qq.com)
# 2023 Jing Du(thuduj12@163.com)

. ./path.sh

stage=$1
stop_stage=$2
num_keywords=2599

config=conf/ds_tcn_ctc.yaml
norm_mean=true
norm_var=true
gpus="0"

checkpoint=
dir=exp/ds_tcn_ctc
average_model=true
num_average=30
if $average_model ;then
score_checkpoint=$dir/avg_${num_average}.pt
else
score_checkpoint=$dir/final.pt
fi

download_dir=/mnt/52_disk/back/DuJing/data/nihaowenwen # your data dir

. tools/parse_options.sh || exit 1;
window_shift=50

#Whether to train base model. If set true, must put train+dev data in trainbase_dir
trainbase=false
trainbase_dir=data/base
trainbase_config=conf/ds_tcn_ctc_base.yaml
trainbase_exp=exp/base

if [ ${stage} -le -3 ] && [ ${stop_stage} -ge -3 ]; then
echo "Download and extracte all datasets"
local/mobvoi_data_download.sh --dl_dir $download_dir
fi


if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then
echo "Preparing datasets..."
mkdir -p dict
echo "<filler> -1" > dict/words.txt
echo "Hi_Xiaowen 0" >> dict/words.txt
echo "Nihao_Wenwen 1" >> dict/words.txt

for folder in train dev test; do
mkdir -p data/$folder
for prefix in p n; do
mkdir -p data/${prefix}_$folder
json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_$folder.json
local/prepare_data.py $download_dir/mobvoi_hotword_dataset $json_path \
data/${prefix}_$folder
done
cat data/p_$folder/wav.scp data/n_$folder/wav.scp > data/$folder/wav.scp
cat data/p_$folder/text data/n_$folder/text > data/$folder/text
rm -rf data/p_$folder data/n_$folder
done
fi

if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
# Here we Use Paraformer Large(https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
# to transcribe the negative wavs, and upload the transcription to modelscope.
git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git
for folder in train dev test; do
if [ -f data/$folder/text ];then
mv data/$folder/text data/$folder/text.label
fi
cp mobvoi_kws_transcription/$folder.text data/$folder/text
done

# and we also copy the tokens and lexicon that used in
# https://modelscope.cn/models/damo/speech_charctc_kws_phone-xiaoyun/summary
cp mobvoi_kws_transcription/tokens.txt data/tokens.txt
cp mobvoi_kws_transcription/lexicon.txt data/lexicon.txt

fi

if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
echo "Compute CMVN and Format datasets"
tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \
--in_scp data/train/wav.scp \
--out_cmvn data/train/global_cmvn

for x in train dev test; do
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur

# Here we use tokens.txt and lexicon.txt to convert txt into index
tools/make_list.py data/$x/wav.scp data/$x/text \
data/$x/wav.dur data/$x/data.list \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt
done
fi

if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ] && [ $trainbase == true ]; then
for x in train dev ; do
if [ ! -f $trainbase_dir/$x/wav.scp ] || [ ! -f $trainbase_dir/$x/text ]; then
echo "If You Want to Train Base KWS-CTC Model, You Should Prepare ASR Data by Yourself."
echo "The wav.scp and text in KALDI-format is Needed, You Should Put Them in $trainbase_dir/$x"
exit
fi
if [ ! -f $trainbase_dir/$x/wav.dur ]; then
tools/wav_to_duration.sh --nj 128 $trainbase_dir/$x/wav.scp $trainbase_dir/$x/wav.dur
fi

# Here we use tokens.txt and lexicon.txt to convert txt into index
if [ ! -f $trainbase_dir/$x/data.list ]; then
tools/make_list.py $trainbase_dir/$x/wav.scp $trainbase_dir/$x/text \
$trainbase_dir/$x/wav.dur $trainbase_dir/$x/data.list \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt
fi
done

echo "Start base training ..."
mkdir -p $trainbase_exp
cmvn_opts=
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')
torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
wekws/bin/train.py --gpus $gpus \
--config $trainbase_config \
--train_data $trainbase_dir/train/data.list \
--cv_data $trainbase_dir/dev/data.list \
--model_dir $trainbase_exp \
--num_workers 2 \
--ddp.dist_backend nccl \
--num_keywords $num_keywords \
--min_duration 50 \
--seed 666 \
$cmvn_opts # \
#--checkpoint $trainbase_exp/23.pt
fi

if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
echo "Start training ..."
mkdir -p $dir
cmvn_opts=
$norm_mean && cmvn_opts="--cmvn_file data/train/global_cmvn"
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')

if $trainbase; then
echo "Use the base model you trained as checkpoint: $trainbase_exp/final.pt"
checkpoint=$trainbase_exp/final.pt
else
echo "Use the base model trained with WenetSpeech as checkpoint: mobvoi_kws_transcription/23.pt"
if [ ! -d mobvoi_kws_transcription ] ;then
git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git
fi
checkpoint=mobvoi_kws_transcription/23.pt # this ckpt may not converge well.
fi

torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
wekws/bin/train.py --gpus $gpus \
--config $config \
--train_data data/train/data.list \
--cv_data data/dev/data.list \
--model_dir $dir \
--num_workers 8 \
--num_keywords $num_keywords \
--min_duration 50 \
--seed 666 \
$cmvn_opts \
${checkpoint:+--checkpoint $checkpoint}
fi

if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Do model average, Compute FRR/FAR ..."
if $average_model; then
python wekws/bin/average_model.py \
--dst_model $score_checkpoint \
--src_path $dir \
--num ${num_average} \
--val_best
fi
result_dir=$dir/test_$(basename $score_checkpoint)
mkdir -p $result_dir
stream=true # we detect keyword online with ctc_prefix_beam_search
score_prefix=""
if $stream ; then
score_prefix=stream_
fi
python wekws/bin/${score_prefix}score_ctc.py \
--config $dir/config.yaml \
--test_data data/test/data.list \
--gpu 0 \
--batch_size 256 \
--checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \
--num_workers 8 \
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt

python wekws/bin/compute_det_ctc.py \
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
--test_data data/test/data.list \
--window_shift $window_shift \
--step 0.001 \
--score_file $result_dir/score.txt \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt
fi


if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
python wekws/bin/export_jit.py \
--config $dir/config.yaml \
--checkpoint $score_checkpoint \
--jit_model $dir/$jit_model
python wekws/bin/export_onnx.py \
--config $dir/config.yaml \
--checkpoint $score_checkpoint \
--onnx_model $dir/$onnx_model
fi
175 changes: 175 additions & 0 deletions examples/hi_xiaowen/s0/run_fsmn_ctc.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
#!/bin/bash
# Copyright 2021 Binbin Zhang(binbzha@qq.com)
# 2023 Jing Du(thuduj12@163.com)

. ./path.sh

stage=$1
stop_stage=$2
num_keywords=2599

config=conf/fsmn_ctc.yaml
norm_mean=true
norm_var=true
gpus="0"

checkpoint=
dir=exp/fsmn_ctc
average_model=true
num_average=30
if $average_model ;then
score_checkpoint=$dir/avg_${num_average}.pt
else
score_checkpoint=$dir/final.pt
fi

download_dir=/mnt/52_disk/back/DuJing/data/nihaowenwen # your data dir

. tools/parse_options.sh || exit 1;
window_shift=50

if [ ${stage} -le -2 ] && [ ${stop_stage} -ge -2 ]; then
echo "Download and extracte all datasets"
local/mobvoi_data_download.sh --dl_dir $download_dir
fi


if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "Preparing datasets..."
mkdir -p dict
echo "<filler> -1" > dict/words.txt
echo "Hi_Xiaowen 0" >> dict/words.txt
echo "Nihao_Wenwen 1" >> dict/words.txt

for folder in train dev test; do
mkdir -p data/$folder
for prefix in p n; do
mkdir -p data/${prefix}_$folder
json_path=$download_dir/mobvoi_hotword_dataset_resources/${prefix}_$folder.json
local/prepare_data.py $download_dir/mobvoi_hotword_dataset $json_path \
data/${prefix}_$folder
done
cat data/p_$folder/wav.scp data/n_$folder/wav.scp > data/$folder/wav.scp
cat data/p_$folder/text data/n_$folder/text > data/$folder/text
rm -rf data/p_$folder data/n_$folder
done
fi

if [ ${stage} -le -0 ] && [ ${stop_stage} -ge -0 ]; then
# Here we Use Paraformer Large(https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
# to transcribe the negative wavs, and upload the transcription to modelscope.
git clone https://www.modelscope.cn/datasets/thuduj12/mobvoi_kws_transcription.git
for folder in train dev test; do
if [ -f data/$folder/text ];then
mv data/$folder/text data/$folder/text.label
fi
cp mobvoi_kws_transcription/$folder.text data/$folder/text
done

# and we also copy the tokens and lexicon that used in
# https://modelscope.cn/models/damo/speech_charctc_kws_phone-xiaoyun/summary
cp mobvoi_kws_transcription/tokens.txt data/tokens.txt
cp mobvoi_kws_transcription/lexicon.txt data/lexicon.txt

fi

if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
echo "Compute CMVN and Format datasets"
tools/compute_cmvn_stats.py --num_workers 16 --train_config $config \
--in_scp data/train/wav.scp \
--out_cmvn data/train/global_cmvn

for x in train dev test; do
tools/wav_to_duration.sh --nj 8 data/$x/wav.scp data/$x/wav.dur

# Here we use tokens.txt and lexicon.txt to convert txt into index
tools/make_list.py data/$x/wav.scp data/$x/text \
data/$x/wav.dur data/$x/data.list \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt
done
fi

if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then

echo "Use the base model from modelscope"
if [ ! -d speech_charctc_kws_phone-xiaoyun ] ;then
git lfs install
git clone https://www.modelscope.cn/damo/speech_charctc_kws_phone-xiaoyun.git
fi
checkpoint=speech_charctc_kws_phone-xiaoyun/train/base.pt
cp speech_charctc_kws_phone-xiaoyun/train/feature_transform.txt.80dim-l2r2 data/global_cmvn.kaldi

echo "Start training ..."
mkdir -p $dir
cmvn_opts=
$norm_mean && cmvn_opts="--cmvn_file data/global_cmvn.kaldi"
$norm_var && cmvn_opts="$cmvn_opts --norm_var"
num_gpus=$(echo $gpus | awk -F ',' '{print NF}')

torchrun --standalone --nnodes=1 --nproc_per_node=$num_gpus \
wekws/bin/train.py --gpus $gpus \
--config $config \
--train_data data/train/data.list \
--cv_data data/dev/data.list \
--model_dir $dir \
--num_workers 8 \
--num_keywords $num_keywords \
--min_duration 50 \
--seed 666 \
$cmvn_opts \
${checkpoint:+--checkpoint $checkpoint}
fi

if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
echo "Do model average, Compute FRR/FAR ..."
if $average_model; then
python wekws/bin/average_model.py \
--dst_model $score_checkpoint \
--src_path $dir \
--num ${num_average} \
--val_best
fi
result_dir=$dir/test_$(basename $score_checkpoint)
mkdir -p $result_dir
stream=true # we detect keyword online with ctc_prefix_beam_search
score_prefix=""
if $stream ; then
score_prefix=stream_
fi
python wekws/bin/${score_prefix}score_ctc.py \
--config $dir/config.yaml \
--test_data data/test/data.list \
--gpu 0 \
--batch_size 256 \
--checkpoint $score_checkpoint \
--score_file $result_dir/score.txt \
--num_workers 8 \
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt

python wekws/bin/compute_det_ctc.py \
--keywords "\u55e8\u5c0f\u95ee,\u4f60\u597d\u95ee\u95ee" \
--test_data data/test/data.list \
--window_shift $window_shift \
--step 0.001 \
--score_file $result_dir/score.txt \
--token_file data/tokens.txt \
--lexicon_file data/lexicon.txt
fi


if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
jit_model=$(basename $score_checkpoint | sed -e 's:.pt$:.zip:g')
onnx_model=$(basename $score_checkpoint | sed -e 's:.pt$:.onnx:g')
# For now, FSMN can not export to JITScript
# python wekws/bin/export_jit.py \
# --config $dir/config.yaml \
# --checkpoint $score_checkpoint \
# --jit_model $dir/$jit_model
python wekws/bin/export_onnx.py \
--config $dir/config.yaml \
--checkpoint $score_checkpoint \
--onnx_model $dir/$onnx_model
fi
169 changes: 167 additions & 2 deletions tools/make_list.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2023 Jing Du(thuduj12@163.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,14 +16,156 @@
# limitations under the License.

import argparse
import logging
import json
import re

symbol_str = '[’!"#$%&\'()*+,-./:;<>=?@,。?★、…【】《》?“”‘’![\\]^_`{|}~]+'

def split_mixed_label(input_str):
tokens = []
s = input_str.lower()
while len(s) > 0:
match = re.match(r'[A-Za-z!?,<>()\']+', s)
if match is not None:
word = match.group(0)
else:
word = s[0:1]
tokens.append(word)
s = s.replace(word, '', 1).strip(' ')
return tokens

def query_token_set(txt, symbol_table, lexicon_table):
tokens_str = tuple()
tokens_idx = tuple()

parts = split_mixed_label(txt)
for part in parts:
if part == '!sil' or part == '(sil)' or part == '<sil>':
tokens_str = tokens_str + ('!sil', )
elif part == '<blk>' or part == '<blank>':
tokens_str = tokens_str + ('<blk>', )
elif part == '(noise)' or part == 'noise)' or \
part == '(noise' or part == '<noise>':
tokens_str = tokens_str + ('<GBG>', )
elif part in symbol_table:
tokens_str = tokens_str + (part, )
elif part in lexicon_table:
for ch in lexicon_table[part]:
tokens_str = tokens_str + (ch, )
else:
# case with symbols or meaningless english letter combination
part = re.sub(symbol_str, '', part)
for ch in part:
tokens_str = tokens_str + (ch, )

for ch in tokens_str:
if ch in symbol_table:
tokens_idx = tokens_idx + (symbol_table[ch], )
elif ch == '!sil':
if 'sil' in symbol_table:
tokens_idx = tokens_idx + (symbol_table['sil'], )
else:
tokens_idx = tokens_idx + (symbol_table['<blk>'], )
elif ch == '<GBG>':
if '<GBG>' in symbol_table:
tokens_idx = tokens_idx + (symbol_table['<GBG>'], )
else:
tokens_idx = tokens_idx + (symbol_table['<blk>'], )
else:
if '<GBG>' in symbol_table:
tokens_idx = tokens_idx + (symbol_table['<GBG>'], )
logging.info(
f'{ch} is not in token set, replace with <GBG>')
else:
tokens_idx = tokens_idx + (symbol_table['<blk>'], )
logging.info(
f'{ch} is not in token set, replace with <blk>')

return tokens_str, tokens_idx


def query_token_list(txt, symbol_table, lexicon_table):
tokens_str = []
tokens_idx = []

parts = split_mixed_label(txt)
for part in parts:
if part == '!sil' or part == '(sil)' or part == '<sil>':
tokens_str.append('!sil')
elif part == '<blk>' or part == '<blank>':
tokens_str.append('<blk>')
elif part == '(noise)' or part == 'noise)' or \
part == '(noise' or part == '<noise>':
tokens_str.append('<GBG>')
elif part in symbol_table:
tokens_str.append(part)
elif part in lexicon_table:
for ch in lexicon_table[part]:
tokens_str.append(ch)
else:
# case with symbols or meaningless english letter combination
part = re.sub(symbol_str, '', part)
for ch in part:
tokens_str.append(ch)

for ch in tokens_str:
if ch in symbol_table:
tokens_idx.append(symbol_table[ch])
elif ch == '!sil':
if 'sil' in symbol_table:
tokens_idx.append(symbol_table['sil'])
else:
tokens_idx.append(symbol_table['<blk>'])
elif ch == '<GBG>':
if '<GBG>' in symbol_table:
tokens_idx.append(symbol_table['<GBG>'])
else:
tokens_idx.append(symbol_table['<blk>'])
else:
if '<GBG>' in symbol_table:
tokens_idx.append(symbol_table['<GBG>'])
logging.info(
f'{ch} is not in token set, replace with <GBG>')
else:
tokens_idx.append(symbol_table['<blk>'])
logging.info(
f'{ch} is not in token set, replace with <blk>')

return tokens_str, tokens_idx

def read_token(token_file):
tokens_table = {}
with open(token_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
assert len(arr) == 2
tokens_table[arr[0]] = int(arr[1]) - 1
fin.close()
return tokens_table


def read_lexicon(lexicon_file):
lexicon_table = {}
with open(lexicon_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().replace('\t', ' ').split()
assert len(arr) >= 2
lexicon_table[arr[0]] = arr[1:]
fin.close()
return lexicon_table


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='')
parser.add_argument('wav_file', help='wav file')
parser.add_argument('text_file', help='text file')
parser.add_argument('duration_file', help='duration file')
parser.add_argument('output_file', help='output list file')
parser.add_argument('--token_file', type=str, default=None,
help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None,
help='the path of lexicon.txt')
args = parser.parse_args()

wav_table = {}
@@ -39,16 +182,38 @@
assert len(arr) == 2
duration_table[arr[0]] = float(arr[1])

token_table = None
if args.token_file:
token_table = read_token(args.token_file)
lexicon_table = None
if args.lexicon_file:
lexicon_table = read_lexicon(args.lexicon_file)

with open(args.text_file, 'r', encoding='utf8') as fin, \
open(args.output_file, 'w', encoding='utf8') as fout:
for line in fin:
arr = line.strip().split(maxsplit=1)
key = arr[0]
txt = int(arr[1])
tokens = None
if token_table is not None and lexicon_table is not None :
if len(arr) < 2: # for some utterence, no text
txt = [1] # the <blank>/sil is indexed by 1
tokens = ["sil"]
else:
tokens, txt = query_token_list(arr[1],
token_table,
lexicon_table)
else:
txt = int(arr[1])
assert key in wav_table
wav = wav_table[key]
assert key in duration_table
duration = duration_table[key]
line = dict(key=key, txt=txt, duration=duration, wav=wav)
if tokens is None:
line = dict(key=key, txt=txt, duration=duration, wav=wav)
else:
line = dict(key=key, tok=tokens, txt=txt,
duration=duration, wav=wav)

json_line = json.dumps(line, ensure_ascii=False)
fout.write(json_line + '\n')
271 changes: 271 additions & 0 deletions wekws/bin/compute_det_ctc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
# 2022 Shaoqing Yu(954793264@qq.com)
# 2023 Jing Du(thuduj12@163.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import logging
import glob
import json
import re
import os
import numpy as np
import matplotlib.pyplot as plt
import pypinyin # for Chinese Character
from tools.make_list import query_token_set, read_lexicon, read_token

def split_mixed_label(input_str):
tokens = []
s = input_str.lower()
while len(s) > 0:
match = re.match(r'[A-Za-z!?,<>()\']+', s)
if match is not None:
word = match.group(0)
else:
word = s[0:1]
tokens.append(word)
s = s.replace(word, '', 1).strip(' ')
return tokens


def space_mixed_label(input_str):
splits = split_mixed_label(input_str)
space_str = ''.join(f'{sub} ' for sub in splits)
return space_str.strip()

def load_label_and_score(keywords_list, label_file, score_file, true_keywords):
score_table = {}
with open(score_file, 'r', encoding='utf8') as fin:
# read score file and store in table
for line in fin:
arr = line.strip().split()
key = arr[0]
is_detected = arr[1]
if is_detected == 'detected':
keyword = true_keywords[arr[2]]
if key not in score_table:
score_table.update({
key: {
'kw': space_mixed_label(keyword),
'confi': float(arr[3])
}
})
else:
if key not in score_table:
score_table.update({key: {'kw': 'unknown', 'confi': -1.0}})

label_lists = []
with open(label_file, 'r', encoding='utf8') as fin:
for line in fin:
obj = json.loads(line.strip())
label_lists.append(obj)

# build empty structure for keyword-filler infos
keyword_filler_table = {}
for keyword in keywords_list:
keyword = true_keywords[keyword]
keyword = space_mixed_label(keyword)
keyword_filler_table[keyword] = {}
keyword_filler_table[keyword]['keyword_table'] = {}
keyword_filler_table[keyword]['keyword_duration'] = 0.0
keyword_filler_table[keyword]['filler_table'] = {}
keyword_filler_table[keyword]['filler_duration'] = 0.0

for obj in label_lists:
assert 'key' in obj
assert 'wav' in obj
assert 'tok' in obj # here we use the tokens
assert 'duration' in obj

key = obj['key']
txt = "".join(obj['tok'])
txt = space_mixed_label(txt)
txt_regstr_lrblk = ' ' + txt + ' '
duration = obj['duration']
assert key in score_table

for keyword in keywords_list:
keyword = true_keywords[keyword]
keyword = space_mixed_label(keyword)
keyword_regstr_lrblk = ' ' + keyword + ' '
if txt_regstr_lrblk.find(keyword_regstr_lrblk) != -1:
if keyword == score_table[key]['kw']:
keyword_filler_table[keyword]['keyword_table'].update(
{key: score_table[key]['confi']})
else:
# uttrance detected but not match this keyword
keyword_filler_table[keyword]['keyword_table'].update(
{key: -1.0})
keyword_filler_table[keyword]['keyword_duration'] += duration
else:
if keyword == score_table[key]['kw']:
keyword_filler_table[keyword]['filler_table'].update(
{key: score_table[key]['confi']})
else:
# uttrance if detected, which is not FA for this keyword
keyword_filler_table[keyword]['filler_table'].update(
{key: -1.0})
keyword_filler_table[keyword]['filler_duration'] += duration

return keyword_filler_table

def load_stats_file(stats_file):
values = []
with open(stats_file, 'r', encoding='utf8') as fin:
for line in fin:
arr = line.strip().split()
threshold, fa_per_hour, frr = arr
values.append([float(fa_per_hour), float(frr) * 100])
values.reverse()
return np.array(values)

def plot_det(dets_dir, figure_file, xlim=5, x_step=1, ylim=35, y_step=5):
det_title = "DetCurve"
plt.figure(dpi=200)
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['font.size'] = 12

for file in glob.glob(f'{dets_dir}/*stats*.txt'):
logging.info(f'reading det data from {file}')
label = os.path.basename(file).split('.')[1]
label = "".join(pypinyin.lazy_pinyin(label))
values = load_stats_file(file)
plt.plot(values[:, 0], values[:, 1], label=label)

plt.xlim([0, xlim])
plt.ylim([0, ylim])
plt.xticks(range(0, xlim + x_step, x_step))
plt.yticks(range(0, ylim + y_step, y_step))
plt.xlabel('False Alarm Per Hour')
plt.ylabel('False Rejection Rate (%)')
plt.grid(linestyle='--')
plt.legend(loc='best', fontsize=6)
plt.savefig(figure_file)

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='compute det curve')
parser.add_argument('--test_data', required=True, help='label file')
parser.add_argument('--keywords', type=str, default=None,
help='keywords, split with comma(,)')
parser.add_argument('--token_file', type=str, default=None,
help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None,
help='the path of lexicon.txt')
parser.add_argument('--score_file', required=True, help='score file')
parser.add_argument('--step', type=float, default=0.01,
help='threshold step')
parser.add_argument('--window_shift', type=int, default=50,
help='window_shift is used to '
'skip the frames after triggered')
parser.add_argument('--stats_dir',
required=False,
default=None,
help='false reject/alarm stats dir, '
'default in score_file')
parser.add_argument('--det_curve_path',
required=False,
default=None,
help='det curve path, default is stats_dir/det.png')
parser.add_argument(
'--xlim',
type=int,
default=5,
help='xlim:range of x-axis, x is false alarm per hour')
parser.add_argument('--x_step', type=int, default=1, help='step on x-axis')
parser.add_argument(
'--ylim',
type=int,
default=35,
help='ylim:range of y-axis, y is false rejection rate')
parser.add_argument('--y_step', type=int, default=5, help='step on y-axis')

args = parser.parse_args()
window_shift = args.window_shift
logging.info(f"keywords is {args.keywords}, "
f"Chinese is converted into Unicode.")

keywords = args.keywords.encode('utf-8').decode('unicode_escape')
keywords_list = keywords.strip().split(',')

token_table = read_token(args.token_file)
lexicon_table = read_lexicon(args.lexicon_file)
true_keywords = {}
for keyword in keywords_list:
strs, indexes = query_token_set(keyword, token_table, lexicon_table)
true_keywords[keyword] = ''.join(strs)

keyword_filler_table = load_label_and_score(
keywords_list, args.test_data, args.score_file, true_keywords)

for keyword in keywords_list:
keyword = true_keywords[keyword]
keyword = space_mixed_label(keyword)
keyword_dur = keyword_filler_table[keyword]['keyword_duration']
keyword_num = len(keyword_filler_table[keyword]['keyword_table'])
filler_dur = keyword_filler_table[keyword]['filler_duration']
filler_num = len(keyword_filler_table[keyword]['filler_table'])
assert keyword_num > 0, \
'Can\'t compute det for {} without positive sample'
assert filler_num > 0, \
'Can\'t compute det for {} without negative sample'

logging.info('Computing det for {}'.format(keyword))
logging.info(' Keyword duration: {} Hours, wave number: {}'.format(
keyword_dur / 3600.0, keyword_num))
logging.info(' Filler duration: {} Hours'.format(filler_dur / 3600.0))

if args.stats_dir :
stats_dir = args.stats_dir
else:
stats_dir = os.path.dirname(args.score_file)
stats_file = os.path.join(
stats_dir, 'stats.' + keyword.replace(' ', '_') + '.txt')
with open(stats_file, 'w', encoding='utf8') as fout:
threshold = 0.0
while threshold <= 1.0:
num_false_reject = 0
num_true_detect = 0
# transverse the all keyword_table
for key, confi in \
keyword_filler_table[keyword]['keyword_table'].items():
if confi < threshold:
num_false_reject += 1
else:
num_true_detect += 1

num_false_alarm = 0
# transverse the all filler_table
for key, confi in keyword_filler_table[
keyword]['filler_table'].items():
if confi >= threshold:
num_false_alarm += 1
# print(f'false alarm: {keyword}, {key}, {confi}')

false_reject_rate = num_false_reject / keyword_num
true_detect_rate = num_true_detect / keyword_num

num_false_alarm = max(num_false_alarm, 1e-6)
false_alarm_per_hour = num_false_alarm / (filler_dur / 3600.0)
false_alarm_rate = num_false_alarm / filler_num

fout.write('{:.3f} {:.6f} {:.6f}\n'.format(
threshold, false_alarm_per_hour, false_reject_rate))
threshold += args.step
if args.det_curve_path :
det_curve_path = args.det_curve_path
else:
det_curve_path = os.path.join(stats_dir, 'det.png')
plot_det(stats_dir, det_curve_path,
args.xlim, args.x_step, args.ylim, args.y_step)
3 changes: 3 additions & 0 deletions wekws/bin/export_onnx.py
Original file line number Diff line number Diff line change
@@ -41,6 +41,9 @@ def main():
configs = yaml.load(fin, Loader=yaml.FullLoader)
feature_dim = configs['model']['input_dim']
model = init_model(configs['model'])
if configs['training_config'].get('criterion', 'max_pooling') == 'ctc':
# if we use ctc_loss, the logits need to be convert into probs
model.forward = model.forward_softmax
print(model)

load_checkpoint(model, args.checkpoint)
2 changes: 1 addition & 1 deletion wekws/bin/score.py
Original file line number Diff line number Diff line change
@@ -106,7 +106,7 @@ def main():
score_abs_path = os.path.abspath(args.score_file)
with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout:
for batch_idx, batch in enumerate(test_data_loader):
keys, feats, target, lengths = batch
keys, feats, target, lengths, target_lengths = batch
feats = feats.to(device)
lengths = lengths.to(device)
logits, _ = model(feats)
219 changes: 219 additions & 0 deletions wekws/bin/score_ctc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
# Copyright (c) 2021 Binbin Zhang(binbzha@qq.com)
# 2022 Shaoqing Yu(954793264@qq.com)
# 2023 Jing Du(thuduj12@163.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import argparse
import copy
import logging
import os
import sys
import math

import torch
import yaml
from torch.utils.data import DataLoader

from wekws.dataset.dataset import Dataset
from wekws.model.kws_model import init_model
from wekws.utils.checkpoint import load_checkpoint
from wekws.model.loss import ctc_prefix_beam_search
from tools.make_list import query_token_set, read_lexicon, read_token

def get_args():
parser = argparse.ArgumentParser(description='recognize with your model')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--test_data', required=True, help='test data file')
parser.add_argument('--gpu',
type=int,
default=-1,
help='gpu id for this rank, -1 for cpu')
parser.add_argument('--checkpoint', required=True, help='checkpoint model')
parser.add_argument('--batch_size',
default=16,
type=int,
help='batch size for inference')
parser.add_argument('--num_workers',
default=0,
type=int,
help='num of subprocess workers for reading')
parser.add_argument('--pin_memory',
action='store_true',
default=False,
help='Use pinned memory buffers used for reading')
parser.add_argument('--prefetch',
default=100,
type=int,
help='prefetch number')
parser.add_argument('--score_file',
required=True,
help='output score file')
parser.add_argument('--jit_model',
action='store_true',
default=False,
help='Use pinned memory buffers used for reading')
parser.add_argument('--keywords', type=str, default=None,
help='the keywords, split with comma(,)')
parser.add_argument('--token_file', type=str, default=None,
help='the path of tokens.txt')
parser.add_argument('--lexicon_file', type=str, default=None,
help='the path of lexicon.txt')

args = parser.parse_args()
return args

def is_sublist(main_list, check_list):
if len(main_list) < len(check_list):
return -1

if len(main_list) == len(check_list):
return 0 if main_list == check_list else -1

for i in range(len(main_list) - len(check_list)):
if main_list[i] == check_list[0]:
for j in range(len(check_list)):
if main_list[i + j] != check_list[j]:
break
else:
return i
else:
return -1


def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

with open(args.config, 'r') as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)

test_conf = copy.deepcopy(configs['dataset_conf'])
test_conf['filter_conf']['max_length'] = 102400
test_conf['filter_conf']['min_length'] = 0
test_conf['speed_perturb'] = False
test_conf['spec_aug'] = False
test_conf['shuffle'] = False
test_conf['feature_extraction_conf']['dither'] = 0.0
test_conf['batch_conf']['batch_size'] = args.batch_size

test_dataset = Dataset(args.test_data, test_conf)
test_data_loader = DataLoader(test_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
prefetch_factor=args.prefetch)

if args.jit_model:
model = torch.jit.load(args.checkpoint)
# For script model, only cpu is supported.
device = torch.device('cpu')
else:
# Init asr model from configs
model = init_model(configs['model'])
load_checkpoint(model, args.checkpoint)
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
model = model.to(device)
model.eval()
score_abs_path = os.path.abspath(args.score_file)

token_table = read_token(args.token_file)
lexicon_table = read_lexicon(args.lexicon_file)
# 4. parse keywords tokens
assert args.keywords is not None, 'at least one keyword is needed'
logging.info(f"keywords is {args.keywords}, "
f"Chinese is converted into Unicode.")
keywords_str = args.keywords.encode('utf-8').decode('unicode_escape')
keywords_list = keywords_str.strip().replace(' ', '').split(',')
keywords_token = {}
keywords_idxset = {0}
keywords_strset = {'<blk>'}
keywords_tokenmap = {'<blk>': 0}
for keyword in keywords_list:
strs, indexes = query_token_set(keyword, token_table, lexicon_table)
keywords_token[keyword] = {}
keywords_token[keyword]['token_id'] = indexes
keywords_token[keyword]['token_str'] = ''.join('%s ' % str(i)
for i in indexes)
[keywords_strset.add(i) for i in strs]
[keywords_idxset.add(i) for i in indexes]
for txt, idx in zip(strs, indexes):
if keywords_tokenmap.get(txt, None) is None:
keywords_tokenmap[txt] = idx

token_print = ''
for txt, idx in keywords_tokenmap.items():
token_print += f'{txt}({idx}) '
logging.info(f'Token set is: {token_print}')

with torch.no_grad(), open(score_abs_path, 'w', encoding='utf8') as fout:
for batch_idx, batch in enumerate(test_data_loader):
keys, feats, target, lengths, target_lengths = batch
feats = feats.to(device)
lengths = lengths.to(device)
logits, _ = model(feats)
logits = logits.softmax(2) # (batch_size, maxlen, vocab_size)
logits = logits.cpu()
for i in range(len(keys)):
key = keys[i]
score = logits[i][:lengths[i]]
hyps = ctc_prefix_beam_search(score,
lengths[i],
keywords_idxset)
hit_keyword = None
hit_score = 1.0
start = 0
end = 0
for one_hyp in hyps:
prefix_ids = one_hyp[0]
# path_score = one_hyp[1]
prefix_nodes = one_hyp[2]
assert len(prefix_ids) == len(prefix_nodes)
for word in keywords_token.keys():
lab = keywords_token[word]['token_id']
offset = is_sublist(prefix_ids, lab)
if offset != -1:
hit_keyword = word
start = prefix_nodes[offset]['frame']
end = prefix_nodes[offset + len(lab) - 1]['frame']
for idx in range(offset, offset + len(lab)):
hit_score *= prefix_nodes[idx]['prob']
break
if hit_keyword is not None:
hit_score = math.sqrt(hit_score)
break

if hit_keyword is not None:
fout.write('{} detected {} {:.3f}\n'.format(
key, hit_keyword, hit_score))
logging.info(
f"batch:{batch_idx}_{i} detect {hit_keyword} "
f"in {key} from {start} to {end} frame. "
f"duration {end - start}, "
f"score {hit_score}, Activated.")
else:
fout.write('{} rejected\n'.format(key))
logging.info(f"batch:{batch_idx}_{i} {key} Deactivated.")

if batch_idx % 10 == 0:
print('Progress batch {}'.format(batch_idx))
sys.stdout.flush()


if __name__ == '__main__':
main()
587 changes: 587 additions & 0 deletions wekws/bin/stream_kws_ctc.py

Large diffs are not rendered by default.

363 changes: 363 additions & 0 deletions wekws/bin/stream_score_ctc.py

Large diffs are not rendered by default.

15 changes: 12 additions & 3 deletions wekws/bin/train.py
Original file line number Diff line number Diff line change
@@ -134,7 +134,8 @@ def main():
output_dim = args.num_keywords

# Write model_dir/config.yaml for inference and export
configs['model']['input_dim'] = input_dim
if 'input_dim' not in configs['model']:
configs['model']['input_dim'] = input_dim
configs['model']['output_dim'] = output_dim
if args.cmvn_file is not None:
configs['model']['cmvn'] = {}
@@ -156,8 +157,16 @@ def main():
# Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements
if rank == 0:
script_model = torch.jit.script(model)
script_model.save(os.path.join(args.model_dir, 'init.zip'))
pass
# TODO: for now streaming FSMN do not support export to JITScript,
# TODO: because there is nn.Sequential with Tuple input
# in current FSMN modules.
# the issue is in https://stackoverflow.com/questions/75714299/
# pytorch-jit-script-error-when-sequential-container-
# takes-a-tuple-input/76553450#76553450

# script_model = torch.jit.script(model)
# script_model.save(os.path.join(args.model_dir, 'init.zip'))
executor = Executor()
# If specify checkpoint, load some info from checkpoint
if args.checkpoint is not None:
10 changes: 10 additions & 0 deletions wekws/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -162,6 +162,16 @@ def Dataset(data_list_file, conf,
spec_aug_conf = conf.get('spec_aug_conf', {})
dataset = Processor(dataset, processor.spec_aug, **spec_aug_conf)

context_expansion = conf.get('context_expansion', False)
if context_expansion:
context_expansion_conf = conf.get('context_expansion_conf', {})
dataset = Processor(dataset, processor.context_expansion,
**context_expansion_conf)

frame_skip = conf.get('frame_skip', 1)
if frame_skip > 1:
dataset = Processor(dataset, processor.frame_skip, frame_skip)

if shuffle:
shuffle_conf = conf.get('shuffle_conf', {})
dataset = Processor(dataset, processor.shuffle, **shuffle_conf)
67 changes: 64 additions & 3 deletions wekws/dataset/processor.py
Original file line number Diff line number Diff line change
@@ -263,6 +263,51 @@ def shuffle(data, shuffle_size=1000):
for x in buf:
yield x

def context_expansion(data, left=1, right=1):
""" expand left and right frames
Args:
data: Iterable[{key, feat, label}]
left (int): feature left context frames
right (int): feature right context frames
Returns:
data: Iterable[{key, feat, label}]
"""
for sample in data:
index = 0
feats = sample['feat']
ctx_dim = feats.shape[0]
ctx_frm = feats.shape[1] * (left + right + 1)
feats_ctx = torch.zeros(ctx_dim, ctx_frm, dtype=torch.float32)
for lag in range(-left, right + 1):
feats_ctx[:, index:index + feats.shape[1]] = torch.roll(
feats, -lag, 0)
index = index + feats.shape[1]

# replication pad left margin
for idx in range(left):
for cpx in range(left - idx):
feats_ctx[idx, cpx * feats.shape[1]:(cpx + 1)
* feats.shape[1]] = feats_ctx[left, :feats.shape[1]]

feats_ctx = feats_ctx[:feats_ctx.shape[0] - right]
sample['feat'] = feats_ctx
yield sample


def frame_skip(data, skip_rate=1):
""" skip frame
Args:
data: Iterable[{key, feat, label}]
skip_rate (int): take every N-frames for model input
Returns:
data: Iterable[{key, feat, label}]
"""
for sample in data:
feats_skip = sample['feat'][::skip_rate, :]
sample['feat'] = feats_skip
yield sample

def batch(data, batch_size=16):
""" Static batch the data by `batch_size`
@@ -302,12 +347,24 @@ def padding(data):
[sample[i]['feat'].size(0) for i in order], dtype=torch.int32)
sorted_feats = [sample[i]['feat'] for i in order]
sorted_keys = [sample[i]['key'] for i in order]
sorted_labels = torch.tensor([sample[i]['label'] for i in order],
dtype=torch.int64)
padded_feats = pad_sequence(sorted_feats,
batch_first=True,
padding_value=0)
yield (sorted_keys, padded_feats, sorted_labels, feats_lengths)

if isinstance(sample[0]['label'], int):
padded_labels = torch.tensor([sample[i]['label'] for i in order],
dtype=torch.int32)
label_lengths = torch.tensor([1 for i in order],
dtype=torch.int32)
else:
sorted_labels = [
torch.tensor(sample[i]['label'], dtype=torch.int32) for i in order
]
label_lengths = torch.tensor([len(sample[i]['label']) for i in order],
dtype=torch.int32)
padded_labels = pad_sequence(
sorted_labels, batch_first=True, padding_value=-1)
yield (sorted_keys, padded_feats, padded_labels, feats_lengths, label_lengths)


def add_reverb(data, reverb_source, aug_prob):
@@ -320,6 +377,8 @@ def add_reverb(data, reverb_source, aug_prob):
rir_io = io.BytesIO(rir_data)
_, rir_audio = wavfile.read(rir_io)
rir_audio = rir_audio.astype(np.float32)
if len(rir_audio.shape) > 1:
rir_audio = rir_audio[:, 0]
rir_audio = rir_audio / np.sqrt(np.sum(rir_audio**2))
out_audio = signal.convolve(audio, rir_audio,
mode='full')[:audio_len]
@@ -348,6 +407,8 @@ def add_noise(data, noise_source, aug_prob):
snr_range = [0, 15]
_, noise_audio = wavfile.read(io.BytesIO(noise_data))
noise_audio = noise_audio.astype(np.float32)
if len(noise_audio.shape) > 1:
noise_audio = noise_audio[:, 0]
if noise_audio.shape[0] > audio_len:
start = random.randint(0, noise_audio.shape[0] - audio_len)
noise_audio = noise_audio[start:start + audio_len]
558 changes: 558 additions & 0 deletions wekws/model/fsmn.py

Large diffs are not rendered by default.

52 changes: 49 additions & 3 deletions wekws/model/kws_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2021 Binbin Zhang
# 2023 Jing Du
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,14 +26,15 @@
NoSubsampling)
from wekws.model.tcn import TCN, CnnBlock, DsCnnBlock
from wekws.model.mdtc import MDTC
from wekws.utils.cmvn import load_cmvn
from wekws.utils.cmvn import load_cmvn, load_kaldi_cmvn
from wekws.model.fsmn import FSMN


class KWSModel(nn.Module):
"""Our model consists of four parts:
1. global_cmvn: Optional, (idim, idim)
2. preprocessing: feature dimention projection, (idim, hdim)
3. backbone: backbone or feature extractor of the whole network, (hdim, hdim)
3. backbone: backbone of the whole network, (hdim, hdim)
4. classifier: output layer or classifier of KWS model, (hdim, odim)
5. activation:
nn.Sigmoid for wakeup word
@@ -72,6 +74,20 @@ def forward(
x = self.activation(x)
return x, out_cache

def forward_softmax(self,
x: torch.Tensor,
in_cache: torch.Tensor = torch.zeros(
0, 0, 0, dtype=torch.float)
) -> Tuple[torch.Tensor, torch.Tensor]:
if self.global_cmvn is not None:
x = self.global_cmvn(x)
x = self.preprocessing(x)
x, out_cache = self.backbone(x, in_cache)
x = self.classifier(x)
x = self.activation(x)
x = x.softmax(2)
return x, out_cache

def fuse_modules(self):
self.preprocessing.fuse_modules()
self.backbone.fuse_modules()
@@ -80,7 +96,10 @@ def fuse_modules(self):
def init_model(configs):
cmvn = configs.get('cmvn', {})
if 'cmvn_file' in cmvn and cmvn['cmvn_file'] is not None:
mean, istd = load_cmvn(cmvn['cmvn_file'])
if "kaldi" in cmvn['cmvn_file']:
mean, istd = load_kaldi_cmvn(cmvn['cmvn_file'])
else:
mean, istd = load_cmvn(cmvn['cmvn_file'])
global_cmvn = GlobalCMVN(
torch.from_numpy(mean).float(),
torch.from_numpy(istd).float(),
@@ -135,6 +154,20 @@ def init_model(configs):
hidden_dim,
kernel_size,
causal=causal)
elif backbone_type == 'fsmn':
input_affine_dim = configs['backbone']['input_affine_dim']
num_layers = configs['backbone']['num_layers']
linear_dim = configs['backbone']['linear_dim']
proj_dim = configs['backbone']['proj_dim']
left_order = configs['backbone']['left_order']
right_order = configs['backbone']['right_order']
left_stride = configs['backbone']['left_stride']
right_stride = configs['backbone']['right_stride']
output_affine_dim = configs['backbone']['output_affine_dim']
backbone = FSMN(input_dim, input_affine_dim, num_layers, linear_dim,
proj_dim, left_order, right_order, left_stride,
right_stride, output_affine_dim, output_dim)

else:
print('Unknown body type {}'.format(backbone_type))
sys.exit(1)
@@ -154,6 +187,8 @@ def init_model(configs):
# last means we use last frame to do backpropagation, so the model
# can be infered streamingly
classifier = LastClassifier(classifier_base)
elif classifier_type == 'identity':
classifier = nn.Identity()
else:
print('Unknown classifier type {}'.format(classifier_type))
sys.exit(1)
@@ -162,6 +197,17 @@ def init_model(configs):
classifier = LinearClassifier(hidden_dim, output_dim)
activation = nn.Sigmoid()

# Here we add a possible "activation_type",
# one can choose to use other activation function.
# We use nn.Identity just for CTC loss
if "activation" in configs:
activation_type = configs["activation"]["type"]
if activation_type == 'identity':
activation = nn.Identity()
else:
print('Unknown activation type {}'.format(activation_type))
sys.exit(1)

kws_model = KWSModel(input_dim, output_dim, hidden_dim, global_cmvn,
preprocessing, backbone, classifier, activation)
return kws_model
338 changes: 337 additions & 1 deletion wekws/model/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2021 Binbin Zhang
# 2023 Jing Du
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,7 +14,11 @@
# limitations under the License.

import torch
import math
import sys
import torch.nn.functional as F
from collections import defaultdict
from typing import List, Tuple

from wekws.utils.mask import padding_mask

@@ -93,6 +98,65 @@ def acc_frame(
correct = pred.eq(target.long().view_as(pred)).sum().item()
return correct * 100.0 / logits.size(0)

def acc_utterance(logits: torch.Tensor, target: torch.Tensor,
logits_length: torch.Tensor, target_length: torch.Tensor):
if logits is None:
return 0

logits = logits.softmax(2) # (1, maxlen, vocab_size)
logits = logits.cpu()
target = target.cpu()

total_word = 0
total_ins = 0
total_sub = 0
total_del = 0
calculator = Calculator()
for i in range(logits.size(0)):
score = logits[i][:logits_length[i]]
hyps = ctc_prefix_beam_search(score, logits_length[i], None, 3, 5)
lab = [str(item) for item in target[i][:target_length[i]].tolist()]
rec = []
if len(hyps) > 0:
rec = [str(item) for item in hyps[0][0]]
result = calculator.calculate(lab, rec)
# print(f'result:{result}')
if result['all'] != 0:
total_word += result['all']
total_ins += result['ins']
total_sub += result['sub']
total_del += result['del']

return float(total_word - total_ins - total_sub
- total_del) * 100.0 / total_word

def ctc_loss(logits: torch.Tensor,
target: torch.Tensor,
logits_lengths: torch.Tensor,
target_lengths: torch.Tensor,
need_acc: bool = False):
""" CTC Loss
Args:
logits: (B, D), D is the number of keywords plus 1 (non-keyword)
target: (B)
logits_lengths: (B)
target_lengths: (B)
Returns:
(float): loss of current batch
"""

acc = 0.0
if need_acc:
acc = acc_utterance(logits, target, logits_lengths, target_lengths)

# logits: (B, L, D) -> (L, B, D)
logits = logits.transpose(0, 1)
logits = logits.log_softmax(2)
loss = F.ctc_loss(
logits, target, logits_lengths, target_lengths, reduction='sum')
loss = loss / logits.size(1) # batch mean

return loss, acc

def cross_entropy(logits: torch.Tensor, target: torch.Tensor):
""" Cross Entropy Loss
@@ -114,12 +178,284 @@ def criterion(type: str,
logits: torch.Tensor,
target: torch.Tensor,
lengths: torch.Tensor,
min_duration: int = 0):
target_lengths: torch.Tensor = None,
min_duration: int = 0,
validation: bool = False, ):
if type == 'ce':
loss, acc = cross_entropy(logits, target)
return loss, acc
elif type == 'max_pooling':
loss, acc = max_pooling_loss(logits, target, lengths, min_duration)
return loss, acc
elif type == 'ctc':
loss, acc = ctc_loss(
logits, target, lengths, target_lengths, validation)
return loss, acc
else:
exit(1)

def ctc_prefix_beam_search(
logits: torch.Tensor,
logits_lengths: torch.Tensor,
keywords_tokenset: set = None,
score_beam_size: int = 3,
path_beam_size: int = 20,
) -> Tuple[List[List[int]], torch.Tensor]:
""" CTC prefix beam search inner implementation
Args:
logits (torch.Tensor): (1, max_len, vocab_size)
logits_lengths (torch.Tensor): (1, )
keywords_tokenset (set): token set for filtering score
score_beam_size (int): beam size for score
path_beam_size (int): beam size for path
Returns:
List[List[int]]: nbest results
"""
maxlen = logits.size(0)
# ctc_probs = logits.softmax(1) # (1, maxlen, vocab_size)
ctc_probs = logits

cur_hyps = [(tuple(), (1.0, 0.0, []))]

# 2. CTC beam search step by step
for t in range(0, maxlen):
probs = ctc_probs[t] # (vocab_size,)
# key: prefix, value (pb, pnb), default value(-inf, -inf)
next_hyps = defaultdict(lambda: (0.0, 0.0, []))

# 2.1 First beam prune: select topk best
top_k_probs, top_k_index = probs.topk(
score_beam_size) # (score_beam_size,)

# filter prob score that is too small
filter_probs = []
filter_index = []
for prob, idx in zip(top_k_probs.tolist(), top_k_index.tolist()):
if keywords_tokenset is not None:
if prob > 0.05 and idx in keywords_tokenset:
filter_probs.append(prob)
filter_index.append(idx)
else:
if prob > 0.05:
filter_probs.append(prob)
filter_index.append(idx)

if len(filter_index) == 0:
continue

for s in filter_index:
ps = probs[s].item()

for prefix, (pb, pnb, cur_nodes) in cur_hyps:
last = prefix[-1] if len(prefix) > 0 else None
if s == 0: # blank
n_pb, n_pnb, nodes = next_hyps[prefix]
n_pb = n_pb + pb * ps + pnb * ps
nodes = cur_nodes.copy()
next_hyps[prefix] = (n_pb, n_pnb, nodes)
elif s == last:
if not math.isclose(pnb, 0.0, abs_tol=0.000001):
# Update *ss -> *s;
n_pb, n_pnb, nodes = next_hyps[prefix]
n_pnb = n_pnb + pnb * ps
nodes = cur_nodes.copy()
if ps > nodes[-1]['prob']: # update frame and prob
nodes[-1]['prob'] = ps
nodes[-1]['frame'] = t
next_hyps[prefix] = (n_pb, n_pnb, nodes)

if not math.isclose(pb, 0.0, abs_tol=0.000001):
# Update *s-s -> *ss, - is for blank
n_prefix = prefix + (s, )
n_pb, n_pnb, nodes = next_hyps[n_prefix]
n_pnb = n_pnb + pb * ps
nodes = cur_nodes.copy()
nodes.append(dict(token=s, frame=t,
prob=ps)) # to record token prob
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)
else:
n_prefix = prefix + (s, )
n_pb, n_pnb, nodes = next_hyps[n_prefix]
if nodes:
if ps > nodes[-1]['prob']: # update frame and prob
# nodes[-1]['prob'] = ps
# nodes[-1]['frame'] = t
# avoid change other beam which has this node.
nodes.pop()
nodes.append(dict(token=s, frame=t, prob=ps))
else:
nodes = cur_nodes.copy()
nodes.append(dict(token=s, frame=t,
prob=ps)) # to record token prob
n_pnb = n_pnb + pb * ps + pnb * ps
next_hyps[n_prefix] = (n_pb, n_pnb, nodes)

# 2.2 Second beam prune
next_hyps = sorted(
next_hyps.items(), key=lambda x: (x[1][0] + x[1][1]), reverse=True)

cur_hyps = next_hyps[:path_beam_size]

hyps = [(y[0], y[1][0] + y[1][1], y[1][2]) for y in cur_hyps]
return hyps


class Calculator:

def __init__(self):
self.data = {}
self.space = []
self.cost = {}
self.cost['cor'] = 0
self.cost['sub'] = 1
self.cost['del'] = 1
self.cost['ins'] = 1

def calculate(self, lab, rec):
# Initialization
lab.insert(0, '')
rec.insert(0, '')
while len(self.space) < len(lab):
self.space.append([])
for row in self.space:
for element in row:
element['dist'] = 0
element['error'] = 'non'
while len(row) < len(rec):
row.append({'dist': 0, 'error': 'non'})
for i in range(len(lab)):
self.space[i][0]['dist'] = i
self.space[i][0]['error'] = 'del'
for j in range(len(rec)):
self.space[0][j]['dist'] = j
self.space[0][j]['error'] = 'ins'
self.space[0][0]['error'] = 'non'
for token in lab:
if token not in self.data and len(token) > 0:
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
for token in rec:
if token not in self.data and len(token) > 0:
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
# Computing edit distance
for i, lab_token in enumerate(lab):
for j, rec_token in enumerate(rec):
if i == 0 or j == 0:
continue
min_dist = sys.maxsize
min_error = 'none'
dist = self.space[i - 1][j]['dist'] + self.cost['del']
error = 'del'
if dist < min_dist:
min_dist = dist
min_error = error
dist = self.space[i][j - 1]['dist'] + self.cost['ins']
error = 'ins'
if dist < min_dist:
min_dist = dist
min_error = error
if lab_token == rec_token:
dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
error = 'cor'
else:
dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
error = 'sub'
if dist < min_dist:
min_dist = dist
min_error = error
self.space[i][j]['dist'] = min_dist
self.space[i][j]['error'] = min_error
# Tracing back
result = {
'lab': [],
'rec': [],
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
i = len(lab) - 1
j = len(rec) - 1
while True:
if self.space[i][j]['error'] == 'cor': # correct
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
result['all'] = result['all'] + 1
result['cor'] = result['cor'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'sub': # substitution
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
result['all'] = result['all'] + 1
result['sub'] = result['sub'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'del': # deletion
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
result['all'] = result['all'] + 1
result['del'] = result['del'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, '')
i = i - 1
elif self.space[i][j]['error'] == 'ins': # insertion
if len(rec[j]) > 0:
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
result['ins'] = result['ins'] + 1
result['lab'].insert(0, '')
result['rec'].insert(0, rec[j])
j = j - 1
elif self.space[i][j]['error'] == 'non': # starting point
break
else: # shouldn't reach here
print(
'this should not happen, '
'i = {i} , j = {j} , error = {error}'
.format(i=i, j=j, error=self.space[i][j]['error']))
return result

def overall(self):
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
for token in self.data:
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result

def cluster(self, data):
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
for token in data:
if token in self.data:
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result

def keys(self):
return list(self.data.keys())
48 changes: 48 additions & 0 deletions wekws/utils/cmvn.py
Original file line number Diff line number Diff line change
@@ -15,6 +15,7 @@

import json
import math
import re

import numpy as np

@@ -42,3 +43,50 @@ def load_cmvn(json_cmvn_file):
variance[i] = 1.0 / math.sqrt(variance[i])
cmvn = np.array([means, variance])
return cmvn

def load_kaldi_cmvn(cmvn_file):
""" Load the kaldi format cmvn stats file and no need to calculate
Args:
cmvn_file: cmvn stats file in kaldi format
Returns:
a numpy array of [means, vars]
"""

means = None
variance = None
with open(cmvn_file) as f:
all_lines = f.readlines()
for idx, line in enumerate(all_lines):
if line.find('AddShift') != -1:
segs = line.strip().split(' ')
assert len(segs) == 3
next_line = all_lines[idx + 1]
means_str = re.findall(r'[\[](.*?)[\]]', next_line)[0]
means_list = means_str.strip().split(' ')
means = [0 - float(s) for s in means_list]
assert len(means) == int(segs[1])
elif line.find('Rescale') != -1:
segs = line.strip().split(' ')
assert len(segs) == 3
next_line = all_lines[idx + 1]
vars_str = re.findall(r'[\[](.*?)[\]]', next_line)[0]
vars_list = vars_str.strip().split(' ')
variance = [float(s) for s in vars_list]
assert len(variance) == int(segs[1])
elif line.find('Splice') != -1:
segs = line.strip().split(' ')
assert len(segs) == 3
next_line = all_lines[idx + 1]
splice_str = re.findall(r'[\[](.*?)[\]]', next_line)[0]
splice_list = splice_str.strip().split(' ')
assert len(splice_list) * int(segs[2]) == int(segs[1])
copy_times = len(splice_list)
else:
continue

cmvn = np.array([means, variance])
cmvn = np.tile(cmvn, (1, copy_times))

return cmvn
15 changes: 11 additions & 4 deletions wekws/utils/executor.py
Original file line number Diff line number Diff line change
@@ -34,17 +34,20 @@ def train(self, model, optimizer, data_loader, device, writer, args):
min_duration = args.get('min_duration', 0)

for batch_idx, batch in enumerate(data_loader):
key, feats, target, feats_lengths = batch
key, feats, target, feats_lengths, label_lengths = batch
feats = feats.to(device)
target = target.to(device)
feats_lengths = feats_lengths.to(device)
label_lengths = label_lengths.to(device)
num_utts = feats_lengths.size(0)
if num_utts == 0:
continue
logits, _ = model(feats)
loss_type = args.get('criterion', 'max_pooling')
loss, acc = criterion(loss_type, logits, target, feats_lengths,
min_duration)
target_lengths=label_lengths,
min_duration=min_duration,
validation=False)
optimizer.zero_grad()
loss.backward()
grad_norm = clip_grad_norm_(model.parameters(), clip)
@@ -67,16 +70,20 @@ def cv(self, model, data_loader, device, args):
total_acc = 0.0
with torch.no_grad():
for batch_idx, batch in enumerate(data_loader):
key, feats, target, feats_lengths = batch
key, feats, target, feats_lengths, label_lengths = batch
feats = feats.to(device)
target = target.to(device)
feats_lengths = feats_lengths.to(device)
label_lengths = label_lengths.to(device)
num_utts = feats_lengths.size(0)
if num_utts == 0:
continue
logits, _ = model(feats)
loss, acc = criterion(args.get('criterion', 'max_pooling'),
logits, target, feats_lengths)
logits, target, feats_lengths,
target_lengths=label_lengths,
min_duration=0,
validation=True)
if torch.isfinite(loss):
num_seen_utts += num_utts
total_loss += loss.item() * num_utts