-
Notifications
You must be signed in to change notification settings - Fork 222
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature: enabling oneDPL and sort primitive refactoring #3046
base: main
Are you sure you want to change the base?
feature: enabling oneDPL and sort primitive refactoring #3046
Conversation
Before merging, please remember to add this new dependency to the installation instructions in |
/intelci: run |
cpp/oneapi/dal/algo/decision_forest/backend/gpu/train_kernel_hist_impl_dpc.cpp
Outdated
Show resolved
Hide resolved
auto event = oneapi::dpl::experimental::kt::gpu::esimd::radix_sort_by_key<true, 8>( | ||
queue, | ||
val_in.get_mutable_data(), | ||
val_in.get_mutable_data() + val_in.get_count(), | ||
ind_in.get_mutable_data(), | ||
dpl::experimental::kt::kernel_param<256, 32>{}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Magic value 8, 256, 32 should be at least explained here.
Also, it might be beneficial not to hardcode them, but let the algorithmic kernel define them. I think different values can be chosen for different algorithms, or for different hardware platforms for better performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At least write the comments that the parameters are chosen for PVC.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code looks good to me, I only have couple of minor comments.
@@ -19,7 +19,7 @@ | |||
#include "oneapi/dal/table/row_accessor.hpp" | |||
#include "oneapi/dal/detail/profiler.hpp" | |||
#include "oneapi/dal/algo/decision_forest/backend/gpu/train_helpers.hpp" | |||
|
|||
#include <iostream> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove this include.
return ::oneapi::dal::backend::primitives::engine_type::philox4x32x10; | ||
case df_engine_types::mt19937: | ||
return ::oneapi::dal::backend::primitives::engine_type::mt19937; | ||
default: throw std::invalid_argument("Unsupported engine type 2"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please change the message in the exception. "type 2" sounds strange.
@@ -133,6 +133,20 @@ enum class splitter_mode { | |||
random | |||
}; | |||
|
|||
/// Available engine methods for building trees | |||
enum class df_engine_types { | |||
/// mt2203 engine |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those comments look excessive. Consider removing.
@@ -594,6 +611,17 @@ class descriptor : public detail::descriptor_base<Task> { | |||
return *this; | |||
} | |||
|
|||
/// Engine method for the random numbers generator used by the algorithm | |||
/// @remark default = df_engine_method::philox |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// @remark default = df_engine_method::philox | |
/// @remark default = df_engine_method::philox4x32x10 |
const auto model = train_result.get_model(); | ||
this->infer_base_checks(desc, data_test, this->get_homogen_table_id(), model, checker_list); | ||
} | ||
// DF_SPMD_CLS_TEST("df cls base check with default params and train weights") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please explain the reason of commenting out the test?
else { | ||
auto& device_engine = | ||
*(static_cast<gen_mt2203*>(engine_.get_device_engine_base_ptr().get()))->get(); | ||
return oneapi::mkl::rng::generate(distr, device_engine, count, dst, deps); | ||
} | ||
|
||
// default: throw std::runtime_error("Unsupported engine type in generate_rng"); | ||
//} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add error handling
else { | |
auto& device_engine = | |
*(static_cast<gen_mt2203*>(engine_.get_device_engine_base_ptr().get()))->get(); | |
return oneapi::mkl::rng::generate(distr, device_engine, count, dst, deps); | |
} | |
// default: throw std::runtime_error("Unsupported engine type in generate_rng"); | |
//} | |
else if (engine_type == engine_type::mt2203) { | |
auto& device_engine = | |
*(static_cast<gen_mt2203*>(engine_.get_device_engine_base_ptr().get()))->get(); | |
return oneapi::mkl::rng::generate(distr, device_engine, count, dst, deps); | |
} else { | |
throw std::runtime_error("Unsupported engine type in generate_rng"); | |
} |
@@ -19,6 +19,15 @@ | |||
#include "oneapi/dal/detail/profiler.hpp" | |||
|
|||
#include <sycl/ext/oneapi/experimental/builtins.hpp> | |||
#include <sycl/ext/oneapi/experimental/builtins.hpp> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove this line.
#include <sycl/ext/oneapi/experimental/builtins.hpp> |
auto event = oneapi::dpl::experimental::kt::gpu::esimd::radix_sort_by_key<true, 8>( | ||
queue, | ||
val_in.get_mutable_data(), | ||
val_in.get_mutable_data() + val_in.get_count(), | ||
ind_in.get_mutable_data(), | ||
dpl::experimental::kt::kernel_param<256, 32>{}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At least write the comments that the parameters are chosen for PVC.
INSTANTIATE_RADIX_SORT(std::int32_t) | ||
INSTANTIATE_RADIX_SORT(std::uint32_t) | ||
INSTANTIATE_RADIX_SORT(std::int64_t) | ||
INSTANTIATE_RADIX_SORT(std::uint64_t) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need all those instantiations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have these types for non dpl(for no pvc devices) radix sort, so its aligned
Description:
Feature: enabling oneDPL and sort primitive refactoring
Summary:
This PR introduces oneDPL enabling and radix sort replacement.
PR completeness and readability
Testing
Performance