Skip to content

Commit

Permalink
julia 07 fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
tawheeler committed Aug 7, 2018
1 parent aaccbb3 commit a1d4b4b
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 61 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Expand Up @@ -2,7 +2,7 @@ language: julia
sudo: required
dist: trusty
julia:
- 0.6
- 0.7
notifications:
email: false
before_install:
Expand Down
9 changes: 5 additions & 4 deletions REQUIRE
@@ -1,4 +1,5 @@
julia 0.6
TreeView 0.1
StatsBase
AbstractTrees
julia 0.7-beta2
AbstractTrees 0.2
StatsBase 0.24
Nullables 0.0.5
TreeView
100 changes: 52 additions & 48 deletions src/ExprRules.jl
Expand Up @@ -3,8 +3,10 @@ __precompile__()
module ExprRules

import TreeView: walk_tree
import Random: srand
using StatsBase
using AbstractTrees
using Nullables

export
Grammar,
Expand Down Expand Up @@ -59,7 +61,7 @@ iseval(rule::Expr) = (rule.head == :call && rule.args[1] == :_)
"""
get_childtypes(rule::Any, types::AbstractVector{Symbol})
Returns the child types of a production rule.
Returns the child types of a production rule.
"""
function get_childtypes(rule::Any, types::AbstractVector{Symbol})
retval = Symbol[]
Expand All @@ -76,8 +78,8 @@ end
"""
Grammar
Represents a grammar and its production rules.
Use the @grammar macro to create a Grammar object.
Represents a grammar and its production rules.
Use the @grammar macro to create a Grammar object.
"""
struct Grammar
rules::Vector{Any} # list of RHS of rules (subexpressions)
Expand All @@ -104,15 +106,17 @@ macro grammar(ex)
types = Symbol[]
bytype = Dict{Symbol,Vector{Int}}()
for e in ex.args
if e.head == :(=)
s = e.args[1] # name of return type
rule = e.args[2] # expression?
rvec = Any[]
_parse_rule!(rvec, rule)
for r in rvec
push!(rules, r)
push!(types, s)
bytype[s] = push!(get(bytype, s, Int[]), length(rules))
if isa(e, Expr)
if e.head == :(=)
s = e.args[1] # name of return type
rule = e.args[2] # expression?
rvec = Any[]
_parse_rule!(rvec, rule)
for r in rvec
push!(rules, r)
push!(types, s)
bytype[s] = push!(get(bytype, s, Int[]), length(rules))
end
end
end
end
Expand All @@ -126,7 +130,7 @@ _parse_rule!(v::Vector{Any}, r) = push!(v, r)
function _parse_rule!(v::Vector{Any}, ex::Expr)
if ex.head == :call && ex.args[1] == :|
terms = length(ex.args) == 2 ?
collect(eval(Main,ex.args[2])) : #|(a:c) case
collect(Core.eval(Main,ex.args[2])) : #|(a:c) case
ex.args[2:end] #a|b|c case
for t in terms
_parse_rule!(v, t)
Expand Down Expand Up @@ -162,7 +166,7 @@ child_types(grammar::Grammar, rule_index::Int) = grammar.childtypes[rule_index]
"""
isterminal(grammar::Grammar, rule_index::Int)
Returns true if the production rule at rule_index is terminal, i.e., does not contain any nonterminal symbols.
Returns true if the production rule at rule_index is terminal, i.e., does not contain any nonterminal symbols.
"""
isterminal(grammar::Grammar, rule_index::Int) = grammar.isterminal[rule_index]

Expand Down Expand Up @@ -190,7 +194,7 @@ nchildren(grammar::Grammar, rule_index::Int) = length(grammar.childtypes[rule_in
"""
max_arity(grammar::Grammar)
Returns the maximum arity (number of children) over all production rules in the grammar.
Returns the maximum arity (number of children) over all production rules in the grammar.
"""
max_arity(grammar::Grammar) = maximum(length(cs) for cs in grammar.childtypes)

Expand All @@ -212,35 +216,35 @@ RuleNode(ind::Int, _val::Any) = RuleNode(ind, Nullable{Any}(_val), RuleNode[])
"""
return_types(grammar::Grammar, node::RuleNode)
Returns the return type in the production rule used by node.
Returns the return type in the production rule used by node.
"""
return_type(grammar::Grammar, node::RuleNode) = grammar.types[node.ind]

"""
child_types(grammar::Grammar, node::RuleNode)
Returns the list of child types in the production rule used by node.
Returns the list of child types in the production rule used by node.
"""
child_types(grammar::Grammar, node::RuleNode) = grammar.childtypes[node.ind]

"""
isterminal(grammar::Grammar, node::RuleNode)
Returns true if the production rule used by node is terminal, i.e., does not contain any nonterminal symbols.
Returns true if the production rule used by node is terminal, i.e., does not contain any nonterminal symbols.
"""
isterminal(grammar::Grammar, node::RuleNode) = grammar.isterminal[node.ind]

"""
nchildren(grammar::Grammar, node::RuleNode)
Returns the number of children in the production rule used by node.
Returns the number of children in the production rule used by node.
"""
nchildren(grammar::Grammar, node::RuleNode) = length(child_types(grammar, node))

"""
contains_returntype(node::RuleNode, grammar::Grammar, sym::Symbol, maxdepth::Int=typemax(Int))
Returns true if the tree rooted at node contains at least one node at depth less than maxdepth
Returns true if the tree rooted at node contains at least one node at depth less than maxdepth
with the given return type.
"""
function contains_returntype(node::RuleNode, grammar::Grammar, sym::Symbol, maxdepth::Int=typemax(Int))
Expand Down Expand Up @@ -281,7 +285,7 @@ function Base.hash(node::RuleNode, h::UInt=zero(UInt))
end

function Base.show(io::IO, grammar::Grammar)
for i in eachindex(grammar.rules)
for i in eachindex(grammar.rules)
println(io, i, ": ", grammar.types[i], " = ", grammar.rules[i])
end
end
Expand Down Expand Up @@ -313,7 +317,7 @@ end
"""
depth(root::RuleNode)
Return the depth of the expression tree rooted at root.
Return the depth of the expression tree rooted at root.
"""
function depth(root::RuleNode)
retval = 1
Expand Down Expand Up @@ -383,8 +387,8 @@ end
Evaluate the expression tree with root rulenode.
"""
Core.eval(rulenode::RuleNode, grammar::Grammar) = eval(Main, get_executable(rulenode, grammar))
Core.eval(grammar::Grammar, index::Int) = eval(Main, grammar.rules[index].args[2])
Core.eval(rulenode::RuleNode, grammar::Grammar) = Core.eval(Main, get_executable(rulenode, grammar))
Core.eval(grammar::Grammar, index::Int) = Core.eval(Main, grammar.rules[index].args[2])
function Base.display(rulenode::RuleNode, grammar::Grammar)
root = get_executable(rulenode, grammar)
if isa(root, Expr)
Expand All @@ -406,7 +410,7 @@ function Base.rand(::Type{RuleNode}, grammar::Grammar, typ::Symbol, max_depth::I
StatsBase.sample([r for r in rules if isterminal(grammar, r)])

rulenode = iseval(grammar, rule_index) ?
RuleNode(rule_index, eval(grammar, rule_index)) :
RuleNode(rule_index, Core.eval(grammar, rule_index)) :
RuleNode(rule_index)

if !grammar.isterminal[rule_index]
Expand All @@ -426,7 +430,7 @@ function Base.rand(::Type{RuleNode}, grammar::Grammar, typ::Symbol, dmap::Abstra
rule_index = StatsBase.sample([r for r in rules if dmap[r] max_depth])

rulenode = iseval(grammar, rule_index) ?
RuleNode(rule_index, eval(grammar, rule_index)) :
RuleNode(rule_index, Core.eval(grammar, rule_index)) :
RuleNode(rule_index)

if !grammar.isterminal[rule_index]
Expand Down Expand Up @@ -454,23 +458,23 @@ function StatsBase.sample(root::RuleNode, maxdepth::Int=typemax(Int))
x.node
end
function _sample(node::RuleNode, x::RuleNodeAndCount, maxdepth::Int)
maxdepth < 1 && return
maxdepth < 1 && return
x.cnt += 1
if rand() <= 1/x.cnt
x.node = node
x.node = node
end
for child in node.children
_sample(child, x, maxdepth-1)
end
end

"""
sample(root::RuleNode, typ::Symbol, grammar::Grammar,
sample(root::RuleNode, typ::Symbol, grammar::Grammar,
maxdepth::Int=typemax(Int))
Selects a uniformly random node of the given return type, typ, limited to maxdepth.
"""
function StatsBase.sample(root::RuleNode, typ::Symbol, grammar::Grammar,
function StatsBase.sample(root::RuleNode, typ::Symbol, grammar::Grammar,
maxdepth::Int=typemax(Int))
x = RuleNodeAndCount(root, 0)
if grammar.types[root.ind] == typ
Expand All @@ -482,13 +486,13 @@ function StatsBase.sample(root::RuleNode, typ::Symbol, grammar::Grammar,
grammar.types[x.node.ind] == typ || error("type $typ not found in RuleNode")
x.node
end
function _sample(node::RuleNode, typ::Symbol, grammar::Grammar, x::RuleNodeAndCount,
function _sample(node::RuleNode, typ::Symbol, grammar::Grammar, x::RuleNodeAndCount,
maxdepth::Int)
maxdepth < 1 && return
maxdepth < 1 && return
if grammar.types[node.ind] == typ
x.cnt += 1
if rand() <= 1/x.cnt
x.node = node
x.node = node
end
end
for child in node.children
Expand Down Expand Up @@ -552,7 +556,7 @@ end
"""
sample(::Type{NodeLoc}, root::RuleNode, maxdepth::Int=typemax(Int))
Selects a uniformly random node in the tree no deeper than maxdepth using reservoir sampling.
Selects a uniformly random node in the tree no deeper than maxdepth using reservoir sampling.
Returns a NodeLoc that specifies the location using its parent so that the subtree can be replaced.
"""
function StatsBase.sample(::Type{NodeLoc}, root::RuleNode, maxdepth::Int=typemax(Int))
Expand All @@ -561,7 +565,7 @@ function StatsBase.sample(::Type{NodeLoc}, root::RuleNode, maxdepth::Int=typemax
x.loc
end
function _sample(::Type{NodeLoc}, node::RuleNode, x::NodeLocAndCount, maxdepth::Int)
maxdepth < 1 && return
maxdepth < 1 && return
for (j,child) in enumerate(node.children)
x.cnt += 1
if rand() <= 1/x.cnt
Expand All @@ -575,9 +579,9 @@ end
sample(::Type{NodeLoc}, root::RuleNode, typ::Symbol, grammar::Grammar)
Selects a uniformly random node in the tree of a given type, specified using its parent such that the subtree can be replaced.
Returns a NodeLoc.
Returns a NodeLoc.
"""
function StatsBase.sample(::Type{NodeLoc}, root::RuleNode, typ::Symbol, grammar::Grammar,
function StatsBase.sample(::Type{NodeLoc}, root::RuleNode, typ::Symbol, grammar::Grammar,
maxdepth::Int=typemax(Int))
x = NodeLocAndCount(root_node_loc(root), 0)
if grammar.types[root.ind] == typ
Expand All @@ -587,9 +591,9 @@ function StatsBase.sample(::Type{NodeLoc}, root::RuleNode, typ::Symbol, grammar:
grammar.types[get(root,x.loc).ind] == typ || error("type $typ not found in RuleNode")
x.loc
end
function _sample(::Type{NodeLoc}, node::RuleNode, typ::Symbol, grammar::Grammar,
function _sample(::Type{NodeLoc}, node::RuleNode, typ::Symbol, grammar::Grammar,
x::NodeLocAndCount, maxdepth::Int)
maxdepth < 1 && return
maxdepth < 1 && return
for (j,child) in enumerate(node.children)
if grammar.types[child.ind] == typ
x.cnt += 1
Expand Down Expand Up @@ -650,7 +654,7 @@ function _next_state!(node::RuleNode, grammar::Grammar, max_depth::Int)
while !child_worked
child_type = return_type(grammar, child)
child_rules = grammar[child_type]
i = findfirst(child_rules, child.ind)
i = something(findfirst(isequal(child.ind), child_rules), 0)
if i < length(child_rules)
child_worked = true
child = RuleNode(child_rules[i+1])
Expand Down Expand Up @@ -695,7 +699,7 @@ function _next_state!(node::RuleNode, grammar::Grammar, max_depth::Int)
end

"""
ExpressionIterator(grammar::Grammar, max_depth::Int, sym::Symbol)
ExpressionIterator(grammar::Grammar, max_depth::Int, sym::Symbol)
An iterator over all possible expressions of a grammar up to max_depth with start symbol sym.
"""
Expand All @@ -704,7 +708,7 @@ mutable struct ExpressionIterator
max_depth::Int
sym::Symbol
end
Base.iteratorsize(::ExpressionIterator) = Base.SizeUnknown()
Base.IteratorSize(::ExpressionIterator) = Base.SizeUnknown()
Base.eltype(::ExpressionIterator) = RuleNode
Base.done(iter::ExpressionIterator, state::Tuple{RuleNode,Bool}) = !state[2]
function Base.start(iter::ExpressionIterator)
Expand All @@ -717,7 +721,7 @@ function Base.start(iter::ExpressionIterator)
while !worked
# increment root's rule
rules = grammar[sym]
i = findfirst(rules, node.ind)
i = something(findfirst(isequal(node.ind), rules), 0)
if i < length(rules)
node, worked = RuleNode(rules[i+1]), true
if !isterminal(grammar, node)
Expand All @@ -738,7 +742,7 @@ function Base.next(iter::ExpressionIterator, state::Tuple{RuleNode,Bool})
while !worked
# increment root's rule
rules = grammar[iter.sym]
i = findfirst(rules, node.ind)
i = something(findfirst(isequal(node.ind), rules), 0)
if i < length(rules)
node, worked = RuleNode(rules[i+1]), true
if !isterminal(grammar, node)
Expand All @@ -754,7 +758,7 @@ function Base.next(iter::ExpressionIterator, state::Tuple{RuleNode,Bool})
end

"""
count_expressions(grammar::Grammar, max_depth::Int, sym::Symbol)
count_expressions(grammar::Grammar, max_depth::Int, sym::Symbol)
Count the number of possible expressions of a grammar up to max_depth with start symbol sym.
"""
Expand All @@ -776,7 +780,7 @@ function count_expressions(grammar::Grammar, max_depth::Int, sym::Symbol)
end

"""
count_expressions(iter::ExpressionIterator)
count_expressions(iter::ExpressionIterator)
Count the number of possible expressions in the expression iterator.
"""
Expand All @@ -793,7 +797,7 @@ Returns the minimum depth achievable for each production rule, dmap.
"""
function mindepth_map(grammar::Grammar)
dmap0 = Int[isterminal(grammar,i) ? 0 : typemax(Int)/2 for i in eachindex(grammar.rules)]
dmap1 = Vector{Int}(length(grammar.rules))
dmap1 = Vector{Int}(undef, length(grammar.rules))
while dmap0 != dmap1
for i in eachindex(grammar.rules)
dmap1[i] = _mindepth(grammar, i, dmap0)
Expand All @@ -809,7 +813,7 @@ end
"""
mindepth(grammar::Grammar, typ::Symbol, dmap::AbstractVector{Int})
Returns the minimum depth achievable for a given nonterminal symbol
Returns the minimum depth achievable for a given nonterminal symbol
"""
function mindepth(grammar::Grammar, typ::Symbol, dmap::AbstractVector{Int})
return minimum(dmap[grammar.bytype[typ]])
Expand Down

0 comments on commit a1d4b4b

Please sign in to comment.