-
Notifications
You must be signed in to change notification settings - Fork 9
/
spe_to_dict.py
139 lines (127 loc) · 4.67 KB
/
spe_to_dict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# Copyright 2020 MIT Probabilistic Computing Project.
# See LICENSE.txt
"""Convert SPE to JSON friendly dictionary."""
from fractions import Fraction
import scipy.stats
from sympy import E
from sympy import sqrt
from ..spe import AtomicLeaf
from ..spe import ContinuousLeaf
from ..spe import DiscreteLeaf
from ..spe import NominalLeaf
from ..spe import ProductSPE
from ..spe import SumSPE
# Needed for "eval"
from ..sets import *
from ..transforms import Id
from ..transforms import Identity
from ..transforms import Radical
from ..transforms import Exponential
from ..transforms import Logarithm
from ..transforms import Abs
from ..transforms import Reciprocal
from ..transforms import Poly
from ..transforms import Piecewise
from ..transforms import EventInterval
from ..transforms import EventFiniteReal
from ..transforms import EventFiniteNominal
from ..transforms import EventOr
from ..transforms import EventAnd
def env_from_dict(env):
if env is None:
return None
# Used in eval.
return {eval(k): eval(v) for k, v in env.items()}
def env_to_dict(env):
if len(env) == 1:
return None
return {repr(k): repr(v) for k, v in env.items()}
def scipy_dist_from_dict(dist):
constructor = getattr(scipy.stats, dist['name'])
return constructor(*dist['args'], **dist['kwds'])
def scipy_dist_to_dict(dist):
return {
'name': dist.dist.name,
'args': dist.args,
'kwds': dist.kwds
}
def spe_from_dict(metadata):
if metadata['class'] == 'NominalLeaf':
symbol = Id(metadata['symbol'])
dist = {x: Fraction(w[0], w[1]) for x, w in metadata['dist']}
return NominalLeaf(symbol, dist)
if metadata['class'] == 'AtomicLeaf':
symbol = Id(metadata['symbol'])
value = float(metadata['value'])
env = env_from_dict(metadata['env'])
return AtomicLeaf(symbol, value, env=env)
if metadata['class'] == 'ContinuousLeaf':
symbol = Id(metadata['symbol'])
dist = scipy_dist_from_dict(metadata['dist'])
support = eval(metadata['support'])
conditioned = metadata['conditioned']
env = env_from_dict(metadata['env'])
return ContinuousLeaf(symbol, dist, support, conditioned, env=env)
if metadata['class'] == 'DiscreteLeaf':
symbol = Id(metadata['symbol'])
dist = scipy_dist_from_dict(metadata['dist'])
support = eval(metadata['support'])
conditioned = metadata['conditioned']
env = env_from_dict(metadata['env'])
return DiscreteLeaf(symbol, dist, support, conditioned, env=env)
if metadata['class'] == 'SumSPE':
children = [spe_from_dict(c) for c in metadata['children']]
weights = metadata['weights']
return SumSPE(children, weights)
if metadata['class'] == 'ProductSPE':
children = [spe_from_dict(c) for c in metadata['children']]
return ProductSPE(children)
assert False, 'Cannot convert %s to SPE' % (metadata,)
def spe_to_dict(spe):
if isinstance(spe, NominalLeaf):
return {
'class' : 'NominalLeaf',
'symbol' : spe.symbol.token,
'dist' : [
(str(x), (w.numerator, w.denominator))
for x, w in spe.dist.items()
],
'env' : env_to_dict(spe.env),
}
if isinstance(spe, AtomicLeaf):
return {
'class' : 'AtomicLeaf',
'symbol' : spe.symbol.token,
'value' : spe.value,
'env' : env_to_dict(spe.env),
}
if isinstance(spe, ContinuousLeaf):
return {
'class' : 'ContinuousLeaf',
'symbol' : spe.symbol.token,
'dist' : scipy_dist_to_dict(spe.dist),
'support' : repr(spe.support),
'conditioned' : spe.conditioned,
'env' : env_to_dict(spe.env),
}
if isinstance(spe, DiscreteLeaf):
return {
'class' : 'DiscreteLeaf',
'symbol' : spe.symbol.token,
'dist' : scipy_dist_to_dict(spe.dist),
'support' : repr(spe.support),
'conditioned' : spe.conditioned,
'env' : env_to_dict(spe.env),
}
if isinstance(spe, SumSPE):
return {
'class' : 'SumSPE',
'children' : [spe_to_dict(c) for c in spe.children],
'weights' : spe.weights,
}
if isinstance(spe, ProductSPE):
return {
'class' : 'ProductSPE',
'children' : [spe_to_dict(c) for c in spe.children],
}
assert False, 'Cannot convert %s to JSON' % (spe,)