Skip to content
This repository has been archived by the owner on Apr 23, 2021. It is now read-only.

Commit

Permalink
[spirv] Add CompositeExtractOp operation.
Browse files Browse the repository at this point in the history
CompositeExtractOp allows to extract a part of a composite object.
  • Loading branch information
denis0x0D committed Jul 11, 2019
1 parent 5f2159d commit 8ce9295
Show file tree
Hide file tree
Showing 6 changed files with 348 additions and 8 deletions.
8 changes: 5 additions & 3 deletions include/mlir/SPIRV/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def SPV_Dialect : Dialect {
def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">;
def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">;
def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">;
def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">;

// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
// for the definition of the following types and type categories.
Expand All @@ -79,14 +80,15 @@ def SPV_Vector : VectorOf<[SPV_Bool, SPV_Integer, SPV_Float]>;
def SPV_AnyPtr : Type<SPV_IsPtrType, "any SPIR-V pointer type">;
def SPV_AnyArray : Type<SPV_IsArrayType, "any SPIR-V array type">;
def SPV_AnyRTArray : Type<SPV_IsRTArrayType, "any SPIR-V runtime array type">;
def SPV_AnyStruct : Type<SPV_IsStructType, "any SPIR-V struct type">;

def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>;
def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
def SPV_Aggregrate : AnyTypeOf<[SPV_AnyArray]>;
def SPV_Composite: AnyTypeOf<[SPV_Vector, SPV_AnyArray]>;
def SPV_Aggregrate : AnyTypeOf<[SPV_AnyArray, SPV_AnyStruct]>;
def SPV_Composite: AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyStruct]>;
def SPV_Type : AnyTypeOf<[
SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector,
SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray
SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct
]>;

class SPV_ScalarOrVectorOf<Type type> :
Expand Down
44 changes: 44 additions & 0 deletions include/mlir/SPIRV/SPIRVOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -324,4 +324,48 @@ def SPV_VariableOp : SPV_Op<"Variable"> {
let opcode = 59;
}

def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> {
let summary = "Extract a part of a composite object.";

let description = [{
Result Type must be the type of object selected by the last provided index.
The instruction result is the extracted object.

Composite is the composite to extract from.

Indexes walk the type hierarchy, potentially down to component granularity,
to select the part to extract. All indexes must be in bounds.
All composite constituents use zero-based numbering, as described by their
OpType… instruction.

### Custom assembly form

``` {.ebnf}
composite-extract-op ::= ssa-id `=` `spv.CompositeExtract` ssa-use
`[` integer-literal (',' integer-literal)* `]`
`:` composite-type
```

For example:

```
%0 = spv.Variable : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
%1 = spv.Load "Function" %0 ["Volatile"] : !spv.array<4x!spv.array<4xf32>>
%2 = spv.CompositeExtract %1[1 : i32] : !spv.array<4x!spv.array<4xf32>>
```

}];

let arguments = (ins
SPV_Composite:$composite,
I32ArrayAttr:$indices
);

let results = (outs
SPV_Type:$component
);

let opcode = 81;
}

#endif // SPIRV_OPS
27 changes: 22 additions & 5 deletions include/mlir/SPIRV/SPIRVTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#ifndef MLIR_SPIRV_SPIRVTYPES_H_
#define MLIR_SPIRV_SPIRVTYPES_H_

#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeSupport.h"
#include "mlir/IR/Types.h"

Expand Down Expand Up @@ -52,9 +53,26 @@ enum Kind {
};
}

// Provides a common base class between VectorType, SPIR-V ArrayType, SPIR-V
// StructType.
class CompositeType : public Type {
public:
using Type::Type;

static bool classof(Type type) {
return (type.getKind() == TypeKind::Array ||
type.getKind() == TypeKind::Struct ||
type.getKind() == StandardTypes::Vector);
}

uint64_t getNumMembers() const;

Type getMemberType(uint64_t) const;
};

// SPIR-V array type
class ArrayType
: public Type::TypeBase<ArrayType, Type, detail::ArrayTypeStorage> {
class ArrayType : public Type::TypeBase<ArrayType, CompositeType,
detail::ArrayTypeStorage> {
public:
using Base::Base;

Expand Down Expand Up @@ -147,9 +165,8 @@ class RuntimeArrayType
};

// SPIR-V struct type
class StructType
: public Type::TypeBase<StructType, Type, detail::StructTypeStorage> {

class StructType : public Type::TypeBase<StructType, CompositeType,
detail::StructTypeStorage> {
public:
using Base::Base;

Expand Down
110 changes: 110 additions & 0 deletions lib/SPIRV/SPIRVOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ using namespace mlir;
static constexpr const char kAlignmentAttrName[] = "alignment";
static constexpr const char kBindingAttrName[] = "binding";
static constexpr const char kDescriptorSetAttrName[] = "descriptor_set";
static constexpr const char kIndicesAttrName[] = "indices";
static constexpr const char kValueAttrName[] = "value";
static constexpr const char kValuesAttrName[] = "values";
static constexpr const char kFnNameAttrName[] = "fn";
Expand Down Expand Up @@ -197,6 +198,115 @@ static void printNoIOOp(Operation *op, OpAsmPrinter *printer) {
printer->printOptionalAttrDict(op->getAttrs());
}

//===----------------------------------------------------------------------===//
// spv.CompositeExtractOp
//===----------------------------------------------------------------------===//

static ParseResult parseCompositeExtractOp(OpAsmParser *parser,
OperationState *state) {
OpAsmParser::OperandType compositeInfo;
Attribute indicesAttr;
Type compositeType;
llvm::SMLoc attrLocation;
int32_t index;

if (parser->parseOperand(compositeInfo) ||
parser->getCurrentLocation(&attrLocation) ||
parser->parseAttribute(indicesAttr, kIndicesAttrName,
state->attributes) ||
parser->parseColonType(compositeType) ||
parser->resolveOperand(compositeInfo, compositeType, state->operands)) {
return failure();
}

auto indicesArrayAttr = indicesAttr.dyn_cast<ArrayAttr>();
if (!indicesArrayAttr) {
return parser->emitError(
attrLocation,
"expected an 32-bit integer array attribute for 'indices'");
}

if (!indicesArrayAttr.size()) {
return parser->emitError(
attrLocation, "expected at least one index for spv.CompositeExtract");
}

Type resultType = compositeType;
for (auto indexAttr : indicesArrayAttr) {
if (auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>()) {
index = indexIntAttr.getInt();
} else {
return parser->emitError(
attrLocation,
"expexted an 32-bit integer for index, but found '")
<< indexAttr << "'";
}
if (resultType.isa<spirv::CompositeType>()) {
if (index < 0 ||
static_cast<uint32_t>(index) >=
resultType.cast<spirv::CompositeType>().getNumMembers()) {
return parser->emitError(attrLocation, "index ")
<< index << " out of bounds for " << resultType;
}
resultType = resultType.cast<spirv::CompositeType>().getMemberType(index);
} else {
return parser->emitError(attrLocation, "invalid type to extract from");
}
}

if (!resultType) {
return parser->emitError(attrLocation,
"can not extract from composite type");
}

state->addTypes(resultType);
return success();
}

static void print(spirv::CompositeExtractOp compositeExtractOp,
OpAsmPrinter *printer) {
*printer << spirv::CompositeExtractOp::getOperationName() << ' '
<< *compositeExtractOp.composite() << compositeExtractOp.indices()
<< " : " << compositeExtractOp.composite()->getType();
}

static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
auto resultType = compExOp.composite()->getType();
auto indicesArrayAttr = compExOp.indices().dyn_cast<ArrayAttr>();

if (!indicesArrayAttr.size()) {
return compExOp.emitOpError(
"expexted at least one index for spv.CompositeExtractOp");
}

int32_t index;
for (auto indexAttr : indicesArrayAttr) {
if (!resultType) {
return compExOp.emitError("invalid type to extract from");
}

index = indexAttr.dyn_cast<IntegerAttr>().getInt();
if (resultType.isa<spirv::CompositeType>()) {
if (index < 0 ||
static_cast<uint32_t>(index) >=
resultType.cast<spirv::CompositeType>().getNumMembers()) {
return compExOp.emitOpError("index ")
<< index << " out of bounds for " << resultType;
}
resultType = resultType.cast<spirv::CompositeType>().getMemberType(index);
} else {
return compExOp.emitOpError("invalid type to extract from");
}
}

if (resultType != compExOp.getType()) {
return compExOp.emitOpError("invalid result type, expected ")
<< resultType << " provided " << compExOp.getType();
}

return success();
}

//===----------------------------------------------------------------------===//
// spv.constant
//===----------------------------------------------------------------------===//
Expand Down
31 changes: 31 additions & 0 deletions lib/SPIRV/SPIRVTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/StandardTypes.h"
#include "mlir/SPIRV/SPIRVTypes.h"
#include "llvm/ADT/StringSwitch.h"

Expand All @@ -28,6 +29,36 @@ using namespace mlir::spirv;
// Pull in all enum utility function definitions
#include "mlir/SPIRV/SPIRVEnums.cpp.inc"

//===----------------------------------------------------------------------===//
// CompositeType
//===----------------------------------------------------------------------===//

Type CompositeType::getMemberType(uint64_t index) const {
switch (getKind()) {
case spirv::TypeKind::Struct:
return cast<StructType>().getMemberType(index);
case spirv::TypeKind::Array:
return cast<ArrayType>().getElementType();
case StandardTypes::Vector:
return cast<VectorType>().getElementType();
default:
llvm_unreachable("invalid composite type");
}
}

uint64_t CompositeType::getNumMembers() const {
switch (getKind()) {
case spirv::TypeKind::Struct:
return cast<StructType>().getNumMembers();
case spirv::TypeKind::Array:
return cast<ArrayType>().getElementCount();
case StandardTypes::Vector:
return cast<VectorType>().getNumElements();
default:
llvm_unreachable("invalid composite type");
}
}

//===----------------------------------------------------------------------===//
// ArrayType
//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 8ce9295

Please sign in to comment.