Skip to content

Commit

Permalink
examples : refactor in order to reuse code and reduce duplication (gg…
Browse files Browse the repository at this point in the history
…erganov#482)

* examples : refactor common code into a library

* examples : refactor common SDL code into a library

* make : update Makefile to use common libs

* common : fix MSVC M_PI ..

* addon.node : link common lib
  • Loading branch information
ggerganov authored and rock3125 committed Feb 21, 2023
1 parent a0f28a5 commit d72d881
Show file tree
Hide file tree
Showing 19 changed files with 580 additions and 1,254 deletions.
19 changes: 11 additions & 8 deletions Makefile
Expand Up @@ -197,18 +197,21 @@ clean:

CC_SDL=`sdl2-config --cflags --libs`

main: examples/main/main.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/main/main.cpp ggml.o whisper.o -o main $(LDFLAGS)
SRC_COMMON = examples/common.cpp
SRC_COMMON_SDL = examples/common-sdl.cpp

main: examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o -o main $(LDFLAGS)
./main -h

stream: examples/stream/stream.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)

command: examples/command/command.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/command/command.cpp ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS)
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS)

talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)

bench: examples/bench/bench.cpp ggml.o whisper.o
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
Expand Down
2 changes: 1 addition & 1 deletion bindings/javascript/whisper.js

Large diffs are not rendered by default.

31 changes: 31 additions & 0 deletions examples/CMakeLists.txt
Expand Up @@ -14,6 +14,37 @@ if (WHISPER_SUPPORT_SDL2)
message(STATUS "SDL2_LIBRARIES = ${SDL2_LIBRARIES}")
endif()

# common

set(TARGET common)

add_library(${TARGET} STATIC
common.h
common.cpp
)

include(DefaultTargetOptions)

set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)

if (WHISPER_SUPPORT_SDL2)
# common-sdl

set(TARGET common-sdl)

add_library(${TARGET} STATIC
common-sdl.h
common-sdl.cpp
)

include(DefaultTargetOptions)

target_include_directories(${TARGET} PUBLIC ${SDL2_INCLUDE_DIRS})
target_link_libraries(${TARGET} PRIVATE ${SDL2_LIBRARIES})

set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()

# examples

include_directories(${CMAKE_CURRENT_SOURCE_DIR})
Expand Down
2 changes: 1 addition & 1 deletion examples/addon.node/CMakeLists.txt
Expand Up @@ -23,7 +23,7 @@ string(REPLACE "\"" "" NODE_ADDON_API_DIR ${NODE_ADDON_API_DIR})
target_include_directories(${TARGET} PRIVATE ${NODE_ADDON_API_DIR})
#==================================================================

target_link_libraries(${TARGET} ${CMAKE_JS_LIB} whisper ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} ${CMAKE_JS_LIB} common whisper ${CMAKE_THREAD_LIBS_INIT})

if(MSVC AND CMAKE_JS_NODELIB_DEF AND CMAKE_JS_NODELIB_TARGET)
# Generate node.lib
Expand Down
102 changes: 11 additions & 91 deletions examples/addon.node/addon.cpp
@@ -1,15 +1,13 @@
#include <cstdint>
#include "napi.h"
#include "common.h"

#include "whisper.h"

#include <string>
#include <thread>
#include <vector>
#include <cmath>

#include "napi.h"

#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"

#include "whisper.h"
#include <cstdint>

struct whisper_params {
int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency());
Expand Down Expand Up @@ -44,7 +42,7 @@ struct whisper_params {
std::string model = "../../ggml-large.bin";

std::vector<std::string> fname_inp = {};
std::vector<std::string> fname_outp = {};
std::vector<std::string> fname_out = {};
};

struct whisper_print_user_data {
Expand Down Expand Up @@ -143,7 +141,6 @@ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, voi
}

int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {

if (params.fname_inp.empty()) {
fprintf(stderr, "error: no input files specified\n");
return 2;
Expand Down Expand Up @@ -181,91 +178,14 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {

for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
const auto fname_outp = f < (int)params.fname_outp.size() && !params.fname_outp[f].empty() ? params.fname_outp[f] : params.fname_inp[f];
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];

std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM

// WAV input
{
drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin

if (fname_inp == "-") {
{
uint8_t buf[1024];
while (true)
{
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}

if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return 4;
}

fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
}
else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return 5;
}

if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "error: WAV file '%s' must be mono or stereo\n", fname_inp.c_str());
return 6;
}

if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
fprintf(stderr, "error: WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str());
return 6;
}

if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "error: WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
return 8;
}

if (wav.bitsPerSample != 16) {
fprintf(stderr, "error: WAV file '%s' must be 16-bit\n", fname_inp.c_str());
return 9;
}

const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);

std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);

// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}

if (params.diarize) {
// convert to stereo, float
pcmf32s.resize(2);

pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) {
fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str());
continue;
}

// print system information
Expand Down
1 change: 1 addition & 0 deletions examples/command.wasm/CMakeLists.txt
Expand Up @@ -11,6 +11,7 @@ add_executable(${TARGET}
include(DefaultTargetOptions)

target_link_libraries(${TARGET} PRIVATE
common
whisper
)

Expand Down
62 changes: 3 additions & 59 deletions examples/command.wasm/emscripten.cpp
@@ -1,4 +1,5 @@
#include "ggml.h"
#include "common.h"
#include "whisper.h"

#include <emscripten.h>
Expand Down Expand Up @@ -27,24 +28,6 @@ std::string g_transcribed = "";

std::vector<float> g_pcmf32;

static std::string trim(const std::string & s) {
std::regex e("^\\s+|\\s+$");
return std::regex_replace(s, e, "");
}

static void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);

float y = data[0];

for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}

// compute similarity between two strings using Levenshtein distance
static float similarity(const std::string & s0, const std::string & s1) {
const size_t len0 = s0.size() + 1;
Expand Down Expand Up @@ -75,44 +58,6 @@ void command_set_status(const std::string & status) {
g_status = status;
}

bool command_vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;

if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}

if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}

float energy_all = 0.0f;
float energy_last = 0.0f;

for (size_t i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);

if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}

energy_all /= n_samples;
energy_last /= n_samples_last;

if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}

if (energy_last > vad_thold*energy_all) {
return false;
}

return true;
}

std::string command_transcribe(whisper_context * ctx, const whisper_full_params & wparams, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();

Expand Down Expand Up @@ -155,7 +100,7 @@ void command_get_audio(int ms, int sample_rate, std::vector<float> & audio) {
const int64_t n_samples = (ms * sample_rate) / 1000;

int64_t n_take = 0;
if (g_pcmf32.size() < n_samples) {
if (n_samples > (int) g_pcmf32.size()) {
n_take = g_pcmf32.size();
} else {
n_take = n_samples;
Expand Down Expand Up @@ -187,7 +132,6 @@ void command_main(size_t index) {

printf("command: using %d threads\n", wparams.n_threads);

bool is_running = true;
bool have_prompt = false;
bool ask_prompt = true;
bool print_energy = false;
Expand Down Expand Up @@ -233,7 +177,7 @@ void command_main(size_t index) {
{
command_get_audio(vad_ms, WHISPER_SAMPLE_RATE, pcmf32_cur);

if (command_vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) {
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
command_set_status("Speech detected! Processing ...");

Expand Down
3 changes: 1 addition & 2 deletions examples/command/CMakeLists.txt
Expand Up @@ -5,6 +5,5 @@ if (WHISPER_SUPPORT_SDL2)

include(DefaultTargetOptions)

target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS})
target_link_libraries(${TARGET} PRIVATE whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT})
target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${CMAKE_THREAD_LIBS_INIT})
endif ()

0 comments on commit d72d881

Please sign in to comment.