/
for.clj
64 lines (57 loc) · 2.83 KB
/
for.clj
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
(ns tech.parallel.for
(:import [java.util.concurrent ForkJoinPool Callable Future ExecutorService]
[java.util ArrayDeque PriorityQueue Comparator]))
(defmacro serial-for
[idx-var num-iters & body]
`(let [num-iters# (long ~num-iters)]
(loop [~idx-var 0]
(when (< ~idx-var num-iters#)
(do
~@body)
(recur (inc ~idx-var))))))
(defn launch-parallel-for
"Given a function that takes exactly 2 arguments, a start-index and a length,
call this function exactly N times where N is ForkJoinPool/getCommonPoolParallelism.
Indexes will be split as evenly as possible among the invocations."
[^long num-iters parallel-for-fn]
(if (< num-iters (* 2 (ForkJoinPool/getCommonPoolParallelism)))
(parallel-for-fn 0 num-iters)
(let [num-iters (long num-iters)
parallelism (ForkJoinPool/getCommonPoolParallelism)
group-size (quot num-iters parallelism)
overflow (rem num-iters parallelism)
overflow-size (+ group-size 1)
group-count (min num-iters parallelism)
;;Get pairs of (start-idx, len) to launch callables
groups (map (fn [^long callable-idx]
(let [group-len (if (< callable-idx overflow)
overflow-size
group-size)
group-start (+ (* overflow-size
(min overflow callable-idx))
(* group-size
(max 0 (- callable-idx overflow))))]
[group-start group-len]))
(range parallelism))
callables (map (fn [[start-idx len]]
(fn [] (parallel-for-fn start-idx len)))
groups)
common-pool (ForkJoinPool/commonPool)
;;launch the missiles
futures (mapv #(.submit common-pool ^Callable %) callables)]
(doseq [^Future fut futures]
(.get fut)))))
(defmacro parallel-for
"Like clojure.core.matrix.macros c-for except this expects index that run from 0 ->
num-iters. Idx is named idx-var and body will be called for each idx in parallel."
[idx-var num-iters & body]
`(let [num-iters# (long ~num-iters)]
(if (< num-iters# (* 2 (ForkJoinPool/getCommonPoolParallelism)))
(serial-for ~idx-var num-iters# ~@body)
(launch-parallel-for num-iters#
(fn [^long group-start# ^long group-len#]
(let [group-end# (+ group-start# group-len#)]
(loop [~idx-var group-start#]
(when (< ~idx-var group-end#)
~@body
(recur (inc ~idx-var))))))))))