New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DGI FullBatch methods fix #1415
DGI FullBatch methods fix #1415
Conversation
Code Climate has analyzed commit 3ec998a and detected 0 issues on this pull request. View more on Code Climate. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I agree that this is an appropriate fix for now, it seems like a reasonable generalisation of our current assumption, and doesn't require infecting too much of the rest of the library.
I guess the fully general form would be a method like Generator.output_shape_for_batch(inputs)
that is used to query the appropriate output shape for each batch. Fortunately, we can add this in a backwards compatible way by adding the method to Generator
with a default implementation that just does this min(...)
call.
My only concern is that it's a bit subtle and could do with a little more explanation.
stellargraph/mapper/corrupted.py
Outdated
@@ -171,7 +171,7 @@ def corrupt_group(group_idx, group): | |||
] | |||
|
|||
# create the appropriate labels | |||
batch_size = inputs[0].shape[: self.num_batch_dims] | |||
batch_size = min(inp.shape[: self.num_batch_dims] for inp in inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, so this works because it's picking up the output indices tensor with shape (1, num_outputs)
, right?
Maybe we could expand the comment to discuss this briefly? E.g.
# create the appropriate labels: this needs to match the number of outputs in the batch.
# For full batch methods, this is the output indices tensor (batch shape = 1 × num_out),
# not the features tensor (1 × num_nodes). Rather than hard-code, we assume that the
# smallest "batch shape" is the right one.
We could potentially also rename this like the following, to hammer home the point:
batch_size = min(inp.shape[: self.num_batch_dims] for inp in inputs) | |
output_batch_shape = min(inp.shape[: self.num_batch_dims] for inp in inputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, so this works because it's picking up the output indices tensor with shape (1, num_outputs)
Yep that's right, and for SAGE algos all inputs have the same batch shape anyways so we're all good (for now...)
Maybe we could expand the comment to discuss this briefly?
Yeah I think that's a could point, its a strange and subtle bit of code that will probably trip someone up in the future.
This PR fixes the bug that caused DGI to break on fullbatch when not all nodes were specified. The bug was cause by the fact that fullbatch generators pass in:
(features, target_indicies, adj)
And DGI infers the batch shape from the first input,
features
in this case, causing the wrong batch shape to be inferred for full batch methods. There are two fixes (that I can think of):num_batch_dims
This PR makes the second change as a minimal fix - minimal code and test changes - to get things working quickly but might be too hacky. Thoughts?
See #1349