From d7ba565cba420068b35dbb667b100f21ad60597f Mon Sep 17 00:00:00 2001 From: Julien Esseiva Date: Sat, 2 Dec 2023 05:12:52 -0800 Subject: [PATCH] Support sorting tracks by particle types (#1044) --- src/celeritas/Types.cc | 1 + src/celeritas/Types.hh | 1 + src/celeritas/track/SortTracksAction.cc | 44 +++++++++--- src/celeritas/track/detail/TrackSortUtils.cu | 67 ++++++++++++------- .../sys/KernelParamCalculator.device.hh | 23 +++++++ 5 files changed, 99 insertions(+), 37 deletions(-) diff --git a/src/celeritas/Types.cc b/src/celeritas/Types.cc index 52c38eac1a..c6673a465b 100644 --- a/src/celeritas/Types.cc +++ b/src/celeritas/Types.cc @@ -61,6 +61,7 @@ char const* to_cstring(TrackOrder value) "sort_along_step_action", "sort_step_limit_action", "sort_action", + "sort_particle_type", }; return to_cstring_impl(value); } diff --git a/src/celeritas/Types.hh b/src/celeritas/Types.hh index 2de7a3d352..4a27dc5bce 100644 --- a/src/celeritas/Types.hh +++ b/src/celeritas/Types.hh @@ -144,6 +144,7 @@ enum class TrackOrder sort_along_step_action, //!< Sort only by the along-step action id sort_step_limit_action, //!< Sort only by the step limit action id sort_action, //!< Sort by along-step id, then post-step ID + sort_particle_type, //!< Sort by particle type size_ }; diff --git a/src/celeritas/track/SortTracksAction.cc b/src/celeritas/track/SortTracksAction.cc index 87b01d3d4c..5451ec233d 100644 --- a/src/celeritas/track/SortTracksAction.cc +++ b/src/celeritas/track/SortTracksAction.cc @@ -34,10 +34,21 @@ bool is_sort_trackorder(TrackOrder to) TrackOrder::sort_step_limit_action, TrackOrder::sort_along_step_action, TrackOrder::sort_action, + TrackOrder::sort_particle_type, }; return std::find(std::begin(allowed), std::end(allowed), to) != std::end(allowed); } + +/*! + * Checks whether the TrackOrder sort tracks using an ActionId. + */ +inline bool is_sort_by_action(TrackOrder to) +{ + return to == TrackOrder::sort_along_step_action + || to == TrackOrder::sort_step_limit_action + || to == TrackOrder::sort_action; +} //---------------------------------------------------------------------------// } // namespace @@ -66,6 +77,9 @@ SortTracksAction::SortTracksAction(ActionId id, TrackOrder track_order) // Sort *before* post-step action, i.e. *after* pre-post and // along-step return ActionOrder::sort_pre_post; + case TrackOrder::sort_particle_type: + // Sorth at the beginning of the step + return ActionOrder::sort_start; default: CELER_ASSERT_UNREACHABLE(); } @@ -86,6 +100,8 @@ std::string SortTracksAction::label() const return "sort-tracks-along-step"; case TrackOrder::sort_step_limit_action: return "sort-tracks-post-step"; + case TrackOrder::sort_particle_type: + return "sort-tracks-start"; default: CELER_ASSERT_UNREACHABLE(); } @@ -97,11 +113,14 @@ std::string SortTracksAction::label() const void SortTracksAction::execute(CoreParams const&, CoreStateHost& state) const { detail::sort_tracks(state.ref(), track_order_); - detail::count_tracks_per_action( - state.ref(), - state.action_thread_offsets()[AllItems{}], - state.action_thread_offsets(), - track_order_); + if (is_sort_by_action(track_order_)) + { + detail::count_tracks_per_action( + state.ref(), + state.action_thread_offsets()[AllItems{}], + state.action_thread_offsets(), + track_order_); + } } //---------------------------------------------------------------------------// @@ -111,12 +130,15 @@ void SortTracksAction::execute(CoreParams const&, CoreStateHost& state) const void SortTracksAction::execute(CoreParams const&, CoreStateDevice& state) const { detail::sort_tracks(state.ref(), track_order_); - detail::count_tracks_per_action( - state.ref(), - state.native_action_thread_offsets()[AllItems{}], - state.action_thread_offsets(), - track_order_); + if (is_sort_by_action(track_order_)) + { + detail::count_tracks_per_action( + state.ref(), + state.native_action_thread_offsets()[AllItems{}], + state.action_thread_offsets(), + track_order_); + } } //---------------------------------------------------------------------------// diff --git a/src/celeritas/track/detail/TrackSortUtils.cu b/src/celeritas/track/detail/TrackSortUtils.cu index ff56cb2f7a..32e95c02e2 100644 --- a/src/celeritas/track/detail/TrackSortUtils.cu +++ b/src/celeritas/track/detail/TrackSortUtils.cu @@ -60,36 +60,38 @@ void partition_impl(TrackSlots const& track_slots, F&& func, StreamId stream_id) //---------------------------------------------------------------------------// +template __global__ void -reorder_actions_kernel(ObserverPtr track_slots, - ObserverPtr actions, - ObserverPtr out_actions, - size_type size) +reorder_ids_kernel(ObserverPtr track_slots, + ObserverPtr ids, + ObserverPtr ids_out, + size_type size) { if (ThreadId tid = celeritas::KernelParamCalculator::thread_id(); tid < size) { - out_actions.get()[tid.get()] - = actions.get()[track_slots.get()[tid.get()]].unchecked_get(); + ids_out.get()[tid.get()] + = ids.get()[track_slots.get()[tid.get()]].unchecked_get(); } } +template void sort_impl(TrackSlots const& track_slots, - ObserverPtr actions, + ObserverPtr ids, StreamId stream_id) { - DeviceVector reordered_actions(track_slots.size(), - stream_id); - CELER_LAUNCH_KERNEL(reorder_actions, - track_slots.size(), - celeritas::device().stream(stream_id).get(), - track_slots.data(), - actions, - make_observer(reordered_actions.data()), - track_slots.size()); + DeviceVector reordered_ids(track_slots.size(), stream_id); + CELER_LAUNCH_KERNEL_TEMPLATE_1(reorder_ids, + Id, + track_slots.size(), + celeritas::device().stream(stream_id).get(), + track_slots.data(), + ids, + make_observer(reordered_ids.data()), + track_slots.size()); thrust::sort_by_key(thrust_execute_on(stream_id), - reordered_actions.data(), - reordered_actions.data() + reordered_actions.size(), + reordered_ids.data(), + reordered_ids.data() + reordered_ids.size(), device_pointer_cast(track_slots.data())); CELER_DEVICE_CHECK_ERROR(); } @@ -197,14 +199,27 @@ void sort_tracks(DeviceRef const& states, TrackOrder order) return partition_impl(states.track_slots, alive_predicate{states.sim.status.data()}, states.stream_id); - case TrackOrder::sort_along_step_action: - return sort_impl(states.track_slots, - states.sim.along_step_action.data(), - states.stream_id); - case TrackOrder::sort_step_limit_action: - return sort_impl(states.track_slots, - states.sim.post_step_action.data(), - states.stream_id); + case TrackOrder::sort_along_step_action: { + using Id = + typename decltype(states.sim.along_step_action)::value_type; + return sort_impl(states.track_slots, + states.sim.along_step_action.data(), + states.stream_id); + } + case TrackOrder::sort_step_limit_action: { + using Id = + typename decltype(states.sim.post_step_action)::value_type; + return sort_impl(states.track_slots, + states.sim.post_step_action.data(), + states.stream_id); + } + case TrackOrder::sort_particle_type: { + using Id = + typename decltype(states.particles.particle_id)::value_type; + return sort_impl(states.track_slots, + states.particles.particle_id.data(), + states.stream_id); + } default: CELER_ASSERT_UNREACHABLE(); } diff --git a/src/corecel/sys/KernelParamCalculator.device.hh b/src/corecel/sys/KernelParamCalculator.device.hh index 2e3e4f104f..a58ab8ed96 100644 --- a/src/corecel/sys/KernelParamCalculator.device.hh +++ b/src/corecel/sys/KernelParamCalculator.device.hh @@ -44,6 +44,29 @@ CELER_DEVICE_CHECK_ERROR(); \ } while (0) +/*! + * \def CELER_LAUNCH_KERNEL_TEMPLATE_1 + * + * Create a kernel param calculator with the given kernel with + * one template parameter, assuming the unction itself has a \c _kernel + * suffix, and launch with the given block/thread sizes and arguments list. + */ +#define CELER_LAUNCH_KERNEL_TEMPLATE_1(NAME, T1, THREADS, STREAM, ...) \ + do \ + { \ + static const ::celeritas::KernelParamCalculator calc_launch_params_( \ + #NAME, NAME##_kernel); \ + auto grid_ = calc_launch_params_(THREADS); \ + \ + CELER_LAUNCH_KERNEL_IMPL(NAME##_kernel, \ + grid_.blocks_per_grid, \ + grid_.threads_per_block, \ + 0, \ + STREAM, \ + __VA_ARGS__); \ + CELER_DEVICE_CHECK_ERROR(); \ + } while (0) + #if CELERITAS_USE_CUDA # define CELER_LAUNCH_KERNEL_IMPL(KERNEL, GRID, BLOCK, SHARED, STREAM, ...) \ KERNEL<<>>(__VA_ARGS__)