In [2]:
using Rewrite
using Rewrite: Rule, PatternRule, EvalRule, Associative, Commutative

using MacroTools

using Swizzles
using Swizzles.Antennae
using Swizzles.RecognizeStyles: reprexpr

using Base.Broadcast: Broadcasted, broadcasted

### `termtransform` Experiments

As of now termtransform does not recognize when two parts of an expression are equal.

In [3]:
function termtransform(ex, sym_to_ex::Dict{Symbolic, Any}) :: Symbolic
    s = Symbolic(gensym())
    sym_to_ex[s] = ex
    return s
end

function termtransform(ex::Expr, sym_to_ex::Dict{Symbolic, Any})
    if @capture(ex, f_(args__))
        return :($f(
            $(map(arg->termtransform(arg, sym_to_ex), args)...)
        ))
    elseif @capture(ex, arg_::T_) && !(T isa Union{Symbol, Expr})
        s = Symbolic(gensym())
        sym_to_ex[s] = arg
        return s
    else
        s = Symbolic(gensym())
        sym_to_ex[s] = ex
        return s
    end
end

function termtransform(ex) :: Tuple{Term, Dict{Symbolic, Any}}
    sym_to_ex = Dict{Symbolic, Any}()
    ex = termtransform(ex, sym_to_ex)
    (Term(ex), sym_to_ex)
end;

In [4]:
A = [1 2 3; 4 5 6]
B = [300 200 100]
C = [1000]
bd = broadcasted(*, A, broadcasted(+, B, C))
r = reprexpr(:hi, typeof(bd))

:((Antenna{typeof(*)}(*))(hi.args[1]::Array{Int64,2}, (Antenna{typeof(+)}(+))((hi.args[2]).args[1]::Array{Int64,2}, (hi.args[2]).args[2]::Array{Int64,1})))

In [5]:
termtransform(r)

(@term((Antenna{typeof(*)}(*))(##363, (Antenna{typeof(+)}(+))(##364, ##365))), Dict{Symbolic,Any}(##364=>:((hi.args[2]).args[1]),##365=>:((hi.args[2]).args[2]),##363=>:(hi.args[1])))

### Simplification Experiments

In [6]:
struct SimplificationSpec
    rules::Rules
    context::Context
end

In [7]:
function default_spec()
    antenna_term(t::Term) = Term(antenna_expr(t.ex))
    antenna_expr(v::Variable) = v
    function antenna_expr(ex::Expr) :: Expr
        if @capture(ex, f_(args__))
            if f isa Union{Symbolic, Variable} # return f(b_t(arg1), b_t(arg2), b_t(arg3), ...)
                return Expr(:call, f, map(antenna_term, args)...)
            elseif f isa Expr
                throw(ArgumentError("nonbroadcastable term: $ex"))
            elseif f isa Symbol
                throw(ArgumentError("nonbroadcastable term: $ex"))
            else # return Antenna(f)(b_t(arg1), b_t(arg2), b_t(arg3))
                return Expr(:call, Antenna(f), map(antenna_expr, args)...)
            end
        end
        throw(ArgumentError("can't convert to antenna version: $ex"))
    end
    
    function antenna_rule(rule::PatternRule) :: PatternRule
        if !isempty(rule.ps)
            throw(ArgumentError("can't convert rule with properties: $r"))
        end
        return PatternRule(antenna_term(rule.left), antenna_term(rule.right))
    end
    
    @vars x y z
    equalities = Array{Rule, 1}([        
        PatternRule(@term(x * (y + z)), @term(x * y + x * z))
        #=PatternRule(
            @term(Antenna(*)(x, Antenna(+)(y, z))),
            @term(Antenna(+)(Antenna(*)(x, y), Antenna(*)(x, z)))
        )=#
    ])
    append!(equalities, map(antenna_rule, equalities))
    
    properties = []

    return SimplificationSpec(Rules(equalities), Context(properties))
end

SPEC = default_spec()

SimplificationSpec(Rules(Rule[PatternRule(@term(x * (y + z)), @term(x * y + x * z), Function[]), PatternRule(@term((Antenna{typeof(*)}(*))(x, (Antenna{typeof(+)}(+))(y, z))), @term((Antenna{typeof(+)}(+))((Antenna{typeof(*)}(*))(x, y), (Antenna{typeof(*)}(*))(x, z))), Function[])]), Context(Rewrite.Property[]))

In [8]:
function simplify(term::Term)
    Rewrite.with_context(SPEC.context) do
       normalize(term, SPEC.rules)
    end
end;

In [9]:
@syms x y z;
simplify(@term(x * (y + z)))

@term(x * y + x * z)

In [10]:
A = [1 2 3; 4 5 6]
B = [300 200 100]
C = [1000]
bd = broadcasted(*, A, broadcasted(+, B, C))
r = reprexpr(:test, typeof(bd))

:((Antenna{typeof(*)}(*))(test.args[1]::Array{Int64,2}, (Antenna{typeof(+)}(+))((test.args[2]).args[1]::Array{Int64,2}, (test.args[2]).args[2]::Array{Int64,1})))

In [11]:
simplify(termtransform(r)[1])

@term((Antenna{typeof(+)}(+))((Antenna{typeof(*)}(*))(##366, ##367), (Antenna{typeof(*)}(*))(##366, ##368)))

### `exprtransform` Experiments

In [17]:
function exprtransform(s::Symbolic, sym_to_ex::Dict{Symbolic, Any})
    return sym_to_ex[s]
end

function exprtransform(ex::Expr, sym_to_ex::Dict{Symbolic, Any}) :: Expr
    if @capture(ex, f_(args__))
        return :($f(
            $(map(arg->exprtransform(arg, sym_to_ex), args)...)
        ))
    end
    throw(ArgumentError("non expr transformable: $ex"))
end

function exprtransform(term::Term, sym_to_ex::Dict{Symbolic, Any}) :: Expr
    return exprtransform(term.ex, sym_to_ex)
end

exprtransform (generic function with 3 methods)

In [18]:
t, d = termtransform(r)
exprtransform(simplify(t), d)

:((Antenna{typeof(+)}(+))((Antenna{typeof(*)}(*))(test.args[1], (test.args[2]).args[1]), (Antenna{typeof(*)}(*))(test.args[1], (test.args[2]).args[2])))

### Dynamic `reprexpr`

In [15]:
function dynamic_reprexpr(leaf,
                  obj_to_sym::IdDict{Any, Symbolic},
                  sym_to_obj::Dict{Symbolic, Any})
    if !haskey(obj_to_sym, leaf)
        sym = Symbolic(gensym())
        obj_to_sym[leaf] = sym
        sym_to_obj[sym] = leaf
    end
     
    return obj_to_sym[leaf]
end


function dynamic_reprexpr(bd::Broadcasted,
                  obj_to_sym::IdDict{Any, Symbolic},
                  sym_to_obj::Dict{Symbolic, Any})
    arg_exprs = map(arg -> dynamic_reprexpr(arg, obj_to_sym, sym_to_obj),
                    bd.args)
    
    :( $Antenna( $(bd.f) )( $(arg_exprs...) ) )
end


function dynamic_reprexpr(obj) :: Tuple{Term, Dict{Symbolic, Any}}
    sym_to_obj = Dict{Symbolic, Any}()
    expr = dynamic_reprexpr(obj, IdDict{Any, Symbolic}(), sym_to_obj)
    (expr, sym_to_obj)
end;

In [16]:
A = [1 2 3; 4 5 6]
B = [300 200 100]
C = [1000]
D = zeros((1000, 1000))
bd = broadcasted(+, broadcasted(*, A, B), broadcasted(*, A, C))
#bd = broadcasted(+, A)

Broadcasted(+, (Broadcasted(*, ([1 2 3; 4 5 6], [300 200 100])), Broadcasted(*, ([1 2 3; 4 5 6], [1000]))))

In [17]:
reprexpr(bd)

(@term(((Antenna)(+))(((Antenna)(*))(##369, ##370), ((Antenna)(*))(##369, ##371))), Dict{Symbolic,Any}(##370=>[300 200 100],##371=>[1000],##369=>[1 2 3; 4 5 6]))