Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Triplet sampling in NeighborLoader #6004

Merged
merged 14 commits into from Nov 23, 2022
Merged

Triplet sampling in NeighborLoader #6004

merged 14 commits into from Nov 23, 2022

Conversation

rusty1s
Copy link
Member

@rusty1s rusty1s commented Nov 18, 2022

This PR implements triplet sampling in link-level tasks, where we want only want to sample a negative destination node.
This PR got quite large, especially because I stumbled upon some weird things in the current link-level pipeline on temporal graphs, e.g., we don't want to sample nodes that do not yet exist for a given edge label timestamp.

When neg_sampling="triplet", then the returned Data object holds:

  • a src_index which denotes the index from left-hand-side embeddings
  • a dst_pos_index which denotes the index from positive right-hand-side embeddings
  • a dst_neg_index which denotes the index from negative right-hand-side embeddings

Overall, it introduces

  • NegativeSamplingConfig with strategy=binary/triplet
  • Deprecates neg_sampling_ratio in favor of neg_sampling
  • Introduces a new method called neg_sampling that respects the timestamp of nodes while sampling
  • Refactors and extends sample_from_edges => Overall, this function got quite large, and I will consider separating them into multiple functions in the future.

@codecov
Copy link

codecov bot commented Nov 18, 2022

Codecov Report

Merging #6004 (d344788) into master (d87c1ba) will decrease coverage by 0.06%.
The diff coverage is 55.68%.

❗ Current head d344788 differs from pull request most recent head 247e16d. Consider uploading reports for the commit 247e16d to get more accurate results

@@            Coverage Diff             @@
##           master    #6004      +/-   ##
==========================================
- Coverage   84.73%   84.66%   -0.07%     
==========================================
  Files         361      361              
  Lines       20215    20185      -30     
==========================================
- Hits        17129    17090      -39     
- Misses       3086     3095       +9     
Impacted Files Coverage Δ
torch_geometric/sampler/utils.py 85.71% <ø> (-0.44%) ⬇️
torch_geometric/sampler/neighbor_sampler.py 67.57% <48.80%> (-10.32%) ⬇️
torch_geometric/loader/link_loader.py 77.02% <57.14%> (-14.36%) ⬇️
torch_geometric/sampler/base.py 90.74% <90.90%> (-0.17%) ⬇️
torch_geometric/loader/link_neighbor_loader.py 100.00% <100.00%> (ø)
torch_geometric/nn/conv/pna_conv.py 92.77% <0.00%> (ø)
torch_geometric/utils/smiles.py 8.88% <0.00%> (+0.24%) ⬆️
torch_geometric/nn/aggr/basic.py 97.77% <0.00%> (+1.94%) ⬆️
torch_geometric/nn/aggr/multi.py 100.00% <0.00%> (+2.19%) ⬆️
torch_geometric/nn/aggr/fused.py 97.77% <0.00%> (+2.24%) ⬆️
... and 1 more

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

Copy link
Member

@wsad1 wsad1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some initial comments.

torch_geometric/sampler/base.py Outdated Show resolved Hide resolved
torch_geometric/loader/link_loader.py Show resolved Hide resolved
torch_geometric/sampler/neighbor_sampler.py Outdated Show resolved Hide resolved
torch_geometric/sampler/neighbor_sampler.py Show resolved Hide resolved
torch_geometric/sampler/neighbor_sampler.py Show resolved Hide resolved
torch_geometric/loader/link_loader.py Outdated Show resolved Hide resolved
examples/graph_sage_unsup_ppi.py Show resolved Hide resolved
test/data/test_lightning_datamodule.py Show resolved Hide resolved
@rusty1s
Copy link
Member Author

rusty1s commented Nov 21, 2022

Incorporated your review. Thank you!

Copy link
Member

@wsad1 wsad1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some final comments.

torch_geometric/sampler/neighbor_sampler.py Outdated Show resolved Hide resolved
torch_geometric/loader/link_neighbor_loader.py Outdated Show resolved Hide resolved
test/loader/test_link_neighbor_loader.py Outdated Show resolved Hide resolved
assert int(batch['paper'].batch.max()) + 1 == 32 + 16
# Check if each seed edge has a different source and dstination node:
assert batch['paper'].num_nodes >= 2 * (32 + 16)
assert int(batch['paper'].batch.max()) + 1 == 32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't this 64? Won't the negative samples also have a disjoint batch?

Copy link
Member Author

@rusty1s rusty1s Nov 23, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally think it makes more sense to group the negatives and positives to its own batch, in particular because they now share time information.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I see. But it seems weird that we now have more than 1 seed edge with the same batch number. Even though their computation graphs are in fact disjoint. But we can debate this later.

@wsad1 wsad1 self-requested a review November 22, 2022 08:55
@rusty1s
Copy link
Member Author

rusty1s commented Nov 23, 2022

@wsad1 Updated based on your comments.

Copy link
Member

@wsad1 wsad1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. Thanks.

assert int(batch['paper'].batch.max()) + 1 == 32 + 16
# Check if each seed edge has a different source and dstination node:
assert batch['paper'].num_nodes >= 2 * (32 + 16)
assert int(batch['paper'].batch.max()) + 1 == 32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I see. But it seems weird that we now have more than 1 seed edge with the same batch number. Even though their computation graphs are in fact disjoint. But we can debate this later.

@rusty1s rusty1s enabled auto-merge (squash) November 23, 2022 16:21
@rusty1s rusty1s merged commit 7991c1d into master Nov 23, 2022
@rusty1s rusty1s deleted the triplet_sampling branch November 23, 2022 16:34
JakubPietrakIntel pushed a commit to JakubPietrakIntel/pytorch_geometric that referenced this pull request Nov 25, 2022
This PR implements triplet sampling in link-level tasks, where we want
only want to sample a negative destination node.
This PR got quite large, especially because I stumbled upon some weird
things in the current link-level pipeline on temporal graphs, e.g., we
don't want to sample nodes that do not yet exist for a given edge label
timestamp.

When `neg_sampling="triplet"`, then the returned `Data` object holds:
* a `src_index` which denotes the index from left-hand-side embeddings
* a `dst_pos_index` which denotes the index from positive
right-hand-side embeddings
* a `dst_neg_index` which denotes the index from negative
right-hand-side embeddings

Overall, it introduces
* `NegativeSamplingConfig` with `strategy=binary/triplet`
* Deprecates `neg_sampling_ratio` in favor of `neg_sampling`
* Introduces a new method called `neg_sampling` that respects the
timestamp of nodes while sampling
* Refactors and extends `sample_from_edges` => Overall, this function
got quite large, and I will consider separating them into multiple
functions in the future.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants