Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DGI FullBatch methods fix #1415

Merged
merged 3 commits into from May 1, 2020

Conversation

kieranricardo
Copy link
Contributor

This PR fixes the bug that caused DGI to break on fullbatch when not all nodes were specified. The bug was cause by the fact that fullbatch generators pass in:

(features, target_indicies, adj)

And DGI infers the batch shape from the first input, features in this case, causing the wrong batch shape to be inferred for full batch methods. There are two fixes (that I can think of):

  1. Specify which input to infer the batch shape from for each generator along with num_batch_dims
  2. Take the minimum batch shape along all inputs

This PR makes the second change as a minimal fix - minimal code and test changes - to get things working quickly but might be too hacky. Thoughts?

See #1349

@codeclimate
Copy link

codeclimate bot commented Apr 30, 2020

Code Climate has analyzed commit 3ec998a and detected 0 issues on this pull request.

View more on Code Climate.

Copy link
Member

@huonw huonw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I agree that this is an appropriate fix for now, it seems like a reasonable generalisation of our current assumption, and doesn't require infecting too much of the rest of the library.

I guess the fully general form would be a method like Generator.output_shape_for_batch(inputs) that is used to query the appropriate output shape for each batch. Fortunately, we can add this in a backwards compatible way by adding the method to Generator with a default implementation that just does this min(...) call.

My only concern is that it's a bit subtle and could do with a little more explanation.

@@ -171,7 +171,7 @@ def corrupt_group(group_idx, group):
]

# create the appropriate labels
batch_size = inputs[0].shape[: self.num_batch_dims]
batch_size = min(inp.shape[: self.num_batch_dims] for inp in inputs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, so this works because it's picking up the output indices tensor with shape (1, num_outputs), right?

Maybe we could expand the comment to discuss this briefly? E.g.

# create the appropriate labels: this needs to match the number of outputs in the batch. 
# For full batch methods, this is the output indices tensor (batch shape = 1 × num_out), 
# not the features tensor (1 × num_nodes). Rather than hard-code, we assume that the
# smallest "batch shape" is the right one.

We could potentially also rename this like the following, to hammer home the point:

Suggested change
batch_size = min(inp.shape[: self.num_batch_dims] for inp in inputs)
output_batch_shape = min(inp.shape[: self.num_batch_dims] for inp in inputs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, so this works because it's picking up the output indices tensor with shape (1, num_outputs)

Yep that's right, and for SAGE algos all inputs have the same batch shape anyways so we're all good (for now...)

Maybe we could expand the comment to discuss this briefly?

Yeah I think that's a could point, its a strange and subtle bit of code that will probably trip someone up in the future.

@kieranricardo kieranricardo merged commit 3f7b0c3 into develop May 1, 2020
@kieranricardo kieranricardo deleted the bugfix/1349-dgi-fullbatch-partial-nodes branch May 1, 2020 00:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants