Skip to content

[MLIR][OpenMP] Fix standalone distribute on the device #133094

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 3, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
Original file line number Diff line number Diff line change
@@ -222,6 +222,24 @@ def ScheduleModifier : OpenMP_I32EnumAttr<

def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;

//===----------------------------------------------------------------------===//
// target_region_flags enum.
//===----------------------------------------------------------------------===//

def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>;
def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>;
def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>;

def TargetRegionFlags : OpenMP_BitEnumAttr<
"TargetRegionFlags",
"target region property flags", [
TargetRegionFlagsNone,
TargetRegionFlagsGeneric,
TargetRegionFlagsSpmd,
TargetRegionFlagsTripCount
]>;

//===----------------------------------------------------------------------===//
// variable_capture_kind enum.
//===----------------------------------------------------------------------===//
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
@@ -1312,7 +1312,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
///
/// \param capturedOp result of a still valid (no modifications made to any
/// nested operations) previous call to `getInnermostCapturedOmpOp()`.
static llvm::omp::OMPTgtExecModeFlags
static ::mlir::omp::TargetRegionFlags
getKernelExecFlags(Operation *capturedOp);
}] # clausesExtraClassDeclaration;

204 changes: 123 additions & 81 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
@@ -1908,8 +1908,8 @@ LogicalResult TargetOp::verifyRegions() {
return emitError("target containing multiple 'omp.teams' nested ops");

// Check that host_eval values are only used in legal ways.
llvm::omp::OMPTgtExecModeFlags execFlags =
getKernelExecFlags(getInnermostCapturedOmpOp());
Operation *capturedOp = getInnermostCapturedOmpOp();
TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
for (Value hostEvalArg :
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
for (Operation *user : hostEvalArg.getUsers()) {
@@ -1924,7 +1924,8 @@ LogicalResult TargetOp::verifyRegions() {
"and 'thread_limit' in 'omp.teams'";
}
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
parallelOp->isAncestor(capturedOp) &&
hostEvalArg == parallelOp.getNumThreads())
continue;

@@ -1933,15 +1934,16 @@ LogicalResult TargetOp::verifyRegions() {
"'omp.parallel' when representing target SPMD";
}
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
loopNestOp.getOperation() == capturedOp &&
(llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
continue;

return emitOpError() << "host_eval argument only legal as loop bounds "
"and steps in 'omp.loop_nest' when "
"representing target SPMD or Generic-SPMD";
"and steps in 'omp.loop_nest' when trip count "
"must be evaluated in the host";
}

return emitOpError() << "host_eval argument illegal use in '"
@@ -1951,42 +1953,21 @@ LogicalResult TargetOp::verifyRegions() {
return success();
}

/// Only allow OpenMP terminators and non-OpenMP ops that have known memory
/// effects, but don't include a memory write effect.
static bool siblingAllowedInCapture(Operation *op) {
if (!op)
return false;
static Operation *
findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
assert(rootOp && "expected valid operation");

bool isOmpDialect =
op->getContext()->getLoadedDialect<omp::OpenMPDialect>() ==
op->getDialect();

if (isOmpDialect)
return op->hasTrait<OpTrait::IsTerminator>();

if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects;
memOp.getEffects(effects);
return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
return isa<MemoryEffects::Write>(effect.getEffect()) &&
isa<SideEffects::AutomaticAllocationScopeResource>(
effect.getResource());
});
}
return true;
}

Operation *TargetOp::getInnermostCapturedOmpOp() {
Dialect *ompDialect = (*this)->getDialect();
Dialect *ompDialect = rootOp->getDialect();
Operation *capturedOp = nullptr;
DominanceInfo domInfo;

// Process in pre-order to check operations from outermost to innermost,
// ensuring we only enter the region of an operation if it meets the criteria
// for being captured. We stop the exploration of nested operations as soon as
// we process a region holding no operations to be captured.
walk<WalkOrder::PreOrder>([&](Operation *op) {
if (op == *this)
rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
if (op == rootOp)
return WalkResult::advance();

// Ignore operations of other dialects or omp operations with no regions,
@@ -2001,22 +1982,24 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
// (i.e. its block's successors can reach it) or if it's not guaranteed to
// be executed before all exits of the region (i.e. it doesn't dominate all
// blocks with no successors reachable from the entry block).
Region *parentRegion = op->getParentRegion();
Block *parentBlock = op->getBlock();

for (Block *successor : parentBlock->getSuccessors())
if (successor->isReachable(parentBlock))
return WalkResult::interrupt();

for (Block &block : *parentRegion)
if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
!domInfo.dominates(parentBlock, &block))
return WalkResult::interrupt();
if (checkSingleMandatoryExec) {
Region *parentRegion = op->getParentRegion();
Block *parentBlock = op->getBlock();

for (Block *successor : parentBlock->getSuccessors())
if (successor->isReachable(parentBlock))
return WalkResult::interrupt();

for (Block &block : *parentRegion)
if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
!domInfo.dominates(parentBlock, &block))
return WalkResult::interrupt();
}

// Don't capture this op if it has a not-allowed sibling, and stop recursing
// into nested operations.
for (Operation &sibling : op->getParentRegion()->getOps())
if (&sibling != op && !siblingAllowedInCapture(&sibling))
if (&sibling != op && !siblingAllowedFn(&sibling))
return WalkResult::interrupt();

// Don't continue capturing nested operations if we reach an omp.loop_nest.
@@ -2029,10 +2012,35 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
return capturedOp;
}

llvm::omp::OMPTgtExecModeFlags
TargetOp::getKernelExecFlags(Operation *capturedOp) {
using namespace llvm::omp;
Operation *TargetOp::getInnermostCapturedOmpOp() {
auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();

// Only allow OpenMP terminators and non-OpenMP ops that have known memory
// effects, but don't include a memory write effect.
return findCapturedOmpOp(
*this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
if (!sibling)
return false;

if (ompDialect == sibling->getDialect())
return sibling->hasTrait<OpTrait::IsTerminator>();

if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
effects;
memOp.getEffects(effects);
return !llvm::any_of(
effects, [&](MemoryEffects::EffectInstance &effect) {
return isa<MemoryEffects::Write>(effect.getEffect()) &&
isa<SideEffects::AutomaticAllocationScopeResource>(
effect.getResource());
});
}
return true;
});
}

TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
// A non-null captured op is only valid if it resides inside of a TargetOp
// and is the result of calling getInnermostCapturedOmpOp() on it.
TargetOp targetOp =
@@ -2041,60 +2049,94 @@ TargetOp::getKernelExecFlags(Operation *capturedOp) {
(targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
"unexpected captured op");

// Make sure this region is capturing a loop. Otherwise, it's a generic
// kernel.
// If it's not capturing a loop, it's a default target region.
if (!isa_and_present<LoopNestOp>(capturedOp))
return OMP_TGT_EXEC_MODE_GENERIC;
return TargetRegionFlags::generic;

SmallVector<LoopWrapperInterface> wrappers;
cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
assert(!wrappers.empty());
// Get the innermost non-simd loop wrapper.
SmallVector<LoopWrapperInterface> loopWrappers;
cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
assert(!loopWrappers.empty());

// Ignore optional SIMD leaf construct.
auto *innermostWrapper = wrappers.begin();
LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
if (isa<SimdOp>(innermostWrapper))
innermostWrapper = std::next(innermostWrapper);

long numWrappers = std::distance(innermostWrapper, wrappers.end());

// Detect Generic-SPMD: target-teams-distribute[-simd].
// Detect SPMD: target-teams-loop.
if (numWrappers == 1) {
if (!isa<DistributeOp, LoopOp>(innermostWrapper))
return OMP_TGT_EXEC_MODE_GENERIC;

Operation *teamsOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
return OMP_TGT_EXEC_MODE_GENERIC;
auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
if (numWrappers != 1 && numWrappers != 2)
return TargetRegionFlags::generic;

if (teamsOp->getParentOp() == targetOp.getOperation())
return isa<DistributeOp>(innermostWrapper)
? OMP_TGT_EXEC_MODE_GENERIC_SPMD
: OMP_TGT_EXEC_MODE_SPMD;
}

// Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
// Detect target-teams-distribute-parallel-wsloop[-simd].
if (numWrappers == 2) {
if (!isa<WsloopOp>(innermostWrapper))
return OMP_TGT_EXEC_MODE_GENERIC;
return TargetRegionFlags::generic;

innermostWrapper = std::next(innermostWrapper);
if (!isa<DistributeOp>(innermostWrapper))
return OMP_TGT_EXEC_MODE_GENERIC;
return TargetRegionFlags::generic;

Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
return OMP_TGT_EXEC_MODE_GENERIC;
return TargetRegionFlags::generic;

Operation *teamsOp = parallelOp->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
return OMP_TGT_EXEC_MODE_GENERIC;
return TargetRegionFlags::generic;

if (teamsOp->getParentOp() == targetOp.getOperation())
return OMP_TGT_EXEC_MODE_SPMD;
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
}
// Detect target-teams-distribute[-simd] and target-teams-loop.
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
Operation *teamsOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
return TargetRegionFlags::generic;

if (teamsOp->getParentOp() != targetOp.getOperation())
return TargetRegionFlags::generic;

if (isa<LoopOp>(innermostWrapper))
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;

// Find single immediately nested captured omp.parallel and add spmd flag
// (generic-spmd case).
//
// TODO: This shouldn't have to be done here, as it is too easy to break.
// The openmp-opt pass should be updated to be able to promote kernels like
// this from "Generic" to "Generic-SPMD". However, the use of the
// `kmpc_distribute_static_loop` family of functions produced by the
// OMPIRBuilder for these kernels prevents that from working.
Dialect *ompDialect = targetOp->getDialect();
Operation *nestedCapture = findCapturedOmpOp(
capturedOp, /*checkSingleMandatoryExec=*/false,
[&](Operation *sibling) {
return sibling && (ompDialect != sibling->getDialect() ||
sibling->hasTrait<OpTrait::IsTerminator>());
});

TargetRegionFlags result =
TargetRegionFlags::generic | TargetRegionFlags::trip_count;

if (!nestedCapture)
return result;

while (nestedCapture->getParentOp() != capturedOp)
nestedCapture = nestedCapture->getParentOp();

return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
: result;
}
// Detect target-parallel-wsloop[-simd].
else if (isa<WsloopOp>(innermostWrapper)) {
Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
return TargetRegionFlags::generic;

if (parallelOp->getParentOp() == targetOp.getOperation())
return TargetRegionFlags::spmd;
}

return OMP_TGT_EXEC_MODE_GENERIC;
return TargetRegionFlags::generic;
}

//===----------------------------------------------------------------------===//
Original file line number Diff line number Diff line change
@@ -4646,7 +4646,17 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
combinedMaxThreadsVal = maxThreadsVal;

// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
attrs.ExecFlags = targetOp.getKernelExecFlags(capturedOp);
omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
assert(
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
omp::TargetRegionFlags::spmd) &&
"invalid kernel flags");
attrs.ExecFlags =
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
: llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
: llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
Comment on lines +4650 to +4659
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if a trip_count hits here (or why isn' it possible)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trip_count flag does not impact the ExecFlags, so it could be present but is ignored here. It is checked by initTargetRuntimeAttrs to decide whether to evaluate and pass the trip count from the host in the kernel arguments structure to the kernel launch call.

attrs.MinTeams = minTeamsVal;
attrs.MaxTeams.front() = maxTeamsVal;
attrs.MinThreads = 1;
@@ -4691,8 +4701,8 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
if (numThreads)
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);

if (targetOp.getKernelExecFlags(capturedOp) !=
llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
omp::TargetRegionFlags::trip_count)) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
attrs.LoopTripCount = nullptr;

4 changes: 2 additions & 2 deletions mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
@@ -2320,7 +2320,7 @@ func.func @omp_target_host_eval_parallel(%x : i32) {
// -----

func.func @omp_target_host_eval_loop1(%x : i32) {
// expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
// expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}}
omp.target host_eval(%x -> %arg0 : i32) {
omp.wsloop {
omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
@@ -2335,7 +2335,7 @@ func.func @omp_target_host_eval_loop1(%x : i32) {
// -----

func.func @omp_target_host_eval_loop2(%x : i32) {
// expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
// expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}}
omp.target host_eval(%x -> %arg0 : i32) {
omp.teams {
^bb0:
17 changes: 17 additions & 0 deletions mlir/test/Dialect/OpenMP/ops.mlir
Original file line number Diff line number Diff line change
@@ -2864,6 +2864,23 @@ func.func @omp_target_host_eval(%x : i32) {
omp.terminator
}

// CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
// CHECK: omp.parallel num_threads(%[[HOST_ARG]] : i32) {
// CHECK: omp.wsloop {
// CHECK: omp.loop_nest
omp.target host_eval(%x -> %arg0 : i32) {
%y = arith.constant 2 : i32
omp.parallel num_threads(%arg0 : i32) {
omp.wsloop {
omp.loop_nest (%iv) : i32 = (%y) to (%y) step (%y) {
omp.yield
}
}
omp.terminator
}
omp.terminator
}

// CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
// CHECK: omp.teams {
// CHECK: omp.distribute {
Loading
Oops, something went wrong.
Loading
Oops, something went wrong.