Skip to content

Commit

Permalink
Make axisfor inferrable
Browse files Browse the repository at this point in the history
  • Loading branch information
tkf committed Jul 19, 2019
1 parent df6611b commit 28d7fb4
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,11 @@ end

axisfor(patterns, i) =
foldlargs(nothing, patterns...) do _, p::AccessPattern
foldlargs(nothing, ntuple(identity, length(p.indices))...) do _, n
if p.indices[n] === i
foldlargs(1, p.indices...) do n, j
if j === i
reduced(axes(p.indexable)[n])
else
nothing
n + 1
end
end
end |> unreduced
Expand Down
2 changes: 1 addition & 1 deletion test/preamble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using Base.Broadcast: broadcasted, instantiate
using Test
using MacroTools
using NDReducibles
using NDReducibles: Index, plan
using NDReducibles: AccessPattern, Index, plan, axisfor

_indices(idxs::String) = Index.(Tuple(Symbol.(collect(idxs))))

Expand Down
11 changes: 10 additions & 1 deletion test/test_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ f2() = plan(
ND(2) => (:i, :j),
)

@testset begin
@testset "plan" begin
if VERSION < v"1.2-"
@test_broken_inferred f1()
@test_broken_inferred f2()
Expand All @@ -22,4 +22,13 @@ f2() = plan(
end
end

@testset "axisfor" begin
patterns = AccessPattern.((
ND(1) => (:i,),
ND(2) => (:i, :j),
))
@test_inferred axisfor(patterns, Index(:i))
@test_inferred axisfor(patterns, Index(:j))
end

end # module

0 comments on commit 28d7fb4

Please sign in to comment.