Skip to content

Commit af4e708

Browse files
committed
rename class
1 parent 61d42f9 commit af4e708

File tree

3 files changed

+16
-16
lines changed

3 files changed

+16
-16
lines changed

pymc3/tests/test_variational_inference.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pymc3 as pm
88
from pymc3 import Model, Normal
99
from pymc3.variational import (
10-
ADVI, FullRankADVI, SVGD, NF,
10+
ADVI, FullRankADVI, SVGD, NFVI,
1111
Empirical, ASVGD,
1212
MeanField, FullRank,
1313
fit, flows
@@ -161,13 +161,13 @@ def simple_model(simple_model_data):
161161

162162

163163
@pytest.fixture('module', params=[
164-
dict(cls=NF, init=dict(flow='scale-loc')),
164+
dict(cls=NFVI, init=dict(flow='scale-loc')),
165165
dict(cls=ADVI, init=dict()),
166166
dict(cls=FullRankADVI, init=dict()),
167167
dict(cls=SVGD, init=dict(n_particles=500, jitter=1)),
168168
dict(cls=ASVGD, init=dict(temperature=1.)),
169169
], ids=[
170-
'NF=scale-loc',
170+
'NFVI=scale-loc',
171171
'ADVI',
172172
'FullRankADVI',
173173
'SVGD',
@@ -202,11 +202,11 @@ def fit_kwargs(inference, using_minibatch):
202202
obj_optimizer=pm.adagrad_window(learning_rate=0.01, n_win=50),
203203
n=12000
204204
),
205-
(NF, 'full'): dict(
205+
(NFVI, 'full'): dict(
206206
obj_optimizer=pm.adagrad_window(learning_rate=0.01, n_win=50),
207207
n=12000
208208
),
209-
(NF, 'mini'): dict(
209+
(NFVI, 'mini'): dict(
210210
obj_optimizer=pm.adagrad_window(learning_rate=0.01, n_win=50),
211211
n=12000
212212
),
@@ -608,7 +608,7 @@ def test_hh_flow():
608608
cov = pm.floatX([[2, -1], [-1, 3]])
609609
with pm.Model():
610610
pm.MvNormal('mvN', mu=pm.floatX([0, 1]), cov=cov, shape=2)
611-
nf = NF('scale-hh*2-loc')
611+
nf = NFVI('scale-hh*2-loc')
612612
nf.fit(25000, obj_optimizer=pm.adam(learning_rate=0.001))
613613
trace = nf.approx.sample(10000)
614614
cov2 = pm.trace_cov(trace)

pymc3/variational/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
FullRankADVI,
2626
SVGD,
2727
ASVGD,
28-
NF,
28+
NFVI,
2929
Inference,
3030
fit
3131
)

pymc3/variational/inference.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -691,7 +691,7 @@ def run_profiling(self, n=1000, score=None, obj_n_mc=300, **kwargs):
691691
n=n, score=score, obj_n_mc=obj_n_mc, **kwargs)
692692

693693

694-
class NF(Inference):
694+
class NFVI(Inference):
695695
R"""
696696
Normalizing flow is a series of invertible transformations on initial distribution.
697697
@@ -753,7 +753,7 @@ def __init__(self, flow='planar*3',
753753
local_rv=None, model=None,
754754
scale_cost_to_minibatch=False,
755755
random_seed=None, start=None, jitter=.1):
756-
super(NF, self).__init__(
756+
super(NFVI, self).__init__(
757757
self.OP, self.APPROX, self.TF,
758758
flow=flow,
759759
local_rv=local_rv, model=model,
@@ -772,7 +772,7 @@ def from_flow(cls, flow):
772772
773773
Returns
774774
-------
775-
:class:`NF`
775+
:class:`NFVI`
776776
"""
777777
inference = object.__new__(cls)
778778
Inference.__init__(inference, KL, flow, None)
@@ -800,10 +800,10 @@ def fit(n=10000, local_rv=None, method='advi', model=None,
800800
- 'advi->fullrank_advi' for fitting ADVI first and then FullRankADVI
801801
- 'svgd' for Stein Variational Gradient Descent
802802
- 'asvgd' for Amortized Stein Variational Gradient Descent
803-
- 'nf' for Normalizing Flow
804-
- 'nf=formula' for Normalizing Flow using formula
803+
- 'nfvi' for Normalizing Flow
804+
- 'nfvi=formula' for Normalizing Flow using formula
805805
806-
model : :class:`pymc3.Model`
806+
model : :class:`Model`
807807
PyMC3 model for inference
808808
random_seed : None or int
809809
leave None to use package global RandomStream or other
@@ -833,7 +833,7 @@ def fit(n=10000, local_rv=None, method='advi', model=None,
833833
fullrank_advi=FullRankADVI,
834834
svgd=SVGD,
835835
asvgd=ASVGD,
836-
nf=NF
836+
nfvi=NFVI
837837
)
838838
if isinstance(method, str):
839839
method = method.lower()
@@ -853,9 +853,9 @@ def fit(n=10000, local_rv=None, method='advi', model=None,
853853
inference = FullRankADVI.from_advi(inference)
854854
logger.info('fitting fullrank advi ...')
855855
return inference.fit(n2, **kwargs)
856-
elif method.startswith('nf='):
856+
elif method.startswith('nfvi='):
857857
formula = method[3:]
858-
inference = NF(
858+
inference = NFVI(
859859
formula,
860860
local_rv=local_rv,
861861
model=model,

0 commit comments

Comments
 (0)