Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose linalg::dot in public API #968

Merged
merged 14 commits into from
Nov 10, 2022
Merged

Conversation

benfred
Copy link
Member

@benfred benfred commented Oct 31, 2022

Closes #805

@benfred benfred requested review from a team as code owners October 31, 2022 23:47
@benfred benfred added non-breaking Non-breaking change enhancement New feature or request and removed cpp CMake labels Oct 31, 2022
@benfred benfred added improvement Improvement / enhancement to an existing function and removed enhancement New feature or request labels Oct 31, 2022
Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again for this PR! We are going to want to think about whether these (axpy, dot, etc...) should be accepting general mdspan or whether we should be constraining them to be vectors up front.

It would also be nice to see the current vector factory functions made more flexible to enable strided layouts rather than adding new functions.

These types of examples (using the existing device vector factory functions to create a strided vector) would be great to have in the quick start as well.

*/
template <typename ElementType, typename IndexType = int, typename LayoutPolicy = layout_stride>
auto make_strided_device_vector_view(ElementType* ptr, IndexType n, IndexType stride)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than adding another factory function for a strided vector, why not just allow a strided layout to be configured in the make_device_vector_view and make_host_vector_view?

Right now the make_*_vector_view automatically configures a row-major layout but the layout should really be configurable (and potentially strided, or col major if desired).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated make_device_vector_view to allow strided input here - let me know what you think.

cpp/include/raft/core/device_mdspan.hpp Outdated Show resolved Hide resolved
template <typename InputType1,
typename InputType2,
typename OutputType,
typename = raft::enable_if_input_device_mdspan<InputType1>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I brought this up with the axpy as well, but it seems weird to accept a general mdspan for this when what we are really looking for is a 1d vector. Do you see value in accepting a matrix or dense tensor with 3+ dimensional extents? If not, we should just accept the vector_view directly (which is aliased to be any mdspan with 1d extents.

If we accepted a device_vector_view directly, we wouldn't need the enable_if statements at all. I think we should go ahead and do the same for the axpy to keep things consistent.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed - made the changes here so that both axpy and dot take device_vector_view's


// Right now the inputs and outputs need to all have the same value_type (float/double etc).
// Try to output a meaningful compiler error if mismatched types are passed here.
// Note: In the future we could remove this restriction using the cublasDotEx function
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just go ahead and wrap the cublasEx functions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created an issue so we can discuss further #977 .

Reading the docs a little closer, and it looks like even w/ cublasDotEx having different dtypes for the input/outputs isn't currently supported: https://docs.nvidia.com/cuda/cublas/index.html#cublas-dotEx - so it won't have much value for the dot API (though I could see a use for it myself with the gemm api w/ implicit and the mixed precision work I was talking about last week)

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes are looking great! Remaining things are very minor.

* @return raft::device_vector_view
*/
template <typename ElementType,
typename IndexType = std::uint32_t,
typename LayoutPolicy = layout_c_contiguous>
auto make_device_vector_view(ElementType* ptr, IndexType n)
auto make_device_vector_view(ElementType* ptr, IndexType n, IndexType stride = 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a little awkward. We accept a layout policy as a template argument, but then we also accept a function argument for a stride which essentially overrides the layout from the template.

Would it be achieving this same goal if a user were to just set a strided layout on the template argument directly? Perhaps we could provide a factory function to make said strided layout and provide the user with something like a statically sized object (eg. std::array) to set the strides for each dimension?

An of course, this is one of those things (the new strided factory function) that I think should have a usage example in the doxygen and perhaps even a subsection section in the mdspan tutorial markdown of the docs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm understanding you correctly - you're thinking we can just pass the layout mapping to the make_device_vector_view function directly , and add a new factory function for creating this layout mapping?

I took a stab at that in the last commit - unfortunately, I couldn't get a single make_device_vector_view function to compile successfully with being passed both a IndexType with the number of elements and the Mapping with the strided layout (was getting compile errors in various other raft functions that I hadn't updated). However, I could get it to work with adding an overload - which is whats in the last commit. Do you have any suggestions on how to clean this up =) ?

I'll add something to the tutorial / docs once we're happy with the API -

cpp/include/raft/linalg/axpy.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/axpy.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/dot.cuh Outdated Show resolved Hide resolved
cpp/include/raft/linalg/dot.cuh Show resolved Hide resolved
* Remove default types,
* Try to fix up factory functions for creating strided vector views
* Add dot funcction that takes host scalar / host_scalar_view
void dot(const raft::handle_t& handle,
raft::device_vector_view<const ElementType, IndexType, LayoutPolicy1> x,
raft::device_vector_view<const ElementType, IndexType, LayoutPolicy2> y,
ElementType* out)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for the host output, we probably should drop this overload. Sorry for being confusing here. I think it makes more sense to accept a host scalar by value for functions like axpy where the scalar is an input. For output on host, I think we should stick to the mdspan scalar wrappers.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed in latest commit

cpp/test/linalg/axpy.cu Outdated Show resolved Hide resolved
cpp/test/linalg/dot.cu Outdated Show resolved Hide resolved
cpp/include/raft/core/device_mdspan.hpp Outdated Show resolved Hide resolved
Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks again @benfred!

@cjnolet
Copy link
Member

cjnolet commented Nov 9, 2022

@gpucibot merge

@rapids-bot rapids-bot bot merged commit 7176d94 into rapidsai:branch-22.12 Nov 10, 2022
@benfred benfred deleted the linalg_dot branch November 10, 2022 06:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CMake cpp improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
Development

Successfully merging this pull request may close these issues.

[FEA] Expose raft::linalg::dot through public API
2 participants