Skip to content

Commit

Permalink
Refactor TrackSortUtils (celeritas-project#1047)
Browse files Browse the repository at this point in the history
  • Loading branch information
esseivaju committed Dec 5, 2023
1 parent d42f872 commit 2d6d983
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 147 deletions.
96 changes: 38 additions & 58 deletions src/celeritas/track/detail/TrackSortUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ using ThreadItems
using TrackSlots = ThreadItems<TrackSlotId::size_type>;

template<class F>
void partition_impl(TrackSlots const& track_slots, F&& func, StreamId)
void partition_impl(TrackSlots const& track_slots, F&& func)
{
auto* start = track_slots.data().get();
std::partition(start, start + track_slots.size(), std::forward<F>(func));
Expand All @@ -38,43 +38,12 @@ void partition_impl(TrackSlots const& track_slots, F&& func, StreamId)
//---------------------------------------------------------------------------//

template<class F>
void sort_impl(TrackSlots const& track_slots, F&& func, StreamId)
void sort_impl(TrackSlots const& track_slots, F&& func)
{
auto* start = track_slots.data().get();
std::sort(start, start + track_slots.size(), std::forward<F>(func));
}

// PRE: get_action is sorted, i.e. i <= j ==> get_action(i) <=
// get_action(j)
template<class F>
void count_tracks_per_action_impl(Span<ThreadId> offsets,
size_type size,
F&& get_action)
{
std::fill(offsets.begin(), offsets.end(), ThreadId{});

// if get_action(0) != get_action(1), get_action(0) never gets initialized
#pragma omp parallel for
for (size_type i = 1; i < size; ++i)
{
ActionId current_action = get_action(ThreadId{i});
if (!current_action)
continue;

if (current_action != get_action(ThreadId{i - 1}))
{
offsets[current_action.unchecked_get()] = ThreadId{i};
}
}

// so make sure get_action(0) is initialized
if (ActionId first = get_action(ThreadId{0}))
{
offsets[first.unchecked_get()] = ThreadId{0};
}
backfill_action_count(offsets, size);
}

//---------------------------------------------------------------------------//
} // namespace

Expand Down Expand Up @@ -111,18 +80,14 @@ void sort_tracks(HostRef<CoreStateData> const& states, TrackOrder order)
{
case TrackOrder::partition_status:
return partition_impl(states.track_slots,
alive_predicate{states.sim.status.data()},
states.stream_id);
AlivePredicate{states.sim.status.data()});
case TrackOrder::sort_along_step_action:
return sort_impl(
states.track_slots,
action_comparator{states.sim.along_step_action.data()},
states.stream_id);
case TrackOrder::sort_step_limit_action:
return sort_impl(
states.track_slots,
action_comparator{states.sim.post_step_action.data()},
states.stream_id);
return sort_impl(states.track_slots,
IdComparator{get_action_ptr(states, order)});
case TrackOrder::sort_particle_type:
return sort_impl(states.track_slots,
IdComparator{states.particles.particle_id.data()});
default:
CELER_ASSERT_UNREACHABLE();
}
Expand All @@ -140,25 +105,40 @@ void count_tracks_per_action(
Collection<ThreadId, Ownership::value, MemSpace::host, ActionId>&,
TrackOrder order)
{
switch (order)
CELER_ASSERT(order == TrackOrder::sort_along_step_action
|| order == TrackOrder::sort_step_limit_action);

ActionAccessor get_action{get_action_ptr(states, order),
states.track_slots.data()};

std::fill(offsets.begin(), offsets.end(), ThreadId{});
auto const size = states.size();
// if get_action(0) != get_action(1), get_action(0) never gets initialized
#pragma omp parallel for
for (size_type i = 1; i < size; ++i)
{
case TrackOrder::sort_along_step_action:
return count_tracks_per_action_impl(
offsets,
states.size(),
ActionAccessor{states.sim.along_step_action.data(),
states.track_slots.data()});
case TrackOrder::sort_step_limit_action:
return count_tracks_per_action_impl(
offsets,
states.size(),
ActionAccessor{states.sim.post_step_action.data(),
states.track_slots.data()});
default:
return;
ActionId current_action = get_action(ThreadId{i});
if (!current_action)
continue;

if (current_action != get_action(ThreadId{i - 1}))
{
offsets[current_action.unchecked_get()] = ThreadId{i};
}
}

// so make sure get_action(0) is initialized
if (ActionId first = get_action(ThreadId{0}))
{
offsets[first.unchecked_get()] = ThreadId{0};
}
backfill_action_count(offsets, size);
}

//---------------------------------------------------------------------------//
/*!
* Fill missing action offsets.
*/
void backfill_action_count(Span<ThreadId> offsets, size_type num_actions)
{
CELER_EXPECT(offsets.size() >= 2);
Expand Down
127 changes: 53 additions & 74 deletions src/celeritas/track/detail/TrackSortUtils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ using ThreadItems
using TrackSlots = ThreadItems<TrackSlotId::size_type>;

//---------------------------------------------------------------------------//

/*!
* Partition track_slots based on predicate.
*/
template<class F>
void partition_impl(TrackSlots const& track_slots, F&& func, StreamId stream_id)
{
Expand All @@ -59,7 +61,10 @@ void partition_impl(TrackSlots const& track_slots, F&& func, StreamId stream_id)
}

//---------------------------------------------------------------------------//

/*!
* Reorder OpaqueId's based on track_slots so that track_slots[tid] correspond
* to ids[tid] instead of ids[tacks_slots[tid]].
*/
template<class Id>
__global__ void
reorder_ids_kernel(ObserverPtr<TrackSlotId::size_type const> track_slots,
Expand All @@ -75,6 +80,10 @@ reorder_ids_kernel(ObserverPtr<TrackSlotId::size_type const> track_slots,
}
}

//---------------------------------------------------------------------------//
/*!
* Sort track slots using ids as keys.
*/
template<class Id, class IdT = typename Id::size_type>
void sort_impl(TrackSlots const& track_slots,
ObserverPtr<Id const> ids,
Expand All @@ -96,13 +105,19 @@ void sort_impl(TrackSlots const& track_slots,
CELER_DEVICE_CHECK_ERROR();
}

// PRE: get_action is sorted, i.e. i <= j ==> get_action(i) <=
// get_action(j)
template<class F>
__device__ void
tracks_per_action_impl(Span<ThreadId> offsets, size_type size, F&& get_action)
//---------------------------------------------------------------------------//
/*!
* Calculate thread boundaries based on action ID.
* \pre actions are sorted
*/
__global__ void
tracks_per_action_kernel(ObserverPtr<ActionId const> actions,
ObserverPtr<TrackSlotId::size_type const> track_slots,
Span<ThreadId> offsets,
size_type size)
{
ThreadId tid = celeritas::KernelParamCalculator::thread_id();
ActionAccessor get_action{actions, track_slots};

if ((tid < size) && tid != ThreadId{0})
{
Expand All @@ -123,30 +138,6 @@ tracks_per_action_impl(Span<ThreadId> offsets, size_type size, F&& get_action)
}
}

__global__ void tracks_per_action_kernel(DeviceRef<CoreStateData> const states,
Span<ThreadId> offsets,
size_type size,
TrackOrder order)
{
switch (order)
{
case TrackOrder::sort_along_step_action:
return tracks_per_action_impl(
offsets,
size,
ActionAccessor{states.sim.along_step_action.data(),
states.track_slots.data()});
case TrackOrder::sort_step_limit_action:
return tracks_per_action_impl(
offsets,
size,
ActionAccessor{states.sim.post_step_action.data(),
states.track_slots.data()});
default:
CELER_ASSERT_UNREACHABLE();
}
}

//---------------------------------------------------------------------------//
} // namespace

Expand Down Expand Up @@ -197,22 +188,13 @@ void sort_tracks(DeviceRef<CoreStateData> const& states, TrackOrder order)
{
case TrackOrder::partition_status:
return partition_impl(states.track_slots,
alive_predicate{states.sim.status.data()},
AlivePredicate{states.sim.status.data()},
states.stream_id);
case TrackOrder::sort_along_step_action: {
using Id =
typename decltype(states.sim.along_step_action)::value_type;
return sort_impl<Id>(states.track_slots,
states.sim.along_step_action.data(),
states.stream_id);
}
case TrackOrder::sort_step_limit_action: {
using Id =
typename decltype(states.sim.post_step_action)::value_type;
return sort_impl<Id>(states.track_slots,
states.sim.post_step_action.data(),
states.stream_id);
}
case TrackOrder::sort_along_step_action:
case TrackOrder::sort_step_limit_action:
return sort_impl(states.track_slots,
get_action_ptr(states, order),
states.stream_id);
case TrackOrder::sort_particle_type: {
using Id =
typename decltype(states.particles.particle_id)::value_type;
Expand All @@ -237,34 +219,31 @@ void count_tracks_per_action(
Collection<ThreadId, Ownership::value, MemSpace::mapped, ActionId>& out,
TrackOrder order)
{
if (order == TrackOrder::sort_along_step_action
|| order == TrackOrder::sort_step_limit_action)
{
// dispatch in the kernel since CELER_LAUNCH_KERNEL doesn't work
// with templated kernels
auto start = device_pointer_cast(make_observer(offsets.data()));
thrust::fill(thrust_execute_on(states.stream_id),
start,
start + offsets.size(),
ThreadId{});
CELER_DEVICE_CHECK_ERROR();
auto* stream = celeritas::device().stream(states.stream_id).get();
CELER_LAUNCH_KERNEL(tracks_per_action,
states.size(),
stream,
states,
offsets,
states.size(),
order);

Span<ThreadId> sout = out[AllItems<ThreadId, MemSpace::mapped>{}];
Copier<ThreadId, MemSpace::host> copy_to_host{sout, states.stream_id};
copy_to_host(MemSpace::device, offsets);

// Copies must be complete before backfilling
CELER_DEVICE_CALL_PREFIX(StreamSynchronize(stream));
backfill_action_count(sout, states.size());
}
CELER_ASSERT(order == TrackOrder::sort_along_step_action
|| order == TrackOrder::sort_step_limit_action);

auto start = device_pointer_cast(make_observer(offsets.data()));
thrust::fill(thrust_execute_on(states.stream_id),
start,
start + offsets.size(),
ThreadId{});
CELER_DEVICE_CHECK_ERROR();
auto* stream = celeritas::device().stream(states.stream_id).get();
CELER_LAUNCH_KERNEL(tracks_per_action,
states.size(),
stream,
get_action_ptr(states, order),
states.track_slots.data(),
offsets,
states.size());

Span<ThreadId> sout = out[AllItems<ThreadId, MemSpace::mapped>{}];
Copier<ThreadId, MemSpace::host> copy_to_host{sout, states.stream_id};
copy_to_host(MemSpace::device, offsets);

// Copies must be complete before backfilling
CELER_DEVICE_CALL_PREFIX(StreamSynchronize(stream));
backfill_action_count(sout, states.size());
}

//---------------------------------------------------------------------------//
Expand Down

0 comments on commit 2d6d983

Please sign in to comment.