Skip to content
Merged
9 changes: 7 additions & 2 deletions src/DEM/APIPrivate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2142,8 +2142,13 @@ inline void DEMSolver::equipKernelIncludes(std::unordered_map<std::string, std::
// It's put here as ApiVersion.h.in (which sets DEME_CUDA_TOOLKIT_HEADERS) is a CMake-in configuration file, we don't
// want to include it anywhere in the h headers in case DEM-Engine is included by some parent project.
void DEMSolver::setDefaultSolverParams() {
m_jitify_options = {"-I" + (JitHelper::KERNEL_INCLUDE_DIR).string(), "-I" + (JitHelper::KERNEL_DIR).string(),
"-I" + std::string(DEME_CUDA_TOOLKIT_HEADERS), "-diag-suppress=550", "-diag-suppress=177"};
m_jitify_options = {"-I" + (JitHelper::KERNEL_INCLUDE_DIR).string(),
"-I" + (JitHelper::KERNEL_DIR).string(),
"-I" + std::string(DEME_CUDA_TOOLKIT_HEADERS),
"-diag-suppress=177",
"-diag-suppress=549",
"-diag-suppress=550",
"-std=c++17"};
}

} // namespace deme
71 changes: 56 additions & 15 deletions src/DEM/dT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2142,7 +2142,40 @@ inline void DEMDynamicThread::migrateEnduringContacts() {
granData.toDevice();
}

inline void DEMDynamicThread::calculateForces() {
// The argument is two maps: contact type -> (start offset, count), contact type -> list of [(program bundle name,
// kernel name)]
inline void DEMDynamicThread::dispatchCalcForceKernels(
const std::unordered_map<contact_t, std::pair<contactPairs_t, contactPairs_t>>& typeStartCountMap,
const std::unordered_map<contact_t, std::vector<std::pair<std::shared_ptr<jitify::Program>, std::string>>>&
typeKernelMap) {
// For each contact type that exists, call its corresponding kernel(s)
for (size_t i = 0; i < m_numExistingTypes; i++) {
contact_t contact_type = existingContactTypes[i];
const auto& start_count = typeStartCountMap.at(contact_type);
// Offset and count being contactPairs_t is very important, as CUDA kernel arguments cannot safely implicitly
// convert type (from size_t to unsigned int, for example)
contactPairs_t startOffset = start_count.first;
contactPairs_t count = start_count.second;

// For this contact type, get its list of (program bundle name, kernel name)
if (typeKernelMap.count(contact_type) == 0) {
DEME_ERROR("Contact type %d has no associated force kernel in to execute!", contact_type);
}
const auto& kernelList = typeKernelMap.at(contact_type);
for (const auto& [progName, kernelName] : kernelList) {
size_t blocks = (count + DT_FORCE_CALC_NTHREADS_PER_BLOCK - 1) / DT_FORCE_CALC_NTHREADS_PER_BLOCK;
if (blocks > 0) {
progName->kernel(kernelName)
.instantiate()
.configure(dim3(blocks), dim3(DT_FORCE_CALC_NTHREADS_PER_BLOCK), 0, streamInfo.stream)
.launch(&simParams, &granData, startOffset, count);
}
}
}
DEME_GPU_CALL(cudaStreamSynchronize(streamInfo.stream));
}

void DEMDynamicThread::calculateForces() {
// Reset force (acceleration) arrays for this time step
size_t nContactPairs = *solverScratchSpace.numContacts;

Expand Down Expand Up @@ -2171,18 +2204,14 @@ inline void DEMDynamicThread::calculateForces() {
}
timers.GetTimer("Clear force array").stop();

size_t blocks_needed_for_contacts =
(nContactPairs + DT_FORCE_CALC_NTHREADS_PER_BLOCK - 1) / DT_FORCE_CALC_NTHREADS_PER_BLOCK;
// If no contact then we don't have to calculate forces. Note there might still be forces, coming from prescription
// or other sources.
if (blocks_needed_for_contacts > 0) {
if (nContactPairs > 0) {
timers.GetTimer("Calculate contact forces").start();
// a custom kernel to compute forces
cal_force_kernels->kernel("calculateContactForces")
.instantiate()
.configure(dim3(blocks_needed_for_contacts), dim3(DT_FORCE_CALC_NTHREADS_PER_BLOCK), 0, streamInfo.stream)
.launch(&simParams, &granData, nContactPairs);
DEME_GPU_CALL(cudaStreamSynchronize(streamInfo.stream));

// Call specialized kernels for each contact type that exists
dispatchCalcForceKernels(typeStartCountMap, contactTypeKernelMap);

// displayDeviceFloat3(granData->contactForces, nContactPairs);
// displayDeviceArray<contact_t>(granData->contactType, nContactPairs);
// std::cout << "===========================" << std::endl;
Expand All @@ -2195,7 +2224,7 @@ inline void DEMDynamicThread::calculateForces() {
collectContactForcesThruCub(collect_force_kernels, granData, nContactPairs, simParams->nOwnerBodies,
contactPairArr_isFresh, streamInfo.stream, solverScratchSpace, timers);
} else {
blocks_needed_for_contacts =
size_t blocks_needed_for_contacts =
(nContactPairs + DEME_MAX_THREADS_PER_BLOCK - 1) / DEME_MAX_THREADS_PER_BLOCK;
// This does both acc and ang acc
collect_force_kernels->kernel("forceToAcc")
Expand Down Expand Up @@ -2283,10 +2312,10 @@ inline void DEMDynamicThread::unpack_impl() {
typeStartOffsets.toHost();
for (size_t i = 0; i < m_numExistingTypes; i++) {
DEME_DEBUG_PRINTF("Contact type %d starts at offset %u", existingContactTypes[i], typeStartOffsets[i]);
typeStartCountMap[existingContactTypes[i]] =
std::make_pair(typeStartOffsets[i],
(i + 1 < m_numExistingTypes ? typeStartOffsets[i + 1] : *solverScratchSpace.numContacts) -
typeStartOffsets[i]);
typeStartCountMap[existingContactTypes[i]] = std::make_pair(
typeStartOffsets[i],
(i + 1 < m_numExistingTypes ? typeStartOffsets[i + 1] : (contactPairs_t)*solverScratchSpace.numContacts) -
typeStartOffsets[i]);
}
// Debug output of the map
// for (const auto& entry : typeStartCountMap) {
Expand Down Expand Up @@ -2579,6 +2608,18 @@ void DEMDynamicThread::jitifyKernels(const std::unordered_map<std::string, std::
misc_kernels = std::make_shared<jitify::Program>(std::move(JitHelper::buildProgram(
"DEMMiscKernels", JitHelper::KERNEL_DIR / "DEMMiscKernels.cu", Subs, JitifyOptions)));
}

// For now, the contact type to kernel map is known and hard-coded after jitification
contactTypeKernelMap = {// Sphere-Sphere contact
{SPHERE_SPHERE_CONTACT, {{cal_force_kernels, "calculateContactForces_SphSph"}}},
// Sphere-Triangle contact
{SPHERE_TRIANGLE_CONTACT, {{cal_force_kernels, "calculateContactForces_SphTri"}}},
// Sphere-Analytical contact
{SPHERE_ANALYTICAL_CONTACT, {{cal_force_kernels, "calculateContactForces_SphAnal"}}},
// Triangle-Triangle contact
{TRIANGLE_TRIANGLE_CONTACT, {{cal_force_kernels, "calculateContactForces_TriTri"}}},
// Triangle-Analytical contact
{TRIANGLE_ANALYTICAL_CONTACT, {{cal_force_kernels, "calculateContactForces_TriAnal"}}}};
}

float* DEMDynamicThread::inspectCall(const std::shared_ptr<jitify::Program>& inspection_kernel,
Expand Down
10 changes: 9 additions & 1 deletion src/DEM/dT.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,9 @@ class DEMDynamicThread {
size_t m_numExistingTypes = 0;
// A map that records the contact <ID start, and count> for each contact type currently existing
std::unordered_map<contact_t, std::pair<contactPairs_t, contactPairs_t>> typeStartCountMap;
// A map that records the corresponding jitify program bundle and kernel name for each contact type
std::unordered_map<contact_t, std::vector<std::pair<std::shared_ptr<jitify::Program>, std::string>>>
contactTypeKernelMap;

// dT's timers
std::vector<std::string> timer_names = {"Clear force array", "Calculate contact forces", "Optional force reduction",
Expand Down Expand Up @@ -673,8 +676,13 @@ class DEMDynamicThread {
// Migrate contact history to fit the structure of the newly received contact array
inline void migrateEnduringContacts();

// Impl of calculateForces
inline void dispatchCalcForceKernels(
const std::unordered_map<contact_t, std::pair<contactPairs_t, contactPairs_t>>& typeStartCountMap,
const std::unordered_map<contact_t, std::vector<std::pair<std::shared_ptr<jitify::Program>, std::string>>>&
typeKernelMap);
// Update clump-based acceleration array based on sphere-based force array
inline void calculateForces();
void calculateForces();

// Update clump pos/oriQ and vel/omega based on acceleration
inline void integrateOwnerMotions();
Expand Down
5 changes: 3 additions & 2 deletions src/demo/DEMdemo_SingleSphereCollide.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@ int main() {

// Testing modifying jitify options and force model prerequisites
auto jitify_options = DEMSim.GetJitifyOptions();
jitify_options.pop_back(); // Remove a warning suppression option
// DEMSim.SetJitifyOptions(jitify_options);
jitify_options.pop_back(); // Remove C++ std17 option
jitify_options.push_back("-std=c++20");
DEMSim.SetJitifyOptions(jitify_options); // Then set it
my_force_model->DefineCustomModelPrerequisites(
"float3 __device__ GetContactForce(float3 AOwner, float3 BOwner, float3 ALinVel, float3 BLinVel, "
"float3 ARotVel, float3 BRotVel, float delta_time, float delta_tan_x, float delta_tan_y, "
Expand Down
Loading