Skip to content

Commit

Permalink
rooted_trees -> RootedTreeIterator
Browse files Browse the repository at this point in the history
  • Loading branch information
ranocha committed Mar 6, 2019
1 parent e722f0f commit ad4cb3e
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 22 deletions.
31 changes: 10 additions & 21 deletions src/RootedTrees.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ using LinearAlgebra
import Base: show, isless, ==, iterate


export RootedTree
export RootedTree, RootedTreeIterator

export α, β, γ, σ, order, residual_order_condition

export rooted_trees, count_trees
export count_trees



Expand All @@ -34,7 +34,6 @@ function RootedTree(level_sequence::AbstractVector)
RootedTree{T,V}(level_sequence)
end
#TODO: Validate rooted tree in constructor?
#TODO: Allow other vector types?


# #function RootedTree(sequence::Vector{T}, valid::Bool)
Expand Down Expand Up @@ -89,8 +88,8 @@ end
"""
canonical_representation!(t::RootedTree)
Use the canonical representation of the rooted tree `t`, i.e. the one with
lexicographically biggest level sequence.
Change the representation of the rooted tree `t` to the canonical one, i.e. the
one with lexicographically biggest level sequence.
"""
function canonical_representation!(t::RootedTree)
subtr = subtrees(t)
Expand Down Expand Up @@ -126,27 +125,17 @@ end
Iterator over all rooted trees of given `order`.
"""
struct RootedTreeIterator{T<:Integer}
order::T
t::RootedTree{T,Vector{T}}

function RootedTreeIterator(level_sequence::AbstractVector{T}) where {T<:Integer}
new{T}(RootedTree(Vector{T}(level_sequence)))
function RootedTreeIterator(order::T) where {T<:Integer}
new{T}(order, RootedTree(Vector{T}(one(T):order)))
end
end
#TODO: change types?

"""
rooted_trees(order::Integer)
Returns an iterator over all rooted trees of given `order`.
"""
function rooted_trees(order::Integer)
order < 1 && throw(ArgumentError("The `order` must be at least one."))

RootedTreeIterator(Vector(one(order):order))
end


function iterate(iter::RootedTreeIterator)
function iterate(iter::RootedTreeIterator{T}) where {T}
iter.t.level_sequence[:] = one(T):iter.order
(iter.t, false)
end

Expand Down Expand Up @@ -188,7 +177,7 @@ function count_trees(order)
order < 1 && throw(ArgumentError("The `order` must be at least one."))

num = 0
for _ in rooted_trees(order)
for _ in RootedTreeIterator(order)
num += 1
end
num
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ t = RootedTree([1, 2, 3, 4, 5])
number_of_rooted_trees = [1, 1, 2, 4, 9, 20, 48, 115, 286, 719]
for order in 1:10
num = 0
for t in rooted_trees(order)
for t in RootedTreeIterator(order)
num += 1
end
@test num == number_of_rooted_trees[order] == count_trees(order)
Expand Down

0 comments on commit ad4cb3e

Please sign in to comment.