Skip to content

Commit

Permalink
[Linalg] Change attribute n_loop_types to iterator
Browse files Browse the repository at this point in the history
This addresses issue #270. Linalg is updated to take the same form
of iterator_types than vector contraction.

Closes #280

COPYBARA_INTEGRATE_REVIEW=tensorflow/mlir#280 from tetuante:PRissue270 d26d88d090d3765d3b9884bfabdd023143f27287
PiperOrigin-RevId: 282905396
Change-Id: I1c55a92690dd31c28f9123b08dd482b52745681c
  • Loading branch information
tetuante authored and tensorflower-gardener committed Nov 28, 2019
1 parent 31013d5 commit 6596a60
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 43 deletions.
2 changes: 1 addition & 1 deletion tensorflow/compiler/mlir/xla/tests/lhlo-fuse-linalg.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: tf-opt -lhlo-fuse-linalg %s -o - | FileCheck %s

#map0 = (d0, d1) -> (d0, d1)
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], n_loop_types = [2, 0, 0], n_views = [2, 1]}
#pointwise_2d_trait = {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"], n_views = [2, 1]}
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
%temp_result = alloc() {temp = true} : memref<2x2xf32>
Expand Down
35 changes: 16 additions & 19 deletions tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_linalg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ namespace mlir {
namespace xla_lhlo {
namespace {

ArrayAttr GetNParallelLoopsAttrs(unsigned nParallelLoops, Builder b) {
auto parallelLoopTypeAttr = b.getStringAttr("parallel");
SmallVector<Attribute, 3> iteratorTypes;
for (int i = 0; i < nParallelLoops; ++i) {
iteratorTypes.push_back(parallelLoopTypeAttr);
}
return b.getArrayAttr(iteratorTypes);
}

template <typename LhloOp>
class PointwiseToLinalgConverter : public OpConversionPattern<LhloOp> {
public:
Expand Down Expand Up @@ -78,19 +87,15 @@ class PointwiseToLinalgConverter : public OpConversionPattern<LhloOp> {
result_or_body_arg.emplace_back(memrefType.getElementType());
}

// Pointwise-ops have all surrounding loops parallel, so the loop triple is
// [argDim, 0, 0].
SmallVector<Attribute, 3> loop_types{rewriter.getI64IntegerAttr(nloops),
rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(0)};
// Define the number of input memref/output memrefs.
SmallVector<Attribute, 2> nmemrefs{
rewriter.getI64IntegerAttr(bodyArgTypes.size()),
rewriter.getI64IntegerAttr(bodyResultTypes.size())};

auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, args, rewriter.getArrayAttr(indexingMaps),
rewriter.getArrayAttr(loop_types), rewriter.getArrayAttr(nmemrefs),
GetNParallelLoopsAttrs(nloops, rewriter),
rewriter.getArrayAttr(nmemrefs),
/*doc=*/nullptr, /*fun=*/nullptr, /*library_call=*/nullptr);

// Add a block to the region.
Expand Down Expand Up @@ -158,11 +163,6 @@ class BroadcastInDimConverter : public OpConversionPattern<BroadcastInDimOp> {
indexingMaps.emplace_back(
AffineMapAttr::get(rewriter.getMultiDimIdentityMap(nloops)));

// Broadcast op has all surrounding loops parallel, so the loop triple is
// [argDim, 0, 0].
SmallVector<Attribute, 3> loop_types{rewriter.getI64IntegerAttr(nloops),
rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(0)};
// Define the number of input memref/output memrefs.
SmallVector<Attribute, 2> nmemrefs{
rewriter.getI64IntegerAttr(bodyArgTypes.size()),
Expand All @@ -171,7 +171,8 @@ class BroadcastInDimConverter : public OpConversionPattern<BroadcastInDimOp> {
auto loc = broadcastOp.getLoc();
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, args, rewriter.getArrayAttr(indexingMaps),
rewriter.getArrayAttr(loop_types), rewriter.getArrayAttr(nmemrefs),
GetNParallelLoopsAttrs(nloops, rewriter),
rewriter.getArrayAttr(nmemrefs),
/*doc=*/nullptr, /*fun=*/nullptr, /*library_call=*/nullptr);

// Add a block to the region.
Expand Down Expand Up @@ -207,19 +208,15 @@ class IotaConverter : public OpConversionPattern<IotaOp> {
indexingMaps.emplace_back(
AffineMapAttr::get(rewriter.getMultiDimIdentityMap(nloops)));

// Pointwise-ops have all surrounding loops parallel, so the loop triple is
// [argDim, 0, 0].
SmallVector<Attribute, 3> loop_types{rewriter.getI64IntegerAttr(nloops),
rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(0)};
// Define the number of input memref/output memrefs.
SmallVector<Attribute, 2> nmemrefs{rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(1)};

auto loc = iotaOp.getLoc();
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
loc, args, rewriter.getArrayAttr(indexingMaps),
rewriter.getArrayAttr(loop_types), rewriter.getArrayAttr(nmemrefs),
GetNParallelLoopsAttrs(nloops, rewriter),
rewriter.getArrayAttr(nmemrefs),
/*doc=*/nullptr, /*fun=*/nullptr, /*library_call=*/nullptr);

// Add a block to the region.
Expand Down Expand Up @@ -277,7 +274,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
// "linalg.yield"(%0) : (f32) -> ()
// }) {
// indexing_maps = [#map0, #map0, #map0],
// n_loop_types = [2, 0, 0],
// iterator_types = ["parallel", "parallel"],
// n_views = [2, 1]
// } : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
// }
Expand Down
57 changes: 34 additions & 23 deletions third_party/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
class GenericOpBase<string mnemonic> : LinalgLibraryBase_Op<mnemonic, []> {
let arguments = (ins Variadic<AnyStridedMemRef>:$views,
AffineMapArrayAttr:$indexing_maps,
I64ArrayAttr:$n_loop_types,
ArrayAttr:$iterator_types,
I64ArrayAttr:$n_views,
OptionalAttr<StrAttr>:$doc,
OptionalAttr<FlatSymbolRefAttr>:$fun,
Expand All @@ -377,7 +377,7 @@ class GenericOpBase<string mnemonic> : LinalgLibraryBase_Op<mnemonic, []> {
let extraClassDeclaration = [{
SmallVector<StringRef, 8> linalgTraitAttrNames() {
return SmallVector<StringRef, 8>{
"doc", "fun", "indexing_maps", "library_call", "n_loop_types", "n_views"
"doc", "fun", "indexing_maps", "library_call", "iterator_types", "n_views"
};
}
unsigned getNumInputs() {
Expand All @@ -395,26 +395,35 @@ class GenericOpBase<string mnemonic> : LinalgLibraryBase_Op<mnemonic, []> {
return val.getZExtValue();
}
unsigned getNumParallelLoops() {
if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0)
return 0;
auto val = n_loop_types().getValue()[0].cast<IntegerAttr>().getValue();
assert(val.getSExtValue() >= 0);
return val.getZExtValue();
unsigned nPar = 0;
for (auto ty : iterator_types()) {
if (ty.cast<StringAttr>().getValue() == "parallel")
nPar++;
}
return nPar;
}
unsigned getNumReductionLoops() {
if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0)
return 0;
auto val = n_loop_types().getValue()[1].cast<IntegerAttr>().getValue();
assert(val.getSExtValue() >= 0);
return val.getZExtValue();
}
unsigned getNumWindowLoops() {
if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
unsigned nRed = 0;
for (auto ty : iterator_types()) {
if (ty.cast<StringAttr>().getValue() == "reduction")
nRed++;
}
return nRed;
}
unsigned getNumWindowLoops() {
if (!getAttr("iterator_types") || iterator_types().getValue().size() == 0)
return 0;
auto val = n_loop_types().getValue()[2].cast<IntegerAttr>().getValue();
assert(val.getSExtValue() >= 0);
return val.getZExtValue();
}
unsigned nWin = 0;
for (auto ty : iterator_types()) {
if (ty.cast<StringAttr>().getValue() == "window")
nWin++;
}
return nWin;
}
unsigned getNumLoops() {
return getNumParallelLoops() + getNumReductionLoops() +
getNumWindowLoops();
Expand Down Expand Up @@ -474,8 +483,9 @@ def GenericOp : GenericOpBase<"generic"> {
The external library is assumed to be dynamically linked and no strong
compile-time guarantees are provided. In the absence of such a library
call, linalg.generic will always lower to loops.
- n_loops: a triple of I64Attr representing the number of enclosing
[parallel, reduction, window] loops respectively.
- iterator_types: an ArrayAttr they type of the enclosing loops; Each element of
the list represents and iterator of one of the following types:
parallel, reduction, window
- n_views: a pair of I64Attr representing the number of input (readonly)
and output (readwrite) views.

Expand All @@ -498,7 +508,7 @@ def GenericOp : GenericOpBase<"generic"> {
indexing_maps = #matmul_accesses,
library_call = "linalg_matmul",
n_views = [2, 1],
n_loop_types = [2, 1, 0]
iterator_types = ["parallel", "parallel", "reduction"]
}
```

Expand Down Expand Up @@ -568,8 +578,9 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
maps to. The external library is assumed to be dynamically linked and
no strong compile-time guarantees are provided. In the absence of such
a library call, linalg.indexed_generic will always lower to loops.
- n_loops: a triple of I64Attr representing the number of enclosing
[parallel, reduction, window] loops respectively.
- iterator_types: an ArrayAttr they type of the enclosing loops; Each element of
the list represents and iterator of one of the following types:
parallel, reduction, window
- n_views: a pair of I64Attr representing the number of input (readonly)
and output (readwrite) views.

Expand All @@ -592,7 +603,7 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
indexing_maps = #matmul_accesses,
library_call = "linalg_matmul",
n_views = [2, 1],
n_loop_types = [2, 1, 0]
iterator_types = ["parallel", "parallel", "reduction"]
}
```

Expand Down

0 comments on commit 6596a60

Please sign in to comment.