diff --git a/include/mlir/SPIRV/SPIRVBase.td b/include/mlir/SPIRV/SPIRVBase.td index b862f71c3075..7377a0035e24 100644 --- a/include/mlir/SPIRV/SPIRVBase.td +++ b/include/mlir/SPIRV/SPIRVBase.td @@ -65,6 +65,7 @@ def SPV_Dialect : Dialect { def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">; def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">; def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">; +def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">; // See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types // for the definition of the following types and type categories. @@ -79,14 +80,15 @@ def SPV_Vector : VectorOf<[SPV_Bool, SPV_Integer, SPV_Float]>; def SPV_AnyPtr : Type; def SPV_AnyArray : Type; def SPV_AnyRTArray : Type; +def SPV_AnyStruct : Type; def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>; def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>; -def SPV_Aggregrate : AnyTypeOf<[SPV_AnyArray]>; -def SPV_Composite: AnyTypeOf<[SPV_Vector, SPV_AnyArray]>; +def SPV_Aggregrate : AnyTypeOf<[SPV_AnyArray, SPV_AnyStruct]>; +def SPV_Composite: AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyStruct]>; def SPV_Type : AnyTypeOf<[ SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector, - SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray + SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct ]>; class SPV_ScalarOrVectorOf : diff --git a/include/mlir/SPIRV/SPIRVOps.td b/include/mlir/SPIRV/SPIRVOps.td index afab62ab4e41..6be0a949fc5f 100644 --- a/include/mlir/SPIRV/SPIRVOps.td +++ b/include/mlir/SPIRV/SPIRVOps.td @@ -324,4 +324,48 @@ def SPV_VariableOp : SPV_Op<"Variable"> { let opcode = 59; } +def SPV_CompositeExtractOp : SPV_Op<"CompositeExtract", [NoSideEffect]> { + let summary = "Extract a part of a composite object."; + + let description = [{ + Result Type must be the type of object selected by the last provided index. + The instruction result is the extracted object. + + Composite is the composite to extract from. + + Indexes walk the type hierarchy, potentially down to component granularity, + to select the part to extract. All indexes must be in bounds. + All composite constituents use zero-based numbering, as described by their + OpType… instruction. + + ### Custom assembly form + + ``` {.ebnf} + composite-extract-op ::= ssa-id `=` `spv.CompositeExtract` ssa-use + `[` integer-literal (',' integer-literal)* `]` + `:` composite-type + ``` + + For example: + + ``` + %0 = spv.Variable : !spv.ptr>, Function> + %1 = spv.Load "Function" %0 ["Volatile"] : !spv.array<4x!spv.array<4xf32>> + %2 = spv.CompositeExtract %1[1 : i32] : !spv.array<4x!spv.array<4xf32>> + ``` + + }]; + + let arguments = (ins + SPV_Composite:$composite, + I32ArrayAttr:$indices + ); + + let results = (outs + SPV_Type:$component + ); + + let opcode = 81; +} + #endif // SPIRV_OPS diff --git a/include/mlir/SPIRV/SPIRVTypes.h b/include/mlir/SPIRV/SPIRVTypes.h index 0bdb9f030234..430cd7c270ff 100644 --- a/include/mlir/SPIRV/SPIRVTypes.h +++ b/include/mlir/SPIRV/SPIRVTypes.h @@ -22,6 +22,7 @@ #ifndef MLIR_SPIRV_SPIRVTYPES_H_ #define MLIR_SPIRV_SPIRVTYPES_H_ +#include "mlir/IR/StandardTypes.h" #include "mlir/IR/TypeSupport.h" #include "mlir/IR/Types.h" @@ -52,9 +53,26 @@ enum Kind { }; } +// Provides a common base class between VectorType, SPIR-V ArrayType, SPIR-V +// StructType. +class CompositeType : public Type { +public: + using Type::Type; + + static bool classof(Type type) { + return (type.getKind() == TypeKind::Array || + type.getKind() == TypeKind::Struct || + type.getKind() == StandardTypes::Vector); + } + + uint64_t getNumMembers() const; + + Type getMemberType(uint64_t) const; +}; + // SPIR-V array type -class ArrayType - : public Type::TypeBase { +class ArrayType : public Type::TypeBase { public: using Base::Base; @@ -147,9 +165,8 @@ class RuntimeArrayType }; // SPIR-V struct type -class StructType - : public Type::TypeBase { - +class StructType : public Type::TypeBase { public: using Base::Base; diff --git a/lib/SPIRV/SPIRVOps.cpp b/lib/SPIRV/SPIRVOps.cpp index 5e275002e526..0cd29770f6df 100644 --- a/lib/SPIRV/SPIRVOps.cpp +++ b/lib/SPIRV/SPIRVOps.cpp @@ -38,6 +38,7 @@ using namespace mlir; static constexpr const char kAlignmentAttrName[] = "alignment"; static constexpr const char kBindingAttrName[] = "binding"; static constexpr const char kDescriptorSetAttrName[] = "descriptor_set"; +static constexpr const char kIndicesAttrName[] = "indices"; static constexpr const char kValueAttrName[] = "value"; static constexpr const char kValuesAttrName[] = "values"; static constexpr const char kFnNameAttrName[] = "fn"; @@ -197,6 +198,115 @@ static void printNoIOOp(Operation *op, OpAsmPrinter *printer) { printer->printOptionalAttrDict(op->getAttrs()); } +//===----------------------------------------------------------------------===// +// spv.CompositeExtractOp +//===----------------------------------------------------------------------===// + +static ParseResult parseCompositeExtractOp(OpAsmParser *parser, + OperationState *state) { + OpAsmParser::OperandType compositeInfo; + Attribute indicesAttr; + Type compositeType; + llvm::SMLoc attrLocation; + int32_t index; + + if (parser->parseOperand(compositeInfo) || + parser->getCurrentLocation(&attrLocation) || + parser->parseAttribute(indicesAttr, kIndicesAttrName, + state->attributes) || + parser->parseColonType(compositeType) || + parser->resolveOperand(compositeInfo, compositeType, state->operands)) { + return failure(); + } + + auto indicesArrayAttr = indicesAttr.dyn_cast(); + if (!indicesArrayAttr) { + return parser->emitError( + attrLocation, + "expected an 32-bit integer array attribute for 'indices'"); + } + + if (!indicesArrayAttr.size()) { + return parser->emitError( + attrLocation, "expected at least one index for spv.CompositeExtract"); + } + + Type resultType = compositeType; + for (auto indexAttr : indicesArrayAttr) { + if (auto indexIntAttr = indexAttr.dyn_cast()) { + index = indexIntAttr.getInt(); + } else { + return parser->emitError( + attrLocation, + "expexted an 32-bit integer for index, but found '") + << indexAttr << "'"; + } + if (resultType.isa()) { + if (index < 0 || + static_cast(index) >= + resultType.cast().getNumMembers()) { + return parser->emitError(attrLocation, "index ") + << index << " out of bounds for " << resultType; + } + resultType = resultType.cast().getMemberType(index); + } else { + return parser->emitError(attrLocation, "invalid type to extract from"); + } + } + + if (!resultType) { + return parser->emitError(attrLocation, + "can not extract from composite type"); + } + + state->addTypes(resultType); + return success(); +} + +static void print(spirv::CompositeExtractOp compositeExtractOp, + OpAsmPrinter *printer) { + *printer << spirv::CompositeExtractOp::getOperationName() << ' ' + << *compositeExtractOp.composite() << compositeExtractOp.indices() + << " : " << compositeExtractOp.composite()->getType(); +} + +static LogicalResult verify(spirv::CompositeExtractOp compExOp) { + auto resultType = compExOp.composite()->getType(); + auto indicesArrayAttr = compExOp.indices().dyn_cast(); + + if (!indicesArrayAttr.size()) { + return compExOp.emitOpError( + "expexted at least one index for spv.CompositeExtractOp"); + } + + int32_t index; + for (auto indexAttr : indicesArrayAttr) { + if (!resultType) { + return compExOp.emitError("invalid type to extract from"); + } + + index = indexAttr.dyn_cast().getInt(); + if (resultType.isa()) { + if (index < 0 || + static_cast(index) >= + resultType.cast().getNumMembers()) { + return compExOp.emitOpError("index ") + << index << " out of bounds for " << resultType; + } + resultType = resultType.cast().getMemberType(index); + } else { + return compExOp.emitOpError("invalid type to extract from"); + } + } + + if (resultType != compExOp.getType()) { + return compExOp.emitOpError("invalid result type, expected ") + << resultType << " provided " << compExOp.getType(); + } + + return success(); +} + //===----------------------------------------------------------------------===// // spv.constant //===----------------------------------------------------------------------===// diff --git a/lib/SPIRV/SPIRVTypes.cpp b/lib/SPIRV/SPIRVTypes.cpp index b2dad8122f13..f3515167988f 100644 --- a/lib/SPIRV/SPIRVTypes.cpp +++ b/lib/SPIRV/SPIRVTypes.cpp @@ -19,6 +19,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/StandardTypes.h" #include "mlir/SPIRV/SPIRVTypes.h" #include "llvm/ADT/StringSwitch.h" @@ -28,6 +29,36 @@ using namespace mlir::spirv; // Pull in all enum utility function definitions #include "mlir/SPIRV/SPIRVEnums.cpp.inc" +//===----------------------------------------------------------------------===// +// CompositeType +//===----------------------------------------------------------------------===// + +Type CompositeType::getMemberType(uint64_t index) const { + switch (getKind()) { + case spirv::TypeKind::Struct: + return cast().getMemberType(index); + case spirv::TypeKind::Array: + return cast().getElementType(); + case StandardTypes::Vector: + return cast().getElementType(); + default: + llvm_unreachable("invalid composite type"); + } +} + +uint64_t CompositeType::getNumMembers() const { + switch (getKind()) { + case spirv::TypeKind::Struct: + return cast().getNumMembers(); + case spirv::TypeKind::Array: + return cast().getElementCount(); + case StandardTypes::Vector: + return cast().getNumElements(); + default: + llvm_unreachable("invalid composite type"); + } +} + //===----------------------------------------------------------------------===// // ArrayType //===----------------------------------------------------------------------===// diff --git a/test/SPIRV/ops.mlir b/test/SPIRV/ops.mlir index 59d6121ad27c..de0e47e8b449 100644 --- a/test/SPIRV/ops.mlir +++ b/test/SPIRV/ops.mlir @@ -1,5 +1,141 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s +//===----------------------------------------------------------------------===// +// spv.CompositeExtractOp +//===----------------------------------------------------------------------===// + +func @composite_extract_f32_from_1D_array(%arg0: !spv.array<4xf32>) -> f32 { + // CHECK: %0 = spv.CompositeExtract %arg0[1 : i32] : !spv.array<4 x f32> + %0 = spv.CompositeExtract %arg0[1 : i32] : !spv.array<4xf32> + return %0: f32 +} + +// ----- + +func @composite_extract_f32_from_2D_array(%arg0: !spv.array<4x!spv.array<4xf32>>) -> f32 { + // CHECK: %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spv.array<4 x !spv.array<4 x f32>> + %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spv.array<4x!spv.array<4xf32>> + return %0: f32 +} + +// ----- + +func @composite_extract_1D_array_from_2D_array(%arg0: !spv.array<4x!spv.array<4xf32>>) -> !spv.array<4xf32> { + // CHECK: %0 = spv.CompositeExtract %arg0[1 : i32] : !spv.array<4 x !spv.array<4 x f32>> + %0 = spv.CompositeExtract %arg0[1 : i32] : !spv.array<4x!spv.array<4xf32>> + return %0 : !spv.array<4xf32> +} + +// ----- + +func @composite_extract_struct(%arg0 : !spv.struct>) -> f32 { + // CHECK: %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spv.struct> + %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32] : !spv.struct> + return %0 : f32 +} + +// ----- + +func @composite_extract_vector(%arg0 : vector<4xf32>) -> f32 { + // CHECK: %0 = spv.CompositeExtract %arg0[1 : i32] : vector<4xf32> + %0 = spv.CompositeExtract %arg0[1 : i32] : vector<4xf32> + return %0 : f32 +} + +// ----- + +func @composite_extract_no_ssa_operand() -> () { + // expected-error @+1 {{expected SSA operand}} + %0 = spv.CompositeExtract [4 : i32, 1 : i32] : !spv.array<4x!spv.array<4xf32>> + return +} + +// ----- + +func @composite_extract_invalid_index_type_1() -> () { + %0 = spv.constant 10 : i32 + %1 = spv.Variable : !spv.ptr>, Function> + %2 = spv.Load "Function" %1 ["Volatile"] : !spv.array<4x!spv.array<4xf32>> + // expected-error @+1 {{expected non-function type}} + %3 = spv.CompositeExtract %2[%0] : !spv.array<4x!spv.array<4xf32>> + return +} + +// ----- + +func @composite_extract_invalid_index_type_2(%arg0 : !spv.array<4x!spv.array<4xf32>>) -> () { + // expected-error @+1 {{op attribute 'indices' failed to satisfy constraint: 32-bit integer array attribute}} + %0 = spv.CompositeExtract %arg0[1] : !spv.array<4x!spv.array<4xf32>> + return +} + +// ----- + +func @composite_extract_invalid_index_identifier(%arg0 : !spv.array<4x!spv.array<4xf32>>) -> () { + // expected-error @+1 {{expected bare identifier}} + %0 = spv.CompositeExtract %arg0(1 : i32) : !spv.array<4x!spv.array<4xf32>> + return +} + +// ----- + +func @composite_extract_2D_array_out_of_bounds_access_1(%arg0: !spv.array<4x!spv.array<4xf32>>) -> () { + // expected-error @+1 {{index 4 out of bounds for '!spv.array<4 x !spv.array<4 x f32>>'}} + %0 = spv.CompositeExtract %arg0[4 : i32, 1 : i32] : !spv.array<4x!spv.array<4xf32>> + return +} + +// ----- + +func @composite_extract_2D_array_out_of_bounds_access_2(%arg0: !spv.array<4x!spv.array<4xf32>> +) -> () { + // expected-error @+1 {{index 4 out of bounds for '!spv.array<4 x f32>'}} + %0 = spv.CompositeExtract %arg0[1 : i32, 4 : i32] : !spv.array<4x!spv.array<4xf32>> + return +} + +// ----- + +func @composite_extract_struct_element_out_of_bounds_access(%arg0 : !spv.struct>) -> () { + // expected-error @+1 {{index 2 out of bounds for '!spv.struct>'}} + %0 = spv.CompositeExtract %arg0[2 : i32, 0 : i32] : !spv.struct> + return +} + +// ----- + +func @composite_extract_vector_out_of_bounds_access(%arg0: vector<4xf32>) -> () { + // expected-error @+1 {{index 4 out of bounds for 'vector<4xf32>'}} + %0 = spv.CompositeExtract %arg0[4 : i32] : vector<4xf32> + return +} + +// ----- + +func @composite_extract_invalid_types_1(%arg0: !spv.array<4x!spv.array<4xf32>>) -> () { + // expected-error @+1 {{invalid type to extract from}} + %0 = spv.CompositeExtract %arg0[1 : i32, 2 : i32, 3 : i32] : !spv.array<4x!spv.array<4xf32>> + return +} + +// ----- + +func @composite_extract_invalid_types_2(%arg0: f32) -> () { + // expected-error @+1 {{invalid type to extract from}} + %0 = spv.CompositeExtract %arg0[1 : i32] : f32 + return +} + +// ----- + +func @composite_extract_invalid_extracted_type(%arg0: !spv.array<4x!spv.array<4xf32>>) -> () { + // expected-error @+1 {{expected at least one index for spv.CompositeExtract}} + %0 = spv.CompositeExtract %arg0[] : !spv.array<4x!spv.array<4xf32>> + return +} + +// ----- + //===----------------------------------------------------------------------===// // spv.EntryPoint //===----------------------------------------------------------------------===//