Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][OpenMP] Fix standalone distribute on the device #133094

Merged
merged 1 commit into from
Apr 3, 2025

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Mar 26, 2025

This patch updates the handling of target regions to set trip counts and kernel execution modes properly, based on clang's behavior. This fixes a race condition on target teams distribute constructs with no parallel do loop inside.

This is how kernels are classified, after changes introduced in this patch:

! Exec mode: SPMD.
! Trip count: Set.
!$omp target teams distribute parallel do
do i=...
end do

! Exec mode: Generic-SPMD.
! Trip count: Set (outer loop).
!$omp target teams distribute
do i=...
  !$omp parallel do private(idx, y)
  do j=...
  end do
end do

! Exec mode: Generic-SPMD.
! Trip count: Set (outer loop).
!$omp target teams distribute
do i=...
  !$omp parallel
    ...
  !$omp end parallel
end do

! Exec mode: Generic.
! Trip count: Set.
!$omp target teams distribute
do i=...
end do

! Exec mode: SPMD.
! Trip count: Not set.
!$omp target parallel do
do i=...
end do

! Exec mode: Generic.
! Trip count: Not set.
!$omp target
  ...
!$omp end target

For the split target teams distribute + parallel do case, clang produces a Generic kernel which gets promoted to Generic-SPMD by the openmp-opt pass. We can't currently replicate that behavior in flang because our codegen for these constructs results in the introduction of calls to the kmpc_distribute_static_loop family of functions, instead of kmpc_distribute_static_init, which currently prevent promotion of the kernel to Generic-SPMD.

For the time being, instead of relying on the openmp-opt pass, we look at the MLIR representation to find the Generic-SPMD pattern and directly tag the kernel as such during codegen. This is what we were already doing, but incorrectly matching other kinds of kernels as such in the process.

@llvmbot
Copy link
Member

llvmbot commented Mar 26, 2025

@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-openmp

Author: Sergio Afonso (skatrak)

Changes

This patch updates the handling of target regions to set trip counts and kernel execution modes properly, based on clang's behavior. This fixes a race condition on target teams distribute constructs with no parallel do loop inside.

This is how kernels are classified, after changes introduced in this patch:

! Exec mode: SPMD.
! Trip count: Set.
!$omp target teams distribute parallel do
do i=...
end do

! Exec mode: Generic-SPMD.
! Trip count: Set (outer loop).
!$omp target teams distribute
do i=...
  !$omp parallel do private(idx, y)
  do j=...
  end do
end do

! Exec mode: Generic.
! Trip count: Set.
!$omp target teams distribute
do i=...
end do

! Exec mode: SPMD.
! Trip count: Not set.
!$omp target parallel do
do i=...
end do

! Exec mode: Generic.
! Trip count: Not set.
!$omp target
  ...
!$omp end target

For the split target teams distribute + parallel do case, clang produces a Generic kernel which gets promoted to Generic-SPMD by the openmp-opt pass. We can't currently replicate that behavior in flang because our codegen for these constructs results in the introduction of calls to the kmpc_distribute_static_loop family of functions, instead of kmpc_distribute_static_init, which currently prevent promotion of the kernel to Generic-SPMD.

For the time being, instead of relying on the openmp-opt pass, we look at the MLIR representation to find the Generic-SPMD pattern and directly tag the kernel as such during codegen. This is what we were already doing, but incorrectly matching other kinds of kernels as such in the process.


Patch is 22.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133094.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td (+18)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+1-1)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+120-67)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+13-3)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+2-2)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+17)
  • (added) mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir (+111)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 690e3df1f685e..9dbe6897a3304 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -222,6 +222,24 @@ def ScheduleModifier : OpenMP_I32EnumAttr<
 
 def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
 
+//===----------------------------------------------------------------------===//
+// target_region_flags enum.
+//===----------------------------------------------------------------------===//
+
+def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
+def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>;
+def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>;
+def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>;
+
+def TargetRegionFlags : OpenMP_BitEnumAttr<
+    "TargetRegionFlags",
+    "target region property flags", [
+      TargetRegionFlagsNone,
+      TargetRegionFlagsGeneric,
+      TargetRegionFlagsSpmd,
+      TargetRegionFlagsTripCount
+    ]>;
+
 //===----------------------------------------------------------------------===//
 // variable_capture_kind enum.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 65095932be627..11530c0fa3620 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1312,7 +1312,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
     ///
     /// \param capturedOp result of a still valid (no modifications made to any
     /// nested operations) previous call to `getInnermostCapturedOmpOp()`.
-    static llvm::omp::OMPTgtExecModeFlags
+    static ::mlir::omp::TargetRegionFlags
     getKernelExecFlags(Operation *capturedOp);
   }] # clausesExtraClassDeclaration;
 
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 882bc4071482f..5b46cab96dd88 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1908,8 +1908,8 @@ LogicalResult TargetOp::verifyRegions() {
     return emitError("target containing multiple 'omp.teams' nested ops");
 
   // Check that host_eval values are only used in legal ways.
-  llvm::omp::OMPTgtExecModeFlags execFlags =
-      getKernelExecFlags(getInnermostCapturedOmpOp());
+  Operation *capturedOp = getInnermostCapturedOmpOp();
+  TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
   for (Value hostEvalArg :
        cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
     for (Operation *user : hostEvalArg.getUsers()) {
@@ -1924,7 +1924,8 @@ LogicalResult TargetOp::verifyRegions() {
                                 "and 'thread_limit' in 'omp.teams'";
       }
       if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
-        if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
+        if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
+            parallelOp->isAncestor(capturedOp) &&
             hostEvalArg == parallelOp.getNumThreads())
           continue;
 
@@ -1933,15 +1934,16 @@ LogicalResult TargetOp::verifyRegions() {
                   "'omp.parallel' when representing target SPMD";
       }
       if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
-        if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
+        if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
+            loopNestOp.getOperation() == capturedOp &&
             (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
              llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
              llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
           continue;
 
         return emitOpError() << "host_eval argument only legal as loop bounds "
-                                "and steps in 'omp.loop_nest' when "
-                                "representing target SPMD or Generic-SPMD";
+                                "and steps in 'omp.loop_nest' when trip count "
+                                "must be evaluated in the host";
       }
 
       return emitOpError() << "host_eval argument illegal use in '"
@@ -1951,33 +1953,12 @@ LogicalResult TargetOp::verifyRegions() {
   return success();
 }
 
-/// Only allow OpenMP terminators and non-OpenMP ops that have known memory
-/// effects, but don't include a memory write effect.
-static bool siblingAllowedInCapture(Operation *op) {
-  if (!op)
-    return false;
+static Operation *
+findCapturedOmpOp(Operation *rootOp,
+                  llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
+  assert(rootOp && "expected valid operation");
 
-  bool isOmpDialect =
-      op->getContext()->getLoadedDialect<omp::OpenMPDialect>() ==
-      op->getDialect();
-
-  if (isOmpDialect)
-    return op->hasTrait<OpTrait::IsTerminator>();
-
-  if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
-    SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects;
-    memOp.getEffects(effects);
-    return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
-      return isa<MemoryEffects::Write>(effect.getEffect()) &&
-             isa<SideEffects::AutomaticAllocationScopeResource>(
-                 effect.getResource());
-    });
-  }
-  return true;
-}
-
-Operation *TargetOp::getInnermostCapturedOmpOp() {
-  Dialect *ompDialect = (*this)->getDialect();
+  Dialect *ompDialect = rootOp->getDialect();
   Operation *capturedOp = nullptr;
   DominanceInfo domInfo;
 
@@ -1985,8 +1966,8 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
   // ensuring we only enter the region of an operation if it meets the criteria
   // for being captured. We stop the exploration of nested operations as soon as
   // we process a region holding no operations to be captured.
-  walk<WalkOrder::PreOrder>([&](Operation *op) {
-    if (op == *this)
+  rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
+    if (op == rootOp)
       return WalkResult::advance();
 
     // Ignore operations of other dialects or omp operations with no regions,
@@ -2016,7 +1997,7 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
     // Don't capture this op if it has a not-allowed sibling, and stop recursing
     // into nested operations.
     for (Operation &sibling : op->getParentRegion()->getOps())
-      if (&sibling != op && !siblingAllowedInCapture(&sibling))
+      if (&sibling != op && !siblingAllowedFn(&sibling))
         return WalkResult::interrupt();
 
     // Don't continue capturing nested operations if we reach an omp.loop_nest.
@@ -2029,10 +2010,33 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
   return capturedOp;
 }
 
-llvm::omp::OMPTgtExecModeFlags
-TargetOp::getKernelExecFlags(Operation *capturedOp) {
-  using namespace llvm::omp;
+Operation *TargetOp::getInnermostCapturedOmpOp() {
+  auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
 
+  // Only allow OpenMP terminators and non-OpenMP ops that have known memory
+  // effects, but don't include a memory write effect.
+  return findCapturedOmpOp(*this, [&](Operation *sibling) {
+    if (!sibling)
+      return false;
+
+    if (ompDialect == sibling->getDialect())
+      return sibling->hasTrait<OpTrait::IsTerminator>();
+
+    if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
+      SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
+          effects;
+      memOp.getEffects(effects);
+      return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
+        return isa<MemoryEffects::Write>(effect.getEffect()) &&
+               isa<SideEffects::AutomaticAllocationScopeResource>(
+                   effect.getResource());
+      });
+    }
+    return true;
+  });
+}
+
+TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
   // A non-null captured op is only valid if it resides inside of a TargetOp
   // and is the result of calling getInnermostCapturedOmpOp() on it.
   TargetOp targetOp =
@@ -2041,57 +2045,106 @@ TargetOp::getKernelExecFlags(Operation *capturedOp) {
           (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
          "unexpected captured op");
 
-  // Make sure this region is capturing a loop. Otherwise, it's a generic
-  // kernel.
+  // If it's not capturing a loop, it's a default target region.
   if (!isa_and_present<LoopNestOp>(capturedOp))
-    return OMP_TGT_EXEC_MODE_GENERIC;
-
-  SmallVector<LoopWrapperInterface> wrappers;
-  cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
-  assert(!wrappers.empty());
+    return TargetRegionFlags::generic;
 
-  // Ignore optional SIMD leaf construct.
-  auto *innermostWrapper = wrappers.begin();
-  if (isa<SimdOp>(innermostWrapper))
-    innermostWrapper = std::next(innermostWrapper);
+  auto getInnermostWrapper = [](LoopNestOp loopOp, int &numWrappers) {
+    SmallVector<LoopWrapperInterface> wrappers;
+    loopOp.gatherWrappers(wrappers);
+    assert(!wrappers.empty());
 
-  long numWrappers = std::distance(innermostWrapper, wrappers.end());
+    // Ignore optional SIMD leaf construct.
+    auto *wrapper = wrappers.begin();
+    if (isa<SimdOp>(wrapper))
+      wrapper = std::next(wrapper);
 
-  // Detect Generic-SPMD: target-teams-distribute[-simd].
-  if (numWrappers == 1) {
-    if (!isa<DistributeOp>(innermostWrapper))
-      return OMP_TGT_EXEC_MODE_GENERIC;
+    numWrappers = static_cast<int>(std::distance(wrapper, wrappers.end()));
+    return wrapper;
+  };
 
-    Operation *teamsOp = (*innermostWrapper)->getParentOp();
-    if (!isa_and_present<TeamsOp>(teamsOp))
-      return OMP_TGT_EXEC_MODE_GENERIC;
+  int numWrappers;
+  LoopWrapperInterface *innermostWrapper =
+      getInnermostWrapper(cast<LoopNestOp>(capturedOp), numWrappers);
 
-    if (teamsOp->getParentOp() == targetOp.getOperation())
-      return OMP_TGT_EXEC_MODE_GENERIC_SPMD;
-  }
+  if (numWrappers != 1 && numWrappers != 2)
+    return TargetRegionFlags::generic;
 
-  // Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
+  // Detect target-teams-distribute-parallel-wsloop[-simd].
   if (numWrappers == 2) {
     if (!isa<WsloopOp>(innermostWrapper))
-      return OMP_TGT_EXEC_MODE_GENERIC;
+      return TargetRegionFlags::generic;
 
     innermostWrapper = std::next(innermostWrapper);
     if (!isa<DistributeOp>(innermostWrapper))
-      return OMP_TGT_EXEC_MODE_GENERIC;
+      return TargetRegionFlags::generic;
 
     Operation *parallelOp = (*innermostWrapper)->getParentOp();
     if (!isa_and_present<ParallelOp>(parallelOp))
-      return OMP_TGT_EXEC_MODE_GENERIC;
+      return TargetRegionFlags::generic;
 
     Operation *teamsOp = parallelOp->getParentOp();
     if (!isa_and_present<TeamsOp>(teamsOp))
-      return OMP_TGT_EXEC_MODE_GENERIC;
+      return TargetRegionFlags::generic;
 
     if (teamsOp->getParentOp() == targetOp.getOperation())
-      return OMP_TGT_EXEC_MODE_SPMD;
+      return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
+  }
+  // Detect target-teams-distribute[-simd].
+  else if (isa<DistributeOp>(innermostWrapper)) {
+    Operation *teamsOp = (*innermostWrapper)->getParentOp();
+    if (!isa_and_present<TeamsOp>(teamsOp))
+      return TargetRegionFlags::generic;
+
+    if (teamsOp->getParentOp() != targetOp.getOperation())
+      return TargetRegionFlags::generic;
+
+    TargetRegionFlags result =
+        TargetRegionFlags::generic | TargetRegionFlags::trip_count;
+
+    // Find single nested parallel-do and add spmd flag (generic-spmd case).
+    // TODO: This shouldn't have to be done here, as it is too easy to break.
+    // The openmp-opt pass should be updated to be able to promote kernels like
+    // this from "Generic" to "Generic-SPMD". However, the use of the
+    // `kmpc_distribute_static_loop` family of functions produced by the
+    // OMPIRBuilder for these kernels prevents that from working.
+    Dialect *ompDialect = targetOp->getDialect();
+    Operation *nestedCapture =
+        findCapturedOmpOp(capturedOp, [&](Operation *sibling) {
+          return sibling && (ompDialect != sibling->getDialect() ||
+                             sibling->hasTrait<OpTrait::IsTerminator>());
+        });
+
+    if (!isa_and_present<LoopNestOp>(nestedCapture))
+      return result;
+
+    int numNestedWrappers;
+    LoopWrapperInterface *nestedWrapper =
+        getInnermostWrapper(cast<LoopNestOp>(nestedCapture), numNestedWrappers);
+
+    if (numNestedWrappers != 1 || !isa<WsloopOp>(nestedWrapper))
+      return result;
+
+    Operation *parallelOp = (*nestedWrapper)->getParentOp();
+    if (!isa_and_present<ParallelOp>(parallelOp))
+      return result;
+
+    if (parallelOp->getParentOp() != capturedOp)
+      return result;
+
+    return result | TargetRegionFlags::spmd;
+  }
+  // Detect target-parallel-wsloop[-simd].
+  else if (isa<WsloopOp>(innermostWrapper)) {
+    Operation *parallelOp = (*innermostWrapper)->getParentOp();
+    if (!isa_and_present<ParallelOp>(parallelOp))
+      return TargetRegionFlags::generic;
+
+    if (parallelOp->getParentOp() == targetOp.getOperation())
+      return TargetRegionFlags::spmd;
   }
 
-  return OMP_TGT_EXEC_MODE_GENERIC;
+  return TargetRegionFlags::generic;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d41489921bd13..4d610d6e2656d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4646,7 +4646,17 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
     combinedMaxThreadsVal = maxThreadsVal;
 
   // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
-  attrs.ExecFlags = targetOp.getKernelExecFlags(capturedOp);
+  omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
+  assert(
+      omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
+                                               omp::TargetRegionFlags::spmd) &&
+      "invalid kernel flags");
+  attrs.ExecFlags =
+      omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
+          ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
+                ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
+                : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
+          : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
   attrs.MinTeams = minTeamsVal;
   attrs.MaxTeams.front() = maxTeamsVal;
   attrs.MinThreads = 1;
@@ -4691,8 +4701,8 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
   if (numThreads)
     attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
 
-  if (targetOp.getKernelExecFlags(capturedOp) !=
-      llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
+  if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
+                              omp::TargetRegionFlags::trip_count)) {
     llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
     attrs.LoopTripCount = nullptr;
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 403128bb2300e..bd0541987339a 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2320,7 +2320,7 @@ func.func @omp_target_host_eval_parallel(%x : i32) {
 // -----
 
 func.func @omp_target_host_eval_loop1(%x : i32) {
-  // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
+  // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}}
   omp.target host_eval(%x -> %arg0 : i32) {
     omp.wsloop {
       omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
@@ -2335,7 +2335,7 @@ func.func @omp_target_host_eval_loop1(%x : i32) {
 // -----
 
 func.func @omp_target_host_eval_loop2(%x : i32) {
-  // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
+  // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}}
   omp.target host_eval(%x -> %arg0 : i32) {
     omp.teams {
     ^bb0:
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index a5cf789402726..e3d2f8bd01018 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2864,6 +2864,23 @@ func.func @omp_target_host_eval(%x : i32) {
     omp.terminator
   }
 
+  // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
+  // CHECK: omp.parallel num_threads(%[[HOST_ARG]] : i32) {
+  // CHECK: omp.wsloop {
+  // CHECK: omp.loop_nest
+  omp.target host_eval(%x -> %arg0 : i32) {
+    %y = arith.constant 2 : i32
+    omp.parallel num_threads(%arg0 : i32) {
+      omp.wsloop {
+        omp.loop_nest (%iv) : i32 = (%y) to (%y) step (%y) {
+          omp.yield
+        }
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+
   // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
   // CHECK: omp.teams {
   // CHECK: omp.distribute {
diff --git a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
new file mode 100644
index 0000000000000..8101660e571e4
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
@@ -0,0 +1,111 @@
+// RUN: split-file %s %t
+// RUN: mlir-translate -mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=HOST
+// RUN: mlir-translate -mlir-to-llvmir %t/device.mlir | FileCheck %s --check-prefix=DEVICE
+
+//--- host.mlir
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+  llvm.func @main(%arg0 : !llvm.ptr) {
+    %x = llvm.load %arg0 : !llvm.ptr -> i32
+    %0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
+    omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) map_entries(%0 -> %ptr : !llvm.ptr) {
+      %x.map = llvm.load %ptr : !llvm.ptr -> i32
+      omp.teams {
+        omp.distribute {
+          omp.loop_nest (%iv1) : i32 = (%lb) to (%ub) step (%step) {
+            omp.parallel {
+              omp.wsloop {
+                omp.loop_nest (%iv2) : i32 = (%x.map) to (%x.map) step (%x.map) {
+                  omp.yield
+                }
+              }
+              omp.terminator
+            }
+            omp.yield
+          }
+        }
+        omp.terminator
+      }
+      omp.terminator
+    }
+    llvm.return
+  }
+}
+
+// HOST-LABEL: define void @main
+// HOST:         %omp_loop.tripcount = {{.*}}
+// HOST-NEXT:    br label %[[ENTRY:.*]]
+// HOST:       [[ENTRY]]:
+// HOST:         %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
+// HOST:         %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
+// HOST-NEXT:    store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
+// HOST:         %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
+// HOST-NEXT:    %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
+// HOST-NEXT:    br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
+// HOST:       [[OFFLOAD_FAILED]]:
+// HOST:         call void @[[TARGET_OUTLINE:.*]]({{.*}})
+
+// HOST:       define internal void @[[TARGET_OUTLINE]]
+// HOST:         call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
+
+// HOST:       define internal void @[[TEAMS_OUTLINE]]
+// HOST:         call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
+
+// HOST:       define internal void @[[DISTRIBUTE_OUTLINE]]
+// HOST:         call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 92, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
+// HOST:         call void (ptr, i32, ptr, ...) @__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Mar 26, 2025

@llvm/pr-subscribers-flang-openmp

Author: Sergio Afonso (skatrak)

Changes

This patch updates the handling of target regions to set trip counts and kernel execution modes properly, based on clang's behavior. This fixes a race condition on target teams distribute constructs with no parallel do loop inside.

This is how kernels are classified, after changes introduced in this patch:

! Exec mode: SPMD.
! Trip count: Set.
!$omp target teams distribute parallel do
do i=...
end do

! Exec mode: Generic-SPMD.
! Trip count: Set (outer loop).
!$omp target teams distribute
do i=...
  !$omp parallel do private(idx, y)
  do j=...
  end do
end do

! Exec mode: Generic.
! Trip count: Set.
!$omp target teams distribute
do i=...
end do

! Exec mode: SPMD.
! Trip count: Not set.
!$omp target parallel do
do i=...
end do

! Exec mode: Generic.
! Trip count: Not set.
!$omp target
  ...
!$omp end target

For the split target teams distribute + parallel do case, clang produces a Generic kernel which gets promoted to Generic-SPMD by the openmp-opt pass. We can't currently replicate that behavior in flang because our codegen for these constructs results in the introduction of calls to the kmpc_distribute_static_loop family of functions, instead of kmpc_distribute_static_init, which currently prevent promotion of the kernel to Generic-SPMD.

For the time being, instead of relying on the openmp-opt pass, we look at the MLIR representation to find the Generic-SPMD pattern and directly tag the kernel as such during codegen. This is what we were already doing, but incorrectly matching other kinds of kernels as such in the process.


Patch is 22.50 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/133094.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td (+18)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+1-1)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+120-67)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+13-3)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+2-2)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+17)
  • (added) mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir (+111)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
index 690e3df1f685e..9dbe6897a3304 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td
@@ -222,6 +222,24 @@ def ScheduleModifier : OpenMP_I32EnumAttr<
 
 def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
 
+//===----------------------------------------------------------------------===//
+// target_region_flags enum.
+//===----------------------------------------------------------------------===//
+
+def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
+def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>;
+def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>;
+def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>;
+
+def TargetRegionFlags : OpenMP_BitEnumAttr<
+    "TargetRegionFlags",
+    "target region property flags", [
+      TargetRegionFlagsNone,
+      TargetRegionFlagsGeneric,
+      TargetRegionFlagsSpmd,
+      TargetRegionFlagsTripCount
+    ]>;
+
 //===----------------------------------------------------------------------===//
 // variable_capture_kind enum.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 65095932be627..11530c0fa3620 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1312,7 +1312,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
     ///
     /// \param capturedOp result of a still valid (no modifications made to any
     /// nested operations) previous call to `getInnermostCapturedOmpOp()`.
-    static llvm::omp::OMPTgtExecModeFlags
+    static ::mlir::omp::TargetRegionFlags
     getKernelExecFlags(Operation *capturedOp);
   }] # clausesExtraClassDeclaration;
 
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 882bc4071482f..5b46cab96dd88 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1908,8 +1908,8 @@ LogicalResult TargetOp::verifyRegions() {
     return emitError("target containing multiple 'omp.teams' nested ops");
 
   // Check that host_eval values are only used in legal ways.
-  llvm::omp::OMPTgtExecModeFlags execFlags =
-      getKernelExecFlags(getInnermostCapturedOmpOp());
+  Operation *capturedOp = getInnermostCapturedOmpOp();
+  TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
   for (Value hostEvalArg :
        cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
     for (Operation *user : hostEvalArg.getUsers()) {
@@ -1924,7 +1924,8 @@ LogicalResult TargetOp::verifyRegions() {
                                 "and 'thread_limit' in 'omp.teams'";
       }
       if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
-        if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
+        if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
+            parallelOp->isAncestor(capturedOp) &&
             hostEvalArg == parallelOp.getNumThreads())
           continue;
 
@@ -1933,15 +1934,16 @@ LogicalResult TargetOp::verifyRegions() {
                   "'omp.parallel' when representing target SPMD";
       }
       if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
-        if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
+        if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
+            loopNestOp.getOperation() == capturedOp &&
             (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
              llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
              llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
           continue;
 
         return emitOpError() << "host_eval argument only legal as loop bounds "
-                                "and steps in 'omp.loop_nest' when "
-                                "representing target SPMD or Generic-SPMD";
+                                "and steps in 'omp.loop_nest' when trip count "
+                                "must be evaluated in the host";
       }
 
       return emitOpError() << "host_eval argument illegal use in '"
@@ -1951,33 +1953,12 @@ LogicalResult TargetOp::verifyRegions() {
   return success();
 }
 
-/// Only allow OpenMP terminators and non-OpenMP ops that have known memory
-/// effects, but don't include a memory write effect.
-static bool siblingAllowedInCapture(Operation *op) {
-  if (!op)
-    return false;
+static Operation *
+findCapturedOmpOp(Operation *rootOp,
+                  llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
+  assert(rootOp && "expected valid operation");
 
-  bool isOmpDialect =
-      op->getContext()->getLoadedDialect<omp::OpenMPDialect>() ==
-      op->getDialect();
-
-  if (isOmpDialect)
-    return op->hasTrait<OpTrait::IsTerminator>();
-
-  if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
-    SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects;
-    memOp.getEffects(effects);
-    return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
-      return isa<MemoryEffects::Write>(effect.getEffect()) &&
-             isa<SideEffects::AutomaticAllocationScopeResource>(
-                 effect.getResource());
-    });
-  }
-  return true;
-}
-
-Operation *TargetOp::getInnermostCapturedOmpOp() {
-  Dialect *ompDialect = (*this)->getDialect();
+  Dialect *ompDialect = rootOp->getDialect();
   Operation *capturedOp = nullptr;
   DominanceInfo domInfo;
 
@@ -1985,8 +1966,8 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
   // ensuring we only enter the region of an operation if it meets the criteria
   // for being captured. We stop the exploration of nested operations as soon as
   // we process a region holding no operations to be captured.
-  walk<WalkOrder::PreOrder>([&](Operation *op) {
-    if (op == *this)
+  rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
+    if (op == rootOp)
       return WalkResult::advance();
 
     // Ignore operations of other dialects or omp operations with no regions,
@@ -2016,7 +1997,7 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
     // Don't capture this op if it has a not-allowed sibling, and stop recursing
     // into nested operations.
     for (Operation &sibling : op->getParentRegion()->getOps())
-      if (&sibling != op && !siblingAllowedInCapture(&sibling))
+      if (&sibling != op && !siblingAllowedFn(&sibling))
         return WalkResult::interrupt();
 
     // Don't continue capturing nested operations if we reach an omp.loop_nest.
@@ -2029,10 +2010,33 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
   return capturedOp;
 }
 
-llvm::omp::OMPTgtExecModeFlags
-TargetOp::getKernelExecFlags(Operation *capturedOp) {
-  using namespace llvm::omp;
+Operation *TargetOp::getInnermostCapturedOmpOp() {
+  auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
 
+  // Only allow OpenMP terminators and non-OpenMP ops that have known memory
+  // effects, but don't include a memory write effect.
+  return findCapturedOmpOp(*this, [&](Operation *sibling) {
+    if (!sibling)
+      return false;
+
+    if (ompDialect == sibling->getDialect())
+      return sibling->hasTrait<OpTrait::IsTerminator>();
+
+    if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
+      SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
+          effects;
+      memOp.getEffects(effects);
+      return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
+        return isa<MemoryEffects::Write>(effect.getEffect()) &&
+               isa<SideEffects::AutomaticAllocationScopeResource>(
+                   effect.getResource());
+      });
+    }
+    return true;
+  });
+}
+
+TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
   // A non-null captured op is only valid if it resides inside of a TargetOp
   // and is the result of calling getInnermostCapturedOmpOp() on it.
   TargetOp targetOp =
@@ -2041,57 +2045,106 @@ TargetOp::getKernelExecFlags(Operation *capturedOp) {
           (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
          "unexpected captured op");
 
-  // Make sure this region is capturing a loop. Otherwise, it's a generic
-  // kernel.
+  // If it's not capturing a loop, it's a default target region.
   if (!isa_and_present<LoopNestOp>(capturedOp))
-    return OMP_TGT_EXEC_MODE_GENERIC;
-
-  SmallVector<LoopWrapperInterface> wrappers;
-  cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
-  assert(!wrappers.empty());
+    return TargetRegionFlags::generic;
 
-  // Ignore optional SIMD leaf construct.
-  auto *innermostWrapper = wrappers.begin();
-  if (isa<SimdOp>(innermostWrapper))
-    innermostWrapper = std::next(innermostWrapper);
+  auto getInnermostWrapper = [](LoopNestOp loopOp, int &numWrappers) {
+    SmallVector<LoopWrapperInterface> wrappers;
+    loopOp.gatherWrappers(wrappers);
+    assert(!wrappers.empty());
 
-  long numWrappers = std::distance(innermostWrapper, wrappers.end());
+    // Ignore optional SIMD leaf construct.
+    auto *wrapper = wrappers.begin();
+    if (isa<SimdOp>(wrapper))
+      wrapper = std::next(wrapper);
 
-  // Detect Generic-SPMD: target-teams-distribute[-simd].
-  if (numWrappers == 1) {
-    if (!isa<DistributeOp>(innermostWrapper))
-      return OMP_TGT_EXEC_MODE_GENERIC;
+    numWrappers = static_cast<int>(std::distance(wrapper, wrappers.end()));
+    return wrapper;
+  };
 
-    Operation *teamsOp = (*innermostWrapper)->getParentOp();
-    if (!isa_and_present<TeamsOp>(teamsOp))
-      return OMP_TGT_EXEC_MODE_GENERIC;
+  int numWrappers;
+  LoopWrapperInterface *innermostWrapper =
+      getInnermostWrapper(cast<LoopNestOp>(capturedOp), numWrappers);
 
-    if (teamsOp->getParentOp() == targetOp.getOperation())
-      return OMP_TGT_EXEC_MODE_GENERIC_SPMD;
-  }
+  if (numWrappers != 1 && numWrappers != 2)
+    return TargetRegionFlags::generic;
 
-  // Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
+  // Detect target-teams-distribute-parallel-wsloop[-simd].
   if (numWrappers == 2) {
     if (!isa<WsloopOp>(innermostWrapper))
-      return OMP_TGT_EXEC_MODE_GENERIC;
+      return TargetRegionFlags::generic;
 
     innermostWrapper = std::next(innermostWrapper);
     if (!isa<DistributeOp>(innermostWrapper))
-      return OMP_TGT_EXEC_MODE_GENERIC;
+      return TargetRegionFlags::generic;
 
     Operation *parallelOp = (*innermostWrapper)->getParentOp();
     if (!isa_and_present<ParallelOp>(parallelOp))
-      return OMP_TGT_EXEC_MODE_GENERIC;
+      return TargetRegionFlags::generic;
 
     Operation *teamsOp = parallelOp->getParentOp();
     if (!isa_and_present<TeamsOp>(teamsOp))
-      return OMP_TGT_EXEC_MODE_GENERIC;
+      return TargetRegionFlags::generic;
 
     if (teamsOp->getParentOp() == targetOp.getOperation())
-      return OMP_TGT_EXEC_MODE_SPMD;
+      return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
+  }
+  // Detect target-teams-distribute[-simd].
+  else if (isa<DistributeOp>(innermostWrapper)) {
+    Operation *teamsOp = (*innermostWrapper)->getParentOp();
+    if (!isa_and_present<TeamsOp>(teamsOp))
+      return TargetRegionFlags::generic;
+
+    if (teamsOp->getParentOp() != targetOp.getOperation())
+      return TargetRegionFlags::generic;
+
+    TargetRegionFlags result =
+        TargetRegionFlags::generic | TargetRegionFlags::trip_count;
+
+    // Find single nested parallel-do and add spmd flag (generic-spmd case).
+    // TODO: This shouldn't have to be done here, as it is too easy to break.
+    // The openmp-opt pass should be updated to be able to promote kernels like
+    // this from "Generic" to "Generic-SPMD". However, the use of the
+    // `kmpc_distribute_static_loop` family of functions produced by the
+    // OMPIRBuilder for these kernels prevents that from working.
+    Dialect *ompDialect = targetOp->getDialect();
+    Operation *nestedCapture =
+        findCapturedOmpOp(capturedOp, [&](Operation *sibling) {
+          return sibling && (ompDialect != sibling->getDialect() ||
+                             sibling->hasTrait<OpTrait::IsTerminator>());
+        });
+
+    if (!isa_and_present<LoopNestOp>(nestedCapture))
+      return result;
+
+    int numNestedWrappers;
+    LoopWrapperInterface *nestedWrapper =
+        getInnermostWrapper(cast<LoopNestOp>(nestedCapture), numNestedWrappers);
+
+    if (numNestedWrappers != 1 || !isa<WsloopOp>(nestedWrapper))
+      return result;
+
+    Operation *parallelOp = (*nestedWrapper)->getParentOp();
+    if (!isa_and_present<ParallelOp>(parallelOp))
+      return result;
+
+    if (parallelOp->getParentOp() != capturedOp)
+      return result;
+
+    return result | TargetRegionFlags::spmd;
+  }
+  // Detect target-parallel-wsloop[-simd].
+  else if (isa<WsloopOp>(innermostWrapper)) {
+    Operation *parallelOp = (*innermostWrapper)->getParentOp();
+    if (!isa_and_present<ParallelOp>(parallelOp))
+      return TargetRegionFlags::generic;
+
+    if (parallelOp->getParentOp() == targetOp.getOperation())
+      return TargetRegionFlags::spmd;
   }
 
-  return OMP_TGT_EXEC_MODE_GENERIC;
+  return TargetRegionFlags::generic;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d41489921bd13..4d610d6e2656d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4646,7 +4646,17 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
     combinedMaxThreadsVal = maxThreadsVal;
 
   // Update kernel bounds structure for the `OpenMPIRBuilder` to use.
-  attrs.ExecFlags = targetOp.getKernelExecFlags(capturedOp);
+  omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
+  assert(
+      omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
+                                               omp::TargetRegionFlags::spmd) &&
+      "invalid kernel flags");
+  attrs.ExecFlags =
+      omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
+          ? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
+                ? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
+                : llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
+          : llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
   attrs.MinTeams = minTeamsVal;
   attrs.MaxTeams.front() = maxTeamsVal;
   attrs.MinThreads = 1;
@@ -4691,8 +4701,8 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
   if (numThreads)
     attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
 
-  if (targetOp.getKernelExecFlags(capturedOp) !=
-      llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
+  if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
+                              omp::TargetRegionFlags::trip_count)) {
     llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
     attrs.LoopTripCount = nullptr;
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 403128bb2300e..bd0541987339a 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2320,7 +2320,7 @@ func.func @omp_target_host_eval_parallel(%x : i32) {
 // -----
 
 func.func @omp_target_host_eval_loop1(%x : i32) {
-  // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
+  // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}}
   omp.target host_eval(%x -> %arg0 : i32) {
     omp.wsloop {
       omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
@@ -2335,7 +2335,7 @@ func.func @omp_target_host_eval_loop1(%x : i32) {
 // -----
 
 func.func @omp_target_host_eval_loop2(%x : i32) {
-  // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
+  // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}}
   omp.target host_eval(%x -> %arg0 : i32) {
     omp.teams {
     ^bb0:
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index a5cf789402726..e3d2f8bd01018 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2864,6 +2864,23 @@ func.func @omp_target_host_eval(%x : i32) {
     omp.terminator
   }
 
+  // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
+  // CHECK: omp.parallel num_threads(%[[HOST_ARG]] : i32) {
+  // CHECK: omp.wsloop {
+  // CHECK: omp.loop_nest
+  omp.target host_eval(%x -> %arg0 : i32) {
+    %y = arith.constant 2 : i32
+    omp.parallel num_threads(%arg0 : i32) {
+      omp.wsloop {
+        omp.loop_nest (%iv) : i32 = (%y) to (%y) step (%y) {
+          omp.yield
+        }
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+
   // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
   // CHECK: omp.teams {
   // CHECK: omp.distribute {
diff --git a/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
new file mode 100644
index 0000000000000..8101660e571e4
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/openmp-target-generic-spmd.mlir
@@ -0,0 +1,111 @@
+// RUN: split-file %s %t
+// RUN: mlir-translate -mlir-to-llvmir %t/host.mlir | FileCheck %s --check-prefix=HOST
+// RUN: mlir-translate -mlir-to-llvmir %t/device.mlir | FileCheck %s --check-prefix=DEVICE
+
+//--- host.mlir
+
+module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
+  llvm.func @main(%arg0 : !llvm.ptr) {
+    %x = llvm.load %arg0 : !llvm.ptr -> i32
+    %0 = omp.map.info var_ptr(%arg0 : !llvm.ptr, i32) map_clauses(to) capture(ByCopy) -> !llvm.ptr
+    omp.target host_eval(%x -> %lb, %x -> %ub, %x -> %step : i32, i32, i32) map_entries(%0 -> %ptr : !llvm.ptr) {
+      %x.map = llvm.load %ptr : !llvm.ptr -> i32
+      omp.teams {
+        omp.distribute {
+          omp.loop_nest (%iv1) : i32 = (%lb) to (%ub) step (%step) {
+            omp.parallel {
+              omp.wsloop {
+                omp.loop_nest (%iv2) : i32 = (%x.map) to (%x.map) step (%x.map) {
+                  omp.yield
+                }
+              }
+              omp.terminator
+            }
+            omp.yield
+          }
+        }
+        omp.terminator
+      }
+      omp.terminator
+    }
+    llvm.return
+  }
+}
+
+// HOST-LABEL: define void @main
+// HOST:         %omp_loop.tripcount = {{.*}}
+// HOST-NEXT:    br label %[[ENTRY:.*]]
+// HOST:       [[ENTRY]]:
+// HOST:         %[[TRIPCOUNT:.*]] = zext i32 %omp_loop.tripcount to i64
+// HOST:         %[[TRIPCOUNT_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KARGS:.*]], i32 0, i32 8
+// HOST-NEXT:    store i64 %[[TRIPCOUNT]], ptr %[[TRIPCOUNT_KARG]]
+// HOST:         %[[RESULT:.*]] = call i32 @__tgt_target_kernel({{.*}}, ptr %[[KARGS]])
+// HOST-NEXT:    %[[CMP:.*]] = icmp ne i32 %[[RESULT]], 0
+// HOST-NEXT:    br i1 %[[CMP]], label %[[OFFLOAD_FAILED:.*]], label %{{.*}}
+// HOST:       [[OFFLOAD_FAILED]]:
+// HOST:         call void @[[TARGET_OUTLINE:.*]]({{.*}})
+
+// HOST:       define internal void @[[TARGET_OUTLINE]]
+// HOST:         call void{{.*}}@__kmpc_fork_teams({{.*}}, ptr @[[TEAMS_OUTLINE:.*]], {{.*}})
+
+// HOST:       define internal void @[[TEAMS_OUTLINE]]
+// HOST:         call void @[[DISTRIBUTE_OUTLINE:.*]]({{.*}})
+
+// HOST:       define internal void @[[DISTRIBUTE_OUTLINE]]
+// HOST:         call void @__kmpc_for_static_init{{.*}}(ptr {{.*}}, i32 {{.*}}, i32 92, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, ptr {{.*}}, i32 {{.*}}, i32 {{.*}})
+// HOST:         call void (ptr, i32, ptr, ...) @__kmpc_fork_call({{.*}}, ptr @[[PARALLEL_OUTLINE:.*]], {{.*...
[truncated]

Comment on lines +4650 to +4659
assert(
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
omp::TargetRegionFlags::spmd) &&
"invalid kernel flags");
attrs.ExecFlags =
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
: llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
: llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if a trip_count hits here (or why isn' it possible)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trip_count flag does not impact the ExecFlags, so it could be present but is ignored here. It is checked by initTargetRuntimeAttrs to decide whether to evaluate and pass the trip count from the host in the kernel arguments structure to the kernel launch call.

@skatrak skatrak force-pushed the fix-distribute-device branch from f81e281 to dc1ac96 Compare March 28, 2025 16:17
SmallVector<LoopWrapperInterface> wrappers;
cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
assert(!wrappers.empty());
auto getInnermostWrapper = [](LoopNestOp loopOp, int &numWrappers) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not need to be a lambda function. It is better as a static function imo.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just updated this to actually remove the function altogether. I created it when I was expecting to use it multiple times, but in the end it only was necessary in one spot.

Copy link
Member

@Meinersbur Meinersbur left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@jsjodin jsjodin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

This patch updates the handling of target regions to set trip counts and kernel
execution modes properly, based on clang's behavior. This fixes a race
condition on `target teams distribute` constructs with no `parallel do` loop
inside.

This is how kernels are classified, after changes introduced in this patch:

```f90
! Exec mode: SPMD.
! Trip count: Set.
!$omp target teams distribute parallel do
do i=...
end do

! Exec mode: Generic-SPMD.
! Trip count: Set (outer loop).
!$omp target teams distribute
do i=...
  !$omp parallel do private(idx, y)
  do j=...
  end do
end do

! Exec mode: Generic-SPMD.
! Trip count: Set (outer loop).
!$omp target teams distribute
do i=...
  !$omp parallel
    ...
  !$omp end parallel
end do

! Exec mode: Generic.
! Trip count: Set.
!$omp target teams distribute
do i=...
end do

! Exec mode: SPMD.
! Trip count: Not set.
!$omp target parallel do
do i=...
end do

! Exec mode: Generic.
! Trip count: Not set.
!$omp target
  ...
!$omp end target
```

For the split `target teams distribute + parallel do` case, clang produces a
Generic kernel which gets promoted to Generic-SPMD by the openmp-opt pass. We
can't currently replicate that behavior in flang because our codegen for these
constructs results in the introduction of calls to the
`kmpc_distribute_static_loop` family of functions, instead of
`kmpc_distribute_static_init`, which currently prevent promotion of the kernel
to Generic-SPMD.

For the time being, instead of relying on the openmp-opt pass, we look at the
MLIR representation to find the Generic-SPMD pattern and directly tag the
kernel as such during codegen. This is what we were already doing, but
incorrectly matching other kinds of kernels as such in the process.
@skatrak skatrak force-pushed the fix-distribute-device branch from 5653ae0 to 4033b12 Compare April 3, 2025 14:18
@skatrak skatrak merged commit f59b5b8 into llvm:main Apr 3, 2025
8 of 11 checks passed
@skatrak skatrak deleted the fix-distribute-device branch April 3, 2025 14:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants