-
Notifications
You must be signed in to change notification settings - Fork 3.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
EquilibriumAggregation
global aggregation layer
#4522
Conversation
Codecov Report
@@ Coverage Diff @@
## master #4522 +/- ##
==========================================
+ Coverage 82.87% 82.95% +0.07%
==========================================
Files 331 332 +1
Lines 18197 18288 +91
==========================================
+ Hits 15081 15171 +90
- Misses 3116 3117 +1
Help us with your feedback. Take ten seconds to tell us how you rate us. |
This is still pretty heavy WIP (and I'm not even sure if it is working 100% yet). @rusty1s do you think there is a good example graph classification problem to try and compare accuracy to? The paper uses the I plan also to have the median example from the paper as an example (acts as a reasonable test it works). |
EquilibriumAggregation
global aggregation layerEquilibriumAggregation
global aggregation layer
I've updated this, and just added a simple example of the median readout learning. I didn't have time to do a larger example yet sorry. |
Hey @rusty1s this is mostly ready for review (though the example I use is of the median aggregation). I'm currently trying to play around to get this to converge well, but I haven't had much luck. I'm not sure if its because of parameters I've used or details missing in the paper (like the alpha used in the inner optimization). But perhaps convergence is just slow - they run for 10 million steps which I'm trying to do now but will take quite some time on my resources :-) |
After ~1 million iterations no signs of convergence, guessing I've got something wrong. Maybe @FabianFuchsML has some input? |
momentum = torch.zeros_like(y) | ||
for _ in range(iterations): | ||
val = func(x, y, batch) | ||
grad = torch.autograd.grad(val, y, create_graph=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The authors of the paper suggest adding an axillary loss on the grad
here (grad.square().sum()/iterations
) to be precise.
The least intrusive way I could think to do this would be to accumulate the grad in a property like:
grad = ...
self.aux_loss += grad.square().sum()/iterations
then to call backwards on it by using a hook? I'm not eventually sure this makes sense.
Any suggestions?
b32b0fa
to
b3438f2
Compare
I've moved this to the aggregation module but there are a few things missing which I'll need opinions on. Leaving comments in the code. |
def energy(self, x: Tensor, y: Tensor, index: Optional[Tensor]): | ||
return self.potential(x, y, index) + self.reg(y) | ||
|
||
def forward(self, x: Tensor, index: Optional[Tensor] = None, *, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a note this was all written before the Aggregation
module, so its not using reduce
yet - need to rethink some of the implementation to make use of it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to make use of reduce
. The LSTMAggregation
module doesn't make use of it as well.
if ptr is not None: | ||
raise ValueError(f"{self.__class__} doesn't support `ptr`") | ||
|
||
if dim_size is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could technically support this, but I don't know what you'd expect the behavior to be: Unlike sum/mean etc, we can't just assume the input is zero (or if we do, the output would just be random).
Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please have a look at https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html. The dim_size
argument ideally is identical to index_size = index.max() + 1
in your case. If passed, there is no need to compute it in the first place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't fully get that sorry. If dim_size
is passed
- if
dim_size < index_size
I'd expect error, - if
dim_size > index_size
I'm not sure what to expect - I guess zero in for those entries?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that is correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @Padarn! Added some minor comments.
`EquilibriumAggregation` to learn to take the median of | ||
a set of numbers | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great to add a few comments about the convergence of EquilibriumAggregation
.
examples/equilibrium_median.py
Outdated
dist = np.random.choice([norm, gamma, uniform]) | ||
x = dist.sample((input_size, 1)) | ||
y = model(x) | ||
loss = (y - x.median()).norm(2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems the loss is not normalized by the input size.
loss = (y - x.median()).norm(2) | |
loss = (y - x.median()).norm(2) / input_size |
thanks @lightaime - I've got a busy week, I'll address your comments ASAP :-) |
Thanks for the reviews - I've updated based on most comments. I still have an uncertainty here: #4522 (comment) but it maybe for another PR |
Thanks @Padarn. Can you also resolve the merge conflicts? I will take a look soon, and address the unresolved comment as well. |
Hey @rusty1s, any thoughts on this one? Happy to keep working on it or break it up, but might be good to not leave it here unfinished. |
Yes, now that #4779 is merged, let's try to integrate this next :) |
Cool. I'll make another pass to see if there is anything I can clean up based on #4779 |
Co-authored-by: Guohao Li <lightaime@gmail.com>
Co-authored-by: Guohao Li <lightaime@gmail.com>
bd4bb93
to
fa47205
Compare
6656ea1
to
7359edc
Compare
hey @rusty1s and @lightaime what do you guys think about merging this one? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LGTM!
Thanks @lightaime - I'm going to merge this one, but will try and build some more example use cases before we release the new version so make sure its good. |
Implements a simple version of the paper "Equilibrium Aggregation: Encoding Sets via Optimization". Note that exact details of the paper have not been implemented yet, and I plan to leave specific details about the optimizer out of this MR, it should only add a new readout layer than implements the method.
TODO:
Addresses #4447