Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EquilibriumAggregation global aggregation layer #4522

Merged
merged 27 commits into from
Jul 27, 2022

Conversation

Padarn
Copy link
Contributor

@Padarn Padarn commented Apr 24, 2022

Implements a simple version of the paper "Equilibrium Aggregation: Encoding Sets via Optimization". Note that exact details of the paper have not been implemented yet, and I plan to leave specific details about the optimizer out of this MR, it should only add a new readout layer than implements the method.

TODO:

  • Example of the 'exact' methods converging (median)
  • Cleanup of the readout layer
  • Support batched datasets (similar to other glob layers)
  • Tests

Addresses #4447

@codecov
Copy link

codecov bot commented Apr 24, 2022

Codecov Report

Merging #4522 (a3a3a96) into master (0c24277) will increase coverage by 0.07%.
The diff coverage is 97.80%.

@@            Coverage Diff             @@
##           master    #4522      +/-   ##
==========================================
+ Coverage   82.87%   82.95%   +0.07%     
==========================================
  Files         331      332       +1     
  Lines       18197    18288      +91     
==========================================
+ Hits        15081    15171      +90     
- Misses       3116     3117       +1     
Impacted Files Coverage Δ
torch_geometric/nn/aggr/equilibrium.py 97.77% <97.77%> (ø)
torch_geometric/nn/aggr/__init__.py 100.00% <100.00%> (ø)
torch_geometric/nn/aggr/base.py 95.74% <0.00%> (+2.12%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us.

@Padarn
Copy link
Contributor Author

Padarn commented Apr 24, 2022

This is still pretty heavy WIP (and I'm not even sure if it is working 100% yet).

@rusty1s do you think there is a good example graph classification problem to try and compare accuracy to? The paper uses the MOLPCBA benchmark (would I could not see in the repo yet, though we could add it). I currently just modified the MUTAG example, there is no performance improvement (or degradation).

I plan also to have the median example from the paper as an example (acts as a reasonable test it works).

@Padarn Padarn added the feature label Apr 24, 2022
@Padarn Padarn self-assigned this Apr 24, 2022
examples/mutag_gin_equib.py Outdated Show resolved Hide resolved
examples/mutag_gin_equib.py Outdated Show resolved Hide resolved
torch_geometric/nn/glob/__init__.py Outdated Show resolved Hide resolved
torch_geometric/nn/glob/equilibrium_aggregation.py Outdated Show resolved Hide resolved
torch_geometric/nn/glob/equilibrium_aggregation.py Outdated Show resolved Hide resolved
torch_geometric/nn/glob/equilibrium_aggregation.py Outdated Show resolved Hide resolved
torch_geometric/nn/glob/equilibrium_aggregation.py Outdated Show resolved Hide resolved
torch_geometric/nn/glob/equilibrium_aggregation.py Outdated Show resolved Hide resolved
torch_geometric/nn/glob/equilibrium_aggregation.py Outdated Show resolved Hide resolved
@Padarn Padarn changed the title [WIP] EquilibriumAggregation global aggregation layer EquilibriumAggregation global aggregation layer Apr 29, 2022
@Padarn
Copy link
Contributor Author

Padarn commented Apr 29, 2022

I've updated this, and just added a simple example of the median readout learning. I didn't have time to do a larger example yet sorry.

@Padarn
Copy link
Contributor Author

Padarn commented May 6, 2022

Hey @rusty1s this is mostly ready for review (though the example I use is of the median aggregation).

I'm currently trying to play around to get this to converge well, but I haven't had much luck. I'm not sure if its because of parameters I've used or details missing in the paper (like the alpha used in the inner optimization). But perhaps convergence is just slow - they run for 10 million steps which I'm trying to do now but will take quite some time on my resources :-)

@Padarn
Copy link
Contributor Author

Padarn commented May 6, 2022

After ~1 million iterations no signs of convergence, guessing I've got something wrong. Maybe @FabianFuchsML has some input?

momentum = torch.zeros_like(y)
for _ in range(iterations):
val = func(x, y, batch)
grad = torch.autograd.grad(val, y, create_graph=True,
Copy link
Contributor Author

@Padarn Padarn May 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The authors of the paper suggest adding an axillary loss on the grad here (grad.square().sum()/iterations) to be precise.

The least intrusive way I could think to do this would be to accumulate the grad in a property like:

grad = ...
self.aux_loss +=  grad.square().sum()/iterations

then to call backwards on it by using a hook? I'm not eventually sure this makes sense.

Any suggestions?

@Padarn
Copy link
Contributor Author

Padarn commented May 28, 2022

I've moved this to the aggregation module but there are a few things missing which I'll need opinions on. Leaving comments in the code.

def energy(self, x: Tensor, y: Tensor, index: Optional[Tensor]):
return self.potential(x, y, index) + self.reg(y)

def forward(self, x: Tensor, index: Optional[Tensor] = None, *,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note this was all written before the Aggregation module, so its not using reduce yet - need to rethink some of the implementation to make use of it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to make use of reduce. The LSTMAggregation module doesn't make use of it as well.

if ptr is not None:
raise ValueError(f"{self.__class__} doesn't support `ptr`")

if dim_size is not None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could technically support this, but I don't know what you'd expect the behavior to be: Unlike sum/mean etc, we can't just assume the input is zero (or if we do, the output would just be random).

Thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please have a look at https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html. The dim_size argument ideally is identical to index_size = index.max() + 1 in your case. If passed, there is no need to compute it in the first place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't fully get that sorry. If dim_size is passed

  • if dim_size < index_size I'd expect error,
  • if dim_size > index_size I'm not sure what to expect - I guess zero in for those entries?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is correct.

Copy link
Contributor

@lightaime lightaime left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @Padarn! Added some minor comments.

test/nn/aggr/test_equilibrium.py Outdated Show resolved Hide resolved
test/nn/aggr/test_equilibrium.py Outdated Show resolved Hide resolved
`EquilibriumAggregation` to learn to take the median of
a set of numbers
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to add a few comments about the convergence of EquilibriumAggregation.

torch_geometric/nn/aggr/__init__.py Outdated Show resolved Hide resolved
dist = np.random.choice([norm, gamma, uniform])
x = dist.sample((input_size, 1))
y = model(x)
loss = (y - x.median()).norm(2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the loss is not normalized by the input size.

Suggested change
loss = (y - x.median()).norm(2)
loss = (y - x.median()).norm(2) / input_size

torch_geometric/nn/aggr/equilibrium.py Outdated Show resolved Hide resolved
torch_geometric/nn/aggr/equilibrium.py Outdated Show resolved Hide resolved
@Padarn
Copy link
Contributor Author

Padarn commented Jun 1, 2022

thanks @lightaime - I've got a busy week, I'll address your comments ASAP :-)

@Padarn
Copy link
Contributor Author

Padarn commented Jun 4, 2022

Thanks for the reviews - I've updated based on most comments. I still have an uncertainty here: #4522 (comment) but it maybe for another PR

@rusty1s
Copy link
Member

rusty1s commented Jun 6, 2022

Thanks @Padarn. Can you also resolve the merge conflicts? I will take a look soon, and address the unresolved comment as well.

@Padarn
Copy link
Contributor Author

Padarn commented Jun 25, 2022

Hey @rusty1s, any thoughts on this one? Happy to keep working on it or break it up, but might be good to not leave it here unfinished.

@rusty1s
Copy link
Member

rusty1s commented Jun 25, 2022

Yes, now that #4779 is merged, let's try to integrate this next :)

@Padarn
Copy link
Contributor Author

Padarn commented Jun 25, 2022

Cool. I'll make another pass to see if there is anything I can clean up based on #4779

@Padarn
Copy link
Contributor Author

Padarn commented Jul 26, 2022

hey @rusty1s and @lightaime what do you guys think about merging this one?

Copy link
Contributor

@lightaime lightaime left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM!

@Padarn
Copy link
Contributor Author

Padarn commented Jul 27, 2022

Thanks @lightaime - I'm going to merge this one, but will try and build some more example use cases before we release the new version so make sure its good.

@Padarn Padarn merged commit 333d3d3 into pyg-team:master Jul 27, 2022
@Padarn Padarn deleted the padarn/optim-embedding branch July 27, 2022 07:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants