Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Optim] - Avoid copying arguments in parallelTransform of threadpool #528

Merged
merged 1 commit into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/api-objects/include/withdrawinfo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ class DeliveredWithdrawInfo {
template <>
struct fmt::formatter<cct::DeliveredWithdrawInfo> {
constexpr auto parse(format_parse_context &ctx) -> decltype(ctx.begin()) {
auto it = ctx.begin(), end = ctx.end();
const auto it = ctx.begin();
const auto end = ctx.end();
if (it != end && *it != '}') {
throw format_error("invalid format");
}
Expand Down
2 changes: 1 addition & 1 deletion src/engine/include/parseoptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,6 @@ auto ParseOptions(ParserType &parser, int argc, const char *argv[]) {
groupParsedOptions.mergeGlobalWith(globalOptions);
}

return parsedOptions;
return std::make_pair(std::move(programName), parsedOptions);
}
} // namespace cct
108 changes: 72 additions & 36 deletions src/engine/src/exchangesorchestrator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "cct_log.hpp"
#include "cct_smallvector.hpp"
#include "cct_string.hpp"
#include "cct_type_traits.hpp"
#include "commonapi.hpp"
#include "currencycode.hpp"
#include "currencyexchangeflatset.hpp"
Expand All @@ -31,7 +32,6 @@
#include "exchangeretriever.hpp"
#include "exchangeretrieverbase.hpp"
#include "market.hpp"
#include "marketorderbook.hpp"
#include "monetaryamount.hpp"
#include "monetaryamountbycurrencyset.hpp"
#include "ordersconstraints.hpp"
Expand All @@ -56,16 +56,43 @@ template <class MainVec>
void FilterVector(MainVec &main, std::span<const bool> considerSpan) {
const auto begIt = main.begin();
const auto endIt = main.end();

main.erase(std::remove_if(begIt, endIt, [=](const auto &val) { return !considerSpan[&val - &*begIt]; }), endIt);
}

using ExchangeAmountPairVector = SmallVector<std::pair<Exchange *, MonetaryAmount>, kTypicalNbPrivateAccounts>;
using ExchangeAmountMarketsPathVector =
SmallVector<std::tuple<Exchange *, MonetaryAmount, MarketsPath>, kTypicalNbPrivateAccounts>;
using ExchangeAmountToCurrency = std::tuple<Exchange *, MonetaryAmount, CurrencyCode, MarketsPath>;
using ExchangeAmountToCurrencyToAmount =
std::tuple<Exchange *, MonetaryAmount, CurrencyCode, MarketsPath, MonetaryAmount>;

struct ExchangeAmountMarkets {
Exchange *exchange;
MonetaryAmount amount;
MarketsPath marketsPath;

using trivially_relocatable = is_trivially_relocatable<MarketsPath>::type;
};

using ExchangeAmountMarketsPathVector = SmallVector<ExchangeAmountMarkets, kTypicalNbPrivateAccounts>;

struct ExchangeAmountToCurrency {
Exchange *exchange;
MonetaryAmount amount;
CurrencyCode currency;
MarketsPath marketsPath;

using trivially_relocatable = is_trivially_relocatable<MarketsPath>::type;
};

using ExchangeAmountToCurrencyVector = SmallVector<ExchangeAmountToCurrency, kTypicalNbPrivateAccounts>;

struct ExchangeAmountToCurrencyToAmount {
Exchange *exchange;
MonetaryAmount amount;
CurrencyCode currency;
MarketsPath marketsPath;
MonetaryAmount endAmount;

using trivially_relocatable = is_trivially_relocatable<MarketsPath>::type;
};

using ExchangeAmountToCurrencyToAmountVector = SmallVector<ExchangeAmountToCurrencyToAmount, kTypicalNbPrivateAccounts>;

template <class VecWithExchangeFirstPos>
Expand Down Expand Up @@ -138,12 +165,12 @@ MarketOrderBookConversionRates ExchangesOrchestrator::getMarketOrderBooks(Market
equiCurrencyCode.isNeutral()
? std::nullopt
: exchange->apiPublic().estimatedConvert(MonetaryAmount(1, mk.quote()), equiCurrencyCode);
MarketOrderBook marketOrderBook(depth ? exchange->queryOrderBook(mk, *depth) : exchange->queryOrderBook(mk));
if (!optConversionRate && !equiCurrencyCode.isNeutral()) {
log::warn("Unable to convert {} into {} on {}", marketOrderBook.market().quote(), equiCurrencyCode,
exchange->name());
log::warn("Unable to convert {} into {} on {}", mk.quote(), equiCurrencyCode, exchange->name());
}
return std::make_tuple(exchange->name(), std::move(marketOrderBook), optConversionRate);
return std::make_tuple(exchange->name(),
depth ? exchange->queryOrderBook(mk, *depth) : exchange->queryOrderBook(mk),
optConversionRate);
};
_threadPool.parallelTransform(selectedExchanges.begin(), selectedExchanges.end(), ret.begin(), marketOrderBooksFunc);
return ret;
Expand All @@ -157,17 +184,17 @@ BalancePerExchange ExchangesOrchestrator::getBalance(std::span<const ExchangeNam
log::info("Query balance from {}{}{} with{} balance in use", ConstructAccumulatedExchangeNames(privateExchangeNames),
equiCurrency.isNeutral() ? "" : " with equi currency ", equiCurrency, withBalanceInUse ? "" : "out");

ExchangeRetriever::SelectedExchanges balanceExchanges =
ExchangeRetriever::SelectedExchanges selectedExchanges =
_exchangeRetriever.select(ExchangeRetriever::Order::kInitial, privateExchangeNames);

SmallVector<BalancePortfolio, kTypicalNbPrivateAccounts> balancePortfolios(balanceExchanges.size());
SmallVector<BalancePortfolio, kTypicalNbPrivateAccounts> balancePortfolios(selectedExchanges.size());
_threadPool.parallelTransform(
balanceExchanges.begin(), balanceExchanges.end(), balancePortfolios.begin(),
selectedExchanges.begin(), selectedExchanges.end(), balancePortfolios.begin(),
[&balanceOptions](Exchange *exchange) { return exchange->apiPrivate().getAccountBalance(balanceOptions); });

BalancePerExchange ret;
ret.reserve(balanceExchanges.size());
std::transform(balanceExchanges.begin(), balanceExchanges.end(), std::make_move_iterator(balancePortfolios.begin()),
ret.reserve(selectedExchanges.size());
std::transform(selectedExchanges.begin(), selectedExchanges.end(), std::make_move_iterator(balancePortfolios.begin()),
std::back_inserter(ret), [](Exchange *exchange, BalancePortfolio &&balancePortfolio) {
return std::make_pair(exchange, std::move(balancePortfolio));
});
Expand Down Expand Up @@ -483,26 +510,35 @@ TradeResultPerExchange LaunchAndCollectTrades(ThreadPool &threadPool, ExchangeAm
ExchangeAmountMarketsPathVector::iterator last, CurrencyCode toCurrency,
const TradeOptions &tradeOptions) {
TradeResultPerExchange tradeResultPerExchange(std::distance(first, last));
threadPool.parallelTransform(first, last, tradeResultPerExchange.begin(), [toCurrency, &tradeOptions](auto &tuple) {
Exchange *exchange = std::get<0>(tuple);
const MonetaryAmount from = std::get<1>(tuple);
TradedAmounts tradedAmounts = exchange->apiPrivate().trade(from, toCurrency, tradeOptions, std::get<2>(tuple));
return std::make_pair(exchange, TradeResult(std::move(tradedAmounts), from));
});
threadPool.parallelTransform(first, last, tradeResultPerExchange.begin(),
[toCurrency, &tradeOptions](ExchangeAmountMarkets &exchangeAmountMarketsPath) {
Exchange *exchange = exchangeAmountMarketsPath.exchange;
const MonetaryAmount from = exchangeAmountMarketsPath.amount;
const auto &marketsPath = exchangeAmountMarketsPath.marketsPath;

TradedAmounts tradedAmounts =
exchange->apiPrivate().trade(from, toCurrency, tradeOptions, marketsPath);

return std::make_pair(exchange, TradeResult(std::move(tradedAmounts), from));
});
return tradeResultPerExchange;
}

template <class Iterator>
TradeResultPerExchange LaunchAndCollectTrades(ThreadPool &threadPool, Iterator first, Iterator last,
const TradeOptions &tradeOptions) {
TradeResultPerExchange tradeResultPerExchange(std::distance(first, last));
threadPool.parallelTransform(first, last, tradeResultPerExchange.begin(), [&tradeOptions](auto &tuple) {
Exchange *exchange = std::get<0>(tuple);
const MonetaryAmount from = std::get<1>(tuple);
const CurrencyCode toCurrency = std::get<2>(tuple);
TradedAmounts tradedAmounts = exchange->apiPrivate().trade(from, toCurrency, tradeOptions, std::get<3>(tuple));
return std::make_pair(exchange, TradeResult(std::move(tradedAmounts), from));
});
using ObjType = decltype(*first);
threadPool.parallelTransform(
first, last, tradeResultPerExchange.begin(), [&tradeOptions](ObjType &exchangeAmountMarketsPath) {
Exchange *exchange = exchangeAmountMarketsPath.exchange;
const MonetaryAmount from = exchangeAmountMarketsPath.amount;
const CurrencyCode toCurrency = exchangeAmountMarketsPath.currency;
const auto &marketsPath = exchangeAmountMarketsPath.marketsPath;

TradedAmounts tradedAmounts = exchange->apiPrivate().trade(from, toCurrency, tradeOptions, marketsPath);
return std::make_pair(exchange, TradeResult(std::move(tradedAmounts), from));
});
return tradeResultPerExchange;
}

Expand Down Expand Up @@ -549,17 +585,17 @@ TradeResultPerExchange ExchangesOrchestrator::trade(MonetaryAmount from, bool is
if (!exchangeAmountMarketsPathVector.empty()) {
// Sort exchanges from largest to lowest available amount (should be after filter on markets and conversion paths)
std::ranges::stable_sort(exchangeAmountMarketsPathVector,
[](const auto &lhs, const auto &rhs) { return std::get<1>(lhs) > std::get<1>(rhs); });
[](const auto &lhs, const auto &rhs) { return lhs.amount > rhs.amount; });

// Locate the point where there is enough available amount to trade for this currency
if (isPercentageTrade) {
MonetaryAmount totalAvailableAmount = std::accumulate(
exchangeAmountMarketsPathVector.begin(), exchangeAmountMarketsPathVector.end(), currentTotalAmount,
[](MonetaryAmount tot, const auto &tuple) { return tot + std::get<1>(tuple); });
MonetaryAmount totalAvailableAmount =
std::accumulate(exchangeAmountMarketsPathVector.begin(), exchangeAmountMarketsPathVector.end(),
currentTotalAmount, [](MonetaryAmount tot, const auto &tuple) { return tot + tuple.amount; });
from = (totalAvailableAmount * from.toNeutral()) / 100;
}
for (auto endIt = exchangeAmountMarketsPathVector.end(); it != endIt && currentTotalAmount < from; ++it) {
MonetaryAmount &amount = std::get<1>(*it);
MonetaryAmount &amount = it->amount;
if (currentTotalAmount + amount > from) {
// Cap last amount such that total start trade on all exchanges reaches exactly 'startAmount'
amount = from - currentTotalAmount;
Expand Down Expand Up @@ -628,8 +664,8 @@ TradeResultPerExchange ExchangesOrchestrator::smartBuy(MonetaryAmount endAmount,
}
MonetaryAmount avAmount = balance.get(fromCurrency);
if (avAmount > 0 &&
std::none_of(trades.begin(), trades.begin() + nbTrades, [pExchange, fromCurrency](const auto &tuple) {
return std::get<0>(tuple) == pExchange && std::get<1>(tuple).currencyCode() == fromCurrency;
std::none_of(trades.begin(), trades.begin() + nbTrades, [pExchange, fromCurrency](const auto &obj) {
return obj.exchange == pExchange && obj.amount.currencyCode() == fromCurrency;
})) {
auto conversionPath = exchangePublic.findMarketsPath(fromCurrency, toCurrency, markets, fiats,
api::ExchangePublic::MarketPathMode::kStrict);
Expand All @@ -649,7 +685,7 @@ TradeResultPerExchange ExchangesOrchestrator::smartBuy(MonetaryAmount endAmount,
}
// Sort exchanges from largest to lowest end amount
std::stable_sort(trades.begin() + nbTrades, trades.end(),
[](const auto &lhs, const auto &rhs) { return std::get<4>(lhs) > std::get<4>(rhs); });
[](const auto &lhs, const auto &rhs) { return lhs.endAmount > rhs.endAmount; });
int nbTradesToKeep = 0;
for (auto &[pExchange, startAmount, tradeToCurrency, conversionPath, tradeEndAmount] : trades) {
if (tradeEndAmount > remEndAmount) {
Expand Down
8 changes: 3 additions & 5 deletions src/main/src/main.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#include <cstdlib>
#include <exception>
#include <filesystem>
#include <iostream>

#include "cct_invalid_argument_exception.hpp"
Expand All @@ -13,19 +12,18 @@
#include "runmodes.hpp"

int main(int argc, const char* argv[]) {
using namespace cct;
try {
using namespace cct;
auto parser =
CommandLineOptionsParser<CoincenterCmdLineOptions>(CoincenterAllowedOptions<CoincenterCmdLineOptions>::value);
const auto cmdLineOptionsVector = ParseOptions(parser, argc, argv);
const auto [programName, cmdLineOptionsVector] = ParseOptions(parser, argc, argv);

if (!cmdLineOptionsVector.empty()) {
const CoincenterCommands coincenterCommands(cmdLineOptionsVector);
const auto programName = std::filesystem::path(argv[0]).filename().string();

ProcessCommandsFromCLI(programName, coincenterCommands, cmdLineOptionsVector.front(), settings::RunMode::kProd);
}
} catch (const cct::invalid_argument& e) {
} catch (const invalid_argument& e) {
std::cerr << "Invalid argument: " << e.what() << '\n';
return EXIT_FAILURE;
} catch (const std::exception& e) {
Expand Down
2 changes: 2 additions & 0 deletions src/objects/include/coincenterinfo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class CoincenterInfo {

AbstractMetricGateway *metricGatewayPtr() const { return _metricGatewayPtr.get(); }

const GeneralConfig &generalConfig() const { return _generalConfig; }

const LoggingInfo &loggingInfo() const { return _generalConfig.loggingInfo(); }

const RequestsConfig &requestsConfig() const { return _generalConfig.requestsConfig(); }
Expand Down
14 changes: 11 additions & 3 deletions src/tech/include/threadpool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,24 @@ class ThreadPool {

auto nbWorkers() const noexcept { return _workers.size(); }

// add new work item to the pool
// Add new work item to the pool
// By default, arguments will be copied for safety. If you want to pass arguments by reference,
// make sure that the reference lifetime is valid through the whole execution time of the future,
// and wrap the argument you want to pass by reference with 'std::ref'.
template <class Func, class... Args>
std::future<std::invoke_result_t<Func, Args...>> enqueue(Func&& func, Args&&... args);

// Parallel version of std::transform with unary operation.
// This function will first enqueue all the tasks at one, using waiting threads of the thread pool,
// and then retrieves and moves the results to 'out', as for std::transform.
// Note: the objects passed in argument from InputIt are not copied and passed by reference (through
// std::reference_wrapper)
template <class InputIt, class OutputIt, class UnaryOperation>
OutputIt parallelTransform(InputIt first, InputIt last, OutputIt out, UnaryOperation unary_op);

// Parallel version of std::transform with binary operation.
// Note: the objects passed in argument from InputIt are not copied and passed by reference (through
// std::reference_wrapper)
template <class InputIt1, class InputIt2, class OutputIt, class BinaryOperation>
OutputIt parallelTransform(InputIt1 first1, InputIt1 last1, InputIt2 first2, OutputIt out, BinaryOperation binary_op);

Expand Down Expand Up @@ -105,6 +112,7 @@ inline ThreadPool::~ThreadPool() {

template <class Func, class... Args>
inline std::future<std::invoke_result_t<Func, Args...>> ThreadPool::enqueue(Func&& func, Args&&... args) {
// std::bind copies the arguments. To avoid copies, you can use std::ref to copy reference instead.
using return_type = std::invoke_result_t<Func, Args...>;

auto task = std::make_shared<std::packaged_task<return_type()>>(
Expand All @@ -130,7 +138,7 @@ inline OutputIt ThreadPool::parallelTransform(InputIt first, InputIt last, Outpu
using FutureT = std::future<std::invoke_result_t<UnaryOperation, decltype(*first)>>;
SmallVector<FutureT, kTypicalNbPrivateAccounts> futures;
for (; first != last; ++first) {
futures.emplace_back(enqueue(unary_op, *first));
futures.emplace_back(enqueue(unary_op, std::ref(*first)));
}
return retrieveAllResults(futures, out);
}
Expand All @@ -141,7 +149,7 @@ inline OutputIt ThreadPool::parallelTransform(InputIt1 first1, InputIt1 last1, I
using FutureT = std::future<std::invoke_result_t<BinaryOperation, decltype(*first1), decltype(*first2)>>;
SmallVector<FutureT, kTypicalNbPrivateAccounts> futures;
for (; first1 != last1; ++first1, ++first2) {
futures.emplace_back(enqueue(binary_op, *first1, *first2));
futures.emplace_back(enqueue(binary_op, std::ref(*first1), std::ref(*first2)));
}
return retrieveAllResults(futures, out);
}
Expand Down
30 changes: 29 additions & 1 deletion src/tech/test/threadpool_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

#include <gtest/gtest.h>

#include <chrono>
#include <forward_list>
#include <future>
#include <numeric>
Expand All @@ -28,6 +27,19 @@ int SlowAdd(const int &lhs, const int &rhs) {
std::this_thread::sleep_for(10ms);
return lhs + rhs;
}

struct NonCopyable {
NonCopyable(int i = 0) : i(i) {}

NonCopyable(const NonCopyable &) = delete;

int i;
};

int SlowDoubleNonCopyable(const NonCopyable &val) {
std::this_thread::sleep_for(10ms);
return val.i * 2;
}
} // namespace

TEST(ThreadPoolTest, Enqueue) {
Expand All @@ -44,6 +56,22 @@ TEST(ThreadPoolTest, Enqueue) {
}
}

TEST(ThreadPoolTest, EnqueueNonCopyable) {
ThreadPool threadPool(2);
vector<std::future<int>> results;

constexpr int kNbElems = 4;
vector<NonCopyable> inputData(kNbElems);
for (int elem = 0; elem < kNbElems; ++elem) {
inputData[elem] = NonCopyable(elem);
results.push_back(threadPool.enqueue(SlowDoubleNonCopyable, std::ref(inputData[elem])));
}

for (int elem = 0; elem < kNbElems; ++elem) {
EXPECT_EQ(results[elem].get(), elem * 2);
}
}

TEST(ThreadPoolTest, ParallelTransformRandomInputIt) {
ThreadPool threadPool(4);
constexpr int kNbElems = 22;
Expand Down