Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,9 @@ static constexpr float MOST_NEGATIVE_SINGLE_FLOAT = static_cast<float>(std::nume
static float LOG2_MOST_POSITIVE_SINGLE_FLOAT = std::log2(MOST_POSITIVE_SINGLE_FLOAT);
static float LOG10_MOST_POSITIVE_SINGLE_FLOAT = std::log10(MOST_POSITIVE_SINGLE_FLOAT);
static constexpr float PI = static_cast<float>(M_PI);

// buffer sizes
static constexpr size_t PROMISE_VENDOR_THREAD_POOL_WORKER_COUNT = 4;
static constexpr size_t PROMISE_VENDOR_THREAD_POOL_LOAD_BALANCER_QUEUE_SIZE = 32;
static constexpr size_t PROMISE_VENDOR_THREAD_POOL_WORKER_QUEUE_SIZE = 32;
} // namespace audioapi
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ jsi::Value PromiseVendor::createAsyncPromise(
&&function) {
auto &runtime = *runtime_;
auto callInvoker = callInvoker_;
auto threadPool = threadPool_;
auto promiseCtor = runtime.global().getPropertyAsFunction(runtime, "Promise");
auto promiseLambda = [callInvoker = std::move(callInvoker),
auto promiseLambda = [threadPool = std::move(threadPool),
callInvoker = std::move(callInvoker),
function = std::move(function)](
jsi::Runtime &runtime,
const jsi::Value &thisValue,
Expand All @@ -78,32 +80,26 @@ jsi::Value PromiseVendor::createAsyncPromise(
auto rejectLocal = arguments[1].asObject(runtime).asFunction(runtime);
auto reject = std::make_shared<jsi::Function>(std::move(rejectLocal));

/// Here we can swap later for thread pool instead of creating a new thread
/// each time
std::thread(
[callInvoker = std::move(callInvoker),
function = std::move(function),
resolve = std::move(resolve),
reject = std::move(reject)](jsi::Runtime &runtime) {
auto result = function(runtime);
if (std::holds_alternative<jsi::Value>(result)) {
auto valueShared = std::make_shared<jsi::Value>(
std::move(std::get<jsi::Value>(result)));
callInvoker->invokeAsync(
[resolve, &runtime, valueShared]() -> void {
resolve->call(runtime, *valueShared);
});
} else {
auto errorMessage = std::get<std::string>(result);
callInvoker->invokeAsync(
[reject, &runtime, errorMessage]() -> void {
auto error = jsi::JSError(runtime, errorMessage);
reject->call(runtime, error.value());
});
}
},
std::ref(runtime))
.detach();
threadPool->schedule([callInvoker = std::move(callInvoker),
function = std::move(function),
resolve = std::move(resolve),
reject = std::move(reject),
&runtime]() {
auto result = function(runtime);
if (std::holds_alternative<jsi::Value>(result)) {
auto valueShared = std::make_shared<jsi::Value>(
std::move(std::get<jsi::Value>(result)));
callInvoker->invokeAsync([resolve, &runtime, valueShared]() -> void {
resolve->call(runtime, *valueShared);
});
} else {
auto errorMessage = std::get<std::string>(result);
callInvoker->invokeAsync([reject, &runtime, errorMessage]() -> void {
auto error = jsi::JSError(runtime, errorMessage);
reject->call(runtime, error.value());
});
}
});
return jsi::Value::undefined();
};
auto promiseFunction = jsi::Function::createFromHostFunction(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once


#include <audioapi/core/utils/Constants.h>
#include <ReactCommon/CallInvoker.h>
#include <jsi/jsi.h>
#include <variant>
Expand All @@ -8,6 +10,7 @@
#include <string>
#include <utility>
#include <functional>
#include <audioapi/utils/ThreadPool.hpp>

namespace audioapi {

Expand All @@ -32,14 +35,19 @@ class Promise {

class PromiseVendor {
public:
PromiseVendor(jsi::Runtime *runtime, const std::shared_ptr<react::CallInvoker> &callInvoker): runtime_(runtime), callInvoker_(callInvoker) {}
PromiseVendor(jsi::Runtime *runtime, const std::shared_ptr<react::CallInvoker> &callInvoker):
runtime_(runtime), callInvoker_(callInvoker), threadPool_(std::make_shared<ThreadPool>(
audioapi::PROMISE_VENDOR_THREAD_POOL_WORKER_COUNT,
audioapi::PROMISE_VENDOR_THREAD_POOL_LOAD_BALANCER_QUEUE_SIZE,
audioapi::PROMISE_VENDOR_THREAD_POOL_WORKER_QUEUE_SIZE)) {}

jsi::Value createPromise(const std::function<void(std::shared_ptr<Promise>)> &function);

/// @brief Creates an asynchronous promise.
/// @param function The function to execute asynchronously. It should return either a jsi::Value on success or a std::string error message on failure.
/// @return The created promise.
/// @note The function is executed on a different thread, and the promise is resolved or rejected based on the function's outcome.
/// @note IMPORTANT: This function is not thread-safe and should be called from a single thread only. (comes from underlying ThreadPool implementation)
/// @example
/// ```cpp
/// auto promise = promiseVendor_->createAsyncPromise(
Expand All @@ -56,6 +64,7 @@ class PromiseVendor {
private:
jsi::Runtime *runtime_;
std::shared_ptr<react::CallInvoker> callInvoker_;
std::shared_ptr<ThreadPool> threadPool_;
};

} // namespace audioapi
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#pragma once
#include <thread>
#include <vector>
#include <functional>
#include <variant>
#include <audioapi/utils/SpscChannel.hpp>

namespace audioapi {

/// @brief A simple thread pool implementation using lock-free SPSC channels for task scheduling and execution.
/// @note The thread pool consists of a load balancer thread and multiple worker threads.
/// @note The load balancer receives tasks and distributes them to worker threads in a round-robin fashion.
/// @note Each worker thread has its own SPSC channel to receive tasks from the load balancer.
/// @note The thread pool can be shut down gracefully by sending a stop event to the load balancer, which then propagates the stop event to all worker threads.
/// @note IMPORTANT: ThreadPool is not thread-safe and events should be scheduled from a single thread only.
class ThreadPool {
struct StopEvent {};
struct TaskEvent { std::function<void()> task; };
using Event = std::variant<TaskEvent, StopEvent>;

using Sender = channels::spsc::Sender<Event, channels::spsc::OverflowStrategy::WAIT_ON_FULL, channels::spsc::WaitStrategy::ATOMIC_WAIT>;
using Receiver = channels::spsc::Receiver<Event, channels::spsc::OverflowStrategy::WAIT_ON_FULL, channels::spsc::WaitStrategy::ATOMIC_WAIT>;
public:
/// @brief Construct a new ThreadPool
/// @param numThreads The number of worker threads to create
/// @param loadBalancerQueueSize The size of the load balancer's queue
/// @param workerQueueSize The size of each worker thread's queue
ThreadPool(size_t numThreads, size_t loadBalancerQueueSize = 32, size_t workerQueueSize = 32) {
auto [sender, receiver] = channels::spsc::channel<Event, channels::spsc::OverflowStrategy::WAIT_ON_FULL, channels::spsc::WaitStrategy::ATOMIC_WAIT>(loadBalancerQueueSize);
loadBalancerSender = std::move(sender);
std::vector<Sender> workerSenders;
workerSenders.reserve(numThreads);
for (size_t i = 0; i < numThreads; ++i) {
auto [workerSender, workerReceiver] = channels::spsc::channel<Event, channels::spsc::OverflowStrategy::WAIT_ON_FULL, channels::spsc::WaitStrategy::ATOMIC_WAIT>(workerQueueSize);
workers.emplace_back(&ThreadPool::workerThreadFunc, this, std::move(workerReceiver));
workerSenders.emplace_back(std::move(workerSender));
}
loadBalancerThread = std::thread(&ThreadPool::loadBalancerThreadFunc, this, std::move(receiver), std::move(workerSenders));
}
~ThreadPool() {
loadBalancerSender.send(StopEvent{});
loadBalancerThread.join();
for (auto& worker : workers) {
worker.join();
}
}

/// @brief Schedule a task to be executed by the thread pool
/// @param task The task to be executed
/// @note This function is lock-free and most of the time wait-free, but may block if the load balancer queue is full.
/// @note Please remember that the task will be executed in a different thread, so make sure to capture any required variables by value.
/// @note The task should not throw exceptions, as they will not be caught.
/// @note The task should end at some point, otherwise the thread pool will never be able to shut down.
/// @note IMPORTANT: This function is not thread-safe and should be called from a single thread only.
void schedule(std::function<void()> &&task) noexcept {
loadBalancerSender.send(TaskEvent{std::move(task)});
}

private:
std::thread loadBalancerThread;
std::vector<std::thread> workers;
Sender loadBalancerSender;

void workerThreadFunc(Receiver &&receiver) {
Receiver localReceiver = std::move(receiver);
while (true) {
auto event = localReceiver.receive();
/// We use [[unlikely]] and [[likely]] attributes to help the compiler optimize the branching.
/// we expect most of the time to receive TaskEvent, and rarely StopEvent.
/// and whenever we receive StopEvent we can burn some cycles as it will not be expected to execute fast.
if (std::holds_alternative<StopEvent>(event)) [[ unlikely ]] {
break;
} else if (std::holds_alternative<TaskEvent>(event)) [[ likely ]] {
std::get<TaskEvent>(event).task();
}
}
}

void loadBalancerThreadFunc(Receiver &&receiver, std::vector<Sender> &&workerSenders) {
Receiver localReceiver = std::move(receiver);
std::vector<Sender> localWorkerSenders = std::move(workerSenders);
size_t nextWorker = 0;
while (true) {
auto event = localReceiver.receive();
/// We use [[unlikely]] and [[likely]] attributes to help the compiler optimize the branching.
/// we expect most of the time to receive TaskEvent, and rarely StopEvent.
/// and whenever we receive StopEvent we can burn some cycles as it will not be expected to execute fast.
if (std::holds_alternative<StopEvent>(event)) [[ unlikely ]] {
// Propagate stop event to all workers
for (size_t i = 0; i < localWorkerSenders.size(); ++i) {
localWorkerSenders[i].send(StopEvent{});
}
break;
} else if (std::holds_alternative<TaskEvent>(event)) [[ likely ]] {
// Dispatch task to the next worker in round-robin fashion
auto& taskEvent = std::get<TaskEvent>(event);
localWorkerSenders[nextWorker].send(std::move(taskEvent));
nextWorker = (nextWorker + 1) % localWorkerSenders.size();
}
}
}
};

}; // namespace audioapi