Skip to content

Commit

Permalink
iteration over all expressions up to maximum depth
Browse files Browse the repository at this point in the history
  • Loading branch information
tawheeler committed Aug 22, 2017
1 parent 163c60a commit 574d681
Show file tree
Hide file tree
Showing 2 changed files with 221 additions and 0 deletions.
131 changes: 131 additions & 0 deletions src/ExprRules.jl
Expand Up @@ -10,6 +10,8 @@ export
RuleNode,
NodeLoc,

ExpressionIterator,

@ruleset,
@digits,
max_arity,
Expand Down Expand Up @@ -431,4 +433,133 @@ function StatsBase.sample(::Type{NodeLoc}, root::RuleNode, typ::Symbol, ruleset:
return selected
end

###

function _next_state!(node::RuleNode, ruleset::RuleSet, max_depth::Int)

if max_depth < 1
return (node, false) # did not work
elseif isterminal(ruleset, node)
# do nothing
return (node, false) # cannot change leaves
else # !isterminal
if isempty(node.children)
if max_depth 1
return (node,false) # cannot expand
end

# build out the node
for c in child_types(ruleset, node)

worked = false
i = 0
child = RuleNode(0)
child_rules = ruleset[c]
while !worked && i < length(child_rules)
i += 1
child = RuleNode(child_rules[i])
worked = true
if !isterminal(ruleset, child)
child, worked = _next_state!(child, ruleset, max_depth-1)
end
end
if !worked
return (node, false) # did not work
end
push!(node.children, child)
end

return (node, true)
else # not empty
# make one change, starting with rightmost child
worked = false
child_index = length(node.children) + 1
while !worked && child_index > 1
child_index -= 1
child = node.children[child_index]

child, child_worked = _next_state!(child, ruleset, max_depth-1)
while !child_worked
child_type = return_type(ruleset, child)
child_rules = ruleset[child_type]
i = findfirst(child_rules, child.ind)
if i < length(child_rules)
child_worked = true
child = RuleNode(child_rules[i+1])
if !isterminal(ruleset, child)
child, child_worked = _next_state!(child, ruleset, max_depth-1)
end
node.children[child_index] = child
else
break
end
end

if child_worked
worked = true

# reset remaining children
for child_index2 in child_index+1 : length(node.children)
c = child_types(ruleset, node)[child_index2]
worked = false
i = 0
child = RuleNode(0)
child_rules = ruleset[c]
while !worked && i < length(child_rules)
i += 1
child = RuleNode(child_rules[i])
worked = true
if !isterminal(ruleset, child)
child, worked = _next_state!(child, ruleset, max_depth-1)
end
end
if !worked
break
end
node.children[child_index2] = child
end
end
end

return (node, worked)
end
end
end

mutable struct ExpressionIterator
ruleset::RuleSet
max_depth::Int
sym::Symbol
end
Base.iteratorsize(::ExpressionIterator) = Base.SizeUnknown()
Base.eltype(::ExpressionIterator) = RuleNode
Base.done(iter::ExpressionIterator, state::Tuple{RuleNode,Bool}) = !state[2]
function Base.start(iter::ExpressionIterator)
node = RuleNode(iter.ruleset[iter.sym][1])
return _next_state!(node, iter.ruleset, iter.max_depth)
end
function Base.next(iter::ExpressionIterator, state::Tuple{RuleNode,Bool})
ruleset, max_depth = iter.ruleset, iter.max_depth
item = deepcopy(state[1])
node, worked = _next_state!(state[1], ruleset, max_depth)

while !worked
# increment root's rule
rules = ruleset[iter.sym]
i = findfirst(rules, node.ind)
if i < length(rules)
node, worked = RuleNode(rules[i+1]), true
if !isterminal(ruleset, node)
node, worked = _next_state!(node, ruleset, max_depth)
end
else
break
end
end

state = (node, worked)
return (item, state)
end


end # module
90 changes: 90 additions & 0 deletions test/runtests.jl
Expand Up @@ -136,3 +136,93 @@ let
loc = sample(NodeLoc, rulenode, :Real, ruleset)
end
end

let
ruleset = @ruleset begin
R = R + R
R = 1
R = 2
end


node = RuleNode(1)
node, worked = ExprRules._next_state!(node, ruleset, 2)
@test worked
@test isequal(node, RuleNode(1, [RuleNode(2), RuleNode(2)]))

node = RuleNode(1, [RuleNode(2), RuleNode(2)])
node, worked = ExprRules._next_state!(node, ruleset, 2)
@test worked
@test isequal(node, RuleNode(1, [RuleNode(2), RuleNode(3)]))

node = RuleNode(1, [RuleNode(2), RuleNode(3)])
node, worked = ExprRules._next_state!(node, ruleset, 2)
@test worked
@test isequal(node, RuleNode(1, [RuleNode(3), RuleNode(2)]))

node = RuleNode(1, [RuleNode(3), RuleNode(2)])
node, worked = ExprRules._next_state!(node, ruleset, 2)
@test worked
@test isequal(node, RuleNode(1, [RuleNode(3), RuleNode(3)]))

node = RuleNode(1, [RuleNode(3), RuleNode(3)])
node, worked = ExprRules._next_state!(node, ruleset, 2)
@test !worked

###

node = RuleNode(1)
for testnode in [
RuleNode(1, [RuleNode(1, [RuleNode(2), RuleNode(2)]), RuleNode(1, [RuleNode(2), RuleNode(2)])]),
RuleNode(1, [RuleNode(1, [RuleNode(2), RuleNode(2)]), RuleNode(1, [RuleNode(2), RuleNode(3)])]),
RuleNode(1, [RuleNode(1, [RuleNode(2), RuleNode(2)]), RuleNode(1, [RuleNode(3), RuleNode(2)])]),
RuleNode(1, [RuleNode(1, [RuleNode(2), RuleNode(2)]), RuleNode(1, [RuleNode(3), RuleNode(3)])]),
RuleNode(1, [RuleNode(1, [RuleNode(2), RuleNode(2)]), RuleNode(2)]),
RuleNode(1, [RuleNode(1, [RuleNode(2), RuleNode(2)]), RuleNode(3)]),
RuleNode(1, [RuleNode(1, [RuleNode(2), RuleNode(3)]), RuleNode(1, [RuleNode(2), RuleNode(2)])]),
RuleNode(1, [RuleNode(1, [RuleNode(2), RuleNode(3)]), RuleNode(1, [RuleNode(2), RuleNode(3)])]),
]
node, worked = ExprRules._next_state!(node, ruleset, 3)
@test worked
@test isequal(node, testnode)
end

###

iter = ExpressionIterator(ruleset, 2, :R)
state = start(iter)
@test !done(iter, state)
@test isequal(first(iter), RuleNode(1, [RuleNode(2), RuleNode(2)]))
@test all(isequal(a,b) for (a,b) in zip(collect(iter), [
RuleNode(1, [RuleNode(2), RuleNode(2)]),
RuleNode(1, [RuleNode(2), RuleNode(3)]),
RuleNode(1, [RuleNode(3), RuleNode(2)]),
RuleNode(1, [RuleNode(3), RuleNode(3)]),
RuleNode(2),
RuleNode(3),
]))
end

let
ruleset = @ruleset begin
R = I | F
I = 1 | 2
F = F + F
F = 1.5
end

iter = ExpressionIterator(ruleset, 2, :R)
@test all(isequal(a,b) for (a,b) in zip(collect(iter), [
RuleNode(1, [RuleNode(3)]),
RuleNode(1, [RuleNode(4)]),
RuleNode(2, [RuleNode(6)]),
]))

iter = ExpressionIterator(ruleset, 3, :R)
@test all(isequal(a,b) for (a,b) in zip(collect(iter), [
RuleNode(1, [RuleNode(3)]),
RuleNode(1, [RuleNode(4)]),
RuleNode(2, [RuleNode(5, [RuleNode(6), RuleNode(6)])]),
RuleNode(2, [RuleNode(6)]),
]))
end

0 comments on commit 574d681

Please sign in to comment.