Skip to content
This repository has been archived by the owner on Apr 23, 2021. It is now read-only.

Commit

Permalink
Start a Linalg dialect
Browse files Browse the repository at this point in the history
    This CL starts implementing a Linalg dialect with the objective of supporting
    optimizing compilation of loops and library calls for a subset of common linear
    algebra operations.

    This CL starts by simply adding a linalg.range type and an operation with the
    proper roundtripping test.

--

PiperOrigin-RevId: 244189468
  • Loading branch information
Nicolas Vasilache authored and joker-eph committed Apr 18, 2019
1 parent 25a2e45 commit 0f11538
Show file tree
Hide file tree
Showing 10 changed files with 276 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/Linalg/Linalg1/include/linalg1/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
namespace linalg {

enum LinalgTypes {
Range = mlir::Type::FIRST_LINALG_TYPE,
Range = mlir::Type::FIRST_PRIVATE_EXPERIMENTAL_0_TYPE,
View,
LAST_USED_LINALG_TYPE = View,
FIRST_PRIVATE_EXPERIMENTAL_0_TYPE = View,
};

} // namespace linalg
Expand Down
53 changes: 53 additions & 0 deletions include/mlir/Linalg/LinalgOps.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//===- LinalgOps.h - Linalg Operations --------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#ifndef MLIR_LINALG_LINALGOPS_H_
#define MLIR_LINALG_LINALGOPS_H_

#include "mlir/IR/OpDefinition.h"
#include "mlir/Support/LLVM.h"

namespace mlir {

/// A RangeOp is used to create a value of RangeType from 3 values of type index
/// that represent the min, max and step values of the range.
class RangeOp : public Op<RangeOp, OpTrait::NOperands<3>::Impl,
OpTrait::OneResult, OpTrait::HasNoSideEffect> {
public:
using Op::Op;

//////////////////////////////////////////////////////////////////////////////
// Hooks to customize the behavior of this op.
//////////////////////////////////////////////////////////////////////////////
static llvm::StringRef getOperationName() { return "linalg.range"; }
static void build(Builder *b, OperationState *result, Value *min, Value *max,
Value *step);
LogicalResult verify();
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p);

//////////////////////////////////////////////////////////////////////////////
// Op-specific functionality.
//////////////////////////////////////////////////////////////////////////////
Value *min() { return getOperand(0); }
Value *max() { return getOperand(1); }
Value *step() { return getOperand(2); }
};

} // namespace mlir

#endif // MLIR_LINALG_LINALGOPS_H_
59 changes: 59 additions & 0 deletions include/mlir/Linalg/LinalgTypes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
//===- LinalgTypes.h - Linalg Types ---------------------------------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#ifndef MLIR_LINALG_LINALGTYPES_H_
#define MLIR_LINALG_LINALGTYPES_H_

#include "mlir/IR/Dialect.h"
#include "mlir/IR/Types.h"

namespace mlir {
class MLIRContext;

enum LinalgTypes {
Range = Type::FIRST_LINALG_TYPE,
LAST_USED_LINALG_TYPE = Range,
};

class LinalgDialect : public Dialect {
public:
explicit LinalgDialect(MLIRContext *context);

/// Parse a type registered to this dialect.
Type parseType(llvm::StringRef spec, Location loc) const override;

/// Print a type registered to this dialect.
void printType(Type type, llvm::raw_ostream &os) const override;
};

/// A RangeType represents a minimal range abstraction (min, max, step).
class RangeType : public Type::TypeBase<RangeType, Type> {
public:
// Used for generic hooks in TypeBase.
using Base::Base;
/// Construction hook.
static RangeType get(MLIRContext *context) {
/// Custom, uniq'ed construction in the MLIRContext.
return Base::get(context, LinalgTypes::Range);
}
/// Used to implement llvm-style cast.
static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; }
};

} // namespace mlir

#endif // MLIR_LINALG_LINALGTYPES_H_
1 change: 1 addition & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ add_subdirectory(ExecutionEngine)
add_subdirectory(FxpMathOps)
add_subdirectory(IR)
add_subdirectory(LLVMIR)
add_subdirectory(Linalg)
add_subdirectory(Parser)
add_subdirectory(Pass)
add_subdirectory(Quantization)
Expand Down
8 changes: 8 additions & 0 deletions lib/Linalg/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
add_llvm_library(MLIRLinalg
LinalgOps.cpp
LinalgRegistration.cpp
LinalgTypes.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Linalg
)
67 changes: 67 additions & 0 deletions lib/Linalg/LinalgOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements a the Linalg operations.
//
//===----------------------------------------------------------------------===//

#include "mlir/Linalg/LinalgOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Linalg/LinalgTypes.h"
#include "mlir/Support/LLVM.h"

using namespace mlir;

void mlir::RangeOp::build(Builder *b, OperationState *result, Value *min,
Value *max, Value *step) {
result->addOperands({min, max, step});
result->addTypes({RangeType::get(b->getContext())});
}

// Verification is simply that a RangeOp takes 3 index ssa-value.
mlir::LogicalResult mlir::RangeOp::verify() {
if (!min() || !min()->getType().isa<IndexType>())
return emitOpError("first operand should be of type index");
if (!max() || !max()->getType().isa<IndexType>())
return emitOpError("second operand should be of type index");
if (!step() || !step()->getType().isa<IndexType>())
return emitOpError("third operand should be of type index");
return mlir::success();
}

// A RangeOp prints as:
//
// ```{.mlir}
// linalg.range %0:%1:%2 : !linalg.range
// ```
void mlir::RangeOp::print(OpAsmPrinter *p) {
*p << getOperationName() << " " << *min() << ":" << *max() << ":" << *step()
<< " : " << getType();
}

bool mlir::RangeOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
RangeType type;
auto affineIntTy = parser->getBuilder().getIndexType();
return parser->parseOperand(rangeInfo[0]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[1]) || parser->parseColon() ||
parser->parseOperand(rangeInfo[2]) || parser->parseColonType(type) ||
parser->resolveOperands(rangeInfo, affineIntTy, result->operands) ||
parser->addTypeToList(type, result->types);
}
24 changes: 24 additions & 0 deletions lib/Linalg/LinalgRegistration.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//===- LinalgRegistration.cpp - Register the linalg dialect statically ----===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#include "mlir/Linalg/LinalgOps.h"
#include "mlir/Linalg/LinalgTypes.h"

using namespace mlir;

// Static initialization for LinalgOps dialect registration.
static DialectRegistration<LinalgDialect> LinalgOps;
53 changes: 53 additions & 0 deletions lib/Linalg/LinalgTypes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//===- Dialect.cpp - Implementation of the linalg dialect and types -------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file implements the Linalg dialect types and dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Linalg/LinalgTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/Linalg/LinalgOps.h"
#include "mlir/Support/LLVM.h"

using namespace mlir;

mlir::LinalgDialect::LinalgDialect(MLIRContext *context)
: Dialect("linalg", context) {
addTypes<RangeType>();
addOperations<RangeOp>();
}

Type mlir::LinalgDialect::parseType(StringRef spec, Location loc) const {
MLIRContext *context = getContext();
if (spec == "range")
return RangeType::get(getContext());
return (context->emitError(loc, "unknown Linalg type: " + spec), Type());
}

/// RangeType prints as just "range".
static void print(RangeType rt, raw_ostream &os) { os << "range"; }

void mlir::LinalgDialect::printType(Type type, raw_ostream &os) const {
switch (type.getKind()) {
default:
llvm_unreachable("Unhandled Linalg type");
case LinalgTypes::Range:
print(type.cast<RangeType>(), os);
break;
}
}
8 changes: 8 additions & 0 deletions test/Linalg/roundtrip.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// RUN: mlir-opt %s -verify | mlir-opt | FileCheck %s

func @range(%arg0: index, %arg1: index, %arg2: index) {
%0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
return
}
// CHECK-LABEL: func @range(%arg0: index, %arg1: index, %arg2: index) {
// CHECK-NEXT: %0 = linalg.range %arg0:%arg1:%arg2 : !linalg.range
1 change: 1 addition & 0 deletions tools/mlir-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ set(LIBS
MLIRAnalysis
MLIREDSC
MLIRFxpMathOps
MLIRLinalg
MLIRLLVMIR
MLIRParser
MLIRPass
Expand Down

0 comments on commit 0f11538

Please sign in to comment.