From fda64570cdc8a3ea8e13d89001eebc476c4a2829 Mon Sep 17 00:00:00 2001 From: 4kangjc Date: Fri, 12 Jul 2024 20:06:10 +0800 Subject: [PATCH] fix: fiber calling with arguments see also https://github.com/Tencent/flare/pull/66, https://github.com/Tencent/flare/pull/74 --- trpc/coroutine/fiber.h | 4 +-- trpc/coroutine/fiber_test.cc | 69 ++++++++++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/trpc/coroutine/fiber.h b/trpc/coroutine/fiber.h index 16c268b8..1d392ed5 100644 --- a/trpc/coroutine/fiber.h +++ b/trpc/coroutine/fiber.h @@ -105,8 +105,8 @@ class Fiber { Fiber(const Attributes& attr, F&& f, Args&&... args) : Fiber(attr, [f = std::forward(f), // P0780R2 is not implemented as of now (GCC 8.2). - t = std::tuple(std::forward(args)...)] { - std::apply(std::forward(f)(std::move(t))); + t = std::make_tuple(std::forward(args)...)]() mutable { + std::apply(std::move(f), std::move(t)); }) {} /// @brief Special case if no parameter is passed to `F`, in this case we don't need diff --git a/trpc/coroutine/fiber_test.cc b/trpc/coroutine/fiber_test.cc index 8dc9c3e3..ae8531e7 100644 --- a/trpc/coroutine/fiber_test.cc +++ b/trpc/coroutine/fiber_test.cc @@ -298,4 +298,73 @@ TEST(Fiber, FiberMutexInMixedContext) { }); } +// Passing `c` as non-const reference to test `std::ref` (see below.). +void Product(int a, int b, int& c) { c = a * b; } + +TEST(Fiber, CallWithArgs) { + RunAsFiber([]() { + // Test lambda + Fiber([](const char* hello) { ASSERT_EQ(hello, "hello"); }, "hello").Join(); + + Fiber( + [](auto&& First, auto&&... other) { + auto ans = (First + ... + other); + ASSERT_EQ(ans, 55); + }, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + .Join(); + + // Test member method + struct Add { + void operator()(int a, int b, int c) const { ASSERT_EQ(a + b, c); } + }; + + const Add add; + Fiber(std::ref(add), 2, 3, 5).Join(); + Fiber(&Add::operator(), &add, 1, 2, 3).Join(); + + struct Worker { // Noncopyable + std::string s; + void work(std::string_view s) { ASSERT_EQ("work...", s); } + void operator()(const std::string& str) { s = str; } + Worker() = default; + Worker(Worker&&) = default; + Worker(const Worker&) = delete; + Worker& operator=(const Worker&) = delete; + }; + + Worker w; + Fiber(&Worker::work, &w, "work...").Join(); + Fiber(&Worker::operator(), &w, "Work Test").Join(); + ASSERT_EQ(w.s, "Work Test"); + Fiber(std::move(w), "Move Test").Join(); + + // Test template function + std::vector vec{5, 4, 3, 2, 1}; + ASSERT_FALSE(std::is_sorted(vec.begin(), vec.end())); + Fiber(&std::sort::iterator>, vec.begin(), vec.end()) + .Join(); + ASSERT_TRUE(std::is_sorted(vec.begin(), vec.end())); + + // Test function name + int res = 0; + Fiber(Product, 2, 5, std::ref(res)).Join(); + ASSERT_EQ(res, 10); + + // Test function address + Fiber(&Product, 6, 7, std::ref(res)).Join(); + ASSERT_EQ(res, 42); + + // Test bind + auto bind_function = + std::bind(Product, 3, std::placeholders::_1, std::placeholders::_2); + Fiber(bind_function, 5, std::ref(res)).Join(); + ASSERT_EQ(res, 15); + + // `std::pair` shouldn't be converted to `std::tuple` implicitly (by CTAD). + Fiber([&](auto&& p) { res = p.first; }, std::make_pair(1, 2)).Join(); + EXPECT_EQ(1, res); + }); +} + } // namespace trpc