-
Notifications
You must be signed in to change notification settings - Fork 258
[spirv] Add CompositeExtractOp operation. #44
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome, thanks denis0x0D! I like the extensive tests! :)
Yes because we need to set the extracted component's type, we need to duplicate the check that happens in the validator inside the parser. It's a bit annoying. The parser check is only used for parsing custom assembly forms; the verifier guards all paths: parsing custom assembly form, parsing generic assembly form, and building ops with C++ API, etc. So we need both. But by properly define them with functions I think we can share the same impl at least.
More comments inline.
include/mlir/SPIRV/SPIRVBase.td
Outdated
@@ -79,14 +80,15 @@ def SPV_Vector : VectorOf<[SPV_Bool, SPV_Integer, SPV_Float]>; | |||
def SPV_AnyPtr : Type<SPV_IsPtrType, "any SPIR-V pointer type">; | |||
def SPV_AnyArray : Type<SPV_IsArrayType, "any SPIR-V array type">; | |||
def SPV_AnyRTArray : Type<SPV_IsRTArrayType, "any SPIR-V runtime array type">; | |||
def SPV_AnyStruct : Type<SPV_IsStructType, "any SPIR-V struct type">; | |||
|
|||
def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>; | |||
def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>; | |||
def SPV_Aggregrate : AnyTypeOf<[SPV_AnyArray]>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also add SPV_AnyStruct
to SPV_Aggregrate
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, added.
include/mlir/SPIRV/SPIRVOps.td
Outdated
let summary = "Extract a part of a composite object."; | ||
|
||
let description = [{ | ||
Extract a part of a composite object. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to duplicate the summary line. The summary and description will be concatenated together when generating the doc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
include/mlir/SPIRV/SPIRVOps.td
Outdated
I32ArrayAttr:$indices | ||
); | ||
|
||
let results = (outs SPV_Type); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please give the output a name, like $component
or something so that we will generate a nicer getter method for ti. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
lib/SPIRV/SPIRVOps.cpp
Outdated
@@ -740,6 +740,102 @@ static LogicalResult verify(spirv::VariableOp varOp) { | |||
return success(); | |||
} | |||
|
|||
//===----------------------------------------------------------------------===// | |||
// spv.CompositeExtractOp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Ops should be sorted alphabetically in this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see, thanks, fixed.
lib/SPIRV/SPIRVOps.cpp
Outdated
|
||
// Extracts the Result Type from Composite Type iterating all over the indices, | ||
// with check for out of bounds. | ||
static Type getElementTypeByIndices(OpAsmParser *parser, OperationState *state, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this element type extraction logic is generally useful for working with composite types. So it should be added to the type system. We can create a base class for composite types and let struct and arrays subclass from it. That way we have a uniform interface for querying element types. This involves touching custom types registered for the SPIR-V dialect though. You can use ShapedType
and VectorType
as an example of how to define base class and subclasses:
mlir/include/mlir/IR/StandardTypes.h
Lines 178 to 272 in 9422eb7
/// This is a common base class between Vector, UnrankedTensor, RankedTensor, | |
/// and MemRef types because they share behavior and semantics around shape, | |
/// rank, and fixed element type. Any type with these semantics should inherit | |
/// from ShapedType. | |
class ShapedType : public Type { | |
public: | |
using ImplType = detail::ShapedTypeStorage; | |
using Type::Type; | |
/// Return the element type. | |
Type getElementType() const; | |
/// If an element type is an integer or a float, return its width. Otherwise, | |
/// abort. | |
unsigned getElementTypeBitWidth() const; | |
/// If this is a ranked type, return the number of elements. Otherwise, abort. | |
int64_t getNumElements() const; | |
/// If this is a ranked type, return the rank. Otherwise, abort. | |
int64_t getRank() const; | |
/// Whether or not this is a ranked type. Memrefs, vectors and ranked tensors | |
/// have a rank, while unranked tensors do not. | |
bool hasRank() const; | |
/// If this is a ranked type, return the shape. Otherwise, abort. | |
ArrayRef<int64_t> getShape() const; | |
/// If this is unranked type or any dimension has unknown size (<0), it | |
/// doesn't have static shape. If all dimensions have known size (>= 0), it | |
/// has static shape. | |
bool hasStaticShape() const; | |
/// If this is a ranked type, return the number of dimensions with dynamic | |
/// size. Otherwise, abort. | |
int64_t getNumDynamicDims() const; | |
/// If this is ranked type, return the size of the specified dimension. | |
/// Otherwise, abort. | |
int64_t getDimSize(int64_t i) const; | |
/// Get the total amount of bits occupied by a value of this type. This does | |
/// not take into account any memory layout or widening constraints, e.g. a | |
/// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice | |
/// it will likely be stored as in a 4xi64 vector register. Fail an assertion | |
/// if the size cannot be computed statically, i.e. if the type has a dynamic | |
/// shape or if its elemental type does not have a known bit width. | |
int64_t getSizeInBits() const; | |
/// Methods for support type inquiry through isa, cast, and dyn_cast. | |
static bool classof(Type type) { | |
return type.getKind() == StandardTypes::Vector || | |
type.getKind() == StandardTypes::RankedTensor || | |
type.getKind() == StandardTypes::UnrankedTensor || | |
type.getKind() == StandardTypes::MemRef; | |
} | |
/// Whether the given dimension size indicates a dynamic dimension. | |
static constexpr bool isDynamic(int64_t dSize) { return dSize < 0; } | |
}; | |
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed | |
/// known constant shape with one or more dimension. | |
class VectorType | |
: public Type::TypeBase<VectorType, ShapedType, detail::VectorTypeStorage> { | |
public: | |
using Base::Base; | |
/// Get or create a new VectorType of the provided shape and element type. | |
/// Assumes the arguments define a well-formed VectorType. | |
static VectorType get(ArrayRef<int64_t> shape, Type elementType); | |
/// Get or create a new VectorType of the provided shape and element type | |
/// declared at the given, potentially unknown, location. If the VectorType | |
/// defined by the arguments would be ill-formed, emit errors and return | |
/// nullptr-wrapping type. | |
static VectorType getChecked(ArrayRef<int64_t> shape, Type elementType, | |
Location location); | |
/// Verify the construction of a vector type. | |
static LogicalResult | |
verifyConstructionInvariants(llvm::Optional<Location> loc, | |
MLIRContext *context, ArrayRef<int64_t> shape, | |
Type elementType); | |
/// Returns true of the given type can be used as an element of a vector type. | |
/// In particular, vectors can consist of integer or float primitives. | |
static bool isValidElementType(Type t) { return t.isIntOrFloat(); } | |
ArrayRef<int64_t> getShape() const; | |
/// Methods for support type inquiry through isa, cast, and dyn_cast. | |
static bool kindof(unsigned kind) { return kind == StandardTypes::Vector; } | |
}; |
(VectorType
will be tricky given that we cannot directly modify it... But I think we can switch
case
for it in composite type.)
Please feel free to leave a TODO here if you'd like to defer the change in a subsequent CL or leaving it to me. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added CompositeType as a base class between ArrayType, StructType. The CompositeTypeStorage does not have a static "construct" method, because we don't need to directly construct it and therefore it does not have a operator == for the hash function.
It provides a method getElementTypByIndices which consumes array of int32_t, as result it returns a pair<Type, StringRef> on success it would be a type selected by the last provided index, on failure a nullpr with an error message. I didn't come with better idea, how to handle different error messages and don't insert asserts.
lib/SPIRV/SPIRVOps.cpp
Outdated
static LogicalResult verify(spirv::CompositeExtractOp compositeExtractOp) { | ||
auto resultType = compositeExtractOp.getType(); | ||
if (resultType == compositeExtractOp.composite()->getType()) { | ||
return compositeExtractOp.emitOpError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you still call the function to verify the result type agree with the extracted type? It's a bit unfortunate that we are duplicating the check logic but we need that in the parser to properly set the result type. And this verifier is used across all paths (parsing custom assembly form, parsing generic assembly form, and building ops with C++ API, etc.) so it would be nice to guarantee.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a verifier.
test/SPIRV/ops.mlir
Outdated
@@ -496,3 +496,139 @@ func @cannot_be_generic_storage_class(%arg0: f32) -> () { | |||
%0 = spv.Variable : !spv.ptr<f32, Generic> | |||
return | |||
} | |||
|
|||
// ----- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome tests! :D
test/SPIRV/ops.mlir
Outdated
|
||
func @composite_extract_invalid_extracted_type(%arg0: !spv.array<4x!spv.array<4xf32>>) -> () { | ||
// expected-error @+1 {{extracted element type can not be the same with composite type}} | ||
%0 = spv.CompositeExtract %arg0[] : !spv.array<4x!spv.array<4xf32>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, it's interesting. I think we need a different check for this. OpCompositeExtract should not allow empty indices. (Although it's not clearly written on the spec.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, fixed.
lib/SPIRV/SPIRVOps.cpp
Outdated
Type resultType = compositeType; | ||
for (Attribute indexAttr : indicesArrayAttr) { | ||
if (!resultType) { | ||
parser->emitError(attrLocation, "invalid type to extract from"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about ..., "cannnot index into ") << resultType;
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, added in some cases, but in this case seems like resultType == nullptr.
lib/SPIRV/SPIRVOps.cpp
Outdated
parser->emitError(attrLocation, "invalid type to extract from"); | ||
return nullptr; | ||
} | ||
// The type cast is save, because of the description in TableGen. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not true here... The TableGen'd checks are inside verify()
but this one is used for parsing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, got it, updated.
@antiagainst thanks a lot for the detailed review! I'll update it, regarding to your comments. |
c5a12c4
to
5511a89
Compare
@antiagainst I've updated the patch, regarding your comments, can you please take a look. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, almost there, just some final comments. Thanks again, Denis!
include/mlir/SPIRV/SPIRVTypes.h
Outdated
using Type::Type; | ||
|
||
static bool kindof(unsigned kind) { | ||
return (kind == TypeKind::Array || kind == TypeKind::Struct); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... || kind == StandardTypes::Vector)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see, to be able to use kinda "cast" thing we have to provide "kindof" function. I was stuck for I a while trying to understand how to use Vector from Standard Types as kinda CompositeType, in case we can't directly modify it. Thanks a lot, now I'm understanding MLIR's type system more better. Sorry, for this kind of realization.
lib/SPIRV/SPIRVOps.cpp
Outdated
|
||
if (!indicesArrayAttr.size()) { | ||
return parser->emitError(attrLocation, "expexted at least one index for ") | ||
<< compositeType; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean spv.CompositeExtract
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, updated.
lib/SPIRV/SPIRVOps.cpp
Outdated
indexes.push_back(indexIntAttr.getInt()); | ||
} else { | ||
return parser->emitError(attrLocation, | ||
"expexted an 32-bit integer for index"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... but found '") << indexAttr << "'";
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
lib/SPIRV/SPIRVOps.cpp
Outdated
|
||
if (!indicesArrayAttr.size()) { | ||
return compExOp.emitOpError("expexted at least one index for ") | ||
<< resultType; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean spv.CompositeExtract
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, fixed.
lib/SPIRV/SPIRVOps.cpp
Outdated
} | ||
|
||
std::pair<Type, StringRef> resultPair; | ||
switch (compositeType.getKind()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be simplified after letting CompositeType to include VectorType.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
return cast<StructType>().getMemberType(index); | ||
case spirv::TypeKind::Array: | ||
return cast<ArrayType>().getElementType(); | ||
default: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing VectorType
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
lib/SPIRV/SPIRVTypes.cpp
Outdated
case spirv::TypeKind::Array: | ||
return cast<ArrayType>().getElementType(); | ||
default: | ||
assert(0 && "invalid composite type"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use llvm_unreachable here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, fixed.
return cast<StructType>().getNumMembers(); | ||
case spirv::TypeKind::Array: | ||
return cast<ArrayType>().getElementCount(); | ||
default: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing VectorType
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, updated.
lib/SPIRV/SPIRVTypes.cpp
Outdated
case spirv::TypeKind::Array: | ||
return cast<ArrayType>().getElementCount(); | ||
default: | ||
assert(0 && "invalid composite type"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also please use llvm_unreachable here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, fixed.
test/SPIRV/ops.mlir
Outdated
@@ -496,3 +632,5 @@ func @cannot_be_generic_storage_class(%arg0: f32) -> () { | |||
%0 = spv.Variable : !spv.ptr<f32, Generic> | |||
return | |||
} | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant empty lines :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, fixed.
5511a89
to
dee9692
Compare
@antiagainst thanks for the review! I've updated the patch, can you please take a look again :) Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks!
include/mlir/SPIRV/SPIRVTypes.h
Outdated
// StructType. | ||
class CompositeType : public Type { | ||
public: | ||
using ImplType = DefaultTypeStorage; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to specify this, ImplType already defaults to DefaultTypeStorage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@River707 Oh, my fault, thanks! Trying to get more familiar with MLIR.
CompositeExtractOp allows to extract a part of a composite object.
dee9692
to
8ce9295
Compare
@antiagainst @River707 thanks for review! Updated regarding to the latest comment. |
This is now landed as b31675b. Thanks a lot for your contribution, @denis0x0D! |
@antiagainst thanks a lot! It would be nice, if you have some other tasks, not assigned to someone, which I can help with! Thanks. |
@denis0x0D, that's awesome! We certainly have lots to do and I really appreciate your help here! :) Let me create some other issues and assign to you. :) |
@antiagainst thanks a lot! BTW, I will be in travel in [15, 21) of July, so, if it possible, some issue which I can start from 21-th of July. |
@denis0x0D, sure! Thanks for letting me know! I wonder whether you can help to add |
@antiagainst thanks a lot for your time! It would be nice to implement OpAccessChain op, if it can wait until the next week, I'll take it wih a pleasure. |
No worries! Let's see whether we can get it done this week. If not, you are free to take it next week! Otherwise we can certainly use your help somewhere else. :) |
@antiagainst is it ok if I'll implement OpAccessChain? If noone works on it. Thanks. |
@denis0x0D, yup, feel free to take it! Thanks a lot! :) |
@antiagainst thanks! Can you please answer some quick question about implementation, I want to make sure, that I understand it in right way.
for example
The first pointer-type is a type of composite object, the second one is a result type, in this case we can easily resolve the base type and result type. |
We can get the accessed component's type statically; it's just that we are not guaranteed to be in bounds dynamically. An vector/array/runtime-array type contains elements of the same type, so we will just get its element type. For structure type which contains heterogeneous elements, the spec requires the index to be from an For the trailing types in the custom assembly form, I think only output the composite's type should suffice? |
@antiagainst oh, I see, we can get an Operation associated with SSA Value (getDefiningOp) and thereofore all attributes (actual value of ConstantOp) to deduce the type extracted from the Composite object, sorry for the stupid question. |
Hello, this patch is addressing #42 issue.
Mostly all verifying checks are added to parser, for some reasons:
BTW, I might miss something and do not see the possible side effects of this, please fix me if I'm wrong.
The verifier only checks that extracted type is not the same as composite type, to handle situation like this
%0 = spv.CompositeExtract %arg0[] : !spv.array<4x!spv.array<4xf32>>
The tests for different composite types added, also StructType is added to composite and spirv types.
@antiagainst can you please take a look. Thanks!