diff --git a/src/ThreadsX.jl b/src/ThreadsX.jl index 1840d681..1ea6a2d0 100644 --- a/src/ThreadsX.jl +++ b/src/ThreadsX.jl @@ -38,7 +38,8 @@ module Implementations import SplittablesBase using ArgCheck: @argcheck, @check using BangBang: SingletonVector, append!!, push!!, union!! -using Base: Ordering, add_sum, mapreduce_empty, mul_prod, reduce_empty +using Base: + HasShape, IteratorSize, Ordering, add_sum, mapreduce_empty, mul_prod, reduce_empty using ConstructionBase: setproperties using InitialValues: asmonoid using Referenceables: referenceable diff --git a/src/basesizes.jl b/src/basesizes.jl index fa4ec7d4..67d6fd13 100644 --- a/src/basesizes.jl +++ b/src/basesizes.jl @@ -1,3 +1,7 @@ +default_basesize(n::Integer) = max(1, cld(n, (8 * Threads.nthreads()))) +default_basesize(xs) = + default_basesize(SplittablesBase.amount(last(extract_transducer(xs)))) + default_basesize(_, _, xs) = default_basesize(xs::AbstractArray) # TODO: handle `Base.Fix2` etc. diff --git a/src/map.jl b/src/map.jl index 49a95eae..27b9b98a 100644 --- a/src/map.jl +++ b/src/map.jl @@ -3,10 +3,16 @@ __map(f, itr; kwargs...) = __map(f, itrs...; kwargs...) = tcollect(MapSplat(f), zip(itrs...); basesize = default_basesize(itrs[1]), kwargs...) +reshape_as(ys, xs) = reshape_as(ys, xs, IteratorSize(xs)) +reshape_as(ys, xs, ::IteratorSize) = ys +reshape_as(ys, xs, ::HasShape) = reshape(ys, size(xs)) +reshape_as(::Empty{T}, xs, isize::HasShape) where {T<:AbstractVector} = + reshape_as(T(undef, length(xs)), xs, isize) + function _map(f, itr, itrs...; kwargs...) ys = __map(f, itr, itrs...; kwargs...) isempty(ys) && return map(f, itr, itrs...) - return ys + return reshape_as(ys, itr) end ThreadsX.map(f, itr, itrs...; kwargs...) = _map(f, itr, itrs...; kwargs...) @@ -36,8 +42,10 @@ end struct ConvertTo{T} end (::ConvertTo{T})(x) where {T} = convert(T, x) -ThreadsX.collect(::Type{T}, itr; kwargs...) where {T} = - tcopy(Map(ConvertTo{T}()), Vector{T}, itr; basesize = default_basesize(itr), kwargs...) +ThreadsX.collect(::Type{T}, itr; kwargs...) where {T} = reshape_as( + tcopy(Map(ConvertTo{T}()), Vector{T}, itr; basesize = default_basesize(itr), kwargs...), + itr, +) ThreadsX.collect(itr; kwargs...) = - tcollect(itr; basesize = default_basesize(itr), kwargs...) + reshape_as(tcollect(itr; basesize = default_basesize(itr), kwargs...), itr) diff --git a/src/utils.jl b/src/utils.jl index b0b60cf0..311e721e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,6 +1,3 @@ -default_basesize(n::Integer) = max(1, cld(n, (8 * Threads.nthreads()))) -default_basesize(xs) = default_basesize(length(xs)) - function adhoc_partition(xs, n) @check firstindex(xs) == 1 m = cld(length(xs), n) diff --git a/test/test_with_base.jl b/test/test_with_base.jl index 1c4f3bc3..d135dbb5 100644 --- a/test/test_with_base.jl +++ b/test/test_with_base.jl @@ -1,5 +1,6 @@ module TestWithBase +using Base: splat using Test using ThreadsX @@ -10,9 +11,16 @@ inc(x) = x + 1 raw_testdata = """ collect(1:10) collect(Float64, 1:10) +collect(x for x in 1:10 if isodd(x)) +collect(Float64, (x for x in 1:10 if isodd(x))) collect(inc(x) for x in 1:10) collect(Float64, (inc(x) for x in 1:10)) +collect(x * y for x in 1:10, y in 11:20) +collect(Float64, (x * y for x in 1:10, y in 11:20)) +collect(x * y for x in 1:0, y in 11:20) +collect(Float64, (x * y for x in 1:0, y in 11:20)) map(inc, 1:10) +map(inc, (x for x in 1:10 if isodd(x))) map(inc, Float64[]) map(inc, ones(3, 3)) map(inc, ones(3, 0)) @@ -21,6 +29,8 @@ map(*, 1:10, 11:20) map(*, ones(3, 3), ones(3, 3)) map(*, ones(3, 0), ones(3, 0)) map(*, ones(0, 3), ones(0, 3)) +map(splat(*), Iterators.product(1:10, 11:20)) +map(splat(*), Iterators.product(1:0, 11:20)) reduce(+, 1:10) reduce(+, 1:0) reduce(+, Bool[])