-
Notifications
You must be signed in to change notification settings - Fork 13.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][OpenMP] Fix standalone distribute on the device #133094
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir-openmp Author: Sergio Afonso (skatrak) ChangesThis patch updates the handling of target regions to set trip counts and kernel execution modes properly, based on clang's behavior. This fixes a race condition on This is how kernels are classified, after changes introduced in this patch: ! Exec mode: SPMD.
! Trip count: Set.
!$omp target teams distribute parallel do
do i=...
end do
! Exec mode: Generic-SPMD.
! Trip count: Set (outer loop).
!$omp target teams distribute
do i=...
!$omp parallel do private(idx, y)
do j=...
end do
end do
! Exec mode: Generic.
! Trip count: Set.
!$omp target teams distribute
do i=...
end do
! Exec mode: SPMD.
! Trip count: Not set.
!$omp target parallel do
do i=...
end do
! Exec mode: Generic.
! Trip count: Not set.
!$omp target
...
!$omp end target For the split For the time being, instead of relying on the openmp-opt pass, we look at the MLIR representation to find the Generic-SPMD pattern and directly tag the kernel as such during codegen. This is what we were already doing, but incorrectly matching other kinds of kernels as such in the process. Patch is 22.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133094.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 690e3df1f685e..9dbe6897a3304 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -222,6 +222,24 @@ def ScheduleModifier : OpenMP_I32EnumAttr<
def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
+//===----------------------------------------------------------------------===//
+// target_region_flags enum.
+//===----------------------------------------------------------------------===//
+
+def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
+def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>;
+def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>;
+def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>;
+
+def TargetRegionFlags : OpenMP_BitEnumAttr<
+ "TargetRegionFlags",
+ "target region property flags", [
+ TargetRegionFlagsNone,
+ TargetRegionFlagsGeneric,
+ TargetRegionFlagsSpmd,
+ TargetRegionFlagsTripCount
+ ]>;
+
//===----------------------------------------------------------------------===//
// variable_capture_kind enum.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 65095932be627..11530c0fa3620 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1312,7 +1312,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
///
/// \param capturedOp result of a still valid (no modifications made to any
/// nested operations) previous call to `getInnermostCapturedOmpOp()`.
- static llvm::omp::OMPTgtExecModeFlags
+ static ::mlir::omp::TargetRegionFlags
getKernelExecFlags(Operation *capturedOp);
}] # clausesExtraClassDeclaration;
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 882bc4071482f..5b46cab96dd88 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1908,8 +1908,8 @@ LogicalResult TargetOp::verifyRegions() {
return emitError("target containing multiple 'omp.teams' nested ops");
// Check that host_eval values are only used in legal ways.
- llvm::omp::OMPTgtExecModeFlags execFlags =
- getKernelExecFlags(getInnermostCapturedOmpOp());
+ Operation *capturedOp = getInnermostCapturedOmpOp();
+ TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
for (Value hostEvalArg :
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
for (Operation *user : hostEvalArg.getUsers()) {
@@ -1924,7 +1924,8 @@ LogicalResult TargetOp::verifyRegions() {
"and 'thread_limit' in 'omp.teams'";
}
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
- if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
+ if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
+ parallelOp->isAncestor(capturedOp) &&
hostEvalArg == parallelOp.getNumThreads())
continue;
@@ -1933,15 +1934,16 @@ LogicalResult TargetOp::verifyRegions() {
"'omp.parallel' when representing target SPMD";
}
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
- if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
+ if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
+ loopNestOp.getOperation() == capturedOp &&
(llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
continue;
return emitOpError() << "host_eval argument only legal as loop bounds "
- "and steps in 'omp.loop_nest' when "
- "representing target SPMD or Generic-SPMD";
+ "and steps in 'omp.loop_nest' when trip count "
+ "must be evaluated in the host";
}
return emitOpError() << "host_eval argument illegal use in '"
@@ -1951,33 +1953,12 @@ LogicalResult TargetOp::verifyRegions() {
return success();
}
-/// Only allow OpenMP terminators and non-OpenMP ops that have known memory
-/// effects, but don't include a memory write effect.
-static bool siblingAllowedInCapture(Operation *op) {
- if (!op)
- return false;
+static Operation *
+findCapturedOmpOp(Operation *rootOp,
+ llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
+ assert(rootOp && "expected valid operation");
- bool isOmpDialect =
- op->getContext()->getLoadedDialect<omp::OpenMPDialect>() ==
- op->getDialect();
-
- if (isOmpDialect)
- return op->hasTrait<OpTrait::IsTerminator>();
-
- if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
- SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects;
- memOp.getEffects(effects);
- return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
- return isa<MemoryEffects::Write>(effect.getEffect()) &&
- isa<SideEffects::AutomaticAllocationScopeResource>(
- effect.getResource());
- });
- }
- return true;
-}
-
-Operation *TargetOp::getInnermostCapturedOmpOp() {
- Dialect *ompDialect = (*this)->getDialect();
+ Dialect *ompDialect = rootOp->getDialect();
Operation *capturedOp = nullptr;
DominanceInfo domInfo;
@@ -1985,8 +1966,8 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
// ensuring we only enter the region of an operation if it meets the criteria
// for being captured. We stop the exploration of nested operations as soon as
// we process a region holding no operations to be captured.
- walk<WalkOrder::PreOrder>([&](Operation *op) {
- if (op == *this)
+ rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
+ if (op == rootOp)
return WalkResult::advance();
// Ignore operations of other dialects or omp operations with no regions,
@@ -2016,7 +1997,7 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
// Don't capture this op if it has a not-allowed sibling, and stop recursing
// into nested operations.
for (Operation &sibling : op->getParentRegion()->getOps())
- if (&sibling != op && !siblingAllowedInCapture(&sibling))
+ if (&sibling != op && !siblingAllowedFn(&sibling))
return WalkResult::interrupt();
// Don't continue capturing nested operations if we reach an omp.loop_nest.
@@ -2029,10 +2010,33 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
return capturedOp;
}
-llvm::omp::OMPTgtExecModeFlags
-TargetOp::getKernelExecFlags(Operation *capturedOp) {
- using namespace llvm::omp;
+Operation *TargetOp::getInnermostCapturedOmpOp() {
+ auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
+ // Only allow OpenMP terminators and non-OpenMP ops that have known memory
+ // effects, but don't include a memory write effect.
+ return findCapturedOmpOp(*this, [&](Operation *sibling) {
+ if (!sibling)
+ return false;
+
+ if (ompDialect == sibling->getDialect())
+ return sibling->hasTrait<OpTrait::IsTerminator>();
+
+ if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
+ SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
+ effects;
+ memOp.getEffects(effects);
+ return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
+ return isa<MemoryEffects::Write>(effect.getEffect()) &&
+ isa<SideEffects::AutomaticAllocationScopeResource>(
+ effect.getResource());
+ });
+ }
+ return true;
+ });
+}
+
+TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
// A non-null captured op is only valid if it resides inside of a TargetOp
// and is the result of calling getInnermostCapturedOmpOp() on it.
TargetOp targetOp =
@@ -2041,57 +2045,106 @@ TargetOp::getKernelExecFlags(Operation *capturedOp) {
(targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
"unexpected captured op");
- // Make sure this region is capturing a loop. Otherwise, it's a generic
- // kernel.
+ // If it's not capturing a loop, it's a default target region.
if (!isa_and_present<LoopNestOp>(capturedOp))
- return OMP_TGT_EXEC_MODE_GENERIC;
-
- SmallVector<LoopWrapperInterface> wrappers;
- cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
- assert(!wrappers.empty());
+ return TargetRegionFlags::generic;
- // Ignore optional SIMD leaf construct.
- auto *innermostWrapper = wrappers.begin();
- if (isa<SimdOp>(innermostWrapper))
- innermostWrapper = std::next(innermostWrapper);
+ auto getInnermostWrapper = [](LoopNestOp loopOp, int &numWrappers) {
+ SmallVector<LoopWrapperInterface> wrappers;
+ loopOp.gatherWrappers(wrappers);
+ assert(!wrappers.empty());
- long numWrappers = std::distance(innermostWrapper, wrappers.end());
+ // Ignore optional SIMD leaf construct.
+ auto *wrapper = wrappers.begin();
+ if (isa<SimdOp>(wrapper))
+ wrapper = std::next(wrapper);
- // Detect Generic-SPMD: target-teams-distribute[-simd].
- if (numWrappers == 1) {
- if (!isa<DistributeOp>(innermostWrapper))
- return OMP_TGT_EXEC_MODE_GENERIC;
+ numWrappers = static_cast<int>(std::distance(wrapper, wrappers.end()));
+ return wrapper;
+ };
- Operation *teamsOp = (*innermostWrapper)->getParentOp();
- if (!isa_and_present<TeamsOp>(teamsOp))
- return OMP_TGT_EXEC_MODE_GENERIC;
+ int numWrappers;
+ LoopWrapperInterface *innermostWrapper =
+ getInnermostWrapper(cast<LoopNestOp>(capturedOp), numWrappers);
- if (teamsOp->getParentOp() == targetOp.getOperation())
- return OMP_TGT_EXEC_MODE_GENERIC_SPMD;
- }
+ if (numWrappers != 1 && numWrappers != 2)
+ return TargetRegionFlags::generic;
- // Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
+ // Detect target-teams-distribute-parallel-wsloop[-simd].
if (numWrappers == 2) {
if (!isa<WsloopOp>(innermostWrapper))
- return OMP_TGT_EXEC_MODE_GENERIC;
+ return TargetRegionFlags::generic;
innermostWrapper = std::next(innermostWrapper);
if (!isa<DistributeOp>(innermostWrapper))
- return OMP_TGT_EXEC_MODE_GENERIC;
+ return TargetRegionFlags::generic;
Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
- return OMP_TGT_EXEC_MODE_GENERIC;
+ return TargetRegionFlags::generic;
Operation *teamsOp = parallelOp->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
- return OMP_TGT_EXEC_MODE_GENERIC;
+ return TargetRegionFlags::generic;
if (teamsOp->getParentOp() == targetOp.getOperation())
- return OMP_TGT_EXEC_MODE_SPMD;
+ return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
+ }
+ // Detect target-teams-distribute[-simd].
+ else if (isa<DistributeOp>(innermostWrapper)) {
+ Operation *teamsOp = (*innermostWrapper)->getParentOp();
+ if (!isa_and_present<TeamsOp>(teamsOp))
+ return TargetRegionFlags::generic;
+
+ if (teamsOp->getParentOp() != targetOp.getOperation())
+ return TargetRegionFlags::generic;
+
+ TargetRegionFlags result =
+ TargetRegionFlags::generic | TargetRegionFlags::trip_count;
+
+ // Find single nested parallel-do and add spmd flag (generic-spmd case).
+ // TODO: This shouldn't have to be done here, as it is too easy to break.
+ // The openmp-opt pass should be updated to be able to promote kernels like
+ // this from "Generic" to "Generic-SPMD". However, the use of the
+ // `kmpc_distribute_static_loop` family of functions produced by the
+ // OMPIRBuilder for these kernels prevents that from working.
+ Dialect *ompDialect = targetOp->getDialect();
+ Operation *nestedCapture =
+ findCapturedOmpOp(capturedOp, [&](Operation *sibling) {
+ return sibling && (ompDialect != sibling->getDialect() ||
+ sibling->hasTrait<OpTrait::IsTerminator>());
+ });
+
+ if (!isa_and_present<LoopNestOp>(nestedCapture))
+ return result;
+
+ int numNestedWrappers;
+ LoopWrapperInterface *nestedWrapper =
+ getInnermostWrapper(cast<LoopNestOp>(nestedCapture), numNestedWrappers);
+
+ if (numNestedWrappers != 1 || !isa<WsloopOp>(nestedWrapper))
+ return result;
+
+ Operation *parallelOp = (*nestedWrapper)->getParentOp();
+ if (!isa_and_present<ParallelOp>(parallelOp))
+ return result;
+
+ if (parallelOp->getParentOp() != capturedOp)
+ return result;
+
+ return result | TargetRegionFlags::spmd;
+ }
+ // Detect target-parallel-wsloop[-simd].
+ else if (isa<WsloopOp>(innermostWrapper)) {
+ Operation *parallelOp = (*innermostWrapper)->getParentOp();
+ if (!isa_and_present<ParallelOp>(parallelOp))
+ return TargetRegionFlags::generic;
+
+ if (parallelOp->getParentOp() == targetOp.getOperation())
+ return TargetRegionFlags::spmd;
}
- return OMP_TGT_EXEC_MODE_GENERIC;
+ return TargetRegionFlags::generic;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d41489921bd13..4d610d6e2656d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4646,7 +4646,17 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
combinedMaxThreadsVal = maxThreadsVal;
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
- attrs.ExecFlags = targetOp.getKernelExecFlags(capturedOp);
+ omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
+ assert(
+ omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
+ omp::TargetRegionFlags::spmd) &&
+ "invalid kernel flags");
+ attrs.ExecFlags =
+ omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
+ ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
+ ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
+ : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
+ : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
attrs.MinTeams = minTeamsVal;
attrs.MaxTeams.front() = maxTeamsVal;
attrs.MinThreads = 1;
@@ -4691,8 +4701,8 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
if (numThreads)
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
- if (targetOp.getKernelExecFlags(capturedOp) !=
- llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
+ if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
+ omp::TargetRegionFlags::trip_count)) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
attrs.LoopTripCount = nullptr;
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 403128bb2300e..bd0541987339a 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2320,7 +2320,7 @@ func.func @omp_target_host_eval_parallel(%x : i32) {
// -----
func.func @omp_target_host_eval_loop1(%x : i32) {
- // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
+ // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}}
omp.target host_eval(%x -> %arg0 : i32) {
omp.wsloop {
omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
@@ -2335,7 +2335,7 @@ func.func @omp_target_host_eval_loop1(%x : i32) {
// -----
func.func @omp_target_host_eval_loop2(%x : i32) {
- // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
+ // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}}
omp.target host_eval(%x -> %arg0 : i32) {
omp.teams {
^bb0:
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index a5cf789402726..e3d2f8bd01018 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2864,6 +2864,23 @@ func.func @omp_target_host_eval(%x : i32) {
omp.terminator
}
+ // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
+ // CHECK: omp.parallel num_threads(%[[HOST_ARG]] : i32) {
+ // CHECK: omp.wsloop {
+ // CHECK: omp.loop_nest
+ omp.target host_eval(%x -> %arg0 : i32) {
+ %y = arith.constant 2 : i32
+ omp.parallel num_threads(%arg0 : i32) {
+ omp.wsloop {
+ omp.loop_nest (%iv) : i32 = (%y) to (%y) step (%y) {
+ omp.yield
+ }
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+
// CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
// CHECK: omp.teams {
// CHECK: omp.distribute {
diff --git a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
new file mode 100644
index 0000000000000..8101660e571e4
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
@@ -0,0 +1,111 @@
+// RUN: split-file %s %t
+// RUN: mlir-translate -mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=HOST
+// RUN: mlir-translate -mlir-to-llvmir %t/device.mlir | FileCheck %s --check-prefix=DEVICE
+
+//--- host.mlir
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+ llvm.func @main(%arg0 : !llvm.ptr) {
+ %x = llvm.load %arg0 : !llvm.ptr -> i32
+ %0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
+ omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) map_entries(%0 -> %ptr : !llvm.ptr) {
+ %x.map = llvm.load %ptr : !llvm.ptr -> i32
+ omp.teams {
+ omp.distribute {
+ omp.loop_nest (%iv1) : i32 = (%lb) to (%ub) step (%step) {
+ omp.parallel {
+ omp.wsloop {
+ omp.loop_nest (%iv2) : i32 = (%x.map) to (%x.map) step (%x.map) {
+ omp.yield
+ }
+ }
+ omp.terminator
+ }
+ omp.yield
+ }
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ llvm.return
+ }
+}
+
+// HOST-LABEL: define void @main
+// HOST: %omp_loop.tripcount = {{.*}}
+// HOST-NEXT: br label %[[ENTRY:.*]]
+// HOST: [[ENTRY]]:
+// HOST: %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
+// HOST: %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
+// HOST-NEXT: store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
+// HOST: %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
+// HOST-NEXT: %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
+// HOST-NEXT: br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
+// HOST: [[OFFLOAD_FAILED]]:
+// HOST: call void @[[TARGET_OUTLINE:.*]]({{.*}})
+
+// HOST: define internal void @[[TARGET_OUTLINE]]
+// HOST: call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
+
+// HOST: define internal void @[[TEAMS_OUTLINE]]
+// HOST: call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
+
+// HOST: define internal void @[[DISTRIBUTE_OUTLINE]]
+// HOST: call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 92, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
+// HOST: call void (ptr, i32, ptr, ...) @__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*...
[truncated]
|
@llvm/pr-subscribers-flang-openmp Author: Sergio Afonso (skatrak) ChangesThis patch updates the handling of target regions to set trip counts and kernel execution modes properly, based on clang's behavior. This fixes a race condition on This is how kernels are classified, after changes introduced in this patch: ! Exec mode: SPMD.
! Trip count: Set.
!$omp target teams distribute parallel do
do i=...
end do
! Exec mode: Generic-SPMD.
! Trip count: Set (outer loop).
!$omp target teams distribute
do i=...
!$omp parallel do private(idx, y)
do j=...
end do
end do
! Exec mode: Generic.
! Trip count: Set.
!$omp target teams distribute
do i=...
end do
! Exec mode: SPMD.
! Trip count: Not set.
!$omp target parallel do
do i=...
end do
! Exec mode: Generic.
! Trip count: Not set.
!$omp target
...
!$omp end target For the split For the time being, instead of relying on the openmp-opt pass, we look at the MLIR representation to find the Generic-SPMD pattern and directly tag the kernel as such during codegen. This is what we were already doing, but incorrectly matching other kinds of kernels as such in the process. Patch is 22.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133094.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 690e3df1f685e..9dbe6897a3304 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -222,6 +222,24 @@ def ScheduleModifier : OpenMP_I32EnumAttr<
def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
+//===----------------------------------------------------------------------===//
+// target_region_flags enum.
+//===----------------------------------------------------------------------===//
+
+def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
+def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>;
+def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>;
+def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>;
+
+def TargetRegionFlags : OpenMP_BitEnumAttr<
+ "TargetRegionFlags",
+ "target region property flags", [
+ TargetRegionFlagsNone,
+ TargetRegionFlagsGeneric,
+ TargetRegionFlagsSpmd,
+ TargetRegionFlagsTripCount
+ ]>;
+
//===----------------------------------------------------------------------===//
// variable_capture_kind enum.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 65095932be627..11530c0fa3620 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1312,7 +1312,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
///
/// \param capturedOp result of a still valid (no modifications made to any
/// nested operations) previous call to `getInnermostCapturedOmpOp()`.
- static llvm::omp::OMPTgtExecModeFlags
+ static ::mlir::omp::TargetRegionFlags
getKernelExecFlags(Operation *capturedOp);
}] # clausesExtraClassDeclaration;
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 882bc4071482f..5b46cab96dd88 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1908,8 +1908,8 @@ LogicalResult TargetOp::verifyRegions() {
return emitError("target containing multiple 'omp.teams' nested ops");
// Check that host_eval values are only used in legal ways.
- llvm::omp::OMPTgtExecModeFlags execFlags =
- getKernelExecFlags(getInnermostCapturedOmpOp());
+ Operation *capturedOp = getInnermostCapturedOmpOp();
+ TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
for (Value hostEvalArg :
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
for (Operation *user : hostEvalArg.getUsers()) {
@@ -1924,7 +1924,8 @@ LogicalResult TargetOp::verifyRegions() {
"and 'thread_limit' in 'omp.teams'";
}
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
- if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
+ if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
+ parallelOp->isAncestor(capturedOp) &&
hostEvalArg == parallelOp.getNumThreads())
continue;
@@ -1933,15 +1934,16 @@ LogicalResult TargetOp::verifyRegions() {
"'omp.parallel' when representing target SPMD";
}
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
- if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
+ if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
+ loopNestOp.getOperation() == capturedOp &&
(llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
continue;
return emitOpError() << "host_eval argument only legal as loop bounds "
- "and steps in 'omp.loop_nest' when "
- "representing target SPMD or Generic-SPMD";
+ "and steps in 'omp.loop_nest' when trip count "
+ "must be evaluated in the host";
}
return emitOpError() << "host_eval argument illegal use in '"
@@ -1951,33 +1953,12 @@ LogicalResult TargetOp::verifyRegions() {
return success();
}
-/// Only allow OpenMP terminators and non-OpenMP ops that have known memory
-/// effects, but don't include a memory write effect.
-static bool siblingAllowedInCapture(Operation *op) {
- if (!op)
- return false;
+static Operation *
+findCapturedOmpOp(Operation *rootOp,
+ llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
+ assert(rootOp && "expected valid operation");
- bool isOmpDialect =
- op->getContext()->getLoadedDialect<omp::OpenMPDialect>() ==
- op->getDialect();
-
- if (isOmpDialect)
- return op->hasTrait<OpTrait::IsTerminator>();
-
- if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
- SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects;
- memOp.getEffects(effects);
- return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
- return isa<MemoryEffects::Write>(effect.getEffect()) &&
- isa<SideEffects::AutomaticAllocationScopeResource>(
- effect.getResource());
- });
- }
- return true;
-}
-
-Operation *TargetOp::getInnermostCapturedOmpOp() {
- Dialect *ompDialect = (*this)->getDialect();
+ Dialect *ompDialect = rootOp->getDialect();
Operation *capturedOp = nullptr;
DominanceInfo domInfo;
@@ -1985,8 +1966,8 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
// ensuring we only enter the region of an operation if it meets the criteria
// for being captured. We stop the exploration of nested operations as soon as
// we process a region holding no operations to be captured.
- walk<WalkOrder::PreOrder>([&](Operation *op) {
- if (op == *this)
+ rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
+ if (op == rootOp)
return WalkResult::advance();
// Ignore operations of other dialects or omp operations with no regions,
@@ -2016,7 +1997,7 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
// Don't capture this op if it has a not-allowed sibling, and stop recursing
// into nested operations.
for (Operation &sibling : op->getParentRegion()->getOps())
- if (&sibling != op && !siblingAllowedInCapture(&sibling))
+ if (&sibling != op && !siblingAllowedFn(&sibling))
return WalkResult::interrupt();
// Don't continue capturing nested operations if we reach an omp.loop_nest.
@@ -2029,10 +2010,33 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
return capturedOp;
}
-llvm::omp::OMPTgtExecModeFlags
-TargetOp::getKernelExecFlags(Operation *capturedOp) {
- using namespace llvm::omp;
+Operation *TargetOp::getInnermostCapturedOmpOp() {
+ auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
+ // Only allow OpenMP terminators and non-OpenMP ops that have known memory
+ // effects, but don't include a memory write effect.
+ return findCapturedOmpOp(*this, [&](Operation *sibling) {
+ if (!sibling)
+ return false;
+
+ if (ompDialect == sibling->getDialect())
+ return sibling->hasTrait<OpTrait::IsTerminator>();
+
+ if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
+ SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
+ effects;
+ memOp.getEffects(effects);
+ return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
+ return isa<MemoryEffects::Write>(effect.getEffect()) &&
+ isa<SideEffects::AutomaticAllocationScopeResource>(
+ effect.getResource());
+ });
+ }
+ return true;
+ });
+}
+
+TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
// A non-null captured op is only valid if it resides inside of a TargetOp
// and is the result of calling getInnermostCapturedOmpOp() on it.
TargetOp targetOp =
@@ -2041,57 +2045,106 @@ TargetOp::getKernelExecFlags(Operation *capturedOp) {
(targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
"unexpected captured op");
- // Make sure this region is capturing a loop. Otherwise, it's a generic
- // kernel.
+ // If it's not capturing a loop, it's a default target region.
if (!isa_and_present<LoopNestOp>(capturedOp))
- return OMP_TGT_EXEC_MODE_GENERIC;
-
- SmallVector<LoopWrapperInterface> wrappers;
- cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
- assert(!wrappers.empty());
+ return TargetRegionFlags::generic;
- // Ignore optional SIMD leaf construct.
- auto *innermostWrapper = wrappers.begin();
- if (isa<SimdOp>(innermostWrapper))
- innermostWrapper = std::next(innermostWrapper);
+ auto getInnermostWrapper = [](LoopNestOp loopOp, int &numWrappers) {
+ SmallVector<LoopWrapperInterface> wrappers;
+ loopOp.gatherWrappers(wrappers);
+ assert(!wrappers.empty());
- long numWrappers = std::distance(innermostWrapper, wrappers.end());
+ // Ignore optional SIMD leaf construct.
+ auto *wrapper = wrappers.begin();
+ if (isa<SimdOp>(wrapper))
+ wrapper = std::next(wrapper);
- // Detect Generic-SPMD: target-teams-distribute[-simd].
- if (numWrappers == 1) {
- if (!isa<DistributeOp>(innermostWrapper))
- return OMP_TGT_EXEC_MODE_GENERIC;
+ numWrappers = static_cast<int>(std::distance(wrapper, wrappers.end()));
+ return wrapper;
+ };
- Operation *teamsOp = (*innermostWrapper)->getParentOp();
- if (!isa_and_present<TeamsOp>(teamsOp))
- return OMP_TGT_EXEC_MODE_GENERIC;
+ int numWrappers;
+ LoopWrapperInterface *innermostWrapper =
+ getInnermostWrapper(cast<LoopNestOp>(capturedOp), numWrappers);
- if (teamsOp->getParentOp() == targetOp.getOperation())
- return OMP_TGT_EXEC_MODE_GENERIC_SPMD;
- }
+ if (numWrappers != 1 && numWrappers != 2)
+ return TargetRegionFlags::generic;
- // Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
+ // Detect target-teams-distribute-parallel-wsloop[-simd].
if (numWrappers == 2) {
if (!isa<WsloopOp>(innermostWrapper))
- return OMP_TGT_EXEC_MODE_GENERIC;
+ return TargetRegionFlags::generic;
innermostWrapper = std::next(innermostWrapper);
if (!isa<DistributeOp>(innermostWrapper))
- return OMP_TGT_EXEC_MODE_GENERIC;
+ return TargetRegionFlags::generic;
Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
- return OMP_TGT_EXEC_MODE_GENERIC;
+ return TargetRegionFlags::generic;
Operation *teamsOp = parallelOp->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
- return OMP_TGT_EXEC_MODE_GENERIC;
+ return TargetRegionFlags::generic;
if (teamsOp->getParentOp() == targetOp.getOperation())
- return OMP_TGT_EXEC_MODE_SPMD;
+ return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
+ }
+ // Detect target-teams-distribute[-simd].
+ else if (isa<DistributeOp>(innermostWrapper)) {
+ Operation *teamsOp = (*innermostWrapper)->getParentOp();
+ if (!isa_and_present<TeamsOp>(teamsOp))
+ return TargetRegionFlags::generic;
+
+ if (teamsOp->getParentOp() != targetOp.getOperation())
+ return TargetRegionFlags::generic;
+
+ TargetRegionFlags result =
+ TargetRegionFlags::generic | TargetRegionFlags::trip_count;
+
+ // Find single nested parallel-do and add spmd flag (generic-spmd case).
+ // TODO: This shouldn't have to be done here, as it is too easy to break.
+ // The openmp-opt pass should be updated to be able to promote kernels like
+ // this from "Generic" to "Generic-SPMD". However, the use of the
+ // `kmpc_distribute_static_loop` family of functions produced by the
+ // OMPIRBuilder for these kernels prevents that from working.
+ Dialect *ompDialect = targetOp->getDialect();
+ Operation *nestedCapture =
+ findCapturedOmpOp(capturedOp, [&](Operation *sibling) {
+ return sibling && (ompDialect != sibling->getDialect() ||
+ sibling->hasTrait<OpTrait::IsTerminator>());
+ });
+
+ if (!isa_and_present<LoopNestOp>(nestedCapture))
+ return result;
+
+ int numNestedWrappers;
+ LoopWrapperInterface *nestedWrapper =
+ getInnermostWrapper(cast<LoopNestOp>(nestedCapture), numNestedWrappers);
+
+ if (numNestedWrappers != 1 || !isa<WsloopOp>(nestedWrapper))
+ return result;
+
+ Operation *parallelOp = (*nestedWrapper)->getParentOp();
+ if (!isa_and_present<ParallelOp>(parallelOp))
+ return result;
+
+ if (parallelOp->getParentOp() != capturedOp)
+ return result;
+
+ return result | TargetRegionFlags::spmd;
+ }
+ // Detect target-parallel-wsloop[-simd].
+ else if (isa<WsloopOp>(innermostWrapper)) {
+ Operation *parallelOp = (*innermostWrapper)->getParentOp();
+ if (!isa_and_present<ParallelOp>(parallelOp))
+ return TargetRegionFlags::generic;
+
+ if (parallelOp->getParentOp() == targetOp.getOperation())
+ return TargetRegionFlags::spmd;
}
- return OMP_TGT_EXEC_MODE_GENERIC;
+ return TargetRegionFlags::generic;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d41489921bd13..4d610d6e2656d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4646,7 +4646,17 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
combinedMaxThreadsVal = maxThreadsVal;
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
- attrs.ExecFlags = targetOp.getKernelExecFlags(capturedOp);
+ omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
+ assert(
+ omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
+ omp::TargetRegionFlags::spmd) &&
+ "invalid kernel flags");
+ attrs.ExecFlags =
+ omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
+ ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
+ ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
+ : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
+ : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
attrs.MinTeams = minTeamsVal;
attrs.MaxTeams.front() = maxTeamsVal;
attrs.MinThreads = 1;
@@ -4691,8 +4701,8 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
if (numThreads)
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
- if (targetOp.getKernelExecFlags(capturedOp) !=
- llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
+ if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
+ omp::TargetRegionFlags::trip_count)) {
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
attrs.LoopTripCount = nullptr;
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 403128bb2300e..bd0541987339a 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2320,7 +2320,7 @@ func.func @omp_target_host_eval_parallel(%x : i32) {
// -----
func.func @omp_target_host_eval_loop1(%x : i32) {
- // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
+ // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}}
omp.target host_eval(%x -> %arg0 : i32) {
omp.wsloop {
omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
@@ -2335,7 +2335,7 @@ func.func @omp_target_host_eval_loop1(%x : i32) {
// -----
func.func @omp_target_host_eval_loop2(%x : i32) {
- // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
+ // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}}
omp.target host_eval(%x -> %arg0 : i32) {
omp.teams {
^bb0:
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index a5cf789402726..e3d2f8bd01018 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2864,6 +2864,23 @@ func.func @omp_target_host_eval(%x : i32) {
omp.terminator
}
+ // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
+ // CHECK: omp.parallel num_threads(%[[HOST_ARG]] : i32) {
+ // CHECK: omp.wsloop {
+ // CHECK: omp.loop_nest
+ omp.target host_eval(%x -> %arg0 : i32) {
+ %y = arith.constant 2 : i32
+ omp.parallel num_threads(%arg0 : i32) {
+ omp.wsloop {
+ omp.loop_nest (%iv) : i32 = (%y) to (%y) step (%y) {
+ omp.yield
+ }
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+
// CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
// CHECK: omp.teams {
// CHECK: omp.distribute {
diff --git a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
new file mode 100644
index 0000000000000..8101660e571e4
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
@@ -0,0 +1,111 @@
+// RUN: split-file %s %t
+// RUN: mlir-translate -mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=HOST
+// RUN: mlir-translate -mlir-to-llvmir %t/device.mlir | FileCheck %s --check-prefix=DEVICE
+
+//--- host.mlir
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+ llvm.func @main(%arg0 : !llvm.ptr) {
+ %x = llvm.load %arg0 : !llvm.ptr -> i32
+ %0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
+ omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) map_entries(%0 -> %ptr : !llvm.ptr) {
+ %x.map = llvm.load %ptr : !llvm.ptr -> i32
+ omp.teams {
+ omp.distribute {
+ omp.loop_nest (%iv1) : i32 = (%lb) to (%ub) step (%step) {
+ omp.parallel {
+ omp.wsloop {
+ omp.loop_nest (%iv2) : i32 = (%x.map) to (%x.map) step (%x.map) {
+ omp.yield
+ }
+ }
+ omp.terminator
+ }
+ omp.yield
+ }
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ llvm.return
+ }
+}
+
+// HOST-LABEL: define void @main
+// HOST: %omp_loop.tripcount = {{.*}}
+// HOST-NEXT: br label %[[ENTRY:.*]]
+// HOST: [[ENTRY]]:
+// HOST: %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
+// HOST: %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
+// HOST-NEXT: store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
+// HOST: %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
+// HOST-NEXT: %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
+// HOST-NEXT: br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
+// HOST: [[OFFLOAD_FAILED]]:
+// HOST: call void @[[TARGET_OUTLINE:.*]]({{.*}})
+
+// HOST: define internal void @[[TARGET_OUTLINE]]
+// HOST: call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
+
+// HOST: define internal void @[[TEAMS_OUTLINE]]
+// HOST: call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
+
+// HOST: define internal void @[[DISTRIBUTE_OUTLINE]]
+// HOST: call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 92, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
+// HOST: call void (ptr, i32, ptr, ...) @__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*...
[truncated]
|
assert( | ||
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic | | ||
omp::TargetRegionFlags::spmd) && | ||
"invalid kernel flags"); | ||
attrs.ExecFlags = | ||
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic) | ||
? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd) | ||
? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD | ||
: llvm::omp::OMP_TGT_EXEC_MODE_GENERIC | ||
: llvm::omp::OMP_TGT_EXEC_MODE_SPMD; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if a trip_count
hits here (or why isn' it possible)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The trip_count
flag does not impact the ExecFlags
, so it could be present but is ignored here. It is checked by initTargetRuntimeAttrs
to decide whether to evaluate and pass the trip count from the host in the kernel arguments structure to the kernel launch call.
f81e281
to
dc1ac96
Compare
SmallVector<LoopWrapperInterface> wrappers; | ||
cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers); | ||
assert(!wrappers.empty()); | ||
auto getInnermostWrapper = [](LoopNestOp loopOp, int &numWrappers) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not need to be a lambda function. It is better as a static function imo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just updated this to actually remove the function altogether. I created it when I was expecting to use it multiple times, but in the end it only was necessary in one spot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This patch updates the handling of target regions to set trip counts and kernel execution modes properly, based on clang's behavior. This fixes a race condition on `target teams distribute` constructs with no `parallel do` loop inside. This is how kernels are classified, after changes introduced in this patch: ```f90 ! Exec mode: SPMD. ! Trip count: Set. !$omp target teams distribute parallel do do i=... end do ! Exec mode: Generic-SPMD. ! Trip count: Set (outer loop). !$omp target teams distribute do i=... !$omp parallel do private(idx, y) do j=... end do end do ! Exec mode: Generic-SPMD. ! Trip count: Set (outer loop). !$omp target teams distribute do i=... !$omp parallel ... !$omp end parallel end do ! Exec mode: Generic. ! Trip count: Set. !$omp target teams distribute do i=... end do ! Exec mode: SPMD. ! Trip count: Not set. !$omp target parallel do do i=... end do ! Exec mode: Generic. ! Trip count: Not set. !$omp target ... !$omp end target ``` For the split `target teams distribute + parallel do` case, clang produces a Generic kernel which gets promoted to Generic-SPMD by the openmp-opt pass. We can't currently replicate that behavior in flang because our codegen for these constructs results in the introduction of calls to the `kmpc_distribute_static_loop` family of functions, instead of `kmpc_distribute_static_init`, which currently prevent promotion of the kernel to Generic-SPMD. For the time being, instead of relying on the openmp-opt pass, we look at the MLIR representation to find the Generic-SPMD pattern and directly tag the kernel as such during codegen. This is what we were already doing, but incorrectly matching other kinds of kernels as such in the process.
5653ae0
to
4033b12
Compare
This patch updates the handling of target regions to set trip counts and kernel execution modes properly, based on clang's behavior. This fixes a race condition on
target teams distribute
constructs with noparallel do
loop inside.This is how kernels are classified, after changes introduced in this patch:
For the split
target teams distribute + parallel do
case, clang produces a Generic kernel which gets promoted to Generic-SPMD by the openmp-opt pass. We can't currently replicate that behavior in flang because our codegen for these constructs results in the introduction of calls to thekmpc_distribute_static_loop
family of functions, instead ofkmpc_distribute_static_init
, which currently prevent promotion of the kernel to Generic-SPMD.For the time being, instead of relying on the openmp-opt pass, we look at the MLIR representation to find the Generic-SPMD pattern and directly tag the kernel as such during codegen. This is what we were already doing, but incorrectly matching other kinds of kernels as such in the process.