Skip to content

Commit

Permalink
better sweep
Browse files Browse the repository at this point in the history
  • Loading branch information
panlanfeng committed Sep 23, 2016
1 parent a081927 commit dd9efb8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 23 deletions.
48 changes: 25 additions & 23 deletions src/sweep.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
function recyletimes(long::Integer, short::Integer)
if short == 1
return long
elseif short == long
return 1
else
throw(DimensionMismatch("Dimension Mismatch and the size of the short one is not 1!"))
end
end

"""
sweep(x, MARGIN, STATS, FUN)
Expand All @@ -11,32 +21,24 @@ Amedian=mapslices(median, A, 3)
sweep(A, 1:2, Amedian)
```
"""
function sweep(x::AbstractArray, MARGIN::AbstractVector, STATS::AbstractArray, FUN::Function=.-; kwargs...)
sz = collect(size(x))
function sweep(x::AbstractArray, MARGIN::Union{AbstractVector, Integer}, STATS::AbstractArray, FUN::Function=.-; kwargs...)
sz = size(x)
dimmargin = sz[MARGIN]
dimstats = size(STATS)
lstats = length(STATS)
lmargin = prod(dimmargin)
if lstats > lmargin
error("STATS is longer than extent of dimension $(sz[MARGIN])")
else
if mod(lmargin, lstats) != 0
error("STATS does not recycle exactly across MARGIN")
elseif lmargin != lstats
warn("STATS is recycled to match the extend of dimension $(sz[MARGIN])")
STATS=reshape(rep(vec(STATS), times=div(lmargin, lstats)), (sz[MARGIN]...))
if length(dimstats) == length(sz)
outertimes = Int[recyletimes(sz[i], dimstats[i]) for i in 1:length(dimstats)]
elseif length(dimstats) <= length(dimmargin)
outertimes = collect(sz)
for i in 1:length(dimstats)
outertimes[MARGIN[i]] = recyletimes(dimmargin[i], dimstats[i])
end
else
throw(DimensionMismatch("Dimension of STATS should match the MARGIN dimension $(dimmargin)"))
end
dimorder=seq_along(sz)
pe = [MARGIN, setdiff(dimorder, MARGIN);]
outer = sz[pe]
outer[1:length(MARGIN)]=1
xstat=Compat.repeat(STATS, outer=outer)
xstat=permutedims(xstat, (sortperm(pe)...))
FUN(x, xstat; kwargs...)
xstat = repeat(STATS, outer=outertimes)
return FUN(x, xstat; kwargs...)
end

function sweep(x::AbstractArray, MARGIN::Integer, STATS::AbstractArray, FUN::Function=.-; kwargs...)
sweep(x, [MARGIN], STATS, FUN; kwargs...)
end
sweep(x::AbstractArray, MARGIN::Tuple, STATS::AbstractArray, FUN::Function=.-; kwargs...)=sweep(x, [MARGIN...], STATS, FUN; kwargs...)
sweep(x::AbstractArray, MARGIN::Union{AbstractVector, Integer}, STATS::Real, FUN::Function=.-; kwargs...)=sweep(x, MARGIN, [STATS], FUN; kwargs...)

sweep(x::AbstractArray, MARGIN::Tuple, STATS::Union{AbstractArray, Real}, FUN::Function=.-; kwargs...)=sweep(x, [MARGIN...], STATS, FUN; kwargs...)
4 changes: 4 additions & 0 deletions test/sweeptest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ Amedian=mapslices(median, A, 3)
@test sweep(A, 1:2, Amedian)[:,:,1]==[-6.0 -6.0 -6.0; -6.0 -6.0 -6.0; -6.0 -6.0 -6.0; -6.0 -6.0 -6.0]

@test sweep(A, (1,2), Amedian)[:,:,1]==[-6.0 -6.0 -6.0; -6.0 -6.0 -6.0; -6.0 -6.0 -6.0; -6.0 -6.0 -6.0]
x = hcat(ones(10), collect(1:10))
@test RFlavor.sweep(x, 2, mean(x,1)) == [0.0 -4.5; 0.0 -3.5; 0.0 -2.5; 0.0 -1.5; 0.0 -0.5; 0.0 0.5; 0.0 1.5; 0.0 2.5; 0.0 3.5; 0.0 4.5]

@test RFlavor.sweep(x, 2, 2.0) == [-1.0 -1.0; -1.0 0.0; -1.0 1.0; -1.0 2.0; -1.0 3.0; -1.0 4.0; -1.0 5.0; -1.0 6.0; -1.0 7.0; -1.0 8.0]

0 comments on commit dd9efb8

Please sign in to comment.