Skip to content

Commit f81e281

Browse files
committed
Use Generic-SPMD for target teams distribute + parallel kernels
1 parent 48cf32c commit f81e281

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

Diff for: mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

+10-18
Original file line numberDiff line numberDiff line change
@@ -2099,10 +2099,9 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
20992099
if (teamsOp->getParentOp() != targetOp.getOperation())
21002100
return TargetRegionFlags::generic;
21012101

2102-
TargetRegionFlags result =
2103-
TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2104-
2105-
// Find single nested parallel-do and add spmd flag (generic-spmd case).
2102+
// Find single immediately nested captured omp.parallel and add spmd flag
2103+
// (generic-spmd case).
2104+
//
21062105
// TODO: This shouldn't have to be done here, as it is too easy to break.
21072106
// The openmp-opt pass should be updated to be able to promote kernels like
21082107
// this from "Generic" to "Generic-SPMD". However, the use of the
@@ -2115,24 +2114,17 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
21152114
sibling->hasTrait<OpTrait::IsTerminator>());
21162115
});
21172116

2118-
if (!isa_and_present<LoopNestOp>(nestedCapture))
2119-
return result;
2120-
2121-
int numNestedWrappers;
2122-
LoopWrapperInterface *nestedWrapper =
2123-
getInnermostWrapper(cast<LoopNestOp>(nestedCapture), numNestedWrappers);
2124-
2125-
if (numNestedWrappers != 1 || !isa<WsloopOp>(nestedWrapper))
2126-
return result;
2117+
TargetRegionFlags result =
2118+
TargetRegionFlags::generic | TargetRegionFlags::trip_count;
21272119

2128-
Operation *parallelOp = (*nestedWrapper)->getParentOp();
2129-
if (!isa_and_present<ParallelOp>(parallelOp))
2120+
if (!nestedCapture)
21302121
return result;
21312122

2132-
if (parallelOp->getParentOp() != capturedOp)
2133-
return result;
2123+
while (nestedCapture->getParentOp() != capturedOp)
2124+
nestedCapture = nestedCapture->getParentOp();
21342125

2135-
return result | TargetRegionFlags::spmd;
2126+
return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2127+
: result;
21362128
}
21372129
// Detect target-parallel-wsloop[-simd].
21382130
else if (isa<WsloopOp>(innermostWrapper)) {

0 commit comments

Comments
 (0)