Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pyg-lib integration in HGTConv #6178

Merged
merged 343 commits into from
Mar 21, 2023
Merged

pyg-lib integration in HGTConv #6178

merged 343 commits into from
Mar 21, 2023

Conversation

puririshi98
Copy link
Contributor

@puririshi98 puririshi98 commented Dec 8, 2022

ready🎉

data = FakeHeteroDataset(num_node_types=16, num_edge_types=128).data.to('cuda')
master branch iterative algorithm fwd pass time: 0.0014522361755371093
average fwd pass time(pyg-lib w/ torch<1.14): 0.0004720401763916016
average fwd pass time(pyg-lib w/ torch>=1.14): 0.00046291828155517576

3x speedup. As part of my blog I will do a more extensive benchmark to see how runtime for for-loop, segment_matmul, and group_matmul backends of HGT compare agaisnt eachother

smaller heterograph:

data = FakeHeteroDataset(num_node_types=8, num_edge_types=16).data.to('cuda')
average fwd pass time: 0.00020465850830078124
average fwd pass time(pyg-lib w/ torch<1.14): 0.00013236522674560546
average fwd pass time(pyg-lib w/ torch>=1.14): 0.00013113975524902343

1.6x speedup

scaling:
image

For interactive graph: https://github.com/puririshi98/rgcn_pyg_lib_forward_bench/blob/main/3d_plot.py

torch_geometric/nn/conv/hgt_conv.py Outdated Show resolved Hide resolved
torch_geometric/nn/conv/hgt_conv.py Outdated Show resolved Hide resolved
torch_geometric/nn/conv/hgt_conv.py Outdated Show resolved Hide resolved
torch_geometric/nn/conv/hgt_conv.py Outdated Show resolved Hide resolved
@rusty1s rusty1s changed the title hgtconv pyg-lib grouped matmul integration pyg-lib integration in HGTConv Dec 8, 2022
rusty1s added a commit to pyg-team/pyg-lib that referenced this pull request Dec 9, 2022
for pyg-team/pytorch_geometric#6178

Co-authored-by: Rishi Puri <riship@riship-mlt.client.nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
@puririshi98 puririshi98 changed the title pyg-lib integration in HGTConv Draft: pyg-lib integration in HGTConv Dec 9, 2022
@puririshi98
Copy link
Contributor Author

still need to test and benchmark it but done w/ an initial draft

Copy link
Member

@wsad1 wsad1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @puririshi98 for all the work.
@rusty1s appreciate a quick look.

@puririshi98
Copy link
Contributor Author

basic correctness passes: https://github.com/puririshi98/rgcn_pyg_lib_forward_bench/blob/main/hgt_correctness.py
@rusty1s merge when u can :)

@rusty1s rusty1s merged commit 805d2d8 into master Mar 21, 2023
@rusty1s rusty1s deleted the hgtconv_pyglib branch March 21, 2023 15:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants