-
Notifications
You must be signed in to change notification settings - Fork 0
/
sampling.py
37 lines (32 loc) · 1.14 KB
/
sampling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import random
def get_non_evidence_nodes(net, ev):
return [n for n in net._nodes if n not in ev]
def gibbs_sample(net, evidence, sample_fn, M, burnin, randomize=True):
"""Does Gibbs sampling on the net. After burnin, calls sample_fn(i, net) after each sweep
optionally, sample_fn may return True to halt the rest of the sampling
"""
net.evidence(evidence)
non_evidence_nodes = get_non_evidence_nodes(net, evidence)
for i in xrange(M+burnin):
if i >= burnin and sample_fn:
if sample_fn(i-burnin, net):
return
if randomize:
random.shuffle(non_evidence_nodes)
for n in non_evidence_nodes:
net.sample_node(n)
def gibbs_sample_dynamic_evidence(net, evidence_gen, sample_fn, M, burnin, randomize=True):
"""Same interface as gibbs_sample(), but at each iteration the generator evidence_gen
is queried for the next evidence state
"""
for i in xrange(M+burnin):
ev = next(evidence_gen)
net.evidence(ev)
non_evidence_nodes = get_non_evidence_nodes(net, ev)
if i >= burnin and sample_fn:
if sample_fn(i-burnin, net):
return
if randomize:
random.shuffle(non_evidence_nodes)
for n in non_evidence_nodes:
net.sample_node(n)