Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update whisper.cpp version to 1.6.2 #142

Merged
merged 9 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ exclude = ["examples/full_usage"]

[package]
name = "whisper-rs"
version = "0.11.0"
version = "0.12.0"
edition = "2021"
description = "Rust bindings for whisper.cpp"
license = "Unlicense"
Expand All @@ -14,7 +14,7 @@ repository = "https://github.com/tazz4843/whisper-rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
whisper-rs-sys = { path = "sys", version = "0.8" }
whisper-rs-sys = { path = "sys", version = "0.10.0" }
log = { version = "0.4", optional = true }
tracing = { version = "0.1", optional = true }

Expand Down
50 changes: 48 additions & 2 deletions examples/audio_transcription.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,37 @@ use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextPar
/// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout.
fn main() -> Result<(), &'static str> {
// Load a context and model.
let mut context_param = WhisperContextParameters::default();

// Enable DTW token level timestamp for known model by using model preset
context_param.dtw_parameters.mode = whisper_rs::DtwMode::ModelPreset {
model_preset: whisper_rs::DtwModelPreset::BaseEn,
};

// Enable DTW token level timestamp for unknown model by providing custom aheads
// see details https://github.com/ggerganov/whisper.cpp/pull/1485#discussion_r1519681143
// values corresponds to ggml-base.en.bin, result will be the same as with DtwModelPreset::BaseEn
let custom_aheads = [
(3, 1),
(4, 2),
(4, 3),
(4, 7),
(5, 1),
(5, 2),
(5, 4),
(5, 6),
]
.map(|(n_text_layer, n_head)| whisper_rs::DtwAhead {
n_text_layer,
n_head,
});
context_param.dtw_parameters.mode = whisper_rs::DtwMode::Custom {
aheads: &custom_aheads,
};

let ctx = WhisperContext::new_with_params(
"example/path/to/model/whisper.cpp/models/ggml-base.en.bin",
WhisperContextParameters::default(),
context_param,
)
.expect("failed to load model");
// Create a state
Expand All @@ -33,6 +61,8 @@ fn main() -> Result<(), &'static str> {
params.set_print_progress(false);
params.set_print_realtime(false);
params.set_print_timestamps(false);
// Enable token level timestamps
params.set_token_timestamps(true);

// Open the audio file.
let reader = hound::WavReader::open("audio.wav").expect("failed to open file");
Expand Down Expand Up @@ -87,8 +117,24 @@ fn main() -> Result<(), &'static str> {
.full_get_segment_t1(i)
.expect("failed to get end timestamp");

let first_token_dtw_ts = if let Ok(token_count) = state.full_n_tokens(i) {
if token_count > 0 {
if let Ok(token_data) = state.full_get_token_data(i, 0) {
token_data.t_dtw
} else {
-1i64
}
} else {
-1i64
}
} else {
-1i64
};
// Print the segment to stdout.
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
println!(
"[{} - {} ({})]: {}",
start_timestamp, end_timestamp, first_token_dtw_ts, segment
);

// Format the segment information as a string.
let line = format!("[{} - {}]: {}\n", start_timestamp, end_timestamp, segment);
Expand Down
6 changes: 5 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ pub use standalone::*;
#[cfg(any(feature = "whisper-cpp-log", feature = "whisper-cpp-tracing"))]
use std::sync::Once;
pub use utilities::*;
pub use whisper_ctx::DtwMode;
pub use whisper_ctx::DtwModelPreset;
pub use whisper_ctx::DtwParameters;
pub use whisper_ctx::WhisperContextParameters;
use whisper_ctx::WhisperInnerContext;
pub use whisper_ctx_wrapper::WhisperContext;
Expand All @@ -44,5 +47,6 @@ pub type WhisperNewSegmentCallback = whisper_rs_sys::whisper_new_segment_callbac
pub type WhisperStartEncoderCallback = whisper_rs_sys::whisper_encoder_begin_callback;
pub type WhisperProgressCallback = whisper_rs_sys::whisper_progress_callback;
pub type WhisperLogitsFilterCallback = whisper_rs_sys::whisper_logits_filter_callback;
pub type WhisperAbortCallback = whisper_rs_sys::whisper_abort_callback;
pub type WhisperAbortCallback = whisper_rs_sys::ggml_abort_callback;
pub type WhisperLogCallback = whisper_rs_sys::ggml_log_callback;
pub type DtwAhead = whisper_rs_sys::whisper_ahead;
4 changes: 2 additions & 2 deletions src/standalone.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ pub struct SystemInfo {
pub f16c: bool,
pub blas: bool,
pub clblast: bool,
pub cublas: bool,
pub cuda: bool,
}

impl Default for SystemInfo {
Expand All @@ -118,7 +118,7 @@ impl Default for SystemInfo {
f16c: whisper_rs_sys::ggml_cpu_has_f16c() != 0,
blas: whisper_rs_sys::ggml_cpu_has_blas() != 0,
clblast: whisper_rs_sys::ggml_cpu_has_clblast() != 0,
cublas: whisper_rs_sys::ggml_cpu_has_cublas() != 0,
cuda: whisper_rs_sys::ggml_cpu_has_cuda() != 0,
}
}
}
Expand Down
160 changes: 157 additions & 3 deletions src/whisper_ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,37 +469,191 @@ impl Drop for WhisperInnerContext {
unsafe impl Send for WhisperInnerContext {}
unsafe impl Sync for WhisperInnerContext {}

pub struct WhisperContextParameters {
pub struct WhisperContextParameters<'a> {
/// Use GPU if available.
///
/// **Warning**: Does not have an effect if OpenCL is selected as GPU backend
/// (in that case, GPU is always enabled).
pub use_gpu: bool,
/// Enable flash attention, default false
///
/// **Warning** Can't be used with DTW. DTW will be disabled if flash_attn is true
pub flash_attn: bool,
/// GPU device id, default 0
pub gpu_device: c_int,
/// DTW token level timestamp parameters
pub dtw_parameters: DtwParameters<'a>,
}

#[allow(clippy::derivable_impls)] // this impl cannot be derived
impl Default for WhisperContextParameters {
impl<'a> Default for WhisperContextParameters<'a> {
fn default() -> Self {
Self {
use_gpu: cfg!(feature = "_gpu"),
flash_attn: false,
gpu_device: 0,
dtw_parameters: DtwParameters::default(),
}
}
}
impl WhisperContextParameters {
impl<'a> WhisperContextParameters<'a> {
pub fn new() -> Self {
Self::default()
}
pub fn use_gpu(&mut self, use_gpu: bool) -> &mut Self {
self.use_gpu = use_gpu;
self
}
pub fn flash_attn(&mut self, flash_attn: bool) -> &mut Self {
self.flash_attn = flash_attn;
self
}
pub fn gpu_device(&mut self, gpu_device: c_int) -> &mut Self {
self.gpu_device = gpu_device;
self
}
pub fn dtw_parameters(&mut self, dtw_parameters: DtwParameters<'a>) -> &mut Self {
self.dtw_parameters = dtw_parameters;
self
}

fn to_c_struct(&self) -> whisper_rs_sys::whisper_context_params {
let dtw_token_timestamps = !matches!(self.dtw_parameters.mode, DtwMode::None);
let mut dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_NONE;
let mut dtw_n_top: c_int = -1;
let mut dtw_aheads = whisper_rs_sys::whisper_aheads {
n_heads: 0,
heads: std::ptr::null(),
};

match &self.dtw_parameters.mode {
DtwMode::None => {}
DtwMode::TopMost { n_top } => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_N_TOP_MOST;
dtw_n_top = *n_top;
}
DtwMode::Custom { aheads } => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_CUSTOM;

dtw_aheads = whisper_rs_sys::whisper_aheads {
n_heads: aheads.len(),
heads: aheads.as_ptr(),
};
}
DtwMode::ModelPreset { model_preset } => match model_preset {
DtwModelPreset::TinyEn => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY_EN;
}
DtwModelPreset::Tiny => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY;
}
DtwModelPreset::BaseEn => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE_EN;
}
DtwModelPreset::Base => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE;
}
DtwModelPreset::SmallEn => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL_EN;
}
DtwModelPreset::Small => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL;
}
DtwModelPreset::MediumEn => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM_EN;
}
DtwModelPreset::Medium => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM;
}
DtwModelPreset::LargeV1 => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V1;
}
DtwModelPreset::LargeV2 => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V2;
}
DtwModelPreset::LargeV3 => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V3;
}
},
}

whisper_rs_sys::whisper_context_params {
use_gpu: self.use_gpu,
flash_attn: self.flash_attn,
gpu_device: self.gpu_device,
dtw_token_timestamps,
dtw_aheads_preset,
dtw_n_top,
dtw_aheads,
dtw_mem_size: self.dtw_parameters.dtw_mem_size,
}
}
}

/// [EXPERIMENTAL] Enable Token-level timestamps with DTW, default Disabled
#[derive(Debug, Clone)]
pub struct DtwParameters<'a> {
pub mode: DtwMode<'a>,
pub dtw_mem_size: usize,
}

impl Default for DtwParameters<'_> {
fn default() -> Self {
Self {
mode: DtwMode::None,
dtw_mem_size: 1024 * 1024 * 128,
}
}
}

#[derive(Debug, Clone)]
pub enum DtwMode<'a> {
/// DTW token level timestamps disabled
None,
/// Use N Top Most layers from loaded model
TopMost {
/// Number of top text layers used from model, should be 0 < n_top <= model n_text_layer
n_top: c_int,
},
/// Use custom aheads, non-empty list of whisper_ahead.
/// 0 < n_text_layer < model n_text_layer, 0 < n_head < model n_text_head for each element
/// See details https://github.com/ggerganov/whisper.cpp/pull/1485#discussion_r1519681143
Custom {
aheads: &'a [whisper_rs_sys::whisper_ahead],
},
/// Use predefined preset for standard models
ModelPreset { model_preset: DtwModelPreset },
}

#[derive(Debug, Clone)]
pub enum DtwModelPreset {
TinyEn,
Tiny,
BaseEn,
Base,
SmallEn,
Small,
MediumEn,
Medium,
LargeV1,
LargeV2,
LargeV3,
}

#[cfg(test)]
#[cfg(feature = "test-with-tiny-model")]
mod test_with_tiny_model {
Expand Down
2 changes: 1 addition & 1 deletion sys/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "whisper-rs-sys"
version = "0.8.1"
version = "0.10.0"
edition = "2021"
description = "Rust bindings for whisper.cpp (FFI bindings)"
license = "Unlicense"
Expand Down
2 changes: 1 addition & 1 deletion sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ fn main() {
}

if cfg!(feature = "cuda") {
config.define("WHISPER_CUBLAS", "ON");
config.define("WHISPER_CUDA", "ON");
}

if cfg!(feature = "openblas") {
Expand Down
Loading
Loading