Skip to content
This repository has been archived by the owner on Apr 23, 2021. It is now read-only.

[spirv] Add CompositeExtractOp operation. #44

Closed

Conversation

denis0x0D
Copy link
Contributor

Hello, this patch is addressing #42 issue.

Mostly all verifying checks are added to parser, for some reasons:

  1. To get the result type of operation, we have to iterate all over the indices, seems like the better way to do it once, instead two times in parser and verifier.
  2. To get the element type of the struct type we have to access it by index, to make it save the check for out of bounds access must be added and it seems like the better way do it once only in the parser.

BTW, I might miss something and do not see the possible side effects of this, please fix me if I'm wrong.

The verifier only checks that extracted type is not the same as composite type, to handle situation like this
%0 = spv.CompositeExtract %arg0[] : !spv.array<4x!spv.array<4xf32>>

The tests for different composite types added, also StructType is added to composite and spirv types.

@antiagainst can you please take a look. Thanks!

Copy link
Contributor

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome, thanks denis0x0D! I like the extensive tests! :)

Yes because we need to set the extracted component's type, we need to duplicate the check that happens in the validator inside the parser. It's a bit annoying. The parser check is only used for parsing custom assembly forms; the verifier guards all paths: parsing custom assembly form, parsing generic assembly form, and building ops with C++ API, etc. So we need both. But by properly define them with functions I think we can share the same impl at least.

More comments inline.

@@ -79,14 +80,15 @@ def SPV_Vector : VectorOf<[SPV_Bool, SPV_Integer, SPV_Float]>;
def SPV_AnyPtr : Type<SPV_IsPtrType, "any SPIR-V pointer type">;
def SPV_AnyArray : Type<SPV_IsArrayType, "any SPIR-V array type">;
def SPV_AnyRTArray : Type<SPV_IsRTArrayType, "any SPIR-V runtime array type">;
def SPV_AnyStruct : Type<SPV_IsStructType, "any SPIR-V struct type">;

def SPV_Numerical : AnyTypeOf<[SPV_Integer, SPV_Float]>;
def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>;
def SPV_Aggregrate : AnyTypeOf<[SPV_AnyArray]>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also add SPV_AnyStruct to SPV_Aggregrate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, added.

let summary = "Extract a part of a composite object.";

let description = [{
Extract a part of a composite object.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to duplicate the summary line. The summary and description will be concatenated together when generating the doc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

I32ArrayAttr:$indices
);

let results = (outs SPV_Type);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please give the output a name, like $component or something so that we will generate a nicer getter method for ti. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@@ -740,6 +740,102 @@ static LogicalResult verify(spirv::VariableOp varOp) {
return success();
}

//===----------------------------------------------------------------------===//
// spv.CompositeExtractOp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Ops should be sorted alphabetically in this file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see, thanks, fixed.


// Extracts the Result Type from Composite Type iterating all over the indices,
// with check for out of bounds.
static Type getElementTypeByIndices(OpAsmParser *parser, OperationState *state,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this element type extraction logic is generally useful for working with composite types. So it should be added to the type system. We can create a base class for composite types and let struct and arrays subclass from it. That way we have a uniform interface for querying element types. This involves touching custom types registered for the SPIR-V dialect though. You can use ShapedType and VectorType as an example of how to define base class and subclasses:

/// This is a common base class between Vector, UnrankedTensor, RankedTensor,
/// and MemRef types because they share behavior and semantics around shape,
/// rank, and fixed element type. Any type with these semantics should inherit
/// from ShapedType.
class ShapedType : public Type {
public:
using ImplType = detail::ShapedTypeStorage;
using Type::Type;
/// Return the element type.
Type getElementType() const;
/// If an element type is an integer or a float, return its width. Otherwise,
/// abort.
unsigned getElementTypeBitWidth() const;
/// If this is a ranked type, return the number of elements. Otherwise, abort.
int64_t getNumElements() const;
/// If this is a ranked type, return the rank. Otherwise, abort.
int64_t getRank() const;
/// Whether or not this is a ranked type. Memrefs, vectors and ranked tensors
/// have a rank, while unranked tensors do not.
bool hasRank() const;
/// If this is a ranked type, return the shape. Otherwise, abort.
ArrayRef<int64_t> getShape() const;
/// If this is unranked type or any dimension has unknown size (<0), it
/// doesn't have static shape. If all dimensions have known size (>= 0), it
/// has static shape.
bool hasStaticShape() const;
/// If this is a ranked type, return the number of dimensions with dynamic
/// size. Otherwise, abort.
int64_t getNumDynamicDims() const;
/// If this is ranked type, return the size of the specified dimension.
/// Otherwise, abort.
int64_t getDimSize(int64_t i) const;
/// Get the total amount of bits occupied by a value of this type. This does
/// not take into account any memory layout or widening constraints, e.g. a
/// vector<3xi57> is reported to occupy 3x57=171 bit, even though in practice
/// it will likely be stored as in a 4xi64 vector register. Fail an assertion
/// if the size cannot be computed statically, i.e. if the type has a dynamic
/// shape or if its elemental type does not have a known bit width.
int64_t getSizeInBits() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Type type) {
return type.getKind() == StandardTypes::Vector ||
type.getKind() == StandardTypes::RankedTensor ||
type.getKind() == StandardTypes::UnrankedTensor ||
type.getKind() == StandardTypes::MemRef;
}
/// Whether the given dimension size indicates a dynamic dimension.
static constexpr bool isDynamic(int64_t dSize) { return dSize < 0; }
};
/// Vector types represent multi-dimensional SIMD vectors, and have a fixed
/// known constant shape with one or more dimension.
class VectorType
: public Type::TypeBase<VectorType, ShapedType, detail::VectorTypeStorage> {
public:
using Base::Base;
/// Get or create a new VectorType of the provided shape and element type.
/// Assumes the arguments define a well-formed VectorType.
static VectorType get(ArrayRef<int64_t> shape, Type elementType);
/// Get or create a new VectorType of the provided shape and element type
/// declared at the given, potentially unknown, location. If the VectorType
/// defined by the arguments would be ill-formed, emit errors and return
/// nullptr-wrapping type.
static VectorType getChecked(ArrayRef<int64_t> shape, Type elementType,
Location location);
/// Verify the construction of a vector type.
static LogicalResult
verifyConstructionInvariants(llvm::Optional<Location> loc,
MLIRContext *context, ArrayRef<int64_t> shape,
Type elementType);
/// Returns true of the given type can be used as an element of a vector type.
/// In particular, vectors can consist of integer or float primitives.
static bool isValidElementType(Type t) { return t.isIntOrFloat(); }
ArrayRef<int64_t> getShape() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(unsigned kind) { return kind == StandardTypes::Vector; }
};

(VectorType will be tricky given that we cannot directly modify it... But I think we can switch case for it in composite type.)

Please feel free to leave a TODO here if you'd like to defer the change in a subsequent CL or leaving it to me. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added CompositeType as a base class between ArrayType, StructType. The CompositeTypeStorage does not have a static "construct" method, because we don't need to directly construct it and therefore it does not have a operator == for the hash function.
It provides a method getElementTypByIndices which consumes array of int32_t, as result it returns a pair<Type, StringRef> on success it would be a type selected by the last provided index, on failure a nullpr with an error message. I didn't come with better idea, how to handle different error messages and don't insert asserts.

static LogicalResult verify(spirv::CompositeExtractOp compositeExtractOp) {
auto resultType = compositeExtractOp.getType();
if (resultType == compositeExtractOp.composite()->getType()) {
return compositeExtractOp.emitOpError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you still call the function to verify the result type agree with the extracted type? It's a bit unfortunate that we are duplicating the check logic but we need that in the parser to properly set the result type. And this verifier is used across all paths (parsing custom assembly form, parsing generic assembly form, and building ops with C++ API, etc.) so it would be nice to guarantee.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a verifier.

@@ -496,3 +496,139 @@ func @cannot_be_generic_storage_class(%arg0: f32) -> () {
%0 = spv.Variable : !spv.ptr<f32, Generic>
return
}

// -----
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome tests! :D


func @composite_extract_invalid_extracted_type(%arg0: !spv.array<4x!spv.array<4xf32>>) -> () {
// expected-error @+1 {{extracted element type can not be the same with composite type}}
%0 = spv.CompositeExtract %arg0[] : !spv.array<4x!spv.array<4xf32>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, it's interesting. I think we need a different check for this. OpCompositeExtract should not allow empty indices. (Although it's not clearly written on the spec.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed.

Type resultType = compositeType;
for (Attribute indexAttr : indicesArrayAttr) {
if (!resultType) {
parser->emitError(attrLocation, "invalid type to extract from");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about ..., "cannnot index into ") << resultType;?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, added in some cases, but in this case seems like resultType == nullptr.

parser->emitError(attrLocation, "invalid type to extract from");
return nullptr;
}
// The type cast is save, because of the description in TableGen.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not true here... The TableGen'd checks are inside verify() but this one is used for parsing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, got it, updated.

@denis0x0D
Copy link
Contributor Author

@antiagainst thanks a lot for the detailed review! I'll update it, regarding to your comments.

@denis0x0D
Copy link
Contributor Author

@antiagainst I've updated the patch, regarding your comments, can you please take a look.
Thanks.

Copy link
Contributor

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, almost there, just some final comments. Thanks again, Denis!

using Type::Type;

static bool kindof(unsigned kind) {
return (kind == TypeKind::Array || kind == TypeKind::Struct);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... || kind == StandardTypes::Vector)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see, to be able to use kinda "cast" thing we have to provide "kindof" function. I was stuck for I a while trying to understand how to use Vector from Standard Types as kinda CompositeType, in case we can't directly modify it. Thanks a lot, now I'm understanding MLIR's type system more better. Sorry, for this kind of realization.


if (!indicesArrayAttr.size()) {
return parser->emitError(attrLocation, "expexted at least one index for ")
<< compositeType;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean spv.CompositeExtract here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated.

indexes.push_back(indexIntAttr.getInt());
} else {
return parser->emitError(attrLocation,
"expexted an 32-bit integer for index");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... but found '") << indexAttr << "'";

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.


if (!indicesArrayAttr.size()) {
return compExOp.emitOpError("expexted at least one index for ")
<< resultType;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean spv.CompositeExtract here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed.

}

std::pair<Type, StringRef> resultPair;
switch (compositeType.getKind()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be simplified after letting CompositeType to include VectorType.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

return cast<StructType>().getMemberType(index);
case spirv::TypeKind::Array:
return cast<ArrayType>().getElementType();
default:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing VectorType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

case spirv::TypeKind::Array:
return cast<ArrayType>().getElementType();
default:
assert(0 && "invalid composite type");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use llvm_unreachable here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed.

return cast<StructType>().getNumMembers();
case spirv::TypeKind::Array:
return cast<ArrayType>().getElementCount();
default:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing VectorType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, updated.

case spirv::TypeKind::Array:
return cast<ArrayType>().getElementCount();
default:
assert(0 && "invalid composite type");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also please use llvm_unreachable here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed.

@@ -496,3 +632,5 @@ func @cannot_be_generic_storage_class(%arg0: f32) -> () {
%0 = spv.Variable : !spv.ptr<f32, Generic>
return
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant empty lines :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, fixed.

@denis0x0D
Copy link
Contributor Author

@antiagainst thanks for the review! I've updated the patch, can you please take a look again :) Thanks!

Copy link
Contributor

@antiagainst antiagainst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks!

// StructType.
class CompositeType : public Type {
public:
using ImplType = DefaultTypeStorage;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to specify this, ImplType already defaults to DefaultTypeStorage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@River707 Oh, my fault, thanks! Trying to get more familiar with MLIR.

CompositeExtractOp allows to extract a part of a composite object.
@denis0x0D
Copy link
Contributor Author

@antiagainst @River707 thanks for review! Updated regarding to the latest comment.

@jpienaar jpienaar closed this in b31675b Jul 12, 2019
@antiagainst
Copy link
Contributor

This is now landed as b31675b. Thanks a lot for your contribution, @denis0x0D!

@denis0x0D
Copy link
Contributor Author

@antiagainst thanks a lot! It would be nice, if you have some other tasks, not assigned to someone, which I can help with! Thanks.

@antiagainst
Copy link
Contributor

@denis0x0D, that's awesome! We certainly have lots to do and I really appreciate your help here! :) Let me create some other issues and assign to you. :)

@denis0x0D
Copy link
Contributor Author

@antiagainst thanks a lot! BTW, I will be in travel in [15, 21) of July, so, if it possible, some issue which I can start from 21-th of July.

@antiagainst
Copy link
Contributor

@denis0x0D, sure! Thanks for letting me know! I wonder whether you can help to add spv.AccessChain too. Given that you've done spv.CompositeExtract, it should be quite similar. I've just landed 9a8202e, so you can use it to inject the basic definition into SPIRVOps.td with ./define_inst.sh OpAccessChain and update according to the spec. But certainly no pressure! Just let me know. Thanks again! :)

@denis0x0D
Copy link
Contributor Author

@antiagainst thanks a lot for your time! It would be nice to implement OpAccessChain op, if it can wait until the next week, I'll take it wih a pleasure.
I see that helpful script gen_spirv_dialect.py, thanks!

@antiagainst
Copy link
Contributor

No worries! Let's see whether we can get it done this week. If not, you are free to take it next week! Otherwise we can certainly use your help somewhere else. :)

@denis0x0D
Copy link
Contributor Author

@antiagainst is it ok if I'll implement OpAccessChain? If noone works on it. Thanks.

@antiagainst
Copy link
Contributor

@denis0x0D, yup, feel free to take it! Thanks a lot! :)

@denis0x0D
Copy link
Contributor Author

@antiagainst thanks! Can you please answer some quick question about implementation, I want to make sure, that I understand it in right way.
In case of AccessChainOp it creates a pointer into composite type from pointer to composite type, the result pointer type must be provided by walking the Base's type hierarchy down to last provided index; the indexes are in SSA forms, so we can't walk down to last provided index during compile time, or I'm missing something. In parser we need to resolve the SSA values provided by Base (ssa-id) and get the type of result. to add it into list So, the ebnf form must look like this:

 op ::= ssa-id `=` `spv.AccessChain` ssa-use
           `[` ssa-use (`,` ssa-use)* `]`
           `:` pointer-type `,` pointer-type

for example

%result_ptr = spv.AccessChain %composite_ptr[%index0] : !spv.ptr<!spv.struct<f32, spv.array<4xf32>>, Function>, !spv.ptr<!spv.array<4xf32>, Function>

The first pointer-type is a type of composite object, the second one is a result type, in this case we can easily resolve the base type and result type.
But I might miss something, please correct me if I'm wrong.

@antiagainst
Copy link
Contributor

antiagainst commented Jul 22, 2019

@denis0x0D:

so we can't walk down to last provided index during compile time, or I'm missing something.

We can get the accessed component's type statically; it's just that we are not guaranteed to be in bounds dynamically. An vector/array/runtime-array type contains elements of the same type, so we will just get its element type. For structure type which contains heterogeneous elements, the spec requires the index to be from an OpConstant so we can statically deduce what type the accessed component is of.

For the trailing types in the custom assembly form, I think only output the composite's type should suffice?

@denis0x0D
Copy link
Contributor Author

@antiagainst oh, I see, we can get an Operation associated with SSA Value (getDefiningOp) and thereofore all attributes (actual value of ConstantOp) to deduce the type extracted from the Composite object, sorry for the stupid question.
Thanks.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants