diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp index 51adb918a7fb..8a746524c07e 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/ScheduleAllocation.cpp @@ -199,6 +199,26 @@ static LivenessIntervalList computeExecutionRegionLivenessIntervals( // Compute ranges for all values independently (ignoring aliasing). for (auto &op : *streamBlock) { int start = opOrdering[&op]; + if (auto concurrentOp = dyn_cast(op)) { + // HACK: allocation planning currently only works on the top-level + // execute op but sometimes we need to allocate locals inside of + // concurrent regions. The real fix here is to make allocation planning + // handle arbitrary nesting but for now we do a quick walk through the + // regions to see if there are any locals that need to be marked live for + // the duration of the region. + concurrentOp.walk([&](Operation *op) { + for (auto value : op->getResults()) { + if (!value.getType().isa()) continue; + if (!value.use_empty()) continue; + LivenessInterval interval; + interval.start = start; + interval.end = start; + interval.value = value; + interval.ordinal = -1; + valueIntervals[value] = interval; + } + }); + } for (auto value : op.getResults()) { if (!value.getType().isa()) continue; LivenessInterval interval; diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir index 86193eb887ac..ca856373b622 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/schedule_allocation.mlir @@ -471,6 +471,45 @@ func.func @applyAsyncDispatchOp(%operand: !stream.resource, %size: in // ----- +// Tests that unused dispatch results nested in concurrent regions are still +// allocated memory. + +// CHECK-LABEL: @applyAsyncDispatchUnusedOp +// CHECK-SAME: (%[[OPERAND:.+]]: !stream.resource, %[[SIZE:.+]]: index, %[[OFFSET:.+]]: index, %[[END:.+]]: index, %[[LENGTH:.+]]: index) +func.func @applyAsyncDispatchUnusedOp(%operand: !stream.resource, %size: index, %offset: index, %end: index, %length: index) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + // CHECK: %[[PACK:.+]]:2 = stream.resource.pack + // CHECK: %[[ALLOCA:.+]], %[[ALLOCA_TIMEPOINT:.+]] = stream.resource.alloca uninitialized : !stream.resource{%[[PACK]]#0} + // CHECK: %[[TIMEPOINT:.+]] = stream.cmd.execute + // CHECK-SAME: await(%[[ALLOCA_TIMEPOINT]]) + // CHECK-SAME: with(%[[OPERAND]] as %[[OPERAND_CAPTURE:.+]]: !stream.resource{%[[SIZE]]}, + // CHECK-SAME: %[[ALLOCA]] as %[[ALLOCA_CAPTURE:.+]]: !stream.resource{%[[PACK]]#0}) + %result, %result_timepoint = stream.async.execute with(%operand as %capture: !stream.resource{%size}) -> (%operand as !stream.resource{%size}) { + // CHECK: stream.cmd.concurrent + %concurrent = stream.async.concurrent with(%capture as %concurrent_capture: !stream.resource{%size}) -> (%capture as !stream.resource{%size}) { + // CHECK-NEXT: stream.cmd.dispatch @executable::@dispatch[%c1, %c1, %c1](%c4 : index) { + // CHECK-NEXT: rw %[[OPERAND_CAPTURE]][%[[OFFSET]] for %[[LENGTH]]] : !stream.resource{%[[SIZE]]}, + // CHECK-NEXT: wo %[[ALLOCA_CAPTURE]][%[[PACK]]#1 for %[[SIZE]]] : !stream.resource{%[[PACK]]#0} + // CHECK-NEXT: } + %0:2 = stream.async.dispatch @executable::@dispatch[%c1, %c1, %c1](%concurrent_capture[%offset to %end for %length], %c4) : (!stream.resource{%size}, index) -> (%concurrent_capture{%size}, !stream.resource{%size}) + // NOTE: %0#1 is unused. + stream.yield %0#0 : !stream.resource{%size} + } + stream.yield %concurrent : !stream.resource{%size} + } => !stream.timepoint + // CHECK: %[[DEALLOCA:.+]] = stream.resource.dealloca await(%[[TIMEPOINT]]) => %[[ALLOCA]] + // CHECK: %[[JOIN:.+]] = stream.timepoint.join max(%[[DEALLOCA]], %[[TIMEPOINT]]) + // CHECK: util.optimization_barrier %[[JOIN]] + util.optimization_barrier %result_timepoint : !stream.timepoint + // CHECK: util.optimization_barrier %[[OPERAND]] + util.optimization_barrier %result : !stream.resource + return +} + +// ----- + // CHECK: stream.cmd.func private @asyncExtern(%arg0[%arg1 for %arg2]: !stream.resource<*>, %arg3: index, %arg4[%arg5 for %arg6]: !stream.resource<*>) stream.async.func private @asyncExtern(%arg0: !stream.resource<*>, %arg1: index) -> (%arg0, !stream.resource<*>)