Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper inference support in cpp runtime #2320

Merged
merged 12 commits into from
Jan 25, 2024

Conversation

zhr1201
Copy link
Contributor

@zhr1201 zhr1201 commented Jan 23, 2024

We were trying to reproduce the steps to load whisper in wenet and finetune it for streaming on libsrispeech, following the steps in https://github.com/wenet-e2e/wenet/tree/main/examples/aishell/whisper. Finetuning did seem to work, but when we tried to load it in the cpp runtime, it only produced empty results. After further inspection, it looks like the feature extraction of whisper is very different than the one implemented in wenet runtime, mainly regarding:

  • Window for STFT (Povey window versus Hanning window)
  • Mel scale is different, there are two commonly used versions 1. one derived from htk, which wenet and kaldi use as well, 2. the other one from matlab slaney toolbox, which is implemented by librosa and used to train whisper
  • Mel weight are different, kaldi was computing the weight based on the frequencies on the mel scale, but whisper training uses the weight computed in the original frequency scale
  • Normalization after computing the PSD is different, mainly the base of the log operation and how it scales the output are different

examples show it's working

decoder main using whisper
whisper_cli

cli is still compatible with the original code, default behavior doesnt change
compatible_cli

There are still a couple places that are not perfect

  1. token list in the above example is directly exported from whisper tokenizer, so the output contains things like <|notimestamp|> and <|space|>, which doesnt look nice. Maybe it can be purely solved by cleanup the token list, but it's also possible that we might need some runtime code change to make this look good. If we do need to do that, i think that's a separate problem and deserves a separate PR. However, if there are simple fixes, i am happy to do it in this PR.
  2. The feature computed is still a bit different from that directly computed using python directly calling the whisper related code, i suspect the difference mainly comes from FFT computation, will show more details in the following comments.

I know these are a lot of changes, so I am more than happy to change the structure into the style you guys prefer if you think there is a better way

@zhr1201
Copy link
Contributor Author

zhr1201 commented Jan 23, 2024

hanning window generated in runtime versus the one from torch, difference are very small
window

@zhr1201
Copy link
Contributor Author

zhr1201 commented Jan 23, 2024

Mel fitlers generated from cpp runtime is also very close to the one from librosa

filters

@zhr1201
Copy link
Contributor Author

zhr1201 commented Jan 23, 2024

computed feature is pretty similar, but NOT THE SAME! notice there is a small difference in max and min value
Screenshot 2024-01-23 at 5 40 21 PM

I dumped the PSD after STFT as well and compared it with the one computed using torch, I think this might be the source of difference
psd_difference

weight = (right_mel - mel) / (right_mel - center_mel);
} else if (mel_type == MelType::Slaney) {
if (mel <= center_mel) {
weight = (InverseMelScale(mel, mel_type) -
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

window_[i] = pow(0.5 - 0.5 * cos(a * i), 0.85);
} else if (window_type == WindowType::Hanning) {
// periodic hanning window
double a = M_2PI / (frame_length_);
Copy link
Contributor Author

@zhr1201 zhr1201 Jan 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}

static inline float MelScale(float freq) {
return 1127.0f * logf(1.0f + freq / 700.0f);
static inline float MelScale(float freq, MelType mel_type = MelType::HTK) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
}

static inline float InverseMelScale(float mel_freq,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@zhr1201 zhr1201 Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be further optimized if needed, there are lot of repeated computations, but could already be optimized by some compiler through constant propogation


if (scaled_float_as_input_) {
for (int j = 0; j < frame_length_; ++j) {
data[j] = data[j] / S16_TO_FLOAT_SCALE;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data feed into this pipeline is int but converted to float without scaling, whisper training code load this as float between -1 to 1

@zhr1201 zhr1201 changed the title Whisper inference support in runtime Whisper inference support in cpp runtime Jan 23, 2024
@zhr1201 zhr1201 marked this pull request as ready for review January 23, 2024 23:25
@robin1001
Copy link
Collaborator

Great job, it's clear.

@xingchensong
Copy link
Member

xingchensong commented Jan 24, 2024

Great job!

I think kaldifeat's impl for computing features of whisper could be beneficial for our work:

@xingchensong
Copy link
Member

BTW, Why does the CTC result contain <|notimestamp|>? The label provided to the CTC loss function doesn't include this tag (only the label given to the CE loss has it), so <|notimestamp|> shouldn't appear during CTC decoding.

@zhr1201
Copy link
Contributor Author

zhr1201 commented Jan 24, 2024

Thanks for referencing this implementation, those are really helpful! Ya, they are doing basically the same thing that we want to do. I think we have two ways to support whisper inference in wenet.

  1. Create another FBank implementation, just like what kaldifeat did, and copy over their code, implementing a interface with Compute so we can integrate it into the feature pipeline in wenet.

  2. Update Fbank with more options like what we are doing in this PR.

For 1, since Fbank computation is not a lot of code, i think it makes sense to have a separate WhisperFbankComputer even without reusing the current FBank code. However for 1. I don't like the fact that they hard coded the weight there, so it's less flexible https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/csrc/whisper-mel-bank.h. and b. It's computing STFT using torch. Since wenet can support runtimes other than torch, we probably don't want to depend on torch in the feature extraction part. However, if we do think 1 is a preferred structure, we can reuse the current wenet FFT so b is not a problem, and we can reuse the code in this PR to generate the filters so a is not a problem as well.

Based on above, i think it's more of a style thing, basically we need to decide if we want to reuse FBank or create another WhisperFbank

@zhr1201
Copy link
Contributor Author

zhr1201 commented Jan 24, 2024

BTW, Why does the CTC result contain <|notimestamp|>? The label provided to the CTC loss function doesn't include this tag (only the label given to the CE loss has it), so <|notimestamp|> shouldn't appear during CTC decoding.

Very good question, i am curious as well, i checked again and it does look like those special whisper tokens is not added in CTC loss calculation, so it shouldn't be possible. Maybe something wrong with my training setup?

Currently my only assumption is that our model didn't converge well (and it generalized very badly on out of domain data), it looks like we did make a mistake in training by not setting

language="zh",
to finetune for english, but i guess that still doesn't explain why we have this <|notimestamp|> token. Will try finetuning again, but would be great if you can confirm if it is the same case in the AISHELL whisper.

@@ -24,13 +24,43 @@
#include "frontend/fft.h"
#include "utils/log.h"

#define S16_ABS_MAX (2 << 15)
Copy link
Contributor Author

@zhr1201 zhr1201 Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another concern is that this probably shouldn't be hard coded here, as it should be the responsibility of wavreader to scale the input, rather than the responsibility of feature_extraction_pipeline. Moving it to wavreader also allows the flexiblity if the audio is encoded using pcm_s32 or pcm_s8 instead of fixing it to pcm_s16.

However, doing that would also require changes in http server, websocket server and grpc server code, and possibly other places like the jni bindings for android etc. Feels like that should be decided by the main maintainers of wenet.

(we could do it in a hacky way e.g. let the cli to take in a scale factor, and making this a paramter of the feature extraction pipeline, but that doens't feel right, e.g. the cli takes in a list of wav files encoded with different number of bits, and it won't work in that case as different samples require different scaling factor.)

i will leave this as it is for now, but do let me know your thoughts if you want to update this in this PR. or if you guys feel like making a separate PR fixing this, that also works.

@robin1001
Copy link
Collaborator

Thanks for referencing this implementation, those are really helpful! Ya, they are doing basically the same thing that we want to do. I think we have two ways to support whisper inference in wenet.

  1. Create another FBank implementation, just like what kaldifeat did, and copy over their code, implementing a interface with Compute so we can integrate it into the feature pipeline in wenet.
  2. Update Fbank with more options like what we are doing in this PR.

For 1, since Fbank computation is not a lot of code, i think it makes sense to have a separate WhisperFbankComputer even without reusing the current FBank code. However for 1. I don't like the fact that they hard coded the weight there, so it's less flexible https://github.com/csukuangfj/kaldifeat/blob/master/kaldifeat/csrc/whisper-mel-bank.h. and b. It's computing STFT using torch. Since wenet can support runtimes other than torch, we probably don't want to depend on torch in the feature extraction part. However, if we do think 1 is a preferred structure, we can reuse the current wenet FFT so b is not a problem, and we can reuse the code in this PR to generate the filters so a is not a problem as well.

Based on above, i think it's more of a style thing, basically we need to decide if we want to reuse FBank or create another WhisperFbank

Yes, I think current implemenation is okay. There is a lot of hard code about mel weights for whisper fbank in kaldifeat, which is not preferred.

@zhr1201
Copy link
Contributor Author

zhr1201 commented Jan 25, 2024

Regarding the STFT difference, I think there is no way to make them match exactly. Reason is because if fft_length != window_length, torch STFT will pad the window on both the left side and the right side: https://github.com/pytorch/pytorch/blob/2d7a360911fb7b27be82c51ca86b4b34b6f1b087/aten/src/ATen/native/SpectralOps.cpp#L936, normally FFT energy doesn't depend on how you pad the input, however this is how torch separate audio into different frames https://github.com/pytorch/pytorch/blob/2d7a360911fb7b27be82c51ca86b4b34b6f1b087/aten/src/ATen/native/SpectralOps.cpp#L949 , because of this, padding the window in different places will result in different part of the raw signal multiplied by different part of the window, resulting in a different PSD result.

But I think it's probably fine, result will be the same if we shift the sequence by (fft_length - window_length) / 2,

padded_wav = F.pad(wav, (56, 56), "constant", 0) # pad the input so window will match the audio the same way wenet does
stft = torch.stft(padded_wav,
                  512,
                  160,
                  window=window,
                  center=False,  # this is another trivial source of differece
                  win_length=400,
                  return_complex=True)
magnitudes = stft[..., :-1].abs()**2
mel_spec_512 = filters_512 @ magnitudes
log_spec_before_norm_512 = torch.clamp(mel_spec_512, min=1e-10).log10()
log_spec_before_norm_512 = torch.maximum(log_spec_before_norm_512, log_spec_before_norm_512.max() - 8.0)
log_spec_after_norm_512 = (log_spec_before_norm_512 + 4.0) / 4.0

and we will get almost the same result

Screenshot 2024-01-24 at 9 47 14 PM

I think this is a feature, not a bug, as ASR result should not change even if we shift the input by some number of sampling points.

robin1001
robin1001 previously approved these changes Jan 25, 2024
@robin1001 robin1001 merged commit baaa27a into wenet-e2e:main Jan 25, 2024
6 checks passed
@zhr1201
Copy link
Contributor Author

zhr1201 commented Jan 25, 2024

感谢各位大佬review 和 approve,愿wenet越来越强大,用户越来越多!

@robin1001
Copy link
Collaborator

开源靠大家,感谢贡献!

@zhr1201
Copy link
Contributor Author

zhr1201 commented Feb 2, 2024

Just for people who are confused about <|notimestamps|>, it's because the token list that i used is wrong, it actually corresponds to blank tokens in CTC, which won't appear in the final transcript.

Related issue: #2329

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants