Skip to content

Commit

Permalink
Factored out more 'common' code into the abstract audio_async superclass
Browse files Browse the repository at this point in the history
  • Loading branch information
shanelenagh committed Feb 8, 2024
1 parent 9c1a270 commit 2a2307b
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 123 deletions.
120 changes: 113 additions & 7 deletions examples/common-audioasync.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <vector>
#include <mutex>
#include <thread>
#include <cstring>

// command-line parameters
struct whisper_params {
Expand Down Expand Up @@ -40,18 +41,123 @@ struct whisper_params {
//
class audio_async {
public:
audio_async(int len_ms) { };
audio_async(int len_ms) {
m_len_ms = len_ms;

m_running = false;
};
~audio_async() { };

virtual bool init(whisper_params params, int sample_rate) = 0;
virtual bool init(whisper_params params, int sample_rate) {
m_sample_rate = sample_rate;

m_audio.resize((m_sample_rate*m_len_ms)/1000);

return true;
}

virtual bool resume() = 0;
virtual bool pause() = 0;
virtual bool clear() = 0;
virtual bool resume() {
m_running = true;
return true;
}
virtual bool pause() {
m_running = false;
return true;
}
virtual bool clear() {
{
std::lock_guard<std::mutex> lock(m_mutex);

m_audio_pos = 0;
m_audio_len = 0;
}
return true;
}
bool is_running() { return m_running; }

// get audio data from the circular buffer
virtual void get(int ms, std::vector<float> & audio) = 0;
virtual void get(int ms, std::vector<float> & result) {
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return;
}

result.clear();

{
std::lock_guard<std::mutex> lock(m_mutex);

if (ms <= 0) {
ms = m_len_ms;
}

size_t n_samples = (m_sample_rate * ms) / 1000;
if (n_samples > m_audio_len) {
n_samples = m_audio_len;
}

result.resize(n_samples);

int s0 = m_audio_pos - n_samples;
if (s0 < 0) {
s0 += m_audio.size();
}

if (s0 + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - s0;

memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
} else {
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
}
}
}

// callback to be called by audio source
virtual void callback(uint8_t * stream, int len) = 0;
void callback(uint8_t * stream, int len) {
if (!m_running) {
return;
}

size_t n_samples = len / sizeof(float);

if (n_samples > m_audio.size()) {
n_samples = m_audio.size();

stream += (len - (n_samples * sizeof(float)));
}

//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);

{
std::lock_guard<std::mutex> lock(m_mutex);

if (m_audio_pos + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - m_audio_pos;

memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
memcpy(&m_audio[0], stream + n0 * sizeof(float), (n_samples - n0) * sizeof(float));

m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = m_audio.size();
} else {
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));

m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
}
}
}

private:
int m_len_ms = 0;
int m_sample_rate = 0;

std::atomic_bool m_running;
std::mutex m_mutex;

std::vector<float> m_audio;
size_t m_audio_pos = 0;
size_t m_audio_len = 0;
};
113 changes: 9 additions & 104 deletions examples/common-sdl.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
#include "common-sdl.h"

audio_async_sdl::audio_async_sdl(int len_ms) : audio_async(len_ms) {

m_len_ms = len_ms;

m_running = false;
}
audio_async_sdl::audio_async_sdl(int len_ms) : audio_async(len_ms) { }

audio_async_sdl::~audio_async_sdl() {
if (m_dev_id_in) {
Expand Down Expand Up @@ -70,11 +65,7 @@ bool audio_async_sdl::init(whisper_params params, int sample_rate) {
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
}

m_sample_rate = capture_spec_obtained.freq;

m_audio.resize((m_sample_rate*m_len_ms)/1000);

return true;
return audio_async::init(params, capture_spec_obtained.freq);
}

bool audio_async_sdl::resume() {
Expand All @@ -83,16 +74,14 @@ bool audio_async_sdl::resume() {
return false;
}

if (m_running) {
if (is_running()) {
fprintf(stderr, "%s: already running!\n", __func__);
return false;
}

SDL_PauseAudioDevice(m_dev_id_in, 0);

m_running = true;

return true;
return audio_async::resume();
}

bool audio_async_sdl::pause() {
Expand All @@ -101,16 +90,14 @@ bool audio_async_sdl::pause() {
return false;
}

if (!m_running) {
if (!is_running()) {
fprintf(stderr, "%s: already paused!\n", __func__);
return false;
}

SDL_PauseAudioDevice(m_dev_id_in, 1);

m_running = false;

return true;
return audio_async::pause();
}

bool audio_async_sdl::clear() {
Expand All @@ -119,55 +106,7 @@ bool audio_async_sdl::clear() {
return false;
}

if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return false;
}

{
std::lock_guard<std::mutex> lock(m_mutex);

m_audio_pos = 0;
m_audio_len = 0;
}

return true;
}

// callback to be called by SDL
void audio_async_sdl::callback(uint8_t * stream, int len) {
if (!m_running) {
return;
}

size_t n_samples = len / sizeof(float);

if (n_samples > m_audio.size()) {
n_samples = m_audio.size();

stream += (len - (n_samples * sizeof(float)));
}

//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);

{
std::lock_guard<std::mutex> lock(m_mutex);

if (m_audio_pos + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - m_audio_pos;

memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
memcpy(&m_audio[0], stream + n0 * sizeof(float), (n_samples - n0) * sizeof(float));

m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = m_audio.size();
} else {
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));

m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
}
}
return audio_async::clear();
}

void audio_async_sdl::get(int ms, std::vector<float> & result) {
Expand All @@ -176,41 +115,7 @@ void audio_async_sdl::get(int ms, std::vector<float> & result) {
return;
}

if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return;
}

result.clear();

{
std::lock_guard<std::mutex> lock(m_mutex);

if (ms <= 0) {
ms = m_len_ms;
}

size_t n_samples = (m_sample_rate * ms) / 1000;
if (n_samples > m_audio_len) {
n_samples = m_audio_len;
}

result.resize(n_samples);

int s0 = m_audio_pos - n_samples;
if (s0 < 0) {
s0 += m_audio.size();
}

if (s0 + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - s0;

memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
} else {
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
}
}
audio_async::get(ms, result);
}

bool sdl_poll_events() {
Expand All @@ -227,4 +132,4 @@ bool sdl_poll_events() {
}

return true;
}
}
13 changes: 1 addition & 12 deletions examples/common-sdl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
//
// SDL Audio capture
//

class audio_async_sdl : public audio_async {
public:
audio_async_sdl(int len_ms);
Expand All @@ -28,23 +27,13 @@ class audio_async_sdl : public audio_async {
bool clear() override;

// callback to be called by SDL
void callback(uint8_t * stream, int len) override;
void callback(uint8_t * stream, int len);

// get audio data from the circular buffer
void get(int ms, std::vector<float> & audio) override;

private:
SDL_AudioDeviceID m_dev_id_in = 0;

int m_len_ms = 0;
int m_sample_rate = 0;

std::atomic_bool m_running;
std::mutex m_mutex;

std::vector<float> m_audio;
size_t m_audio_pos = 0;
size_t m_audio_len = 0;
};

// Return false if need to quit
Expand Down

0 comments on commit 2a2307b

Please sign in to comment.