-
Notifications
You must be signed in to change notification settings - Fork 159
/
dynamic.jl
187 lines (157 loc) · 5.74 KB
/
dynamic.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
include("trace.jl")
"""
DynamicDSLFunction{T} <: GenerativeFunction{T,DynamicDSLTrace}
A generative function based on a shallowly embedding modeling language based on Julia functions.
Constructed using the `@gen` keyword.
Most methods in the generative function interface involve a end-to-end execution of the function.
"""
struct DynamicDSLFunction{T} <: GenerativeFunction{T,DynamicDSLTrace}
params_grad::Dict{Symbol,Any}
params::Dict{Symbol,Any}
arg_types::Vector{Type}
has_defaults::Bool
arg_defaults::Vector{Union{Some{Any},Nothing}}
julia_function::Function
has_argument_grads::Vector{Bool}
accepts_output_grad::Bool
end
function DynamicDSLFunction(arg_types::Vector{Type},
arg_defaults::Vector{Union{Some{Any},Nothing}},
julia_function::Function,
has_argument_grads, ::Type{T},
accepts_output_grad::Bool) where {T}
params_grad = Dict{Symbol,Any}()
params = Dict{Symbol,Any}()
has_defaults = any(arg -> arg != nothing, arg_defaults)
DynamicDSLFunction{T}(params_grad, params, arg_types,
has_defaults, arg_defaults,
julia_function,
has_argument_grads, accepts_output_grad)
end
function DynamicDSLTrace(gen_fn::T, args) where {T<:DynamicDSLFunction}
# pad args with default values, if available
if gen_fn.has_defaults && length(args) < length(gen_fn.arg_defaults)
defaults = gen_fn.arg_defaults[length(args)+1:end]
defaults = map(x -> something(x), defaults)
args = Tuple(vcat(collect(args), defaults))
end
DynamicDSLTrace{T}(gen_fn, args)
end
accepts_output_grad(gen_fn::DynamicDSLFunction) = gen_fn.accepts_output_grad
mutable struct GFUntracedState
params::Dict{Symbol,Any}
end
function (gen_fn::DynamicDSLFunction)(args...)
state = GFUntracedState(gen_fn.params)
gen_fn.julia_function(state, args...)
end
function exec(gen_fn::DynamicDSLFunction, state, args::Tuple)
gen_fn.julia_function(state, args...)
end
# whether there is a gradient of score with respect to each argument
# it returns 'nothing' for those arguemnts that don't have a derivatice
has_argument_grads(gen::DynamicDSLFunction) = gen.has_argument_grads
"Global reference to the GFI state for the dynamic modeling language."
const state = gensym("state")
"Implementation of @trace for the dynamic modeling language."
function dynamic_trace_impl(expr::Expr)
@assert expr.head == :gentrace "Not a Gen trace expression."
call, addr = expr.args[1], expr.args[2]
if (call.head != :call) error("syntax error in @trace at $(call)") end
fn = call.args[1]
args = Expr(:tuple, call.args[2:end]...)
if addr != nothing
addr = something(addr)
return Expr(:call, GlobalRef(@__MODULE__, :traceat), state, fn, args, addr)
else
return Expr(:call, GlobalRef(@__MODULE__, :splice), state, fn, args)
end
end
# Defaults for untraced execution
@inline traceat(state::GFUntracedState, gen_fn::GenerativeFunction, args, key) =
gen_fn(args...)
@inline traceat(state::GFUntracedState, dist::Distribution, args, key) =
random(dist, args...)
@inline splice(state::GFUntracedState, gen_fn::DynamicDSLFunction, args::Tuple) =
gen_fn(args...)
########################
# trainable parameters #
########################
"Implementation of @param for the dynamic modeling language."
function dynamic_param_impl(expr::Expr)
@assert expr.head == :genparam "Not a Gen param expression."
name = expr.args[1]
Expr(:(=), name, Expr(:call, GlobalRef(@__MODULE__, :read_param), state, QuoteNode(name)))
end
function read_param(state, name::Symbol)
if haskey(state.params, name)
state.params[name]
else
throw(UndefVarError(name))
end
end
##################
# AddressVisitor #
##################
struct AddressVisitor
visited::DynamicSelection
end
AddressVisitor() = AddressVisitor(DynamicSelection())
function visit!(visitor::AddressVisitor, addr)
if addr in visitor.visited
error("Attempted to visit address $addr, but it was already visited")
end
push!(visitor.visited, addr)
end
function all_visited(visited::Selection, choices::ChoiceMap)
allvisited = true
for (key, _) in get_values_shallow(choices)
allvisited = allvisited && (key in visited)
end
for (key, submap) in get_submaps_shallow(choices)
if !(key in visited)
subvisited = visited[key]
allvisited = allvisited && all_visited(subvisited, submap)
end
end
allvisited
end
function get_unvisited(visited::Selection, choices::ChoiceMap)
unvisited = choicemap()
for (key, _) in get_values_shallow(choices)
if !(key in visited)
set_value!(unvisited, key, get_value(choices, key))
end
end
for (key, submap) in get_submaps_shallow(choices)
if !(key in visited)
subvisited = visited[key]
sub_unvisited = get_unvisited(subvisited, submap)
set_submap!(unvisited, key, sub_unvisited)
end
end
unvisited
end
get_visited(visitor) = visitor.visited
function check_no_submap(constraints::ChoiceMap, addr)
if !isempty(get_submap(constraints, addr))
error("Expected a value at address $addr but found a sub-assignment")
end
end
function check_no_value(constraints::ChoiceMap, addr)
if has_value(constraints, addr)
error("Expected a sub-assignment at address $addr but found a value")
end
end
function gen_fn_changed_error(addr)
error("Generative function changed at address: $addr")
end
include("simulate.jl")
include("generate.jl")
include("propose.jl")
include("assess.jl")
include("project.jl")
include("update.jl")
include("regenerate.jl")
include("backprop.jl")
export DynamicDSLFunction