diff --git a/include/cppcoro/io_service.hpp b/include/cppcoro/io_service.hpp index 27867c58..464cf09e 100644 --- a/include/cppcoro/io_service.hpp +++ b/include/cppcoro/io_service.hpp @@ -79,6 +79,13 @@ namespace cppcoro const std::chrono::duration& delay, cancellation_token cancellationToken = {}) noexcept; + /// Process events until the task completes. + /// + /// \return + /// Result of the co_await task. + template + decltype(auto) process_events_until_complete(TASK&& task); + /// Process events until the io_service is stopped. /// /// \return @@ -178,6 +185,26 @@ namespace cppcoro }; + template + decltype(auto) io_service::process_events_until_complete(TASK&& task) + { + if (!task.is_ready()) + { + auto callback = [](void* io) noexcept + { + static_cast(io)->stop(); + }; + + auto starter = task.get_starter(); + starter.start(cppcoro::detail::continuation{ callback, this }); + + process_events(); + reset(); + } + + return std::forward(task).operator co_await().await_resume(); + } + class io_service::schedule_operation { public: diff --git a/test/io_service_tests.cpp b/test/io_service_tests.cpp index ddf5db31..a4cf2e43 100644 --- a/test/io_service_tests.cpp +++ b/test/io_service_tests.cpp @@ -226,4 +226,36 @@ TEST_CASE_FIXTURE(io_service_fixture_with_threads<1>, "Many concurrent timers") << "ms"); } +TEST_CASE("io_service::process_events_until_complete(task)") +{ + cppcoro::io_service ioService; + + auto makeTask = [](cppcoro::io_service& io) -> cppcoro::task + { + co_await io.schedule(); + co_return "foo"; + }; + + auto task = makeTask(ioService); + + CHECK(ioService.process_events_until_complete(task) == "foo"); + CHECK(ioService.process_events_until_complete(makeTask(ioService)) == "foo"); +} + +TEST_CASE("io_service::process_events_until_complete(shared_task)") +{ + cppcoro::io_service ioService; + + auto makeTask = [](cppcoro::io_service& io) -> cppcoro::shared_task + { + co_await io.schedule(); + co_return "foo"; + }; + + auto task = makeTask(ioService); + + CHECK(ioService.process_events_until_complete(task) == "foo"); + CHECK(ioService.process_events_until_complete(makeTask(ioService)) == "foo"); +} + TEST_SUITE_END();