Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cutlass 3xTF32,DMMA based L2/cosine distance kernels for SM 8.0 or higher #939

Merged
merged 28 commits into from
Nov 16, 2022

Conversation

mdoijade
Copy link
Contributor

@mdoijade mdoijade commented Oct 21, 2022

-- 3xTF32 cutlass based L2 exp/cosine kernel provides 3.5x speedup for fp32 inputs compared to existing pairwise distance kernel for ampere or higher.
-- DMMA cutlass based implementation for L2 exp/cosine provides 2.6x speedup instead of existing double precision FMA pipeline based kernel.
-- add cutlass as header only dependency to RAFT.

@cjnolet cjnolet added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change labels Oct 27, 2022
@GPUtester
Copy link
Contributor

Can one of the admins verify this patch?

Admins can comment ok to test to allow this one PR to run or add to allowlist to allow all future PRs from the same author to run.

cpp/cmake/thirdparty/get_cutlass.cmake Outdated Show resolved Hide resolved
cpp/cmake/thirdparty/get_cutlass.cmake Outdated Show resolved Hide resolved
cpp/cmake/thirdparty/get_cutlass.cmake Outdated Show resolved Hide resolved
cpp/cmake/thirdparty/get_cutlass.cmake Outdated Show resolved Hide resolved
cpp/cmake/thirdparty/get_cutlass.cmake Outdated Show resolved Hide resolved
cpp/cmake/thirdparty/get_cutlass.cmake Outdated Show resolved Hide resolved
cpp/cmake/thirdparty/get_cutlass.cmake Outdated Show resolved Hide resolved
cpp/cmake/thirdparty/get_cutlass.cmake Outdated Show resolved Hide resolved
cpp/cmake/thirdparty/get_cutlass.cmake Outdated Show resolved Hide resolved
cpp/cmake/thirdparty/get_cutlass.cmake Outdated Show resolved Hide resolved
Copy link
Contributor

@robertmaynard robertmaynard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By these lines it looks like cutlass is now a hard requirement to use RAFT no matter the configuration, since the dependency is on the core raft library and not raft::distance.

I expect that is wrong, and this all needs to be moved to raft::distance

cpp/CMakeLists.txt Outdated Show resolved Hide resolved
cpp/CMakeLists.txt Outdated Show resolved Hide resolved
Copy link
Contributor

@tfeher tfeher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mdoijade for the PR! It is exciting to see these changes! The PR looks great overall, I have mostly requested documenting from where (and how) the CUTLASS code was adapted from. The goal with this to make it easier to adapt to changes in future CUTLASS releases.

Please remove the [WIP] tag and add a PR title and please add a description: it would be great to have a few lines on how the normalization factors are handled using iterators and epilogue functions.

cpp/include/raft/util/cudart_utils.hpp Outdated Show resolved Hide resolved
cpp/test/distance/dist_adj.cu Show resolved Hide resolved
cpp/include/raft/distance/detail/cosine.cuh Show resolved Hide resolved
cpp/include/raft/distance/detail/pairwise_distance_gemm.h Outdated Show resolved Hide resolved
cpp/cmake/thirdparty/get_cutlass.cmake Outdated Show resolved Hide resolved
cpp/cmake/thirdparty/get_cutlass.cmake Outdated Show resolved Hide resolved
cpp/cmake/thirdparty/get_cutlass.cmake Outdated Show resolved Hide resolved
cpp/cmake/thirdparty/get_cutlass.cmake Outdated Show resolved Hide resolved
cpp/cmake/thirdparty/get_cutlass.cmake Outdated Show resolved Hide resolved
cpp/cmake/thirdparty/get_cutlass.cmake Outdated Show resolved Hide resolved
@mdoijade mdoijade changed the title [WIP] Add cutlass 3xTF32,DMMA based L2/cosine distance kernels for SM 8.0 or higher Add cutlass 3xTF32,DMMA based L2/cosine distance kernels for SM 8.0 or higher Nov 11, 2022
Copy link
Contributor

@tfeher tfeher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Mahesh for addressing the issues. The C++ changes look good to me.

@@ -217,6 +218,7 @@ target_link_libraries(
CUDA::cusolver${_ctk_static_suffix}
CUDA::cusparse${_ctk_static_suffix}
$<$<BOOL:${RAFT_ENABLE_thrust_DEPENDENCY}>:raft::Thrust>
nvidia::cutlass::cutlass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This dependency needs to be on the raft-distance target and not the raft target

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertmaynard I see that several tests and core algorithms depend on cosine/euclidean distance headers where we are using cutlass, hence I think it needs to be dependency on raft target. without it I am seeing several build failures when those sources are built.
I've modified get_cutlass.cmake from raft-distance-exports to raft-exports. Can this resolve the build issue in pylibraft?
I've submitted this change and waiting to see if CI passes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjnolet Are you okay with cutlass being a hard requirement for raft?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer if we could make CUTLASS a dependency only of raft::distance (which pylibraft uses).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mdoijade any tests/benchmarks and downstream projects which depend on distances also specify and use the raft::distance target.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I've tried to add dependency on raft_distance wherever required and got the build working locally, please review if it looks good now.
also I've enabled the cutlass path only till cuda 11.x.

@@ -422,7 +424,7 @@ if(RAFT_COMPILE_NN_LIBRARY)
INTERFACE_POSITION_INDEPENDENT_CODE ON
)

target_link_libraries(raft_nn_lib PUBLIC faiss::faiss raft::raft)
target_link_libraries(raft_nn_lib PUBLIC faiss::faiss raft::raft raft_distance)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does raft_nn_lib need the raft_distance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raft_nn_lib sources make use of distance headers such as cosine.cuh/euclidean.cuh where it encounters build failure with missing cutlass headers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is raft_nn_lib srcs that have this dependency, the target_link_libraries usage should be PRIVATE. Secondly it should be a dependency on the cutlass target and not raft_distance.

This currently creates a hard dependency for downstream consumers of raft, which breaks requests like find_package(raft COMPONENTS nn) since the raft_distance target won't exist for those users.

Copy link
Contributor Author

@mdoijade mdoijade Nov 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so can I make it target_link_libraries(raft_nn_lib PUBLIC faiss::faiss raft::raft nvidia::cutlass::cutlass) ?
and then will I need to modify get_cutlass.cmake to add rapids_export_package to raft-nn-lib-exports ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you would want target_link_libraries(raft_nn_lib PUBLIC faiss::faiss raft::raft PRIVATE nvidia::cutlass::cutlass) and no other change would be required.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made these changes and pushed them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @cjnolet and @robertmaynard for quick help and fixes!

Copy link
Contributor

@robertmaynard robertmaynard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CMake code looks good!

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@cjnolet
Copy link
Member

cjnolet commented Nov 16, 2022

rerun tests

@cjnolet
Copy link
Member

cjnolet commented Nov 16, 2022

@gpucibot merge

1 similar comment
@cjnolet
Copy link
Member

cjnolet commented Nov 16, 2022

@gpucibot merge

@rapids-bot rapids-bot bot merged commit 611abc7 into rapidsai:branch-22.12 Nov 16, 2022
rapids-bot bot pushed a commit that referenced this pull request Nov 18, 2022
PR #939 introduced CUTLASS dependency. When compiled in debug mode, this leads to the following error:

```
ptxas error   : Stack size for entry function '_ZN12raft_cutlass6KernelINS_...' cannot be statically determined
```

This would be normally just a warning, but we treat warnings as errors. This PR disables the warning in Debug mode.

Authors:
  - Tamas Bela Feher (https://github.com/tfeher)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1033
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CMake cpp improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
Development

Successfully merging this pull request may close these issues.

5 participants