New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SVGD #1671
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks fine to me - let's see how the tests behave. (Just saw the WIP).
Needs more tests I imagine, and there might be some errors with the build. Not merging yet obviously.
This is still WIP. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rather refactor pymc3.variational.updates
pymc3/variational/updates.py
Outdated
class Update(object): | ||
|
||
def __init__(self): | ||
self.__dict__.update(locals()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still not satisfied about the current solution.
- Here you don't need init at all
- locals have
self
and create circular reference
pymc3/variational/updates.py
Outdated
|
||
def __init__(self, lr=0.01, *args, **kwargs): | ||
Update.__init__(self, *args, **kwargs) | ||
self.__dict__.update(locals()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You update dict twice, moreover passing any additional argument will cause an error, but signature allows it (e.g. SGD(foo=1)
will fail, but it is not obvious)
pymc3/variational/updates.py
Outdated
Update.__init__(self, *args, **kwargs) | ||
self.__dict__.update(locals()) | ||
|
||
def __call__(self, p, g): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Current implementation does not handle the case when list of shared variables is passed, moreover dict is better format of return value. The same comments are about other updates.
pymc3/variational/updates.py
Outdated
def __init__(self): | ||
self.__dict__.update(locals()) | ||
|
||
def __call__(self, params, grads): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not computing grads from objective wrt params(potentially list of params) on the fly? That forces user to do unnecessary work
Update method can be just a helper function overriden below. What call will do is computing gradients and creating updates.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, signature(interface) changes in inherited classes, it must be the same according to OOP style
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think we should instead use https://github.com/Lasagne/Lasagne/blob/master/lasagne/updates.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@twiecki I think Lasagne has good and reliable implementation. If they don't mind copy-paste that would be great. Otherwise we should probably change api so that it will be the same as Lasagne's api.
@ferrine I added updates.py from lasagne. Not sure it works correctly yet. |
Great! We can check it with simple advi problem. |
@twiecki is there any work left? I'm looking forward merging this PR as it will be useful in many cases |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is any work, what about moving updates.py
in a separate PR?
@@ -0,0 +1 @@ | |||
docs/source/notebooks/ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New line
sumkxy = tt.sum(Kxy, axis=1).dimshuffle(0, 'x') | ||
dxkxy = tt.add(dxkxy, tt.mul(X, sumkxy)) / (h ** 2) | ||
|
||
return (Kxy, dxkxy) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant parentheses
@ferrine I went ahead and merged to get things moving. I'll address the parentheses thing and add docs in a separate PR. |
No description provided.