From 0c5040715b13aeb9c6f0f8944f7079591bbcb577 Mon Sep 17 00:00:00 2001 From: Stephane Janel Date: Tue, 24 Oct 2023 09:53:54 +0200 Subject: [PATCH] Finishes all futures when one throws an exception --- src/api/common/include/exchangeprivateapi.hpp | 4 +- src/api/exchanges/src/binanceprivateapi.cpp | 15 +- src/engine/src/exchangesorchestrator.cpp | 71 +++++----- src/tech/CMakeLists.txt | 2 + src/tech/include/threadpool.hpp | 129 +++++++++++------- src/tech/test/threadpool_test.cpp | 30 +++- 6 files changed, 151 insertions(+), 100 deletions(-) diff --git a/src/api/common/include/exchangeprivateapi.hpp b/src/api/common/include/exchangeprivateapi.hpp index ccf486a9..84b2381f 100644 --- a/src/api/common/include/exchangeprivateapi.hpp +++ b/src/api/common/include/exchangeprivateapi.hpp @@ -117,8 +117,8 @@ class ExchangePrivate : public ExchangeBase { /// Returns the amounts actually traded with the final amount balance on this currency TradedAmountsVectorWithFinalAmount queryDustSweeper(CurrencyCode currencyCode); - /// Builds en ExchangeName wrapping the exchange and the key name - ExchangeName exchangeName() const { return ExchangeName(_exchangePublic.name(), _apiKey.name()); } + /// Builds an ExchangeName wrapping the exchange and the key name + ExchangeName exchangeName() const { return {_exchangePublic.name(), _apiKey.name()}; } const ExchangeInfo &exchangeInfo() const { return _exchangePublic.exchangeInfo(); } diff --git a/src/api/exchanges/src/binanceprivateapi.cpp b/src/api/exchanges/src/binanceprivateapi.cpp index 0a339277..bfe2ef88 100644 --- a/src/api/exchanges/src/binanceprivateapi.cpp +++ b/src/api/exchanges/src/binanceprivateapi.cpp @@ -10,6 +10,7 @@ #include "recentdeposit.hpp" #include "ssl_sha.hpp" #include "stringhelpers.hpp" +#include "timedef.hpp" #include "timestring.hpp" #include "tradeinfo.hpp" @@ -61,7 +62,7 @@ void SetNonceAndSignature(const APIKey& apiKey, CurlPostData& postData, Duration bool CheckErrorDoRetry(int statusCode, const json& ret, QueryDelayDir& queryDelayDir, Duration& sleepingTime, Duration& queryDelay) { - static constexpr Duration kInitialDurationQueryDelay = std::chrono::milliseconds(200); + static constexpr Duration kInitialDurationQueryDelay = TimeInMs(200); switch (statusCode) { case kInvalidTimestamp: { auto msgIt = ret.find("msg"); @@ -80,7 +81,7 @@ bool CheckErrorDoRetry(int statusCode, const json& ret, QueryDelayDir& queryDela } queryDelay -= sleepingTime; log::warn("Our local time is ahead of Binance server's time. Query delay modified to {} ms", - std::chrono::duration_cast(queryDelay).count()); + std::chrono::duration_cast(queryDelay).count()); // Ensure Nonce is increasing while modifying the query delay std::this_thread::sleep_for(sleepingTime); return true; @@ -96,7 +97,7 @@ bool CheckErrorDoRetry(int statusCode, const json& ret, QueryDelayDir& queryDela } queryDelay += sleepingTime; log::warn("Our local time is behind of Binance server's time. Query delay modified to {} ms", - std::chrono::duration_cast(queryDelay).count()); + std::chrono::duration_cast(queryDelay).count()); return true; } } @@ -131,7 +132,7 @@ json PrivateQuery(CurlHandle& curlHandle, const APIKey& apiKey, HttpRequestType json ret; for (int retryPos = 0; retryPos < kNbOrderRequestsRetries; ++retryPos) { if (retryPos != 0) { - log::trace("Wait {} ms...", std::chrono::duration_cast(sleepingTime).count()); + log::trace("Wait {} ms...", std::chrono::duration_cast(sleepingTime).count()); std::this_thread::sleep_for(sleepingTime); sleepingTime = (3 * sleepingTime) / 2; } @@ -153,7 +154,7 @@ json PrivateQuery(CurlHandle& curlHandle, const APIKey& apiKey, HttpRequestType break; } if (throwIfError) { - log::error("Full Binance json error: '{}'", ret.dump()); + log::error("Full Binance json error for {}: '{}'", apiKey.name(), ret.dump()); throw exception("Error: {}, msg: {}", MonetaryAmount(statusCode), ret["msg"].get()); } return ret; @@ -267,7 +268,7 @@ Orders BinancePrivate::queryOpenedOrders(const OrdersConstraints& openedOrdersCo } int64_t millisecondsSinceEpoch = orderDetails["time"].get(); - TimePoint placedTime{std::chrono::milliseconds(millisecondsSinceEpoch)}; + TimePoint placedTime{TimeInMs(millisecondsSinceEpoch)}; if (!openedOrdersConstraints.validatePlacedTime(placedTime)) { continue; } @@ -697,7 +698,7 @@ MonetaryAmount BinancePrivate::queryWithdrawDelivery(const InitiatedWithdrawInfo MonetaryAmount amountReceived(depositDetail["amount"].get(), currencyCode); int64_t millisecondsSinceEpoch = depositDetail["insertTime"].get(); - TimePoint timestamp{std::chrono::milliseconds(millisecondsSinceEpoch)}; + TimePoint timestamp{TimeInMs(millisecondsSinceEpoch)}; closestRecentDepositPicker.addDeposit(RecentDeposit(amountReceived, timestamp)); } diff --git a/src/engine/src/exchangesorchestrator.cpp b/src/engine/src/exchangesorchestrator.cpp index 1a7d08d2..c6d3a193 100644 --- a/src/engine/src/exchangesorchestrator.cpp +++ b/src/engine/src/exchangesorchestrator.cpp @@ -97,8 +97,8 @@ ExchangeHealthCheckStatus ExchangesOrchestrator::healthCheck(ExchangeNameSpan ex ExchangeHealthCheckStatus ret(selectedExchanges.size()); - _threadPool.parallel_transform(selectedExchanges.begin(), selectedExchanges.end(), ret.begin(), - [](Exchange *exchange) { return std::make_pair(exchange, exchange->healthCheck()); }); + _threadPool.parallelTransform(selectedExchanges.begin(), selectedExchanges.end(), ret.begin(), + [](Exchange *exchange) { return std::make_pair(exchange, exchange->healthCheck()); }); return ret; } @@ -109,7 +109,7 @@ ExchangeTickerMaps ExchangesOrchestrator::getTickerInformation(ExchangeNameSpan UniquePublicSelectedExchanges selectedExchanges = _exchangeRetriever.selectOneAccount(exchangeNames); ExchangeTickerMaps ret(selectedExchanges.size()); - _threadPool.parallel_transform( + _threadPool.parallelTransform( selectedExchanges.begin(), selectedExchanges.end(), ret.begin(), [](Exchange *exchange) { return std::make_pair(exchange, exchange->queryAllApproximatedOrderBooks(1)); }); @@ -124,8 +124,8 @@ MarketOrderBookConversionRates ExchangesOrchestrator::getMarketOrderBooks(Market equiCurrencyCode.isNeutral() ? "" : equiCurrencyCode); UniquePublicSelectedExchanges selectedExchanges = _exchangeRetriever.selectOneAccount(exchangeNames); std::array isMarketTradable; - _threadPool.parallel_transform(selectedExchanges.begin(), selectedExchanges.end(), isMarketTradable.begin(), - [mk](Exchange *exchange) { return exchange->queryTradableMarkets().contains(mk); }); + _threadPool.parallelTransform(selectedExchanges.begin(), selectedExchanges.end(), isMarketTradable.begin(), + [mk](Exchange *exchange) { return exchange->queryTradableMarkets().contains(mk); }); FilterVector(selectedExchanges, isMarketTradable); @@ -141,7 +141,7 @@ MarketOrderBookConversionRates ExchangesOrchestrator::getMarketOrderBooks(Market } return std::make_tuple(exchange->name(), std::move(marketOrderBook), optConversionRate); }; - _threadPool.parallel_transform(selectedExchanges.begin(), selectedExchanges.end(), ret.begin(), marketOrderBooksFunc); + _threadPool.parallelTransform(selectedExchanges.begin(), selectedExchanges.end(), ret.begin(), marketOrderBooksFunc); return ret; } @@ -157,9 +157,9 @@ BalancePerExchange ExchangesOrchestrator::getBalance(std::span balancePortfolios(balanceExchanges.size()); - _threadPool.parallel_transform( + _threadPool.parallelTransform( balanceExchanges.begin(), balanceExchanges.end(), balancePortfolios.begin(), - [&](Exchange *exchange) { return exchange->apiPrivate().getAccountBalance(balanceOptions); }); + [&balanceOptions](Exchange *exchange) { return exchange->apiPrivate().getAccountBalance(balanceOptions); }); BalancePerExchange ret; ret.reserve(balanceExchanges.size()); @@ -202,7 +202,7 @@ WalletPerExchange ExchangesOrchestrator::getDepositInfo(std::span walletPerExchange(depositInfoExchanges.size()); - _threadPool.parallel_transform( + _threadPool.parallelTransform( depositInfoExchanges.begin(), depositInfoExchanges.end(), walletPerExchange.begin(), [depositCurrency](Exchange *exchange) { return exchange->apiPrivate().queryDepositWallet(depositCurrency); }); WalletPerExchange ret; @@ -221,7 +221,7 @@ OpenedOrdersPerExchange ExchangesOrchestrator::getOpenedOrders(std::spanapiPrivate().queryOpenedOrders(openedOrdersConstraints))); }); @@ -236,7 +236,7 @@ NbCancelledOrdersPerExchange ExchangesOrchestrator::cancelOrders(std::spanapiPrivate().cancelOpenedOrders(ordersConstraints)); }); @@ -252,7 +252,7 @@ DepositsPerExchange ExchangesOrchestrator::getRecentDeposits(std::spanapiPrivate().queryRecentDeposits(depositsConstraints)); }); @@ -268,7 +268,7 @@ WithdrawsPerExchange ExchangesOrchestrator::getRecentWithdraws(std::spanapiPrivate().queryRecentWithdraws(withdrawsConstraints)); }); @@ -280,7 +280,7 @@ ConversionPathPerExchange ExchangesOrchestrator::getConversionPaths(Market mk, E log::info("Query {} conversion path from {}", mk, ConstructAccumulatedExchangeNames(exchangeNames)); UniquePublicSelectedExchanges selectedExchanges = _exchangeRetriever.selectOneAccount(exchangeNames); ConversionPathPerExchange conversionPathPerExchange(selectedExchanges.size()); - _threadPool.parallel_transform( + _threadPool.parallelTransform( selectedExchanges.begin(), selectedExchanges.end(), conversionPathPerExchange.begin(), [mk](Exchange *exchange) { return std::make_pair(exchange, exchange->apiPublic().findMarketsPath(mk.base(), mk.quote())); }); @@ -305,8 +305,8 @@ MarketsPerExchange ExchangesOrchestrator::getMarketsPerExchange(CurrencyCode cur [cur1, cur2](Market mk) { return mk.canTrade(cur1) && (cur2.isNeutral() || mk.canTrade(cur2)); }); return std::make_pair(exchange, std::move(ret)); }; - _threadPool.parallel_transform(selectedExchanges.begin(), selectedExchanges.end(), marketsPerExchange.begin(), - marketsWithCur); + _threadPool.parallelTransform(selectedExchanges.begin(), selectedExchanges.end(), marketsPerExchange.begin(), + marketsWithCur); return marketsPerExchange; } @@ -315,7 +315,7 @@ UniquePublicSelectedExchanges ExchangesOrchestrator::getExchangesTradingCurrency bool shouldBeWithdrawable) { UniquePublicSelectedExchanges selectedExchanges = _exchangeRetriever.selectOneAccount(exchangeNames); std::array isCurrencyTradablePerExchange; - _threadPool.parallel_transform( + _threadPool.parallelTransform( selectedExchanges.begin(), selectedExchanges.end(), isCurrencyTradablePerExchange.begin(), [currencyCode, shouldBeWithdrawable](Exchange *exchange) { CurrencyExchangeFlatSet currencies = exchange->queryTradableCurrencies(); @@ -332,9 +332,8 @@ UniquePublicSelectedExchanges ExchangesOrchestrator::getExchangesTradingMarket(M ExchangeNameSpan exchangeNames) { UniquePublicSelectedExchanges selectedExchanges = _exchangeRetriever.selectOneAccount(exchangeNames); std::array isMarketTradablePerExchange; - _threadPool.parallel_transform(selectedExchanges.begin(), selectedExchanges.end(), - isMarketTradablePerExchange.begin(), - [mk](Exchange *exchange) { return exchange->queryTradableMarkets().contains(mk); }); + _threadPool.parallelTransform(selectedExchanges.begin(), selectedExchanges.end(), isMarketTradablePerExchange.begin(), + [mk](Exchange *exchange) { return exchange->queryTradableMarkets().contains(mk); }); // Erases Exchanges which do not propose asked market FilterVector(selectedExchanges, isMarketTradablePerExchange); @@ -426,7 +425,7 @@ TradedAmountsPerExchange LaunchAndCollectTrades(ThreadPool &threadPool, Exchange ExchangeAmountMarketsPathVector::iterator last, CurrencyCode toCurrency, const TradeOptions &tradeOptions) { TradedAmountsPerExchange tradeAmountsPerExchange(std::distance(first, last)); - threadPool.parallel_transform(first, last, tradeAmountsPerExchange.begin(), [toCurrency, &tradeOptions](auto &tuple) { + threadPool.parallelTransform(first, last, tradeAmountsPerExchange.begin(), [toCurrency, &tradeOptions](auto &tuple) { Exchange *exchange = std::get<0>(tuple); return std::make_pair( exchange, exchange->apiPrivate().trade(std::get<1>(tuple), toCurrency, tradeOptions, std::get<2>(tuple))); @@ -438,7 +437,7 @@ template TradedAmountsPerExchange LaunchAndCollectTrades(ThreadPool &threadPool, Iterator first, Iterator last, const TradeOptions &tradeOptions) { TradedAmountsPerExchange tradeAmountsPerExchange(std::distance(first, last)); - threadPool.parallel_transform(first, last, tradeAmountsPerExchange.begin(), [&tradeOptions](auto &tuple) { + threadPool.parallelTransform(first, last, tradeAmountsPerExchange.begin(), [&tradeOptions](auto &tuple) { Exchange *exchange = std::get<0>(tuple); return std::make_pair(exchange, exchange->apiPrivate().trade(std::get<1>(tuple), std::get<2>(tuple), tradeOptions, std::get<3>(tuple))); @@ -718,11 +717,11 @@ TradedAmountsVectorWithFinalAmountPerExchange ExchangesOrchestrator::dustSweeper _exchangeRetriever.select(ExchangeRetriever::Order::kInitial, privateExchangeNames); TradedAmountsVectorWithFinalAmountPerExchange ret(selExchanges.size()); - _threadPool.parallel_transform(selExchanges.begin(), selExchanges.end(), ret.begin(), - [currencyCode](Exchange *exchange) { - return std::make_pair(static_cast(exchange), - exchange->apiPrivate().queryDustSweeper(currencyCode)); - }); + _threadPool.parallelTransform(selExchanges.begin(), selExchanges.end(), ret.begin(), + [currencyCode](Exchange *exchange) { + return std::make_pair(static_cast(exchange), + exchange->apiPrivate().queryDustSweeper(currencyCode)); + }); return ret; } @@ -747,8 +746,8 @@ DeliveredWithdrawInfoWithExchanges ExchangesOrchestrator::withdraw(MonetaryAmoun throw exception("Cannot withdraw to the same account"); } std::array currencyExchangeSets; - _threadPool.parallel_transform(exchangePair.begin(), exchangePair.end(), currencyExchangeSets.begin(), - [](Exchange *exchange) { return exchange->queryTradableCurrencies(); }); + _threadPool.parallelTransform(exchangePair.begin(), exchangePair.end(), currencyExchangeSets.begin(), + [](Exchange *exchange) { return exchange->queryTradableCurrencies(); }); DeliveredWithdrawInfoWithExchanges ret{{&fromExchange, &toExchange}, DeliveredWithdrawInfo{}}; @@ -783,10 +782,10 @@ MonetaryAmountPerExchange ExchangesOrchestrator::getWithdrawFees(CurrencyCode cu UniquePublicSelectedExchanges selectedExchanges = getExchangesTradingCurrency(currencyCode, exchangeNames, true); MonetaryAmountPerExchange withdrawFeePerExchange(selectedExchanges.size()); - _threadPool.parallel_transform(selectedExchanges.begin(), selectedExchanges.end(), withdrawFeePerExchange.begin(), - [currencyCode](Exchange *exchange) { - return std::make_pair(exchange, exchange->queryWithdrawalFee(currencyCode)); - }); + _threadPool.parallelTransform(selectedExchanges.begin(), selectedExchanges.end(), withdrawFeePerExchange.begin(), + [currencyCode](Exchange *exchange) { + return std::make_pair(exchange, exchange->queryWithdrawalFee(currencyCode)); + }); return withdrawFeePerExchange; } @@ -796,7 +795,7 @@ MonetaryAmountPerExchange ExchangesOrchestrator::getLast24hTradedVolumePerExchan UniquePublicSelectedExchanges selectedExchanges = getExchangesTradingMarket(mk, exchangeNames); MonetaryAmountPerExchange tradedVolumePerExchange(selectedExchanges.size()); - _threadPool.parallel_transform( + _threadPool.parallelTransform( selectedExchanges.begin(), selectedExchanges.end(), tradedVolumePerExchange.begin(), [mk](Exchange *exchange) { return std::make_pair(exchange, exchange->queryLast24hVolume(mk)); }); return tradedVolumePerExchange; @@ -809,7 +808,7 @@ LastTradesPerExchange ExchangesOrchestrator::getLastTradesPerExchange(Market mk, UniquePublicSelectedExchanges selectedExchanges = getExchangesTradingMarket(mk, exchangeNames); LastTradesPerExchange ret(selectedExchanges.size()); - _threadPool.parallel_transform( + _threadPool.parallelTransform( selectedExchanges.begin(), selectedExchanges.end(), ret.begin(), [mk, nbLastTrades](Exchange *exchange) { return std::make_pair(static_cast(exchange), exchange->queryLastTrades(mk, nbLastTrades)); }); @@ -822,7 +821,7 @@ MonetaryAmountPerExchange ExchangesOrchestrator::getLastPricePerExchange(Market UniquePublicSelectedExchanges selectedExchanges = getExchangesTradingMarket(mk, exchangeNames); MonetaryAmountPerExchange lastPricePerExchange(selectedExchanges.size()); - _threadPool.parallel_transform( + _threadPool.parallelTransform( selectedExchanges.begin(), selectedExchanges.end(), lastPricePerExchange.begin(), [mk](Exchange *exchange) { return std::make_pair(exchange, exchange->queryLastPrice(mk)); }); return lastPricePerExchange; diff --git a/src/tech/CMakeLists.txt b/src/tech/CMakeLists.txt index 814aec3b..a01a0b9e 100644 --- a/src/tech/CMakeLists.txt +++ b/src/tech/CMakeLists.txt @@ -113,6 +113,8 @@ add_unit_test( add_unit_test( threadpool_test test/threadpool_test.cpp + DEFINITIONS + CCT_DISABLE_SPDLOG ) add_unit_test( diff --git a/src/tech/include/threadpool.hpp b/src/tech/include/threadpool.hpp index d4c74e08..778c3be5 100644 --- a/src/tech/include/threadpool.hpp +++ b/src/tech/include/threadpool.hpp @@ -1,16 +1,19 @@ #pragma once #include +#include #include #include #include #include #include -#include #include #include #include "cct_const.hpp" +#include "cct_exception.hpp" +#include "cct_invalid_argument_exception.hpp" +#include "cct_log.hpp" #include "cct_smallvector.hpp" #include "cct_vector.hpp" @@ -19,41 +22,20 @@ namespace cct { /// @brief C++ ThreadPool implementation. Number of threads is to be specified at creation of the object. /// @note original code taken from https://github.com/progschj/ThreadPool/blob/master/ThreadPool.h, with modifications: /// - Rule of 5: delete all special members. -/// - Utility function parallel_transform added. +/// - Utility function parallelTransform added. /// - C++20 version with std::invoke_result instead of std::result_of and std::jthread that calls join /// automatically class ThreadPool { public: - explicit ThreadPool(int nbThreads = 1) { - if (nbThreads < 1) { - throw std::invalid_argument("number of threads should be strictly positive"); - } - _workers.reserve(static_cast(nbThreads)); - for (decltype(nbThreads) threadPos = 0; threadPos < nbThreads; ++threadPos) { - _workers.emplace_back([this] { - while (true) { - TasksQueue::value_type task; - - { - std::unique_lock lock(this->_queueMutex); - this->_condition.wait(lock, [this] { return this->_stop || !this->_tasks.empty(); }); - if (this->_stop && this->_tasks.empty()) { - break; - } - task = std::move(this->_tasks.front()); - this->_tasks.pop(); - } - task(); - } - }); - } - } + explicit ThreadPool(int nbThreads = 1); ThreadPool(const ThreadPool&) = delete; ThreadPool(ThreadPool&&) = delete; ThreadPool& operator=(const ThreadPool&) = delete; ThreadPool& operator=(ThreadPool&&) = delete; + ~ThreadPool(); + auto nbWorkers() const noexcept { return _workers.size(); } // add new work item to the pool @@ -64,25 +46,19 @@ class ThreadPool { // This function will first enqueue all the tasks at one, using waiting threads of the thread pool, // and then retrieves and moves the results to 'out', as for std::transform. template - OutputIt parallel_transform(InputIt beg, InputIt end, OutputIt out, UnaryOperation op); + OutputIt parallelTransform(InputIt first, InputIt last, OutputIt out, UnaryOperation unary_op); // Parallel version of std::transform with binary operation. template - OutputIt parallel_transform(InputIt1 first1, InputIt1 last1, InputIt2 first2, OutputIt d_first, - BinaryOperation binary_op); - - ~ThreadPool() { - stopRequested(); - _condition.notify_all(); - } + OutputIt parallelTransform(InputIt1 first1, InputIt1 last1, InputIt2 first2, OutputIt out, BinaryOperation binary_op); private: using TasksQueue = std::queue>; - void stopRequested() { - std::unique_lock lock(_queueMutex); - _stop = true; - } + template + OutputIt retrieveAllResults(Futures& futures, OutputIt out); + + void stopRequested(); // the task queue TasksQueue _tasks; @@ -97,6 +73,36 @@ class ThreadPool { vector _workers; }; +inline ThreadPool::ThreadPool(int nbThreads) { + if (nbThreads < 1) { + throw invalid_argument("number of threads should be strictly positive"); + } + _workers.reserve(static_cast(nbThreads)); + for (decltype(nbThreads) threadPos = 0; threadPos < nbThreads; ++threadPos) { + _workers.emplace_back([this] { + while (true) { + TasksQueue::value_type task; + + { + std::unique_lock lock(this->_queueMutex); + this->_condition.wait(lock, [this] { return this->_stop || !this->_tasks.empty(); }); + if (this->_stop && this->_tasks.empty()) { + break; + } + task = std::move(this->_tasks.front()); + this->_tasks.pop(); + } + task(); + } + }); + } +} + +inline ThreadPool::~ThreadPool() { + stopRequested(); + _condition.notify_all(); +} + template inline std::future::type> ThreadPool::enqueue(Func&& f, Args&&... args) { using return_type = typename std::invoke_result::type; @@ -120,32 +126,55 @@ inline std::future::type> ThreadPool: } template -inline OutputIt ThreadPool::parallel_transform(InputIt beg, InputIt end, OutputIt out, UnaryOperation op) { - using FutureT = std::future>; +inline OutputIt ThreadPool::parallelTransform(InputIt first, InputIt last, OutputIt out, UnaryOperation unary_op) { + using FutureT = std::future>; SmallVector futures; - for (; beg != end; ++beg) { - futures.emplace_back(enqueue(op, *beg)); + for (; first != last; ++first) { + futures.emplace_back(enqueue(unary_op, *first)); } - auto nbFutures = futures.size(); - for (decltype(nbFutures) runPos = 0; runPos < nbFutures; ++runPos, ++out) { - *out = futures[runPos].get(); - } - return out; + return retrieveAllResults(futures, out); } template -inline OutputIt ThreadPool::parallel_transform(InputIt1 first1, InputIt1 last1, InputIt2 first2, OutputIt out, - BinaryOperation binary_op) { +inline OutputIt ThreadPool::parallelTransform(InputIt1 first1, InputIt1 last1, InputIt2 first2, OutputIt out, + BinaryOperation binary_op) { using FutureT = std::future>; SmallVector futures; for (; first1 != last1; ++first1, ++first2) { futures.emplace_back(enqueue(binary_op, *first1, *first2)); } + return retrieveAllResults(futures, out); +} + +template +inline OutputIt ThreadPool::retrieveAllResults(Futures& futures, OutputIt out) { auto nbFutures = futures.size(); + int nbExceptionsThrown = 0; for (decltype(nbFutures) runPos = 0; runPos < nbFutures; ++runPos, ++out) { - *out = futures[runPos].get(); + try { + *out = futures[runPos].get(); + } catch (const std::exception& e) { + // When a future throws an exception, it will be rethrown at the get() method call. + // We need to catch it and finish getting all the results before we can rethrow it. + using OutputType = std::remove_cvref_t; + // value initialize the result for this thread. Probably not needed, but safer. + *out = OutputType(); + log::critical("exception caught in thread pool: {}", e.what()); + ++nbExceptionsThrown; + } + } + if (nbExceptionsThrown != 0) { + // In this command line implementation of coincenter, I choose to rethrow any exception thrown by threads. + // In a server implementation, we could maybe only log the error and not rethrow the exception + throw exception("{} exception(s) thrown in thread pool", nbExceptionsThrown); } + return out; } +inline void ThreadPool::stopRequested() { + std::unique_lock lock(_queueMutex); + _stop = true; +} + } // namespace cct \ No newline at end of file diff --git a/src/tech/test/threadpool_test.cpp b/src/tech/test/threadpool_test.cpp index 728c397e..818f637a 100644 --- a/src/tech/test/threadpool_test.cpp +++ b/src/tech/test/threadpool_test.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include "cct_vector.hpp" @@ -15,12 +16,15 @@ namespace cct { namespace { using namespace std::chrono_literals; -int SlowDouble(int val) { +int SlowDouble(const int &val) { + if (val == 42) { + throw std::invalid_argument("42 is not the answer to the ultimate question of life"); + } std::this_thread::sleep_for(10ms); return val * 2; } -int SlowAdd(int lhs, int rhs) { +int SlowAdd(const int &lhs, const int &rhs) { std::this_thread::sleep_for(10ms); return lhs + rhs; } @@ -47,7 +51,7 @@ TEST(ThreadPoolTest, ParallelTransformRandomInputIt) { std::iota(data.begin(), data.end(), 0); vector res(data.size()); - threadPool.parallel_transform(data.begin(), data.end(), res.begin(), SlowDouble); + threadPool.parallelTransform(data.begin(), data.end(), res.begin(), SlowDouble); for (int elem = 0; elem < kNbElems; ++elem) { EXPECT_EQ(2 * data[elem], res[elem]); @@ -61,13 +65,29 @@ TEST(ThreadPoolTest, ParallelTransformForwardInputIt) { std::iota(data.begin(), data.end(), 0); std::forward_list res(kNbElems); - threadPool.parallel_transform(data.begin(), data.end(), res.begin(), SlowDouble); + threadPool.parallelTransform(data.begin(), data.end(), res.begin(), SlowDouble); for (auto dataIt = data.begin(), resIt = res.begin(); dataIt != data.end(); ++dataIt, ++resIt) { EXPECT_EQ(2 * *dataIt, *resIt); } } +TEST(ThreadPoolTest, ParallelTransformException) { + ThreadPool threadPool(3); + constexpr int kNbElems = 5; + vector data(kNbElems); + std::iota(data.begin(), data.end(), 40); + vector res(data.size(), 40); + + try { + threadPool.parallelTransform(data.begin(), data.end(), res.begin(), SlowDouble); + EXPECT_TRUE(false); // should not arrive here + } catch (...) { + } + + EXPECT_EQ(res, (vector{80, 82, 0, 86, 88})); +} + TEST(ThreadPoolTest, ParallelTransformBinaryOperation) { ThreadPool threadPool(2); constexpr int kNbElems = 11; @@ -79,7 +99,7 @@ TEST(ThreadPoolTest, ParallelTransformBinaryOperation) { std::iota(data2.begin(), data2.end(), 3); vector res(kNbElems); - threadPool.parallel_transform(data1.begin(), data1.end(), data2.begin(), res.begin(), SlowAdd); + threadPool.parallelTransform(data1.begin(), data1.end(), data2.begin(), res.begin(), SlowAdd); auto resIt = res.begin(); auto data1It = data1.begin();