Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ cc_library(
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu",
"//tensorflow/compiler/mlir/xla:lhlo_legalize_to_parallel_loops",
"//tensorflow/compiler/mlir/xla:tensor_linalg_to_buffer_linalg",
"//tensorflow/compiler/mlir/xla:xla_dialect_registration",
"//tensorflow/compiler/mlir/xla:xla_legalize_control_flow",
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
Expand Down
16 changes: 16 additions & 0 deletions tensorflow/compiler/mlir/xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,22 @@ cc_library(
alwayslink = 1,
)

cc_library(
name = "tensor_linalg_to_buffer_linalg",
srcs = ["transforms/tensor_linalg_to_buffer_linalg.cc"],
deps = [
":buffer_assignment",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)

cc_library(
name = "lhlo_fuse_linalg",
srcs = ["transforms/lhlo_fuse_linalg.cc"],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// RUN: tf-opt -tensor-linalg-to-buffer-linalg --buffer-assignment -split-input-file %s | FileCheck %s -dump-input-on-failure

#map0 = affine_map<(d0) -> (d0)>

module {
// CHECK-LABEL: func @muliple_results_generic_op
func @muliple_results_generic_op(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
%0, %1 = linalg.generic {args_in = 1 : i64, args_out = 2 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]} %arg0 {
^bb0(%arg1: f32):
%1 = exp %arg1 : f32
linalg.yield %1, %1 : f32, f32
}: tensor<4xf32> -> (tensor<4xf32>, tensor<4xf32>)
return %0, %1 : tensor<4xf32>, tensor<4xf32>
}
}
// CHECK: (%[[NEW_ARG0:.*]]: [[TYPE:.*]], %[[ARG1_RESULT:.*]]: [[TYPE]], %[[ARG2_RESULT:.*]]: [[TYPE]])
// CHECK: %[[FIRST_ALLOC:.*]] = alloc() : [[TYPE]]
// CHECK: %[[SECOND_ALLOC:.*]] = alloc() : [[TYPE]]
// CHECK: linalg.generic
// CHECK-SAME: %{{.*}}, %{{.*}}, %{{.*}}
// CHECK: ^{{[a-z0-9_]*}}
// CHECK-SAME: %[[ARG0:.*]]: f32, %{{.*}}: f32, %{{.*}}: f32
// CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = exp %[[ARG0]]
// CHECK: linalg.yield %[[RESULT]], %[[RESULT]]
// CHECK: [[TYPE]], [[TYPE]], [[TYPE]]
// CHECK-NEXT: linalg.copy(%[[FIRST_ALLOC]], %[[ARG1_RESULT]])
// CHECK-NEXT: dealloc %[[FIRST_ALLOC]]
// CHECK-NEXT: linalg.copy(%[[SECOND_ALLOC]], %[[ARG2_RESULT]])
// CHECK-NEXT: dealloc %[[SECOND_ALLOC]]
// CHECK-NEXT: return
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// This file implements logic for transforming Linalg operations with tensor
// types to memref type and allocate and deallocate buffers using the
// BufferAssignmentPlacer.

#include "absl/memory/memory.h"
#include "llvm/ADT/APInt.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // TF:llvm-project
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // TF:llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
#include "mlir/IR/AffineExpr.h" // TF:llvm-project
#include "mlir/IR/Attributes.h" // TF:llvm-project
#include "mlir/IR/Builders.h" // TF:llvm-project
#include "mlir/IR/Function.h" // TF:llvm-project
#include "mlir/IR/Location.h" // TF:llvm-project
#include "mlir/IR/MLIRContext.h" // TF:llvm-project
#include "mlir/IR/Operation.h" // TF:llvm-project
#include "mlir/IR/PatternMatch.h" // TF:llvm-project
#include "mlir/IR/StandardTypes.h" // TF:llvm-project
#include "mlir/Pass/Pass.h" // TF:llvm-project
#include "mlir/Transforms/DialectConversion.h" // TF:llvm-project
#include "tensorflow/compiler/mlir/xla/ir/lhlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/buffer_assignment.h"
#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h"

namespace mlir {
namespace xla {
namespace {

class GenericOpConverter
: public xla::BufferAssignmentOpConversionPattern<linalg::GenericOp> {
public:
using xla::BufferAssignmentOpConversionPattern<
linalg::GenericOp>::BufferAssignmentOpConversionPattern;

LogicalResult matchAndRewrite(
linalg::GenericOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc();
SmallVector<Value, 4> args(operands.begin(), operands.end());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will ValueRange work here?


// Update all types to memref types.
auto results = op.getOperation()->getResults();
for (auto result : results) {
auto type = result.getType().cast<ShapedType>();
if (!type)
op.emitOpError()
<< "tensor to buffer conversion expects ranked results";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe "expects shaped results"? Also this chekc is probably not needed at all, since Linalg verifiers check that:
in LinalgStructuredOps.td:

def LinalgOperand: Type<
  Or<[AnyRankedTensor.predicate, AnyStridedMemRef.predicate]>>;

class LinalgOperandOfRank<int rank>: Type<
  And<[
    LinalgOperand.predicate,
    CPred<"$_self.cast<ShapedType>().getRank() == " # rank>]
  >>;

class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
  let arguments = (ins Variadic<LinalgOperand>:$views,
                   I64Attr:$args_in,
                   I64Attr:$args_out,

auto memrefType = MemRefType::get(type.getShape(), type.getElementType());

// Compute alloc position and insert a custom allocation node.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.restoreInsertionPoint(
bufferAssignment->computeAllocPosition(result));
auto alloc = rewriter.create<AllocOp>(loc, memrefType);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will create invalid IR if the memrefType has any unknown dimensions. You need to check for hasStaticShape above and bail out for non-static-shape.

See also https://llvm.discourse.group/t/computing-output-shapes-of-structured-ops-on-tensors/866

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@silvasean Yes, you are right. This code doesn't take dynamically shaped types into account. We are going to insert an Assert here. Currently, Buffer Assignment doesn't support dynamically shaped types and it's one of our top priorities. We definitely use the link above. Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An assertion would not be a good choice. See https://mlir.llvm.org/getting_started/DeveloperGuide/#assertions-and-crashes-in-passes

You can instead say something like if (!rankedTensorType.hasStaticShape()) return rewriter.notifyMatchFailure(op, "dynamic shapes not currently supported")

result.replaceAllUsesWith(alloc);
args.push_back(alloc);
}

// Generate a new linalg operation that works on buffers.
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, llvm::None, args, rewriter.getI64IntegerAttr(operands.size()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm::None -> /*output_tensors=*/llvm::None

rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(),
op.iterator_types(), op.docAttr(), op.funAttr(), op.library_callAttr());

// Move regions from the old operation to the new one.
auto& region = linalgOp.region();
rewriter.inlineRegionBefore(op.region(), region, region.end());

// TODO(dfki): verify the internal memref-based linalg functionality.
auto& entryBlock = region.front();
for (auto result : results) {
auto type = result.getType().cast<ShapedType>();
entryBlock.addArgument(type.getElementType());
}

rewriter.eraseOp(op);
return success();
}
};

void populateTensorLinalgToBufferLinalgConversionPattern(
MLIRContext* context, xla::BufferAssignmentPlacer* placer,
OwningRewritePatternList* patterns) {
patterns->insert<xla::FunctionAndBlockSignatureConverter, GenericOpConverter,
xla::NonVoidToVoidReturnOpConverter<
mlir::ReturnOp, mlir::ReturnOp, linalg::CopyOp>>(context,
placer);
}

struct TensorLinalgToBufferLinalg
: public FunctionPass<TensorLinalgToBufferLinalg> {
void runOnFunction() override {
OwningRewritePatternList patterns;
auto& context = getContext();
ConversionTarget target(context);

// Make all linalg operations illegal as long as they work on tensors.
auto isLegalOperation = [](Operation* op) {
auto isIllegalValue = [](Value operand) {
return operand.getType().isa<TensorType>();
};
auto operands = op->getOperands();
auto results = op->getResults();
return std::none_of(operands.begin(), operands.end(), isIllegalValue) &&
std::none_of(results.begin(), results.end(), isIllegalValue);
};
target.addLegalDialect<StandardOpsDialect>();
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(
Optional<ConversionTarget::DynamicLegalityCallbackFn>(
isLegalOperation));

// Mark return operations illegal as long as they return values.
target.addDynamicallyLegalOp<mlir::ReturnOp>(
[](mlir::ReturnOp returnOp) { return returnOp.getNumOperands() == 0; });

auto function = getFunction();
xla::BufferAssignmentPlacer placer(function);
xla::FunctionAndBlockSignatureConverter::addDynamicallyLegalFuncOp(target);
populateTensorLinalgToBufferLinalgConversionPattern(function.getContext(),
&placer, &patterns);

// Do partial conversion so we can have unknown ops in tests.
if (failed(applyPartialConversion(function, target, patterns, nullptr))) {
signalPassFailure();
}
}
};
} // namespace

static PassRegistration<TensorLinalgToBufferLinalg>
tensor_linalg_to_buffer_linalg_pass(
"tensor-linalg-to-buffer-linalg",
"Legalize linalg operations with tensor type operands to memref type "
"ones");

} // namespace xla
} // namespace mlir