/
staticpool.jl
60 lines (45 loc) · 1.46 KB
/
staticpool.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#############################
# Internal Structures
#############################
struct StaticPool <: AbstractThreadPool
tids :: Vector{Int}
end
#############################
# Constructors
#############################
"""
StaticPool(init_thrd=1, nthrds=Threads.nthreads())
The main StaticPool object.
"""
function StaticPool(init_thrd::Integer=1, nthrds::Integer=Threads.nthreads())
thrd0 = min(init_thrd, Threads.nthreads())
thrd1 = min(init_thrd+nthrds-1, Threads.nthreads())
return StaticPool(thrd0:thrd1)
end
#############################
# ThreadPool Interface
#############################
function tmap(pool::StaticPool, fn::Function, itr)
data = collect(itr)
applicable(fn, first(data)) || error("function can't be applied to iterator contents")
N = length(data)
result = Array{_detect_type(fn, data), ndims(data)}(undef, size(data))
nthrds = length(pool.tids)
njobs = div(N,nthrds)
remjobs = N % nthrds
len(ind) = max(0, njobs + (nthrds-ind+1 <= remjobs ? 1 : 0))
finish(ind) = sum([len(x) for x in 1:ind])
start(ind) = finish(ind)-len(ind)+1
_fn(ind) = begin
if finish(ind) > 0
for i in start(ind):finish(ind)
@inbounds result[i] = fn(Base.unsafe_getindex(data, i))
end
end
end
Threads.@threads for tid in 1:Threads.nthreads()
ind = findfirst(t->t==tid, pool.tids)
isnothing(ind) || _fn(ind)
end
return result
end