Skip to content

Commit

Permalink
Improving iree-stream-emplace-allocations pass. (iree-org#13732)
Browse files Browse the repository at this point in the history
This allows for more aggressive placement into exports with storage by
simplifying the IR so analysis succeeds more often.

Progress on iree-org#13545.
  • Loading branch information
benvanik authored and nhasabni committed Aug 24, 2023
1 parent 769bfd2 commit 5dd4065
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,6 @@ namespace iree_compiler {

namespace {

// Returns the size in bytes of the buffer/buffer view |storage|.
static Value getStorageSize(Value storage, OpBuilder &builder) {
if (storage.getType().isa<IREE::HAL::BufferViewType>()) {
storage = builder.create<IREE::HAL::BufferViewBufferOp>(
storage.getLoc(), builder.getType<IREE::HAL::BufferType>(), storage);
}
return builder
.create<IREE::HAL::BufferLengthOp>(storage.getLoc(),
builder.getIndexType(), storage)
.getResult();
}

// %1 = hal.tensor.import %0 : !hal.buffer_view -> tensor<4xf32>
// ->
// %1 = stream.tensor.import %0 : !hal.buffer_view ->
Expand Down Expand Up @@ -171,7 +159,10 @@ struct ConvertTensorExportOp
if (adaptor.getTargetStorage()) {
// Query the target storage buffer length; we will only populate up to
// what is required for the output.
auto storageSize = getStorageSize(op.getTargetStorage(), rewriter);
auto storageSize = rewriter.createOrFold<IREE::Stream::TensorSizeOfOp>(
op.getLoc(), rewriter.getIndexType(),
TypeAttr::get(op.getSource().getType()), adaptor.getSourceDims(),
/*affinity=*/nullptr);

// Import the target storage as a resource that we can use as an update
// target. We overwrite the contents and just cast the storage to the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ func.func @exportBufferView(%tensor: tensor<?x?x4xf32>, %dim0: index, %dim1: ind
// CHECK-LABEL: @exportBufferViewInPlace
// CHECK-SAME: (%[[TENSOR:.+]]: !stream.resource<*>, %[[SIZE:.+]]: index, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[STORAGE:.+]]: !hal.buffer)
func.func @exportBufferViewInPlace(%tensor: tensor<?x?x4xf32>, %dim0: index, %dim1: index, %storage: !hal.buffer) -> !hal.buffer_view {
// CHECK: %[[STORAGE_LENGTH:.+]] = hal.buffer.length<%[[STORAGE]]
// CHECK: %[[STORAGE_SIZE:.+]] = stream.tensor.sizeof tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} : index
// CHECK-NEXT: %[[STORAGE_IMPORT:.+]] = stream.tensor.import %[[STORAGE]]
// CHECK-SAME: : !hal.buffer -> tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} in !stream.resource<external>{%[[STORAGE_LENGTH]]}
// CHECK-SAME: : !hal.buffer -> tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} in !stream.resource<external>{%[[STORAGE_SIZE]]}
// CHECK-NEXT: %[[STORAGE_UPDATE:.+]] = stream.async.update %[[TENSOR]], %[[STORAGE_IMPORT]][%c0 to %[[SIZE]]]
// CHECK-SAME: : !stream.resource<*>{%[[SIZE]]} -> %[[STORAGE_IMPORT]] as !stream.resource<external>{%[[STORAGE_LENGTH]]}
// CHECK-SAME: : !stream.resource<*>{%[[SIZE]]} -> %[[STORAGE_IMPORT]] as !stream.resource<external>{%[[STORAGE_SIZE]]}
// CHECK-NEXT: %[[STORAGE_RESULT:.+]] = stream.tensor.export %[[STORAGE_UPDATE]] :
// CHECK-SAME: tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} in !stream.resource<external>{%[[STORAGE_LENGTH]]}
// CHECK-SAME: tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} in !stream.resource<external>{%[[STORAGE_SIZE]]}
// CHECK-SAME: -> !hal.buffer_view
%0 = hal.tensor.export %tensor into(%storage : !hal.buffer) : tensor<?x?x4xf32>{%dim0, %dim1} -> !hal.buffer_view
// CHECK: return %[[STORAGE_RESULT]]
Expand All @@ -92,10 +92,9 @@ func.func @exportBufferViewInPlace(%tensor: tensor<?x?x4xf32>, %dim0: index, %di
// CHECK-LABEL: @exportBufferViewInPlaceToView
// CHECK-SAME: (%[[TENSOR:.+]]: !stream.resource<*>, %[[SIZE:.+]]: index, %[[DIM0:.+]]: index, %[[DIM1:.+]]: index, %[[STORAGE:.+]]: !hal.buffer_view)
func.func @exportBufferViewInPlaceToView(%tensor: tensor<?x?x4xf32>, %dim0: index, %dim1: index, %storage: !hal.buffer_view) -> !hal.buffer_view {
// CHECK: %[[STORAGE_BUFFER:.+]] = hal.buffer_view.buffer<%[[STORAGE]]
// CHECK: %[[STORAGE_LENGTH:.+]] = hal.buffer.length<%[[STORAGE_BUFFER]]
// CHECK: %[[STORAGE_SIZE:.+]] = stream.tensor.sizeof tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} : index
// CHECK-NEXT: %[[STORAGE_IMPORT:.+]] = stream.tensor.import %[[STORAGE]]
// CHECK-SAME: : !hal.buffer_view -> tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} in !stream.resource<external>{%[[STORAGE_LENGTH]]}
// CHECK-SAME: : !hal.buffer_view -> tensor<?x?x4xf32>{%[[DIM0]], %[[DIM1]]} in !stream.resource<external>{%[[STORAGE_SIZE]]}
%0 = hal.tensor.export %tensor into(%storage : !hal.buffer_view) : tensor<?x?x4xf32>{%dim0, %dim1} -> !hal.buffer_view
return %0 : !hal.buffer_view
}
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1452,7 +1452,7 @@ static void printCollectiveParam(OpAsmPrinter &p, Operation *op,
assert(keyword && "collective op must have a param keyword");
p << keyword << "(";
p.printOperand(paramValue);
p << ")";
p << ") ";
}
}

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Dialect/Stream/IR/StreamOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1777,7 +1777,7 @@ def Stream_AsyncCollectiveOp : Stream_Op<"async.collective", [
`` $op `` `[` $element_count `]`
(`on` `(` $affinity^ `)`)?
`channel` `(` $channel `)`
custom<CollectiveParam>(ref($op), $param)
custom<CollectiveParam>(ref($op), $param) ``
$source `[` $source_offset `to` $source_end `for` $source_length `]` `,`
$target `[` $target_offset `to` $target_end `for` $target_length `]` `:`
type($source) `` `{` $source_size `}` `->`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,25 @@ namespace {
// Emplacement
//===----------------------------------------------------------------------===//

static void replaceUsesAndTransfer(Value oldValue, Value newValue) {
assert(oldValue.getType().isa<IREE::Stream::ResourceType>());
assert(newValue.getType().isa<IREE::Stream::ResourceType>());
if (oldValue.getType() == newValue.getType()) {
oldValue.replaceAllUsesWith(newValue);
return;
}
OpBuilder builder(newValue.getContext());
builder.setInsertionPointAfterValue(newValue);
Value newValueSize = IREE::Util::SizeAwareTypeInterface::queryValueSize(
newValue.getLoc(), newValue, builder);
IREE::Stream::AffinityAttr sourceAffinity;
IREE::Stream::AffinityAttr resultAffinity;
Value transferValue = builder.create<IREE::Stream::AsyncTransferOp>(
newValue.getLoc(), oldValue.getType(), newValue, newValueSize,
newValueSize, sourceAffinity, resultAffinity);
oldValue.replaceAllUsesWith(transferValue);
}

static bool tryEmplaceDispatchOp(IREE::Stream::AsyncDispatchOp dispatchOp) {
bool didChange = false;
for (auto [resultIndex, result] : llvm::enumerate(dispatchOp.getResults())) {
Expand Down Expand Up @@ -93,7 +112,7 @@ static bool tryEmplaceDispatchOp(IREE::Stream::AsyncDispatchOp dispatchOp) {
dispatchOp.getResultSizesMutable().assign(resultSizes);

// Replace users with the result of the dispatch op.
targetResult.replaceAllUsesWith(result);
replaceUsesAndTransfer(targetResult, result);
userOp->erase();

didChange = true;
Expand Down
30 changes: 25 additions & 5 deletions compiler/src/iree/compiler/Dialect/Util/IR/UtilTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ LogicalResult PtrType::verify(function_ref<InFlightDiagnostic()> emitError,
}

//===----------------------------------------------------------------------===//
// !util.ptr<T>
// !util.object
//===----------------------------------------------------------------------===//

// static
Expand Down Expand Up @@ -193,16 +193,36 @@ bool isValueUsableForOp(Value value, Operation *op) {

bool tryMoveProducerBefore(Value value, Operation *consumerOp) {
auto *producerOp = value.getDefiningOp();
if (!producerOp) return true; // block arg, ok to use
if (!producerOp) {
return true; // block arg, ok to use
}

// Producers and consumers in the same block are easy to check.
if (producerOp->getBlock() == consumerOp->getBlock()) {
if (producerOp->isBeforeInBlock(consumerOp)) return true;
if (producerOp->isBeforeInBlock(consumerOp)) {
// Producer comes before the consumer in the same block and already
// satisfies the request based on SSA dominance.
return true;
}
for (auto operand : producerOp->getOperands()) {
if (!isValueUsableForOp(operand, consumerOp)) return false;
if (!isValueUsableForOp(operand, consumerOp)) {
return false;
}
}
producerOp->moveBefore(consumerOp);
return true;
}
// Could extend this - really need a shared helper function.

// If the value is directly usable from another block (dominates, etc) then
// the condition is already satisfied. We don't move ops that don't satisfy
// this yet.
if (isValueUsableForOp(value, consumerOp)) {
return true;
}

// Could support more cases of satisfaction checks or movement. Ops that exist
// in ancestors (like those implicitly captured by nested scf.if/scf.for ops)
// are good candidates to check for.
return false;
}

Expand Down

0 comments on commit 5dd4065

Please sign in to comment.