-
Notifications
You must be signed in to change notification settings - Fork 159
/
update.jl
142 lines (123 loc) · 5.5 KB
/
update.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
mutable struct SwitchUpdateState{T}
weight::Float64
score::Float64
noise::Float64
prev_trace::Trace
trace::Trace
index::Int
discard::ChoiceMap
updated_retdiff::Diff
SwitchUpdateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace)
end
"""
update_recurse_merge(prev_choices::ChoiceMap, choices::ChoiceMap)
Returns choices that are in constraints, merged with all choices in the previous trace that do not have the same address as some choice in the constraints."
"""
function update_recurse_merge(prev_choices::ChoiceMap, choices::ChoiceMap)
prev_choice_submap_iterator = get_submaps_shallow(prev_choices)
prev_choice_value_iterator = get_values_shallow(prev_choices)
choice_submap_iterator = get_submaps_shallow(choices)
choice_value_iterator = get_values_shallow(choices)
new_choices = DynamicChoiceMap()
# Add (address, value) to new_choices from prev_choices if address does not occur in choices.
for (address, value) in prev_choice_value_iterator
address in keys(choice_value_iterator) && continue
set_value!(new_choices, address, value)
end
# Add (address, submap) to new_choices from prev_choices if address does not occur in choices.
# If it does, enter a recursive call to update_recurse_merge.
for (address, node1) in prev_choice_submap_iterator
if address in keys(choice_submap_iterator)
node2 = get_submap(choices, address)
node = update_recurse_merge(node1, node2)
set_submap!(new_choices, address, node)
else
set_submap!(new_choices, address, node1)
end
end
# Add (address, value) from choices to new_choices. This is okay because we've excluded any conflicting addresses from the prev_choices above.
for (address, value) in choice_value_iterator
set_value!(new_choices, address, value)
end
sel, _ = zip(prev_choice_submap_iterator...)
comp = complement(select(sel...))
for (address, node) in get_submaps_shallow(get_selected(choices, comp))
set_submap!(new_choices, address, node)
end
return new_choices
end
"""
update_discard(prev_choices::ChoiceMap, choices::ChoiceMap, new_choices::ChoiceMap)
Returns choices from previous trace that:
1. have an address which does not appear in the new trace.
2. have an address which does appear in the constraints.
"""
function update_discard(prev_choices::ChoiceMap, choices::ChoiceMap, new_choices::ChoiceMap)
discard = choicemap()
for (k, v) in get_submaps_shallow(prev_choices)
new_submap = get_submap(new_choices, k)
choices_submap = get_submap(choices, k)
sub_discard = update_discard(v, choices_submap, new_submap)
set_submap!(discard, k, sub_discard)
end
for (k, v) in get_values_shallow(prev_choices)
if (!has_value(new_choices, k) || has_value(choices, k))
set_value!(discard, k, v)
end
end
discard
end
@inline update_discard(prev_trace::Trace, choices::ChoiceMap, new_trace::Trace) = update_discard(get_choices(prev_trace), choices, get_choices(new_trace))
function process!(gen_fn::Switch{C, N, K, T},
index::Int,
index_argdiff::UnknownChange,
args::Tuple,
kernel_argdiffs::Tuple,
choices::ChoiceMap,
state::SwitchUpdateState{T}) where {C, N, K, T}
# Generate new trace.
merged = update_recurse_merge(get_choices(state.prev_trace), choices)
branch_fn = getfield(gen_fn.branches, index)
new_trace, weight = generate(branch_fn, args, merged)
weight -= get_score(state.prev_trace)
state.discard = update_discard(state.prev_trace, choices, new_trace)
# Set state.
state.index = index
state.weight = weight
state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection())
state.score = get_score(new_trace)
state.trace = new_trace
state.updated_retdiff = UnknownChange()
end
function process!(gen_fn::Switch{C, N, K, T},
index::Int,
index_argdiff::NoChange, # TODO: Diffed wrapper?
args::Tuple,
kernel_argdiffs::Tuple,
choices::ChoiceMap,
state::SwitchUpdateState{T}) where {C, N, K, T}
# Update trace.
new_trace, weight, retdiff, discard = update(getfield(state.prev_trace, :branch), args, kernel_argdiffs, choices)
# Set state.
state.index = index
state.weight = weight
state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection())
state.score = get_score(new_trace)
state.trace = new_trace
state.updated_retdiff = retdiff
state.discard = discard
end
@inline process!(gen_fn::Switch{C, N, K, T}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, choices::ChoiceMap, state::SwitchUpdateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, choices, state)
function update(trace::SwitchTrace{A, T, U},
args::Tuple,
argdiffs::Tuple,
choices::ChoiceMap) where {A, T, U}
gen_fn = trace.gen_fn
index, index_argdiff = args[1], argdiffs[1]
state = SwitchUpdateState{T}(0.0, 0.0, 0.0, trace)
process!(gen_fn, index, index_argdiff,
args[2 : end], argdiffs[2 : end], choices, state)
return SwitchTrace(gen_fn, state.trace,
get_retval(state.trace), args,
state.score, state.noise), state.weight, state.updated_retdiff, state.discard
end