Skip to content
This repository has been archived by the owner on Dec 3, 2019. It is now read-only.

Commit

Permalink
Refactored conversion macros for flexiblity and clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
Tim Thatcher committed Jul 13, 2016
1 parent e8811f3 commit 24ee207
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 75 deletions.
4 changes: 4 additions & 0 deletions src/hyperparameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ function convert{T<:Real}(::Type{HyperParameter{T}}, θ::HyperParameter)
HyperParameter{T}(convert(T, θ.value), convert(Interval{T}, θ.bounds), θ.isfixed)
end

function convert{T<:Real}(::Type{Variable{T}}, θ::HyperParameter)
Variable{T}(convert(T, θ.value), θ.isfixed)
end

function show{T}(io::IO, θ::HyperParameter{T})
print(io, "HyperParameter{" * string(T) * "}(", θ.value, ") ∈ ", θ.bounds)
end
Expand Down
130 changes: 55 additions & 75 deletions src/meta.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ function promote_type_int(U_i::DataType...)
U_max <: Signed ? U_max : Int64
end


# Checks to make sure the fields in the datatype are of type Parameter{T} or Parameter{U} where
# T<:AbstractFloat and U<:Integer
function fieldparameters(obj::DataType)
function checkfields(obj::DataType)
fields = fieldnames(obj)
0 < length(obj.parameters) <= 2 || error("Type must have one or two parameters (T & U)")
for param in obj.parameters
Expand Down Expand Up @@ -43,23 +42,18 @@ function fieldparameters(obj::DataType)
error("$field type must be ::Parameter{T<:AbstractFloat} or ::Parameter{U<:Integer}")
end
end
field_parameters = Symbol[fieldtype(obj, field).parameters[1].name for field in fields]
return (fields, field_parameters)
return fields
end

function promote_code_block(obj::DataType, cstr_params, field_params)
promote_T = Expr(:call, :promote_type_float, cstr_params[field_params .== :T]...)
if length(obj.parameters) == 1
return :(T = $promote_T)
else
promote_U = Expr(:call, :promote_type_int, cstr_params[field_params .== :U]...)
return Expr(:block, :(T = $promote_T), :(U = $promote_U))
end
function fieldparameters(obj::DataType)
fields = checkfields(obj)
field_parameters = Symbol[fieldtype(obj, field).parameters[1].name for field in fields]
return (fields, field_parameters)
end

function fieldparameters_constructor(obj::DataType)
fields, field_params = fieldparameters(obj)
n = length(fields)
# [:T, :U, :T, :U] -> ([:T1, :U1, :T2, :U2], [:Float64, :Int64, promote_type_float(T1), ...]
function constructorparameters(field_params)
n = length(field_params)
counter = [1,1] # T count, U count
constructor_params = Array(Symbol, n)
constructor_types = Array(Union{Symbol,Expr}, n)
Expand All @@ -76,13 +70,15 @@ function fieldparameters_constructor(obj::DataType)
end
counter[param_idx] += 1
end
(fields, field_params, constructor_params, constructor_types)
(constructor_params, constructor_types)
end

function generate_outer_constructor2(obj::DataType, default_values::Tuple{Vararg{Real}})
# (:a,:b,:c,:d), [:T, :U, :T, :U], [:T1, :U1, :T2, :U2],
# [:Float64, :Int64, promote_type_float(T1), promote_type_int(U1)]
fields, field_params, cstr_params, cstr_types = fieldparameters_constructor(obj)
function generate_outer_constructor(obj::DataType, default_values::Tuple{Vararg{Real}})
# (:a,:b,:c,:d), [:T, :U, :T, :U],
fields, field_params = fieldparameters(obj)

# [:T1, :U1, :T2, :U2], [:Float64, :Int64, promote_type_float(T1), promote_type_int(U1)]
cstr_params, cstr_types = constructorparameters(field_params)

if (n = length(default_values)) != length(fields)
error("Default count does not match field count")
Expand All @@ -98,71 +94,55 @@ function generate_outer_constructor2(obj::DataType, default_values::Tuple{Vararg
# (Variable{T}(a), Variable{T}(b), Variable{U}(c))
call_args = [:(Variable{$(field_params[i])}($(fields[i]))) for i = 1:n]

block_definition = Expr(:call, Expr(:curly, obj.name.name, defn_params...), defn_args...)
block_promotion = promote_code_block(obj, cstr_params, field_params)
block_call = Expr(:call, Expr(:curly, obj.name.name, [p.name for p in obj.parameters]...), call_args...)
sym_obj = obj.name.name

return Expr(:function, block_definition, Expr(:block, block_promotion, block_call))
end
ex_definition = Expr(:call, Expr(:curly, sym_obj, defn_params...), defn_args...)
ex_body = Expr(:block)

ex1 = Expr(:call, :promote_type_float, cstr_params[field_params .== :T]...)
push!(ex_body.args, :(T = $ex1))

function generate_conversions(obj::DataType)
fields, field_parameters = fieldparameters(obj)
obj_sym = obj.name.name
if length(obj.parameters) == 2
# convert(::Type{obj{T,U}}, ::obj) = ...
convert_target = Expr(:(::), Expr(:curly, :Type, Expr(:curly, obj_sym, :T, :U)))
converted_arguments =
Expr[:(Variable(convert($(field_parameters[i]), obj.$(fields[i]).value),
obj.$(fields[i]).isfixed)) for i in eachindex(fields)]
conversion1 = Expr(:(=), Expr(:call, Expr(:curly, :convert, :T, :U), convert_target,
Expr(:(::), :obj, obj_sym)),
Expr(:call, obj_sym, converted_arguments...))
# convert(::Type{obj{T}}, ::obj{_,U}) = ...
convert_target = Expr(:(::), Expr(:curly, :Type, Expr(:curly, obj_sym, :T)))
convert_source = Expr(:(::), Expr(:curly, :Type, Expr(:curly, obj_sym, :_, :U)))
conversion2 = Expr(:(=), Expr(:call, Expr(:curly, :convert, :T, :_, :U), convert_target,
Expr(:(::), :obj, Expr(:curly, obj_sym, :_, :U))),
Expr(:call, obj_sym, converted_arguments...))
return Expr(:block, conversion1, conversion2)
elseif length(obj.parameters) == 1
# convert(::Type{obj{T}}, ::obj) = ...
convert_target = Expr(:(::), Expr(:curly, :Type, Expr(:curly, obj_sym, :T)))
converted_arguments = Expr[:(Variable(convert(T, obj.$(fields[i]).value),
obj.$(fields[i]).isfixed)) for i in eachindex(fields)]
return Expr(:(=), Expr(:call, Expr(:curly, :convert, :T), convert_target,
Expr(:(::), :obj, obj_sym)),
Expr(:call, obj_sym, converted_arguments...))
else
error("Data type should have 1 or 2 parameters")
ex2 = Expr(:call, :promote_type_int, cstr_params[field_params .== :U]...)
push!(ex_body.args, :(U = $ex2))
end

ex3 = Expr(:call, Expr(:curly, sym_obj, [p.name for p in obj.parameters]...), call_args...)
push!(ex_body.args, ex3)

return Expr(:function, ex_definition, ex_body)
end

# Notes:
# Assumes arguments are organized by type parameter
# Assumes Variable{} is the argument type
function generate_outer_constructor(obj::DataType, defaults::Tuple{Vararg{Real}})
fields, parameters = fieldparameters(obj)
length(defaults) == length(fields) || error("Default count does not match field count")
first_idx = Dict(findfirst(parameters, :T) => :Float64, findfirst(parameters, :U) => :Int64)
# Produces [:Float64, T, T, ..., Int64, U, U, ...]
default_parameters = Symbol[get(first_idx, i, parameters[i]) for i in eachindex(parameters)]
# Produces [arg1::Argument{T} = convert(Float64, default1),
# arg2::Argument{T} = convert(T,default2)...]
constructor_params = [Expr(:(<:), p.name, p.ub.name.name) for p in obj.parameters]
constructor_args = [Expr(:kw, :($(fields[i])::Argument{$(parameters[i])}),
:(convert($(default_parameters[i]), $(defaults[i]))))
for i in eachindex(fields)]
type_params = [p.name for p in obj.parameters]
type_args = [:(Variable($(fields[i]))) for i in eachindex(fields)]

obj_sym = obj.name.name
Expr(:(=), Expr(:call, Expr(:curly, obj_sym, constructor_params...), constructor_args...),
Expr(:call, Expr(:curly, obj_sym, type_params...), type_args...))
function generate_conversion_TU(obj::DataType)
fields, field_params = fieldparameters(obj)
sym_obj = obj.name.name

ex1 = :(convert{T,U}(::Type{$sym_obj{T,U}}, f::$sym_obj))
ex2 = :(convert{T,_,U}(::Type{$sym_obj{T}}, f::$sym_obj{_,U}))

args = [:(Variable{$(field_params[i])}(f.$(fields[i]))) for i in eachindex(fields)]

ex3 = Expr(:call, sym_obj, args...)

Expr(:block, Expr(:(=), ex1, ex3), Expr(:(=), ex2, ex3))
end

function generate_conversion_T(obj::DataType)
fields = checkfields(obj)
sym_obj = obj.name.name

ex1 = :(convert{T}(::Type{$sym_obj{T}}, f::$sym_obj))

args = [:(Variable{T}(f.$(fields[i]))) for i in eachindex(fields)]

Expr(:(=), ex1, Expr(:call, sym_obj, args...))
end

function generate_conversions(obj::DataType)
length(obj.parameters) == 2 ? generate_conversion_TU(obj) : generate_conversion_T(obj)
end

macro outer_constructor(obj, defaults)
eval(generate_outer_constructor2(eval(obj), eval(defaults)))
eval(generate_outer_constructor(eval(obj), eval(defaults)))
eval(generate_conversions(eval(obj)))
end

0 comments on commit 24ee207

Please sign in to comment.