diff --git a/packages/react-native-audio-api/common/cpp/audioapi/core/utils/Constants.h b/packages/react-native-audio-api/common/cpp/audioapi/core/utils/Constants.h index 4d82dcab2..8d53d8d21 100644 --- a/packages/react-native-audio-api/common/cpp/audioapi/core/utils/Constants.h +++ b/packages/react-native-audio-api/common/cpp/audioapi/core/utils/Constants.h @@ -16,4 +16,9 @@ static constexpr float MOST_NEGATIVE_SINGLE_FLOAT = static_cast(std::nume static float LOG2_MOST_POSITIVE_SINGLE_FLOAT = std::log2(MOST_POSITIVE_SINGLE_FLOAT); static float LOG10_MOST_POSITIVE_SINGLE_FLOAT = std::log10(MOST_POSITIVE_SINGLE_FLOAT); static constexpr float PI = static_cast(M_PI); + +// buffer sizes +static constexpr size_t PROMISE_VENDOR_THREAD_POOL_WORKER_COUNT = 4; +static constexpr size_t PROMISE_VENDOR_THREAD_POOL_LOAD_BALANCER_QUEUE_SIZE = 32; +static constexpr size_t PROMISE_VENDOR_THREAD_POOL_WORKER_QUEUE_SIZE = 32; } // namespace audioapi diff --git a/packages/react-native-audio-api/common/cpp/audioapi/jsi/JsiPromise.cpp b/packages/react-native-audio-api/common/cpp/audioapi/jsi/JsiPromise.cpp index 6487019a1..6fa25a498 100644 --- a/packages/react-native-audio-api/common/cpp/audioapi/jsi/JsiPromise.cpp +++ b/packages/react-native-audio-api/common/cpp/audioapi/jsi/JsiPromise.cpp @@ -66,8 +66,10 @@ jsi::Value PromiseVendor::createAsyncPromise( &&function) { auto &runtime = *runtime_; auto callInvoker = callInvoker_; + auto threadPool = threadPool_; auto promiseCtor = runtime.global().getPropertyAsFunction(runtime, "Promise"); - auto promiseLambda = [callInvoker = std::move(callInvoker), + auto promiseLambda = [threadPool = std::move(threadPool), + callInvoker = std::move(callInvoker), function = std::move(function)]( jsi::Runtime &runtime, const jsi::Value &thisValue, @@ -78,32 +80,26 @@ jsi::Value PromiseVendor::createAsyncPromise( auto rejectLocal = arguments[1].asObject(runtime).asFunction(runtime); auto reject = std::make_shared(std::move(rejectLocal)); - /// Here we can swap later for thread pool instead of creating a new thread - /// each time - std::thread( - [callInvoker = std::move(callInvoker), - function = std::move(function), - resolve = std::move(resolve), - reject = std::move(reject)](jsi::Runtime &runtime) { - auto result = function(runtime); - if (std::holds_alternative(result)) { - auto valueShared = std::make_shared( - std::move(std::get(result))); - callInvoker->invokeAsync( - [resolve, &runtime, valueShared]() -> void { - resolve->call(runtime, *valueShared); - }); - } else { - auto errorMessage = std::get(result); - callInvoker->invokeAsync( - [reject, &runtime, errorMessage]() -> void { - auto error = jsi::JSError(runtime, errorMessage); - reject->call(runtime, error.value()); - }); - } - }, - std::ref(runtime)) - .detach(); + threadPool->schedule([callInvoker = std::move(callInvoker), + function = std::move(function), + resolve = std::move(resolve), + reject = std::move(reject), + &runtime]() { + auto result = function(runtime); + if (std::holds_alternative(result)) { + auto valueShared = std::make_shared( + std::move(std::get(result))); + callInvoker->invokeAsync([resolve, &runtime, valueShared]() -> void { + resolve->call(runtime, *valueShared); + }); + } else { + auto errorMessage = std::get(result); + callInvoker->invokeAsync([reject, &runtime, errorMessage]() -> void { + auto error = jsi::JSError(runtime, errorMessage); + reject->call(runtime, error.value()); + }); + } + }); return jsi::Value::undefined(); }; auto promiseFunction = jsi::Function::createFromHostFunction( diff --git a/packages/react-native-audio-api/common/cpp/audioapi/jsi/JsiPromise.h b/packages/react-native-audio-api/common/cpp/audioapi/jsi/JsiPromise.h index 9d8ad6883..985bc64dd 100644 --- a/packages/react-native-audio-api/common/cpp/audioapi/jsi/JsiPromise.h +++ b/packages/react-native-audio-api/common/cpp/audioapi/jsi/JsiPromise.h @@ -1,5 +1,7 @@ #pragma once + +#include #include #include #include @@ -8,6 +10,7 @@ #include #include #include +#include namespace audioapi { @@ -32,7 +35,11 @@ class Promise { class PromiseVendor { public: - PromiseVendor(jsi::Runtime *runtime, const std::shared_ptr &callInvoker): runtime_(runtime), callInvoker_(callInvoker) {} + PromiseVendor(jsi::Runtime *runtime, const std::shared_ptr &callInvoker): + runtime_(runtime), callInvoker_(callInvoker), threadPool_(std::make_shared( + audioapi::PROMISE_VENDOR_THREAD_POOL_WORKER_COUNT, + audioapi::PROMISE_VENDOR_THREAD_POOL_LOAD_BALANCER_QUEUE_SIZE, + audioapi::PROMISE_VENDOR_THREAD_POOL_WORKER_QUEUE_SIZE)) {} jsi::Value createPromise(const std::function)> &function); @@ -40,6 +47,7 @@ class PromiseVendor { /// @param function The function to execute asynchronously. It should return either a jsi::Value on success or a std::string error message on failure. /// @return The created promise. /// @note The function is executed on a different thread, and the promise is resolved or rejected based on the function's outcome. + /// @note IMPORTANT: This function is not thread-safe and should be called from a single thread only. (comes from underlying ThreadPool implementation) /// @example /// ```cpp /// auto promise = promiseVendor_->createAsyncPromise( @@ -56,6 +64,7 @@ class PromiseVendor { private: jsi::Runtime *runtime_; std::shared_ptr callInvoker_; + std::shared_ptr threadPool_; }; } // namespace audioapi diff --git a/packages/react-native-audio-api/common/cpp/audioapi/utils/ThreadPool.hpp b/packages/react-native-audio-api/common/cpp/audioapi/utils/ThreadPool.hpp new file mode 100644 index 000000000..5fec2ff19 --- /dev/null +++ b/packages/react-native-audio-api/common/cpp/audioapi/utils/ThreadPool.hpp @@ -0,0 +1,104 @@ +#pragma once +#include +#include +#include +#include +#include + +namespace audioapi { + +/// @brief A simple thread pool implementation using lock-free SPSC channels for task scheduling and execution. +/// @note The thread pool consists of a load balancer thread and multiple worker threads. +/// @note The load balancer receives tasks and distributes them to worker threads in a round-robin fashion. +/// @note Each worker thread has its own SPSC channel to receive tasks from the load balancer. +/// @note The thread pool can be shut down gracefully by sending a stop event to the load balancer, which then propagates the stop event to all worker threads. +/// @note IMPORTANT: ThreadPool is not thread-safe and events should be scheduled from a single thread only. +class ThreadPool { + struct StopEvent {}; + struct TaskEvent { std::function task; }; + using Event = std::variant; + + using Sender = channels::spsc::Sender; + using Receiver = channels::spsc::Receiver; +public: + /// @brief Construct a new ThreadPool + /// @param numThreads The number of worker threads to create + /// @param loadBalancerQueueSize The size of the load balancer's queue + /// @param workerQueueSize The size of each worker thread's queue + ThreadPool(size_t numThreads, size_t loadBalancerQueueSize = 32, size_t workerQueueSize = 32) { + auto [sender, receiver] = channels::spsc::channel(loadBalancerQueueSize); + loadBalancerSender = std::move(sender); + std::vector workerSenders; + workerSenders.reserve(numThreads); + for (size_t i = 0; i < numThreads; ++i) { + auto [workerSender, workerReceiver] = channels::spsc::channel(workerQueueSize); + workers.emplace_back(&ThreadPool::workerThreadFunc, this, std::move(workerReceiver)); + workerSenders.emplace_back(std::move(workerSender)); + } + loadBalancerThread = std::thread(&ThreadPool::loadBalancerThreadFunc, this, std::move(receiver), std::move(workerSenders)); + } + ~ThreadPool() { + loadBalancerSender.send(StopEvent{}); + loadBalancerThread.join(); + for (auto& worker : workers) { + worker.join(); + } + } + + /// @brief Schedule a task to be executed by the thread pool + /// @param task The task to be executed + /// @note This function is lock-free and most of the time wait-free, but may block if the load balancer queue is full. + /// @note Please remember that the task will be executed in a different thread, so make sure to capture any required variables by value. + /// @note The task should not throw exceptions, as they will not be caught. + /// @note The task should end at some point, otherwise the thread pool will never be able to shut down. + /// @note IMPORTANT: This function is not thread-safe and should be called from a single thread only. + void schedule(std::function &&task) noexcept { + loadBalancerSender.send(TaskEvent{std::move(task)}); + } + +private: + std::thread loadBalancerThread; + std::vector workers; + Sender loadBalancerSender; + + void workerThreadFunc(Receiver &&receiver) { + Receiver localReceiver = std::move(receiver); + while (true) { + auto event = localReceiver.receive(); + /// We use [[unlikely]] and [[likely]] attributes to help the compiler optimize the branching. + /// we expect most of the time to receive TaskEvent, and rarely StopEvent. + /// and whenever we receive StopEvent we can burn some cycles as it will not be expected to execute fast. + if (std::holds_alternative(event)) [[ unlikely ]] { + break; + } else if (std::holds_alternative(event)) [[ likely ]] { + std::get(event).task(); + } + } + } + + void loadBalancerThreadFunc(Receiver &&receiver, std::vector &&workerSenders) { + Receiver localReceiver = std::move(receiver); + std::vector localWorkerSenders = std::move(workerSenders); + size_t nextWorker = 0; + while (true) { + auto event = localReceiver.receive(); + /// We use [[unlikely]] and [[likely]] attributes to help the compiler optimize the branching. + /// we expect most of the time to receive TaskEvent, and rarely StopEvent. + /// and whenever we receive StopEvent we can burn some cycles as it will not be expected to execute fast. + if (std::holds_alternative(event)) [[ unlikely ]] { + // Propagate stop event to all workers + for (size_t i = 0; i < localWorkerSenders.size(); ++i) { + localWorkerSenders[i].send(StopEvent{}); + } + break; + } else if (std::holds_alternative(event)) [[ likely ]] { + // Dispatch task to the next worker in round-robin fashion + auto& taskEvent = std::get(event); + localWorkerSenders[nextWorker].send(std::move(taskEvent)); + nextWorker = (nextWorker + 1) % localWorkerSenders.size(); + } + } + } +}; + +}; // namespace audioapi