Skip to content

Commit

Permalink
Refactor ActionSequence to be templated on Params (celeritas-projec…
Browse files Browse the repository at this point in the history
  • Loading branch information
esseivaju committed Apr 28, 2024
1 parent d6720d6 commit 772d57c
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 36 deletions.
5 changes: 3 additions & 2 deletions src/celeritas/global/Stepper.hh
Expand Up @@ -30,6 +30,7 @@ struct Primary;

namespace detail
{
template<class Params>
class ActionSequence;
}

Expand Down Expand Up @@ -78,7 +79,7 @@ class StepperInterface
//!@{
//! \name Type aliases
using Input = StepperInput;
using ActionSequence = detail::ActionSequence;
using ActionSequence = detail::ActionSequence<CoreParams>;
using SpanConstPrimary = Span<Primary const>;
using result_type = StepperResult;
//!@}
Expand Down Expand Up @@ -157,7 +158,7 @@ class Stepper final : public StepperInterface
private:
// Params and call sequence
std::shared_ptr<CoreParams const> params_;
std::shared_ptr<detail::ActionSequence> actions_;
std::shared_ptr<ActionSequence> actions_;
// State data
CoreState<M> state_;
};
Expand Down
55 changes: 27 additions & 28 deletions src/celeritas/global/detail/ActionSequence.cc
Expand Up @@ -20,8 +20,8 @@
#include "corecel/sys/ScopedProfiling.hh"
#include "corecel/sys/Stopwatch.hh"
#include "corecel/sys/Stream.hh"
#include "celeritas/global/CoreParams.hh"

#include "ParamsTraits.hh"
#include "../ActionInterface.hh"
#include "../ActionRegistry.hh"
#include "../CoreState.hh"
Expand All @@ -34,22 +34,26 @@ namespace detail
/*!
* Construct from an action registry and sequence options.
*/
ActionSequence::ActionSequence(ActionRegistry const& reg, Options options)
template<class Params>
ActionSequence<Params>::ActionSequence(ActionRegistry const& reg,
Options options)
: options_(std::move(options))
{
actions_.reserve(reg.num_actions());
// Loop over all action IDs
for (auto aidx : range(reg.num_actions()))
{
// Get abstract action shared pointer and see if it's explicit
auto const& base = reg.action(ActionId{aidx});
if (auto expl
= std::dynamic_pointer_cast<ExplicitActionInterface const>(base))
using element_type = typename SPConstSpecializedExplicit::element_type;
if (auto expl = std::dynamic_pointer_cast<element_type>(base))
{
// Add explicit action to our array
actions_.push_back(std::move(expl));
}
}

begin_run_.reserve(reg.mutable_actions().size());
// Loop over all mutable actions
for (auto const& base : reg.mutable_actions())
{
Expand All @@ -63,7 +67,8 @@ ActionSequence::ActionSequence(ActionRegistry const& reg, Options options)
// Sort actions by increasing order (and secondarily, increasing IDs)
std::sort(actions_.begin(),
actions_.end(),
[](SPConstExplicit const& a, SPConstExplicit const& b) {
[](SPConstSpecializedExplicit const& a,
SPConstSpecializedExplicit const& b) {
return std::make_tuple(a->order(), a->action_id())
< std::make_tuple(b->order(), b->action_id());
});
Expand All @@ -78,8 +83,9 @@ ActionSequence::ActionSequence(ActionRegistry const& reg, Options options)
/*!
* Initialize actions and states.
*/
template<class Params>
template<MemSpace M>
void ActionSequence::begin_run(CoreParams const& params, CoreState<M>& state)
void ActionSequence<Params>::begin_run(Params const& params, State<M>& state)
{
for (auto const& sp_action : begin_run_)
{
Expand All @@ -92,15 +98,10 @@ void ActionSequence::begin_run(CoreParams const& params, CoreState<M>& state)
/*!
* Call all explicit actions with host or device data.
*/
template<typename Params, template<MemSpace M> class State, MemSpace M>
void ActionSequence::execute(Params const& params, State<M>& state)
template<class Params>
template<MemSpace M>
void ActionSequence<Params>::execute(Params const& params, State<M>& state)
{
using ExplicitAction = typename ParamsTraits<Params>::ExplicitAction;

static_assert(
std::is_same_v<State<M>, typename ParamsTraits<Params>::template State<M>>,
"The Params and State type are not matching.");

[[maybe_unused]] Stream::StreamT stream = nullptr;
if (M == MemSpace::device && options_.sync)
{
Expand All @@ -114,9 +115,7 @@ void ActionSequence::execute(Params const& params, State<M>& state)
{
ScopedProfiling profile_this{actions_[i]->label()};
Stopwatch get_time;
auto const& concrete_action
= dynamic_cast<ExplicitAction const&>(*actions_[i]);
concrete_action.execute(params, state);
actions_[i]->execute(params, state);
if (M == MemSpace::device)
{
CELER_DEVICE_CALL_PREFIX(StreamSynchronize(stream));
Expand All @@ -127,12 +126,10 @@ void ActionSequence::execute(Params const& params, State<M>& state)
else
{
// Just loop over the actions
for (SPConstExplicit const& sp_action : actions_)
for (auto const& sp_action : actions_)
{
ScopedProfiling profile_this{sp_action->label()};
auto const& concrete_action
= dynamic_cast<ExplicitAction const&>(*sp_action);
concrete_action.execute(params, state);
sp_action->execute(params, state);
}
}
}
Expand All @@ -141,15 +138,17 @@ void ActionSequence::execute(Params const& params, State<M>& state)
// Explicit template instantiation
//---------------------------------------------------------------------------//

template void
ActionSequence::begin_run(CoreParams const&, CoreState<MemSpace::host>&);
template void
ActionSequence::begin_run(CoreParams const&, CoreState<MemSpace::device>&);
template class ActionSequence<CoreParams>;

template void ActionSequence<CoreParams>::begin_run(CoreParams const&,
State<MemSpace::host>&);
template void ActionSequence<CoreParams>::begin_run(CoreParams const&,
State<MemSpace::device>&);

template void
ActionSequence::execute(CoreParams const&, CoreState<MemSpace::host>&);
template void
ActionSequence::execute(CoreParams const&, CoreState<MemSpace::device>&);
ActionSequence<CoreParams>::execute(CoreParams const&, State<MemSpace::host>&);
template void ActionSequence<CoreParams>::execute(CoreParams const&,
State<MemSpace::device>&);

// TODO: add explicit template instantiation of execute for optical data

Expand Down
27 changes: 21 additions & 6 deletions src/celeritas/global/detail/ActionSequence.hh
Expand Up @@ -8,10 +8,12 @@
#pragma once

#include <memory>
#include <type_traits>
#include <vector>

#include "corecel/Types.hh"

#include "ParamsTraits.hh"
#include "../ActionInterface.hh"
#include "../CoreTrackDataFwd.hh"

Expand All @@ -30,18 +32,31 @@ namespace detail
* TODO accessors here are used by diagnostic output from celer-sim etc.;
* perhaps make this public or add a diagnostic output for it?
*/
template<class Params>
class ActionSequence
{
public:
//!@{
//! \name Type aliases
template<MemSpace M>
using State = typename ParamsTraits<Params>::template State<M>;
using SpecializedExplicitAction =
typename ParamsTraits<Params>::ExplicitAction;
using SPBegin = std::shared_ptr<BeginRunActionInterface>;
using SPConstExplicit = std::shared_ptr<ExplicitActionInterface const>;
using SPConstSpecializedExplicit
= std::shared_ptr<SpecializedExplicitAction const>;
using VecBeginAction = std::vector<SPBegin>;
using VecExplicitAction = std::vector<SPConstExplicit>;
using VecSpecializedExplicitAction
= std::vector<SPConstSpecializedExplicit>;
using VecDouble = std::vector<double>;
//!@}

// Verify that we have a valid explicit action type for the given Params
static_assert(
std::is_base_of_v<ExplicitActionInterface, SpecializedExplicitAction>,
"ParamTraits<Params> explicit action must be derived from "
"ExplicitActionInterface");

//! Construction/execution options
struct Options
{
Expand All @@ -56,10 +71,10 @@ class ActionSequence

// Launch all actions with the given memory space.
template<MemSpace M>
void begin_run(CoreParams const& params, CoreState<M>& state);
void begin_run(Params const& params, State<M>& state);

// Launch all actions with the given memory space.
template<typename Params, template<MemSpace M> class State, MemSpace M>
template<MemSpace M>
void execute(Params const&, State<M>& state);

//// ACCESSORS ////
Expand All @@ -71,15 +86,15 @@ class ActionSequence
VecBeginAction const& begin_run_actions() const { return begin_run_; }

//! Get the ordered vector of actions in the sequence
VecExplicitAction const& actions() const { return actions_; }
VecSpecializedExplicitAction const& actions() const { return actions_; }

//! Get the corresponding accumulated time, if 'sync' or host called
VecDouble const& accum_time() const { return accum_time_; }

private:
Options options_;
VecBeginAction begin_run_;
VecExplicitAction actions_;
VecSpecializedExplicitAction actions_;
VecDouble accum_time_;
};

Expand Down

0 comments on commit 772d57c

Please sign in to comment.