Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper inference support in cpp runtime #2320

Merged
merged 12 commits into from
Jan 25, 2024
14 changes: 13 additions & 1 deletion runtime/core/decoder/params.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ DEFINE_int32(core_number, 1, "Core number of process");
// FeaturePipelineConfig flags
DEFINE_int32(num_bins, 80, "num mel bins for fbank feature");
DEFINE_int32(sample_rate, 16000, "sample rate for audio");
DEFINE_string(feat_type, "kaldi", "Type of feature extraction: kaldi, whisper");

// TLG fst
DEFINE_string(fst_path, "", "TLG fst path");
Expand Down Expand Up @@ -115,9 +116,20 @@ DEFINE_int32(language_type, 0,
DEFINE_bool(lowercase, true, "lowercase final result if needed");

namespace wenet {

FeatureType StringToFeatureType(const std::string& feat_type_str) {
if (feat_type_str == "kaldi")
return FeatureType::KALDI;
else if (feat_type_str == "whisper")
return FeatureType::Whisper;
else
throw std::invalid_argument("Unsupported feat type!");
}

std::shared_ptr<FeaturePipelineConfig> InitFeaturePipelineConfigFromFlags() {
FeatureType feat_type = StringToFeatureType(FLAGS_feat_type);
auto feature_config = std::make_shared<FeaturePipelineConfig>(
FLAGS_num_bins, FLAGS_sample_rate);
FLAGS_num_bins, FLAGS_sample_rate, feat_type);
return feature_config;
}

Expand Down
201 changes: 167 additions & 34 deletions runtime/core/frontend/fbank.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,53 @@
#define FRONTEND_FBANK_H_

#include <cstring>
#include <fstream>
#include <limits>
#include <random>
#include <string>
#include <utility>
#include <vector>

#include "frontend/fft.h"
#include "utils/log.h"

#define S16_TO_FLOAT_SCALE 32768

namespace wenet {

// This code is based on kaldi Fbank implementation, please see
// https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-fbank.cc

enum class WindowType {
zhr1201 marked this conversation as resolved.
Show resolved Hide resolved
Povey,
Hanning,
};

enum class MelType {
zhr1201 marked this conversation as resolved.
Show resolved Hide resolved
HTK,
Slaney,
};

enum class NormalizationType {
KALDI,
zhr1201 marked this conversation as resolved.
Show resolved Hide resolved
Whisper,
};

enum class LogBase {
BaseE,
Base10,
};

class Fbank {
public:
Fbank(int num_bins, int sample_rate, int frame_length, int frame_shift)
Fbank(int num_bins, int sample_rate, int frame_length, int frame_shift,
float low_freq = 20, bool pre_emphasis = true,
bool scaled_float_as_input = false,
float log_floor = std::numeric_limits<float>::epsilon(),
LogBase log_base = LogBase::BaseE,
WindowType window_type = WindowType::Povey,
MelType mel_type = MelType::HTK,
NormalizationType norm_type = NormalizationType::KALDI)
: num_bins_(num_bins),
sample_rate_(sample_rate),
frame_length_(frame_length),
Expand All @@ -39,7 +71,12 @@ class Fbank {
remove_dc_offset_(true),
generator_(0),
distribution_(0, 1.0),
dither_(0.0) {
dither_(0.0),
pre_emphasis_(pre_emphasis),
scaled_float_as_input_(scaled_float_as_input),
log_floor_(log_floor),
log_base_(log_base),
norm_type_(norm_type) {
fft_points_ = UpperPowerOfTwo(frame_length_);
// generate bit reversal table and trigonometric function table
const int fft_points_4 = fft_points_ / 4;
Expand All @@ -50,29 +87,48 @@ class Fbank {

int num_fft_bins = fft_points_ / 2;
float fft_bin_width = static_cast<float>(sample_rate_) / fft_points_;
int low_freq = 20, high_freq = sample_rate_ / 2;
float mel_low_freq = MelScale(low_freq);
float mel_high_freq = MelScale(high_freq);
int high_freq = sample_rate_ / 2;
float mel_low_freq = MelScale(low_freq, mel_type);
float mel_high_freq = MelScale(high_freq, mel_type);
float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1);
bins_.resize(num_bins_);
center_freqs_.resize(num_bins_);

for (int bin = 0; bin < num_bins; ++bin) {
float left_mel = mel_low_freq + bin * mel_freq_delta,
center_mel = mel_low_freq + (bin + 1) * mel_freq_delta,
right_mel = mel_low_freq + (bin + 2) * mel_freq_delta;
center_freqs_[bin] = InverseMelScale(center_mel);
center_freqs_[bin] = InverseMelScale(center_mel, mel_type);
std::vector<float> this_bin(num_fft_bins);
int first_index = -1, last_index = -1;
for (int i = 0; i < num_fft_bins; ++i) {
float freq = (fft_bin_width * i); // Center frequency of this fft
// bin.
float mel = MelScale(freq);
float mel = MelScale(freq, mel_type);
if (mel > left_mel && mel < right_mel) {
float weight;
if (mel <= center_mel)
weight = (mel - left_mel) / (center_mel - left_mel);
else
weight = (right_mel - mel) / (right_mel - center_mel);
if (mel_type == MelType::HTK) {
if (mel <= center_mel)
weight = (mel - left_mel) / (center_mel - left_mel);
else if (mel > center_mel)
weight = (right_mel - mel) / (right_mel - center_mel);
} else if (mel_type == MelType::Slaney) {
if (mel <= center_mel) {
weight = (InverseMelScale(mel, mel_type) -
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InverseMelScale(left_mel, mel_type)) /
(InverseMelScale(center_mel, mel_type) -
InverseMelScale(left_mel, mel_type));
weight *= 2.0 / (InverseMelScale(right_mel, mel_type) -
InverseMelScale(left_mel, mel_type));
} else if (mel > center_mel) {
weight = (InverseMelScale(right_mel, mel_type) -
InverseMelScale(mel, mel_type)) /
(InverseMelScale(right_mel, mel_type) -
InverseMelScale(center_mel, mel_type));
weight *= 2.0 / (InverseMelScale(right_mel, mel_type) -
InverseMelScale(left_mel, mel_type));
}
}
this_bin[i] = weight;
if (first_index == -1) first_index = i;
last_index = i;
Expand All @@ -86,13 +142,7 @@ class Fbank {
bins_[bin].second[i] = this_bin[first_index + i];
}
}

// povey window
povey_window_.resize(frame_length_);
double a = M_2PI / (frame_length - 1);
for (int i = 0; i < frame_length; ++i) {
povey_window_[i] = pow(0.5 - 0.5 * cos(a * i), 0.85);
}
InitWindow(window_type);
}

void set_use_log(bool use_log) { use_log_ = use_log; }
Expand All @@ -105,12 +155,56 @@ class Fbank {

int num_bins() const { return num_bins_; }

static inline float InverseMelScale(float mel_freq) {
return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f);
void InitWindow(WindowType window_type) {
window_.resize(frame_length_);
if (window_type == WindowType::Povey) {
// povey window
double a = M_2PI / (frame_length_ - 1);
for (int i = 0; i < frame_length_; ++i)
window_[i] = pow(0.5 - 0.5 * cos(a * i), 0.85);
} else if (window_type == WindowType::Hanning) {
// periodic hanning window
double a = M_2PI / (frame_length_);
Copy link
Contributor Author

@zhr1201 zhr1201 Jan 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for (int i = 0; i < frame_length_; ++i)
window_[i] = 0.5 * (1.0 - cos(i * a));
}
}

static inline float InverseMelScale(float mel_freq,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@zhr1201 zhr1201 Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be further optimized if needed, there are lot of repeated computations, but could already be optimized by some compiler through constant propogation

MelType mel_type = MelType::HTK) {
if (mel_type == MelType::HTK) {
return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f);
} else if (mel_type == MelType::Slaney) {
float f_min = 0.0;
float f_sp = 200.0f / 3.0f;
float min_log_hz = 1000.0;
float freq = f_min + f_sp * mel_freq;
float min_log_mel = (min_log_hz - f_min) / f_sp;
float logstep = logf(6.4) / 27.0f;
if (mel_freq >= min_log_mel) {
return min_log_hz * expf(logstep * (mel_freq - min_log_mel));
} else {
return freq;
}
}
}

static inline float MelScale(float freq) {
return 1127.0f * logf(1.0f + freq / 700.0f);
static inline float MelScale(float freq, MelType mel_type = MelType::HTK) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (mel_type == MelType::HTK) {
return 1127.0f * logf(1.0f + freq / 700.0f);
} else if (mel_type == MelType::Slaney) {
float f_min = 0.0;
float f_sp = 200.0f / 3.0f;
float min_log_hz = 1000.0;
float mel = (freq - f_min) / f_sp;
float min_log_mel = (min_log_hz - f_min) / f_sp;
float logstep = logf(6.4) / 27.0f;
if (freq >= min_log_hz) {
return min_log_mel + logf(freq / min_log_hz) / logstep;
} else {
return mel;
}
}
}

static int UpperPowerOfTwo(int n) {
Expand All @@ -125,26 +219,50 @@ class Fbank {
(*data)[0] -= coeff * (*data)[0];
}

// Apply povey window on data in place
void Povey(std::vector<float>* data) const {
CHECK_GE(data->size(), povey_window_.size());
for (size_t i = 0; i < povey_window_.size(); ++i) {
(*data)[i] *= povey_window_[i];
// Apply window on data in place
void ApplyWindow(std::vector<float>* data) const {
CHECK_GE(data->size(), window_.size());
for (size_t i = 0; i < window_.size(); ++i) {
(*data)[i] *= window_[i];
}
}

void WhisperNorm(std::vector<std::vector<float>>* feat,
float max_mel_engery) {
int num_frames = feat->size();
for (int i = 0; i < num_frames; ++i) {
for (int j = 0; j < num_bins_; ++j) {
float energy = (*feat)[i][j];
if (energy < max_mel_engery - 8) energy = max_mel_engery - 8;
energy = (energy + 4.0) / 4.0;
(*feat)[i][j] = energy;
}
}
}

// Compute fbank feat, return num frames
int Compute(const std::vector<float>& wave,
std::vector<std::vector<float>>* feat) {
int num_samples = wave.size();

if (num_samples < frame_length_) return 0;
int num_frames = 1 + ((num_samples - frame_length_) / frame_shift_);
feat->resize(num_frames);
std::vector<float> fft_real(fft_points_, 0), fft_img(fft_points_, 0);
std::vector<float> power(fft_points_ / 2);

float max_mel_engery = std::numeric_limits<float>::min();

for (int i = 0; i < num_frames; ++i) {
std::vector<float> data(wave.data() + i * frame_shift_,
wave.data() + i * frame_shift_ + frame_length_);

if (scaled_float_as_input_) {
for (int j = 0; j < frame_length_; ++j) {
data[j] = data[j] / S16_TO_FLOAT_SCALE;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data feed into this pipeline is int but converted to float without scaling, whisper training code load this as float between -1 to 1

}
}

// optional add noise
if (dither_ != 0.0) {
for (size_t j = 0; j < data.size(); ++j)
Expand All @@ -158,8 +276,10 @@ class Fbank {
for (size_t j = 0; j < data.size(); ++j) data[j] -= mean;
}

PreEmphasis(0.97, &data);
Povey(&data);
if (pre_emphasis_) {
PreEmphasis(0.97, &data);
}
ApplyWindow(&data);
// copy data to fft_real
memset(fft_img.data(), 0, sizeof(float) * fft_points_);
memset(fft_real.data() + frame_length_, 0,
Expand All @@ -174,6 +294,7 @@ class Fbank {

(*feat)[i].resize(num_bins_);
// cepstral coefficients, triangle filter array

for (int j = 0; j < num_bins_; ++j) {
float mel_energy = 0.0;
int s = bins_[j].first;
Expand All @@ -182,14 +303,20 @@ class Fbank {
}
// optional use log
if (use_log_) {
if (mel_energy < std::numeric_limits<float>::epsilon())
mel_energy = std::numeric_limits<float>::epsilon();
mel_energy = logf(mel_energy);
}
if (mel_energy < log_floor_) mel_energy = log_floor_;

if (log_base_ == LogBase::BaseE)
mel_energy = logf(mel_energy);
else if (log_base_ == LogBase::Base10)
mel_energy = log10(mel_energy);
}
if (max_mel_engery < mel_energy) max_mel_engery = mel_energy;
(*feat)[i][j] = mel_energy;
}
}
if (norm_type_ == NormalizationType::Whisper)
WhisperNorm(feat, max_mel_engery);

return num_frames;
}

Expand All @@ -200,9 +327,15 @@ class Fbank {
int fft_points_;
bool use_log_;
bool remove_dc_offset_;
bool pre_emphasis_;
bool scaled_float_as_input_;
float log_floor_;
LogBase log_base_;
NormalizationType norm_type_;

std::vector<float> center_freqs_;
std::vector<std::pair<int, std::vector<float>>> bins_;
std::vector<float> povey_window_;
std::vector<float> window_;
std::default_random_engine generator_;
std::normal_distribution<float> distribution_;
float dither_;
Expand Down
4 changes: 3 additions & 1 deletion runtime/core/frontend/feature_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ FeaturePipeline::FeaturePipeline(const FeaturePipelineConfig& config)
: config_(config),
feature_dim_(config.num_bins),
fbank_(config.num_bins, config.sample_rate, config.frame_length,
config.frame_shift),
config.frame_shift, config.low_freq, config.pre_emphasis,
config.scaled_float_as_input, config.log_floor, config.log_base,
config.window_type, config.mel_type, config.norm_type),
num_frames_(0),
input_finished_(false) {}

Expand Down
Loading
Loading