-
Notifications
You must be signed in to change notification settings - Fork 159
/
importance.jl
124 lines (113 loc) · 5.2 KB
/
importance.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
(traces, log_norm_weights, lml_est) = importance_sampling(model::GenerativeFunction,
model_args::Tuple, observations::ChoiceMap, num_samples::Int; verbose=false)
(traces, log_norm_weights, lml_est) = importance_sampling(model::GenerativeFunction,
model_args::Tuple, observations::ChoiceMap,
proposal::GenerativeFunction, proposal_args::Tuple,
num_samples::Int; verbose=false)
Run importance sampling, returning a vector of traces with associated log weights.
The log-weights are normalized.
Also return the estimate of the marginal likelihood of the observations (`lml_est`).
The observations are addresses that must be sampled by the model in the given model arguments.
The first variant uses the internal proposal distribution of the model.
The second variant uses a custom proposal distribution defined by the given generative function.
All addresses of random choices sampled by the proposal should also be sampled by the model function.
Setting `verbose=true` prints a progress message every sample.
"""
function importance_sampling(
model::GenerativeFunction{T,U},
model_args::Tuple,
observations::ChoiceMap,
num_samples::Int;
verbose=false) where {T,U}
traces = Vector{U}(undef, num_samples)
log_weights = Vector{Float64}(undef, num_samples)
for i=1:num_samples
verbose && println("sample: $i of $num_samples")
(traces[i], log_weights[i]) = generate(model, model_args, observations)
end
log_total_weight = logsumexp(log_weights)
log_ml_estimate = log_total_weight - log(num_samples)
log_normalized_weights = log_weights .- log_total_weight
return (traces, log_normalized_weights, log_ml_estimate)
end
function importance_sampling(
model::GenerativeFunction{T,U},
model_args::Tuple,
observations::ChoiceMap,
proposal::GenerativeFunction,
proposal_args::Tuple,
num_samples::Int;
verbose=false) where {T,U}
traces = Vector{U}(undef, num_samples)
log_weights = Vector{Float64}(undef, num_samples)
for i=1:num_samples
verbose && println("sample: $i of $num_samples")
(proposed_choices, proposal_weight, _) = propose(proposal, proposal_args)
constraints = merge(observations, proposed_choices)
(traces[i], model_weight) = generate(model, model_args, constraints)
log_weights[i] = model_weight - proposal_weight
end
log_total_weight = logsumexp(log_weights)
log_ml_estimate = log_total_weight - log(num_samples)
log_normalized_weights = log_weights .- log_total_weight
return (traces, log_normalized_weights, log_ml_estimate)
end
"""
(trace, lml_est) = importance_resampling(model::GenerativeFunction,
model_args::Tuple, observations::ChoiceMap, num_samples::Int;
verbose=false)
(traces, lml_est) = importance_resampling(model::GenerativeFunction,
model_args::Tuple, observations::ChoiceMap,
proposal::GenerativeFunction, proposal_args::Tuple,
num_samples::Int; verbose=false)
Run sampling importance resampling, returning a single trace.
Unlike `importance_sampling`, the memory used constant in the number of samples.
Setting `verbose=true` prints a progress message every sample.
"""
function importance_resampling(
model::GenerativeFunction{T,U},
model_args::Tuple,
observations::ChoiceMap,
num_samples::Int;
verbose=false) where {T,U}
(model_trace::U, log_weight) = generate(model, model_args, observations)
log_total_weight = log_weight
for i=2:num_samples
verbose && println("sample: $i of $num_samples")
(cand_model_trace, log_weight) = generate(model, model_args, observations)
log_total_weight = logsumexp(log_total_weight, log_weight)
if bernoulli(exp(log_weight - log_total_weight))
model_trace = cand_model_trace
end
end
log_ml_estimate = log_total_weight - log(num_samples)
return (model_trace::U, log_ml_estimate::Float64)
end
function importance_resampling(
model::GenerativeFunction{T,U},
model_args::Tuple,
observations::ChoiceMap,
proposal::GenerativeFunction{V,W},
proposal_args::Tuple,
num_samples::Int;
verbose=false) where {T,U,V,W}
(proposal_choices, proposal_weight, _) = propose(proposal, proposal_args)
constraints = merge(observations, proposal_choices)
(model_trace::U, model_weight) = generate(model, model_args, constraints)
log_total_weight = model_weight - proposal_weight
for i=2:num_samples
verbose && println("sample: $i of $num_samples")
(proposal_choices, proposal_weight, _) = propose(proposal, proposal_args)
constraints = merge(observations, proposal_choices)
(cand_model_trace, model_weight) = generate(model, model_args, constraints)
log_weight = model_weight - proposal_weight
log_total_weight = logsumexp(log_total_weight, log_weight)
if bernoulli(exp(log_weight - log_total_weight))
model_trace = cand_model_trace
end
end
log_ml_estimate = log_total_weight - log(num_samples)
return (model_trace::U, log_ml_estimate::Float64)
end
export importance_sampling, importance_resampling