-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Refactoring OPVI to support Normalizing Flows #2306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I made this point to a the new 3.2_dev branch but seems like there are failing tests, e.g.:
|
The file is |
@twiecki this error can be caused by session scoped fixture that is generator itself. |
Happy to say "tests pass". I also archived 5% speed up on convolutional mnist example using GPU |
pymc3/variational/opvi.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: fix docs here
pymc3/variational/approximations.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: use node property
pymc3/variational/approximations.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: use dict for shared params
(cherry picked from commit d181cea)
(cherry picked from commit c74337e)
(cherry picked from commit 48c38b2)
(cherry picked from commit 0f97ba3)
LDA example runs with new refactoring at the same speed, but sgd eventually gave poor results. ADAM solved the problem |
@twiecki Let me confirm that we can merge this on master? |
order = ArrayOrdering(vars) | ||
if inputvar is None: | ||
inputvar = tt.vector('flat_view', dtype=theano.config.floatX) | ||
if vars: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't if vars:
one less indent here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it should not, I use if there for the case of empty vars. or else I get exception from flatten
pymc3/distributions/dist_math.py
Outdated
if delta.ndim > 1: | ||
result = tt.batched_dot(result, delta) | ||
else: | ||
result = result.dot(delta.T) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't here result = result.dot(delta)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd better delete this function as it is not used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep I agree
pymc3/tests/conftest.py
Outdated
|
||
|
||
@pytest.fixture(scope="session", autouse=True) | ||
@pytest.yield_fixture(scope="function", autouse=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't yield_fixture deprecated? https://docs.pytest.org/en/latest/fixture.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, why changing the scope from "session" to "function"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do drop context after each test
------- | ||
list | ||
""" | ||
if isinstance(params, list): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cast_to_list() raise TypeError if list or tuple is given. It contradicts docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I forgot to change it. Thanks
pymc3/variational/opvi.py
Outdated
[self.view_local(l_posterior, name) | ||
for name in self.local_names] | ||
) | ||
sample_fn = theano.function([theano.In(s, 'draws', 1)], sampled) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you tell me why you use theano.In()
here? I think theano.function([s], sampled)
might work, but there might be some reason.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no reason. Just decided to use it when first implemented
@memoize | ||
def histogram_logp(self): | ||
"""Symbolic logp for every point in trace | ||
def histogram(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So histogram is still a property but user only interact it with Empirical right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I remember we decided that Empirical is about what is a class about and Histogram is about how it stores samples. I see no bad in such property
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep sounds good to me.
LGTM, I have no more comments. |
@junpenglao Those are used for autoencoding variational bayes. Global variables are relevant to the model, while global variables are associated with observatons, often referred to as latent variables. |
Local variables are somehow conditioned on data and thus encoded. So if we deal with generative model these variables are encoded in inference when computing logp. As encoder maps to mu and sd we do not need extra transformations there and they need different treatment Global variables in contrast are independently sampled from data and need to be passed through trandforms |
I'm not sure all those
|
I use node property for getting static symbolic references to nodes after they are created |
I just don't see why there is a |
I'm not sure about the extent to which memoization makes debugging difficult. However, we can easily remove memoize decorator if it found not to be effective. So, I'd say merging this PR is reasonable. Now tests have already passed and we can postpone to check the effectiveness of memoization. I think if |
I agree that this is not critical for this PR. with pm.Model() as model:
a = pm.Normal('a')
with model:
trace = pm.sample()
with model:
b = pm.Normal('b', mu=a)
trace2 = pm.sample() This currently fails with a disconnected input error, because |
I half-agree with @taku-y. The conservative thing to do, however, would be to remove memoization for now and merge them back in once we've proven their worth. |
Memoization is not for performance. Properties are called on construction phase. Memoization is for creating things I can reuse and access being confident |
Problem with disconnected input is more about how we do hashing. We just need make it dependent on vars |
If it is about keeping the result unique I think in most cases it is much better to just do that work in |
@aseyboldt the idea behind that is the following
I see no alternative to memoization About serialization. I've moved to using dicts as shared_params as they are easy to serialize (with get/set value). This mechanism can be discussed |
Do you know a flexible way to put all that stuff to init? Only metaclass comes in mind for me. But it for sure will be much more complicated |
Some tests related to SVGD failed. It seems relevant to the last two commits. I remember that tests has passed before these commits. |
(cherry picked from commit cc4291e)
I wanted to take the time to read though the variational code properly for quite some time, I guess I'll do that this weekend. I'll let you know if I come up with an alternative. |
@ferrine Ready to merge? |
Yes |
Thanks! |
Architecture I created this winter for variational inference supposed to be modular so implementations of new method is$z_0$ ) and probability ($q(z)$) for$z$ ). So not much left for a developer: there is abstract method $f_{\theta} : z_0 \rightarrow z $ $z$ and parametrized with $\theta$ . The result is approximate posterior. Note that intermediate $z_t$ are not available in that case and we can't compute $q(z)$ for the flow.
few lines of code with math for computing samples from initial distribution (
generated posterior (
that generates
That PR will solve the problem using new internal architecture. I'll rely on symbolic entry points and theano.clone for archiving necessary modularity and control