Skip to content

Commit

Permalink
feature(fiber): reimplement pools
Browse files Browse the repository at this point in the history
* fix weird deadlocks
* add better validation for invariants
* make them a lot faster

Signed-off-by: Rudi Grinberg <me@rgrinberg.com>
  • Loading branch information
rgrinberg committed Jan 27, 2023
1 parent 0737782 commit 4fb2d4e
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 25 deletions.
5 changes: 4 additions & 1 deletion otherlibs/fiber/fiber.mli
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,10 @@ module Pool : sig
is called, [task pool ~f] will fail to submit new tasks.
Note that stopping the pool does not prevent already queued tasks from
running. *)
running.
[stop pool] subsequent calls to [stop] ignored. In other words, this
function is idempotent *)
val stop : t -> unit fiber

(** [run pool] Runs all tasks submitted to [pool] in parallel. Errors raised
Expand Down
86 changes: 65 additions & 21 deletions otherlibs/fiber/pool.ml
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
open Stdune
open Core
open Core.O

type mvar =
| Done
| Task of (unit -> unit t)

type status =
| Open
| Closed
| Open (* new tasks are allowed *)
| Closed (* new tasks are forbidden *)

type runner =
| Running (* Firing fibers inside the queue. *)
| Awaiting_resume of unit k
(* Ran out of work. Waiting to be resumed once work is added or
pool is closed. *)
| Awaiting_run (* Just created. [run] hasn't been called yet. *)

(* A pool consumes tasks from a queue in the context where [run] was executed.
type t =
{ mvar : mvar Mvar.t
It's implemented by a simple queue of thunks and a continuation to resume
[run] whenever it runs out of work.
To optimize this further, we can bake in the operation into [effect] in [Core]. *)

type nonrec t =
{ tasks : (unit -> unit t) Queue.t (* pending tasks *)
; mutable runner : runner
(* The continuation to resume the runner set by [run] *)
; mutable status : status
}

Expand All @@ -20,26 +31,59 @@ let running t k =
| Open -> k true
| Closed -> k false

let create () = { mvar = Mvar.create (); status = Open }
let create () =
{ tasks = Queue.create (); runner = Awaiting_run; status = Open }

let task t ~f k =
match t.status with
| Closed ->
Code_error.raise "pool is closed. new tasks may not be submitted" []
| Open -> Mvar.write t.mvar (Task f) k

let stream t =
Stream.In.create (fun () ->
let+ next = Mvar.read t.mvar in
match next with
| Done -> None
| Task task -> Some task)
| Open -> (
Queue.push t.tasks f;
match t.runner with
| Running | Awaiting_run -> k ()
| Awaiting_resume r ->
t.runner <- Running;
resume r () k)

let stop t k =
match t.status with
| Closed -> k ()
| Open ->
| Open -> (
t.status <- Closed;
Mvar.write t.mvar Done k
match t.runner with
| Running | Awaiting_run -> k ()
| Awaiting_resume r ->
t.runner <- Running;
resume r () k)

let run t = stream t |> Stream.In.parallel_iter ~f:(fun task -> task ())
let run t k =
match t.runner with
| Awaiting_resume _ | Running ->
Code_error.raise "Fiber.Pool.run: concurent calls to run aren't allowed" []
| Awaiting_run ->
t.runner <- Running;
(* The number of currently running fibers in the pool. Only when this
number reaches zero we may call the final continuation [k]. *)
let n = ref 1 in
let done_fiber () =
decr n;
if !n = 0 then k () else end_of_fiber
in
let rec read t =
match Queue.pop t.tasks with
| None -> finish_or_suspend t
| Some v ->
incr n;
fork (fun () -> v () done_fiber) read_delayed
and read_delayed () = read t
and suspend_k k =
(* we are suspending because we have no tasks *)
assert (Queue.is_empty t.tasks);
t.runner <- Awaiting_resume k
and finish_or_suspend t =
match t.status with
| Closed -> done_fiber ()
| Open -> suspend suspend_k read_delayed
in
read t
45 changes: 42 additions & 3 deletions test/expect-tests/fiber/fiber_tests.ml
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,8 @@ let%expect_test "double run a pool" =
let pool = Pool.create () in
Fiber.fork_and_join_unit (fun () -> Pool.run pool) (fun () -> Pool.run pool));
[%expect.unreachable]
[@@expect.uncaught_exn {| (Test_scheduler.Never) |}]
[@@expect.uncaught_exn
{| ("(\"Fiber.Pool.run: concurent calls to run aren't allowed\", {})") |}]

let%expect_test "run -> stop -> run a pool" =
(* We shouldn't be able to call [Pool.run] again after we already called
Expand All @@ -873,7 +874,8 @@ let%expect_test "run -> stop -> run a pool" =
in
Pool.run pool);
[%expect.unreachable]
[@@expect.uncaught_exn {| (Test_scheduler.Never) |}]
[@@expect.uncaught_exn
{| ("(\"Fiber.Pool.run: concurent calls to run aren't allowed\", {})") |}]

let%expect_test "stop a pool and then run it" =
(Scheduler.run
Expand All @@ -900,8 +902,45 @@ let%expect_test "pool - weird deadlock" =
let* () = Pool.task pool ~f:Fiber.return in
Fiber.fork_and_join_unit (fun () -> Pool.stop pool) (fun () -> Pool.run pool)
);
[%expect {||}]

let%expect_test "nested run in task" =
(Scheduler.run
@@
let pool = Pool.create () in
let* () = Pool.task pool ~f:(fun () -> Pool.run pool) in
Fiber.fork_and_join_unit (fun () -> Pool.stop pool) (fun () -> Pool.run pool)
);
[%expect.unreachable]
[@@expect.uncaught_exn {| (Test_scheduler.Never) |}]
[@@expect.uncaught_exn
{| ("(\"Fiber.Pool.run: concurent calls to run aren't allowed\", {})") |}]

let%expect_test "nested tasks" =
(Scheduler.run
@@
let pool = Pool.create () in
let* () =
Pool.task pool ~f:(fun () ->
print_endline "outer";
let* () =
Pool.task pool ~f:(fun () ->
let+ () = Fiber.return () in
print_endline "inner")
in
Pool.stop pool)
in
Pool.run pool);
[%expect {|
outer
inner |}]

let%expect_test "stopping inside a task" =
(Scheduler.run
@@
let pool = Pool.create () in
let* () = Pool.task pool ~f:(fun () -> Pool.stop pool) in
Pool.run pool);
[%expect {||}]

let%expect_test "stack usage with consecutive Ivar.fill" =
let stack_size () = (Gc.stat ()).stack_size in
Expand Down

0 comments on commit 4fb2d4e

Please sign in to comment.