diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index d0286e5acff9ce..40aa584e12f3ab 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -37,6 +37,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 8675a411788344..4a8833a865bdd3 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -310,6 +310,7 @@ cc_library( "@llvm-project//mlir:Dialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -911,6 +912,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ], ) @@ -1037,6 +1039,7 @@ cc_library( "@llvm-project//llvm:Analysis", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@local_tsl//tsl/platform:status", "@local_xla//xla:statusor", @@ -1210,6 +1213,7 @@ cc_library( "//tensorflow/core/platform:errors", "//tensorflow/lite/schema:schema_fbs", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_xla//xla:statusor", ], ) diff --git a/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.cc b/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.cc index 5f77797b9aa8a7..59cc28f9fa0608 100644 --- a/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.cc +++ b/tensorflow/compiler/mlir/lite/experimental/common/outline_operations.cc @@ -44,7 +44,7 @@ namespace common { bool IsConstantOrNone(Operation* op) { return (op->getNumResults() == 1 && - op->getResult(0).getType().isa()) || + mlir::isa(op->getResult(0).getType())) || matchPattern(op, m_Constant()) || isa(op); } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD index 248a55c7fe17e1..8f70532a06977a 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD @@ -42,6 +42,7 @@ cc_library( "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -88,6 +89,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h b/tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h index 40f4902e655bcd..88382e8cf6f27b 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/common/utils.h @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/utils/utils.h" @@ -52,7 +53,7 @@ bool NotTFLQuantDequantizeOp(Operation* op); // Returns true if it is a shaped type of f32 elements. inline bool IsF32ShapedType(Type t) { - if (auto shaped_type = t.dyn_cast_or_null()) { + if (auto shaped_type = mlir::dyn_cast_or_null(t)) { return shaped_type.getElementType().isF32(); } return false; diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.cc b/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.cc index 9b2458571f0c34..11a1b31e5102de 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/execution_metadata_exporter.cc @@ -29,6 +29,7 @@ #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h" #include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h" #include "tensorflow/compiler/mlir/lite/experimental/tac/runtime_metadata_generated.h" @@ -82,8 +83,7 @@ std::optional> GetPerDeviceCosts( for (const auto& kv : hardware_map) { auto cost_attr = device_costs_attr.getNamed(kv.first); if (!cost_attr.has_value()) return std::nullopt; - float cost = cost_attr->getValue() - .dyn_cast_or_null() + float cost = mlir::dyn_cast_or_null(cost_attr->getValue()) .getValueAsDouble(); device_costs[kv.second] = cost; } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.cc index 4efdd053eec5c2..701d9cad1c34c1 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/cost_model.cc @@ -59,13 +59,14 @@ int64_t GetTransferredTensorBytes(func::CallOp from_graph, for (auto input : to_graph.getOperands()) { Operation* input_op = input.getDefiningOp(); if (input_op && input_op == from_graph.getOperation()) { - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = + mlir::dyn_cast_or_null(input.getType()); if (input_type == nullptr || !input_type.hasStaticShape()) continue; // Quantized type does not support getSizeInBits. if (IsQUI8Type(input_type) || IsQI8Type(input_type)) { total_size_transferred += input_type.getNumElements() * 8; } else { - auto s_type = input_type.cast(); + auto s_type = mlir::cast(input_type); total_size_transferred += s_type.getNumElements() * s_type.getElementTypeBitWidth(); } @@ -81,7 +82,8 @@ int64_t GetTransferredElementCount(func::CallOp from_graph, for (auto input : to_graph.getOperands()) { Operation* input_op = input.getDefiningOp(); if (input_op && input_op == from_graph.getOperation()) { - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = + mlir::dyn_cast_or_null(input.getType()); if (input_type == nullptr || !input_type.hasStaticShape()) continue; total_element_count += input_type.getNumElements(); } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.cc index ea1c299fd546c1..0c37a8da20575f 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform.cc @@ -156,13 +156,13 @@ struct FoldQuantizedI32ToFloat : public OpRewritePattern { if (!IsQI32Type(input_dequant.getType())) return failure(); auto output_type = - dequant_op.getOutput().getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(dequant_op.getOutput().getType()); if (!output_type || !output_type.getElementType().isF32()) return failure(); - auto input_type = input_dequant.getType().dyn_cast(); + auto input_type = mlir::dyn_cast(input_dequant.getType()); // TODO(renjieliu): support UniformQuantizedPerAxisType. - auto q_type = input_type.getElementType() - .dyn_cast_or_null(); + auto q_type = mlir::dyn_cast_or_null( + input_type.getElementType()); if (!q_type) return failure(); const float scale = q_type.getScale(); @@ -183,9 +183,9 @@ struct FoldQuantizedI32ToFloat : public OpRewritePattern { }; auto dequant_values = - input_values.cast().mapValues( - FloatType::getF32(rewriter.getContext()), - llvm::function_ref(dequantize_func)); + mlir::cast(input_values) + .mapValues(FloatType::getF32(rewriter.getContext()), + llvm::function_ref(dequantize_func)); rewriter.replaceOpWithNewOp(dequant_op, dequant_op.getType(), dequant_values); diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.cc index baf25aa54c109b..278c54e8805f3d 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/device_transform_patterns.cc @@ -96,11 +96,11 @@ LogicalResult EnsureBias(Operation* op, int bias_idx, PatternRewriter& rewriter) { auto bias = op->getOperand(bias_idx); - if (!bias.getType().isa()) return failure(); + if (!mlir::isa(bias.getType())) return failure(); // Proceed to create a zero bias. auto output = op->getResult(0); - auto output_type = output.getType().dyn_cast_or_null(); + auto output_type = mlir::dyn_cast_or_null(output.getType()); if (!output_type) return failure(); // bias should be a vector sized of the last output dim. @@ -163,7 +163,7 @@ SmallVector SliceOutputs(Operation* split_op, Value input, SmallVector slice_size; auto current_output = split_op->getResult(i); auto current_output_type = - current_output.getType().cast(); + mlir::cast(current_output.getType()); for (int d = 0; d < input_type.getRank(); ++d) { if (d == split_dim) { // Split dimension. @@ -208,7 +208,7 @@ LogicalResult LowerPackIntoConcatReshape::matchAndRewrite( TFL::PackOp pack_op, PatternRewriter& rewriter) const { // Pack op should have same shape type. SmallVector pack_inputs(pack_op.getValues()); - auto input_type = pack_inputs[0].getType().dyn_cast(); + auto input_type = mlir::dyn_cast(pack_inputs[0].getType()); if (!input_type) return failure(); // Figure out output shapes. @@ -266,8 +266,8 @@ LogicalResult SquaredDifference::matchAndRewrite( TFL::SquaredDifferenceOp squared_diff_op, PatternRewriter& rewriter) const { auto x = squared_diff_op.getLhs(); auto y = squared_diff_op.getRhs(); - auto x_type = x.getType().dyn_cast(); - auto y_type = y.getType().dyn_cast(); + auto x_type = mlir::dyn_cast(x.getType()); + auto y_type = mlir::dyn_cast(y.getType()); if (!x_type || !y_type) return failure(); if (x_type.getShape() != y_type.getShape()) return failure(); @@ -290,16 +290,16 @@ LogicalResult UnrollSplit::matchAndRewrite(TFL::SplitOp split_op, PatternRewriter& rewriter) const { auto num_splits = split_op.getNumSplits(); auto input = split_op.getValue(); - auto input_type = input.getType().dyn_cast(); + auto input_type = mlir::dyn_cast(input.getType()); if (input_type == nullptr || !input_type.hasStaticShape()) return failure(); for (auto result : split_op.getResults()) { - auto result_type = result.getType().dyn_cast(); + auto result_type = mlir::dyn_cast(result.getType()); if (result_type == nullptr) return failure(); } auto output = split_op.getResult(0); - auto output_type = output.getType().cast(); + auto output_type = mlir::cast(output.getType()); // TODO(renjieliu): change to use split_dim when we raise the constants // as well. @@ -330,11 +330,11 @@ LogicalResult UnrollSplitV::matchAndRewrite(TFL::SplitVOp splitv_op, return failure(); auto input = splitv_op.getValue(); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type || !input_type.hasRank()) return failure(); for (auto result : splitv_op.getResults()) { - auto result_type = result.getType().dyn_cast(); + auto result_type = mlir::dyn_cast(result.getType()); if (result_type == nullptr) return failure(); } @@ -371,20 +371,21 @@ LogicalResult PadSlice::matchAndRewrite(TFL::SliceOp slice_op, // We have to know the shape of the input, as well as the begin/size. // also, begin and size have to be constants. auto input = slice_op.getInput(); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type || !input_type.hasStaticShape()) return failure(); if (input_type.getRank() >= 4) return failure(); auto begin = slice_op.getBegin(); - auto begin_type = begin.getType().dyn_cast_or_null(); + auto begin_type = mlir::dyn_cast_or_null(begin.getType()); if (!begin_type || !begin_type.hasStaticShape()) return failure(); auto size = slice_op.getSize(); - auto size_type = size.getType().dyn_cast_or_null(); + auto size_type = mlir::dyn_cast_or_null(size.getType()); if (!size_type || !size_type.hasStaticShape()) return failure(); - auto output_type = slice_op.getType().dyn_cast_or_null(); + auto output_type = + mlir::dyn_cast_or_null(slice_op.getType()); if (!output_type || !output_type.hasStaticShape()) return failure(); // Pad 0s in front of the begin. @@ -472,17 +473,17 @@ LogicalResult FullyConnectedToConv::matchAndRewrite( TFL::FullyConnectedOp fc_op, PatternRewriter& rewriter) const { // We have to know the shape of the input. auto input = fc_op.getInput(); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type || !input_type.hasStaticShape()) return failure(); // We have to know the shape of the weight. auto weight = fc_op.getFilter(); - auto weight_type = weight.getType().dyn_cast_or_null(); + auto weight_type = mlir::dyn_cast_or_null(weight.getType()); if (!weight_type || !weight_type.hasStaticShape()) return failure(); // We have to know the shape of the output as well. auto output = fc_op.getResult(0); - auto output_type = output.getType().dyn_cast_or_null(); + auto output_type = mlir::dyn_cast_or_null(output.getType()); if (!output_type || !output_type.hasStaticShape()) return failure(); // Insert a reshape after the input. @@ -532,13 +533,14 @@ LogicalResult PadConcat::matchAndRewrite(TFL::ConcatenationOp concat_op, PatternRewriter& rewriter) const { int rank = -1; for (auto input : concat_op.getValues()) { - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type || !input_type.hasStaticShape()) return failure(); rank = input_type.getRank(); } - auto output_type = concat_op.getType().dyn_cast_or_null(); + auto output_type = + mlir::dyn_cast_or_null(concat_op.getType()); if (!output_type || !output_type.hasStaticShape()) return failure(); if (rank >= 4) return failure(); @@ -547,7 +549,7 @@ LogicalResult PadConcat::matchAndRewrite(TFL::ConcatenationOp concat_op, // We will insert a reshape op after every input. SmallVector reshape_ops; for (auto input : concat_op.getValues()) { - auto input_type = input.getType().cast(); + auto input_type = mlir::cast(input.getType()); // Get the new shape. SmallVector new_shape; for (int i = 0; i < 4 - rank; ++i) { @@ -603,7 +605,7 @@ LogicalResult PadConcat::matchAndRewrite(TFL::ConcatenationOp concat_op, LogicalResult ReduceMeanToAvgPool::matchAndRewrite( TFL::MeanOp mean_op, PatternRewriter& rewriter) const { auto input = mean_op.getInput(); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); // Only 4d is supported here. if (!input_type || input_type.getRank() != 4) return failure(); @@ -619,7 +621,7 @@ LogicalResult ReduceMeanToAvgPool::matchAndRewrite( } auto output = mean_op.getOutput(); - auto output_type = output.getType().dyn_cast_or_null(); + auto output_type = mlir::dyn_cast_or_null(output.getType()); if (!output_type) return failure(); auto input_quantized_type = @@ -669,7 +671,7 @@ LogicalResult ReduceMeanToAvgPool::matchAndRewrite( LogicalResult InsertRequantForReduceMean::matchAndRewrite( TFL::MeanOp mean_op, PatternRewriter& rewriter) const { auto input = mean_op.getInput(); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type) return failure(); // Only need to do this for quantized input. @@ -678,7 +680,7 @@ LogicalResult InsertRequantForReduceMean::matchAndRewrite( if (!input_quantized_type) return failure(); auto output = mean_op.getOutput(); - auto output_type = output.getType().dyn_cast_or_null(); + auto output_type = mlir::dyn_cast_or_null(output.getType()); if (!output_type) return failure(); auto output_quantized_type = quant::QuantizedType::getQuantizedElementType(output_type); diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/fold_constants_to_subgraph.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/fold_constants_to_subgraph.cc index b6c544a8f69c9b..e4985f2b5700d5 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/fold_constants_to_subgraph.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/fold_constants_to_subgraph.cc @@ -107,11 +107,12 @@ bool IsConstOrQConstInt(Operation* op) { if (auto arith_const_op = dyn_cast_or_null(op)) { // arith ConstOp path. - auto type = arith_const_op.getType().cast().getElementType(); + auto type = + mlir::cast(arith_const_op.getType()).getElementType(); if (!type.isInteger(32) && !type.isInteger(64)) return false; } else if (auto const_op = dyn_cast_or_null(op)) { // ConstOp path. - auto type = const_op.getType().cast().getElementType(); + auto type = mlir::cast(const_op.getType()).getElementType(); if (!type.isInteger(32) && !type.isInteger(64)) return false; } else { // QConstOp path. diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/raise_target_subgraphs.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/raise_target_subgraphs.cc index 4fd9f945764b3a..1ff585f6c71cb6 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/raise_target_subgraphs.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/raise_target_subgraphs.cc @@ -113,18 +113,11 @@ void AddAttrs(OpsAdded& ops_added, OpBuilder& builder, int func_count) { added_func_op->setAttr(kInterfaceNameAttr, interface_name); added_call_op->setAttr(kInterfaceNameAttr, interface_name); - StringAttr device = added_func_op->getRegion(0) - .getBlocks() - .front() - .front() - .getAttr(kDevice) - .cast(); - StringAttr inference_type = added_func_op->getRegion(0) - .getBlocks() - .front() - .front() - .getAttr(kInferenceType) - .cast(); + StringAttr device = mlir::cast( + added_func_op->getRegion(0).getBlocks().front().front().getAttr(kDevice)); + StringAttr inference_type = mlir::cast( + added_func_op->getRegion(0).getBlocks().front().front().getAttr( + kInferenceType)); added_call_op->setAttr(kDevice, device); added_call_op->setAttr(kInferenceType, inference_type); added_func_op->setAttr(kDevice, device); diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_filter.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_filter.cc index a2f7441cc170b1..05cadcbb26b1e7 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_filter.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_filter.cc @@ -110,7 +110,7 @@ void ApplyTacFilter(ModuleOp module, const TacFilter& tac_filter, llvm::Regex op_regex(tac_filter.op_filter().op_name_pattern()); module.walk([&](Operation* op) { - auto named_loc = op->getLoc().dyn_cast(); + auto named_loc = mlir::dyn_cast(op->getLoc()); if (!named_loc) { return; } diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index acdbb79201a3d7..84bc7ab2daae19 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -190,11 +190,11 @@ static StatusOr GetTFLiteType(Type type, return tflite::TensorType_BFLOAT16; } else if (type.isF64()) { return tflite::TensorType_FLOAT64; - } else if (type.isa()) { + } else if (mlir::isa(type)) { return tflite::TensorType_STRING; - } else if (type.isa()) { + } else if (mlir::isa(type)) { return tflite::TensorType_UINT8; - } else if (auto complex_type = type.dyn_cast()) { + } else if (auto complex_type = mlir::dyn_cast(type)) { auto ftype = complex_type.getElementType(); if (ftype.isF32()) { return tflite::TensorType_COMPLEX64; @@ -203,7 +203,7 @@ static StatusOr GetTFLiteType(Type type, return tflite::TensorType_COMPLEX128; } return Status(absl::StatusCode::kInvalidArgument, "Unsupported type"); - } else if (auto itype = type.dyn_cast()) { + } else if (auto itype = mlir::dyn_cast(type)) { switch (itype.getWidth()) { case 1: return tflite::TensorType_BOOL; @@ -228,19 +228,20 @@ static StatusOr GetTFLiteType(Type type, : tflite::TensorType_INT64; } } else if (auto q_uniform_type = - type.dyn_cast()) { + mlir::dyn_cast(type)) { return GetTFLiteType(q_uniform_type.getStorageType(), q_uniform_type.isSigned()); } else if (auto q_peraxis_type = - type.dyn_cast()) { + mlir::dyn_cast( + type)) { return GetTFLiteType(q_peraxis_type.getStorageType(), q_peraxis_type.isSigned()); } else if (auto q_calibrated_type = - type.dyn_cast()) { + mlir::dyn_cast(type)) { return GetTFLiteType(q_calibrated_type.getExpressedType()); - } else if (type.isa()) { + } else if (mlir::isa(type)) { return tflite::TensorType_RESOURCE; - } else if (type.isa()) { + } else if (mlir::isa(type)) { return tflite::TensorType_VARIANT; } // TFLite export fills FLOAT32 for unknown data types. Returning an error @@ -258,13 +259,13 @@ static bool IsConst(Operation* op) { static bool IsTFResourceOp(Operation* op) { for (const auto& operand : op->getOperands()) { auto elementType = getElementTypeOrSelf(operand.getType()); - if (elementType.isa()) { + if (mlir::isa(elementType)) { return true; } } for (const auto& result : op->getResults()) { auto elementType = getElementTypeOrSelf(result.getType()); - if (elementType.isa()) { + if (mlir::isa(elementType)) { return true; } } @@ -310,7 +311,8 @@ static std::string GetOpDescriptionForDebug(Operation* inst) { os << (!first ? ", " : ""); first = false; os << named_attr.getName().getValue() << " = "; - if (auto element_attr = named_attr.getValue().dyn_cast()) { + if (auto element_attr = + mlir::dyn_cast(named_attr.getValue())) { if (element_attr.getNumElements() <= kLargeElementsAttr) { element_attr.print(os); } else { @@ -355,9 +357,9 @@ static std::string GetOpsSummary( template static bool HasValidTFLiteType(Value value, T& error_handler) { // None type is allowed to represent unspecified operands. - if (value.getType().isa()) return true; + if (mlir::isa(value.getType())) return true; - auto type = value.getType().dyn_cast(); + auto type = mlir::dyn_cast(value.getType()); if (!type) { if (auto op = value.getDefiningOp()) { error_handler.emitError() @@ -416,7 +418,7 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) { for (auto arg : bb.getArguments()) { if (!HasValidTFLiteType(arg, fn)) { auto elementType = getElementTypeOrSelf(arg.getType()); - if (elementType.isa()) { + if (mlir::isa(elementType)) { return fn.emitError( "function argument uses variant type. Currently, the " "variant type is not natively supported in TFLite. Please " @@ -435,10 +437,10 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) { if (inst.hasTrait()) break; for (auto result : inst.getResults()) { - if (result.getType().isa()) continue; + if (mlir::isa(result.getType())) continue; if (!HasValidTFLiteType(result, inst)) { auto elementType = getElementTypeOrSelf(result.getType()); - if (elementType.isa()) { + if (mlir::isa(elementType)) { return inst.emitError( "operand result uses variant type. Currently, the " "variant type is not natively supported in TFLite. " @@ -919,7 +921,7 @@ std::optional> Translator::BuildBuffer( if (auto cst = dyn_cast(inst)) { // arith::ConstantOp have ElementAttr at this point due to validation of the // TFLite module. - attr = cst.getValue().cast(); + attr = mlir::cast(cst.getValue()); } else if (auto cst = dyn_cast(inst)) { attr = cst.getValue(); } else if (auto cst = dyn_cast(inst)) { @@ -930,10 +932,10 @@ std::optional> Translator::BuildBuffer( attr = cst.getValue(); } else if (auto cst = dyn_cast(inst)) { mlir::VhloToStablehloTypeConverter vhlo_type_converter; - auto tensor_v1_attr = cst.getValue().cast(); + auto tensor_v1_attr = mlir::cast(cst.getValue()); attr = mlir::DenseIntOrFPElementsAttr::getFromRawBuffer( - vhlo_type_converter.convertType(tensor_v1_attr.getType()) - .cast(), + mlir::cast( + vhlo_type_converter.convertType(tensor_v1_attr.getType())), tensor_v1_attr.getData()); } else if (auto cst = dyn_cast(inst)) { attr = cst.getCompressedData(); @@ -956,7 +958,7 @@ std::optional> Translator::BuildBuffer( // trouble calling ConvertToTensor(). For now, extract the tensor data from // ElementsAttr directly in this and read type from tflite::TensorType instead // of tensorflow::DataType. - auto type = value.getType().cast(); + auto type = mlir::cast(value.getType()); tflite::TensorType tflite_element_type = GetTFLiteType(type.getElementType()).value(); if (tflite_element_type == tflite::TensorType_INT4) { @@ -1052,7 +1054,7 @@ int32_t Translator::UnnamedRegionToSubgraph( std::optional>> Translator::BuildTFVariantType(mlir::Type element_type) { std::vector> variant_params; - auto variant_type = element_type.dyn_cast(); + auto variant_type = mlir::dyn_cast(element_type); if (!variant_type) { return variant_params; } @@ -1081,7 +1083,7 @@ Translator::BuildTFVariantType(mlir::Type element_type) { std::optional> Translator::BuildTensorFromType( mlir::Type type, const std::string& name) { - auto tensor_type = type.cast(); + auto tensor_type = mlir::cast(type); llvm::ArrayRef shape_ref; std::vector shape; @@ -1104,15 +1106,15 @@ std::optional> Translator::BuildTensorFromType( return std::nullopt; } BufferOffset q_params = 0; - if (auto qtype = element_type.dyn_cast()) { + if (auto qtype = + mlir::dyn_cast(element_type)) { std::vector scales = {static_cast(qtype.getScale())}; std::vector zero_points = {qtype.getZeroPoint()}; q_params = tflite::CreateQuantizationParameters( builder_, /*min=*/0, /*max=*/0, builder_.CreateVector(scales), builder_.CreateVector(zero_points)); - } else if (auto qtype = - element_type - .dyn_cast()) { + } else if (auto qtype = mlir::dyn_cast( + element_type)) { std::vector mins = {static_cast(qtype.getMin())}; std::vector maxs = {static_cast(qtype.getMax())}; q_params = tflite::CreateQuantizationParameters( @@ -1131,7 +1133,7 @@ std::optional> Translator::BuildTensor( Value value, const std::string& name, unsigned buffer_idx, const std::optional>& quant_parameters) { - auto type = value.getType().cast(); + auto type = mlir::cast(value.getType()); // TFLite requires tensor shape only for the inputs and constants. // However, we output all known shapes for better round-tripping @@ -1161,9 +1163,9 @@ std::optional> Translator::BuildTensor( // Const op can have a result of dynamic shaped type (e.g. due to constant // folding), but we can still derive the shape of a constant tensor for // its attribute type. - auto tensor_attr = inst->getAttr("value").cast(); + auto tensor_attr = mlir::cast(inst->getAttr("value")); llvm::ArrayRef shape_ref = - tensor_attr.getType().cast().getShape(); + mlir::cast(tensor_attr.getType()).getShape(); if (mlir::failed(check_shape(shape_ref))) return std::nullopt; shape = std::vector(shape_ref.begin(), shape_ref.end()); @@ -1202,7 +1204,8 @@ std::optional> Translator::BuildTensor( } BufferOffset q_params; - if (auto qtype = element_type.dyn_cast()) { + if (auto qtype = + mlir::dyn_cast(element_type)) { std::vector scales = {static_cast(qtype.getScale())}; std::vector zero_points = {qtype.getZeroPoint()}; q_params = tflite::CreateQuantizationParameters( @@ -1211,8 +1214,8 @@ std::optional> Translator::BuildTensor( builder_, /*min=*/0, /*max=*/0, builder_.CreateVector(scales), builder_.CreateVector(zero_points)); } else if (auto qtype = - element_type - .dyn_cast()) { + mlir::dyn_cast( + element_type)) { std::vector scales(qtype.getScales().begin(), qtype.getScales().end()); std::vector zero_points(qtype.getZeroPoints().begin(), @@ -1350,7 +1353,9 @@ BufferOffset Translator::BuildCustomOperator( Operation* inst, mlir::TFL::CustomOp op, const std::vector& operands, const std::vector& results) { const std::string attrs = - op.getCustomOption().cast().getValue().str(); + mlir::cast(op.getCustomOption()) + .getValue() + .str(); std::vector custom_option_vector(attrs.size(), 0); memcpy(custom_option_vector.data(), attrs.data(), attrs.size()); auto opcode_index = @@ -1559,7 +1564,7 @@ Translator::BuildStablehloPrecisionConfig(::mlir::ArrayAttr precisionConfig) { for (auto it = precisionConfig.begin(); it != precisionConfig.end(); it++) { precision_config_vec.push_back(static_cast( - (it->cast()).getValue())); + (mlir::cast(*it)).getValue())); } return builder_.CreateVector(precision_config_vec); } @@ -1571,7 +1576,7 @@ Translator::BuildVhloPrecisionConfigV1( auto values = precisionConfig.getValue(); for (auto it = values.begin(); it != values.end(); it++) { precision_config_vec.push_back(static_cast( - (it->cast()).getValue())); + (mlir::cast(*it)).getValue())); } return builder_.CreateVector(precision_config_vec); } @@ -1852,27 +1857,25 @@ std::optional> Translator::BuildVhloGatherV1Op( GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_GATHER); auto offset_dims = builder_.CreateVector(mlir::GetVector( - gather_op.getOffsetDims().cast(), + mlir::cast(gather_op.getOffsetDims()), vhlo_type_converter)); auto collapsed_slice_dims = builder_.CreateVector(mlir::GetVector( - gather_op.getCollapsedSliceDims().cast(), + mlir::cast(gather_op.getCollapsedSliceDims()), vhlo_type_converter)); auto start_index_map = builder_.CreateVector(mlir::GetVector( - gather_op.getStartIndexMap().cast(), + mlir::cast(gather_op.getStartIndexMap()), vhlo_type_converter)); auto slice_sizes = builder_.CreateVector(mlir::GetVector( - gather_op.getSliceSizes().cast(), + mlir::cast(gather_op.getSliceSizes()), vhlo_type_converter)); auto gather_option = tflite::CreateStablehloGatherOptions( builder_, offset_dims, collapsed_slice_dims, start_index_map, - gather_op.getIndexVectorDim() - .cast() + mlir::cast(gather_op.getIndexVectorDim()) .getValue() .getSExtValue(), slice_sizes, - gather_op.getIndicesAreSorted() - .cast() + mlir::cast(gather_op.getIndicesAreSorted()) .getValue()); return tflite::CreateOperator( @@ -1899,26 +1902,26 @@ std::optional> Translator::BuildVhloScatterV1Op( UnnamedRegionToSubgraph(&body, tflite::BuiltinOperator_STABLEHLO_SCATTER); if (subgraph_index < 0) return std::nullopt; - int64_t index_vector_dim = scatter_op.getIndexVectorDim() - .cast() - .getValue() - .getSExtValue(); - bool unique_indices = scatter_op.getUniqueIndices() - .cast() - .getValue(); - bool indices_are_sorted = scatter_op.getIndicesAreSorted() - .cast() - .getValue(); + int64_t index_vector_dim = + mlir::cast(scatter_op.getIndexVectorDim()) + .getValue() + .getSExtValue(); + bool unique_indices = + mlir::cast(scatter_op.getUniqueIndices()) + .getValue(); + bool indices_are_sorted = + mlir::cast(scatter_op.getIndicesAreSorted()) + .getValue(); auto update_window_dims = builder_.CreateVector(mlir::GetVector( - scatter_op.getUpdateWindowDims().cast(), + mlir::cast(scatter_op.getUpdateWindowDims()), vhlo_type_converter)); auto inserted_window_dims = builder_.CreateVector(mlir::GetVector( - scatter_op.getInsertedWindowDims().cast(), + mlir::cast(scatter_op.getInsertedWindowDims()), vhlo_type_converter)); auto scatter_dims_to_operand_dims = builder_.CreateVector( - mlir::GetVector(scatter_op.getScatterDimsToOperandDims() - .cast(), + mlir::GetVector(mlir::cast( + scatter_op.getScatterDimsToOperandDims()), vhlo_type_converter)); auto options = tflite::CreateStablehloScatterOptions( @@ -1946,20 +1949,22 @@ Translator::BuildVhloReduceWindowV1Op( uint32_t opcode_index = GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_REDUCE_WINDOW); - auto window_dimensions = builder_.CreateVector(mlir::GetVector( - reduce_window_op.getWindowDimensions().cast(), - vhlo_type_converter)); + auto window_dimensions = builder_.CreateVector( + mlir::GetVector(mlir::cast( + reduce_window_op.getWindowDimensions()), + vhlo_type_converter)); auto window_strides = builder_.CreateVector(mlir::GetVector( - reduce_window_op.getWindowStrides().cast(), + mlir::cast(reduce_window_op.getWindowStrides()), vhlo_type_converter)); auto base_dilations = builder_.CreateVector(mlir::GetVector( - reduce_window_op.getBaseDilations().cast(), - vhlo_type_converter)); - auto window_dilations = builder_.CreateVector(mlir::GetVector( - reduce_window_op.getWindowDilations().cast(), + mlir::cast(reduce_window_op.getBaseDilations()), vhlo_type_converter)); + auto window_dilations = builder_.CreateVector( + mlir::GetVector(mlir::cast( + reduce_window_op.getWindowDilations()), + vhlo_type_converter)); auto padding = builder_.CreateVector(mlir::GetVector( - reduce_window_op.getPadding().cast(), + mlir::cast(reduce_window_op.getPadding()), vhlo_type_converter)); auto& body = reduce_window_op.getBody(); int32_t subgraph_index = UnnamedRegionToSubgraph( @@ -1990,8 +1995,7 @@ Translator::BuildVhloRngBitGeneratorV1Op( uint32_t opcode_index = GetOpcodeIndex( op_name, tflite::BuiltinOperator_STABLEHLO_RNG_BIT_GENERATOR); tflite::RngAlgorithm algorithm = tflite::RngAlgorithm_DEFAULT; - switch (rng_op.getRngAlgorithm() - .cast() + switch (mlir::cast(rng_op.getRngAlgorithm()) .getValue()) { case mlir::vhlo::RngAlgorithmV1::THREE_FRY: algorithm = tflite::RngAlgorithm_THREEFRY; @@ -2024,13 +2028,13 @@ std::optional> Translator::BuildVhloPadV1Op( GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_PAD); auto edge_padding_low = builder_.CreateVector(mlir::GetVector( - pad_op.getEdgePaddingLow().cast(), + mlir::cast(pad_op.getEdgePaddingLow()), vhlo_type_converter)); auto edge_padding_high = builder_.CreateVector(mlir::GetVector( - pad_op.getEdgePaddingHigh().cast(), + mlir::cast(pad_op.getEdgePaddingHigh()), vhlo_type_converter)); auto interior_padding = builder_.CreateVector(mlir::GetVector( - pad_op.getInteriorPadding().cast(), + mlir::cast(pad_op.getInteriorPadding()), vhlo_type_converter)); auto pad_option = tflite::CreateStablehloPadOptions( @@ -2263,10 +2267,10 @@ std::optional> Translator::BuildOperator( GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_IOTA); auto iota_option = tflite::CreateStablehloIotaOptions( - builder_, vhlo_op.getIotaDimension() - .cast() - .getValue() - .getSExtValue()); + builder_, + mlir::cast(vhlo_op.getIotaDimension()) + .getValue() + .getSExtValue()); return tflite::CreateOperator( builder_, opcode_index, builder_.CreateVector(operands), @@ -2280,7 +2284,7 @@ std::optional> Translator::BuildOperator( op_name, tflite::BuiltinOperator_STABLEHLO_DYNAMIC_SLICE); auto slice_sizes = builder_.CreateVector(mlir::GetVector( - vhlo_op.getSliceSizes().cast(), + mlir::cast(vhlo_op.getSliceSizes()), vhlo_type_converter)); auto dynamic_slice_option = @@ -2303,13 +2307,13 @@ std::optional> Translator::BuildOperator( tflite::StablehloComparisonType_STABLEHLO_COMPARISON_TYPE_NOTYPE; if (compare_type_attr) compare_type = static_cast( - compare_type_attr.cast() + mlir::cast(compare_type_attr) .getValue()); auto compare_option = tflite::CreateStablehloCompareOptions( builder_, static_cast( - vhlo_op.getComparisonDirection() - .cast() + mlir::cast( + vhlo_op.getComparisonDirection()) .getValue()), compare_type); @@ -2326,10 +2330,10 @@ std::optional> Translator::BuildOperator( op_name, tflite::BuiltinOperator_STABLEHLO_CONCATENATE); auto concat_option = tflite::CreateStablehloConcatenateOptions( - builder_, vhlo_op.getDimension() - .cast() - .getValue() - .getSExtValue()); + builder_, + mlir::cast(vhlo_op.getDimension()) + .getValue() + .getSExtValue()); return tflite::CreateOperator( builder_, opcode_index, builder_.CreateVector(operands), @@ -2344,13 +2348,13 @@ std::optional> Translator::BuildOperator( GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_SLICE); auto start_indices = builder_.CreateVector((mlir::GetVector( - vhlo_op.getStartIndicesAttr().cast(), + mlir::cast(vhlo_op.getStartIndicesAttr()), vhlo_type_converter))); auto limit_indices = builder_.CreateVector(mlir::GetVector( - vhlo_op.getLimitIndicesAttr().cast(), + mlir::cast(vhlo_op.getLimitIndicesAttr()), vhlo_type_converter)); auto strides = builder_.CreateVector(mlir::GetVector( - vhlo_op.getStridesAttr().cast(), + mlir::cast(vhlo_op.getStridesAttr()), vhlo_type_converter)); auto slice_option = tflite::CreateStablehloSliceOptions( @@ -2369,63 +2373,64 @@ std::optional> Translator::BuildOperator( op_name, tflite::BuiltinOperator_STABLEHLO_CONVOLUTION); auto window_strides = builder_.CreateVector(mlir::GetVector( - vhlo_op.getWindowStrides().cast(), + mlir::cast(vhlo_op.getWindowStrides()), vhlo_type_converter)); auto padding = builder_.CreateVector(mlir::GetVector( - vhlo_op.getPadding().cast(), + mlir::cast(vhlo_op.getPadding()), vhlo_type_converter)); auto lhs_dialation = builder_.CreateVector(mlir::GetVector( - vhlo_op.getLhsDilation().cast(), + mlir::cast(vhlo_op.getLhsDilation()), vhlo_type_converter)); auto rhs_dialation = builder_.CreateVector(mlir::GetVector( - vhlo_op.getRhsDilation().cast(), + mlir::cast(vhlo_op.getRhsDilation()), vhlo_type_converter)); auto window_reversal = builder_.CreateVector(mlir::GetVector( - vhlo_op.getWindowReversal().cast(), + mlir::cast(vhlo_op.getWindowReversal()), vhlo_type_converter)); - auto input_batch_dimension = vhlo_op.getInputBatchDimension() - .cast() + auto input_batch_dimension = mlir::cast( + vhlo_op.getInputBatchDimension()) .getValue() .getSExtValue(); - auto input_feature_dimension = vhlo_op.getInputFeatureDimension() - .cast() + auto input_feature_dimension = mlir::cast( + vhlo_op.getInputFeatureDimension()) .getValue() .getSExtValue(); auto kernel_input_feature_dimension = - vhlo_op.getKernelInputFeatureDimension() - .cast() + mlir::cast( + vhlo_op.getKernelInputFeatureDimension()) .getValue() .getSExtValue(); auto kernel_output_feature_dimension = - vhlo_op.getKernelOutputFeatureDimension() - .cast() + mlir::cast( + vhlo_op.getKernelOutputFeatureDimension()) .getValue() .getSExtValue(); - auto output_batch_dimension = vhlo_op.getOutputBatchDimension() - .cast() + auto output_batch_dimension = mlir::cast( + vhlo_op.getOutputBatchDimension()) .getValue() .getSExtValue(); - auto output_feature_dimension = vhlo_op.getOutputFeatureDimension() - .cast() + auto output_feature_dimension = mlir::cast( + vhlo_op.getOutputFeatureDimension()) .getValue() .getSExtValue(); auto kernel_spatial_dimensions = builder_.CreateVector( - mlir::GetVector(vhlo_op.getKernelSpatialDimensions() - .cast(), + mlir::GetVector(mlir::cast( + vhlo_op.getKernelSpatialDimensions()), vhlo_type_converter)); auto output_spatial_dimension = builder_.CreateVector( - mlir::GetVector(vhlo_op.getOutputSpatialDimensions() - .cast(), + mlir::GetVector(mlir::cast( + vhlo_op.getOutputSpatialDimensions()), vhlo_type_converter)); auto input_spatial_dimension = builder_.CreateVector( - mlir::GetVector(vhlo_op.getInputSpatialDimensions() - .cast(), + mlir::GetVector(mlir::cast( + vhlo_op.getInputSpatialDimensions()), vhlo_type_converter)); BufferOffset> precision_config = 0; if (vhlo_op.getPrecisionConfig()) { precision_config = BuildVhloPrecisionConfigV1( - vhlo_op.getPrecisionConfig().dyn_cast()); + mlir::dyn_cast( + vhlo_op.getPrecisionConfig())); } auto convolution_option = tflite::CreateStablehloConvolutionOptions( @@ -2435,12 +2440,11 @@ std::optional> Translator::BuildOperator( kernel_output_feature_dimension, kernel_spatial_dimensions, output_batch_dimension, output_feature_dimension, output_spatial_dimension, - vhlo_op.getFeatureGroupCount() - .cast() + mlir::cast( + vhlo_op.getFeatureGroupCount()) .getValue() .getSExtValue(), - vhlo_op.getBatchGroupCount() - .cast() + mlir::cast(vhlo_op.getBatchGroupCount()) .getValue() .getSExtValue(), precision_config); @@ -2458,8 +2462,8 @@ std::optional> Translator::BuildOperator( op_name, tflite::BuiltinOperator_STABLEHLO_BROADCAST_IN_DIM); auto broadcast_dimensions = builder_.CreateVector( - mlir::GetVector(vhlo_op.getBroadcastDimensions() - .cast(), + mlir::GetVector(mlir::cast( + vhlo_op.getBroadcastDimensions()), vhlo_type_converter)); auto broadcast_option = tflite::CreateStablehloBroadcastInDimOptions( @@ -2478,8 +2482,8 @@ std::optional> Translator::BuildOperator( uint32_t opcode_index = GetOpcodeIndex( op_name, tflite::BuiltinOperator_STABLEHLO_CUSTOM_CALL); auto op_api_version = - vhlo_op.getApiVersion() - .cast() + mlir::cast( + vhlo_op.getApiVersion()) .getValue(); int32_t api_version = 0; if (op_api_version == @@ -2495,16 +2499,14 @@ std::optional> Translator::BuildOperator( API_VERSION_STATUS_RETURNING_UNIFIED) api_version = 3; - auto call_target_name = - builder_.CreateString(vhlo_op.getCallTargetName() - .cast() - .getValue() - .str()); - auto backend_config = - builder_.CreateString(vhlo_op.getBackendConfig() - .cast() - .getValue() - .str()); + auto call_target_name = builder_.CreateString( + mlir::cast(vhlo_op.getCallTargetName()) + .getValue() + .str()); + auto backend_config = builder_.CreateString( + mlir::cast(vhlo_op.getBackendConfig()) + .getValue() + .str()); // building the computation info auto flex_builder = std::make_unique(); size_t map_start = flex_builder->StartMap(); @@ -2517,25 +2519,25 @@ std::optional> Translator::BuildOperator( if (name == "call_target_name" || name == "backend_config") continue; if (llvm::isa(attr)) flex_builder->Bool(name.c_str(), - attr.cast().getValue()); + mlir::cast(attr).getValue()); if (llvm::isa(attr)) flex_builder->String( - name.c_str(), attr.cast().getValue().str()); + name.c_str(), + mlir::cast(attr).getValue().str()); if (llvm::isa(attr)) flex_builder->Bool( name.c_str(), - attr.cast().getValue()); + mlir::cast(attr).getValue()); if (llvm::isa(attr)) flex_builder->String( name.c_str(), - attr.cast().getValue().str()); + mlir::cast(attr).getValue().str()); } flex_builder->EndMap(map_start); flex_builder->Finish(); auto custom_call_option = tflite::CreateStablehloCustomCallOptions( builder_, call_target_name, - vhlo_op.getHasSideEffect() - .cast<::mlir::vhlo::BooleanV1Attr>() + mlir::cast<::mlir::vhlo::BooleanV1Attr>(vhlo_op.getHasSideEffect()) .getValue(), backend_config, api_version, 0, builder_.CreateVector(flex_builder->GetBuffer())); @@ -2553,7 +2555,7 @@ std::optional> Translator::BuildOperator( GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_REDUCE); auto dimension = builder_.CreateVector(mlir::GetVector( - vhlo_op.getDimensions().cast(), + mlir::cast(vhlo_op.getDimensions()), vhlo_type_converter)); auto& body = vhlo_op.getBody(); int32_t subgraph_index = UnnamedRegionToSubgraph( @@ -2576,26 +2578,27 @@ std::optional> Translator::BuildOperator( op_name, tflite::BuiltinOperator_STABLEHLO_DOT_GENERAL); auto lhs_batching_dimensions = builder_.CreateVector( - mlir::GetVector(vhlo_op.getLhsBatchingDimensions() - .cast(), + mlir::GetVector(mlir::cast( + vhlo_op.getLhsBatchingDimensions()), vhlo_type_converter)); auto rhs_batching_dimensions = builder_.CreateVector( - mlir::GetVector(vhlo_op.getRhsBatchingDimensions() - .cast(), + mlir::GetVector(mlir::cast( + vhlo_op.getRhsBatchingDimensions()), vhlo_type_converter)); auto lhs_contracting_dimensions = builder_.CreateVector( - mlir::GetVector(vhlo_op.getLhsContractingDimensions() - .cast(), + mlir::GetVector(mlir::cast( + vhlo_op.getLhsContractingDimensions()), vhlo_type_converter)); auto rhs_contracting_dimensions = builder_.CreateVector( - mlir::GetVector(vhlo_op.getRhsContractingDimensions() - .cast(), + mlir::GetVector(mlir::cast( + vhlo_op.getRhsContractingDimensions()), vhlo_type_converter)); BufferOffset> precision_config = 0; if (vhlo_op.getPrecisionConfig()) { - precision_config = BuildVhloPrecisionConfigV1( - vhlo_op.getPrecisionConfig().cast()); + precision_config = + BuildVhloPrecisionConfigV1(mlir::cast( + vhlo_op.getPrecisionConfig())); } auto dot_geneoral_option = tflite::CreateStablehloDotGeneralOptions( @@ -2621,11 +2624,11 @@ std::optional> Translator::BuildOperator( auto sort_option = tflite::CreateStablehloSortOptions( builder_, - vhlo_op.getDimension() - .cast() + mlir::cast(vhlo_op.getDimension()) .getValue() .getSExtValue(), - vhlo_op.getIsStable().cast().getValue(), + mlir::cast(vhlo_op.getIsStable()) + .getValue(), comparator_subgraph_index); return tflite::CreateOperator( @@ -2667,7 +2670,7 @@ std::optional> Translator::BuildOperator( auto transpose_option = tflite::CreateStablehloTransposeOptions( builder_, builder_.CreateVector(mlir::GetVector( - vhlo_op.getPermutation().cast(), + mlir::cast(vhlo_op.getPermutation()), vhlo_type_converter))); return tflite::CreateOperator( @@ -2793,7 +2796,8 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) { llvm::SmallVector input_names; llvm::SmallVector output_names; - if (auto str = dict_attr.get("inputs").dyn_cast_or_null()) { + if (auto str = + mlir::dyn_cast_or_null(dict_attr.get("inputs"))) { str.getValue().split(input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); if (input_names.size() != fn.getNumArguments()) { @@ -2807,7 +2811,7 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) { } if (auto str = - dict_attr.get("outputs").dyn_cast_or_null()) { + mlir::dyn_cast_or_null(dict_attr.get("outputs"))) { str.getValue().split(output_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); auto term = fn.back().getTerminator(); @@ -2832,13 +2836,14 @@ bool Translator::IsStatefulOperand(mlir::Operation* op, int operand_index) { BufferOffset Translator::GetQuantizationForQuantStatsOpOutput( mlir::quantfork::StatisticsOp stats_op) { - auto layer_stats = stats_op.getLayerStats().cast(); + auto layer_stats = + mlir::cast(stats_op.getLayerStats()); std::optional axis_stats = stats_op.getAxisStats(); std::optional axis = stats_op.getAxis(); std::vector mins, maxs; mlir::DenseFPElementsAttr min_max_attr = axis_stats.has_value() - ? axis_stats.value().cast() + ? mlir::cast(axis_stats.value()) : layer_stats; for (const auto& index_and_value : @@ -2873,7 +2878,7 @@ std::optional> Translator::BuildSubGraph( auto build_tensor_and_buffer = [&](Value value, const int subgraph_index, const std::string& tensor_name) { // NoneType represents optional and may be skipped here. - if (value.getType().isa()) { + if (mlir::isa(value.getType())) { return true; } @@ -2957,7 +2962,8 @@ std::optional> Translator::BuildSubGraph( "effective_hidden_scale_intermediate"}; for (const std::string& intermediate : intermediate_names) { auto intermediate_attr = inst.getAttr(intermediate); - if (auto attr = intermediate_attr.dyn_cast_or_null()) { + if (auto attr = + mlir::dyn_cast_or_null(intermediate_attr)) { Type qtype = attr.getValue(); auto tensor_or = BuildTensorFromType( qtype, name_mapper_.GetUniqueName(intermediate).str()); @@ -3003,7 +3009,7 @@ std::optional> Translator::BuildSubGraph( std::vector operands; operands.reserve(real_inst->getNumOperands()); for (auto operand : real_inst->getOperands()) { - if (operand.getType().isa()) + if (mlir::isa(operand.getType())) operands.push_back(kTfLiteOptionalTensor); else if (auto stats_op = llvm::dyn_cast_or_null( @@ -3084,7 +3090,7 @@ Translator::CreateMetadataVector() { for (const auto& named_attr : dict_attr) { StringRef name = named_attr.getName(); mlir::Attribute attr = named_attr.getValue(); - if (auto content = attr.dyn_cast()) { + if (auto content = mlir::dyn_cast(attr)) { metadata.push_back(BuildMetadata(name, content.getValue())); } else { module_.emitError( @@ -3132,7 +3138,7 @@ Translator::CreateMetadataVector() { llvm::SmallVector GetStringsFromAttrWithSeparator( mlir::DictionaryAttr attr, const std::string& attr_key) { llvm::SmallVector result; - if (auto str = attr.get(attr_key).dyn_cast_or_null()) { + if (auto str = mlir::dyn_cast_or_null(attr.get(attr_key))) { str.getValue().split(result, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); } @@ -3151,9 +3157,11 @@ std::vector GetStringsFromDictionaryAttr( auto attrs = arg_attr.getValue(); for (const auto attr : attrs) { if (attr.getName() == attr_name) { - auto array_attr = attr.getValue().dyn_cast_or_null(); + auto array_attr = + mlir::dyn_cast_or_null(attr.getValue()); if (!array_attr || array_attr.empty()) continue; - auto string_attr = array_attr[0].dyn_cast_or_null(); + auto string_attr = + mlir::dyn_cast_or_null(array_attr[0]); if (!string_attr) continue; result.push_back(string_attr.getValue().str()); } @@ -3236,7 +3244,7 @@ std::vector BuildSignaturedef( auto unique_name = std::string(name_mapper.GetUniqueName(operand.get())); result[0].outputs[sig_def_outputs[i]] = unique_name; } - if (auto name_attr = exported_name[0].dyn_cast_or_null()) + if (auto name_attr = mlir::dyn_cast_or_null(exported_name[0])) result[0].signature_key = name_attr.getValue().str(); result[0].subgraph_index = subgraph_index; return result; @@ -3722,8 +3730,8 @@ BufferOffset Translator::BuildSparsityParameters( std::vector> fb_dim_metadata( dim_size); for (int i = 0; i < dim_size; i++) { - const auto dim_metadata = - s_attr.getDimMetadata()[i].dyn_cast(); + const auto dim_metadata = mlir::dyn_cast( + s_attr.getDimMetadata()[i]); if (dim_metadata.getFormat().getValue() == mlir::TFL::DimensionType::DENSE) { fb_dim_metadata[i] = tflite::CreateDimensionMetadata( diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc index 6a10262989930f..fb90b631ea1680 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc @@ -42,6 +42,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "stablehlo/dialect/VhloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -177,7 +178,7 @@ ConvertI64ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray, std::vector intVec; intVec.reserve(attrArray.getValue().size()); for (auto attr : attrArray.getValue()) { - intVec.push_back(attr.cast().getInt()); + intVec.push_back(mlir::cast(attr).getInt()); } return builder->CreateVector(intVec); } @@ -189,7 +190,7 @@ ConvertF32ArrayAttrForOptionWriter(mlir::ArrayAttr attrArray, floatVec.reserve(attrArray.getValue().size()); for (auto attr : attrArray.getValue()) { floatVec.push_back( - attr.cast().getValue().convertToFloat()); + mlir::cast(attr).getValue().convertToFloat()); } return builder->CreateVector(floatVec); } @@ -341,8 +342,8 @@ static mlir::Attribute BuildVhloTensorV1Attr(std::vector shape, std::vector value, mlir::Builder builder) { mlir::StablehloVhloTypeConverter type_converter; - auto builtin_attr = BuildRankedTensorAttr(shape, value, builder) - .dyn_cast(); + auto builtin_attr = mlir::dyn_cast( + BuildRankedTensorAttr(shape, value, builder)); auto vhlo_type = type_converter.convertType(builtin_attr.getType()); return mlir::vhlo::TensorV1Attr::get(builder.getContext(), vhlo_type, builtin_attr.getRawData()); @@ -352,8 +353,8 @@ static mlir::Attribute BuildVhloTensorV1Attr(std::vector shape, std::vector value, mlir::Builder builder) { mlir::StablehloVhloTypeConverter type_converter; - auto builtin_attr = BuildRankedTensorAttr(shape, value, builder) - .dyn_cast(); + auto builtin_attr = mlir::dyn_cast( + BuildRankedTensorAttr(shape, value, builder)); auto vhlo_type = type_converter.convertType(builtin_attr.getType()); return mlir::vhlo::TensorV1Attr::get(builder.getContext(), vhlo_type, builtin_attr.getRawData()); diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.h b/tensorflow/compiler/mlir/lite/flatbuffer_operator.h index 381f2a4c024549..4a54264b6244f2 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.h +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.h @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "stablehlo/dialect/VhloOps.h" // from @stablehlo #include "stablehlo/dialect/VhloTypes.h" // from @stablehlo @@ -65,7 +66,7 @@ class StablehloVhloTypeConverter : public mlir::vhlo::VhloTypeConverter { return attr; if (auto stablehloAttr = - attr.dyn_cast_or_null()) { + mlir::dyn_cast_or_null(attr)) { return mlir::vhlo::TypeExtensionsV1Attr::get(stablehloAttr.getContext(), stablehloAttr.getBounds()); } @@ -88,7 +89,8 @@ class VhloToStablehloTypeConverter : public vhlo::VhloTypeConverter { } Attribute convertEncoding(Attribute attr) const final { - if (auto vhloAttr = attr.dyn_cast_or_null()) { + if (auto vhloAttr = + mlir::dyn_cast_or_null(attr)) { return stablehlo::TypeExtensionsAttr::get(vhloAttr.getContext(), vhloAttr.getBounds()); } @@ -296,8 +298,8 @@ static inline std::vector GetVector( vhlo::TensorV1Attr elements, mlir::vhlo::VhloTypeConverter &vhlo_type_converter) { return GetOptionalVector(mlir::DenseIntElementsAttr::getFromRawBuffer( - vhlo_type_converter.convertType(elements.getType()) - .cast(), + mlir::cast( + vhlo_type_converter.convertType(elements.getType())), elements.getData())); } diff --git a/tensorflow/compiler/mlir/lite/metrics/BUILD b/tensorflow/compiler/mlir/lite/metrics/BUILD index 6218a2fb30a829..464cd8f33822b7 100644 --- a/tensorflow/compiler/mlir/lite/metrics/BUILD +++ b/tensorflow/compiler/mlir/lite/metrics/BUILD @@ -72,5 +72,6 @@ cc_library( "//tensorflow/lite/python/metrics:converter_error_data_proto_cc", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/lite/metrics/types_util.cc b/tensorflow/compiler/mlir/lite/metrics/types_util.cc index b47347ceb03827..7dd658e54dd12e 100644 --- a/tensorflow/compiler/mlir/lite/metrics/types_util.cc +++ b/tensorflow/compiler/mlir/lite/metrics/types_util.cc @@ -20,6 +20,7 @@ limitations under the License. #include "llvm/ADT/TypeSwitch.h" #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/lite/python/metrics/converter_error_data.pb.h" namespace mlir { @@ -67,8 +68,8 @@ class LocationExtractor : public Location { new_call->set_name(loc.getName().str()); // Add child as the source location. auto child_loc = loc.getChildLoc(); - if (child_loc.isa()) { - auto typed_child_loc = child_loc.dyn_cast(); + if (mlir::isa(child_loc)) { + auto typed_child_loc = mlir::dyn_cast(child_loc); ExtractFileLine(typed_child_loc, new_call->mutable_source()); } }) @@ -83,7 +84,7 @@ class LocationExtractor : public Location { // Skip the first location if it stores information for propagating // op_type metadata. if (num_locs > 0) { - if (auto name_loc = locations[0].dyn_cast()) { + if (auto name_loc = mlir::dyn_cast(locations[0])) { if (name_loc.getName().strref().ends_with(":")) { if (num_locs == 2) { return LocationExtractor(locations[1]).Extract(error_data); diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 16e12bbb6da04d..1518690c9880f0 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" @@ -80,7 +81,7 @@ Status HandleInputOutputArraysWithModule( if (!input_attr) { return errors::InvalidArgument("no inputs attribute found"); } - auto input_names = input_attr.cast().getValue(); + auto input_names = mlir::cast(input_attr).getValue(); input_names.split(function_input_names, ",", /*MaxSplit=*/-1, /*KeepEmpty=*/false); const int function_input_names_size = function_input_names.size(); @@ -106,7 +107,7 @@ Status HandleInputOutputArraysWithModule( if (!output_attr) { return errors::InvalidArgument("no outputs attribute found"); } - auto output_names = output_attr.cast().getValue(); + auto output_names = mlir::cast(output_attr).getValue(); output_names.split(function_output_names, ",", /*MaxSplit=*/-1, /*KeepEmpty=*/false); const int function_output_names_size = function_output_names.size(); diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc index 3de159a1414429..f6eac3e90ec8bd 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertConst.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/Passes.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" @@ -80,7 +81,7 @@ LogicalResult QuantizedConstRewrite::matchAndRewrite( } // Is the constant value a type expressed in a way that we support? - if (!value.isa()) { + if (!mlir::isa(value)) { return failure(); } diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc index e99addc5b5f8a5..f0aa7caec0be8d 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc @@ -15,6 +15,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/Passes.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" @@ -121,9 +122,9 @@ class ConstFakeQuantPerAxisRewrite min.reserve(fqOp.getMin().size()); max.reserve(fqOp.getMax().size()); for (auto m : fqOp.getMin()) - min.push_back(m.cast().getValueAsDouble()); + min.push_back(mlir::cast(m).getValueAsDouble()); for (auto m : fqOp.getMax()) - max.push_back(m.cast().getValueAsDouble()); + max.push_back(mlir::cast(m).getValueAsDouble()); return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.getNumBits(), fqOp.getAxis(), min, max, fqOp.getNarrowRange(), diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc b/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc index d111141958c403..8aa6475b888702 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project using namespace mlir; using namespace mlir::quantfork; @@ -51,20 +52,20 @@ OpFoldResult StorageCastOp::fold(FoldAdaptor) { /// The quantization specification should match the expressed type. static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) { - if (auto typeAttr = quantSpec.dyn_cast()) { + if (auto typeAttr = mlir::dyn_cast(quantSpec)) { Type spec = typeAttr.getValue(); - if (spec.isa()) return false; + if (mlir::isa(spec)) return false; // The spec should be either a quantized type which is compatible to the // expressed type, or a primitive type which is as same as the // (element type of) the expressed type. - if (auto quantizedType = spec.dyn_cast()) + if (auto quantizedType = mlir::dyn_cast(spec)) return quantizedType.isCompatibleExpressedType(expressed); - if (auto tensorType = expressed.dyn_cast()) + if (auto tensorType = mlir::dyn_cast(expressed)) return spec == tensorType.getElementType(); - if (auto vectorType = expressed.dyn_cast()) + if (auto vectorType = mlir::dyn_cast(expressed)) return spec == vectorType.getElementType(); } return false; @@ -99,13 +100,13 @@ LogicalResult QuantizeRegionOp::verify() { } LogicalResult StatisticsOp::verify() { - auto tensorArg = getArg().getType().dyn_cast(); + auto tensorArg = mlir::dyn_cast(getArg().getType()); if (!tensorArg) return emitOpError("arg needs to be tensor type."); // Verify layerStats attribute. { auto layerStatsType = getLayerStats().getShapedType(); - if (!layerStatsType.getElementType().isa()) { + if (!mlir::isa(layerStatsType.getElementType())) { return emitOpError("layerStats must have a floating point element type"); } if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) { @@ -122,7 +123,7 @@ LogicalResult StatisticsOp::verify() { std::multiplies()); auto axisStatsType = getAxisStats()->getShapedType(); - if (!axisStatsType.getElementType().isa()) { + if (!mlir::isa(axisStatsType.getElementType())) { return emitOpError("axisStats must have a floating point element type"); } if (axisStatsType.getRank() != 2 || axisStatsType.getDimSize(1) != 2 || diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc b/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc index 919c711272b2c1..2ad06f77de8866 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc @@ -32,8 +32,8 @@ using namespace mlir::quantfork; static Attribute convertPrimitiveValueAttr( Attribute origRealValue, quant::QuantizedType quantizedElementType, const UniformQuantizedValueConverter &converter, Type &outConvertedType) { - if (origRealValue.isa()) { - FloatAttr floatAttr = origRealValue.cast(); + if (mlir::isa(origRealValue)) { + FloatAttr floatAttr = mlir::cast(origRealValue); outConvertedType = quantizedElementType.getStorageType(); return IntegerAttr::get(quantizedElementType.getStorageType(), converter.quantizeFloatToInt(floatAttr.getValue())); @@ -64,11 +64,11 @@ static SparseElementsAttr convertSparseElementsAttr( quant::QuantizedType quantizedElementType, const UniformQuantizedValueConverter &converter) { DenseElementsAttr realDenseAttr = realSparseAttr.getValues(); - if (!realDenseAttr.isa()) { + if (!mlir::isa(realDenseAttr)) { return nullptr; } DenseElementsAttr quantDenseAttr = - convertDenseFPElementsAttr(realDenseAttr.cast(), + convertDenseFPElementsAttr(mlir::cast(realDenseAttr), quantizedElementType, converter); if (!quantDenseAttr) { return nullptr; @@ -76,9 +76,9 @@ static SparseElementsAttr convertSparseElementsAttr( // Cast from an expressed-type-based type to storage-type-based type, // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>). - ShapedType newSparseType = - quantizedElementType.castExpressedToStorageType(realSparseAttr.getType()) - .dyn_cast_or_null(); + ShapedType newSparseType = mlir::dyn_cast_or_null( + quantizedElementType.castExpressedToStorageType( + realSparseAttr.getType())); if (!newSparseType) { return nullptr; } @@ -93,17 +93,19 @@ Attribute mlir::quantfork::quantizeAttrUniform( Attribute realValue, quant::UniformQuantizedType quantizedElementType, const UniformQuantizedValueConverter &converter, Type &outConvertedType) { // Fork to handle different variants of constants supported. - if (realValue.isa()) { + if (mlir::isa(realValue)) { // Dense tensor or vector constant. - auto converted = convertDenseFPElementsAttr( - realValue.cast(), quantizedElementType, converter); + auto converted = + convertDenseFPElementsAttr(mlir::cast(realValue), + quantizedElementType, converter); outConvertedType = converted.getType(); return converted; } - if (realValue.isa()) { + if (mlir::isa(realValue)) { // Sparse tensor or vector constant. - auto converted = convertSparseElementsAttr( - realValue.cast(), quantizedElementType, converter); + auto converted = + convertSparseElementsAttr(mlir::cast(realValue), + quantizedElementType, converter); outConvertedType = converted.getType(); return converted; } @@ -121,13 +123,14 @@ Attribute mlir::quantfork::quantizeAttr( Attribute realValue, quant::QuantizedType quantizedElementType, Type &outConvertedType) { if (auto uniformQuantized = - quantizedElementType.dyn_cast()) { + mlir::dyn_cast(quantizedElementType)) { UniformQuantizedValueConverter converter(uniformQuantized); return quantizeAttrUniform(realValue, uniformQuantized, converter, outConvertedType); } if (auto uniformQuantizedPerAxis = - quantizedElementType.dyn_cast()) { + mlir::dyn_cast( + quantizedElementType)) { UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis); auto converted = converter.convert(realValue); // TODO: why we need this outConvertedType? remove it? diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc index 9a25d849ea7c8a..e6284d273e50d0 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/fallback_to_flex_ops.cc @@ -268,7 +268,7 @@ Value SetNoFallbackAttr(PatternRewriter &rewriter, Value val) { // Returns true if the attr is a float attribute and be equal to value. static bool FloatValueEquals(const Attribute &attr, double value) { - auto fp_attr = attr.dyn_cast_or_null(); + auto fp_attr = mlir::dyn_cast_or_null(attr); if (fp_attr == nullptr) return false; if (fp_attr.isSplat()) { @@ -281,7 +281,7 @@ static bool FloatValueEquals(const Attribute &attr, double value) { // Returns true if the rank of the value equals to the given rank. bool RankEquals(Value value, int rank) { - auto rank_type = value.getType().template dyn_cast(); + auto rank_type = mlir::dyn_cast(value.getType()); return (rank_type && rank_type.getRank() == rank); } diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc index 55790c40509946..b4015181886788 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc @@ -133,7 +133,7 @@ struct InsertQuantOpsAfterTFFakeQuantOp if (PerAxis) { // This is a special case that the quant_dim is the last dimensions // according to the tf.FakeQuantWithMinMaxPerChannel. - quant_dim = res.getType().template cast().getRank() - 1; + quant_dim = mlir::cast(res.getType()).getRank() - 1; } // Use the min/max from the operands and the num_bits and narrow_range // attribute to create the quantization parameter for the new quantize op. diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index 179b20d27d462f..5f26983232a618 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -68,6 +68,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ], alwayslink = 1, ) @@ -546,6 +547,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc index 15481b9a0a1ad2..52f2c4be02a3aa 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc @@ -73,7 +73,7 @@ bool IsI8ToF32Cast(stablehlo::ConvertOp convert_op) { const bool is_i8_operand = convert_op.getOperand().getType().getElementType().isInteger(/*width=*/8); const bool is_f32_result = - convert_op.getResult().getType().getElementType().isa(); + mlir::isa(convert_op.getResult().getType().getElementType()); return is_i8_operand && is_f32_result; } @@ -92,7 +92,7 @@ bool IsI32ToF32Cast(stablehlo::ConvertOp convert_op) { convert_op.getOperand().getType().getElementType().isInteger( /*width=*/32); const bool is_f32_result = - convert_op.getResult().getType().getElementType().isa(); + mlir::isa(convert_op.getResult().getType().getElementType()); return is_i32_operand && is_f32_result; } @@ -104,7 +104,8 @@ LogicalResult MatchZeroPointsOperand(Value zero_points) { return failure(); } - auto zero_points_type = zero_points.getType().dyn_cast_or_null(); + auto zero_points_type = + mlir::dyn_cast_or_null(zero_points.getType()); if (!zero_points_type) { LLVM_DEBUG(llvm::dbgs() << "Zero point value should be a tensor type. Got: " << zero_points_type << ".\n"); @@ -112,7 +113,7 @@ LogicalResult MatchZeroPointsOperand(Value zero_points) { } if (Type zero_points_element_type = zero_points_type.getElementType(); - !zero_points_element_type.isa()) { + !mlir::isa(zero_points_element_type)) { LLVM_DEBUG(llvm::dbgs() << "Zero point should be an integer type. Got: " << zero_points_element_type << ".\n"); return failure(); @@ -146,7 +147,7 @@ LogicalResult MatchInverseScalesOperand(Value inverse_scales) { } auto inverse_scales_type = - inverse_scales.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(inverse_scales.getType()); if (!inverse_scales_type) { LLVM_DEBUG(llvm::dbgs() << "Inverse scales should be a tensor type. Got: " << inverse_scales_type << ".\n"); @@ -154,7 +155,7 @@ LogicalResult MatchInverseScalesOperand(Value inverse_scales) { } if (Type inverse_scales_element_type = inverse_scales_type.getElementType(); - !inverse_scales_element_type.isa()) { + !mlir::isa(inverse_scales_element_type)) { LLVM_DEBUG(llvm::dbgs() << "Inverse scales element should be a float type. Got: " << inverse_scales_element_type << ".\n"); @@ -207,7 +208,7 @@ class UniformQuantizeFunctionCallPattern { } auto input_value_type = - input_value.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(input_value.getType()); if (!input_value_type) { LLVM_DEBUG(llvm::dbgs() << "Failed to match @uniform_quantize function call pattern. " @@ -216,7 +217,7 @@ class UniformQuantizeFunctionCallPattern { } if (Type input_element_type = input_value_type.getElementType(); - !input_element_type.isa()) { + !mlir::isa(input_element_type)) { LLVM_DEBUG(llvm::dbgs() << "Failed to match @uniform_quantize function call pattern. " "Input value's element type must be a float. Got: " @@ -299,7 +300,7 @@ class UniformDequantizeFunctionCallPattern { } auto input_value_type = - input_value.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(input_value.getType()); if (!input_value_type) { LLVM_DEBUG(llvm::dbgs() << "Failed to match @uniform_dequantize call pattern. Input " @@ -309,7 +310,7 @@ class UniformDequantizeFunctionCallPattern { } if (Type input_element_type = input_value_type.getElementType(); - !input_element_type.isa()) { + !mlir::isa(input_element_type)) { LLVM_DEBUG(llvm::dbgs() << "Failed to match @uniform_dequantize call pattern. Input " "value's element type must be integer. Got: " @@ -433,8 +434,9 @@ class ComposeUniformQuantizedConvolutionOp LogicalResult match(stablehlo::ConvolutionOp op) const final { // Verify operands' types. for (Type operand_type : op.getOperandTypes()) { - if (Type element_type = operand_type.cast().getElementType(); - !element_type.isa()) { + if (Type element_type = + mlir::cast(operand_type).getElementType(); + !mlir::isa(element_type)) { LLVM_DEBUG(llvm::dbgs() << "Failed to match. The operand type must be a float. Got: " << element_type << ".\n"); @@ -477,8 +479,9 @@ class ComposeUniformQuantizedConvolutionOp // Match the subgraph that receives the convolution output. Value conv_output_value = op.getResult(); if (auto output_element_type = - conv_output_value.getType().cast().getElementType(); - !output_element_type.isa()) { + mlir::cast(conv_output_value.getType()) + .getElementType(); + !mlir::isa(output_element_type)) { LLVM_DEBUG( llvm::dbgs() << "Failed to match. Output type is expected to be a float. Got: " @@ -530,14 +533,12 @@ class ComposeUniformQuantizedConvolutionOp return failure(); } - if (!(other_zp_i8_to_f32_convert_op.getResult() - .getType() - .getElementType() - .isa() && - other_zp_i8_to_f32_convert_op.getOperand() - .getType() - .getElementType() - .isa())) { + if (!(mlir::isa(other_zp_i8_to_f32_convert_op.getResult() + .getType() + .getElementType()) && + mlir::isa(other_zp_i8_to_f32_convert_op.getOperand() + .getType() + .getElementType()))) { LLVM_DEBUG( llvm::dbgs() << "Failed to match. The ConvertOp is not an i8->f32 type cast.\n"); @@ -671,8 +672,8 @@ class ComposeUniformQuantizedConvolutionOp rewriter.create( uniform_quantize_call_op.getLoc(), /*result=*/ - input_value.getType().cast().clone( - input_quantized_element_type), + mlir::cast(input_value.getType()) + .clone(input_quantized_element_type), /*operand=*/input_value); rewriter.replaceAllUsesWith(input_i8_to_f32_convert_op.getResult(), @@ -689,20 +690,21 @@ class ComposeUniformQuantizedConvolutionOp // This is i8 values disguised as f32 (due to the upcast trick). Simply // cast them to i8. ElementsAttr filter_value = filter_constant_op.getValue(); - filter_i8_value_attr = filter_value.cast().mapValues( - rewriter.getI8Type(), [](const APFloat& val) -> APInt { - APSInt convertedInt(/*BitWidth=*/8, /*isUnsigned=*/false); - bool ignored; - val.convertToInteger(convertedInt, APFloat::rmTowardZero, &ignored); - return convertedInt; - }); + filter_i8_value_attr = + mlir::cast(filter_value) + .mapValues(rewriter.getI8Type(), [](const APFloat& val) -> APInt { + APSInt convertedInt(/*BitWidth=*/8, /*isUnsigned=*/false); + bool ignored; + val.convertToInteger(convertedInt, APFloat::rmTowardZero, + &ignored); + return convertedInt; + }); } else if (isa(filter_op) && isa( filter_op->getOperand(0).getDefiningOp())) { - filter_i8_value_attr = + filter_i8_value_attr = mlir::cast( cast(filter_op->getOperand(0).getDefiningOp()) - .getValue() - .cast(); + .getValue()); } // Create Uniform Quantized constant for the filter. @@ -719,9 +721,9 @@ class ComposeUniformQuantizedConvolutionOp scale_combined_broadcast_in_dim_op.getOperand().getDefiningOp()); SmallVector filter_scale_values; - for (const auto combined_scale_value : combined_scale_constant_op.getValue() - .cast() - .getValues()) { + for (const auto combined_scale_value : + mlir::cast(combined_scale_constant_op.getValue()) + .getValues()) { // UniformQuantizedPerAxisType requires scales to have double dtype. const double filter_scale_value = static_cast( combined_scale_value * input_inverse_scales_value); @@ -780,7 +782,8 @@ class ComposeUniformQuantizedConvolutionOp Value conv_output_value = op.getResult(); auto output_uniform_quantized_tensor_type = RankedTensorType::getChecked( rewriter.getUnknownLoc(), - /*shape=*/conv_output_value.getType().cast().getShape(), + /*shape=*/ + mlir::cast(conv_output_value.getType()).getShape(), output_uniform_quantized_type); SmallVector new_conv_output_types = { @@ -1017,8 +1020,8 @@ class ComposeUniformQuantizedDotGeneralOp rewriter.create( input_i8_to_f32_convert_op.getLoc(), /*result=*/ - input_value.getType().cast().clone( - input_uniform_quantized_type), + mlir::cast(input_value.getType()) + .clone(input_uniform_quantized_type), /*operand=*/input_value); rewriter.replaceAllUsesWith(input_i8_to_f32_convert_op.getResult(), @@ -1029,13 +1032,13 @@ class ComposeUniformQuantizedDotGeneralOp stablehlo::ConstantOp filter_constant_op = GetFilterConstantOp(filter_value); auto filter_value_attr = - filter_constant_op.getValue().cast(); + mlir::cast(filter_constant_op.getValue()); if (filter_value_attr.getElementType().isF32()) { // This is i8 values disguised as f32 (due to the upcast trick). Simply // cast them to i8. filter_value_attr = - filter_value_attr.cast().mapValues( - rewriter.getI8Type(), [](const APFloat& val) -> APInt { + mlir::cast(filter_value_attr) + .mapValues(rewriter.getI8Type(), [](const APFloat& val) -> APInt { APSInt converted_int(/*BitWidth=*/8, /*isUnsigned=*/false); bool ignored; val.convertToInteger(converted_int, APFloat::rmTowardZero, @@ -1072,9 +1075,9 @@ class ComposeUniformQuantizedDotGeneralOp auto merged_scale_constant_op = cast(multiply_op_second_operand.getDefiningOp()); SmallVector filter_scale_values; - for (const auto merged_scale : merged_scale_constant_op.getValue() - .cast() - .getValues()) { + for (const auto merged_scale : + mlir::cast(merged_scale_constant_op.getValue()) + .getValues()) { // (s1 * s2) * (1 / s1) = s2 // UniformQuantizedPerAxisType requires scales to have double dtype. filter_scale_values.push_back( @@ -1086,7 +1089,7 @@ class ComposeUniformQuantizedDotGeneralOp const int quantization_dimension = GetFilterQuantizationDimension( op.getDotDimensionNumbers(), - filter_value_attr.getType().cast().getRank()); + mlir::cast(filter_value_attr.getType()).getRank()); const UniformQuantizedPerAxisType filter_uniform_quantized_type = CreateI8F32UniformQuantizedPerAxisType( filter_constant_op.getLoc(), *rewriter.getContext(), @@ -1097,8 +1100,8 @@ class ComposeUniformQuantizedDotGeneralOp auto quantized_filter_constant_op = rewriter.create( filter_constant_op.getLoc(), /*output=*/ - filter_constant_op.getResult().getType().cast().clone( - filter_uniform_quantized_type), + mlir::cast(filter_constant_op.getResult().getType()) + .clone(filter_uniform_quantized_type), /*value=*/filter_value_attr); rewriter.replaceAllUsesWith(filter_value, @@ -1137,8 +1140,8 @@ class ComposeUniformQuantizedDotGeneralOp auto new_dot_general_op = rewriter.create( op.getLoc(), /*resultType0=*/ - op.getResult().getType().cast().clone( - output_uniform_quantized_type), + mlir::cast(op.getResult().getType()) + .clone(output_uniform_quantized_type), /*lhs=*/op.getLhs(), /*rhs=*/op.getRhs(), /*dot_dimension_numbers=*/op.getDotDimensionNumbers(), /*precision_config=*/op.getPrecisionConfigAttr()); @@ -1395,8 +1398,8 @@ class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations rewriter.create( input1_uniform_quantize_call_op.getLoc(), /*result=*/ - input1_value.getType().cast().clone( - input1_uniform_quantized_type), + mlir::cast(input1_value.getType()) + .clone(input1_uniform_quantized_type), /*operand=*/input1_value); rewriter.replaceAllUsesWith(input1_zero_point_subtract_op.getResult(), @@ -1434,8 +1437,8 @@ class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations rewriter.create( input2_uniform_quantize_call_op.getLoc(), /*result=*/ - input2_value.getType().cast().clone( - input2_uniform_quantized_type), + mlir::cast(input2_value.getType()) + .clone(input2_uniform_quantized_type), /*operand=*/input2_value); rewriter.replaceAllUsesWith(input2_zero_point_subtract_op.getResult(), @@ -1482,8 +1485,8 @@ class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations auto new_dot_general_op = rewriter.create( op.getLoc(), /*resultType0=*/ - op.getResult().getType().cast().clone( - output_uniform_quantized_type), + mlir::cast(op.getResult().getType()) + .clone(output_uniform_quantized_type), /*lhs=*/op.getLhs(), /*rhs=*/op.getRhs(), /*dot_dimension_numbers=*/op.getDotDimensionNumbers(), /*precision_config=*/op.getPrecisionConfigAttr()); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.cc index 801c8775682cbd..8c28f2e5e5df4b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_avg_pool.cc @@ -55,7 +55,7 @@ DenseIntElementsAttr GetPaddingArrayAttr(Builder& builder, Operation* old_op) { } ShapedType GetPaddedType(Operation* old_op) { - auto input_type = old_op->getOperand(0).getType().cast(); + auto input_type = mlir::cast(old_op->getOperand(0).getType()); auto input_shape = input_type.getShape(); // NCHW int64_t batch_size = input_shape[0]; int64_t channel_size = input_shape[1]; @@ -124,7 +124,7 @@ StringAttr GetPaddingStringAttr(Builder& builder, Operation* old_op) { auto composite_attrs = composite_op.getCompositeAttributes(); auto operand_shape = - composite_op.getOperand(0).getType().cast().getShape(); + mlir::cast(composite_op.getOperand(0).getType()).getShape(); // NC(H)(W) std::vector spatial_dim_sizes = { static_cast(operand_shape[2]), diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.cc index 403bf9968a9acd..2809c81458918c 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace odml { @@ -65,8 +66,8 @@ bool GetI32VectorFromDenseI64CompositeAttr( bool IsSupportedNchwUpsampleBlinear( Value input, Value output, const DenseIntElementsAttr& output_size_attr) { - auto input_shape = input.getType().cast().getShape(); - auto output_shape = output.getType().cast().getShape(); + auto input_shape = mlir::cast(input.getType()).getShape(); + auto output_shape = mlir::cast(output.getType()).getShape(); // Only support 4D tensor. if (input_shape.size() != 4 || output_shape.size() != 4) { @@ -89,7 +90,7 @@ bool IsSupportedNchwUpsampleBlinear( ShapedType GetNhwcReturnTypeFromNchw(Operation* old_op) { auto composite_result_shape = - old_op->getResults().front().getType().cast().getShape(); + mlir::cast(old_op->getResults().front().getType()).getShape(); std::array output_shape; // NHWC <- NCHW output_shape[0] = composite_result_shape[0]; @@ -97,7 +98,7 @@ ShapedType GetNhwcReturnTypeFromNchw(Operation* old_op) { output_shape[2] = composite_result_shape[3]; output_shape[3] = composite_result_shape[1]; - auto input_type = old_op->getOperand(0).getType().cast(); + auto input_type = mlir::cast(old_op->getOperand(0).getType()); return RankedTensorType::get(output_shape, input_type.getElementType()); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.h index 0691dc74997212..79d0910bce18a4 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.h @@ -38,10 +38,10 @@ template bool EnsureAttribute(const DictionaryAttr& composite_attributes, const std::string& attr_name, AttrType* out_attr) { Attribute attr = composite_attributes.get(attr_name); - if (!attr.isa_and_nonnull()) { + if (!mlir::isa_and_nonnull(attr)) { return false; } - if (AttrType content = attr.dyn_cast()) { + if (AttrType content = mlir::dyn_cast(attr)) { *out_attr = content; return true; } else { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc index 847738e5cc7cbe..c2b31aeb540720 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/fold_broadcast_pass.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" @@ -117,7 +118,7 @@ static Attribute BinaryFolder(Op *op) { auto rhs = dyn_cast_or_null(rhs_op.getValue()); if (!lhs || !rhs) return {}; - ShapedType type = op->getType().template cast(); + ShapedType type = mlir::cast(op->getType()); if (!type.hasStaticShape()) { return {}; } @@ -125,15 +126,15 @@ static Attribute BinaryFolder(Op *op) { Type etype = type.getElementType(); // Evaluate for element types. - if (!etype.isa()) { + if (!mlir::isa(etype)) { return {}; } // Special case for folding splats no matter how large. // Only covers the case of both attrs being splats; operation-specific cases // like adding a zero or multiplying by one are handled elsewhere. - SplatElementsAttr splatLhs = lhs.template dyn_cast(); - SplatElementsAttr splatRhs = rhs.template dyn_cast(); + SplatElementsAttr splatLhs = mlir::dyn_cast(lhs); + SplatElementsAttr splatRhs = mlir::dyn_cast(rhs); if (splatLhs && splatRhs) { auto signedLhs = addSign(splatLhs.getSplatValue(), etype); auto signedRhs = addSign(splatRhs.getSplatValue(), etype); @@ -195,10 +196,10 @@ class FoldBroadcastInDimBeforeBinaryElementwiseOp auto bcast_dims = bcast_op.getBroadcastDimensions(); auto elem_type = const_val.getElementType(); Attribute result; - if (elem_type.template isa()) { + if (mlir::isa(elem_type)) { result = ConstFoldBroadcastInDim(result_type, const_val, bcast_dims); - } else if (elem_type.template isa()) { + } else if (mlir::isa(elem_type)) { result = ConstFoldBroadcastInDim(result_type, const_val, bcast_dims); } else { @@ -217,14 +218,14 @@ using FoldBroadcastInDimBeforeMulOp = // Constant folds mhlo.mul, this folder doesn't have an upper limit on how many // elements can be folded. LogicalResult ConstantFoldMul(mhlo::MulOp op, PatternRewriter &rewriter) { - ShapedType type = op.getType().dyn_cast(); + ShapedType type = mlir::dyn_cast(op.getType()); Type etype = type.getElementType(); Attribute result = {}; - if (etype.isa()) { + if (mlir::isa(etype)) { result = BinaryFolder>( &op); - } else if (etype.isa()) { + } else if (mlir::isa(etype)) { result = BinaryFolder>( &op); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/hlo_matchers.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/hlo_matchers.cc index 2d9308b05cb47b..9e7a5d424a2ecc 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/hlo_matchers.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/hlo_matchers.cc @@ -237,9 +237,9 @@ bool MatchReshapedIota(DenseIntElementsAttr dimensions, Value iota) { auto reshape_op = dyn_cast_or_null(iota.getDefiningOp()); if (!reshape_op) return false; auto operand_type = - reshape_op.getOperand().getType().dyn_cast(); + mlir::dyn_cast(reshape_op.getOperand().getType()); if (!operand_type || !operand_type.hasStaticShape()) return false; - auto reshape_type = reshape_op.getType().cast(); + auto reshape_type = mlir::cast(reshape_op.getType()); // Reshape can take a 1-D iota input and add extra dims of size one. if (operand_type.getRank() != 1) return false; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc index 3a483f44568ce2..96081a2b2b1bd8 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc @@ -162,10 +162,8 @@ class ConvertNdConvOp : public OpConversionPattern { } // tf Convolution doesn't support quantized type. - if (conv_op.getRhs() - .getType() - .getElementType() - .isa()) { + if (mlir::isa( + conv_op.getRhs().getType().getElementType())) { return failure(); } @@ -193,11 +191,11 @@ class ConvertNdConvOp : public OpConversionPattern { const int kernel_input_feature_dimension = dnums.getKernelInputFeatureDimension(); const int input_channels = - conv_op.getLhs().getType().cast().getDimSize( - input_feature_dimension); + mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_feature_dimension); const int kernel_input_channels = - conv_op.getRhs().getType().cast().getDimSize( - kernel_input_feature_dimension); + mlir::cast(conv_op.getRhs().getType()) + .getDimSize(kernel_input_feature_dimension); int feature_group_count = conv_op.getFeatureGroupCount(); // check if group count is valid @@ -238,14 +236,14 @@ class ConvertNdConvOp : public OpConversionPattern { }; static bool IsSupportedConvOp(mhlo::ConvolutionOp conv_op) { - if (!conv_op.getRhs().getType().cast().hasStaticShape()) { + if (!mlir::cast(conv_op.getRhs().getType()).hasStaticShape()) { return false; } - if (!conv_op.getLhs().getType().cast().hasStaticShape() && - !conv_op.getType().cast().hasStaticShape()) { + if (!mlir::cast(conv_op.getLhs().getType()).hasStaticShape() && + !mlir::cast(conv_op.getType()).hasStaticShape()) { auto dnums = conv_op.getDimensionNumbers(); - auto lhs_type = conv_op.getLhs().getType().cast(); - auto out_type = conv_op.getType().cast(); + auto lhs_type = mlir::cast(conv_op.getLhs().getType()); + auto out_type = mlir::cast(conv_op.getType()); int64_t input_batch_dim = dnums.getInputBatchDimension(); int64_t out_batch_dim = dnums.getOutputBatchDimension(); for (size_t i = 0; i < lhs_type.getRank(); ++i) { @@ -263,10 +261,7 @@ class ConvertNdConvOp : public OpConversionPattern { if (!lhs_dilation.isSplat() || lhs_dilation.getSplatValue() != 1) return false; - if (conv_op.getWindowStrides() - .value() - .getType() - .cast() + if (mlir::cast(conv_op.getWindowStrides().value().getType()) .getRank() != 1) return false; @@ -290,10 +285,10 @@ class ConvertNdConvOp : public OpConversionPattern { int64_t pad_low_int64; int64_t pad_high_int64; tensorflow::Status status = tensorflow::GetWindowedOutputSizeVerbose( - conv_op.getLhs().getType().cast().getDimSize( - input_spatial_dim[i]), - conv_op.getRhs().getType().cast().getDimSize( - kernel_spatial_dim[i]), + mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_spatial_dim[i]), + mlir::cast(conv_op.getRhs().getType()) + .getDimSize(kernel_spatial_dim[i]), dilation[dim], strides[dim], tensorflow::Padding::SAME, &output_size, &pad_low_int64, &pad_high_int64); if (!status.ok()) return false; @@ -314,7 +309,7 @@ class ConvertNdConvOp : public OpConversionPattern { return value; } - auto input_type = value.getType().cast(); + auto input_type = mlir::cast(value.getType()); auto input_shape = input_type.getShape(); llvm::SmallVector start; @@ -380,7 +375,7 @@ class ConvertNdConvOp : public OpConversionPattern { // Convolution. This is needed because TF.Conv3DOp doesn't support EXPLICIT. if (padding == "EXPLICIT" && num_spatial_dims == 3) { auto lhs_type = - conv_op.getLhs().getType().template dyn_cast(); + mlir::dyn_cast(conv_op.getLhs().getType()); RankedTensorType padding_attr_type = mlir::RankedTensorType::get( {lhs_type.getRank(), 2}, rewriter.getIntegerType(64)); auto padding_const = rewriter.create( @@ -394,7 +389,7 @@ class ConvertNdConvOp : public OpConversionPattern { padding = "VALID"; } - auto conv_output_type = conv_op.getType().cast(); + auto conv_output_type = mlir::cast(conv_op.getType()); DenseIntElementsAttr permutation; const bool need_transpose_output = NeedsReformatTypeAndPermutation( dnums.getOutputBatchDimension(), dnums.getOutputFeatureDimension(), @@ -418,7 +413,7 @@ class ConvertNdConvOp : public OpConversionPattern { // Reshapes filter format to [filter_height, filter_width, in_channels, // channel_multiplier] from HLO's [filter_height, filter_width, 1, // in_channels * channel_multiplier] format. - auto filter_type = rhs.getType().cast(); + auto filter_type = mlir::cast(rhs.getType()); llvm::ArrayRef hlo_filter_shape = filter_type.getShape(); llvm::SmallVector tf_filter_shape(hlo_filter_shape.begin(), hlo_filter_shape.end()); @@ -491,13 +486,13 @@ class Convert1DConvOp : public OpConversionPattern { // Group convolution is not supported yet. const int64_t input_feature_dimension = dnums.getInputFeatureDimension(); const int64_t input_channels = - conv_op.getLhs().getType().cast().getDimSize( - input_feature_dimension); + mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_feature_dimension); const int kernel_input_feature_dimension = dnums.getKernelInputFeatureDimension(); const int kernel_input_channels = - conv_op.getRhs().getType().cast().getDimSize( - kernel_input_feature_dimension); + mlir::cast(conv_op.getRhs().getType()) + .getDimSize(kernel_input_feature_dimension); const int64_t feature_group_count = conv_op.getFeatureGroupCount(); if (feature_group_count != input_channels / kernel_input_channels || input_channels % kernel_input_channels != 0) @@ -508,7 +503,7 @@ class Convert1DConvOp : public OpConversionPattern { // // Reshape input image to add a new spatial dimension. - auto image_type = conv_op.getLhs().getType().cast(); + auto image_type = mlir::cast(conv_op.getLhs().getType()); SmallVector image_2d_shape(image_type.getShape().begin(), image_type.getShape().end()); image_2d_shape.push_back(1); @@ -530,7 +525,7 @@ class Convert1DConvOp : public OpConversionPattern { image_permutation_and_shape.permutation); // Reshape kernel to add a new spatial dimension. - auto kernel_type = conv_op.getRhs().getType().cast(); + auto kernel_type = mlir::cast(conv_op.getRhs().getType()); SmallVector kernel_2d_shape; for (int64_t dim : kernel_type.getShape()) { kernel_2d_shape.push_back(dim); @@ -623,7 +618,7 @@ class Convert1DConvOp : public OpConversionPattern { // // Determine the 2-D convolution output shape. - auto output_type = conv_op->getResult(0).getType().cast(); + auto output_type = mlir::cast(conv_op->getResult(0).getType()); SmallVector output_2d_shape; for (int64_t dim : output_type.getShape()) { output_2d_shape.push_back(dim); @@ -648,7 +643,7 @@ class Convert1DConvOp : public OpConversionPattern { conv_op.getPrecisionConfigAttr()); OpResult conv2d_output = conv2d_op->getResult(0); - auto conv2d_output_type = conv2d_output.getType().cast(); + auto conv2d_output_type = mlir::cast(conv2d_output.getType()); // // Transpose and reshape the output @@ -676,9 +671,9 @@ using Convert3DConvOp = ConvertNdConvOp<3>; // lhs_dilation>1 and window_strides=1. LogicalResult IsSupportedNonTrivialConvOp(mhlo::ConvolutionOp conv_op, ConversionPatternRewriter& rewriter) { - if (!conv_op.getLhs().getType().cast().hasStaticShape() || - !conv_op.getRhs().getType().cast().hasStaticShape() || - !conv_op.getType().cast().hasStaticShape()) + if (!mlir::cast(conv_op.getLhs().getType()).hasStaticShape() || + !mlir::cast(conv_op.getRhs().getType()).hasStaticShape() || + !mlir::cast(conv_op.getType()).hasStaticShape()) return rewriter.notifyMatchFailure(conv_op, "requires static shape"); mhlo::ConvDimensionNumbersAttr dnums = conv_op.getDimensionNumbers(); @@ -687,10 +682,7 @@ LogicalResult IsSupportedNonTrivialConvOp(mhlo::ConvolutionOp conv_op, return rewriter.notifyMatchFailure(conv_op, "requires non-trivial lhs_dilation"); - if (conv_op.getWindowStrides() - .value() - .getType() - .cast() + if (mlir::cast(conv_op.getWindowStrides().value().getType()) .getRank() != 1) return rewriter.notifyMatchFailure( conv_op, "requires window_strides to equal to one"); @@ -746,19 +738,19 @@ class ConvertToResizeBilinearOpOrDepthwiseTransposedConvOp mhlo::ConvDimensionNumbersAttr dnums = conv_op.getDimensionNumbers(); const int input_feature_dimension = dnums.getInputFeatureDimension(); const int input_channels = - conv_op.getLhs().getType().cast().getDimSize( - input_feature_dimension); + mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_feature_dimension); int feature_group_count = conv_op.getFeatureGroupCount(); const int kernel_input_feature_dimension = dnums.getKernelInputFeatureDimension(); const int kernel_input_channels = - conv_op.getRhs().getType().cast().getDimSize( - kernel_input_feature_dimension); + mlir::cast(conv_op.getRhs().getType()) + .getDimSize(kernel_input_feature_dimension); const int kernel_output_feature_dimension = dnums.getKernelOutputFeatureDimension(); const int kernel_output_channels = - conv_op.getRhs().getType().cast().getDimSize( - kernel_output_feature_dimension); + mlir::cast(conv_op.getRhs().getType()) + .getDimSize(kernel_output_feature_dimension); // To support a depthwise convolution, we need- // 1. feature_group_count != 1 (except when input_channels==1) @@ -795,7 +787,7 @@ class ConvertToResizeBilinearOpOrDepthwiseTransposedConvOp auto create_slice = [&](mlir::Value tensor, int depth_idx, int channel_idx, bool is_kernel = false) -> mlir::Value { std::vector tensor_shape = - tensor.getType().cast().getShape().vec(); + mlir::cast(tensor.getType()).getShape().vec(); // Calculate offsets based on depth_idx, channel_idx and tensor_shape std::vector start_indices(tensor_shape.size(), 0); @@ -828,7 +820,8 @@ class ConvertToResizeBilinearOpOrDepthwiseTransposedConvOp // Calculate convolution output_type based on sliced_input and // sliced_kernel - auto output_type = conv_op->getResult(0).getType().cast(); + auto output_type = + mlir::cast(conv_op->getResult(0).getType()); std::vector new_output_shape = output_type.getShape().vec(); new_output_shape[dnums.getOutputFeatureDimension()] /= feature_group_count; @@ -884,8 +877,8 @@ class ConvertToResizeBilinearOpOrDepthwiseTransposedConvOp int feature_group_count = conv_op.getFeatureGroupCount(); const int input_feature_dimension = dnums.getInputFeatureDimension(); const int input_channels = - conv_op.getLhs().getType().cast().getDimSize( - input_feature_dimension); + mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_feature_dimension); // Check for Group Convolution parameters if (feature_group_count != 1 && feature_group_count != input_channels) { @@ -919,7 +912,7 @@ class ConvertToResizeBilinearOpOrDepthwiseTransposedConvOp auto padding_values = padding.getValues(); // Cast the dimension sizes to int. - auto lhs_type = conv_op.getLhs().getType().cast(); + auto lhs_type = mlir::cast(conv_op.getLhs().getType()); llvm::SmallVector input_sizes = { static_cast(lhs_type.getDimSize(input_spatial_dimensions[0])), static_cast(lhs_type.getDimSize(input_spatial_dimensions[1]))}; @@ -1101,7 +1094,8 @@ class ConvertNonTrivialConvOp transpose_order[dnums.getOutputSpatialDimensions().data()[i]] = i + 1; } auto output_shape = - conv_op.getResult().getType().cast().getShape(); + mlir::cast(conv_op.getResult().getType()) + .getShape(); SmallVector transposed_output_shape = { output_shape[dnums.getOutputBatchDimension()], output_shape[dnums.getOutputSpatialDimensions().data()[0]], @@ -1114,7 +1108,7 @@ class ConvertNonTrivialConvOp } auto output_type = RankedTensorType::get( transposed_output_shape, - conv_op.getRhs().getType().cast().getElementType()); + mlir::cast(conv_op.getRhs().getType()).getElementType()); auto output_sizes = rewriter.create( conv_op.getLoc(), DenseIntElementsAttr::get( @@ -1138,7 +1132,8 @@ class ConvertNonTrivialConvOp } else { SmallVector output_shape_i32; for (int64_t dim : - conv_op.getResult().getType().cast().getShape()) { + mlir::cast(conv_op.getResult().getType()) + .getShape()) { output_shape_i32.push_back(dim); } auto output_sizes = rewriter.create( @@ -1176,14 +1171,12 @@ class ConvertNonTrivialConvOp for (size_t i = 1; i <= num_spatial_dims; ++i) { int64_t stride = strides[i]; - int64_t input_size = - conv_op.getLhs().getType().cast().getDimSize( - input_spatial_dims[i - 1]); - int64_t kernel_size = - conv_op.getRhs().getType().cast().getDimSize( - kernel_spatial_dims[i - 1]); - int64_t output_size = conv_op.getType().cast().getDimSize( - output_spatial_dims[i - 1]); + int64_t input_size = mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_spatial_dims[i - 1]); + int64_t kernel_size = mlir::cast(conv_op.getRhs().getType()) + .getDimSize(kernel_spatial_dims[i - 1]); + int64_t output_size = mlir::cast(conv_op.getType()) + .getDimSize(output_spatial_dims[i - 1]); // stablehlo.convolution op needs explicit padding to be set to model any // Transposed-Convolution in JAX/PT. Checking to see if- @@ -1225,11 +1218,10 @@ class ConvertNonTrivialConvOp return false; } int64_t stride = strides[i + 1]; - int64_t input_size = - conv_op.getLhs().getType().cast().getDimSize( - input_spatial_dims[i]); - int64_t output_size = conv_op.getType().cast().getDimSize( - output_spatial_dims[i]); + int64_t input_size = mlir::cast(conv_op.getLhs().getType()) + .getDimSize(input_spatial_dims[i]); + int64_t output_size = mlir::cast(conv_op.getType()) + .getDimSize(output_spatial_dims[i]); // The reason for the below check is as follows: // When computing the output, we have the following relation between // o - output dim size, i - input dim size, s - stride, P - total pads @@ -1280,13 +1272,11 @@ class ConvertDynamicSliceOp : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::DynamicSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { - ShapedType input_type = op.getOperand().getType().cast(); + ShapedType input_type = mlir::cast(op.getOperand().getType()); if (!input_type.hasStaticShape()) return failure(); - Type start_indices_element_type = op.getStartIndices() - .front() - .getType() - .cast() - .getElementType(); + Type start_indices_element_type = + mlir::cast(op.getStartIndices().front().getType()) + .getElementType(); // The mhlo dynamic_slice's start_indices can be either signed/unsigned // int32/int64. However, TF only takes in either i32 or i64 types for begin, @@ -1307,8 +1297,8 @@ class ConvertDynamicSliceOp : public OpConversionPattern { for (uint64_t i = 0, e = op.getStartIndices().size(); i < e; ++i) { // Always put a cast there. auto start = op.getStartIndices()[i]; - auto cast_type = start.getType().cast().clone( - signed_start_indices_element_type); + auto cast_type = mlir::cast(start.getType()) + .clone(signed_start_indices_element_type); auto cast_op = rewriter.create(op.getLoc(), cast_type, start); Value clamp_max = rewriter.create( op.getLoc(), rewriter.getIntegerAttr( @@ -1409,11 +1399,11 @@ class ConvertDynamicUpdateSliceOp LogicalResult matchAndRewrite( mhlo::DynamicUpdateSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { - ShapedType operand_type = op.getOperand().getType().cast(); + ShapedType operand_type = mlir::cast(op.getOperand().getType()); ShapedType update_type = - op.getUpdate().getType().dyn_cast_or_null(); - ShapedType start_indices_type = - op.getStartIndices().front().getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(op.getUpdate().getType()); + ShapedType start_indices_type = mlir::dyn_cast_or_null( + op.getStartIndices().front().getType()); if (update_type == nullptr || start_indices_type == nullptr) return rewriter.notifyMatchFailure( op, "update and start_indices should have ShapedType"); @@ -1474,8 +1464,8 @@ class ConvertSortToTfTopk : public OpConversionPattern { op, "only match for the case where operands is of size 2"); auto keys = op.getInputs()[0]; auto indices = op.getInputs()[1]; - auto keys_ty = keys.getType().dyn_cast_or_null(); - auto indices_ty = indices.getType().dyn_cast_or_null(); + auto keys_ty = mlir::dyn_cast_or_null(keys.getType()); + auto indices_ty = mlir::dyn_cast_or_null(indices.getType()); if (!keys_ty || !keys_ty.hasStaticShape() || !keys_ty.getElementType().isIntOrFloat()) return rewriter.notifyMatchFailure( @@ -1589,7 +1579,7 @@ Value BuildDotOperandFlattenedShapeOp(Value operand, DotDimensionsInfo dot_dimensions_info, ImplicitLocOpBuilder& builder, bool is_lhs) { - auto operand_type = operand.getType().cast(); + auto operand_type = mlir::cast(operand.getType()); BoolAttr true_attr = builder.getBoolAttr(true); auto operand_shape = builder.create(operand, true_attr); const int64_t operand_rank = operand_type.getRank(); @@ -1665,8 +1655,8 @@ Value BuildDotOperandFlattenedShapeOp(Value operand, Value ConvertDot(PatternRewriter& rewriter, Value lhs, Value rhs, DotDimensionNumbersAttr dot_dimension_numbers, ShapedType result_type, mlir::Location loc) { - auto lhs_type = lhs.getType().cast(); - auto rhs_type = rhs.getType().cast(); + auto lhs_type = mlir::cast(lhs.getType()); + auto rhs_type = mlir::cast(rhs.getType()); const int lhs_rank = lhs_type.getRank(); const int rhs_rank = rhs_type.getRank(); ImplicitLocOpBuilder builder(loc, rewriter); @@ -1821,7 +1811,7 @@ Value ConvertDot(PatternRewriter& rewriter, Value lhs, Value rhs, // necessary. Value ConvertDotOp(PatternRewriter& rewriter, Operation* old_op) { auto dot_op = cast(old_op); - auto lhs_rank = dot_op.getLhs().getType().cast().getRank(); + auto lhs_rank = mlir::cast(dot_op.getLhs().getType()).getRank(); auto dot_dimension_numbers = DotDimensionNumbersAttr::get(rewriter.getContext(), /*lhs_batching_dimensions=*/{}, @@ -1831,17 +1821,18 @@ Value ConvertDotOp(PatternRewriter& rewriter, Operation* old_op) { /*rhs_contracting_dimensions=*/{0}); return ConvertDot( rewriter, dot_op.getLhs(), dot_op.getRhs(), dot_dimension_numbers, - dot_op.getResult().getType().cast(), dot_op.getLoc()); + mlir::cast(dot_op.getResult().getType()), dot_op.getLoc()); } // Converts mhlo.dot to tf.BatchMatMul. Reshape or Transpose ops will also be // inserted to convert to well-formed matrix multiply. Value ConvertDotGeneralOp(PatternRewriter& rewriter, Operation* old_op) { auto dot_general_op = cast(old_op); - return ConvertDot(rewriter, dot_general_op.getLhs(), dot_general_op.getRhs(), - dot_general_op.getDotDimensionNumbers(), - dot_general_op.getResult().getType().cast(), - dot_general_op.getLoc()); + return ConvertDot( + rewriter, dot_general_op.getLhs(), dot_general_op.getRhs(), + dot_general_op.getDotDimensionNumbers(), + mlir::cast(dot_general_op.getResult().getType()), + dot_general_op.getLoc()); } // Replace BinaryOp with a combination of TfBinaryOp and TfReduceOp if the @@ -1940,9 +1931,9 @@ class ConvertReduceOpToTfOp : public OpConversionPattern { reduce_op.getResults().size() != 1) return failure(); - if (!reduce_op.getInputs()[0].getType().isa()) + if (!mlir::isa(reduce_op.getInputs()[0].getType())) return failure(); - if (!reduce_op.getType(0).isa()) return failure(); + if (!mlir::isa(reduce_op.getType(0))) return failure(); return success(); } }; @@ -1953,13 +1944,13 @@ class ConvertReduceOpToTfProd using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp; LogicalResult MatchInitValue(Value init_value) const override { - auto type = init_value.getType().cast().getElementType(); - if (type.isa()) { + auto type = mlir::cast(init_value.getType()).getElementType(); + if (mlir::isa(type)) { float const_value; if (failed(GetConstantSplatValue(init_value, const_value)) || const_value != 1.0) return failure(); - } else if (type.isa() && type.isSignlessInteger()) { + } else if (mlir::isa(type) && type.isSignlessInteger()) { int32_t const_value; if (failed(GetConstantSplatValue(init_value, const_value)) || const_value != 1) @@ -1978,13 +1969,13 @@ class ConvertReduceOpToTfSum using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp; LogicalResult MatchInitValue(Value init_value) const override { - auto type = init_value.getType().cast().getElementType(); - if (type.isa()) { + auto type = mlir::cast(init_value.getType()).getElementType(); + if (mlir::isa(type)) { APFloat const_value(.0); if (failed(GetConstantSplatValue(init_value, const_value)) || !const_value.isZero()) return failure(); - } else if (type.isa() && type.isSignlessInteger()) { + } else if (mlir::isa(type) && type.isSignlessInteger()) { APInt const_value; if (failed(GetConstantSplatValue(init_value, const_value)) || !const_value.isZero()) @@ -2003,13 +1994,13 @@ class ConvertReduceOpToTfMax using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp; LogicalResult MatchInitValue(Value init_value) const override { - auto type = init_value.getType().cast().getElementType(); - if (type.isa()) { + auto type = mlir::cast(init_value.getType()).getElementType(); + if (mlir::isa(type)) { APFloat const_value(.0); if (failed(GetConstantSplatValue(init_value, const_value)) || !const_value.isInfinity() || !const_value.isNegative()) return failure(); - } else if (type.isa() && type.isSignlessInteger()) { + } else if (mlir::isa(type) && type.isSignlessInteger()) { APInt const_value; if (failed(GetConstantSplatValue(init_value, const_value)) || !const_value.isMinSignedValue()) @@ -2027,14 +2018,14 @@ class ConvertReduceOpToTfMin using ConvertReduceOpToTfOp::ConvertReduceOpToTfOp; LogicalResult MatchInitValue(Value init_value) const override { - auto type = init_value.getType().cast().getElementType(); + auto type = mlir::cast(init_value.getType()).getElementType(); - if (type.isa()) { + if (mlir::isa(type)) { APFloat const_value(.0); if (failed(GetConstantSplatValue(init_value, const_value)) || !const_value.isInfinity() || const_value.isNegative()) return failure(); - } else if (type.isa() && type.isSignlessInteger()) { + } else if (mlir::isa(type) && type.isSignlessInteger()) { APInt const_value; if (failed(GetConstantSplatValue(init_value, const_value)) || !const_value.isMaxSignedValue()) @@ -2088,7 +2079,7 @@ class ConvertReduceOpToTfArgmax auto element_type = attr.getType().getElementType(); if (attr.getNumElements() != 1 || !element_type.isIntOrFloat()) return false; - if (element_type.isa()) { + if (mlir::isa(element_type)) { auto value = *attr.value_begin(); return value.isNegative() && value.isInfinity(); } else if (element_type.isInteger(1)) { @@ -2112,7 +2103,7 @@ class ConvertReduceOpToTfArgmin auto element_type = attr.getType().getElementType(); if (attr.getNumElements() != 1 || !element_type.isIntOrFloat()) return false; - if (element_type.isa()) { + if (mlir::isa(element_type)) { auto value = *attr.value_begin(); return !value.isNegative() && value.isInfinity(); } else if (element_type.isInteger(1)) { @@ -2134,18 +2125,18 @@ class ConvertIotaOpToTfRange : public OpConversionPattern { mhlo::IotaOp iota_op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { RankedTensorType type = - iota_op.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(iota_op.getType()); // TF::RangeOp doesn't support UI16. if (!type || type.getElementType().isUnsignedInteger(16)) return failure(); const uint64_t dimension = iota_op.getIotaDimension(); Type element_type = type.getElementType(); Attribute start, limit, delta; - if (element_type.isa()) { + if (mlir::isa(element_type)) { start = rewriter.getFloatAttr(element_type, 0.0); limit = rewriter.getFloatAttr(element_type, type.getShape()[dimension]); delta = rewriter.getFloatAttr(element_type, 1.0); - } else if (element_type.isa()) { + } else if (mlir::isa(element_type)) { start = rewriter.getIntegerAttr(element_type, 0); limit = rewriter.getIntegerAttr(element_type, type.getShape()[dimension]); delta = rewriter.getIntegerAttr(element_type, 1); @@ -2249,9 +2240,10 @@ bool IsSpatialPoolingWithoutDilation( // Check that the individual padding values are corresponding to SAME // padding from TensorFlow. - auto operand_type = rw.getInputs()[0].getType().dyn_cast(); + auto operand_type = + mlir::dyn_cast(rw.getInputs()[0].getType()); RankedTensorType output_type = - rw.getResult(0).getType().dyn_cast(); + mlir::dyn_cast(rw.getResult(0).getType()); if (!operand_type || !output_type) return false; for (uint64_t i = 1; i < rank - 1; ++i) { @@ -2293,12 +2285,13 @@ class ConvertLoweredCumOp : public OpConversionPattern { auto const_op = llvm::dyn_cast_or_null( rw.getInitValues()[0].getDefiningOp()); if (!const_op) return failure(); - auto const_op_dense_value = const_op.getValue().cast(); + auto const_op_dense_value = + mlir::cast(const_op.getValue()); if (!const_op_dense_value || !IsInitValue(const_op_dense_value)) { return failure(); } - auto operand_type = rw.getInputs()[0].getType().cast(); + auto operand_type = mlir::cast(rw.getInputs()[0].getType()); // For a cumulative op, require a tensor of 1s for each dimension in // operand. @@ -2383,7 +2376,7 @@ class ConvertLoweredCumSumOp auto element_type = attr.getType().getElementType(); if (attr.getNumElements() != 1 || !element_type.isIntOrFloat()) return false; - if (element_type.isa()) { + if (mlir::isa(element_type)) { auto value = *attr.value_begin(); return value.isZero(); } @@ -2399,7 +2392,7 @@ class ConvertLoweredCumProdOp auto element_type = attr.getType().getElementType(); if (attr.getNumElements() != 1 || !element_type.isIntOrFloat()) return false; - if (element_type.isa()) { + if (mlir::isa(element_type)) { auto value = *attr.value_begin(); return value.isExactlyValue(1.0); } @@ -2431,8 +2424,8 @@ class ConvertAvgPoolOp : public OpConversionPattern { // Check that this is a floating point reduce window with a rank of 4 or 5. const RankedTensorType rw_type = - rw.getResult(0).getType().dyn_cast(); - if (!rw_type || !rw_type.getElementType().isa() || + mlir::dyn_cast(rw.getResult(0).getType()); + if (!rw_type || !mlir::isa(rw_type.getElementType()) || rw_type.getRank() <= 3 || rw_type.getRank() > 5) return failure(); @@ -2568,8 +2561,8 @@ class ConvertMaxPoolOp : public OpConversionPattern { // Check that this is a floating point reduce window with a rank of 4 or 5. const RankedTensorType rw_type = - rw.getResult(0).getType().dyn_cast(); - if (!rw_type || !rw_type.getElementType().isa() || + mlir::dyn_cast(rw.getResult(0).getType()); + if (!rw_type || !mlir::isa(rw_type.getElementType()) || rw_type.getRank() <= 3 || rw_type.getRank() > 5) return failure(); @@ -2639,7 +2632,7 @@ class ConvertMaxPoolOp : public OpConversionPattern { // Returns the shape of the given value in a Constant Op. arith::ConstantOp ShapeToConst(PatternRewriter& rewriter, Value value) { - ArrayRef shape = value.getType().cast().getShape(); + ArrayRef shape = mlir::cast(value.getType()).getShape(); auto attr_type = RankedTensorType::get({static_cast(shape.size())}, rewriter.getIntegerType(64)); auto attr = DenseElementsAttr::get(attr_type, shape); @@ -2659,36 +2652,37 @@ bool IsSign(APFloat a, APFloat sign) { } bool IsDenseSplatIntAttr(ElementsAttr float_or_int) { - return float_or_int.isa() && - float_or_int.isa(); + return mlir::isa(float_or_int) && + mlir::isa(float_or_int); } bool IsDenseSplatFloatAttr(ElementsAttr float_or_int) { - return float_or_int.isa() && - float_or_int.isa(); + return mlir::isa(float_or_int) && + mlir::isa(float_or_int); } bool ValueIsReciprocal(ElementsAttr float_or_int, ElementsAttr rhs) { if (IsDenseSplatFloatAttr(float_or_int) && IsDenseSplatFloatAttr(float_or_int)) { - return (float_or_int.cast().getSplatValue() * - rhs.cast().getSplatValue()) + return (mlir::cast(float_or_int) + .getSplatValue() * + mlir::cast(rhs).getSplatValue()) .isExactlyValue(1.0); } else if (IsDenseSplatIntAttr(float_or_int) && IsDenseSplatIntAttr(float_or_int)) { - return (float_or_int.cast().getSplatValue() * - rhs.cast().getSplatValue()) == 1; + return (mlir::cast(float_or_int).getSplatValue() * + mlir::cast(rhs).getSplatValue()) == 1; } return false; } bool ValueEquals(ElementsAttr float_or_int, double rhs) { if (IsDenseSplatFloatAttr(float_or_int)) { - return float_or_int.cast() + return mlir::cast(float_or_int) .getSplatValue() .isExactlyValue(rhs); } else if (IsDenseSplatIntAttr(float_or_int)) { - return float_or_int.cast().getSplatValue() == + return mlir::cast(float_or_int).getSplatValue() == static_cast(rhs); } return false; @@ -2696,11 +2690,12 @@ bool ValueEquals(ElementsAttr float_or_int, double rhs) { bool ValueGreaterThanZero(ElementsAttr float_or_int) { if (IsDenseSplatIntAttr(float_or_int)) { - auto value = float_or_int.cast().getSplatValue(); + auto value = + mlir::cast(float_or_int).getSplatValue(); return !value.isNegative() && !value.isZero(); } else if (IsDenseSplatFloatAttr(float_or_int)) { auto value = - float_or_int.cast().getSplatValue(); + mlir::cast(float_or_int).getSplatValue(); return !value.isNaN() && !value.isNegative() && !value.isZero(); } return false; @@ -2723,13 +2718,13 @@ bool TensorIsSign(PatternRewriter& rewriter, ElementsAttr float_or_int, int_spl && sgn_cst_spl) { return IsSign(int_spl.getValue(), sgn_cst_spl.getValue()); } - if (float_or_int.isa()) { + if (mlir::isa(float_or_int)) { auto sgn_splat_value = sgn_splat.getSplatValue(); return llvm::all_of(float_or_int.getValues(), [&](APFloat value) { return IsSign(value, sgn_splat_value); }); } - if (float_or_int.isa()) { + if (mlir::isa(float_or_int)) { auto sgn_splat_value = sgn_splat.getSplatValue(); return llvm::all_of(float_or_int.getValues(), [&](APInt value) { return IsSign(value, sgn_splat_value); @@ -2778,9 +2773,11 @@ class ConvertGatherOp : public OpConversionPattern { Value start_indices = gather_op.getStartIndices(); // Can only convert with static shaped gather. - ShapedType operand_type = operand.getType().cast(); - ShapedType start_indices_type = start_indices.getType().cast(); - ShapedType result_type = gather_op.getResult().getType().cast(); + ShapedType operand_type = mlir::cast(operand.getType()); + ShapedType start_indices_type = + mlir::cast(start_indices.getType()); + ShapedType result_type = + mlir::cast(gather_op.getResult().getType()); if (!operand_type.hasStaticShape()) { gather_op.emitOpError() << "Dynamic shaped operand is not supported."; return failure(); @@ -2917,9 +2914,11 @@ class ConvertGatherOp : public OpConversionPattern { static const int max_batch_size = 50; // Can only convert with static shaped gather. - ShapedType operand_type = operand.getType().cast(); - ShapedType start_indices_type = start_indices.getType().cast(); - ShapedType result_type = gather_op.getResult().getType().cast(); + ShapedType operand_type = mlir::cast(operand.getType()); + ShapedType start_indices_type = + mlir::cast(start_indices.getType()); + ShapedType result_type = + mlir::cast(gather_op.getResult().getType()); if (!operand_type.hasStaticShape() || !start_indices_type.hasStaticShape() || !result_type.hasStaticShape()) { return rewriter.notifyMatchFailure( @@ -3140,7 +3139,7 @@ class ConvertWhileOp : public OpConversionPattern { // This rule doesn't support mhlo::WhileOp with tuple inputs. for (auto type : while_op->getOperandTypes()) { - if (type.isa()) return failure(); + if (mlir::isa(type)) return failure(); } // Creates a TF::WhileRegionOp to replace the mhlo::WhileOp. HLO WhileOp @@ -3296,7 +3295,7 @@ class ConvertCustomCallWithApproxTopK } } auto backend_config = - op.getBackendConfigAttr().dyn_cast_or_null(); + mlir::dyn_cast_or_null(op.getBackendConfigAttr()); if (!backend_config) { return op.emitOpError() << "Missing backend_config attribute"; } @@ -3385,12 +3384,13 @@ class ConvertCustomCallWithApproxTopK << "ApproxTopK takes exactly 1 called_computation."; } mlir::func::FuncOp callee = module_op_->lookupSymbol( - op.getCalledComputations()[0].cast()); + mlir::cast(op.getCalledComputations()[0])); mlir::FunctionType callee_type = callee.getFunctionType(); SmallVector expected_callee_input_types; auto num_inputs = op.getInputs().size() / 2; for (unsigned i = 0; i < num_inputs; ++i) { - auto input_type = op.getOperand(i).getType().dyn_cast(); + auto input_type = + mlir::dyn_cast(op.getOperand(i).getType()); auto scalar = RankedTensorType::get({}, input_type.getElementType()); expected_callee_input_types.push_back(scalar); expected_callee_input_types.push_back(scalar); @@ -3491,12 +3491,10 @@ class ConvertRealDynamicSliceOp LogicalResult matchAndRewrite( mhlo::RealDynamicSliceOp real_dynamic_slice_op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { - auto start_indices_type = real_dynamic_slice_op.getStartIndices() - .getType() - .cast(); - auto end_indices_type = real_dynamic_slice_op.getLimitIndices() - .getType() - .cast(); + auto start_indices_type = mlir::cast( + real_dynamic_slice_op.getStartIndices().getType()); + auto end_indices_type = mlir::cast( + real_dynamic_slice_op.getLimitIndices().getType()); if (start_indices_type.getNumDynamicDims() != 0 || end_indices_type.getNumDynamicDims() != 0) { @@ -3522,7 +3520,7 @@ class ConvertDynamicIotaOp : public OpConversionPattern { mhlo::DynamicIotaOp dynamic_iota_op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { RankedTensorType type = - dynamic_iota_op.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(dynamic_iota_op.getType()); if (!type || type.getElementType().isUnsignedInteger(64)) { return rewriter.notifyMatchFailure(dynamic_iota_op, "TF::RangeOp doesn't support UI64"); @@ -3538,19 +3536,19 @@ class ConvertDynamicIotaOp : public OpConversionPattern { const uint64_t dimension = dynamic_iota_op.getIotaDimension(); Type element_type = type.getElementType(); Attribute start, delta; - if (element_type.isa()) { + if (mlir::isa(element_type)) { start = rewriter.getFloatAttr(element_type, 0.0); delta = rewriter.getFloatAttr(element_type, 1.0); - } else if (element_type.isa()) { + } else if (mlir::isa(element_type)) { start = rewriter.getIntegerAttr(element_type, 0); delta = rewriter.getIntegerAttr(element_type, 1); } else { return failure(); } auto output_shape = dynamic_iota_op.getOperand(); - if (element_type.isa()) { + if (mlir::isa(element_type)) { auto cast_type = - output_shape.getType().cast().clone(element_type); + mlir::cast(output_shape.getType()).clone(element_type); output_shape = rewriter.create(dynamic_iota_op.getLoc(), cast_type, output_shape); } @@ -3581,7 +3579,7 @@ bool IsTFStyleBroadcast(DenseIntElementsAttr broadcast_dimensions, // broadcast_dimensions is an increasing list by definition, thus it suffices // to check the first element. int64_t input_rank = broadcast_dimensions.getNumElements(); - int64_t output_rank = output.getType().cast().getRank(); + int64_t output_rank = mlir::cast(output.getType()).getRank(); return input_rank == 0 || (broadcast_dimensions.getValues()[0].getSExtValue() == output_rank - input_rank); @@ -3606,11 +3604,12 @@ arith::ConstantOp ExpandedShape(PatternRewriter& rewriter, Value input, Value output) { // Initialize expanded shape with output rank and dimensions of 1. SmallVector expanded_shape( - output.getType().cast().getRank(), + mlir::cast(output.getType()).getRank(), /*Value=*/rewriter.getI64IntegerAttr(1)); // Set dimension sizes specified by broadcast_dimensions. - ArrayRef input_shape = input.getType().cast().getShape(); + ArrayRef input_shape = + mlir::cast(input.getType()).getShape(); for (auto x : llvm::enumerate(broadcast_dimensions)) { expanded_shape[x.value().getSExtValue()] = rewriter.getI64IntegerAttr(input_shape[x.index()]); @@ -3627,9 +3626,9 @@ arith::ConstantOp ExpandedShape(PatternRewriter& rewriter, Value input, Value ExpandedDynamicShape(PatternRewriter& rewriter, Value input, DenseIntElementsAttr broadcast_dimensions, Value output) { - assert(output.getType().cast() && + assert(mlir::cast(output.getType()) && "output type must be of ShapedType"); - int64_t output_rank = output.getType().cast().getRank(); + int64_t output_rank = mlir::cast(output.getType()).getRank(); llvm::SmallVector expanded_dimensions; llvm::SmallSet broadcast_dimensions_values; for (auto x : llvm::enumerate(broadcast_dimensions)) { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.cc index 9d52ee30dd3ce7..520cff8681156a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.cc @@ -37,7 +37,7 @@ LogicalResult ConvertCustomCallOp::matchAndRewrite( rewriter.getStringAttr(mhlo_custom_call.getCallTargetName())); if (auto bc = mhlo_custom_call.getBackendConfig()) { - if (auto stringattr = bc->dyn_cast_or_null()) { + if (auto stringattr = mlir::dyn_cast_or_null(*bc)) { tfl_custom.setCustomOptionAttr( TFL::ConstBytesAttr::get(rewriter.getContext(), stringattr)); } @@ -53,7 +53,7 @@ LogicalResult ConvertCustomCallOp::matchAndRewrite( std::optional IsCustomCallLegal(mhlo::CustomCallOp op) { if (op.getCallTargetName().starts_with("custom_call.")) { auto bc = op.getBackendConfig(); - if (!bc || bc->isa()) { + if (!bc || mlir::isa(*bc)) { return false; } } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.cc index ef3337cbca27cd..ccd726d2737f84 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/dot_general.cc @@ -169,7 +169,7 @@ Value BuildDotOperandFlattenedShapeOp(Value operand, DotDimensionsInfo dot_dimensions_info, ImplicitLocOpBuilder& builder, bool is_lhs) { - auto operand_type = operand.getType().cast(); + auto operand_type = mlir::cast(operand.getType()); auto operand_shape = builder.create( RankedTensorType::get(static_cast(operand_type.getRank()), builder.getIntegerType(32)), @@ -248,8 +248,8 @@ Value BuildDotOperandFlattenedShapeOp(Value operand, Value ConvertDot(PatternRewriter& rewriter, Value lhs, Value rhs, mhlo::DotDimensionNumbersAttr dot_dimension_numbers, ShapedType result_type, mlir::Location loc) { - auto lhs_type = lhs.getType().cast(); - auto rhs_type = rhs.getType().cast(); + auto lhs_type = mlir::cast(lhs.getType()); + auto rhs_type = mlir::cast(rhs.getType()); const int lhs_rank = lhs_type.getRank(); const int rhs_rank = rhs_type.getRank(); ImplicitLocOpBuilder builder(loc, rewriter); @@ -412,7 +412,7 @@ Value ConvertDot(PatternRewriter& rewriter, Value lhs, Value rhs, // be inserted when necessary. See ConvertDotGeneralOp for additional notes. Value ConvertDotOp(PatternRewriter& rewriter, Operation* old_op) { auto dot_op = cast(old_op); - auto lhs_rank = dot_op.getLhs().getType().cast().getRank(); + auto lhs_rank = mlir::cast(dot_op.getLhs().getType()).getRank(); auto dot_dimension_numbers = mhlo::DotDimensionNumbersAttr::get(rewriter.getContext(), /*lhsBatchingDimensions=*/{}, @@ -422,15 +422,16 @@ Value ConvertDotOp(PatternRewriter& rewriter, Operation* old_op) { /*rhsContractingDimensions=*/{0}); return ConvertDot( rewriter, dot_op.getLhs(), dot_op.getRhs(), dot_dimension_numbers, - dot_op.getResult().getType().cast(), dot_op.getLoc()); + mlir::cast(dot_op.getResult().getType()), dot_op.getLoc()); } Value ConvertDotGeneralOp(PatternRewriter& rewriter, Operation* old_op) { auto dot_general_op = cast(old_op); - return ConvertDot(rewriter, dot_general_op.getLhs(), dot_general_op.getRhs(), - dot_general_op.getDotDimensionNumbers(), - dot_general_op.getResult().getType().cast(), - dot_general_op.getLoc()); + return ConvertDot( + rewriter, dot_general_op.getLhs(), dot_general_op.getRhs(), + dot_general_op.getDotDimensionNumbers(), + mlir::cast(dot_general_op.getResult().getType()), + dot_general_op.getLoc()); } } // namespace odml } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h index bfb705d00813d5..157cb82ce8e94e 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/reduce.h @@ -74,7 +74,7 @@ class ConvertReduceOpToArgMinMax : public OpConversionPattern { if (!MatchIota(reduce_op.getDimensions(), iota)) return failure(); // Match the reduction computation. - const bool is_float = operand_init.getElementType().isa(); + const bool is_float = mlir::isa(operand_init.getElementType()); if (failed(MatchReduceToArgMinMaxType1(reduce_op, is_float, is_argmax)) && failed(MatchReduceToArgMinMaxType2(reduce_op, is_argmax))) return rewriter.notifyMatchFailure( @@ -91,8 +91,8 @@ class ConvertReduceOpToArgMinMax : public OpConversionPattern { // Generate a Max and an ArgMax of as the mhlo op returns both while in TF // we have separate ops for them. If only one of them is used then the other // one will be garbage collected later. - if (!operand.getType().isa()) return failure(); - auto operand_type = operand.getType().cast(); + if (!mlir::isa(operand.getType())) return failure(); + auto operand_type = mlir::cast(operand.getType()); if (operand_type.getElementType().isInteger(1)) { // TF does not support min or max on boolean (int1) arguments. // Use AnyOp for MaxOp and AllOp for MinOp. diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.h index fb0e0d80a4eb9b..f8ea8227137617 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/scatter.h @@ -95,9 +95,9 @@ class ConvertScatterOp : public OpConversionPattern { OperandRange updates = scatter_op.getUpdates(); if (operands.size() != 1 || updates.size() != 1) return failure(); - ShapedType operand_type = operands[0].getType().cast(); - ShapedType indices_type = indices.getType().cast(); - ShapedType updates_type = updates[0].getType().cast(); + ShapedType operand_type = mlir::cast(operands[0].getType()); + ShapedType indices_type = mlir::cast(indices.getType()); + ShapedType updates_type = mlir::cast(updates[0].getType()); Value new_updates = updates[0]; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.cc index c2f533776d0408..783f0431e9b964 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/util.cc @@ -203,7 +203,7 @@ Value InsertTranspose(Value value, int batch_dim, int feature_dim, int default_batch_dim, int default_feature_dim, int default_spatial_dim_start, int num_spatial_dims, ConversionPatternRewriter& rewriter) { - auto type = value.getType().cast(); + auto type = mlir::cast(value.getType()); DenseIntElementsAttr permutation; const int spatial_dim_start = spatial_dimensions.front(); if (!NeedsReformatTypeAndPermutation( @@ -224,7 +224,7 @@ Value InsertTranspose(Value value, int batch_dim, int feature_dim, Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) { IntegerType new_ele_type = rewriter.getIntegerType(32); - if (auto shaped_type = val.getType().dyn_cast()) { + if (auto shaped_type = mlir::dyn_cast(val.getType())) { ShapedType new_type = RankedTensorType::get(shaped_type.getShape(), new_ele_type); return rewriter.create(loc, new_type, val); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc index 6e0a3325460b7a..50521a02c7b907 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc @@ -63,13 +63,13 @@ LogicalResult BuildOption(flexbuffers::Builder* fbb, Operation* op, const char* key = pair.getName().data(); const auto attr = pair.getValue(); - if (attr.isa<::mlir::IntegerAttr>()) { - fbb->Int(key, attr.dyn_cast().getInt()); + if (mlir::isa<::mlir::IntegerAttr>(attr)) { + fbb->Int(key, mlir::dyn_cast(attr).getInt()); return success(); } - if (attr.isa<::mlir::FloatAttr>()) { - fbb->Double(key, attr.dyn_cast().getValueAsDouble()); + if (mlir::isa<::mlir::FloatAttr>(attr)) { + fbb->Double(key, mlir::dyn_cast(attr).getValueAsDouble()); return success(); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_custom_call_to_composite.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_custom_call_to_composite.cc index 4cfb0e04e96af4..e699c303bbaac2 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_custom_call_to_composite.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_custom_call_to_composite.cc @@ -45,18 +45,19 @@ struct ReplaceCustomCallWithComposite final LogicalResult matchAndRewrite(mlir::stablehlo::CustomCallOp op, PatternRewriter &rewriter) const override { auto backendConfig = - op->getAttr("composite.backend_config").dyn_cast(); + mlir::dyn_cast(op->getAttr("composite.backend_config")); if (!backendConfig) return op->emitError( "custom_call has no 'composite.backend_config' attribute or the " "attribute is not a dictionary"); - auto name = backendConfig.get("name").dyn_cast(); + auto name = mlir::dyn_cast(backendConfig.get("name")); if (!name) return op->emitError( "backend_config has no 'name' key or the name value is not a string"); - auto attrs = backendConfig.get("attributes").dyn_cast(); + auto attrs = + mlir::dyn_cast(backendConfig.get("attributes")); if (!attrs) return op->emitError( "backend_config has no 'attributes' key or the attributes value is " @@ -66,7 +67,7 @@ struct ReplaceCustomCallWithComposite final if (!calledComputations || calledComputations.size() != 1) return op->emitError("expected exactly one called_computation"); - auto decomposition = calledComputations[0].cast(); + auto decomposition = mlir::cast(calledComputations[0]); auto composite = rewriter.create( op.getLoc(), op.getResultTypes(), op.getOperands(), name.str(), attrs, diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc index 141983965ff7ce..e091e8cb3201c1 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc @@ -93,7 +93,7 @@ class StablehloToOdmlTypeConverter : public vhlo::VhloTypeConverter { return attr; if (auto stablehlo_attr = - attr.dyn_cast_or_null()) { + mlir::dyn_cast_or_null(attr)) { return vhlo::TypeExtensionsV1Attr::get(stablehlo_attr.getContext(), stablehlo_attr.getBounds()); } @@ -119,7 +119,8 @@ class VhloToStablehloTypeConverter : public vhlo::VhloTypeConverter { } Attribute convertEncoding(Attribute attr) const final { - if (auto vhlo_attr = attr.dyn_cast_or_null()) { + if (auto vhlo_attr = + mlir::dyn_cast_or_null(attr)) { return stablehlo::TypeExtensionsAttr::get(vhlo_attr.getContext(), vhlo_attr.getBounds()); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc index f7a136f2259ad2..82c7a4b4687055 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc @@ -105,8 +105,8 @@ void PrintOpStatsPass::runOnOperation() { (dyn_cast_or_null(op)) ? op->getOperand(1) : op->getResult(0); - ShapedType value_shaped_type = - value_for_deducing_op_type.getType().dyn_cast_or_null(); + ShapedType value_shaped_type = mlir::dyn_cast_or_null( + value_for_deducing_op_type.getType()); if (value_shaped_type != nullptr) { auto operand_or_result = value_shaped_type.getElementType(); std::string dtype; @@ -122,15 +122,16 @@ void PrintOpStatsPass::runOnOperation() { }) .Case([&](Type) { auto uniform_quantized_dtype = - operand_or_result.dyn_cast_or_null() + mlir::dyn_cast_or_null( + operand_or_result) .getStorageType(); dtype = absl::StrCat( "uq_", uniform_quantized_dtype.getIntOrFloatBitWidth()); }) .Case([&](Type) { auto uniform_quantized_dtype = - operand_or_result - .dyn_cast_or_null() + mlir::dyn_cast_or_null( + operand_or_result) .getStorageType(); dtype = absl::StrCat( "uq_", uniform_quantized_dtype.getIntOrFloatBitWidth()); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc index b0797521798994..d9c23dfa12b8ae 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -37,8 +38,8 @@ namespace odml { // Convert mhlo.dot to mhlo.dot_general. LogicalResult ConvertDotToDotGeneral(mhlo::DotOp op, PatternRewriter &rewriter) { - auto lhs_type = op.getLhs().getType().cast(); - auto rhs_type = op.getRhs().getType().cast(); + auto lhs_type = mlir::cast(op.getLhs().getType()); + auto rhs_type = mlir::cast(op.getRhs().getType()); if (!lhs_type.hasRank() || !rhs_type.hasRank()) { return rewriter.notifyMatchFailure(op, "unsupported unranked input type"); } @@ -264,7 +265,7 @@ LogicalResult LiftDotConcatLHS(mhlo::ConcatenateOp concat, new_concat_shape[new_concat_dim] = 0; for (auto v : all_dot_lhs) { new_concat_shape[new_concat_dim] += - v.getType().dyn_cast().getShape()[new_concat_dim]; + mlir::dyn_cast(v.getType()).getShape()[new_concat_dim]; } auto new_concat = rewriter.create( @@ -353,7 +354,7 @@ LogicalResult LiftDotConcatLHSAndRHS(mhlo::ConcatenateOp concat, lhs_new_concat_shape[lhs_batch_dim] = 0; for (auto v : all_dot_lhs) { lhs_new_concat_shape[lhs_batch_dim] += - v.getType().dyn_cast().getShape()[lhs_batch_dim]; + mlir::dyn_cast(v.getType()).getShape()[lhs_batch_dim]; } const int64_t rhs_batch_dim = first_dot.getDotDimensionNumbers().getRhsBatchingDimensions()[0]; @@ -362,7 +363,7 @@ LogicalResult LiftDotConcatLHSAndRHS(mhlo::ConcatenateOp concat, rhs_new_concat_shape[rhs_batch_dim] = 0; for (auto v : all_dot_rhs) { rhs_new_concat_shape[rhs_batch_dim] += - v.getType().dyn_cast().getShape()[rhs_batch_dim]; + mlir::dyn_cast(v.getType()).getShape()[rhs_batch_dim]; } auto lhs_new_concat = rewriter.create( diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.cc index 8234380f0f4182..81c6fc47473d43 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" namespace mlir { @@ -62,7 +63,7 @@ class RenameEntrypointToMainPass // } // clang-format on for (auto attr : session_initializer.getInitializers()) { - auto sym_attr = attr.dyn_cast(); + auto sym_attr = mlir::dyn_cast(attr); if (!sym_attr) break; entrypoints.erase(sym_attr.getValue()); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc index 4304d34f4743ec..f86b78275fb951 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo.cc @@ -66,7 +66,7 @@ class ConvertReduceOpToTFLiteArgmax auto element_type = attr.getType().getElementType(); if (attr.getNumElements() != 1 || !element_type.isIntOrFloat()) return false; - if (element_type.isa()) { + if (mlir::isa(element_type)) { auto value = *attr.value_begin(); return value.isNegative() && value.isInfinity(); } else if (element_type.isInteger(1)) { @@ -90,7 +90,7 @@ class ConvertReduceOpToTFLiteArgmin auto element_type = attr.getType().getElementType(); if (attr.getNumElements() != 1 || !element_type.isIntOrFloat()) return false; - if (element_type.isa()) { + if (mlir::isa(element_type)) { auto value = *attr.value_begin(); return !value.isNegative() && value.isInfinity(); } else if (element_type.isInteger(1)) { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfold_splat_constant_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfold_splat_constant_pass.cc index e7f86a022d2274..7a3abd35d0d376 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfold_splat_constant_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfold_splat_constant_pass.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { @@ -60,7 +61,7 @@ class UnfoldSplatConstantPass void UnfoldSplatConstant(mlir::OpBuilder* op_builder, mhlo::ConstantOp const_op) const { auto splat_elements_attr = - const_op.getValue().dyn_cast(); + mlir::dyn_cast(const_op.getValue()); if (!splat_elements_attr) { return; } @@ -68,8 +69,8 @@ class UnfoldSplatConstantPass return; } auto element_type = splat_elements_attr.getType().getElementType(); - if (element_type.isa() || - element_type.isa()) { + if (mlir::isa(element_type) || + mlir::isa(element_type)) { return; } op_builder->setInsertionPoint(const_op); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc index f4cdad00b79774..dadcabc55a5e57 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/unfuse_batch_norm_pass.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h" @@ -60,7 +61,8 @@ Value broadcastToFeatureDim(Location loc, RankedTensorType result_type, // Gets the shape of operand, assuming it is a dynamic shape with static rank. Value getShapeValue(Location loc, Value operand, PatternRewriter &rewriter) { - RankedTensorType resultType = operand.getType().dyn_cast(); + RankedTensorType resultType = + mlir::dyn_cast(operand.getType()); return rewriter.create( loc, RankedTensorType::get(/*shape=*/{resultType.getRank()}, @@ -92,8 +94,8 @@ Value materializeEpsilon(Operation *op, FloatAttr epsilon_attr, } auto scalar_type = RankedTensorType::get(/*shape=*/{}, fp_type); - auto epsilon_tensor_attr = - DenseElementsAttr::get(scalar_type, {epsilon_attr.cast()}); + auto epsilon_tensor_attr = DenseElementsAttr::get( + scalar_type, {mlir::cast(epsilon_attr)}); Value epsilon = b.create(epsilon_tensor_attr); auto dims_type = RankedTensorType::get(/*shape=*/{0}, b.getIntegerType(64)); auto dims = DenseIntElementsAttr::get(dims_type, SmallVector{}); @@ -113,7 +115,7 @@ class UnfuseBatchNormTrainingPattern LogicalResult matchAndRewrite(mhlo::BatchNormTrainingOp bn_op, PatternRewriter &rewriter) const override { auto inputs = bn_op.getOperand(); - auto input_type = inputs.getType().dyn_cast(); + auto input_type = mlir::dyn_cast(inputs.getType()); if (!input_type) { return failure(); } @@ -172,13 +174,14 @@ class UnfuseBatchNormInferencePattern // Enforce type invariants. // Note that we deduce the actual element type from the variance, // which should not be subject to quantization at a higher level. - auto input_type = bn_op.getOperand().getType().dyn_cast(); + auto input_type = + mlir::dyn_cast(bn_op.getOperand().getType()); auto variance_type = - bn_op.getVariance().getType().dyn_cast(); + mlir::dyn_cast(bn_op.getVariance().getType()); if (!input_type || !variance_type) { return failure(); } - auto fp_type = variance_type.getElementType().dyn_cast(); + auto fp_type = mlir::dyn_cast(variance_type.getElementType()); if (!fp_type) { return failure(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/analyze_variables.cc b/tensorflow/compiler/mlir/lite/transforms/analyze_variables.cc index 6fd0278bf909e4..39afd416ab1aa2 100644 --- a/tensorflow/compiler/mlir/lite/transforms/analyze_variables.cc +++ b/tensorflow/compiler/mlir/lite/transforms/analyze_variables.cc @@ -16,6 +16,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -92,7 +93,8 @@ void AnalyzeVariablesPass::runOnOperation() { // Note: this might disable native variables in more than needed cases. // TODO(b/189370197): Enhance variable analysis. for (auto operand : op->getOperands()) { - if (getElementTypeOrSelf(operand.getType()).isa()) { + if (mlir::isa( + getElementTypeOrSelf(operand.getType()))) { legalize_to_tfl = false; return WalkResult::interrupt(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/decompose_hybrid_quantization.cc b/tensorflow/compiler/mlir/lite/transforms/decompose_hybrid_quantization.cc index 5329274271c55c..3fcd82ef033938 100644 --- a/tensorflow/compiler/mlir/lite/transforms/decompose_hybrid_quantization.cc +++ b/tensorflow/compiler/mlir/lite/transforms/decompose_hybrid_quantization.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -62,20 +63,20 @@ class DequantizeConverter : public OpRewritePattern { bool allTypesFp = true; bool allTypesQuantizedOrInt = true; for (auto operand : op->getOperands()) { - ShapedType type = operand.getType().template dyn_cast(); + ShapedType type = mlir::dyn_cast(operand.getType()); if (!type) continue; - allTypesFp &= !type.getElementType().isa(); + allTypesFp &= !mlir::isa(type.getElementType()); allTypesQuantizedOrInt &= - (type.getElementType().isa() || - type.getElementType().isa()); + (mlir::isa(type.getElementType()) || + mlir::isa(type.getElementType())); } for (auto result : op->getResults()) { - ShapedType type = result.getType().template cast(); - allTypesFp &= !type.getElementType().isa(); + ShapedType type = mlir::cast(result.getType()); + allTypesFp &= !mlir::isa(type.getElementType()); allTypesQuantizedOrInt &= - (type.getElementType().isa() || - type.getElementType().isa()); + (mlir::isa(type.getElementType()) || + mlir::isa(type.getElementType())); } // If all quantized or floating point then types are consistent. diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index 2f015e61d58fe6..94ed4b1e0340a5 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -152,7 +153,7 @@ void DefaultQuantParamsPass::AddToWorkListIfUnquantized( Value value, std::vector *values) { // If the result isn't with float type, this result is an integer tensor and // doesn't require quantization. - auto tensor_type = value.getType().dyn_cast(); + auto tensor_type = mlir::dyn_cast(value.getType()); if (!tensor_type) { // There are none type values. return; @@ -202,9 +203,9 @@ quant::QuantParams DefaultQuantParamsPass::GetQuantParamsForBias( for (int non_bias : non_biases) { Operation *non_bias_define = op->getOperand(non_bias).getDefiningOp(); if (auto dequant = llvm::dyn_cast(non_bias_define)) { - auto non_bias_type = dequant.getInput().getType().cast(); + auto non_bias_type = mlir::cast(dequant.getInput().getType()); auto non_bias_ele_type = - non_bias_type.getElementType().cast(); + mlir::cast(non_bias_type.getElementType()); non_bias_types.push_back(non_bias_ele_type); } else { // The non-bias hasn't been quantized, let's skip this bias. diff --git a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc index 8a3abc94e2af57..5cac14867482bb 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc +++ b/tensorflow/compiler/mlir/lite/transforms/dense_to_sparse.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/lite/kernels/internal/utils/sparsity_format_converter.h" @@ -92,13 +93,13 @@ float CalculateRandomSparsity(const ElementsAttr& attr, int num_elements = type.getNumElements(); int num_zeros = 0; - if (type.getElementType().isa()) { + if (mlir::isa(type.getElementType())) { for (const auto val : attr.getValues()) { if (val.isZero()) { num_zeros++; } } - } else if (type.getElementType().isa()) { + } else if (mlir::isa(type.getElementType())) { for (const auto val : attr.getValues()) { if (val == 0) { num_zeros++; @@ -144,7 +145,7 @@ float CalculateBlockSparsity(const ElementsAttr& attr, const ShapedType& type, sparsity = GetSparsity(type.getNumElements() - format_converter.GetData().size(), type.getNumElements()); - } else if (type.getElementType().isa()) { + } else if (mlir::isa(type.getElementType())) { tflite::internal::sparsity::FormatConverter format_converter( shape, traversal_order, format, b_size, b_map); std::vector data; @@ -179,10 +180,10 @@ InspectResult InspectWeight( InspectResult result = {}; if (auto cst = dyn_cast(inst)) { attr = cst.getValue(); - type = cst.getType().cast(); + type = mlir::cast(cst.getType()); } else if (auto cst = dyn_cast(inst)) { attr = cst.getValue(); - type = cst.getType().cast(); + type = mlir::cast(cst.getType()); } else { result.can_compress = false; return result; @@ -229,10 +230,10 @@ std::vector BuildSparsityParameterAttribute( ShapedType type; if (auto cst = dyn_cast(inst)) { attr = cst.getValue(); - type = cst.getType().cast(); + type = mlir::cast(cst.getType()); } else if (auto cst = dyn_cast(inst)) { attr = cst.getValue(); - type = cst.getType().cast(); + type = mlir::cast(cst.getType()); } else { assert(false && "Expected a constant-like op"); } @@ -317,10 +318,10 @@ void DenseToSparsePass::runOnOperation() { float ratio_threshold = kBlockOverRandomSparsityRatio; if (isa(inst)) { supported_block_size = sparse_op.GetFloatBlockSize(); - type = dyn_cast(inst).getType().cast(); + type = mlir::cast(dyn_cast(inst).getType()); } else if (isa(inst)) { supported_block_size = sparse_op.GetQuantizedBlockSize(); - type = dyn_cast(inst).getType().cast(); + type = mlir::cast(dyn_cast(inst).getType()); ratio_threshold = kBlockOverRandomSparsityRatioQuant; } else { continue; @@ -341,7 +342,7 @@ void DenseToSparsePass::runOnOperation() { SparsityParameterAttr s_param; if (auto cst = dyn_cast(inst)) { auto attr = cst.getValue(); - auto type = cst.getType().cast(); + auto type = mlir::cast(cst.getType()); if (type.getElementType().isF32()) { std::vector dense_data; dense_data.reserve(type.getNumElements()); @@ -385,7 +386,7 @@ void DenseToSparsePass::runOnOperation() { } } else if (auto cst = dyn_cast(inst)) { auto attr = cst.getValue(); - auto type = cst.getType().cast(); + auto type = mlir::cast(cst.getType()); std::vector dense_data; dense_data.reserve(type.getNumElements()); for (const auto& val : attr.getValues()) diff --git a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h index 51068fcf4ac67c..fe8bb7d2ca177f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h +++ b/tensorflow/compiler/mlir/lite/transforms/dilated_conv.h @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -110,7 +111,7 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( } // Allow dynamic width and height dimensions only. - auto result_ty = op.getResult().getType().template cast(); + auto result_ty = mlir::cast(op.getResult().getType()); if (!result_ty.hasRank() || result_ty.getRank() != 4 || result_ty.isDynamicDim(0) || result_ty.isDynamicDim(3)) { return rewriter.notifyMatchFailure( @@ -187,8 +188,7 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( // Make sure that the axis in `expand_op` is constant. if (auto const_op = llvm::dyn_cast(expand_op.getDim().getDefiningOp())) { - expand_axis = (*const_op.getValue() - .cast() + expand_axis = (*mlir::cast(const_op.getValue()) .getValues() .begin()) .getSExtValue(); @@ -208,7 +208,7 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( return rewriter.notifyMatchFailure( squeeze_op, "squeeze dims should have exactly 1 dimension specified"); } - int64_t squeeze_axis = squeeze_dims[0].cast().getInt(); + int64_t squeeze_axis = mlir::cast(squeeze_dims[0]).getInt(); if (squeeze_axis < 0) { // Always squeeze 4D input to 3D input. squeeze_axis += 4; @@ -318,7 +318,8 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( } if (expand_op) { - if (stb_op.getInput().getType().dyn_cast() == nullptr) { + if (mlir::dyn_cast(stb_op.getInput().getType()) == + nullptr) { return rewriter.notifyMatchFailure( stb_op, "SpaceToBatchND op's input should have RankedTensorType"); } @@ -401,7 +402,7 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( expand_op.setOperand(0, stb_op.getInput()); // Calculate the shape for expand. auto input_shape = - stb_op.getInput().getType().cast().getShape(); + mlir::cast(stb_op.getInput().getType()).getShape(); SmallVector expand_shape(input_shape.begin(), input_shape.end()); expand_shape.insert(expand_shape.begin() + expand_axis, 1); @@ -412,7 +413,7 @@ LogicalResult ConvertTFDilatedConvOp::matchAndRewrite( // Update the conv op's output shape. auto bts_output_shape = - bts_op.getOutput().getType().cast().getShape(); + mlir::cast(bts_op.getOutput().getType()).getShape(); SmallVector conv_result_shape(bts_output_shape.begin(), bts_output_shape.end()); conv_result_shape.insert(conv_result_shape.begin() + expand_axis, 1); diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_hashtables.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_hashtables.cc index 252e18e191aea4..5e88048d775532 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_hashtables.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_hashtables.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -142,10 +143,12 @@ bool checkWhetherGraphHasValidStaticLookupTables(ModuleOp module) { // Only allow string -> int64 and int64 -> string mappings due to kernel // capability. - if (!((key_dtype.isa() && value_dtype.isa() && - value_dtype.cast().getWidth() == 64) || - (value_dtype.isa() && key_dtype.isa() && - key_dtype.cast().getWidth() == 64))) { + if (!((mlir::isa(key_dtype) && + mlir::isa(value_dtype) && + mlir::cast(value_dtype).getWidth() == 64) || + (mlir::isa(value_dtype) && + mlir::isa(key_dtype) && + mlir::cast(key_dtype).getWidth() == 64))) { return false; } diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc index e8bae6eb64280f..9b0a80a4f92a71 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_jax_random.cc @@ -84,10 +84,10 @@ void LegalizeJaxRandomPass::runOnOperation() { auto func = getOperation(); if (!IsJaxRandomUniform(func) && !IsJaxRandomNormal(func)) return; auto result_tuple_ty = - func.getFunctionType().getResult(0).dyn_cast_or_null(); + mlir::dyn_cast_or_null(func.getFunctionType().getResult(0)); if (!result_tuple_ty) return; if (result_tuple_ty.size() != 1) return; - auto result_ty = result_tuple_ty.getType(0).dyn_cast(); + auto result_ty = mlir::dyn_cast(result_tuple_ty.getType(0)); func.eraseBody(); func.addEntryBlock(); diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 38a8bffd87bb03..b18227bdddc1ee 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -96,7 +96,7 @@ class LegalizeTFPass : public impl::LegalizeTFPassBase { // Util that casts 'val' to Int32 by adding a cast Op. Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) { IntegerType new_ele_type = rewriter.getIntegerType(32); - if (auto shaped_type = val.getType().dyn_cast()) { + if (auto shaped_type = mlir::dyn_cast(val.getType())) { ShapedType new_type = RankedTensorType::get(shaped_type.getShape(), new_ele_type); return rewriter.createOrFold(loc, new_type, val, @@ -114,7 +114,7 @@ Value CreateCastToInt32(Value val, Location loc, PatternRewriter& rewriter) { // 2. In the default case, cast the `Value` to an int32_t. Value CreateInt32ConstOrCast(Value val, Location loc, PatternRewriter& rewriter) { - if (val.getType().cast().hasStaticShape()) { + if (mlir::cast(val.getType()).hasStaticShape()) { DenseElementsAttr shape_value_attr; if (matchPattern(val, m_Constant(&shape_value_attr))) { SmallVector new_shape_array_i32; @@ -137,7 +137,7 @@ Value CreateInt32ConstOrCast(Value val, Location loc, // Get shape of an operand or result, support both dynamic and static shape. Value GetShape(Value input, Location loc, PatternRewriter& rewriter) { - auto shaped_type = input.getType().cast(); + auto shaped_type = mlir::cast(input.getType()); if (shaped_type.hasStaticShape()) { auto static_shape = shaped_type.getShape(); auto static_shape_type = @@ -271,7 +271,7 @@ bool ConvertTFBatchMatMulOp2TFLFullyConnectedOp(Operation* bmm_op, // Create a tfl.transpose op that performs ZX transpose on `input`. auto create_z_x_transpose_op = [&](Value input) -> Value { - RankedTensorType input_type = input.getType().cast(); + RankedTensorType input_type = mlir::cast(input.getType()); const int input_rank = input_type.getRank(); // Create a 1D I32 tensor for representing the dimension permutation. @@ -364,7 +364,7 @@ LogicalResult ConvertTFMatMulOp::matchAndRewrite( auto rhs = op->getOperand(1); auto transpose = [&](Value input) -> std::pair { RankedTensorType type = - input.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(input.getType()); if (!type || type.getRank() != 2) return {failure(), nullptr}; auto permute_attr = DenseIntElementsAttr::get( @@ -583,15 +583,15 @@ bool ConvertTFMatrixDiagV2orV3(Operation* op, PatternRewriter* rewriter) { // Verify padding_value is a tensor with all 0s. mlir::Value padding_value = tf_matrix_diag_v2_or_v3_op.getPaddingValue(); mlir::Type element_type = - padding_value.getType().cast().getElementType(); - if (element_type.isa()) { + mlir::cast(padding_value.getType()).getElementType(); + if (mlir::isa(element_type)) { DenseFPElementsAttr padding_attr; if (!matchPattern(padding_value, m_Constant(&padding_attr)) || !padding_attr.isSplat() || !padding_attr.getSplatValue().isZero()) { return false; } - } else if (element_type.isa()) { + } else if (mlir::isa(element_type)) { DenseIntElementsAttr padding_attr; if (!matchPattern(padding_value, m_Constant(&padding_attr)) || !padding_attr.isSplat() || @@ -642,7 +642,7 @@ struct LegalizeUnidirectionalSequenceLstm : public RewritePattern { SmallVector tflite_indices; for (auto index_attr : tflite_indices_attr.getValue()) { - IntegerAttr index = index_attr.cast(); + IntegerAttr index = mlir::cast(index_attr); tflite_indices.push_back(index.getInt()); } @@ -773,13 +773,13 @@ class ApplyExplicitBroadcasting : public OpRewritePattern { SmallVector symbolic_broadcast_shape; // Matches fail when lhs or rhs is unranked tensor. // TODO(b/176202543): Support unranked tensor. - if (!lhs.getType().cast().hasRank() || - !rhs.getType().cast().hasRank()) { + if (!mlir::cast(lhs.getType()).hasRank() || + !mlir::cast(rhs.getType()).hasRank()) { return failure(); } if (!OpTrait::util::getBroadcastedShape( - lhs.getType().cast().getShape(), - rhs.getType().cast().getShape(), + mlir::cast(lhs.getType()).getShape(), + mlir::cast(rhs.getType()).getShape(), symbolic_broadcast_shape)) { return failure(); } @@ -824,13 +824,13 @@ class ApplyExplicitBroadcasting : public OpRewritePattern { auto lhs = op->getOperand(0); auto rhs = op->getOperand(1); - if (!lhs.getType().cast().hasStaticShape() || - !rhs.getType().cast().hasStaticShape()) { + if (!mlir::cast(lhs.getType()).hasStaticShape() || + !mlir::cast(rhs.getType()).hasStaticShape()) { return rewriteOpWithDynamicInput(op, rewriter); } - auto lhs_shape = lhs.getType().cast().getShape(); - auto rhs_shape = rhs.getType().cast().getShape(); + auto lhs_shape = mlir::cast(lhs.getType()).getShape(); + auto rhs_shape = mlir::cast(rhs.getType()).getShape(); if (lhs_shape == rhs_shape) { return failure(); @@ -892,23 +892,23 @@ class ApplyExplicitBroadcasting // Matches fail when lhs|rhs|cond is unranked tensor. // TODO(b/176202543): Support unranked tensor. - if (!lhs.getType().cast().hasRank() || - !rhs.getType().cast().hasRank() || - !cond.getType().cast().hasRank()) { + if (!mlir::cast(lhs.getType()).hasRank() || + !mlir::cast(rhs.getType()).hasRank() || + !mlir::cast(cond.getType()).hasRank()) { return failure(); } // Calculates symbolic broadcast shape that is only used in types. SmallVector symbolic_broadcast_lhs_rhs_shape; if (!OpTrait::util::getBroadcastedShape( - lhs.getType().cast().getShape(), - rhs.getType().cast().getShape(), + mlir::cast(lhs.getType()).getShape(), + mlir::cast(rhs.getType()).getShape(), symbolic_broadcast_lhs_rhs_shape)) { return failure(); } SmallVector symbolic_broadcast_shape; if (!OpTrait::util::getBroadcastedShape( - cond.getType().cast().getShape(), + mlir::cast(cond.getType()).getShape(), symbolic_broadcast_lhs_rhs_shape, symbolic_broadcast_shape)) { return failure(); } @@ -964,15 +964,15 @@ class ApplyExplicitBroadcasting auto rhs = op->getOperand(2); // Should have static shapes to calculate the broadcasted shape. - if (!lhs.getType().cast().hasStaticShape() || - !rhs.getType().cast().hasStaticShape() || - !cond.getType().cast().hasStaticShape()) { + if (!mlir::cast(lhs.getType()).hasStaticShape() || + !mlir::cast(rhs.getType()).hasStaticShape() || + !mlir::cast(cond.getType()).hasStaticShape()) { return rewriteOpWithDynamicInput(op, rewriter); } - auto lhs_shape = lhs.getType().cast().getShape(); - auto rhs_shape = rhs.getType().cast().getShape(); - auto cond_shape = cond.getType().cast().getShape(); + auto lhs_shape = mlir::cast(lhs.getType()).getShape(); + auto rhs_shape = mlir::cast(rhs.getType()).getShape(); + auto cond_shape = mlir::cast(cond.getType()).getShape(); if (lhs_shape == rhs_shape && cond_shape == lhs_shape) { return failure(); diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_variables.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_variables.cc index 7098b2f75157da..7742ea06976c00 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_variables.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_variables.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -67,7 +68,7 @@ class LegalizeVariablesPass // If TFLite variable legalization is not allowed, then we skip this pass. if (auto legalize_tfl_variables_attr = module->getAttr(kLegalizeTflVariables)) { - if (!legalize_tfl_variables_attr.cast().getValue()) return; + if (!mlir::cast(legalize_tfl_variables_attr).getValue()) return; } RewritePatternSet patterns(&getContext()); diff --git a/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc b/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc index e212ce16ee6ccd..747e96d40b6850 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lift_tflite_flex_ops.cc @@ -117,7 +117,7 @@ class LiftFlexCustomOp : public OpRewritePattern { // TODO(b/146131919): correct handling of resource type if (auto tensor_array_v3_op = dyn_cast(tf_op)) { Value handle = tensor_array_v3_op.getHandle(); - auto handle_type = handle.getType().cast(); + auto handle_type = mlir::cast(handle.getType()); if (handle_type.getElementType().isInteger(/*width=*/32)) { Type resource_tensor_type = handle_type.clone(TF::ResourceType::get(rewriter.getContext())); @@ -225,8 +225,8 @@ class LiftFlexCustomOp : public OpRewritePattern { return emitError(loc, mlir_attr.status().message()); } if (absl::StrContains(op_name, "Dataset") && - mlir_attr->isa()) { - mlir_attr = mlir_attr->cast().getName(); + mlir::isa(*mlir_attr)) { + mlir_attr = mlir::cast(*mlir_attr).getName(); } attributes.push_back(builder.getNamedAttr(attr_name, *mlir_attr)); } diff --git a/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc b/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc index a8adac41229277..7fea1e395ea209 100644 --- a/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc +++ b/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc @@ -94,7 +94,7 @@ LogicalResult ModifyIONodesPass::SetupInputOutputTypesIfNull( LogicalResult ModifyIONodesPass::ModifyInputNodes( func::FuncOp func, llvm::SmallVectorImpl& new_input_types, OpBuilder builder) { - if (input_type.isa()) { + if (mlir::isa(input_type)) { return success(); } @@ -151,7 +151,7 @@ LogicalResult ModifyIONodesPass::ModifyOutputNodes( auto* terminator = block.getTerminator(); builder.setInsertionPoint(terminator); - if (output_type.isa()) { + if (mlir::isa(output_type)) { return success(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index 1fc84007a64cce..0ecb04f82b5952 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -81,7 +81,7 @@ constexpr char kRelu6[] = "RELU6"; constexpr char kRelu1[] = "RELU_N1_TO_1"; ElementsAttr FlattenTo1D(Attribute a) { - auto elements = a.cast(); + auto elements = mlir::cast(a); const std::array flattened_shape = {elements.getNumElements()}; auto new_type = RankedTensorType::get(flattened_shape, elements.getType().getElementType()); @@ -91,8 +91,8 @@ ElementsAttr FlattenTo1D(Attribute a) { // This assumes that the bias is of shape NxCx1x1 and doesn't require transpose // Its corresponding constraint is optimize_patterns.td:IsBiasShape() ElementsAttr ReshapeNCHWBiasToNHWC(Value v, Attribute a) { - auto elements = a.cast(); - auto shape = v.getType().cast().getShape(); + auto elements = mlir::cast(a); + auto shape = mlir::cast(v.getType()).getShape(); if (shape.size() != 4 || shape[2] != 1 || shape[3] != 1) return elements; const std::array new_shape = {shape[0], shape[2], shape[3], shape[1]}; @@ -105,15 +105,16 @@ bool L2NormalizeReduceAxis(Value sq_op, DenseElementsAttr axis) { if (axis.getNumElements() == 0) { return false; } - if (sq_op.getType().cast().getRank() - 1 == + if (mlir::cast(sq_op.getType()).getRank() - 1 == *axis.getValues().begin() || *axis.getValues().begin() == -1) { return true; } - if (sq_op.getType().cast().getRank() != axis.getNumElements()) { + if (mlir::cast(sq_op.getType()).getRank() != + axis.getNumElements()) { return false; } - auto shape = sq_op.getType().cast(); + auto shape = mlir::cast(sq_op.getType()); SmallVector elems{axis.getValues().begin(), axis.getValues().end()}; for (int i = 0; i < shape.getRank(); ++i) { @@ -144,9 +145,10 @@ class OptimizePass : public impl::OptimizePassBase { // is equal to the non-contracting dimension after a reshape bool BroadcastDimsProductEqual(Value input, Value output, size_t agg_start_idx) { - ArrayRef input_shape = input.getType().cast().getShape(); + ArrayRef input_shape = + mlir::cast(input.getType()).getShape(); ArrayRef output_shape = - output.getType().cast().getShape(); + mlir::cast(output.getType()).getShape(); int64_t agg_value = 1; for (size_t i = agg_start_idx; i < input_shape.size() - 1; ++i) { @@ -166,7 +168,7 @@ bool IsBroadcastableElementsAttrAndType(Type a, Type b) { // broadcast-compatible with `b`. bool OperandsBroadcastToOutputType(Type a, Type b, Type expected_output) { Type output_element_type = - expected_output.cast().getElementType(); + mlir::cast(expected_output).getElementType(); Type broadcasted_type = OpTrait::util::getBroadcastedType(a, b, output_element_type); return broadcasted_type != Type() && broadcasted_type == expected_output; @@ -175,8 +177,8 @@ bool OperandsBroadcastToOutputType(Type a, Type b, Type expected_output) { // Returns whether if `type1` dimensions are the same as the ending dimensions // of `type2`. This is more restricted than broadcastable. bool IsTailOfShape(Type type1, Type type2) { - auto tail_type = type1.dyn_cast(); - auto full_type = type2.dyn_cast(); + auto tail_type = mlir::dyn_cast(type1); + auto full_type = mlir::dyn_cast(type2); if (!tail_type || !full_type || !tail_type.hasRank() || !full_type.hasRank() || tail_type.getRank() > full_type.getRank()) return false; @@ -189,8 +191,8 @@ bool IsTailOfShape(Type type1, Type type2) { // the reduced `type1` dimensions are the same as the ending dimensions // of `type2`. bool IsReducedTailOfShape(Type type1, Type type2) { - auto tail_type = type1.dyn_cast(); - auto full_type = type2.dyn_cast(); + auto tail_type = mlir::dyn_cast(type1); + auto full_type = mlir::dyn_cast(type2); if (!tail_type || !full_type || !tail_type.hasRank() || !full_type.hasRank()) return false; @@ -211,10 +213,10 @@ bool IsReducedTailOfShape(Type type1, Type type2) { // elements in type2. This is a required condition to flatten type2 to form a // 1D array and allow the binaryOp handle the broadcasting implicitly. bool IsLastDimEqualToNumElements(Type type1, Type type2) { - return (type1.cast().getRank() >= 1 && - type1.cast().getDimSize( - type1.cast().getRank() - 1) == - type2.cast().getNumElements()); + return (mlir::cast(type1).getRank() >= 1 && + mlir::cast(type1).getDimSize( + mlir::cast(type1).getRank() - 1) == + mlir::cast(type2).getNumElements()); } bool CanFuseConvOrDepthwiseConvShapes(const ArrayRef filter_shape, @@ -249,20 +251,21 @@ bool CanFuseConvOrDepthwiseConvShapes(const ArrayRef filter_shape, bool CanFuseConvOrDepthwiseConv(Value filter, Attribute val, bool is_depthwise) { - const auto elements = val.dyn_cast(); + const auto elements = mlir::dyn_cast(val); if (!elements) { return false; } const auto elements_shape = elements.getType().getShape(); - const auto filter_shape = filter.getType().cast().getShape(); + const auto filter_shape = mlir::cast(filter.getType()).getShape(); return CanFuseConvOrDepthwiseConvShapes(filter_shape, elements_shape, is_depthwise); } bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val, bool is_depthwise) { - if (const auto elements = val.dyn_cast()) { - if (const auto filter_elements = filter.dyn_cast()) { + if (const auto elements = mlir::dyn_cast(val)) { + if (const auto filter_elements = + mlir::dyn_cast(filter)) { return CanFuseConvOrDepthwiseConvShapes( filter_elements.getType().getShape(), elements.getType().getShape(), is_depthwise); @@ -277,8 +280,8 @@ bool CanFuseConvOrDepthwiseConv(Attribute filter, Attribute val, bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params, DenseIntElementsAttr indices, Type output_type) { - auto params_type = params.getType().dyn_cast(); - auto indices_type = indices.getType().dyn_cast(); + auto params_type = mlir::dyn_cast(params.getType()); + auto indices_type = mlir::dyn_cast(indices.getType()); // Checks the shape of `params` is [n, ...], shape of `indices` is [n, 1]. 2D // `indices` means it gets the first row of `params`. As long as indices // iterate the first row of `params`, the output is identical to input. @@ -306,8 +309,8 @@ bool CanOptimizeIdentityGatherNdOrScatterNdOp(Value params, // for each dim i, the output tensor is identical to `input`. bool CanOptimizeIdentitySliceOp(Value input, Attribute begin, Attribute size) { // Checks if `begin` and `size` are i32 or i64. - auto begin_attr = begin.dyn_cast(); - auto size_attr = size.dyn_cast(); + auto begin_attr = mlir::dyn_cast(begin); + auto size_attr = mlir::dyn_cast(size); if (!begin_attr || !size_attr) { return false; } @@ -323,7 +326,7 @@ bool CanOptimizeIdentitySliceOp(Value input, Attribute begin, Attribute size) { // Checks if `input` is ranked and its rank is equal to number of elements in // `begin` and `size`. - auto input_ty = input.getType().cast(); + auto input_ty = mlir::cast(input.getType()); if (!input_ty.hasRank()) { return false; } @@ -348,7 +351,7 @@ bool CanOptimizeIdentitySliceOp(Value input, Attribute begin, Attribute size) { // Expand Attribute 'a' to 4D with all 1s except 1 dimension. // Which dimension depends on 'is_depthwise' is true or false. ElementsAttr ExpandTo4DForConvImpl(Attribute a, bool is_depthwise) { - auto elements = a.dyn_cast(); + auto elements = mlir::dyn_cast(a); auto shape = elements.getType().getShape(); if (!shape.empty()) { // Checks that elements are essentially 1d. @@ -410,13 +413,14 @@ DenseElementsAttr RemapPermutation(Value permutation1, Value permutation2) { static bool ShapeMatchesReduceWithKeepAxes(Value input, const mlir::Attribute &axes, const mlir::Attribute &shape) { - RankedTensorType type = input.getType().dyn_cast_or_null(); + RankedTensorType type = + mlir::dyn_cast_or_null(input.getType()); if (!type) return false; DenseIntElementsAttr axes_attr = - axes.dyn_cast_or_null(); + mlir::dyn_cast_or_null(axes); DenseIntElementsAttr shape_attr = - shape.dyn_cast_or_null(); + mlir::dyn_cast_or_null(shape); if (!axes_attr || !shape_attr) return false; if (shape_attr.getNumElements() != type.getRank()) return false; @@ -441,12 +445,12 @@ static bool ShapeMatchesReduceWithKeepAxes(Value input, static bool AreInputDimensionsOneInAxes(Value input, const mlir::Attribute &axes) { RankedTensorType input_type = - input.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(input.getType()); if (!input_type) return false; auto type_shape = input_type.getShape(); DenseIntElementsAttr axes_attr = - axes.dyn_cast_or_null(); + mlir::dyn_cast_or_null(axes); if (!axes_attr) return false; for (auto a : axes_attr.getValues()) { @@ -467,7 +471,7 @@ static bool AreInputDimensionsOneInAxes(Value input, } static bool FloatValueEquals(const Attribute &attr, double value) { - auto fp_attr = attr.dyn_cast_or_null(); + auto fp_attr = mlir::dyn_cast_or_null(attr); if (!fp_attr) return false; if (fp_attr.isSplat()) { @@ -482,12 +486,12 @@ static bool FloatValueEquals(const Attribute &attr, double value) { // to `raw_value`. template bool IsConstantValueOf(mlir::TypedAttr value, T raw_value) { - auto element_type = value.getType().cast().getElementType(); + auto element_type = mlir::cast(value.getType()).getElementType(); - if (element_type.isa()) { + if (mlir::isa(element_type)) { return FloatValueEquals(value, raw_value); - } else if (element_type.isa()) { - auto int_attr = value.dyn_cast_or_null(); + } else if (mlir::isa(element_type)) { + auto int_attr = mlir::dyn_cast_or_null(value); if (!int_attr) return false; if (int_attr.isSplat()) { @@ -502,13 +506,13 @@ bool IsConstantValueOf(mlir::TypedAttr value, T raw_value) { // Returns true if the value's element type is F32. bool IsF32Value(Value value) { - return value.getType().cast().getElementType().isF32(); + return mlir::cast(value.getType()).getElementType().isF32(); } // Returns the number of elements in attr if it is a static shape, 1 otherwise, // as an unranked int32 Attribute. TypedAttr GetNumElementsOrOne(Type type) { - auto shaped_type = type.cast(); + auto shaped_type = mlir::cast(type); int32_t num_elements = shaped_type.hasStaticShape() ? shaped_type.getNumElements() : 1; @@ -523,7 +527,7 @@ TypedAttr GetNumElementsOrOne(Type type) { Value ReshapeValueDroppingLastDim(OpBuilder &builder, Value value) { // This function is always guarded with HasTrivialShapeExceptSecondLastDim(), // so we could cast safely here. - auto type = value.getType().cast(); + auto type = mlir::cast(value.getType()); SmallVector new_shape; if (type.hasStaticShape()) { for (int64_t dim : type.getShape().drop_back()) { @@ -543,7 +547,7 @@ Value ReshapeValueDroppingLastDim(OpBuilder &builder, Value value) { // Returns true if val has a static shape and the last dimension equals 1. bool IsLastDimensionEqualOne(Value val) { - const auto val_type = val.getType().cast(); + const auto val_type = mlir::cast(val.getType()); if (!val_type.hasStaticShape()) return false; const auto val_shape = val_type.getShape(); if (val_shape.empty()) return false; @@ -577,7 +581,7 @@ bool HasOneUseOrUsedByOnlyBinaryOps(Value out_value) { // // If such a value is used in an Equal operator, it can be replaced with OneHot. bool IsOneHotIndexAttribute(Attribute attr) { - const auto dense_attr = attr.dyn_cast_or_null(); + const auto dense_attr = mlir::dyn_cast_or_null(attr); if (!dense_attr) { return false; } @@ -602,7 +606,7 @@ bool IsOneHotIndexAttribute(Attribute attr) { } Value Get1DShapeValue(OpBuilder &builder, Value value) { - auto type = value.getType().cast(); + auto type = mlir::cast(value.getType()); if (!type.hasStaticShape()) { return nullptr; } @@ -614,11 +618,11 @@ Value Get1DShapeValue(OpBuilder &builder, Value value) { } Type GetEmbeddingLookupShape(Value lookup, Value value) { - auto lookup_type = lookup.getType().cast(); + auto lookup_type = mlir::cast(lookup.getType()); if (!lookup_type.hasStaticShape()) { return nullptr; } - auto value_type = value.getType().cast(); + auto value_type = mlir::cast(value.getType()); if (!value_type.hasStaticShape() || value_type.getRank() != 2) { return nullptr; } @@ -665,7 +669,7 @@ bool IsF32Splat(Attribute input_splat) { // Attribute holding a single value of float type. If attr has no elements, the // result is 0.0f. TypedAttr ConvertSingleElementAttrToFloatAttr(Attribute attr) { - const auto dense_fp_attr = attr.dyn_cast_or_null(); + const auto dense_fp_attr = mlir::dyn_cast_or_null(attr); if (dense_fp_attr) { // Already float => return return dense_fp_attr; @@ -673,7 +677,7 @@ TypedAttr ConvertSingleElementAttrToFloatAttr(Attribute attr) { OpBuilder builder(attr.getContext()); - const auto dense_int_attr = attr.dyn_cast(); + const auto dense_int_attr = mlir::dyn_cast(attr); const auto int_values = dense_int_attr.getValues(); float float_val = 0.0f; if (!int_values.empty()) { @@ -793,9 +797,7 @@ struct SqueezeReshapesAroundBroadcastOp // Pattern is applied only if the broadcast_to shape has more than 5 // dimensions. - if (tfl_broadcast_to_op.getShape() - .getType() - .cast() + if (mlir::cast(tfl_broadcast_to_op.getShape().getType()) .getNumElements() < 6) { return rewriter.notifyMatchFailure(loc, "Not supported broadcast_to shape"); @@ -831,7 +833,7 @@ struct SqueezeReshapesAroundBroadcastOp // Calculate the number of extra leading and trailing 1s in the // broadcast_op output. auto broadcast_output_shapetype = - tfl_broadcast_to_op.getOutput().getType().cast(); + mlir::cast(tfl_broadcast_to_op.getOutput().getType()); int num_leading_broadcast_dims = GetNumLeadingOnes(broadcast_output_shapetype); int num_trailing_broadcast_dims = @@ -839,9 +841,7 @@ struct SqueezeReshapesAroundBroadcastOp // Get the new shape for the inner reshape_op after removing the extra 1s. llvm::SmallVector new_reshape_shape_i32{ - inner_reshape_op.getOutput() - .getType() - .cast() + mlir::cast(inner_reshape_op.getOutput().getType()) .getShape() .drop_back(num_trailing_broadcast_dims) .drop_front(num_leading_broadcast_dims)}; @@ -886,11 +886,11 @@ struct ConvertTFLBroadcastToMulOp LogicalResult matchAndRewrite(TFL::BroadcastToOp tfl_broadcast_to_op, PatternRewriter &rewriter) const override { auto input_type = - tfl_broadcast_to_op.getInput().getType().cast(); + mlir::cast(tfl_broadcast_to_op.getInput().getType()); auto output_type = - tfl_broadcast_to_op.getOutput().getType().cast(); + mlir::cast(tfl_broadcast_to_op.getOutput().getType()); auto shape_type = - tfl_broadcast_to_op.getShape().getType().cast(); + mlir::cast(tfl_broadcast_to_op.getShape().getType()); Type element_type = input_type.getElementType(); auto loc = tfl_broadcast_to_op->getLoc(); @@ -909,7 +909,7 @@ struct ConvertTFLBroadcastToMulOp // Allow lowering when the input's elements type is F32, BFloat16, I32 or // I16. - if (!(element_type.isa() || + if (!(mlir::isa(element_type) || element_type.isInteger(32) || element_type.isInteger(16))) return rewriter.notifyMatchFailure(loc, "element_type_not_supported"); @@ -986,7 +986,7 @@ struct FuseAddAndStridedSlice : public OpRewritePattern { return failure(); mlir::TensorType constant_val_type = - constant_val.getType().cast(); + mlir::cast(constant_val.getType()); // If it's not 1D or 0D (which can be broadcasted to 1D), reject the // matching. if (constant_val_type.getRank() > 1) { @@ -994,14 +994,14 @@ struct FuseAddAndStridedSlice : public OpRewritePattern { } mlir::RankedTensorType end_type = - strided_slice_op.getEnd().getType().dyn_cast(); + mlir::dyn_cast(strided_slice_op.getEnd().getType()); // begin, end and strides are Rank 1 tensors with one element per dimension // of input. int64_t num_dims = end_type.getShape()[0]; DenseElementsAttr new_added_value = added_value.reshape(RankedTensorType::get( {num_dims}, - added_value.getType().cast().getElementType())); + mlir::cast(added_value.getType()).getElementType())); ::mlir::arith::ConstantOp new_end = rewriter.create( strided_slice_op.getEnd().getLoc(), new_added_value); @@ -1183,7 +1183,7 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern { add_op.getLhs().getDefiningOp()); if (!fc_op) return failure(); - auto constant_val_type = constant_val.getType().cast(); + auto constant_val_type = mlir::cast(constant_val.getType()); // In TFLite FullyConnect definition, bias must be a 1D tensor where // the number of elements is equal to the number of channels. @@ -1199,7 +1199,7 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern { Value filter = fc_op.getFilter(); Value bias = fc_op.getBias(); ElementsAttr bias_value; - const bool is_none_bias = bias.getType().isa(); + const bool is_none_bias = mlir::isa(bias.getType()); if (fc_op.getFusedActivationFunction() != "NONE") return failure(); if (!is_none_bias && !matchPattern(bias, m_Constant(&bias_value))) @@ -1212,7 +1212,7 @@ struct FuseFullyConnectedAndAdd : public OpRewritePattern { // to properly broadcast the scalar to `{num_channels}` shape. // Get the number of channels if possible. - auto filter_type = filter.getType().dyn_cast(); + auto filter_type = mlir::dyn_cast(filter.getType()); // Filter must be a `2D` tensor with `{num_channels, num_features}` // shape. The following check is rejecting unknown rank (-1). if (filter_type == nullptr || filter_type.getRank() != 2) { @@ -1287,14 +1287,14 @@ struct FuseAddAndFullyConnected // Don't match adds where the added constant is not 1D. { - auto addend_shape = add_op.getRhs().getType().cast(); + auto addend_shape = mlir::cast(add_op.getRhs().getType()); if (!addend_shape.hasStaticShape()) return failure(); if (addend_shape.getShape().size() != 1) return failure(); } // Calculate new bias. Generate a new FC; it will be constant folded. auto old_bias = fc_op.getBias(); - if (!old_bias || old_bias.getType().isa()) { + if (!old_bias || mlir::isa(old_bias.getType())) { // TODO(b/180752069): Figure out new bias' type when old bias is empty. return failure(); } @@ -1358,7 +1358,7 @@ struct FuseMulAndFullyConnected // Don't match muls where the multiplier constant is not 1D. { - auto multiplier_shape = mul_op.getRhs().getType().cast(); + auto multiplier_shape = mlir::cast(mul_op.getRhs().getType()); if (!multiplier_shape.hasStaticShape()) return failure(); if (multiplier_shape.getShape().size() != 1) return failure(); } @@ -1464,7 +1464,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern { Value bias = fc_op.getBias(); ElementsAttr cst_tmp; if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure(); - if (!bias.getType().isa() && + if (!mlir::isa(bias.getType()) && !matchPattern(bias, m_Constant(&cst_tmp))) return failure(); if (fc_op.getFusedActivationFunction() != "NONE") return failure(); @@ -1494,7 +1494,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern { // Rewrite. Since the folder of TFL::MulOp couldn't broadcast the operands, // TF::MulOp is used to fold the constant. // TODO(b/139192933): switch to the TFL constant folding - auto filter_type = filter.getType().cast(); + auto filter_type = mlir::cast(filter.getType()); if (filter_type.hasStaticShape()) { auto size = filter_type.getNumElements() * filter_type.getElementTypeBitWidth(); @@ -1506,7 +1506,7 @@ struct FuseFullyConnectedAndMul : public OpRewritePattern { rewriter.create(mul_op.getLoc(), filter, new_const_val) .getZ(); // If bias isn't None, it needs to be multiplied as well. - if (!bias.getType().isa()) { + if (!mlir::isa(bias.getType())) { bias = rewriter.create(mul_op.getLoc(), bias, constant_val) .getZ(); } @@ -1585,7 +1585,7 @@ struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern { // weight constant ElementsAttr cst_tmp; if (!matchPattern(filter, m_Constant(&cst_tmp))) return failure(); - if (!bias.getType().isa() && + if (!mlir::isa(bias.getType()) && !matchPattern(bias, m_Constant(&cst_tmp))) return failure(); if (fc_op.getFusedActivationFunction() != "NONE") return failure(); @@ -1607,7 +1607,7 @@ struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern { } // Make sure that the fused bias will be a 1D tensor. - auto gamma_shape = gamma.getType().cast(); + auto gamma_shape = mlir::cast(gamma.getType()); if (!gamma_shape.hasRank() || gamma_shape.getRank() != 1) { return failure(); } @@ -1623,7 +1623,7 @@ struct FuseAffinOpAndMulWithQDQs : public OpRewritePattern { new_filter, new_qtype); // If bias isn't None, it needs to be multiplied as well. - if (!bias.getType().isa()) { + if (!mlir::isa(bias.getType())) { rewriter.setInsertionPoint(fc_op); auto new_bias = rewriter.create(loc, bias, gamma); fc_op.getOperation()->replaceUsesOfWith(bias, new_bias); @@ -1674,7 +1674,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { } filter = q.getInput(); } - if (!bias.getType().isa() && + if (!mlir::isa(bias.getType()) && !matchPattern(bias, m_Constant(&bias_cst))) return failure(); auto binary_op_activation_func = @@ -1705,7 +1705,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { // The new bias should be a 1-D tensor with length equals to the bias // dimension of the weight. SmallVector new_bias_values; - if (bias.getType().isa()) { // none bias, a list of zeros + if (mlir::isa(bias.getType())) { // none bias, a list of zeros new_bias_values.resize(bias_size, APFloat::getZero(cst_value.getSemantics())); } else if (bias_cst.getNumElements() == 1) { // scalar bias, broadcast it @@ -1806,12 +1806,11 @@ struct ScalarizeSplatConstantForBroadcastableOps } constexpr int kSplatOperandIndex = 1; - auto result_type = - binary_op.getResult().getType().template cast(); + auto result_type = mlir::cast(binary_op.getResult().getType()); mlir::Value non_splat_operand = binary_op.getOperand(1 - kSplatOperandIndex); auto non_splat_operand_type = - non_splat_operand.getType().cast(); + mlir::cast(non_splat_operand.getType()); // If the other operand's shape does not equal to the result shape, then we // cannot scalarize the splat constant because the result shape relies on // the splat constant op's shape for broadcasting. @@ -1850,10 +1849,11 @@ struct ScalarizeSplatConstantForBroadcastableOps if (!matchPattern(value, m_Constant(elements_attr))) { return false; } - auto element_type = value.getType().cast().getElementType(); + auto element_type = + mlir::cast(value.getType()).getElementType(); // Ignore per-axis quantized constants because after converting to scalar, // we will lose per-axis qantization parameter. - if (element_type.isa()) { + if (mlir::isa(element_type)) { return false; } if (IsScalar(value)) { @@ -1864,7 +1864,7 @@ struct ScalarizeSplatConstantForBroadcastableOps // If this type is a scalar shaped type. bool IsScalar(mlir::Value value) const { - auto type = value.getType().dyn_cast(); + auto type = mlir::dyn_cast(value.getType()); if (!type) { return false; } @@ -1883,7 +1883,7 @@ struct ScalarizeSplatConstantForBroadcastableOps DenseElementsAttr value; // Check that bias are constants if not none. Value bias = affine_op->getOperand(2); - if (!bias.getType().isa() && + if (!mlir::isa(bias.getType()) && !matchPattern(bias, m_Constant(&value))) { return false; } @@ -1896,7 +1896,7 @@ struct ScalarizeSplatConstantForBroadcastableOps // We can only fuse F32/BF16. auto is_fusable_type = [](Type t) { Type element_type = t; - if (auto shaped_type = t.dyn_cast()) { + if (auto shaped_type = mlir::dyn_cast(t)) { element_type = shaped_type.getElementType(); } return element_type.isBF16() || element_type.isF32(); @@ -1926,8 +1926,9 @@ struct ConvertTrivialTransposeOpToReshapeOp LogicalResult matchAndRewrite(TFL::TransposeOp transpose_op, PatternRewriter &rewriter) const override { - auto input_type = transpose_op.getInput().getType().cast(); - auto output_type = transpose_op.getOutput().getType().cast(); + auto input_type = mlir::cast(transpose_op.getInput().getType()); + auto output_type = + mlir::cast(transpose_op.getOutput().getType()); // It's possible to know if the transformation is safe only if the input // & output shapes are fully known and permutation is a constant. if (!input_type.hasStaticShape() || !output_type.hasStaticShape()) @@ -2002,10 +2003,9 @@ struct RemoveReshapeBeforeFullyConnected LogicalResult matchAndRewrite(TFL::FullyConnectedOp fully_connected_op, PatternRewriter &) const override { auto input = fully_connected_op.getInput(); - auto input_ty = input.getType().dyn_cast(); - auto output_ty = fully_connected_op.getOutput()[0] - .getType() - .template dyn_cast(); + auto input_ty = mlir::dyn_cast(input.getType()); + auto output_ty = + mlir::dyn_cast(fully_connected_op.getOutput()[0].getType()); if (!input_ty.hasStaticShape() || fully_connected_op.getWeightsFormat() != "DEFAULT" || fully_connected_op.getKeepNumDims() || !output_ty.hasStaticShape() || @@ -2018,7 +2018,7 @@ struct RemoveReshapeBeforeFullyConnected // Check if the last dimension does not change after reshape. auto reshape_input = reshape_op.getInput(); - auto reshape_input_ty = reshape_input.getType().dyn_cast(); + auto reshape_input_ty = mlir::dyn_cast(reshape_input.getType()); if (!reshape_input_ty.hasStaticShape() || input_ty.getRank() == 0 || reshape_input_ty.getRank() == 0 || input_ty.getDimSize(input_ty.getRank() - 1) != @@ -2061,9 +2061,9 @@ struct RemoveReshapeAfterFullyConnected if (!reshape_op.getInput().hasOneUse()) return failure(); auto input_shape = - fully_connected_op.getInput().getType().cast(); - auto output_shape = fully_connected_op.getType(0).cast(); - auto reshape_shape = reshape_op.getType().cast(); + mlir::cast(fully_connected_op.getInput().getType()); + auto output_shape = mlir::cast(fully_connected_op.getType(0)); + auto reshape_shape = mlir::cast(reshape_op.getType()); if (!input_shape.hasStaticShape() || !output_shape.hasStaticShape() || !reshape_shape.hasStaticShape()) return failure(); @@ -2128,7 +2128,7 @@ struct FuseUnpackAndConcatToReshape } } - auto output_type = concat_op.getType().cast(); + auto output_type = mlir::cast(concat_op.getType()); if (!output_type.hasStaticShape()) { return failure(); } @@ -2188,8 +2188,8 @@ struct OptimizeTopK : public OpRewritePattern { // for last dimension. // It can be done by verifying the number of elements: // i.e., num_input/input_last_dim = num_result/k - auto input_ty = value.getType().dyn_cast_or_null(); - auto result_ty = slice_op.getType().dyn_cast(); + auto input_ty = mlir::dyn_cast_or_null(value.getType()); + auto result_ty = mlir::dyn_cast(slice_op.getType()); if (!input_ty || !result_ty) return std::nullopt; if (!input_ty.hasStaticShape() || !result_ty.hasStaticShape()) return std::nullopt; @@ -2230,8 +2230,8 @@ struct OptimizeTopK : public OpRewritePattern { Value k_cst = rewriter.create( op.getLoc(), DenseElementsAttr::get(k_ty, k)); // Compute new result types. - auto values_ty = values.getType().dyn_cast(); - auto indices_ty = indices.getType().dyn_cast(); + auto values_ty = mlir::dyn_cast(values.getType()); + auto indices_ty = mlir::dyn_cast(indices.getType()); auto shape = std::vector(); for (auto d : values_ty.getShape().drop_back()) { shape.push_back(d); @@ -2439,7 +2439,7 @@ struct FuseLogSoftmax : public OpRewritePattern { if (!sum_op || !sum_op.getKeepDims() || !isSupportedAxis( sum_op.getAxes(), - sum_op.getOperand(0).getType().cast().getRank())) { + mlir::cast(sum_op.getOperand(0).getType()).getRank())) { return failure(); } if (!sum_op->hasOneUse()) { @@ -2466,10 +2466,10 @@ struct FuseLogSoftmax : public OpRewritePattern { parent_sub_op.getRhs().getDefiningOp()); if (!reduce_max_op || !reduce_max_op->hasOneUse() || !reduce_max_op.getKeepDims() || - !isSupportedAxis(reduce_max_op.getAxes(), reduce_max_op.getOperand(0) - .getType() - .cast() - .getRank())) { + !isSupportedAxis( + reduce_max_op.getAxes(), + mlir::cast(reduce_max_op.getOperand(0).getType()) + .getRank())) { return failure(); } @@ -2562,7 +2562,7 @@ struct UndoBroadcastFullyConnectedBiasAddWithQDQs } auto bias_type = bias_op.getType(); - auto bias_rank = bias_type.cast().getRank(); + auto bias_rank = mlir::cast(bias_type).getRank(); if (bias_rank > 4 || bias_rank < 2) { return failure(); } @@ -2587,8 +2587,8 @@ struct UndoBroadcastFullyConnectedBiasAddWithQDQs q_op.setOperand(new_bias_op); auto new_q_op_type = RankedTensorType::Builder( - q_op.getResult().getType().cast()) - .setShape(new_bias_type.cast().getShape()); + mlir::cast(q_op.getResult().getType())) + .setShape(mlir::cast(new_bias_type).getShape()); q_op.getResult().setType(new_q_op_type); auto attr = TypeAttr::get(q_op.getResult().getType()); q_op.setQtypeAttr(attr); @@ -2596,8 +2596,8 @@ struct UndoBroadcastFullyConnectedBiasAddWithQDQs // Update DequantizeOp's output shape auto new_dq_op_type = RankedTensorType::Builder( - dq_op.getResult().getType().cast()) - .setShape(new_bias_type.cast().getShape()); + mlir::cast(dq_op.getResult().getType())) + .setShape(mlir::cast(new_bias_type).getShape()); dq_op.getResult().setType(new_dq_op_type); // Remove old bias diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.cc index 5b696b52db4b2e..0eacfcb8ef09f0 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.cc @@ -94,7 +94,8 @@ struct ConvertBatchMatMulOp2FullyConnectedOp // Create a tfl.transpose op that performs ZX transpose on `input`. auto create_z_x_transpose_op = [&](Value input) -> Value { - RankedTensorType input_type = input.getType().cast(); + RankedTensorType input_type = + mlir::cast(input.getType()); const int input_rank = input_type.getRank(); // Create a 1D I32 tensor for representing the dimension permutation. @@ -176,7 +177,7 @@ struct ConvertBatchMatMulOpToReduceSum // the adj(X|Y) attribute, respectively. // So adjX == True indicates [..., c_x, r_x == 1]. llvm::ArrayRef lhs_shape = - bmm_op.getX().getType().cast().getShape(); + mlir::cast(bmm_op.getX().getType()).getShape(); int rX = lhs_shape.size() - 2; int cX = lhs_shape.size() - 1; if (bmm_op.getAdjX()) { @@ -189,7 +190,7 @@ struct ConvertBatchMatMulOpToReduceSum } llvm::ArrayRef rhs_shape = - bmm_op.getY().getType().cast().getShape(); + mlir::cast(bmm_op.getY().getType()).getShape(); int rY = rhs_shape.size() - 1; int cY = rhs_shape.size() - 2; if (bmm_op.getAdjX()) { @@ -210,11 +211,11 @@ struct ConvertBatchMatMulOpToReduceSum private: bool SplatValueEquals(SplatElementsAttr float_or_int, double rhs) const { - if (float_or_int.isa()) { - return float_or_int.cast() + if (mlir::isa(float_or_int)) { + return mlir::cast(float_or_int) .getSplatValue() .isExactlyValue(rhs); - } else if (float_or_int.cast()) { + } else if (mlir::cast(float_or_int)) { return float_or_int.getSplatValue() == static_cast(rhs); } return false; diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc index 7d7ab4b5acd33d..69137210b48ffc 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_functional_ops.cc @@ -21,12 +21,13 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -110,7 +111,7 @@ class FoldIfOp : public OpRewritePattern { if (!matchPattern(op.getCond(), m_Constant(&cond))) return failure(); // TODO(hinsu): Handle constants that are not scalar booleans. - auto cond_type = cond.getType().dyn_cast(); + auto cond_type = mlir::dyn_cast(cond.getType()); if (!cond_type || !cond_type.getShape().equals({}) || !cond_type.getElementType().isInteger(/*width=*/1)) return failure(); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc index 4ce0a3b8c43225..62c2c43778e254 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -66,9 +67,9 @@ struct PushDownDequantize : public OpRewritePattern { // Only push down the dequantize op when the output is smaller, so that it // can have smaller memory usage. auto input_type = - dequantize_op.getOutput().getType().dyn_cast(); - auto output_type = - passthrough_op->getResult(0).getType().dyn_cast(); + mlir::dyn_cast(dequantize_op.getOutput().getType()); + auto output_type = mlir::dyn_cast( + passthrough_op->getResult(0).getType()); if (!input_type || !output_type || get_num_elements(input_type) <= get_num_elements(output_type)) { return failure(); @@ -85,7 +86,7 @@ struct PushDownDequantize : public OpRewritePattern { // Set the input type of the passthrough op and pull it up. Type new_output_type; - if (input_element_type.isa()) { + if (mlir::isa(input_element_type)) { new_output_type = QuantizedType::getQuantizedElementType( dequantize_op.getInput().getType()) .castFromExpressedType(output_type); diff --git a/tensorflow/compiler/mlir/lite/transforms/pin_ops_with_side_effects.cc b/tensorflow/compiler/mlir/lite/transforms/pin_ops_with_side_effects.cc index 1d0cd497b052f3..7baa0136f1c33c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/pin_ops_with_side_effects.cc +++ b/tensorflow/compiler/mlir/lite/transforms/pin_ops_with_side_effects.cc @@ -37,9 +37,9 @@ namespace { #include "tensorflow/compiler/mlir/lite/transforms/passes.h.inc" bool IsResourceTensor(Value value) { - const auto tensor_type = value.getType().dyn_cast(); + const auto tensor_type = mlir::dyn_cast(value.getType()); return tensor_type && - tensor_type.getElementType().isa(); + mlir::isa(tensor_type.getElementType()); } // The default criterion for operations being considered as causing or being diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 80d7ab24c23316..867eecff15818f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -233,8 +234,8 @@ struct FoldTransposeOp : public OpRewritePattern { DenseIntElementsAttr perm_tensor; if (!matchPattern(op.getPerm(), m_Constant(&perm_tensor))) return failure(); - if (!(getElementTypeOrSelf(op.getOutput().getType())) - .isa()) + if (!mlir::isa( + (getElementTypeOrSelf(op.getOutput().getType())))) return failure(); ElementsAttr input_tensor = qconst_op.getValue(); @@ -244,7 +245,7 @@ struct FoldTransposeOp : public OpRewritePattern { assert(perm_tensor.getType().getNumElements() == num_dimensions); ArrayRef input_shape = input_tensor.getShapedType().getShape(); - auto output_type = op.getOutput().getType().cast(); + auto output_type = mlir::cast(op.getOutput().getType()); SmallVector perm; SmallVector output_shape; @@ -265,9 +266,9 @@ struct FoldTransposeOp : public OpRewritePattern { auto result_type = RankedTensorType::get(output_shape, output_type.getElementType()); auto values_type = RankedTensorType::get( - output_shape, output_type.getElementType() - .cast() - .getStorageType()); + output_shape, + mlir::cast(output_type.getElementType()) + .getStorageType()); rewriter.replaceOpWithNewOp( op, TypeAttr::get(result_type), DenseIntElementsAttr::get(values_type, new_values)); @@ -289,18 +290,18 @@ struct FoldReshapeOp : public OpRewritePattern { if (qconst_op == nullptr) return failure(); auto dense_elements = - qconst_op.getValue().dyn_cast_or_null(); + mlir::dyn_cast_or_null(qconst_op.getValue()); if (dense_elements == nullptr) return failure(); // Handle per tensor cases only. - if (!(getElementTypeOrSelf(op.getType())) - .isa()) { + if (!mlir::isa( + (getElementTypeOrSelf(op.getType())))) { return failure(); } // Remove identity reshape with both static result and input shape. - auto result_type = op.getType().cast(); - auto input_type = op.getInput().getType().cast(); + auto result_type = mlir::cast(op.getType()); + auto input_type = mlir::cast(op.getInput().getType()); // Constant folding // If the result type isn't static, tries to derive the result type from @@ -318,9 +319,9 @@ struct FoldReshapeOp : public OpRewritePattern { RankedTensorType::get(shape_data, input_type.getElementType()); } auto values_type = RankedTensorType::get( - result_type.getShape(), result_type.getElementType() - .cast() - .getStorageType()); + result_type.getShape(), + mlir::cast(result_type.getElementType()) + .getStorageType()); DenseElementsAttr reshaped_elements = dense_elements.reshape(values_type); rewriter.replaceOpWithNewOp(op, TypeAttr::get(result_type), diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc index 873efde6a290d6..9ed32a1b9a674e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc @@ -80,13 +80,15 @@ LogicalResult CreateTflFusableOpCustomOptions( size_t start_map = fbb.StartMap(); for (auto attr : attrs) { - if (auto float_attr = attr.second.dyn_cast_or_null()) { + if (auto float_attr = mlir::dyn_cast_or_null(attr.second)) { fbb.Float(attr.first.data(), float_attr.getValue().convertToFloat()); - } else if (auto int_attr = attr.second.dyn_cast_or_null()) { + } else if (auto int_attr = + mlir::dyn_cast_or_null(attr.second)) { fbb.Int(attr.first.data(), int_attr.getInt()); - } else if (auto bool_attr = attr.second.dyn_cast_or_null()) { + } else if (auto bool_attr = mlir::dyn_cast_or_null(attr.second)) { fbb.Bool(attr.first.data(), bool_attr.getValue()); - } else if (auto string_attr = attr.second.dyn_cast_or_null()) { + } else if (auto string_attr = + mlir::dyn_cast_or_null(attr.second)) { fbb.String(attr.first.data(), string_attr.getValue().str()); } else { // TODO(b/201482289): support other data types. @@ -182,7 +184,7 @@ LogicalResult CheckFusableLayerNormalizedLstmCellSimple( func::FuncOp lstm_func) { for (int i = 0; i < 5; ++i) { auto input = lstm_func.getArgument(i); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type) { lstm_func.emitWarning( "we cannot fuse this lstm func because all the inputs have not " @@ -197,7 +199,7 @@ LogicalResult CheckFusableLayerNormalizedLstmCellSimple( LogicalResult CheckFusableLstmCellSimple(func::FuncOp lstm_func) { for (int i = 0; i < 4; ++i) { auto input = lstm_func.getArgument(i); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type) { lstm_func.emitWarning( "we cannot fuse this lstm func because all the inputs have not " @@ -250,7 +252,7 @@ LogicalResult CheckFusableKerasLstm(func::FuncOp lstm_func, ModuleOp module) { // types. for (int i = 0; i < 6; ++i) { auto input = lstm_func.getArgument(i); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type) { lstm_func.emitWarning( "we cannot fuse this lstm func because all the inputs have not " @@ -368,7 +370,7 @@ void PrepareCompositeFunctionsPass::ConvertTFImplementsWithAttributes( for (auto attr_item : dict_attr) { // Push other attributes except the TFLFusableOp. if (attr_item.getName() == kTFLFusableOp && - attr_item.getValue().dyn_cast().getValue()) { + mlir::dyn_cast(attr_item.getValue()).getValue()) { tfl_fusable_op = true; } else { attributes.push_back({attr_item.getName(), attr_item.getValue()}); diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index ce11ca73970136..9f76ad1f6e9098 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -153,8 +153,8 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(func::FuncOp func) { bool need_to_set_input_nodes_quantization_params = false; for (const BlockArgument arg : func.getArguments()) { - auto shaped = arg.getType().dyn_cast(); - if (shaped && shaped.getElementType().isa() && + auto shaped = mlir::dyn_cast(arg.getType()); + if (shaped && mlir::isa(shaped.getElementType()) && !has_quantize_op(arg)) { need_to_set_input_nodes_quantization_params = true; break; @@ -179,8 +179,8 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(func::FuncOp func) { auto add_quantize_op = [&](Location loc, Type input_type, Block* block, Block::iterator insertion_point, Value arg, int i) { - if (auto shaped = input_type.dyn_cast()) { - if (shaped.getElementType().isa()) { + if (auto shaped = mlir::dyn_cast(input_type)) { + if (mlir::isa(shaped.getElementType())) { // If there are existing quantize ops, they are from training and we // should respect them. if (has_quantize_op(arg)) { diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc index f0fd79ff207f39..0b823844aa4a58 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h" @@ -193,7 +194,7 @@ class PrepareDynamicRangeQuantizableOp continue; } - if (attr.dyn_cast().size() >= + if (mlir::dyn_cast(attr).size() >= quant_specs_.minimum_elements_for_weights) { continue; } @@ -205,7 +206,7 @@ class PrepareDynamicRangeQuantizableOp "supported. The operand ") << const_op->getName().getStringRef().str() << " at index " << qi << " was not quantized because it has " - << attr.dyn_cast().size() + << mlir::dyn_cast(attr).size() << " elements which is fewer than the " "`minimum_elements_for_weights` threshold of " << quant_specs_.minimum_elements_for_weights; @@ -233,7 +234,7 @@ class PrepareDynamicRangeQuantizableOp // Get types TensorType old_result_type = - op.getResult().getType().template dyn_cast(); + mlir::dyn_cast(op.getResult().getType()); FloatType quantized_type = FloatType::getF16(op.getContext()); ShapedType new_result_type = old_result_type.clone(quantized_type); @@ -287,27 +288,27 @@ class PrepareDynamicRangeQuantizableOp DenseFPElementsAttr attr; if (!matchPattern(op->getResult(0), m_Constant(&attr))) return false; - if (attr.dyn_cast().size() < + if (mlir::dyn_cast(attr).size() < quant_specs_.minimum_elements_for_weights) { op->emitRemark("Quantization is skipped for ") << quantize_op->getName().getStringRef().str() << " because it has " - << attr.dyn_cast().size() + << mlir::dyn_cast(attr).size() << " elements which is fewer than the threshold(" << quant_specs_.minimum_elements_for_weights << " elements)."; return false; } if (op_with_per_axis_support) { - quant_type = quant::GetUniformQuantizedPerAxisTypeForWeight( - attr, affine_user.GetQuantizationDimIndex(), - /*symmetric=*/true, bit_width, is_signed, - is_narrow_range, is_legacy_float) - .template dyn_cast(); + quant_type = mlir::dyn_cast( + quant::GetUniformQuantizedPerAxisTypeForWeight( + attr, affine_user.GetQuantizationDimIndex(), + /*symmetric=*/true, bit_width, is_signed, is_narrow_range, + is_legacy_float)); } else { - quant_type = quant::GetUniformQuantizedTypeForWeight( - attr, is_narrow_range && is_signed, bit_width, is_signed, - is_narrow_range, is_legacy_float) - .template dyn_cast(); + quant_type = mlir::dyn_cast( + quant::GetUniformQuantizedTypeForWeight( + attr, is_narrow_range && is_signed, bit_width, is_signed, + is_narrow_range, is_legacy_float)); } return insertQDQ(rewriter, op, quant_type, quant_op); } @@ -346,7 +347,7 @@ class PrepareDynamicRangeQuantizableOp bool getQuantizableOps(arith::ConstantOp op, QuantizationUnits& quantizable_ops) const { // Non-float tensors do not need quantization. - auto type = op.getType().dyn_cast(); + auto type = mlir::dyn_cast(op.getType()); if (!type || !type.getElementType().isF32()) return false; Value value = op.getResult(); @@ -420,7 +421,7 @@ class PrepareDynamicRangeQuantizableOp // Get types Type old_result_type = op.getResult().getType(); ShapedType new_result_type = - cast_op.getType().template dyn_cast(); + mlir::dyn_cast(cast_op.getType()); // Proceeds only if the casting is to float16 if (!new_result_type.getElementType().isF16()) continue; @@ -428,7 +429,7 @@ class PrepareDynamicRangeQuantizableOp // Cast values std::vector new_values; DenseFPElementsAttr value_attr = - op.getValue().cast(); + mlir::cast(op.getValue()); new_values.reserve(value_attr.getNumElements()); constexpr float kMaxFloat16Value = 65504.f; diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h index e102c6bedd4328..68404404926775 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h @@ -36,6 +36,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" @@ -100,19 +101,18 @@ LogicalResult GetLstmProperty(LstmOp op, return failure(); } lstm_variant->use_projection = - !op.getProjectionWeights().getType().template isa(); + !mlir::isa(op.getProjectionWeights().getType()); lstm_variant->use_peephole = - !op.getCellToOutputWeights().getType().template isa(); + !mlir::isa(op.getCellToOutputWeights().getType()); lstm_variant->use_layer_norm = - !op.getForgetLayerNormCoefficients().getType().template isa(); + !mlir::isa(op.getForgetLayerNormCoefficients().getType()); *op_property = operator_property::GetOperatorProperty( *lstm_variant, activation_number_of_bits); // TODO(b/176258587) move this to operator_property.cc if this is needed in // other components, too. - bool use_cifg = - op.getInputToInputWeights().getType().template isa(); + bool use_cifg = mlir::isa(op.getInputToInputWeights().getType()); if (use_cifg) { const absl::flat_hash_set cifg_non_inputs = {1, 5, 9, 12, 20}; const int cifg_non_intermediate = 0; @@ -197,9 +197,9 @@ class PrepareLstmOutputScale : public OpRewritePattern { llvm::SmallVector min_max_values; for (auto& stats_op : stats_ops) { - auto values = stats_op.getLayerStats() - .dyn_cast() - .getValues(); + auto values = + mlir::dyn_cast(stats_op.getLayerStats()) + .getValues(); min_max_values.insert(min_max_values.end(), values.begin(), values.end()); } @@ -285,8 +285,8 @@ class ConvertOpStatsToQDQs : public OpRewritePattern { const operator_property::TensorProperty& tensor_property, PatternRewriter& rewriter) const { // Non-float tensors are neither weights nor require quantization. - auto type = const_op->getResult(0).getType().dyn_cast(); - if (!type || !type.getElementType().isa()) return success(); + auto type = mlir::dyn_cast(const_op->getResult(0).getType()); + if (!type || !mlir::isa(type.getElementType())) return success(); DenseFPElementsAttr attr; if (!matchPattern(const_op->getResult(0), m_Constant(&attr))) { @@ -312,12 +312,12 @@ class ConvertOpStatsToQDQs : public OpRewritePattern { rewriter.getIntegerType(16), attr.getType().getElementType(), scale, /*zeroPoint=*/0, llvm::minIntN(10), -llvm::minIntN(10)); } else { - quant_type = quant::GetUniformQuantizedTypeForWeight( - attr, /*symmetric=*/true, - /*num_bits=*/tensor_property.number_of_bits, - /*is_signed=*/true, - /*narrow_range=*/true, quant_specs_.legacy_float_scale) - .template dyn_cast(); + quant_type = mlir::dyn_cast( + quant::GetUniformQuantizedTypeForWeight( + attr, /*symmetric=*/true, + /*num_bits=*/tensor_property.number_of_bits, + /*is_signed=*/true, + /*narrow_range=*/true, quant_specs_.legacy_float_scale)); } if (!quant_type) { const_op->emitError("Failed to get quantized type"); @@ -346,7 +346,7 @@ class ConvertOpStatsToQDQs : public OpRewritePattern { << "] is a state tensor, but has more than one use."; return failure(); } - auto stats = stats_op.getLayerStats().dyn_cast(); + auto stats = mlir::dyn_cast(stats_op.getLayerStats()); if (!stats || stats.getNumElements() != 2) { stats_op.emitError("Stats should have 2 values."); return failure(); @@ -454,7 +454,7 @@ class ConvertLstmStatsToQDQs : public ConvertOpStatsToQDQs { return failure(); } auto calibrated_type = - quant_type.template dyn_cast(); + mlir::dyn_cast(quant_type); if (!calibrated_type) { int num_storage_bits = quant_type.getStorageTypeIntegralWidth(); if (tensor_property.number_of_bits != num_storage_bits) { @@ -474,9 +474,9 @@ class ConvertLstmStatsToQDQs : public ConvertOpStatsToQDQs { /*narrowRange=*/false, calibrated_type.getExpressedType(), /*isSigned=*/this->quant_specs_.IsSignedInferenceType()); if (this->quant_specs_.legacy_float_scale) { - qtype = quant::DownCastScale(qtype, calibrated_type.getMin(), - calibrated_type.getMax(), op.getLoc()) - .template cast(); + qtype = mlir::cast( + quant::DownCastScale(qtype, calibrated_type.getMin(), + calibrated_type.getMax(), op.getLoc())); } } else if (tensor_property.number_of_bits == 16) { double max = std::max(std::abs(calibrated_type.getMin()), @@ -508,9 +508,9 @@ inline quant::AccumulatorScaleFunc GetUniformQuantizedTypeForBiasWithScale( return [=](const std::vector& quant_params, const int adjusted_quant_dim, const bool legacy_float_scale) -> quant::QuantParams { - if (auto qtype = quant::GetUniformQuantizedTypeForBias( - quant_params, legacy_float_scale, adjusted_quant_dim) - .dyn_cast_or_null()) { + if (auto qtype = mlir::dyn_cast_or_null( + quant::GetUniformQuantizedTypeForBias( + quant_params, legacy_float_scale, adjusted_quant_dim))) { return quant::UniformQuantizedType::get( qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(), qtype.getScale() * scale, qtype.getZeroPoint(), @@ -540,14 +540,14 @@ std::unique_ptr GetLstmOpQuantSpec(LstmOp op) { tensor_property.derived_scale.intermediate_tensors) { auto quant_type = GetIntermediateElementType(op, tensor_index); if (!quant_type || - !quant_type.template isa()) { + !mlir::isa(quant_type)) { op->emitError() << "While processing derived scale, intermediate " << intermediate_attributes[tensor_index] << " is not quantized."; return nullptr; } - scale *= quant_type.template dyn_cast() - .getScale(); + scale *= + mlir::dyn_cast(quant_type).getScale(); } for (float factor : tensor_property.derived_scale.factors) { scale *= factor; @@ -590,7 +590,8 @@ class PropagateTransposedPerAxisQuantDim auto q_op = dyn_cast_or_null( dq_op.getOperand().getDefiningOp()); if (!q_op) return failure(); - auto qtype = dq_op.getArg().getType().cast().getElementType(); + auto qtype = + mlir::cast(dq_op.getArg().getType()).getElementType(); auto aqtype = dyn_cast_or_null(qtype); if (!aqtype) return failure(); @@ -599,8 +600,8 @@ class PropagateTransposedPerAxisQuantDim auto next_op = *transpose_op.getResult().getUsers().begin(); if (dyn_cast_or_null(next_op)) return failure(); - auto input_type = transpose_op.getInput().getType().cast(); - auto perm_type = transpose_op.getPerm().getType().cast(); + auto input_type = mlir::cast(transpose_op.getInput().getType()); + auto perm_type = mlir::cast(transpose_op.getPerm().getType()); if (input_type.hasStaticShape() && perm_type.hasStaticShape()) { if (perm_type.getNumElements() != input_type.getRank()) { return transpose_op.emitOpError( diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 9f0a7fbafff450..09e7f080fa83fb 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -89,7 +89,7 @@ namespace { // Preconditions: The given value must have a ShapedType. static Value CreateTFCastOpI32(OpBuilder *builder, Location loc, Value x, BoolAttr truncate) { - auto x_type = x.getType().dyn_cast_or_null(); + auto x_type = mlir::dyn_cast_or_null(x.getType()); if (!x_type) llvm_unreachable("unsupported type"); Type type = x_type.clone(builder->getI32Type()); return builder->create(loc, type, x, truncate); @@ -200,14 +200,14 @@ class ConvertTFConvOp : public RewritePattern { // that we can extract info from the shape (e.g., for constructing bias // tensor, for setting depth_multiplier attribute, etc.). auto filter = tf_op.getFilter(); - auto filter_type = filter.getType().template dyn_cast(); + auto filter_type = mlir::dyn_cast(filter.getType()); if (!filter_type || filter_type.getRank() != 4 || !filter_type.hasStaticShape()) return failure(); Value input = tf_op.getInput(); RankedTensorType input_type = - input.getType().template dyn_cast(); + mlir::dyn_cast(input.getType()); // Only rank size four input will be only available by the tf.Conv2D // operator verification. if (!input_type || input_type.isDynamicDim(3)) { @@ -244,7 +244,7 @@ class ConvertTFConvOp : public RewritePattern { op->getAttrOfType("explicit_paddings").getValue(); auto get_int = [](Attribute attr) { - return attr.template cast().getInt(); + return mlir::cast(attr).getInt(); }; SmallVector padding_values(padding_attr_array.size()); @@ -324,7 +324,7 @@ class ConvertTFConv2D : public ConvertTFConvOp { auto perm_op = rewriter.create(loc, perm_type, perm_attr); // Create tensor type for the transpose result. - auto filter_type = filter.getType().cast(); + auto filter_type = mlir::cast(filter.getType()); auto result_shape = llvm::to_vector<4>(llvm::map_range(perm, [filter_type](int64_t dim) { return filter_type.getDimSize(dim); @@ -361,7 +361,8 @@ class ConvertTFDepthwiseConv2dNative // have a corresponding 'depth_multiplier' attribute; the multiplier is the // fourth dimension in the 4-D filter tensor. We query the multiplier from // tf.DepthwiseConv2dNative and set it as the attribute value accordingly. - auto multiplier = filter.getType().cast().getDimSize(3); + auto multiplier = + mlir::cast(filter.getType()).getDimSize(3); filter = legalizeFilter(rewriter, loc, filter); return rewriter.create( @@ -385,7 +386,7 @@ class ConvertTFDepthwiseConv2dNative /// RankedTensorType. Value legalizeFilter(PatternRewriter &rewriter, Location loc, Value filter) const { - auto filter_type = filter.getType().cast(); + auto filter_type = mlir::cast(filter.getType()); auto filterShape = filter_type.getShape(); SmallVector result_shape = {1, filterShape[0], filterShape[1], filterShape[2] * filterShape[3]}; @@ -443,7 +444,7 @@ struct ConvertTFStridedSlice : public RewritePattern { // Insert a new reshape op. Value original_input = strided_slice_op.getInput(); RankedTensorType original_input_type = - original_input.getType().dyn_cast(); + mlir::dyn_cast(original_input.getType()); if (!original_input_type) { return failure(); } @@ -522,7 +523,8 @@ struct ConvertTFStridedSlice : public RewritePattern { DenseIntElementsAttr begin_dense_elem_attr; Value begin = strided_slice_op.getBegin(); - auto begin_ranked_attr_type = begin.getType().dyn_cast(); + auto begin_ranked_attr_type = + mlir::dyn_cast(begin.getType()); if (!begin_ranked_attr_type || !matchPattern(begin, m_Constant(&begin_dense_elem_attr))) { return failure(); @@ -530,7 +532,7 @@ struct ConvertTFStridedSlice : public RewritePattern { DenseIntElementsAttr end_dense_elem_attr; Value end = strided_slice_op.getEnd(); - auto end_ranked_attr_type = end.getType().dyn_cast(); + auto end_ranked_attr_type = mlir::dyn_cast(end.getType()); if (!end_ranked_attr_type || !matchPattern(end, m_Constant(&end_dense_elem_attr))) { return failure(); @@ -539,14 +541,15 @@ struct ConvertTFStridedSlice : public RewritePattern { DenseIntElementsAttr stride_dense_elem_attr; Value stride = strided_slice_op.getStrides(); auto stride_ranked_attr_type = - stride.getType().dyn_cast(); + mlir::dyn_cast(stride.getType()); if (!stride_ranked_attr_type || !matchPattern(stride, m_Constant(&stride_dense_elem_attr))) { return failure(); } Value input = strided_slice_op.getInput(); - RankedTensorType input_type = input.getType().dyn_cast(); + RankedTensorType input_type = + mlir::dyn_cast(input.getType()); if (!input_type) { return failure(); } @@ -554,7 +557,7 @@ struct ConvertTFStridedSlice : public RewritePattern { const int input_size = input_shape.size(); - RankedTensorType begin_type = begin.getType().cast(); + RankedTensorType begin_type = mlir::cast(begin.getType()); const ArrayRef begin_shape = begin_type.getShape(); const int begin_dim = begin_shape.size(); @@ -688,7 +691,7 @@ struct ConvertTFStridedSlice : public RewritePattern { } auto ranked_input_type = - strided_slice_op.getInput().getType().dyn_cast(); + mlir::dyn_cast(strided_slice_op.getInput().getType()); if (!ranked_input_type) { return failure(); } @@ -697,10 +700,11 @@ struct ConvertTFStridedSlice : public RewritePattern { auto end_attr = strided_slice_op.getEnd(); auto strides_attr = strided_slice_op.getStrides(); - auto begin_attr_type = begin_attr.getType().dyn_cast(); - auto end_attr_type = end_attr.getType().dyn_cast(); + auto begin_attr_type = + mlir::dyn_cast(begin_attr.getType()); + auto end_attr_type = mlir::dyn_cast(end_attr.getType()); auto strides_attr_type = - strides_attr.getType().dyn_cast(); + mlir::dyn_cast(strides_attr.getType()); DenseIntElementsAttr begin_elem_attr; DenseIntElementsAttr end_elem_attr; @@ -899,8 +903,8 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern { if (!epsilon) epsilon = rewriter.getFloatAttr(rewriter.getF32Type(), 0.0001f); - if (!(((epsilon.isa<::mlir::FloatAttr>())) && - ((epsilon.cast<::mlir::FloatAttr>().getType().isF32())))) { + if (!(((mlir::isa<::mlir::FloatAttr>(epsilon))) && + ((mlir::cast<::mlir::FloatAttr>(epsilon).getType().isF32())))) { return rewriter.notifyMatchFailure( fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { diag << "op 'tf.FusedBatchNormV3' attribute 'epsilon' failed to " @@ -963,7 +967,7 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern { int64_t last_dim = ShapedType::kDynamic; { auto is_last_dim_compatible = [](const Value &v, int64_t &last_dim) { - auto v_type = v.getType().dyn_cast_or_null(); + auto v_type = mlir::dyn_cast_or_null(v.getType()); if (!v_type) return true; int64_t v_last_dim = v_type.getDimSize(v_type.getRank() - 1); if (v_last_dim == ShapedType::kDynamic) return true; @@ -1007,9 +1011,8 @@ struct FusedBatchNormV3Pat : public ::mlir::RewritePattern { // For training, mean and variance is calculated from input values. if (is_training.getValue()) { - auto input_type = fused_batch_norm_op.getX() - .getType() - .dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null( + fused_batch_norm_op.getX().getType()); if (!input_type || input_type.getRank() != 4) { return rewriter.notifyMatchFailure( fused_batch_norm_op, [&](::mlir::Diagnostic &diag) { @@ -1383,14 +1386,14 @@ struct ConvertRfftToRfft2d : public RewritePattern { auto rfft_op = dyn_cast(op); auto input = rfft_op.getInput(); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type) return failure(); auto fft_len = rfft_op.getFftLength(); - auto fft_len_type = fft_len.getType().dyn_cast_or_null(); + auto fft_len_type = mlir::dyn_cast_or_null(fft_len.getType()); if (!fft_len_type) return failure(); auto output_type = - rfft_op.getResult().getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(rfft_op.getResult().getType()); if (!output_type) return failure(); // Expanded inputs. diff --git a/tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise.cc b/tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise.cc index 363c30ab0b818c..7a8b35e4be7cde 100644 --- a/tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise.cc +++ b/tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project @@ -280,7 +281,7 @@ class CommuteTransposeWithEwiseOps : public RewritePattern { } auto other_input_type = - cst_arg->getResult(0).getType().cast(); + mlir::cast(cst_arg->getResult(0).getType()); Operation *tposed_const; if (other_input_type.getNumElements() == 1) { diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc b/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc index 0d9db051ef27ff..96412f20633f6a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize_variables.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -169,7 +170,7 @@ void QuantizeVariablesPass::QuantizeVariable( for (VarHandleOp var_handle_op : var_handle_ops) { builder.setInsertionPoint(var_handle_op); auto output_type = UnrankedTensorType::get(TF::ResourceType::get( - {ref_qtype.cast()}, builder.getContext())); + {mlir::cast(ref_qtype)}, builder.getContext())); auto new_var_handle_op = builder.create( var_handle_op.getLoc(), output_type, var_handle_op.getContainer(), var_handle_op.getSharedName()); diff --git a/tensorflow/compiler/mlir/lite/transforms/reduce_type_precision.cc b/tensorflow/compiler/mlir/lite/transforms/reduce_type_precision.cc index bee14272020446..659c5aceb39c04 100644 --- a/tensorflow/compiler/mlir/lite/transforms/reduce_type_precision.cc +++ b/tensorflow/compiler/mlir/lite/transforms/reduce_type_precision.cc @@ -62,12 +62,12 @@ class CheckRangeAndConvertI8ToI4 : public OpRewritePattern { LogicalResult matchAndRewrite(arith::ConstantOp op, PatternRewriter &rewriter) const override { - auto const_type = op.getType().dyn_cast(); + auto const_type = mlir::dyn_cast(op.getType()); if (!const_type || !const_type.getElementType().isSignlessInteger(8)) { return failure(); } - auto attr = op.getValue().cast(); + auto attr = mlir::cast(op.getValue()); for (mlir::APInt v : attr.getValues()) { auto v_int = static_cast(*(v.getRawData())); if (v_int > 7 || v_int < -8) { @@ -79,7 +79,7 @@ class CheckRangeAndConvertI8ToI4 : public OpRewritePattern { auto shaped_type = mlir::RankedTensorType::get(const_type.getShape(), builder.getI4Type()); auto newAttr = DenseElementsAttr::getFromRawBuffer( - shaped_type, op.getValue().cast().getRawData()); + shaped_type, mlir::cast(op.getValue()).getRawData()); rewriter.replaceOpWithNewOp(op, newAttr); return success(); @@ -92,8 +92,8 @@ class SanitizeGatherOpOutputToI4 : public OpRewritePattern { LogicalResult matchAndRewrite(TFL::GatherOp op, PatternRewriter &rewriter) const override { - auto const_type = op.getOperand(0).getType().dyn_cast(); - auto result_type = op.getResult().getType().dyn_cast(); + auto const_type = mlir::dyn_cast(op.getOperand(0).getType()); + auto result_type = mlir::dyn_cast(op.getResult().getType()); if (!const_type || !const_type.getElementType().isSignlessInteger(4) || !result_type || !result_type.getElementType().isSignlessInteger(8)) { return failure(); @@ -109,7 +109,8 @@ class SanitizeGatherOpOutputToI4 : public OpRewritePattern { auto new_gather_op = rewriter.create( op.getLoc(), /*result=*/ - op.getResult().getType().cast().clone(builder.getI4Type()), + mlir::cast(op.getResult().getType()) + .clone(builder.getI4Type()), /*operand=*/op.getOperands(), op->getAttrs()); rewriter.replaceAllUsesWith(op.getResult(), new_gather_op.getResult()); diff --git a/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc b/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc index c8999216c8054b..ab03af3a4c062a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc +++ b/tensorflow/compiler/mlir/lite/transforms/reduce_while_operands.cc @@ -104,7 +104,7 @@ void FindProducers(Value start_node, std::vector &neighbors) { while (!queue.empty()) { auto node = queue.back(); queue.pop_back(); - if (auto arg = node.dyn_cast_or_null()) { + if (auto arg = mlir::dyn_cast_or_null(node)) { neighbors.push_back(arg.getArgNumber()); continue; } @@ -149,7 +149,7 @@ bool AllOperationSafe(Block &block) { // Fact: if every op's operands are defined in the same block as op, // then no operation has implicit arugments (constant doesn't count). for (auto operand : op->getOperands()) { - if (operand.dyn_cast_or_null()) continue; + if (mlir::dyn_cast_or_null(operand)) continue; auto operand_op = operand.getDefiningOp(); if (IsConstant(operand_op)) continue; if (operand_op->getBlock() != op->getBlock()) { diff --git a/tensorflow/compiler/mlir/lite/transforms/unfold_large_splat_constant.cc b/tensorflow/compiler/mlir/lite/transforms/unfold_large_splat_constant.cc index 1def97523cd668..2669159b0206bb 100644 --- a/tensorflow/compiler/mlir/lite/transforms/unfold_large_splat_constant.cc +++ b/tensorflow/compiler/mlir/lite/transforms/unfold_large_splat_constant.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -62,7 +63,7 @@ class UnfoldLargeSplatConstantPass void MaybeUnfoldLargeSplatConstant(mlir::OpBuilder* op_builder, mlir::arith::ConstantOp const_op) const { auto splat_elements_attr = - const_op.getValue().dyn_cast(); + mlir::dyn_cast(const_op.getValue()); if (!splat_elements_attr) { return; } diff --git a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc index a3c3ece3dc94a1..013abb6ec0ea80 100644 --- a/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc +++ b/tensorflow/compiler/mlir/lite/transforms/while_loop_outline.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -92,13 +93,13 @@ bool IsCompatibleTypeWithTFLCastOp(Type type) { return true; // Complex> is allowed. - if (elemType.isa() && - elemType.cast().getElementType().isF32()) + if (mlir::isa(elemType) && + mlir::cast(elemType).getElementType().isF32()) return true; // QUINT8 and UI8 are allowed. - if (elemType.isa() || - (elemType.isInteger(8) && elemType.cast().isUnsigned())) + if (mlir::isa(elemType) || + (elemType.isInteger(8) && mlir::cast(elemType).isUnsigned())) return true; return false; diff --git a/tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h b/tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h index ac170f33d9ba85..c851d73b03290d 100644 --- a/tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h +++ b/tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h @@ -17,6 +17,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace TFL { @@ -27,7 +28,7 @@ class ArithmeticCountUtilHelper { static bool GetFirstOutputCount(mlir::Operation* op, int64_t* count) { auto output = op->getResult(0); auto output_type = - output.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(output.getType()); if (!output_type || !output_type.hasStaticShape()) return false; *count = output_type.getNumElements(); @@ -38,7 +39,7 @@ class ArithmeticCountUtilHelper { int64_t total_count = 0; for (auto input : op->getOperands()) { auto input_type = - input.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(input.getType()); if (!input_type || !input_type.hasStaticShape()) { return false; } @@ -54,12 +55,12 @@ class ArithmeticCountUtilHelper { int64_t* count) { auto weight = op->getOperand(1); auto weight_type = - weight.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(weight.getType()); if (weight_type == nullptr || !weight_type.hasStaticShape()) return false; auto output = op->getResult(0); auto output_type = - output.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(output.getType()); if (output_type == nullptr || !output_type.hasStaticShape()) return false; int64_t cols = 1; @@ -73,7 +74,7 @@ class ArithmeticCountUtilHelper { auto bias = op->getOperand(2); if (bias) { auto bias_type = - bias.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(bias.getType()); if (bias_type && bias_type.hasStaticShape()) { *count += output_type.getNumElements(); } diff --git a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc index 20336080cc20d6..1629000ff181df 100644 --- a/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/attribute_utils.cc @@ -15,23 +15,24 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace TFL { FloatAttr ExtractSingleElementAsFloat(ElementsAttr attr) { if (attr.getShapedType().getNumElements() != 1 || - !attr.getShapedType().getElementType().isa()) { + !mlir::isa(attr.getShapedType().getElementType())) { return {}; } return attr.getSplatValue(); } FloatAttr GetSingleElementAsFloatOrSelf(Attribute attr) { - if (auto m = attr.dyn_cast_or_null()) { + if (auto m = mlir::dyn_cast_or_null(attr)) { return ExtractSingleElementAsFloat(m); } else { - return attr.dyn_cast_or_null(); + return mlir::dyn_cast_or_null(attr); } } diff --git a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc index 96d75cca30a48d..d00fec3182525c 100644 --- a/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/const_tensor_utils.cc @@ -131,11 +131,11 @@ StatusOr GetQuantizedType(const TensorT& tensor, Builder builder, if (!storage_type) { const mlir::Type raw_elem_type = ConvertElementType(tensor.type, builder); - if (!raw_elem_type.isa()) { + if (!mlir::isa(raw_elem_type)) { return absl::InvalidArgumentError( "Quantized tensors must be stored as integers"); } - storage_type = raw_elem_type.cast(); + storage_type = mlir::cast(raw_elem_type); } // TFlite uses narrow-range [u]int8 for constant buffers of quantized weights. @@ -254,11 +254,11 @@ mlir::ElementsAttr GetSplat(RankedTensorType type, int unique_index, return DenseElementsAttr::get( type, builder.getIntegerAttr(element_ty, unique_index)); - if (element_ty.isa()) + if (mlir::isa(element_ty)) return DenseElementsAttr::get( type, builder.getFloatAttr(element_ty, unique_index)); - if (auto qtype = element_ty.dyn_cast()) { + if (auto qtype = mlir::dyn_cast(element_ty)) { mlir::RankedTensorType new_type = tensorflow::GetTypeFromTFTensorShape( type.getShape(), qtype.getStorageType()); return DenseElementsAttr::get( @@ -272,9 +272,10 @@ StatusOr ConvertIntBuffer( bool truncate) { mlir::Type elem_type = shaped_type.getElementType(); unsigned bit_width; - if (auto itype = elem_type.dyn_cast()) { + if (auto itype = mlir::dyn_cast(elem_type)) { bit_width = itype.getWidth(); - } else if (auto qtype = elem_type.dyn_cast()) { + } else if (auto qtype = + mlir::dyn_cast(elem_type)) { bit_width = qtype.getStorageTypeIntegralWidth(); shaped_type = tensorflow::GetTypeFromTFTensorShape(shaped_type.getShape(), qtype.getStorageType()); diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc index 8bf3b4f0106604..6a4dbf3e505ba6 100644 --- a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc @@ -56,7 +56,8 @@ absl::StatusOr CreateTypedAttr(ShapedType shaped_type, int value) { } else if (element_type.isF32()) { return DenseElementsAttr::get(shaped_type, static_cast(value)); - } else if (auto complex_type = element_type.dyn_cast()) { + } else if (auto complex_type = + mlir::dyn_cast(element_type)) { auto etype = complex_type.getElementType(); if (etype.isF32()) { tensorflow::TensorProto repr; @@ -77,7 +78,7 @@ absl::StatusOr CreateTypedAttr(ShapedType shaped_type, int value) { return tensorflow::Status(absl::StatusCode::kInvalidArgument, "Unsupported type"); } - } else if (auto itype = element_type.dyn_cast()) { + } else if (auto itype = mlir::dyn_cast(element_type)) { if (element_type.isSignedInteger()) { switch (itype.getWidth()) { case 8: diff --git a/tensorflow/compiler/mlir/lite/utils/convert_type.cc b/tensorflow/compiler/mlir/lite/utils/convert_type.cc index f2e659b9aea9ce..7091e5ad155431 100644 --- a/tensorflow/compiler/mlir/lite/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/lite/utils/convert_type.cc @@ -18,6 +18,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "xla/statusor.h" @@ -40,16 +41,16 @@ tflite::TensorType ConvertTypeToTensorType(mlir::Type type) { return tflite::TensorType_FLOAT32; } else if (type.isF64()) { return tflite::TensorType_FLOAT64; - } else if (type.isa()) { + } else if (mlir::isa(type)) { return tflite::TensorType_STRING; - } else if (auto complex_type = type.dyn_cast()) { + } else if (auto complex_type = mlir::dyn_cast(type)) { if (complex_type.getElementType().isF32()) { return tflite::TensorType_COMPLEX64; } else if (complex_type.getElementType().isF64()) { return tflite::TensorType_COMPLEX128; } llvm_unreachable("invalid complex Type in conversion"); - } else if (auto itype = type.dyn_cast()) { + } else if (auto itype = mlir::dyn_cast(type)) { switch (itype.getWidth()) { case 1: return tflite::TensorType_BOOL; @@ -209,7 +210,7 @@ absl::StatusOr TfTypeToTflType(tensorflow::DataType type) { mlir::Type GetShapeStrippedType(mlir::TypeAttr type_attr) { auto type = type_attr.getValue(); - auto shaped_type = type.dyn_cast(); + auto shaped_type = mlir::dyn_cast(type); if (shaped_type) { return shaped_type.getElementType(); } else { diff --git a/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h b/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h index 77b047f68c6bf2..d1dcf8c304b0a9 100644 --- a/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h @@ -123,7 +123,7 @@ class InsertTFLQuantOpsAfterTFFakeQuantOp { int quant_dim = -1; if (PerAxis) { // This is a special case that the quant_dim is the last dimensions. - quant_dim = res.getType().template cast().getRank() - 1; + quant_dim = mlir::cast(res.getType()).getRank() - 1; } // Use the min/max from the operands and the num_bits and narrow_range // attribute to create the quantization parameter for the new quantize op. diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc index 0a563238635d20..bada49a68a9e55 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc @@ -127,7 +127,7 @@ Value Reverse(OpBuilder* builder, Value value_to_reverse, int axis, } ArrayRef GetRankedTensorShape(Value value) { - return value.getType().cast().getShape(); + return mlir::cast(value.getType()).getShape(); } Value SliceRankedTensor(OpBuilder* builder, Value input, @@ -159,7 +159,7 @@ Value SliceRankedTensor(OpBuilder* builder, Value input, location, RankedTensorType::get( size_values, - input.getType().cast().getElementType()), + mlir::cast(input.getType()).getElementType()), input, slice_i2c_begin, slice_i2c_size); } @@ -170,7 +170,8 @@ Value CreateStridedSliceOp(mlir::Location loc, ArrayRef output_shape, int64_t ellipsis_mask, int64_t new_axis_mask, int64_t shrink_axis_mask, OpBuilder* builder) { auto output_type = RankedTensorType::get( - output_shape, input.getType().cast().getElementType()); + output_shape, + mlir::cast(input.getType()).getElementType()); auto begin_tensor = CreateI32DenseConst(builder, begin, loc); auto end_tensor = CreateI32DenseConst(builder, end, loc); auto strides_tensor = CreateI32DenseConst(builder, strides, loc); @@ -387,7 +388,8 @@ void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() { SmallVector output_shape{1, tensorflow::kTFDynamicSize}; auto input_types = fused_func_op_.getFunctionType().getInputs(); auto output_type = tensorflow::GetTypeFromTFTensorShape( - output_shape, input_.getType().cast().getElementType()); + output_shape, + mlir::cast(input_.getType()).getElementType()); fused_func_op_.setType(mlir::FunctionType::get(fused_func_op_.getContext(), input_types, output_type)); } @@ -410,7 +412,8 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() { // Create the fused LSTM op. SmallVector output_shape = {1, n_output_}; auto result_type = mlir::RankedTensorType::get( - output_shape, input_.getType().cast().getElementType()); + output_shape, + mlir::cast(input_.getType()).getElementType()); lstm_ = builder_.create( fused_func_op_.getLoc(), result_type, input_, input2input_, input2forget_, input2cell_, input2output_, rec2input_, rec2forget_, rec2cell_, @@ -436,7 +439,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() { SmallVector func_output_shape = {1, tensorflow::kTFDynamicSize}; auto func_result_type = tensorflow::GetTypeFromTFTensorShape( func_output_shape, - input_.getType().cast().getElementType()); + mlir::cast(input_.getType()).getElementType()); auto tensor_cast = builder_.create( fused_func_op_.getLoc(), func_result_type, lstm_.getResult()); @@ -491,7 +494,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() { bias_ = fused_func_op_.getArgument(2); weight_ = fused_func_op_.getArgument(1); - weight_type_ = weight_.getType().cast(); + weight_type_ = mlir::cast(weight_.getType()); if (weight_type_.getRank() != 2) { return fused_func_op_.emitError() << "The weight tensor was not of rank 2"; @@ -505,7 +508,7 @@ LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() { n_cell_ = weight_type_.getDimSize(1) / num_gates_; projection_ = fused_func_op_.getArgument(3); - projection_type_ = projection_.getType().cast(); + projection_type_ = mlir::cast(projection_.getType()); if (projection_type_.getRank() != 2) { n_output_ = n_cell_; } else { @@ -532,7 +535,8 @@ LogicalResult ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::Initialize() { } layer_norm_scale_ = fused_func_op_.getArgument(4); - layer_norm_scale_type_ = layer_norm_scale_.getType().cast(); + layer_norm_scale_type_ = + mlir::cast(layer_norm_scale_.getType()); if (layer_norm_scale_type_.getRank() != 1) { return fused_func_op_.emitError() << "The layer_norm_scale tensor was not of rank 1"; @@ -607,7 +611,7 @@ TF::ReshapeOp CreateFlattenOP(const Value& input, Location loc, LogicalResult CreateEqualSizeSplitVOp(Value input, int axis, int splits, Location loc, OpBuilder* builder, Operation** result) { - auto input_type = input.getType().cast(); + auto input_type = mlir::cast(input.getType()); SmallVector output_shape; int size_of_splits; if (input_type.getRank() < axis || axis < 0) return failure(); @@ -666,7 +670,7 @@ LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, if (time_major_attr == nullptr) return failure(); bool time_majored = time_major_attr.getValue(); - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type) { func_op.emitError() << "Input type is not a ranked tensor type"; return failure(); @@ -692,7 +696,7 @@ LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, // Setup correct weights. RankedTensorType weight_type = - weight_kernel.getType().cast(); + mlir::cast(weight_kernel.getType()); if (weight_type.getRank() != 2) return func_op.emitError() << "The weight should be rank of 2"; @@ -700,7 +704,7 @@ LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, Transpose2D(builder, weight_kernel, weight_type, func_op.getLoc()); RankedTensorType recurrent_kernel_type = - recurrent_kernel.getType().cast(); + mlir::cast(recurrent_kernel.getType()); const int64_t n_output = recurrent_kernel_type.getDimSize(0); Value transpose_recurrent_kernel = Transpose2D( @@ -726,28 +730,28 @@ LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, // IndyLSTMs are a LSTM variant with diagonal recurrent weight // matrices. For optimization purposes these are provided as vectors. Value recurrent_to_input_weights = - indy ? CreateFlattenOP(recurrent_weights_array->getResult(0), - func_op.getLoc(), builder) - .getResult() - .cast() + indy ? mlir::cast( + CreateFlattenOP(recurrent_weights_array->getResult(0), + func_op.getLoc(), builder) + .getResult()) : recurrent_weights_array->getResult(0); Value recurrent_to_forget_weights = - indy ? CreateFlattenOP(recurrent_weights_array->getResult(1), - func_op.getLoc(), builder) - .getResult() - .cast() + indy ? mlir::cast( + CreateFlattenOP(recurrent_weights_array->getResult(1), + func_op.getLoc(), builder) + .getResult()) : recurrent_weights_array->getResult(1); Value recurrent_to_cell_weights = - indy ? CreateFlattenOP(recurrent_weights_array->getResult(2), - func_op.getLoc(), builder) - .getResult() - .cast() + indy ? mlir::cast( + CreateFlattenOP(recurrent_weights_array->getResult(2), + func_op.getLoc(), builder) + .getResult()) : recurrent_weights_array->getResult(2); Value recurrent_to_output_weights = - indy ? CreateFlattenOP(recurrent_weights_array->getResult(3), - func_op.getLoc(), builder) - .getResult() - .cast() + indy ? mlir::cast( + CreateFlattenOP(recurrent_weights_array->getResult(3), + func_op.getLoc(), builder) + .getResult()) : recurrent_weights_array->getResult(3); // Splits the bias into 4: @@ -765,7 +769,7 @@ LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, } auto result_type = mlir::RankedTensorType::get( output_shape, - final_inputs.getType().cast().getElementType()); + mlir::cast(final_inputs.getType()).getElementType()); Value none = CreateNoneValue(builder, func_op.getLoc()); auto lstm = builder->create( @@ -866,7 +870,8 @@ LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, // All the rest: states, device. for (int i = 2; i < 5; ++i) { - auto result_type = func_op.getResultTypes()[i].dyn_cast(); + auto result_type = + mlir::dyn_cast(func_op.getResultTypes()[i]); outputs.push_back(CreatTfF32ConstOp(builder, result_type.getShape(), 0.0f, func_op.getLoc())); output_types.push_back(result_type); diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc index 342bbb5c7fe382..7fe7ae8404137c 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc @@ -134,22 +134,18 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) { auto transpose_op = fused_lstm_func_.getBody().front().begin(); transpose_op++; - EXPECT_EQ( - transpose_op->getOperand(0).getType().cast().getDimSize( - 0), - 3); - EXPECT_EQ( - transpose_op->getOperand(0).getType().cast().getDimSize( - 1), - 12); - EXPECT_EQ( - transpose_op->getResult(0).getType().cast().getDimSize( - 0), - 12); - EXPECT_EQ( - transpose_op->getResult(0).getType().cast().getDimSize( - 1), - 3); + EXPECT_EQ(mlir::cast(transpose_op->getOperand(0).getType()) + .getDimSize(0), + 3); + EXPECT_EQ(mlir::cast(transpose_op->getOperand(0).getType()) + .getDimSize(1), + 12); + EXPECT_EQ(mlir::cast(transpose_op->getResult(0).getType()) + .getDimSize(0), + 12); + EXPECT_EQ(mlir::cast(transpose_op->getResult(0).getType()) + .getDimSize(1), + 3); auto it = fused_lstm_func_.getBody().back().rbegin(); EXPECT_EQ(it->getName().getStringRef(), @@ -161,33 +157,31 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimple) { EXPECT_EQ(it->getNumOperands(), 24); EXPECT_EQ(it->getNumResults(), 1); // cifg = false, so input2input is not None. - EXPECT_FALSE(it->getOperand(1).getType().isa()); + EXPECT_FALSE(mlir::isa(it->getOperand(1).getType())); // input layer norm is None - EXPECT_TRUE(it->getOperand(20).getType().isa()); + EXPECT_TRUE(mlir::isa(it->getOperand(20).getType())); // proj_bias is F32 - EXPECT_TRUE(it->getOperand(17) - .getType() - .cast() + EXPECT_TRUE(mlir::cast(it->getOperand(17).getType()) .getElementType() .isF32()); // output gate bias is 0 since it is out of bounds of the bias tensor, so // we set its value as a const tensor of specified size and value 0. - EXPECT_TRUE(mlir::cast( - it->getOpOperand(15).get().getDefiningOp()) - .getValue() - .cast() - .getValues()[0] - .getValue() - .isExactlyValue(0.0f)); + EXPECT_TRUE( + mlir::cast(mlir::cast( + it->getOpOperand(15).get().getDefiningOp()) + .getValue()) + .getValues()[0] + .getValue() + .isExactlyValue(0.0f)); EXPECT_EQ(fused_lstm_func_.getFunctionType().getNumResults(), 1); auto output_types = fused_lstm_func_.getFunctionType().getResults(); SmallVector output_shape{1, mlir::ShapedType::kDynamic}; - EXPECT_EQ(output_types[0].cast().getShape().size(), + EXPECT_EQ(mlir::cast(output_types[0]).getShape().size(), output_shape.size()); for (int i = 0; i < output_shape.size(); i++) { - EXPECT_EQ(output_types[0].cast().getDimSize(i), + EXPECT_EQ(mlir::cast(output_types[0]).getDimSize(i), output_shape[i]); } } @@ -215,7 +209,7 @@ TEST_F(LstmUtilsTest, ConvertLSTMCellSimpleToFusedLSTMCoupleInputForget) { EXPECT_EQ(it->getNumOperands(), 24); EXPECT_EQ(it->getNumResults(), 1); // cifg = true, so input2input is None. - EXPECT_TRUE(it->getOperand(1).getType().isa()); + EXPECT_TRUE(mlir::isa(it->getOperand(1).getType())); } TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) { @@ -242,23 +236,25 @@ TEST_F(LstmUtilsTest, ConvertLayerNormLSTMCellSimpleToFusedLSTM) { EXPECT_EQ(it->getNumOperands(), 24); EXPECT_EQ(it->getNumResults(), 1); // cifg = false, so input2input is not None. - EXPECT_FALSE(it->getOperand(1).getType().isa()); + EXPECT_FALSE(mlir::isa(it->getOperand(1).getType())); // input layer norm - EXPECT_FALSE(it->getOperand(20).getType().isa()); + EXPECT_FALSE(mlir::isa(it->getOperand(20).getType())); + EXPECT_EQ(mlir::cast(it->getOperand(20).getType()) + .getShape() + .size(), + 1); EXPECT_EQ( - it->getOperand(20).getType().cast().getShape().size(), - 1); - EXPECT_EQ(it->getOperand(20).getType().cast().getDimSize(0), - 3); + mlir::cast(it->getOperand(20).getType()).getDimSize(0), + 3); EXPECT_EQ(fused_ln_lstm_func_.getFunctionType().getNumResults(), 1); auto output_types = fused_ln_lstm_func_.getFunctionType().getResults(); SmallVector output_shape{1, mlir::ShapedType::kDynamic}; - EXPECT_EQ(output_types[0].cast().getShape().size(), + EXPECT_EQ(mlir::cast(output_types[0]).getShape().size(), output_shape.size()); for (int i = 0; i < output_shape.size(); i++) { - EXPECT_EQ(output_types[0].cast().getDimSize(i), + EXPECT_EQ(mlir::cast(output_types[0]).getDimSize(i), output_shape[i]); } } diff --git a/tensorflow/compiler/mlir/lite/utils/nms_utils.cc b/tensorflow/compiler/mlir/lite/utils/nms_utils.cc index 5633068509faf4..cab3df456c0e00 100644 --- a/tensorflow/compiler/mlir/lite/utils/nms_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/nms_utils.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" namespace mlir { @@ -74,7 +75,7 @@ LogicalResult ConvertNMSPaddedFunc::VerifySignature() { // The TFLite fused op does not support batching yet. // TODO(b/158709815): Add support for batches with padded NMS. auto boxes_type = - func_.getFunctionType().getInput(0).dyn_cast(); + mlir::dyn_cast(func_.getFunctionType().getInput(0)); if (boxes_type == nullptr || !boxes_type.hasRank() || boxes_type.getRank() != 2) { return func_.emitWarning() << "TFLite does not support batched input for " @@ -121,7 +122,7 @@ LogicalResult ConvertSSDPostProcessFunc::CreateNMSCustomOptions( failed(AddFloatAttr(func, attrs, "w_scale", &fbb))) return failure(); auto use_regular_nms = - attrs.get("use_regular_nms").dyn_cast_or_null(); + mlir::dyn_cast_or_null(attrs.get("use_regular_nms")); if (!use_regular_nms) { return func.emitError() << "use_regular_nms attribute is not set or not a bool"; @@ -137,7 +138,7 @@ LogicalResult ConvertSSDPostProcessFunc::CreateNMSCustomOptions( LogicalResult ConvertSSDPostProcessFunc::AddIntAttr( func::FuncOp func, DictionaryAttr attrs, const std::string& attribute, flexbuffers::Builder* builder) { - auto int_attr = attrs.get(attribute).dyn_cast_or_null(); + auto int_attr = mlir::dyn_cast_or_null(attrs.get(attribute)); if (!int_attr) { return func.emitError() << attribute.c_str() << " attribute is not set or not an integer"; @@ -149,7 +150,7 @@ LogicalResult ConvertSSDPostProcessFunc::AddIntAttr( LogicalResult ConvertSSDPostProcessFunc::AddFloatAttr( func::FuncOp func, DictionaryAttr attrs, const std::string& attribute, flexbuffers::Builder* builder) { - auto float_attr = attrs.get(attribute).dyn_cast_or_null(); + auto float_attr = mlir::dyn_cast_or_null(attrs.get(attribute)); if (!float_attr) { return func.emitError() << attribute.c_str() << " attribute is not set or not a float"; @@ -160,7 +161,7 @@ LogicalResult ConvertSSDPostProcessFunc::AddFloatAttr( LogicalResult ConvertSSDPostProcessFunc::HasIntAttr( func::FuncOp func, DictionaryAttr attrs, const std::string& attribute) { - auto int_attr = attrs.get(attribute).dyn_cast_or_null(); + auto int_attr = mlir::dyn_cast_or_null(attrs.get(attribute)); if (!int_attr) { return func.emitWarning() << attribute.c_str() << " attribute is not set or not an integer"; @@ -170,7 +171,7 @@ LogicalResult ConvertSSDPostProcessFunc::HasIntAttr( LogicalResult ConvertSSDPostProcessFunc::HasFloatAttr( func::FuncOp func, DictionaryAttr attrs, const std::string& attribute) { - auto float_attr = attrs.get(attribute).dyn_cast_or_null(); + auto float_attr = mlir::dyn_cast_or_null(attrs.get(attribute)); if (!float_attr) { return func.emitWarning() << attribute.c_str() << " attribute is not set or not a float"; diff --git a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc index c7944b67406907..98cc1048bced90 100644 --- a/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/perception_ops_utils.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -45,14 +46,15 @@ inline LogicalResult HasIntegerArrayWithSize(func::FuncOp* func, const DictionaryAttr& attrs, const std::string& attr_name, int N) { - ArrayAttr array_attr = attrs.get(attr_name).dyn_cast_or_null(); + ArrayAttr array_attr = + mlir::dyn_cast_or_null(attrs.get(attr_name)); if (array_attr == nullptr || array_attr.size() != N) { return func->emitWarning() << "'" << attr_name << "' attribute for " << kMaxUnpooling << " must be set and has size of " << N; } for (Attribute integer_attr : array_attr.getValue()) { - IntegerAttr value = integer_attr.dyn_cast(); + IntegerAttr value = mlir::dyn_cast(integer_attr); if (!value) { return func->emitWarning() << "'" << attr_name << "' attribute for " << kMaxUnpooling @@ -66,7 +68,8 @@ inline LogicalResult GetIntegerArraySafe( func::FuncOp* func, const DictionaryAttr& attrs, const std::string& attr_name, llvm::SmallVectorImpl* results, int N) { - ArrayAttr array_attr = attrs.get(attr_name).dyn_cast_or_null(); + ArrayAttr array_attr = + mlir::dyn_cast_or_null(attrs.get(attr_name)); if (array_attr == nullptr || array_attr.size() != N) { return func->emitError() << "'" << attr_name << "' attribute for " << kMaxUnpooling @@ -75,7 +78,7 @@ inline LogicalResult GetIntegerArraySafe( results->reserve(N); for (Attribute integer_attr : array_attr.getValue()) { - IntegerAttr value = integer_attr.dyn_cast(); + IntegerAttr value = mlir::dyn_cast(integer_attr); if (!value) { return func->emitError() << "'" << attr_name << "' attribute for " << kMaxUnpooling @@ -132,7 +135,7 @@ LogicalResult ConvertMaxUnpoolingFunc::VerifySignature() { } // Retrieves padding. - auto padding = attrs.get("padding").dyn_cast_or_null(); + auto padding = mlir::dyn_cast_or_null(attrs.get("padding")); if (!padding) { return func_.emitWarning() << "'padding' attribute for " << kMaxUnpooling << " is not set or not a string"; @@ -166,7 +169,7 @@ LogicalResult ConvertMaxUnpoolingFunc::CreateCustomOptions( pool_params.stride_width = strides[1]; // Retrieves padding. - auto padding = attrs.get("padding").dyn_cast_or_null(); + auto padding = mlir::dyn_cast_or_null(attrs.get("padding")); if (!padding) { return func_.emitError() << "'padding' attribute for " << kMaxUnpooling << " is not set or not a string"; @@ -224,22 +227,22 @@ LogicalResult ConvertDenseImageWarpFunc::VerifySignature() { } // Check types and shapes. - auto image_type = - func_.getFunctionType().getInput(0).dyn_cast_or_null(); + auto image_type = mlir::dyn_cast_or_null( + func_.getFunctionType().getInput(0)); if (!image_type || !image_type.getElementType().isF32() || image_type.getRank() != 4) { return func_.emitWarning() << "Image should be a 4D float tensor"; } - auto flow_type = - func_.getFunctionType().getInput(1).dyn_cast_or_null(); + auto flow_type = mlir::dyn_cast_or_null( + func_.getFunctionType().getInput(1)); if (!flow_type || !flow_type.getElementType().isF32() || flow_type.getRank() != 4) { return func_.emitWarning() << "Flow should be a 4D float tensor"; } - auto output_type = - func_.getFunctionType().getResult(0).dyn_cast_or_null(); + auto output_type = mlir::dyn_cast_or_null( + func_.getFunctionType().getResult(0)); if (!output_type || !output_type.getElementType().isF32() || output_type.getRank() != 4) { return func_.emitWarning() << "Output should be a 4D float tensor"; diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc index 7ce9c56086e691..5e9bcc16d27537 100644 --- a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc @@ -62,11 +62,13 @@ inline ConstBytesAttr CustomOption(OpBuilder* builder, } inline TensorType GetInputType(func::FuncOp func, int idx) { - return func.getFunctionType().getInput(idx).dyn_cast_or_null(); + return mlir::dyn_cast_or_null( + func.getFunctionType().getInput(idx)); } inline TensorType GetResultType(func::FuncOp func, int idx) { - return func.getFunctionType().getResult(idx).dyn_cast_or_null(); + return mlir::dyn_cast_or_null( + func.getFunctionType().getResult(idx)); } inline bool RankEquals(const TensorType& type, int rank) { @@ -89,7 +91,7 @@ LogicalResult VerifyWhitespaceTokenizer(func::FuncOp func) { // * 2nd output is the inner offset; // * 3rd output is the outer offset. auto input_type = GetInputType(func, 0); - if (!input_type || !input_type.getElementType().isa() || + if (!input_type || !mlir::isa(input_type.getElementType()) || !input_type.hasRank()) { return func.emitError() << "Input should be a string tensor"; } @@ -107,7 +109,7 @@ LogicalResult VerifyWhitespaceTokenizer(func::FuncOp func) { auto value_type = GetResultType(func, 0); if (!RankEquals(value_type, 1) || - !value_type.getElementType().isa()) { + !mlir::isa(value_type.getElementType())) { return func.emitError() << "1st output should be string tensor"; } if (func.getNumResults() > 1) { @@ -157,12 +159,14 @@ LogicalResult VerifyNgrams(func::FuncOp func) { int row_splits = func.getFunctionType().getInputs().size() - kRowSplits; if (row_splits == 0) { auto input_values = GetInputType(func, kValues); - if (!input_values || !input_values.getElementType().isa()) { + if (!input_values || + !mlir::isa(input_values.getElementType())) { return func.emitError() << "Input " << kValues << " should be a string tensor"; } auto output_values = GetResultType(func, kValues); - if (!output_values || !output_values.getElementType().isa()) { + if (!output_values || + !mlir::isa(output_values.getElementType())) { return func.emitError() << "Output " << kValues << " should be a string tensor"; } @@ -175,13 +179,13 @@ LogicalResult VerifyNgrams(func::FuncOp func) { } else { auto input_values = GetInputType(func, kValues); if (!RankEquals(input_values, 1) || - !input_values.getElementType().isa()) { + !mlir::isa(input_values.getElementType())) { return func.emitError() << "Input " << kValues << " should be a 1D string tensor"; } auto output_values = GetResultType(func, kValues); if (!RankEquals(output_values, 1) || - !output_values.getElementType().isa()) { + !mlir::isa(output_values.getElementType())) { return func.emitError() << "Output " << kValues << " should be a 1D string tensor"; } @@ -211,14 +215,14 @@ LogicalResult CreateNgramsCustomOption(func::FuncOp func, DictionaryAttr attrs, flexbuffers::Builder fbb; size_t start_map = fbb.StartMap(); - auto width = attrs.get("width").dyn_cast_or_null(); + auto width = mlir::dyn_cast_or_null(attrs.get("width")); if (!width) { return func.emitError() << "'width' attribute is not set or not an integer"; } fbb.Int("width", width.getInt()); auto string_separator = - attrs.get("string_separator").dyn_cast_or_null(); + mlir::dyn_cast_or_null(attrs.get("string_separator")); if (!string_separator) { return func.emitError() << "'string_separator' attribute is not set or not a string"; @@ -229,14 +233,14 @@ LogicalResult CreateNgramsCustomOption(func::FuncOp func, DictionaryAttr attrs, string_separator.getValue().size()); fbb.String("string_separator", string_separator_str); - auto axis = attrs.get("axis").dyn_cast_or_null(); + auto axis = mlir::dyn_cast_or_null(attrs.get("axis")); if (!axis) { return func.emitError() << "'axis' attribute is not set or not an integer"; } fbb.Int("axis", axis.getInt()); auto reduction_type = - attrs.get("reduction_type").dyn_cast_or_null(); + mlir::dyn_cast_or_null(attrs.get("reduction_type")); if (!reduction_type) { return func.emitError() << "'reduction_type' attribute is not set or not a string"; @@ -277,23 +281,23 @@ LogicalResult VerifySgnnProjection(func::FuncOp func, FuncAttr attr) { return func.emitError() << "Mismatched number of inputs and outputs."; } auto values_type = GetInputType(func, 0); - if (!values_type || !values_type.getElementType().isa()) { + if (!values_type || !mlir::isa(values_type.getElementType())) { return func.emitError() << "First input should be a string tensor"; } auto row_splits_type = GetInputType(func, 1); if (!row_splits_type || - !row_splits_type.getElementType().isa()) { + !mlir::isa(row_splits_type.getElementType())) { return func.emitError() << "Second input should be an integer tensor"; } auto hash_seed = - attr.getAttrs().get("hash_seed").dyn_cast_or_null(); + mlir::dyn_cast_or_null(attr.getAttrs().get("hash_seed")); if (!hash_seed) { return func.emitError() << "'hash_seed' attribute is not set or not an array"; } auto output_type = GetResultType(func, 0); - if (!output_type || !output_type.getElementType().isa() || + if (!output_type || !mlir::isa(output_type.getElementType()) || !RankEquals(output_type, 2)) { return func.emitError() << "Output should be a 2D float tensor."; } @@ -302,7 +306,8 @@ LogicalResult VerifySgnnProjection(func::FuncOp func, FuncAttr attr) { << "Output 2nd dimension should be the num of hash seeds."; } - auto buckets = attr.getAttrs().get("buckets").dyn_cast_or_null(); + auto buckets = + mlir::dyn_cast_or_null(attr.getAttrs().get("buckets")); if (!buckets) { return func.emitError() << "'buckets' attribute is not set or not int"; } @@ -316,15 +321,16 @@ LogicalResult CreateSgnnProjectionCustomOption( flexbuffers::Builder fbb; size_t start_map = fbb.StartMap(); - auto hash_seed = attrs.get("hash_seed").dyn_cast_or_null(); + auto hash_seed = mlir::dyn_cast_or_null(attrs.get("hash_seed")); auto vector_start = fbb.StartVector("hash_seed"); for (int i = 0; i < hash_seed.size(); i++) { fbb.Add(static_cast( - (hash_seed.getValue().data() + i)->dyn_cast().getInt())); + mlir::dyn_cast(*(hash_seed.getValue().data() + i)) + .getInt())); } fbb.EndVector(vector_start, /*typed=*/true, /*fixed=*/false); - auto buckets = attrs.get("buckets").dyn_cast_or_null(); + auto buckets = mlir::dyn_cast_or_null(attrs.get("buckets")); fbb.Int("buckets", buckets.getInt()); fbb.EndMap(start_map); diff --git a/tensorflow/compiler/mlir/lite/utils/utils.h b/tensorflow/compiler/mlir/lite/utils/utils.h index 9fce1bc44387c3..524cff0556ea76 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.h +++ b/tensorflow/compiler/mlir/lite/utils/utils.h @@ -46,7 +46,7 @@ inline bool OpHasSameStaticShapes(Operation* op) { int operand_num = 0; ArrayRef shape; for (Value value : values) { - auto shaped_type = value.getType().dyn_cast(); + auto shaped_type = mlir::dyn_cast(value.getType()); if (!shaped_type || !shaped_type.hasStaticShape()) { return false; } @@ -64,14 +64,14 @@ inline bool OpHasSameStaticShapes(Operation* op) { // Checks if all elements in the constant attribute value are 1. inline bool IsAllOnesConstant(Attribute value) { - auto values = value.cast().getValues(); + auto values = mlir::cast(value).getValues(); return !std::any_of(values.begin(), values.end(), [](int32_t element_value) { return element_value != 1; }); } // Checks if all elements in the constant attribute value are non-negative. inline bool HasNonNegativeValues(Attribute value) { - auto values = value.cast().getValues(); + auto values = mlir::cast(value).getValues(); return !std::any_of( values.begin(), values.end(), [](const APInt& element_value) { return element_value.isNegative(); }); @@ -79,8 +79,8 @@ inline bool HasNonNegativeValues(Attribute value) { // Utility function to get the offset between two dense attribute values. inline TypedAttr GetOffSet(Attribute begin, Attribute end) { - auto begin_values = begin.cast().getValues(); - auto end_values = end.cast().getValues(); + auto begin_values = mlir::cast(begin).getValues(); + auto end_values = mlir::cast(end).getValues(); SmallVector offsets; if (begin_values.size() == end_values.size()) { @@ -118,7 +118,7 @@ inline bool AreLastTwoDimsTransposed(Value permutation) { // Gets the new type after transposing the last 2 dimensions. inline Type TransposeLastTwoDims(Type type) { - auto shaped_type = type.dyn_cast(); + auto shaped_type = mlir::dyn_cast(type); if (!shaped_type.hasStaticShape() || shaped_type.getRank() < 2) { return nullptr; } @@ -136,7 +136,7 @@ inline Type TransposeLastTwoDims(Type type) { // applying the permutation to the given shape through a transpose. inline ShapedType GetTransposedType(Value input, llvm::ArrayRef permutation_array) { - auto input_type = input.getType().cast(); + auto input_type = mlir::cast(input.getType()); if (permutation_array.size() != input_type.getRank()) { return nullptr; } @@ -153,7 +153,8 @@ inline ShapedType GetTransposedType(Value input, // Precondition: output_val's is ranked tensor. // Returns a truncated shape when `truncate` is set to true. inline DenseElementsAttr GetShape(Value output_val, bool truncate = false) { - auto output_shape = output_val.getType().dyn_cast().getShape(); + auto output_shape = + mlir::dyn_cast(output_val.getType()).getShape(); SmallVector shape; shape.reserve(output_shape.size()); diff --git a/tensorflow/compiler/mlir/lite/utils/validators.cc b/tensorflow/compiler/mlir/lite/utils/validators.cc index f4714e00e5f2a4..902d7b144ba69d 100644 --- a/tensorflow/compiler/mlir/lite/utils/validators.cc +++ b/tensorflow/compiler/mlir/lite/utils/validators.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace TFL { @@ -36,45 +37,45 @@ bool TFIntListIs1XY1(Operation *op, StringRef name, IntegerAttr *x, auto elements = attr.getValue(); if (elements.size() != 4 || std::any_of(elements.begin(), elements.end(), - [](Attribute e) { return !e.isa(); })) + [](Attribute e) { return !mlir::isa(e); })) return false; - if (elements.front().cast().getInt() != 1 || - elements.back().cast().getInt() != 1) + if (mlir::cast(elements.front()).getInt() != 1 || + mlir::cast(elements.back()).getInt() != 1) return false; Builder b(op->getContext()); - *x = b.getI32IntegerAttr(elements[1].cast().getInt()); - *y = b.getI32IntegerAttr(elements[2].cast().getInt()); + *x = b.getI32IntegerAttr(mlir::cast(elements[1]).getInt()); + *y = b.getI32IntegerAttr(mlir::cast(elements[2]).getInt()); return true; } // Returns true if the attribute is an integer list of the form [1, X, Y, 1]. bool TFIntListIs1XY1(const Attribute attr) { - const auto &elements = attr.cast().getValue(); + const auto &elements = mlir::cast(attr).getValue(); if (elements.size() != 4 || std::any_of(elements.begin(), elements.end(), - [](Attribute e) { return !e.isa(); })) + [](Attribute e) { return !mlir::isa(e); })) return false; - if (elements.front().cast().getValue() != 1 || - elements.back().cast().getValue() != 1) + if (mlir::cast(elements.front()).getValue() != 1 || + mlir::cast(elements.back()).getValue() != 1) return false; return true; } // Returns true if the attribute is an integer list of the form [1, 1, X, Y]. bool TFIntListIs11XY(const Attribute attr) { - const auto &elements = attr.cast().getValue(); + const auto &elements = mlir::cast(attr).getValue(); if (elements.size() != 4 || std::any_of(elements.begin(), elements.end(), - [](Attribute e) { return !e.isa(); })) + [](Attribute e) { return !mlir::isa(e); })) return false; const Attribute *data = elements.data(); - if (data[0].cast().getValue() != 1 || - data[1].cast().getValue() != 1) + if (mlir::cast(data[0]).getValue() != 1 || + mlir::cast(data[1]).getValue() != 1) return false; return true; } @@ -91,17 +92,17 @@ bool TFIntListIs1XYZ1(Operation *op, StringRef name, IntegerAttr *x, auto elements = attr.getValue(); if (elements.size() != 5 || std::any_of(elements.begin(), elements.end(), - [](Attribute e) { return !e.isa(); })) + [](Attribute e) { return !mlir::isa(e); })) return false; - if (elements.front().cast().getInt() != 1 || - elements.back().cast().getInt() != 1) + if (mlir::cast(elements.front()).getInt() != 1 || + mlir::cast(elements.back()).getInt() != 1) return false; Builder b(op->getContext()); - *x = b.getI32IntegerAttr(elements[1].cast().getInt()); - *y = b.getI32IntegerAttr(elements[2].cast().getInt()); - *z = b.getI32IntegerAttr(elements[3].cast().getInt()); + *x = b.getI32IntegerAttr(mlir::cast(elements[1]).getInt()); + *y = b.getI32IntegerAttr(mlir::cast(elements[2]).getInt()); + *z = b.getI32IntegerAttr(mlir::cast(elements[3]).getInt()); return true; } @@ -109,10 +110,10 @@ bool TFIntListIs1XYZ1(Operation *op, StringRef name, IntegerAttr *x, // Returns true if every element of the attribute is 1. All elements of `attr` // must be `IntegerAttr`. bool TFIntListIsAllOnes(const Attribute attr) { - const auto &elements = attr.cast().getValue(); + const auto &elements = mlir::cast(attr).getValue(); return !std::any_of(elements.begin(), elements.end(), [](Attribute e) { - return e.cast().getValue() != 1; + return mlir::cast(e).getValue() != 1; }); } @@ -133,7 +134,7 @@ bool IsDimensionsDegenerateExceptLastOne(ArrayRef elements_shape) { } bool IsDimensionsDegenerateExceptLastOne(TypedAttr val) { - if (auto ranked_type = val.getType().dyn_cast()) { + if (auto ranked_type = mlir::dyn_cast(val.getType())) { return IsDimensionsDegenerateExceptLastOne(ranked_type.getShape()); } return false; diff --git a/tensorflow/compiler/mlir/lite/utils/validators.h b/tensorflow/compiler/mlir/lite/utils/validators.h index 08d2e7b068b4be..0e7370c5fa499b 100644 --- a/tensorflow/compiler/mlir/lite/utils/validators.h +++ b/tensorflow/compiler/mlir/lite/utils/validators.h @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace TFL { @@ -70,21 +71,21 @@ bool TFIntListIsAllOnes(Attribute attr); // Returns true iff the given value is a float32 tensor. // is "DT_FLOAT". inline bool TFTypeIsFloat32Tensor(Value value) { - auto tensorType = value.getType().dyn_cast(); + auto tensorType = mlir::dyn_cast(value.getType()); if (!tensorType) return false; return tensorType.getElementType().isF32(); } // Returns true iff the given value is a bf16 tensor. inline bool TFTypeIsBFloat16Tensor(Value value) { - auto tensorType = value.getType().dyn_cast(); + auto tensorType = mlir::dyn_cast(value.getType()); if (!tensorType) return false; return tensorType.getElementType().isBF16(); } // Returns true iff the given value is a f16 tensor. inline bool TFTypeIsHalfTensor(Value value) { - auto tensorType = value.getType().dyn_cast(); + auto tensorType = mlir::dyn_cast(value.getType()); if (!tensorType) return false; return tensorType.getElementType().isF16(); } diff --git a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc index f5912553f10dbe..8f3261f6574ff7 100644 --- a/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc +++ b/tensorflow/compiler/mlir/op_or_arg_name_mapper.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/utils/name_utils.h" static inline absl::string_view StringRefToView(llvm::StringRef ref) { @@ -123,7 +124,7 @@ std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) { // If the location is none of the expected types, then simply use name // generated using the op type. Follow TF convention and append the result // index unless 0. - if (auto result = val.dyn_cast()) { + if (auto result = mlir::dyn_cast(val)) { if (result.getResultNumber() > 0) return llvm::formatv("{0}:{1}", result.getOwner()->getName().getStringRef(), @@ -131,7 +132,7 @@ std::string OpOrArgLocNameMapper::GetName(OpOrVal op_or_val) { return std::string(result.getOwner()->getName().getStringRef()); } // Use the ASM syntax for BlockArgument - if (auto arg = val.dyn_cast()) { + if (auto arg = mlir::dyn_cast(val)) { return "arg" + std::to_string(arg.getArgNumber()); } return ""; diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc index a5b8cb487e6359..1367e7e5eaa175 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.cc @@ -42,14 +42,14 @@ namespace mlir::quant { using ::mlir::stablehlo::DotGeneralOp; bool HasStaticShape(Value value) { - auto shaped_type = value.getType().dyn_cast(); + auto shaped_type = mlir::dyn_cast(value.getType()); if (!shaped_type) return false; return shaped_type.hasStaticShape(); } bool HasStaticShapeAtDims(Value value, const ArrayRef dims) { - auto shaped_type = value.getType().dyn_cast(); + auto shaped_type = mlir::dyn_cast(value.getType()); if (!shaped_type || !shaped_type.hasRank()) return false; for (auto dim : dims) { @@ -59,9 +59,9 @@ bool HasStaticShapeAtDims(Value value, const ArrayRef dims) { } Type CloneTypeWithNewElementType(Type old_type, Type element_type) { - if (!old_type.isa()) return {}; + if (!mlir::isa(old_type)) return {}; - return old_type.cast().clone(element_type); + return mlir::cast(old_type).clone(element_type); } SmallVector CloneOpWithReplacedOperands( @@ -133,9 +133,11 @@ absl::StatusOr IsDotGeneralFullyConnected(DotGeneralOp dot_general_op) { const ArrayRef rhs_contracting_dims = dot_dimension_numbers.getRhsContractingDimensions(); const int64_t input_rank = - dot_general_op.getOperand(0).getType().dyn_cast().getRank(); + mlir::dyn_cast(dot_general_op.getOperand(0).getType()) + .getRank(); const int64_t filter_rank = - dot_general_op.getOperand(1).getType().dyn_cast().getRank(); + mlir::dyn_cast(dot_general_op.getOperand(1).getType()) + .getRank(); // The following conditions are such requirements: // - rank(lhs) is 1 or 2 // - rank(rhs) = 2 @@ -164,7 +166,8 @@ std::optional GetDotGeneralQuantizationDim( DotGeneralOp dot_general_op) { if (dot_general_op == nullptr) return std::nullopt; const int64_t filter_rank = - dot_general_op.getOperand(1).getType().dyn_cast().getRank(); + mlir::dyn_cast(dot_general_op.getOperand(1).getType()) + .getRank(); // To quantize rhs per-channel, we currently only consider the case where // `stablehlo.dot_general` is legalizable to `tfl.fully_connected`. diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h index c80b43dd6baaaf..e94f9359d6fad2 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h @@ -69,7 +69,7 @@ bool HasStaticShapeAtDims(Value value, ArrayRef dims); // Whether `value` has known rank of `rank`. Returns false when it is not a // `ShapedType` or its rank is unknown. inline bool HasRankOf(Value value, const int64_t rank) { - auto shaped_type = value.getType().dyn_cast_or_null(); + auto shaped_type = mlir::dyn_cast_or_null(value.getType()); return shaped_type && shaped_type.hasRank() && shaped_type.getRank() == rank; } @@ -219,7 +219,7 @@ Operation* FindOperandOfType(Operation* op) { // Returns the function attribute for the given call op which is lifted for // quantization. inline FlatSymbolRefAttr GetFuncAttr(TF::PartitionedCallOp call_op) { - return call_op.getFAttr().template dyn_cast(); + return mlir::dyn_cast(call_op.getFAttr()); } inline FlatSymbolRefAttr GetFuncAttr(TF::XlaCallModuleOp call_op) { diff --git a/tensorflow/compiler/mlir/quantization/common/ir/QuantOps.cc b/tensorflow/compiler/mlir/quantization/common/ir/QuantOps.cc index 7bd7424e4d1c6a..6ddebac1ff00f9 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/QuantOps.cc +++ b/tensorflow/compiler/mlir/quantization/common/ir/QuantOps.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/common/ir/QuantOpsDialect.cc.inc" namespace mlir::quant::ir { @@ -49,20 +50,20 @@ OpFoldResult StorageCastOp::fold(FoldAdaptor) { /// The quantization specification should match the expressed type. static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) { - if (auto typeAttr = quantSpec.dyn_cast()) { + if (auto typeAttr = mlir::dyn_cast(quantSpec)) { Type spec = typeAttr.getValue(); - if (spec.isa()) return false; + if (mlir::isa(spec)) return false; // The spec should be either a quantized type which is compatible to the // expressed type, or a primitive type which is as same as the // (element type of) the expressed type. - if (auto quantizedType = spec.dyn_cast()) + if (auto quantizedType = mlir::dyn_cast(spec)) return quantizedType.isCompatibleExpressedType(expressed); - if (auto tensorType = expressed.dyn_cast()) + if (auto tensorType = mlir::dyn_cast(expressed)) return spec == tensorType.getElementType(); - if (auto vectorType = expressed.dyn_cast()) + if (auto vectorType = mlir::dyn_cast(expressed)) return spec == vectorType.getElementType(); } return false; @@ -97,13 +98,13 @@ LogicalResult QuantizeRegionOp::verify() { } LogicalResult StatisticsOp::verify() { - auto tensorArg = getArg().getType().dyn_cast(); + auto tensorArg = mlir::dyn_cast(getArg().getType()); if (!tensorArg) return emitOpError("arg needs to be tensor type."); // Verify layerStats attribute. { auto layerStatsType = getLayerStats().getShapedType(); - if (!layerStatsType.getElementType().isa()) { + if (!mlir::isa(layerStatsType.getElementType())) { return emitOpError("layerStats must have a floating point element type"); } if (layerStatsType.getRank() != 1 || layerStatsType.getDimSize(0) != 2) { @@ -120,7 +121,7 @@ LogicalResult StatisticsOp::verify() { std::multiplies()); auto axisStatsType = getAxisStats()->getShapedType(); - if (!axisStatsType.getElementType().isa()) { + if (!mlir::isa(axisStatsType.getElementType())) { return emitOpError("axisStats must have a floating point element type"); } if (axisStatsType.getRank() != 2 || axisStatsType.getDimSize(1) != 2 || diff --git a/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc index 5a200241af00dd..4677054dc6c765 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc +++ b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc @@ -18,18 +18,19 @@ limitations under the License. #include #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project using namespace mlir; using namespace mlir::quantfork; static bool isQuantizablePrimitiveType(Type inputType) { - return inputType.isa(); + return mlir::isa(inputType); } ExpressedToQuantizedConverter ExpressedToQuantizedConverter::forInputType( Type inputType) { - if (inputType.isa()) { - Type elementType = inputType.cast().getElementType(); + if (mlir::isa(inputType)) { + Type elementType = mlir::cast(inputType).getElementType(); if (!isQuantizablePrimitiveType(elementType)) return ExpressedToQuantizedConverter{inputType, nullptr}; return ExpressedToQuantizedConverter{inputType, elementType}; @@ -44,11 +45,11 @@ ExpressedToQuantizedConverter ExpressedToQuantizedConverter::forInputType( Type ExpressedToQuantizedConverter::convert( quant::QuantizedType elementalType) const { assert(expressedType && "convert() on unsupported conversion"); - if (auto tensorType = inputType.dyn_cast()) + if (auto tensorType = mlir::dyn_cast(inputType)) return RankedTensorType::get(tensorType.getShape(), elementalType); - if (auto tensorType = inputType.dyn_cast()) + if (auto tensorType = mlir::dyn_cast(inputType)) return UnrankedTensorType::get(elementalType); - if (auto vectorType = inputType.dyn_cast()) + if (auto vectorType = mlir::dyn_cast(inputType)) return VectorType::get(vectorType.getShape(), elementalType); // If the expressed types match, just use the new elemental type. @@ -59,7 +60,7 @@ Type ExpressedToQuantizedConverter::convert( ElementsAttr UniformQuantizedPerAxisValueConverter::convert( Attribute realValue) { - if (auto attr = realValue.dyn_cast()) { + if (auto attr = mlir::dyn_cast(realValue)) { return convert(attr); } // TODO: handles sparse elements attribute diff --git a/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h index b6f65e455d0c09..0bd2814017be20 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h +++ b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace quantfork { @@ -75,7 +76,7 @@ class UniformQuantizedValueConverter { static_cast(uniformType.getStorageTypeMin()), static_cast(uniformType.getStorageTypeMax()), uniformType.getStorageTypeIntegralWidth(), uniformType.isSigned()) { - assert(uniformType.getExpressedType().isa()); + assert(mlir::isa(uniformType.getExpressedType())); assert(uniformType.getStorageType().isSignlessInteger()); } @@ -203,7 +204,7 @@ class UniformQuantizedPerAxisValueConverter { storageBitWidth(uniformType.getStorageTypeIntegralWidth()), isSigned(uniformType.isSigned()), quantizationDim(uniformType.getQuantizedDimension()) { - assert(uniformType.getExpressedType().isa()); + assert(mlir::isa(uniformType.getExpressedType())); assert(uniformType.getStorageType().isSignlessInteger()); assert(scales.size() == zeroPoints.size()); } diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc index 5492219f938228..cc4f3deb7da015 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc @@ -137,7 +137,7 @@ ValueRange CreateTFXlaCallModuleOp(OpBuilder& builder, const Location location, SmallVector shape_attrs; for (const Type result_type : output_types) { shape_attrs.push_back( - tf_type::ShapeAttr::get(ctx, result_type.cast())); + tf_type::ShapeAttr::get(ctx, mlir::cast(result_type))); } auto empty_array_attr = ArrayAttr::get(ctx, {}); auto platforms = ArrayAttr::get(ctx, {StringAttr::get(ctx, kPlatformCpu)}); @@ -267,7 +267,7 @@ LogicalResult SetAttributeMap(MLIRContext& context, const NamedAttribute& attribute = attributes[idx]; // Skip the following steps if the attribute value is `NullAttribute`. if (const auto string_attr = - attribute.getValue().dyn_cast_or_null(); + mlir::dyn_cast_or_null(attribute.getValue()); string_attr != nullptr && string_attr.getValue().equals(kNullAttributeValue)) { continue; diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc index c6f858815e140f..2911521d174300 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call_test.cc @@ -118,10 +118,11 @@ TEST_F(LiftAsFunctionCallTest, FunctionLiftedAsXlaCallModuleOp) { FindOperationOfType(entry_func); EXPECT_TRUE(isa(lifted_op)); - EXPECT_EQ(lifted_op->getAttr("_original_entry_function").cast(), - "composite_dot_general_fn_1"); EXPECT_EQ( - lifted_dot_general_op->getAttr("precision_config").cast(), + mlir::cast(lifted_op->getAttr("_original_entry_function")), + "composite_dot_general_fn_1"); + EXPECT_EQ( + mlir::cast(lifted_dot_general_op->getAttr("precision_config")), builder_.getArrayAttr(SmallVector( 1, mlir::stablehlo::PrecisionAttr::get( ctx_.get(), mlir::stablehlo::Precision::DEFAULT)))); @@ -144,8 +145,9 @@ TEST_F(LiftAsFunctionCallTest, FunctionNoAttrLiftedAsXlaCallModuleOp) { "composite_dot_general_fn", operands, results)[0] .getDefiningOp(); EXPECT_TRUE(isa(lifted_op)); - EXPECT_EQ(lifted_op->getAttr("_original_entry_function").cast(), - "composite_dot_general_fn_1"); + EXPECT_EQ( + mlir::cast(lifted_op->getAttr("_original_entry_function")), + "composite_dot_general_fn_1"); } TEST_F(LiftAsFunctionCallTest, EinsumSupportedForXlaDotV2Succeeds) { diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc index 216a4a2b3d58e9..7645177160fc62 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.cc @@ -91,10 +91,11 @@ bool HasPerAxisQuantizedOperand(Operation* op) { for (int i = 0; i < op->getNumOperands(); ++i) { if (auto dq_op = dyn_cast_or_null( op->getOperand(i).getDefiningOp())) { - auto type = dq_op.getArg().getType().cast().getElementType(); + auto type = + mlir::cast(dq_op.getArg().getType()).getElementType(); if (auto per_axis_qtype = - QuantizedType::getQuantizedElementType(type) - .dyn_cast_or_null()) { + mlir::dyn_cast_or_null( + QuantizedType::getQuantizedElementType(type))) { return true; } } @@ -179,7 +180,7 @@ bool QuantizationDriver::SetConstantResultParams(Operation* op) { /*num_bits=*/8, is_signed_, /*narrow_range=*/is_weight, legacy_float_scale_); } - if (const auto quant_type = final_type.dyn_cast_or_null(); + if (const auto quant_type = mlir::dyn_cast_or_null(final_type); quant_type != nullptr) { return SetResultParams(op, /*result_index=*/0, quant_type); } @@ -225,7 +226,7 @@ QuantizedType QuantizationDriver::GetBiasParams( if (bias_op != nullptr) { Type bias_type = bias_op->getResult(0).getType(); if (bias_type != builder_.getNoneType()) { - const int bias_rank = bias_type.dyn_cast().getRank(); + const int bias_rank = mlir::dyn_cast(bias_type).getRank(); adjusted_quant_dim = bias_rank > 1 ? bias_rank - 1 : 0; } } @@ -489,12 +490,12 @@ QuantizedType QuantizationDriver::GetQuantParamsForSameScaleConstraint( void QuantizationDriver::PreprocessConstantOps() { fn_.walk([&](arith::ConstantOp cst) { // Non-float tensors are neither weights nor require quantization. - const auto type = cst.getType().dyn_cast(); - if (!type || !type.getElementType().isa()) return; + const auto type = mlir::dyn_cast(cst.getType()); + if (!type || !mlir::isa(type.getElementType())) return; // Skip if the value is NaN or INF. // Otherwise the illegal scale/zp will be calculated. - auto float_attr = cst.getValueAttr().dyn_cast(); + auto float_attr = mlir::dyn_cast(cst.getValueAttr()); if (float_attr && (float_attr.getValues().empty() || !float_attr.getValues()[0].isFinite())) { return; @@ -620,7 +621,7 @@ bool QuantizationDriver::ShouldCheckBiasScale( auto affine_op = dyn_cast(op); auto bias_op = op->getOperand(bias_index).getDefiningOp(); if (!affine_op || !bias_op || input_indices.size() != 2) return false; - if (!bias_op.getValue().isa()) return false; + if (!mlir::isa(bias_op.getValue())) return false; filter_index = affine_op.GetAffineOperandIndex(); if (!op->getOperand(filter_index).getDefiningOp()) { return false; @@ -658,12 +659,12 @@ bool QuantizationDriver::SetBiasParamsWithAdjustments( QuantState filter_state = GetOperandQuantState(op, filter_index); auto bias_op = op->getOperand(bias_index).getDefiningOp(); const double input_scale = - input_state.params.cast().getScale(); + mlir::cast(input_state.params).getScale(); - auto bias_values = bias_op.getValue().cast(); + auto bias_values = mlir::cast(bias_op.getValue()); // Restrict maximum absolute value of bias within INT_MAX / 2, to make some // room for accumulator. - if (auto bias_quantized_type = params.dyn_cast(); + if (auto bias_quantized_type = mlir::dyn_cast(params); bias_quantized_type != nullptr) { double bias_half_range = 0.0f; for (auto bias : bias_values.getValues()) { @@ -691,7 +692,7 @@ bool QuantizationDriver::SetBiasParamsWithAdjustments( } const auto filter_quantized_type = - filter_state.params.cast(); + mlir::cast(filter_state.params); changed |= SetOperandParams( op, filter_index, UniformQuantizedType::getChecked( @@ -703,10 +704,10 @@ bool QuantizationDriver::SetBiasParamsWithAdjustments( filter_quantized_type.getStorageTypeMax()), /*override=*/true); } else if (auto bias_quantized_type = - params.dyn_cast(); + mlir::dyn_cast(params); bias_quantized_type != nullptr) { const auto filter_quantized_type = - filter_state.params.cast(); + mlir::cast(filter_state.params); std::vector new_bias_scales = bias_quantized_type.getScales().vec(); std::vector new_filter_scales = filter_quantized_type.getScales().vec(); @@ -822,21 +823,22 @@ bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { // Use the final state to set all the operands' parameters. for (int i = 0; i < op->getNumOperands(); ++i) { - if (auto type = op->getOperand(i).getType().dyn_cast()) { + if (auto type = + mlir::dyn_cast(op->getOperand(i).getType())) { // Without this check, it will accidentally propagate the quantization // information by the shared non-float tensors. - if (type.getElementType().isa()) + if (mlir::isa(type.getElementType())) changed |= SetOperandParams(op, i, params); } } // Use the final state to set all the results' parameters. for (int i = 0; i < op->getNumResults(); ++i) - if (auto type = op->getResult(i).getType().dyn_cast(); + if (auto type = mlir::dyn_cast(op->getResult(i).getType()); type != nullptr) { // Without this check, it will accidentally propagate the quantization // information by the shared non-float-tensors. - if (type.getElementType().isa()) + if (mlir::isa(type.getElementType())) changed |= SetResultParams(op, i, params); } } diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc index cc82c09894b46b..f017054cbe7044 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver_test.cc @@ -159,10 +159,9 @@ TEST_F(ApplyQuantizationParamsPropagationTest, FinalizeInsertsQDQOps) { ASSERT_NE(filter_qcast_op, nullptr); EXPECT_TRUE(isa(filter_qcast_op)); EXPECT_TRUE(isa(filter_dcast_op)); - EXPECT_TRUE(isa(filter_qcast_op->getResult(0) - .getType() - .cast() - .getElementType())); + EXPECT_TRUE(isa( + mlir::cast(filter_qcast_op->getResult(0).getType()) + .getElementType())); } } // namespace diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc index 13822637f6887d..7bfda7c392ea42 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc @@ -125,14 +125,14 @@ QuantizedType ResetMinMaxFromNumBits(const QuantizedType type, const auto& recalculate_zero_point = [&](int64_t zero_point) -> int64_t { return qmax - std::round((storage_type_max - zero_point) / rate); }; - if (auto q_type = type.dyn_cast()) { + if (auto q_type = mlir::dyn_cast(type)) { const double scale = recalculate_scale(q_type.getScale()); const double zero_point = recalculate_zero_point(q_type.getZeroPoint()); return UniformQuantizedType::get(q_type.getFlags(), q_type.getStorageType(), q_type.getExpressedType(), scale, zero_point, qmin, qmax); } else if (auto q_type = - type.dyn_cast()) { + mlir::dyn_cast(type)) { const int size = q_type.getScales().size(); SmallVector scales(size); SmallVector zero_points(size); @@ -155,7 +155,7 @@ quant::UniformQuantizedPerAxisType ResetAxisAndBroadcast( const ArrayRef shape, const quant::UniformQuantizedPerAxisType qtype, const Type target, const int quant_dim) { - const auto shaped = target.dyn_cast(); + const auto shaped = mlir::dyn_cast(target); if (!shaped) return {}; const ArrayRef new_shape = shaped.getShape(); @@ -247,7 +247,7 @@ Type GetQuantizedType(Builder builder, const Type input_type, builder.getUnknownLoc()); } } else if (min.size() == max.size()) { - auto shape = input_type.dyn_cast(); + auto shape = mlir::dyn_cast(input_type); if (!shape || shape.getRank() <= quant_dim || static_cast(min.size()) != shape.getDimSize(quant_dim)) { return {}; @@ -277,11 +277,13 @@ Type GetQuantizedType(Builder builder, const Type input_type, // TODO(fengliuai): promote this utility method to mlir QuantOps. TypeAttr RescaleQuantizedType(const Type input, const Attribute factor) { - const auto factor_values = factor.dyn_cast_or_null(); + const auto factor_values = + mlir::dyn_cast_or_null(factor); if (!factor_values) return {}; const auto ele_type = quant::QuantizedType::getQuantizedElementType(input); if (!ele_type) return {}; - if (auto qtype = ele_type.dyn_cast()) { + if (auto qtype = + mlir::dyn_cast(ele_type)) { const ArrayRef scales = qtype.getScales(); // Broadcasting hasn't been implemented yet. if (static_cast(scales.size()) != factor_values.getNumElements()) @@ -315,8 +317,8 @@ TypeAttr GetQuantizedTypeAttr(const Builder builder, const Type input_type, const bool legacy_float_scale, const bool use_fake_quant_num_bits) { SmallVector min_value, max_value; - const auto mins = min.dyn_cast(); - const auto maxs = max.dyn_cast(); + const auto mins = mlir::dyn_cast(min); + const auto maxs = mlir::dyn_cast(max); if (mins && maxs) { min_value.reserve(mins.getNumElements()); max_value.reserve(maxs.getNumElements()); @@ -327,8 +329,8 @@ TypeAttr GetQuantizedTypeAttr(const Builder builder, const Type input_type, max_value.push_back(FloatAttr::getValueAsDouble(*it)); } } else { - const auto fmin = min.dyn_cast(); - const auto fmax = max.dyn_cast(); + const auto fmin = mlir::dyn_cast(min); + const auto fmax = mlir::dyn_cast(max); if (fmin && fmax) { min_value.push_back(fmin.getValueAsDouble()); max_value.push_back(fmax.getValueAsDouble()); @@ -348,14 +350,15 @@ TypeAttr CastQuantizedTypeAttrFromExpressedType(const Builder builder, const TypeAttr source, const Type target, const int axis) { - const auto source_type = source.getValue().dyn_cast_or_null(); + const auto source_type = + mlir::dyn_cast_or_null(source.getValue()); if (!source_type) return {}; const auto src_ele_type = source_type.getElementType(); - auto qtype = src_ele_type.dyn_cast(); + auto qtype = mlir::dyn_cast(src_ele_type); // Reset the quantization dimensions if it is per-axis. if (const auto per_axis = - qtype.dyn_cast_or_null()) { + mlir::dyn_cast_or_null(qtype)) { // For the pass-through ops, we don't know which the dimension will be the // new quantization dimension. Only if the new quantization dimension can // be inferred, it is safe to reset the per-axis quantized type. @@ -429,7 +432,7 @@ Type GetUniformQuantizedTypeForWeight( SmallVector mins(1, std::numeric_limits::max()); SmallVector maxs(1, std::numeric_limits::min()); - const auto fp = attr.dyn_cast(); + const auto fp = mlir::dyn_cast(attr); if (!fp) return {}; // Computes the effective min/max values of the attribute values. @@ -440,7 +443,7 @@ Type GetUniformQuantizedTypeForWeight( GetQuantizedType(builder, attr.getType(), mins[0], maxs[0], /*quant_dim=*/-1, num_bits, narrow_range, is_signed, legacy_float_scale, use_fake_quant_num_bits); - if (const auto ele_type = type.dyn_cast_or_null()) + if (const auto ele_type = mlir::dyn_cast_or_null(type)) return ele_type.getElementType(); return {}; @@ -451,7 +454,7 @@ Type GetUniformQuantizedPerAxisTypeForWeight( const unsigned num_bits, const bool is_signed, const bool narrow_range, const bool legacy_float_scale, const bool use_fake_quant_num_bits) { const Builder builder(attr.getContext()); - const auto shape = attr.getType().cast().getShape(); + const auto shape = mlir::cast(attr.getType()).getShape(); if (static_cast(shape.size()) <= quant_dim) return {}; // `symmetric` can only be used when it is `signed` and `narrow_range`. if (symmetric && (!is_signed || !narrow_range)) return {}; @@ -462,7 +465,7 @@ Type GetUniformQuantizedPerAxisTypeForWeight( std::multiplies()); SmallVector mins(dim_size, std::numeric_limits::max()); SmallVector maxs(dim_size, std::numeric_limits::min()); - const auto fp = attr.dyn_cast(); + const auto fp = mlir::dyn_cast(attr); if (!fp) return {}; // Computes the effective min/max values of the attribute values. @@ -471,7 +474,7 @@ Type GetUniformQuantizedPerAxisTypeForWeight( const auto type = GetQuantizedType( builder, attr.getType(), mins, maxs, quant_dim, num_bits, narrow_range, is_signed, legacy_float_scale, use_fake_quant_num_bits); - if (auto ele_type = type.dyn_cast_or_null()) + if (auto ele_type = mlir::dyn_cast_or_null(type)) return ele_type.getElementType(); return {}; @@ -497,13 +500,13 @@ quant::QuantizedType GetUniformQuantizedTypeForBias( expressed_type = op_type.getExpressedType(); if (const auto type = - op_type.dyn_cast()) { + mlir::dyn_cast(op_type)) { if (axis_size != 1 && axis_size != type.getScales().size()) return {}; if (quant_dim != -1 && quant_dim != type.getQuantizedDimension()) return {}; axis_size = type.getScales().size(); quant_dim = type.getQuantizedDimension(); - } else if (!op_type.isa()) { + } else if (!mlir::isa(op_type)) { return {}; } } @@ -513,12 +516,12 @@ quant::QuantizedType GetUniformQuantizedTypeForBias( llvm::SmallVector scales(axis_size, 1.0); for (const auto op_type : op_types) { if (const auto type = - op_type.dyn_cast()) { + mlir::dyn_cast(op_type)) { for (const auto& index_scale : llvm::enumerate(type.getScales())) { scales[index_scale.index()] *= index_scale.value(); } } else if (const auto type = - op_type.dyn_cast()) { + mlir::dyn_cast(op_type)) { for (int index = 0; index < axis_size; ++index) { scales[index] *= type.getScale(); } @@ -557,11 +560,11 @@ quant::QuantizedType GetUniformQuantizedTypeForBias( ElementsAttr QuantizeLegacy(const Attribute real_value, const Type tensor_type) { - if (!real_value.isa() || + if (!mlir::isa(real_value) || !quant::QuantizedType::getQuantizedElementType(tensor_type)) { return {}; } - const auto real_values_attr = real_value.cast(); + const auto real_values_attr = mlir::cast(real_value); auto q_type = quant::QuantizedType::getQuantizedElementType(tensor_type); std::vector real_values; llvm::SmallVector quantized_attr; @@ -571,16 +574,15 @@ ElementsAttr QuantizeLegacy(const Attribute real_value, std::back_inserter(real_values), [&](APFloat value) -> float { return value.convertToFloat(); }); - const ShapedType new_dense_type = - q_type.castExpressedToStorageType(real_values_attr.getType()) - .dyn_cast_or_null(); + const ShapedType new_dense_type = mlir::dyn_cast_or_null( + q_type.castExpressedToStorageType(real_values_attr.getType())); const int width = - q_type.getStorageType().dyn_cast().getWidth(); + mlir::dyn_cast(q_type.getStorageType()).getWidth(); if (width == 8 && q_type.getStorageTypeMax() == 127 && q_type.getStorageTypeMin() == -127) { std::vector quantized_values(real_values_attr.getNumElements()); - if (auto uniform_type = q_type.dyn_cast()) { + if (auto uniform_type = mlir::dyn_cast(q_type)) { float min, max, scale; tflite::tensor_utils::SymmetricQuantizeFloats( real_values.data(), real_values.size(), quantized_values.data(), &min, @@ -590,7 +592,7 @@ ElementsAttr QuantizeLegacy(const Attribute real_value, return Quantize(real_value, tensor_type); } } else if (auto uniform_type = - q_type.dyn_cast()) { + mlir::dyn_cast(q_type)) { std::vector scales_inv; std::vector dimension; dimension.insert(dimension.end(), new_dense_type.getShape().begin(), @@ -619,7 +621,8 @@ ElementsAttr QuantizeLegacy(const Attribute real_value, // not correctly quantized by legacy quantizer so call the new Quantize. return Quantize(real_value, tensor_type); } else if (width == 16) { - if (const auto uniform_type = q_type.dyn_cast()) { + if (const auto uniform_type = + mlir::dyn_cast(q_type)) { const auto quantized_values = tflite::optimize::utils::SymmetricQuantizeFloatsToInt16( real_values.data(), real_values.size(), uniform_type.getScale()); @@ -632,10 +635,11 @@ ElementsAttr QuantizeLegacy(const Attribute real_value, } } else if (width == 32) { std::vector scales; - if (const auto uniform_type = q_type.dyn_cast()) { + if (const auto uniform_type = + mlir::dyn_cast(q_type)) { scales.push_back(uniform_type.getScale()); } else if (const auto uniform_type = - q_type.dyn_cast()) { + mlir::dyn_cast(q_type)) { scales.insert(scales.end(), uniform_type.getScales().begin(), uniform_type.getScales().end()); } else { @@ -658,8 +662,8 @@ ElementsAttr Quantize(const Attribute real_value, const Type tensor_type) { if (const auto q_type = quant::QuantizedType::getQuantizedElementType(tensor_type)) { Type converted_type; - return quantfork::quantizeAttr(real_value, q_type, converted_type) - .dyn_cast_or_null(); + return mlir::dyn_cast_or_null( + quantfork::quantizeAttr(real_value, q_type, converted_type)); } return {}; } @@ -680,10 +684,10 @@ quant::QuantizedType DownCastScale(QuantizedType type, if (!type) return type; SmallVector scales(mins.size()); SmallVector zero_points(mins.size()); - if (auto q_type = type.dyn_cast()) { + if (auto q_type = mlir::dyn_cast(type)) { zero_points.push_back(q_type.getZeroPoint()); } else if (auto q_type = - type.dyn_cast()) { + mlir::dyn_cast(type)) { zero_points = {q_type.getZeroPoints().begin(), q_type.getZeroPoints().end()}; } @@ -703,13 +707,13 @@ quant::QuantizedType DownCastScale(QuantizedType type, } } } - if (auto q_type = type.dyn_cast()) { + if (auto q_type = mlir::dyn_cast(type)) { return UniformQuantizedType::get(q_type.getFlags(), q_type.getStorageType(), q_type.getExpressedType(), scales[0], zero_points[0], q_type.getStorageTypeMin(), q_type.getStorageTypeMax()); } else if (auto q_type = - type.dyn_cast()) { + mlir::dyn_cast(type)) { return quant::UniformQuantizedPerAxisType::get( q_type.getFlags(), q_type.getStorageType(), q_type.getExpressedType(), scales, zero_points, q_type.getQuantizedDimension(), @@ -724,8 +728,8 @@ quant::QuantizedType DownCastScale(QuantizedType type, static bool PreferResultScale(Operation* op) { int float_operands = 0; for (auto operand : op->getOperands()) { - if (auto operand_type = operand.getType().dyn_cast()) { - if (operand_type.getElementType().isa()) { + if (auto operand_type = mlir::dyn_cast(operand.getType())) { + if (mlir::isa(operand_type.getElementType())) { if (++float_operands > 1) return true; } } @@ -903,9 +907,9 @@ LogicalResult VerifySameScales(Operation* op) { // method. if (!same_scale_op.RequiredSameQuantizedAxes()) { const auto expected_per_axis_qtype = - expected_params.dyn_cast(); + mlir::dyn_cast(expected_params); const auto compared_per_axis_qtype = - compared_params.dyn_cast(); + mlir::dyn_cast(compared_params); if (expected_per_axis_qtype && compared_per_axis_qtype && llvm::equal(expected_per_axis_qtype.getScales(), compared_per_axis_qtype.getScales()) && @@ -947,8 +951,8 @@ quant::UniformQuantizedType GetFixedOutputRange( const bool is_signed, const int bit_width, const Type tensor_type, const double scale, int64_t zero_point, int64_t storage_min, int64_t storage_max) { - const auto result_type = tensor_type.cast(); - if (!result_type.getElementType().isa()) return {}; + const auto result_type = mlir::cast(tensor_type); + if (!mlir::isa(result_type.getElementType())) return {}; Builder builder(result_type.getContext()); // Only support 8-bits and 16-bits @@ -990,14 +994,14 @@ Type ConvertSignedQuantizedToUnsigned(const Type signed_tensor_type, const auto flags = !quant::QuantizationFlags::Signed; QType new_qtype; - if (auto uqtype = qtype.dyn_cast()) { + if (auto uqtype = mlir::dyn_cast(qtype)) { new_qtype = quant::UniformQuantizedType::getChecked( loc, flags, qtype.getStorageType(), qtype.getExpressedType(), uqtype.getScale(), uqtype.getZeroPoint() - offset, uqtype.getStorageTypeMin() - offset, uqtype.getStorageTypeMax() - offset); } else if (auto aqtype = - qtype.dyn_cast()) { + mlir::dyn_cast(qtype)) { const auto zero_points = aqtype.getZeroPoints(); llvm::SmallVector new_zero_points(zero_points.begin(), zero_points.end()); diff --git a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.cc b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.cc index a64ba201250727..7f66d76798acfa 100644 --- a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.cc +++ b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.cc @@ -94,12 +94,12 @@ bool IsStorageTypeI32(const QuantizedType quantized_type) { bool IsExpressedTypeF32(const QuantizedType quantized_type) { const Type expressed_type = quantized_type.getExpressedType(); - return expressed_type.isa(); + return mlir::isa(expressed_type); } bool IsI8F32UniformQuantizedType(const Type type) { const UniformQuantizedType quantized_type = - type.dyn_cast_or_null(); + mlir::dyn_cast_or_null(type); if (!quantized_type) { LLVM_DEBUG(llvm::dbgs() << "Expected a uniform quantized type. Got: " << type << ".\n"); @@ -123,7 +123,7 @@ bool IsI8F32UniformQuantizedType(const Type type) { bool IsI8F32UniformQuantizedPerAxisType(const Type type) { const UniformQuantizedPerAxisType quantized_per_axis_type = - type.dyn_cast_or_null(); + mlir::dyn_cast_or_null(type); if (!quantized_per_axis_type) { LLVM_DEBUG(llvm::dbgs() << "Expected a uniform quantized type. Got: " << type << ".\n"); @@ -147,7 +147,7 @@ bool IsI8F32UniformQuantizedPerAxisType(const Type type) { bool IsI32F32UniformQuantizedType(const Type type) { const UniformQuantizedType quantized_type = - type.dyn_cast_or_null(); + mlir::dyn_cast_or_null(type); if (!quantized_type) { LLVM_DEBUG(llvm::dbgs() << "Expected a uniform quantized type. Got: " << type << ".\n"); @@ -171,7 +171,7 @@ bool IsI32F32UniformQuantizedType(const Type type) { bool IsI32F32UniformQuantizedPerAxisType(const Type type) { const UniformQuantizedPerAxisType quantized_per_axis_type = - type.dyn_cast_or_null(); + mlir::dyn_cast_or_null(type); if (!quantized_per_axis_type) { LLVM_DEBUG(llvm::dbgs() << "Expected a uniform quantized type. Got: " << type << ".\n"); @@ -208,11 +208,11 @@ bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type) { } bool IsQuantizedTensorType(Type type) { - if (!type.isa()) { + if (!mlir::isa(type)) { return false; } - Type element_type = type.cast().getElementType(); - return element_type.isa(); + Type element_type = mlir::cast(type).getElementType(); + return mlir::isa(element_type); } bool IsOpFullyQuantized(Operation* op) { diff --git a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h index ab850c878ff0dd..e30db98a9616de 100644 --- a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h +++ b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h @@ -82,7 +82,7 @@ bool IsExpressedTypeF32(QuantizedType quantized_type); // Given a value, extract the `ElementType`. // `value` should be a non-null `TensorType`. inline Type GetElementType(const Value value) { - return value.getType().cast().getElementType(); + return mlir::cast(value.getType()).getElementType(); } // Returns true iff `type` is a uniform quantized type whose storage type is diff --git a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc index e9443a667fcef3..d4055b1732b1d8 100644 --- a/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc +++ b/tensorflow/compiler/mlir/quantization/common/uniform_quantized_types_test.cc @@ -348,7 +348,8 @@ TEST_F(IsI8F32UniformQuantizedTypeTest, UniformQuantizedTypeSucceeds) { /*flags=*/QuantizationFlags::Signed, builder_.getI8Type(), builder_.getF32Type(), /*scale=*/1.0, /*zeroPoint=*/0, /*storageTypeMin=*/-128, /*storageTypeMax=*/127); - EXPECT_THAT(qi8_type.dyn_cast_or_null(), NotNull()); + EXPECT_THAT(mlir::dyn_cast_or_null(qi8_type), + NotNull()); } TEST_F(IsI8F32UniformQuantizedTypeTest, StorageTypeI8Succeeds) { @@ -398,8 +399,9 @@ TEST_F(IsI8F32UniformQuantizedTypeTest, UniformQuantizedPerAxisTypeSucceeds) { /*scales=*/{1.0}, /*zeroPoints=*/{0}, /*quantizedDimension=*/0, /*storageTypeMin=*/-128, /*storageTypeMax=*/127); - EXPECT_THAT(qi8_per_axis_type.dyn_cast_or_null(), - NotNull()); + EXPECT_THAT( + mlir::dyn_cast_or_null(qi8_per_axis_type), + NotNull()); } TEST_F(IsI8F32UniformQuantizedPerAxisTypeTest, StorageTypeI8Succeeds) { @@ -452,7 +454,8 @@ TEST_F(IsI32F32UniformQuantizedTypeTest, UniformQuantizedTypeSucceeds) { /*zeroPoint=*/0, /*storageTypeMin=*/-2147483647, /*storageTypeMax=*/2147483646); EXPECT_TRUE(IsI32F32UniformQuantizedType(qi32_type)); - EXPECT_THAT(qi32_type.dyn_cast_or_null(), NotNull()); + EXPECT_THAT(mlir::dyn_cast_or_null(qi32_type), + NotNull()); } TEST_F(IsI32F32UniformQuantizedTypeTest, StorageTypeI32Succeeds) { @@ -509,7 +512,7 @@ TEST_F(IsI32F32UniformQuantizedPerAxisTypeTest, /*storageTypeMax=*/127); EXPECT_FALSE(IsI32F32UniformQuantizedPerAxisType(qi8_type)); EXPECT_FALSE(IsStorageTypeI32(qi8_type)); - EXPECT_THAT(qi8_type.dyn_cast_or_null(), + EXPECT_THAT(mlir::dyn_cast_or_null(qi8_type), IsNull()); } @@ -523,7 +526,7 @@ TEST_F(IsI32F32UniformQuantizedTypeTest, UniformQuantizedPerAxisTypeSucceeds) { /*storageTypeMin=*/-2147483647, /*storageTypeMax=*/2147483646); EXPECT_THAT( - qi32_per_axis_type.dyn_cast_or_null(), + mlir::dyn_cast_or_null(qi32_per_axis_type), NotNull()); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.cc index a619197c27e135..f8181deca51a0e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/report.cc @@ -85,9 +85,8 @@ std::optional GetQuantizationResult(func::CallOp call_op) { std::optional GetQuantizationResult( TF::XlaCallModuleOp xla_call_module_op) { const StringAttr callee_name_attr = - xla_call_module_op - ->getDiscardableAttr(kOriginalStablehloEntryFunctionAttrName) - .dyn_cast_or_null(); + mlir::dyn_cast_or_null(xla_call_module_op->getDiscardableAttr( + kOriginalStablehloEntryFunctionAttrName)); // `TF::XlaCallModuleOp` without the `_original_entry_function` means it is // not a quantizable unit. diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc index cd861d934e75f8..965da4ff998635 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_ops_to_mhlo.cc @@ -88,7 +88,7 @@ FailureOr GetUniformQuantizedType( } auto original_element_type = getElementTypeOrSelf(original_type); - if (!original_element_type.isa()) { + if (!mlir::isa(original_element_type)) { return rewriter.notifyMatchFailure( op, "Quantized type must be qint8 or qint32."); } @@ -112,7 +112,7 @@ FailureOr GetUniformQuantizedType( quantized_dimension, storage_type_min, storage_type_max); } - return original_type.cast().clone(elem_ty); + return mlir::cast(original_type).clone(elem_ty); } // If operand is TF const op, create MHLO constant op from the contents. @@ -178,8 +178,8 @@ FailureOr ConvertPaddingAttr( const xla::ConvolutionDimensionNumbers &dnums, PatternRewriter &rewriter) { StringAttr conv_padding = op.getPaddingAttr(); SmallVector padding_nums; - ShapedType lhs_shape = op.getLhs().getType().template cast(); - ShapedType rhs_shape = op.getRhs().getType().template cast(); + ShapedType lhs_shape = mlir::cast(op.getLhs().getType()); + ShapedType rhs_shape = mlir::cast(op.getRhs().getType()); // Handle only static shape cases. // TODO(b/260284866): Handle dynamic shape cases. @@ -203,15 +203,15 @@ FailureOr ConvertPaddingAttr( padding_nums.resize(padding_nums_size); for (int i = 0; i < dnums.input_spatial_dimensions_size(); ++i) { const int64_t stride = - op.getWindowStridesAttr()[i].template cast().getInt(); + mlir::cast(op.getWindowStridesAttr()[i]).getInt(); const int64_t lhs_size_dilated = ::tensorflow::UniformQuantizedConvolutionParams::DilatedSize( lhs_shape.getDimSize(dnums.input_spatial_dimensions(i)), - op.getLhsDilationAttr()[i].template cast().getInt()); + mlir::cast(op.getLhsDilationAttr()[i]).getInt()); const int64_t rhs_size_dilated = ::tensorflow::UniformQuantizedConvolutionParams::DilatedSize( rhs_shape.getDimSize(dnums.kernel_spatial_dimensions(i)), - op.getRhsDilationAttr()[i].template cast().getInt()); + mlir::cast(op.getRhsDilationAttr()[i]).getInt()); const int64_t output_size = (lhs_size_dilated + stride - 1) / stride; const int64_t total_padding = std::max( @@ -262,7 +262,7 @@ FailureOr> ConvertToMhloConvolutionOpAttrs( attr.getName() == op.getLhsDilationAttrName() || attr.getName() == op.getRhsDilationAttrName()) { attr.setValue(ConvertToDenseElementsAttr( - attr.getValue().template cast(), rewriter)); + mlir::cast(attr.getValue()), rewriter)); converted_attrs.push_back(attr); } } @@ -362,9 +362,9 @@ class ConvertUniformQuantizeOp op->getLoc(), *output_type, op.getInput()); rewriter.replaceOpWithNewOp( op, - output_type->clone(output_type->getElementType() - .dyn_cast() - .getStorageType()), + output_type->clone( + mlir::dyn_cast(output_type->getElementType()) + .getStorageType()), result); return success(); @@ -438,9 +438,9 @@ class ConvertUniformRequantizeOp op->getLoc(), *output_type, input_quant); rewriter.replaceOpWithNewOp( op, - output_type->clone(output_type->getElementType() - .dyn_cast() - .getStorageType()), + output_type->clone( + mlir::dyn_cast(output_type->getElementType()) + .getStorageType()), result); return success(); } @@ -502,9 +502,9 @@ class ConvertUniformQuantizedDotOp /*precision_config=*/nullptr); rewriter.replaceOpWithNewOp( op, - output_type->clone(output_type->getElementType() - .dyn_cast() - .getStorageType()), + output_type->clone( + mlir::dyn_cast(output_type->getElementType()) + .getStorageType()), result); return success(); } @@ -564,9 +564,9 @@ class ConvertUniformQuantizedConvolutionOp op->getLoc(), *output_type, operands, *converted_attrs_or); rewriter.replaceOpWithNewOp( op, - output_type->clone(output_type->getElementType() - .dyn_cast() - .getStorageType()), + output_type->clone( + mlir::dyn_cast(output_type->getElementType()) + .getStorageType()), result); return success(); } @@ -582,7 +582,7 @@ class ConvertUniformQuantizedAddOp ConversionPatternRewriter &rewriter) const override { Value lhs = adaptor.getLhs(); - auto lhs_type = lhs.getType().cast(); + auto lhs_type = mlir::cast(lhs.getType()); if (!lhs_type.hasRank()) { return rewriter.notifyMatchFailure( op, "Legalization supports cases where only lhs rank known."); @@ -632,9 +632,9 @@ class ConvertUniformQuantizedAddOp op->getLoc(), *output_type, lhs, *rhs_or, broadcast_dims); rewriter.replaceOpWithNewOp( op, - output_type->clone(output_type->getElementType() - .dyn_cast() - .getStorageType()), + output_type->clone( + mlir::dyn_cast(output_type->getElementType()) + .getStorageType()), result); return success(); } @@ -692,9 +692,9 @@ class ConvertUniformQuantizedClipByValueOp op->getLoc(), *output_type, res_min_clipped, *max_or, broadcast_dims); rewriter.replaceOpWithNewOp( op, - output_type->clone(output_type->getElementType() - .dyn_cast() - .getStorageType()), + output_type->clone( + mlir::dyn_cast(output_type->getElementType()) + .getStorageType()), res_max_clipped); return success(); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types.cc index 65192fc1117673..f07097a109a0af 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_types.cc @@ -71,7 +71,7 @@ bool IsIllegalType(Type type) { // If input is not TF qint types, returns the original type. Type ToLegalType(Type type) { if (IsTFQintType(type)) return GetIntTypeFromTFQint(type); - if (auto shaped = type.dyn_cast()) { + if (auto shaped = mlir::dyn_cast(type)) { Type elem = shaped.getElementType(); if (IsTFQintType(elem)) return shaped.clone(ToLegalType(elem)); } @@ -289,7 +289,7 @@ class TFConstOpQuantToIntPattern : public OpConversionPattern { } auto dense_attr_or = GetDenseAttrFromTensorProtoAttr( tensor_proto_attr.getValue(), - ToLegalType(op.getOutput().getType()).dyn_cast()); + mlir::dyn_cast(ToLegalType(op.getOutput().getType()))); if (failed(dense_attr_or)) { op->emitError("failed to get DenseElementAttr."); return failure(); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc index 2825195addea12..7484ed89aa51b1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/verify_quant_legalization.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -53,7 +54,7 @@ class VerifyQuantLegalization bool IsQuantType(Type type) { auto element_type = getElementTypeOrSelf(type); - return element_type.isa() || + return mlir::isa(element_type) || IsTFQintType(element_type); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_func_to_bfloat16.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_func_to_bfloat16.cc index 0204a19452bb0d..4a85786dc94937 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_func_to_bfloat16.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_func_to_bfloat16.cc @@ -143,8 +143,8 @@ class BFloat16TypePattern : public ConversionPattern { state.attributes.set( const_op.getValueAttrName(), DenseFPElementsAttr::get( - const_op.getValue().getType().dyn_cast().clone( - rewriter.getBF16Type()), + mlir::dyn_cast(const_op.getValue().getType()) + .clone(rewriter.getBF16Type()), bfloat16_values)); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc index 2716672ba8bceb..686204030c1fdc 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc @@ -155,7 +155,7 @@ class DeferActivationTransposeForMaxPoolReduceWindowOp PatternRewriter& rewriter) const override { auto transpose_op = cast(op.getOperand(0).getDefiningOp()); - const auto result_type = op.getResult(0).getType().cast(); + const auto result_type = mlir::cast(op.getResult(0).getType()); const SmallVector new_result_shape = Permute(result_type.getShape(), kNchwToNhwcPermutation); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc index 051745c0d6792b..06e38c3935c417 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc @@ -127,12 +127,13 @@ class FoldTransposedConstantOp if (!const_op) return failure(); // Only support float tensors. - auto tensor_type = const_op.getType().dyn_cast_or_null(); + auto tensor_type = mlir::dyn_cast_or_null(const_op.getType()); if (!tensor_type || !tensor_type.getElementType().isF32()) { return failure(); } - return success(const_op.getValue().isa_and_nonnull()); + return success( + mlir::isa_and_nonnull(const_op.getValue())); } void rewrite(mlir::stablehlo::TransposeOp op, @@ -140,7 +141,8 @@ class FoldTransposedConstantOp auto const_op = cast(op.getOperand().getDefiningOp()); - const auto value_attr = const_op.getValue().cast(); + const auto value_attr = + mlir::cast(const_op.getValue()); const ArrayRef original_shape = value_attr.getShapedType().getShape(); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc index 9c6ad5fcc6e5ae..85e6f0f655de37 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc @@ -90,7 +90,7 @@ class InsertWeightParamPattern if (op->getNumResults() != 1) { return failure(); } - auto type = op->getResult(0).getType().cast(); + auto type = mlir::cast(op->getResult(0).getType()); if (!type || !type.getElementType().isF32()) { return failure(); } @@ -124,11 +124,10 @@ class InsertWeightParamPattern Type weight_type; if (IsPerTensor(weight_only_ptq)) { - weight_type = + weight_type = dyn_cast( quant::GetUniformQuantizedTypeForWeight( attr, /*symmetric=*/false, /*num_bits=*/8, /*is_signed=*/true, - /*narrow_range=*/false, /*legacy_float_scale=*/false) - .template dyn_cast(); + /*narrow_range=*/false, /*legacy_float_scale=*/false)); } else { int quantization_dimension = GetQuantizationDimension( weight_only_ptq, cast(quantizable_op)); @@ -138,7 +137,7 @@ class InsertWeightParamPattern /*narrow_range=*/false, /*legacy_float_scale=*/false); } - auto quant_type = weight_type.template dyn_cast(); + auto quant_type = dyn_cast(weight_type); if (!quant_type) { op->emitError( "Failed to get weight quantization parameters for weight-only " diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc index 6577666ab90f10..8ec0ff211e75c2 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions.cc @@ -66,7 +66,7 @@ Attribute DefaultOrNullAttr(OpBuilder& builder, const Attribute& attr) { // Checks whether the value of a constant equals the given float, regardless // of the tensor dimension. bool FloatValueEquals(const Attribute& attr, const double value) { - const auto fp_attr = attr.dyn_cast_or_null(); + const auto fp_attr = mlir::dyn_cast_or_null(attr); if (!fp_attr) return false; if (fp_attr.isSplat()) { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc index acfe3cfd6fc6b2..9a0d8fb2a25b2b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc @@ -69,8 +69,8 @@ class MergeFusionWithUniformDequantizePattern auto func_name = call_op.getCallee(); if (!func_name.starts_with("quantized_")) return failure(); if (call_op->getNumResults() != 1) return failure(); - if (!getElementTypeOrSelf(call_op->getResult(0).getType()) - .isa()) + if (!mlir::isa( + getElementTypeOrSelf(call_op->getResult(0).getType()))) return failure(); // Fetch the callee function. @@ -89,8 +89,8 @@ class MergeFusionWithUniformDequantizePattern // Create a new func.call op with f32 output. auto new_call_op = call_op.clone(); new_call_op->getResult(0).setType( - call_op.getResult(0).getType().cast().clone( - rewriter.getF32Type())); + mlir::cast(call_op.getResult(0).getType()) + .clone(rewriter.getF32Type())); rewriter.setInsertionPoint(call_op); rewriter.insert(new_call_op); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc index 521f701598fb0a..ed2da6ed103273 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc @@ -73,7 +73,7 @@ class RewriteNchwConvolutionToNhwc // Transpose the input tensor: [b, f, 0, 1] => [b, 0, 1, f] Value input = op->getOperand(0); const TensorType new_input_tensor_type = GetTransposedTensorType( - input.getType().cast(), kNchwToNhwcPermutation); + mlir::cast(input.getType()), kNchwToNhwcPermutation); auto input_transpose_op = rewriter.create( op.getLoc(), /*resultType0=*/new_input_tensor_type, /*operand=*/input, @@ -82,7 +82,7 @@ class RewriteNchwConvolutionToNhwc // Transpose the filter tensor: [o, i, 0, 1] => [0, 1, i, o] Value filter = op->getOperand(1); const TensorType new_filter_tensor_type = GetTransposedTensorType( - filter.getType().cast(), kOihwToHwioPermutation); + mlir::cast(filter.getType()), kOihwToHwioPermutation); auto filter_transpose_op = rewriter.create( op.getLoc(), /*resultType0=*/new_filter_tensor_type, /*operand=*/filter, @@ -98,7 +98,8 @@ class RewriteNchwConvolutionToNhwc /*outputSpatialDimensions=*/SmallVector{1, 2}); // Determine the shape of the output tensor: [b, f, 0, 1] => [b, 0, 1, f] - auto output_tensor_type = op->getResult(0).getType().cast(); + auto output_tensor_type = + mlir::cast(op->getResult(0).getType()); const TensorType new_conv_output_tensor_type = GetTransposedTensorType(output_tensor_type, kNchwToNhwcPermutation); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc index b4e12ddc9e0607..8d3ee0717e469e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc @@ -238,7 +238,7 @@ void CreateAndReturnQuantizedBiasPattern( if (succeeded(bcast_op)) { Value bcast_op_result = (*bcast_op)->getResult(0); auto bcast_op_result_type = - bcast_op_result.getType().cast(); + mlir::cast(bcast_op_result.getType()); const ArrayRef bcast_shape = bcast_op_result_type.getShape(); const TensorType new_bcast_op_result_type = bcast_op_result_type.cloneWith( bcast_shape, accumulation_quantized_element_type); @@ -246,7 +246,7 @@ void CreateAndReturnQuantizedBiasPattern( } const auto add_op_result_type = - add_op_result.getType().cast(); + mlir::cast(add_op_result.getType()); const ArrayRef add_op_shape = add_op_result_type.getShape(); // For quantized bias add case, lhs, rhs, and result have the same types. const TensorType new_add_op_result_type = add_op_result_type.cloneWith( @@ -320,7 +320,7 @@ void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter, Value gemm_style_op_result = gemm_style_op->getResult(0); const auto gemm_style_op_result_type = - gemm_style_op_result.getType().cast(); + mlir::cast(gemm_style_op_result.getType()); const ArrayRef gemm_style_shape = gemm_style_op_result_type.getShape(); @@ -328,11 +328,12 @@ void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter, TensorType new_gemm_style_op_result_type; const double input_scale = - getElementTypeOrSelf(input_type).cast().getScale(); + mlir::cast(getElementTypeOrSelf(input_type)) + .getScale(); if (enable_per_channel_quantized_weight) { - ArrayRef filter_scales = getElementTypeOrSelf(filter_type) - .cast() + ArrayRef filter_scales = mlir::cast( + getElementTypeOrSelf(filter_type)) .getScales(); std::vector result_scales; result_scales.reserve(filter_scales.size()); @@ -342,8 +343,8 @@ void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter, } const ArrayRef zero_points = - getElementTypeOrSelf(filter_type) - .cast() + mlir::cast( + getElementTypeOrSelf(filter_type)) .getZeroPoints(); // `stablehlo.convolution` assumes the following format: @@ -353,7 +354,7 @@ void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter, // `stablehlo.dot_general` legalizable to `tfl.fully_connected` has a // filter rank of 2 with the last dimension as the channel dimension. const int64_t quantization_dimension = - filter_type.cast().getShape().size() - 1; + mlir::cast(filter_type).getShape().size() - 1; accumulation_quantized_element_type = CreateI32F32UniformQuantizedPerAxisType( gemm_style_op->getLoc(), *rewriter.getContext(), result_scales, @@ -362,9 +363,9 @@ void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter, new_gemm_style_op_result_type = gemm_style_op_result_type.cloneWith( gemm_style_shape, accumulation_quantized_element_type); } else { - const double filter_scale = getElementTypeOrSelf(filter_type) - .cast() - .getScale(); + const double filter_scale = + mlir::cast(getElementTypeOrSelf(filter_type)) + .getScale(); const double result_scale = input_scale * filter_scale; accumulation_quantized_element_type = CreateI32F32UniformQuantizedType( @@ -557,13 +558,13 @@ class QuantizeSingularOpPattern : public EntryFuncBodyQuantizationPattern { // Get the quantized tensor manipulation op's output type and update. const auto singular_op_result_type = - singular_op_result.getType().cast(); + mlir::cast(singular_op_result.getType()); const ArrayRef singular_op_shape = singular_op_result_type.getShape(); const TensorType new_singular_op_result_type = singular_op_result_type.cloneWith( - singular_op_shape, - getElementTypeOrSelf(operand_type).cast()); + singular_op_shape, mlir::cast( + getElementTypeOrSelf(operand_type))); singular_op_result.setType(new_singular_op_result_type); // Create requantization op and return. @@ -757,13 +758,13 @@ class QuantizeOpWithRegionPattern inputs.reserve(op_with_region->getNumOperands()); for (Value operand : op_with_region->getOperands()) { const Type operand_type = operand.getType(); - if (operand_type.isa()) { + if (mlir::isa(operand_type)) { inputs.push_back(operand); continue; } const Type element_type = - operand.getType().cast().getElementType(); + mlir::cast(operand.getType()).getElementType(); if (auto dq_op = dyn_cast_or_null( operand.getDefiningOp())) { inputs.push_back(dq_op.getOperand()); @@ -783,13 +784,13 @@ class QuantizeOpWithRegionPattern output_types.reserve(op_with_region->getNumResults()); for (const Value result : op_with_region->getResults()) { const Type result_type = result.getType(); - if (result_type.isa()) { + if (mlir::isa(result_type)) { outputs_replaced.push_back(result); output_types.push_back(result_type); continue; } const Type result_element_type = - result.getType().cast().getElementType(); + mlir::cast(result.getType()).getElementType(); // If the user is the QuantizeOp, it must be the only user. if (result.hasOneUse() && isa(*result.user_begin())) { @@ -823,7 +824,7 @@ class QuantizeOpWithRegionPattern const Type operand_type = quantized_op->getOperandTypes()[0]; const Type element_type = - operand_type.cast().getElementType(); + mlir::cast(operand_type).getElementType(); for (Region& region : quantized_op->getRegions()) { ReplaceTypesInNestedRegion(region, element_type); } @@ -880,7 +881,7 @@ class QuantizeOpWithRegionPattern // Replaces element type of the given tensor type while preserving shape of // the given type. If the given type is not tensor type, just return itself. Type ReplaceElementType(const Type type, const Type element_type) const { - if (TensorType tensor_type = type.dyn_cast()) { + if (TensorType tensor_type = mlir::dyn_cast(type)) { return tensor_type.clone(element_type); } return type; @@ -898,23 +899,23 @@ bool IsQuantizedCompositeFunction(func::CallOp call_op) { bool has_quantized_types = false; for (Value operand : call_op.getOperands()) { - if (const TensorType type = operand.getType().dyn_cast()) { - if (type.getElementType().isa()) { + if (const TensorType type = mlir::dyn_cast(operand.getType())) { + if (mlir::isa(type.getElementType())) { return false; } - if (type.getElementType() - .isa()) { + if (mlir::isa( + type.getElementType())) { has_quantized_types = true; } } } for (const Value result : call_op.getResults()) { - if (const auto type = result.getType().dyn_cast()) { - if (type.getElementType().isa()) { + if (const auto type = mlir::dyn_cast(result.getType())) { + if (mlir::isa(type.getElementType())) { return false; } - if (type.getElementType() - .isa()) { + if (mlir::isa( + type.getElementType())) { has_quantized_types = true; } } @@ -943,7 +944,7 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) { ->has_same_scale_requirement) { for (const OpResult result : preceding_op->getResults()) { const Type element_type = getElementTypeOrSelf(result.getType()); - if (element_type.isa()) { + if (mlir::isa(element_type)) { return true; } } @@ -971,7 +972,7 @@ bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) { ->has_same_scale_requirement) { for (Value operand : following_op->getOperands()) { const Type element_type = getElementTypeOrSelf(operand.getType()); - if (element_type.isa()) { + if (mlir::isa(element_type)) { return true; } } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h index 3a735f15da9723..c07314d6cff6cf 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.h @@ -159,12 +159,13 @@ class StableHloQuantizationPattern : public OpRewritePattern { inputs.reserve(candidate_op->getNumOperands()); for (auto operand : candidate_op->getOperands()) { Type operand_type = operand.getType(); - if (operand_type.isa()) { + if (mlir::isa(operand_type)) { inputs.push_back(operand); continue; } - auto ele_type = operand.getType().cast().getElementType(); + auto ele_type = + mlir::cast(operand.getType()).getElementType(); if (auto dq_op = dyn_cast_or_null(operand.getDefiningOp())) { inputs.push_back(dq_op.getOperand()); @@ -190,13 +191,13 @@ class StableHloQuantizationPattern : public OpRewritePattern { Type result_type = result.getType(); // Add this to the test coverage once we create test ops with none type // results. - if (result_type.isa()) { + if (mlir::isa(result_type)) { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result_type); continue; } Type result_ele_type = - result.getType().cast().getElementType(); + mlir::cast(result.getType()).getElementType(); // If the user is the QuantizeOp, it must be the only user. if (result.hasOneUse() && isa(*result.user_begin())) { auto user = cast(*result.user_begin()); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_weight.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_weight.cc index 95f150d683c57b..e0469cc8d14032 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_weight.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantize_weight.cc @@ -111,7 +111,7 @@ class QuantizeWeight : public OpRewritePattern { QuantizationUnits GetQuantizableOps(ConstantOp op) const { // Non-float tensors do not need quantization. QuantizationUnits quantizable_ops; - const ShapedType type = op.getType().dyn_cast(); + const ShapedType type = mlir::dyn_cast(op.getType()); if (!type || !type.getElementType().isF32()) return quantizable_ops; const Value value = op.getResult(); @@ -150,7 +150,7 @@ class QuantizeWeight : public OpRewritePattern { } TensorType old_result_type = - op.getResult().getType().dyn_cast(); + mlir::dyn_cast(op.getResult().getType()); const FloatType quantized_type = FloatType::getF16(op.getContext()); const ShapedType new_result_type = old_result_type.clone(quantized_type); @@ -184,7 +184,7 @@ class QuantizeWeight : public OpRewritePattern { // Get types. const Type old_result_type = op.getResult().getType(); const ShapedType new_result_type = - convert_op.getType().dyn_cast(); + mlir::dyn_cast(convert_op.getType()); // Proceeds only if the converting is to float16. if (!new_result_type.getElementType().isF16()) continue; @@ -192,7 +192,7 @@ class QuantizeWeight : public OpRewritePattern { // Convert values. std::vector new_values; const DenseFPElementsAttr value_attr = - op.getValue().cast(); + mlir::cast(op.getValue()); new_values.reserve(value_attr.getNumElements()); for (const float value : value_attr.getValues()) { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc index 6ed82c125b0be9..e1b4adb013684c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc @@ -163,7 +163,7 @@ void CreateXlaCallModuleOp(ValueRange inputs, ValueRange outputs, SmallVector shape_attrs; for (const Type result_type : result_types) { shape_attrs.push_back( - tf_type::ShapeAttr::get(ctx, result_type.cast())); + tf_type::ShapeAttr::get(ctx, mlir::cast(result_type))); } const auto empty_array_attr = ArrayAttr::get(ctx, {}); // TODO: b/310291615 - find a better way for platform support. @@ -502,7 +502,7 @@ void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass:: SymbolTable symbol_table(module_op); for (auto call_op : main_func.getOps()) { func_ops.push_back(dyn_cast_or_null(symbol_table.lookup( - call_op.getFAttr().cast().getValue()))); + mlir::cast(call_op.getFAttr()).getValue()))); } for (auto call_op : main_func.getOps()) { func_ops.push_back( diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type.cc b/tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type.cc index efda4282b2cbec..640f0ebc5c5061 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type.cc @@ -29,7 +29,7 @@ bool IsLargeFloatType(Type type) { } Type ToBfloat16Type(Type type) { - if (auto shaped = type.dyn_cast()) { + if (auto shaped = mlir::dyn_cast(type)) { const Type elem = shaped.getElementType(); if (IsLargeFloatType(elem)) { return shaped.clone(BFloat16Type::get(type.getContext())); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils.cc b/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils.cc index 2f801565b93a1f..555e8af25b374f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils.cc @@ -37,8 +37,8 @@ limitations under the License. namespace mlir::quant::tensorflow { bool IsTFQintType(const Type type) { - return type.isa(); + return mlir::isa(type); } Type GetIntTypeFromTFQint(const Type type) { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc index 87d71438cf4e7c..e1fbe1917d9780 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/utils/tf_type_utils_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/register_common_dialects.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -188,31 +189,31 @@ TEST(GetIntTypeFromTFQintTest, ChecksIntTypesFromTFQint) { auto type = GetIntTypeFromTFQint(TF::Qint8Type::get(context.get())); EXPECT_TRUE(llvm::isa(type)); - EXPECT_EQ(type.dyn_cast().getWidth(), 8); - EXPECT_FALSE(type.dyn_cast().isSigned()); - EXPECT_FALSE(type.dyn_cast().isUnsigned()); + EXPECT_EQ(mlir::dyn_cast(type).getWidth(), 8); + EXPECT_FALSE(mlir::dyn_cast(type).isSigned()); + EXPECT_FALSE(mlir::dyn_cast(type).isUnsigned()); type = GetIntTypeFromTFQint(TF::Qint16Type::get(context.get())); EXPECT_TRUE(llvm::isa(type)); - EXPECT_EQ(type.dyn_cast().getWidth(), 16); - EXPECT_FALSE(type.dyn_cast().isSigned()); - EXPECT_FALSE(type.dyn_cast().isUnsigned()); + EXPECT_EQ(mlir::dyn_cast(type).getWidth(), 16); + EXPECT_FALSE(mlir::dyn_cast(type).isSigned()); + EXPECT_FALSE(mlir::dyn_cast(type).isUnsigned()); type = GetIntTypeFromTFQint(TF::Qint32Type::get(context.get())); EXPECT_TRUE(llvm::isa(type)); - EXPECT_EQ(type.dyn_cast().getWidth(), 32); - EXPECT_FALSE(type.dyn_cast().isSigned()); - EXPECT_FALSE(type.dyn_cast().isUnsigned()); + EXPECT_EQ(mlir::dyn_cast(type).getWidth(), 32); + EXPECT_FALSE(mlir::dyn_cast(type).isSigned()); + EXPECT_FALSE(mlir::dyn_cast(type).isUnsigned()); type = GetIntTypeFromTFQint(TF::Quint8Type::get(context.get())); EXPECT_TRUE(llvm::isa(type)); - EXPECT_EQ(type.dyn_cast().getWidth(), 8); - EXPECT_TRUE(type.dyn_cast().isUnsigned()); + EXPECT_EQ(mlir::dyn_cast(type).getWidth(), 8); + EXPECT_TRUE(mlir::dyn_cast(type).isUnsigned()); type = GetIntTypeFromTFQint(TF::Quint16Type::get(context.get())); EXPECT_TRUE(llvm::isa(type)); - EXPECT_EQ(type.dyn_cast().getWidth(), 16); - EXPECT_TRUE(type.dyn_cast().isUnsigned()); + EXPECT_EQ(mlir::dyn_cast(type).getWidth(), 16); + EXPECT_TRUE(mlir::dyn_cast(type).isUnsigned()); // Non qint types are returned as is. EXPECT_EQ(GetIntTypeFromTFQint(IntegerType::get(type.getContext(), 32)), diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc index c2445456339fb9..60d2c07bdab8ea 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc @@ -93,8 +93,7 @@ SmallVector GetEntryFunctionInputs(func::FuncOp func_op) { func_op->getAttrOfType("tf.entry_function"); SmallVector inputs; - entry_function_attr.get("inputs") - .dyn_cast_or_null() + mlir::dyn_cast_or_null(entry_function_attr.get("inputs")) .strref() .split(inputs, /*Separator=*/","); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args_test.cc index 238b1bb8ef8955..d77859a67c9dca 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args_test.cc @@ -94,9 +94,11 @@ TEST_F(ConvertAssetArgsTest, ConvertsSingleAssetArg) { EXPECT_THAT(arg_attrs.get("tf_saved_model.bound_input"), IsNull()); const ArrayRef index_path_attrs = - arg_attrs.get("tf_saved_model.index_path").cast().getValue(); + mlir::cast(arg_attrs.get("tf_saved_model.index_path")) + .getValue(); EXPECT_THAT(index_path_attrs, SizeIs(1)); - StringAttr index_path = index_path_attrs[0].dyn_cast_or_null(); + StringAttr index_path = + mlir::dyn_cast_or_null(index_path_attrs[0]); EXPECT_THAT(index_path, NotNull()); EXPECT_THAT(index_path, Eq("arg_0:0")); } @@ -122,9 +124,11 @@ TEST_F(ConvertAssetArgsTest, NonBoundedArgsNotModified) { EXPECT_THAT(arg_attrs.get("tf_saved_model.bound_input"), IsNull()); const ArrayRef index_path_attrs = - arg_attrs.get("tf_saved_model.index_path").cast().getValue(); + mlir::cast(arg_attrs.get("tf_saved_model.index_path")) + .getValue(); EXPECT_THAT(index_path_attrs, SizeIs(1)); - StringAttr index_path = index_path_attrs[0].dyn_cast_or_null(); + StringAttr index_path = + mlir::dyn_cast_or_null(index_path_attrs[0]); EXPECT_THAT(index_path, NotNull()); EXPECT_THAT(index_path, Eq("arg_0:0")); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.cc index 7be369e7947ced..8ba632b66ae0f3 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.cc @@ -65,9 +65,9 @@ bool QuantizationUnitLoc::classof(Attribute attr) { if (!llvm::isa(attr)) return false; auto callsite_loc = llvm::dyn_cast(attr); - if (!callsite_loc.getCaller().isa()) return false; + if (!mlir::isa(callsite_loc.getCaller())) return false; StringRef caller_name = - callsite_loc.getCaller().cast().getName().strref(); + mlir::cast(callsite_loc.getCaller()).getName().strref(); return caller_name.starts_with(kQuantizationUnitPrefix) && caller_name.ends_with(kQuantizationUnitSuffix); } @@ -75,8 +75,8 @@ bool QuantizationUnitLoc::classof(Attribute attr) { std::optional FindQuantizationUnitFromLoc(Location loc) { if (isa(loc)) { - Location caller = loc.cast().getCaller(); - StringRef caller_name = caller.cast().getName().strref(); + Location caller = mlir::cast(loc).getCaller(); + StringRef caller_name = mlir::cast(caller).getName().strref(); const size_t start_index = kQuantizationUnitPrefix.size(); const size_t end_index = caller_name.rfind(kQuantizationUnitSuffix); std::string serialized_proto = @@ -87,14 +87,15 @@ FindQuantizationUnitFromLoc(Location loc) { } } else if (isa(loc)) { // If the op is rewritten, FusedLoc can be created. - for (Location child_loc : loc.cast().getLocations()) { + for (Location child_loc : mlir::cast(loc).getLocations()) { std::optional found_unit = FindQuantizationUnitFromLoc(child_loc); if (found_unit.has_value()) return found_unit; } } else if (isa(loc)) { // If the graph is inlined, CallSiteLoc can be created. - return FindQuantizationUnitFromLoc(loc.cast().getCallee()); + return FindQuantizationUnitFromLoc( + mlir::cast(loc).getCallee()); } return std::nullopt; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc index 52ca3722a12bd5..9630b20b32d571 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc @@ -56,7 +56,7 @@ bool IsOpWithInt8TypeOperand(Operation* op) { } bool IsValueWithQuantizablePrecision(Value val) { - auto type = val.getType().dyn_cast(); + auto type = mlir::dyn_cast(val.getType()); if (!type) return false; // Supported original tensor data types. if (type.getElementType().isF32() || type.getElementType().isBF16()) @@ -82,7 +82,7 @@ std::unique_ptr GetTFOpQuantSpec(Operation* op) { auto spec = std::make_unique(); if (auto call_op = dyn_cast(op)) { StringRef function_name = - call_op.getFAttr().cast().getValue(); + mlir::cast(call_op.getFAttr()).getValue(); if (!function_name.starts_with("composite_")) { return spec; } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.cc index 723adde447e546..47beb9e0c2636f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.cc @@ -153,11 +153,10 @@ QuantizedType CalculateUniformQuantParams( DenseFPElementsAttr attr; if (!matchPattern(op->getResult(0), m_Constant(&attr))) return nullptr; - QuantizedType quant_type = + QuantizedType quant_type = mlir::dyn_cast( quant::GetUniformQuantizedTypeForWeight( attr, /*symmetric=*/kIsNarrowRange && kIsSigned, kBitWidth, kIsSigned, - kIsNarrowRange, /*is_legacy_float*/ false) - .template dyn_cast(); + kIsNarrowRange, /*is_legacy_float*/ false)); return quant_type; } @@ -172,16 +171,16 @@ std::optional AddUniformQuantizeOps(PatternRewriter& rewriter, } Type expressed_type = op.getResult().getType(); Type quantized_type = quant_type.castFromExpressedType(expressed_type); - ShapedType shaped_quantized_type = quantized_type.cast(); + ShapedType shaped_quantized_type = mlir::cast(quantized_type); DenseElementsAttr tensor_proto_attr = - Quantize(attr, shaped_quantized_type).dyn_cast(); + mlir::dyn_cast(Quantize(attr, shaped_quantized_type)); if (!tensor_proto_attr) { return nullptr; } - Type storage_type = shaped_quantized_type.getElementType() - .cast() - .getStorageType(); + Type storage_type = + mlir::cast(shaped_quantized_type.getElementType()) + .getStorageType(); ShapedType new_type = shaped_quantized_type.clone(storage_type); rewriter.setInsertionPointAfter(op); @@ -205,7 +204,7 @@ Operation* LogicsForUniformDequanization(PatternRewriter& rewriter, auto new_cast_op = rewriter.create(loc, create_unknown_input_shape, input_val); // TODO - b/278949920: Enable Per-Channel Quantization for XLA Opset - auto qtype = quant_type.dyn_cast(); + auto qtype = mlir::dyn_cast(quant_type); TensorType scale_type = RankedTensorType::get({}, rewriter.getF32Type()); Value scale_op = rewriter.create( loc, scale_type, @@ -253,7 +252,7 @@ std::optional ApplyUniformQuantization( std::optional dequantized_val = AddUniformDequantizeOps(rewriter, quant_type, quantized_val.value(), - op.getType().cast()); + mlir::cast(op.getType())); return dequantized_val; } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_quantization_unit_loc.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_quantization_unit_loc.cc index d390ac6d548e78..109fa943f9334b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_quantization_unit_loc.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_quantization_unit_loc.cc @@ -69,17 +69,17 @@ class AddQuantizationUnitLocPass // tensorflow/compiler/mlir/tensorflow/translate/import_model.cc for more // details. bool IsImportLocPattern(FusedLoc loc) { - ArrayRef locations = loc.cast().getLocations(); + ArrayRef locations = mlir::cast(loc).getLocations(); if (locations.size() < 2 || !isa(locations.front())) return false; StringRef op_type_with_suffix = - locations.front().cast().getName().strref(); + mlir::cast(locations.front()).getName().strref(); if (!op_type_with_suffix.ends_with(":")) return false; return absl::c_all_of(locations, [](Location loc) { return isa(loc) || (isa(loc) && - isa(loc.cast().getCallee())); + isa(mlir::cast(loc).getCallee())); }); } @@ -99,23 +99,23 @@ void FindQuantizationUnitsRecursively(Location loc, } }; - ArrayRef locations = loc.cast().getLocations(); - if (IsImportLocPattern(loc.cast())) { + ArrayRef locations = mlir::cast(loc).getLocations(); + if (IsImportLocPattern(mlir::cast(loc))) { QuantizationUnit new_unit; // Op type is a NameLoc with the ":" suffix. StringRef op_type_with_suffix = - locations.front().cast().getName().strref(); + mlir::cast(locations.front()).getName().strref(); StringRef op_type = op_type_with_suffix.substr(0, op_type_with_suffix.size() - 1); new_unit.set_op_type(op_type.str()); if (isa(locations.back())) { StringRef name_loc_id = - locations.back().cast().getName().strref(); + mlir::cast(locations.back()).getName().strref(); set_node_and_func_name(new_unit, name_loc_id); } else { - Location callee = locations.back().cast().getCallee(); - StringRef name_loc_id = callee.cast().getName().strref(); + Location callee = mlir::cast(locations.back()).getCallee(); + StringRef name_loc_id = mlir::cast(callee).getName().strref(); set_node_and_func_name(new_unit, name_loc_id); } units.push_back(new_unit); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc index e4229cb97bf45a..8c02ace87d8001 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_custom_aggregation_op_to_quant_stats.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" @@ -75,8 +76,8 @@ class ConvertCustomAggregationOpToQuantStats LogicalResult matchAndRewrite(TF::CustomAggregatorOp op, PatternRewriter &rewriter) const override { - FloatAttr min = op->getAttr("min").dyn_cast_or_null(); - FloatAttr max = op->getAttr("max").dyn_cast_or_null(); + FloatAttr min = mlir::dyn_cast_or_null(op->getAttr("min")); + FloatAttr max = mlir::dyn_cast_or_null(op->getAttr("max")); // When there are no min and max attributes, remove op. if (min == nullptr || max == nullptr) { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_xla_op_to_tf_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_xla_op_to_tf_op.cc index d23a0f8d3a7af2..c39492f0efe709 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_xla_op_to_tf_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tf_xla_op_to_tf_op.cc @@ -158,10 +158,8 @@ Value CreateEinsumOpFromXlaDotV2Op(OpBuilder& builder, const Location loc, xla::DotDimensionNumbers dot_dimension_numbers; dot_dimension_numbers.ParseFromString(dot_dimension_numbers_str.str()); SmallVector input_arguments = {lhs, rhs}; - const int lhs_rank = - lhs.getType().template cast().getShape().size(); - const int rhs_rank = - rhs.getType().template cast().getShape().size(); + const int lhs_rank = mlir::cast(lhs.getType()).getShape().size(); + const int rhs_rank = mlir::cast(rhs.getType()).getShape().size(); const std::string einsum_equation = CreateEinsumEquation(dot_dimension_numbers, lhs_rank, rhs_rank); @@ -218,7 +216,7 @@ RankedTensorType RestoreCollapsedDimensions( Type GetSliceOpOutputType(Type xla_gather_op_output_type, const absl::flat_hash_set& collapsed_dims) { if (auto ranked_output_type = - xla_gather_op_output_type.dyn_cast(); + mlir::dyn_cast(xla_gather_op_output_type); ranked_output_type) { return RestoreCollapsedDimensions(ranked_output_type, collapsed_dims); } @@ -228,9 +226,9 @@ Type GetSliceOpOutputType(Type xla_gather_op_output_type, // TODO (b/275225582): Supports Xla Gather op in general case. bool IsXlaGatherWithoutBatch(Value operand, Value start_indices) { - auto operand_type = operand.getType().dyn_cast_or_null(); + auto operand_type = mlir::dyn_cast_or_null(operand.getType()); auto start_indices_type = - start_indices.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(start_indices.getType()); if (start_indices_type == nullptr || operand_type == nullptr) return false; return start_indices_type.getShape().size() == 1; } @@ -245,7 +243,7 @@ Value CreateSliceAndReshapeOpFromXlaGatherOpWithoutBatch( // Construct full start_indices with given start_indices and // start_index_map. const ArrayRef operand_shape = - operand.getType().cast().getShape(); + mlir::cast(operand.getType()).getShape(); const int64_t operand_rank = operand_shape.size(); // Fills zeros if start_index is not given in start_indices. @@ -273,7 +271,7 @@ Value CreateSliceAndReshapeOpFromXlaGatherOpWithoutBatch( builder.create( loc, RankedTensorType::get( - start_indices.getType().template cast().getShape(), + mlir::cast(start_indices.getType()).getShape(), builder.getI64Type()), start_indices)); @@ -289,7 +287,7 @@ Value CreateSliceAndReshapeOpFromXlaGatherOpWithoutBatch( builder.create( loc, RankedTensorType::get( - slice_sizes.getType().template cast().getShape(), + mlir::cast(slice_sizes.getType()).getShape(), builder.getI64Type()), slice_sizes)); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.cc index b4cdcd8f771a21..b3fc6207842469 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.cc @@ -17,6 +17,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project @@ -92,7 +93,7 @@ class ReplaceTpuPartitionedCallOpWithPartitionedCallOp private: LogicalResult matchAndRewrite(TF::TPUPartitionedCallOp call_op, PatternRewriter& rewriter) const override { - auto f_attr = call_op.getFAttr().dyn_cast(); + auto f_attr = mlir::dyn_cast(call_op.getFAttr()); auto module_op = call_op->getParentOfType(); SymbolTable symbol_table(module_op); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc index 8790c908ba88a6..5450ae8442ca43 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc @@ -240,7 +240,7 @@ class AddCustomAggregationOp : public RewritePattern { if (auto call_op = dyn_cast_or_null(defining_op)) { StringRef function_name = - call_op.getFAttr().cast().getValue(); + mlir::cast(call_op.getFAttr()).getValue(); if (function_name.contains("gather")) continue; } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc index 682889917c112e..0f855088d17943 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_main_function.cc @@ -154,7 +154,7 @@ void GetUniqueInputOutputNodeNames(ModuleOp module_op, if (auto inputs_attr = tf_attrs.get("inputs")) { const std::string inputs_attr_str = - inputs_attr.cast().getValue().str(); + mlir::cast(inputs_attr).getValue().str(); std::vector fn_input_names = absl::StrSplit(inputs_attr_str, ',', absl::SkipEmpty()); @@ -174,7 +174,7 @@ void GetUniqueInputOutputNodeNames(ModuleOp module_op, if (auto outputs_attr = tf_attrs.get("outputs")) { const std::string outputs_attr_str = - outputs_attr.cast().getValue().str(); + mlir::cast(outputs_attr).getValue().str(); std::vector fn_output_names = absl::StrSplit(outputs_attr_str, ',', absl::SkipEmpty()); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.cc index 63fb3bd94005ee..80691bb57e789c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.cc @@ -174,7 +174,7 @@ class CheckQuantizableOps LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op, PatternRewriter& rewriter) const override { StringRef function_name = - call_op.getFAttr().cast().getValue(); + mlir::cast(call_op.getFAttr()).getValue(); if (!function_name.starts_with("composite_") || !call_op->hasAttr(kQuantTraitAttrName)) { return failure(); @@ -193,11 +193,10 @@ class CheckQuantizableOps } // Only the composite functions with f32 inputs are quantizable. - if (call_op.getResults().size() == 1 && !call_op->getResult(0) - .getType() - .cast() - .getElementType() - .isF32()) { + if (call_op.getResults().size() == 1 && + !mlir::cast(call_op->getResult(0).getType()) + .getElementType() + .isF32()) { check_status.Update(absl::InternalError( "Composite functions for quantization should be f32 type.")); } @@ -274,7 +273,7 @@ class CheckQuantizableOps // For BatchMatMul, the input must be ranked to determine the batch // dimensions. ShapedType shaped_type = - call_op->getOperand(0).getType().dyn_cast(); + mlir::dyn_cast(call_op->getOperand(0).getType()); if (!shaped_type || !shaped_type.hasRank()) { return absl::InternalError("The input of BatchMatMul must have rank."); } @@ -282,7 +281,8 @@ class CheckQuantizableOps // This op is guaranteed to be a constant as ODS checks IsConstTensor. // Check if the number of elements meets the requirement. int64_t num_elements = - call_op.getOperand(0).getType().cast().getNumElements(); + mlir::cast(call_op.getOperand(0).getType()) + .getNumElements(); if (num_elements < quant_options_.min_num_elements_for_weights()) { return absl::InternalError( "The params of Gather have fewer number of elements than " diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.cc index 0acb2e56ea617e..a75bef5f842746 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.cc @@ -137,7 +137,8 @@ class CheckQuantizableOps // This op is guaranteed to be a constant as ODS checks IsConstTensor. // Check if the number of elements meets the requirement. int current_num_elements = - call_op.getOperand(idx).getType().cast().getNumElements(); + mlir::cast(call_op.getOperand(idx).getType()) + .getNumElements(); if (current_num_elements < min_num_elements_for_weights_) { call_op.emitRemark("Quantization is skipped for ") << call_op->getName().getStringRef().str() << " because it has " @@ -149,7 +150,7 @@ class CheckQuantizableOps } StringRef function_name = - call_op.getFAttr().cast().getValue(); + mlir::cast(call_op.getFAttr()).getValue(); if ((quantization_method_ == tensorflow::quantization::QuantizationMethod:: METHOD_DYNAMIC_RANGE_INT8) && (function_name.contains("batch_matmul") || diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc index f1f65a1a183371..85acaeb9603f2e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/TypeRange.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/common/func.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.h" @@ -153,7 +154,7 @@ LogicalResult ValidateInitFunc(func::FuncOp init_func_op) { FetchOp fetch_op = graph_op.GetFetch(); for (const Value fetch : fetch_op.getFetches()) { - if (!fetch.getType().isa()) { + if (!mlir::isa(fetch.getType())) { fetch_op.emitError(absl::StrFormat( "Validation failed for the initializer function: %s. " "All initializer function's fetches should be " diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_save_function_ops_to_main.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_save_function_ops_to_main.cc index e092352dc52c29..6f42c9fcaba7c5 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_save_function_ops_to_main.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_save_function_ops_to_main.cc @@ -143,7 +143,7 @@ BlockArgument GetFilePrefixArg(func::FuncOp main_func_op) { auto index_path_attr = main_func_op.getArgAttrOfType(i, kTfSavedModelIndexPathAttr); if (index_path_attr && !index_path_attr.empty() && - index_path_attr[0].cast() == kTfFilePrefix) { + mlir::cast(index_path_attr[0]) == kTfFilePrefix) { return main_func_op.getArgument(i); } } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc index 38075bb67b7010..b0a84d71c84182 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.cc @@ -98,8 +98,8 @@ class PrepareLiftingPass // indices in `val2`. bool HasEqualElementSize(Value val1, Value val2, ArrayRef val1_indices, ArrayRef val2_indices) { - ShapedType val1_shape = val1.getType().cast(); - ShapedType val2_shape = val2.getType().cast(); + ShapedType val1_shape = mlir::cast(val1.getType()); + ShapedType val2_shape = mlir::cast(val2.getType()); if (!val1_shape.hasRank() || !val2_shape.hasRank()) return false; int val1_result = 1; @@ -134,7 +134,7 @@ bool ReshapableTo1DTensor(ShapedType rhs_shape) { } Value ReshapeTo1DTensor(OpBuilder& builder, Location loc, Value value) { - auto shape = value.getType().cast(); + auto shape = mlir::cast(value.getType()); if (shape.getRank() != 1) { SmallVector new_shape; new_shape.push_back(shape.getNumElements()); @@ -182,7 +182,7 @@ LogicalResult MatchSupportedAffineOp(Operation* op, Value& binding_output, // Makes the 1D value broadcastable with the `rhs_shape`. Value MakeOneDimValueBroadcastable(OpBuilder& builder, Location loc, Value value, ShapedType rhs_shape) { - ShapedType value_shape = value.getType().dyn_cast_or_null(); + ShapedType value_shape = mlir::dyn_cast_or_null(value.getType()); if (!value_shape || value_shape.getRank() != 1 || !value_shape.hasStaticShape() || !rhs_shape.hasStaticShape()) { return {}; @@ -211,7 +211,8 @@ bool CanBeSymmetricallyQuantized(Value weight) { auto dq_op = weight.getDefiningOp(); if (!dq_op) return true; - auto qtype = dq_op.getArg().getType().cast().getElementType(); + auto qtype = + mlir::cast(dq_op.getArg().getType()).getElementType(); if (auto uniform_type = llvm::dyn_cast_or_null(qtype)) { return uniform_type.getZeroPoint() == 0; } else if (auto per_axis_type = @@ -252,12 +253,12 @@ Value MultiplyFakeQuantValue(OpBuilder& builder, Location loc, Value value, Value float_value = q_op.getArg(); Value new_value = builder.create(loc, float_value, multiplier); - auto new_value_type = new_value.getType().cast(); + auto new_value_type = mlir::cast(new_value.getType()); // Get multiplier value in double. DenseFPElementsAttr multiplier_attr; if (!matchPattern(multiplier, m_Constant(&multiplier_attr)) || - multiplier_attr.getType().cast().getRank() > 1) { + mlir::cast(multiplier_attr.getType()).getRank() > 1) { return {}; } std::vector multiplier_values; @@ -268,7 +269,7 @@ Value MultiplyFakeQuantValue(OpBuilder& builder, Location loc, Value value, // Multiply the quantization parameters by the multiplier. QuantizedType new_qtype; - auto element_type = q_op.getType().cast().getElementType(); + auto element_type = mlir::cast(q_op.getType()).getElementType(); if (auto uniform_type = llvm::dyn_cast(element_type)) { if (multiplier_attr.isSplat()) { double new_scale = multiplier_array.front() * uniform_type.getScale(); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc index fe38ed8dc0f634..cad8c1686eb67b 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -171,8 +172,8 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(func::FuncOp func) { bool need_to_set_input_nodes_quantization_params = false; for (const BlockArgument arg : func.getArguments()) { - auto shaped = arg.getType().dyn_cast(); - if (shaped && shaped.getElementType().isa() && + auto shaped = mlir::dyn_cast(arg.getType()); + if (shaped && mlir::isa(shaped.getElementType()) && !has_quantize_op(arg)) { need_to_set_input_nodes_quantization_params = true; break; @@ -197,8 +198,8 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(func::FuncOp func) { auto add_quantize_op = [&](Location loc, Type input_type, Block* block, Block::iterator insertion_point, Value arg, int i) { - if (auto shaped = input_type.dyn_cast()) { - if (shaped.getElementType().isa()) { + if (auto shaped = mlir::dyn_cast(input_type)) { + if (mlir::isa(shaped.getElementType())) { // If there are existing quantize ops, they are from training and we // should respect them. if (has_quantize_op(arg)) { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc index 71587390580406..b2c0ceb205ca99 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize_drq.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" @@ -142,7 +143,7 @@ class PrepareDRQQuantizableOp : public OpRewritePattern { bool getQuantizableOps(arith::ConstantOp op, QuantizationUnits& quantizable_ops) const { // Non-float tensors do not need quantization. - auto type = op.getType().dyn_cast(); + auto type = mlir::dyn_cast(op.getType()); if (!type || !type.getElementType().isF32()) return false; Value value = op.getResult(); @@ -183,23 +184,23 @@ class PrepareDRQQuantizableOp : public OpRewritePattern { if (attr.size() < quant_specs_.minimum_elements_for_weights) { op->emitRemark("Quantization is skipped for ") << quantized_op->getName().getStringRef().str() << " because it has " - << attr.dyn_cast().size() + << mlir::dyn_cast(attr).size() << " elements which is fewer than the threshold(" << quant_specs_.minimum_elements_for_weights << " elements)."; return false; } if (is_per_channel_quantization) { - quant_type = quant::GetUniformQuantizedPerAxisTypeForWeight( - attr, quant_dim, - /*symmetric=*/true, bit_width, is_signed, - is_narrow_range, is_legacy_float) - .template dyn_cast(); + quant_type = mlir::dyn_cast( + quant::GetUniformQuantizedPerAxisTypeForWeight( + attr, quant_dim, + /*symmetric=*/true, bit_width, is_signed, is_narrow_range, + is_legacy_float)); } else { - quant_type = quant::GetUniformQuantizedTypeForWeight( - attr, is_narrow_range && is_signed, bit_width, is_signed, - is_narrow_range, is_legacy_float) - .template dyn_cast(); + quant_type = mlir::dyn_cast( + quant::GetUniformQuantizedTypeForWeight( + attr, is_narrow_range && is_signed, bit_width, is_signed, + is_narrow_range, is_legacy_float)); } return insertQDQ(rewriter, op, quant_type, quant_op); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc index 3f54fe580fe1c4..08b2faadacd3d5 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.cc @@ -202,7 +202,7 @@ class PreprocessConstantOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::PartitionedCallOp op, PatternRewriter& rewriter) const override { - const auto f_attr = op.getFAttr().dyn_cast(); + const auto f_attr = mlir::dyn_cast(op.getFAttr()); // Non-quantizable op if (!op->hasAttr(kQuantTraitAttrName)) return failure(); StringRef function_name = f_attr.getValue(); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/propagate_quantize_type.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/propagate_quantize_type.cc index 8570652b4019e7..0d2edd5bacd6c1 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/propagate_quantize_type.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/propagate_quantize_type.cc @@ -100,7 +100,7 @@ class PropagateDequantizeOpIfAllowed LogicalResult matchAndRewrite(TF::PartitionedCallOp op, PatternRewriter& rewriter) const override { - const auto f_attr = op.getFAttr().dyn_cast(); + const auto f_attr = mlir::dyn_cast(op.getFAttr()); StringRef function_name = f_attr.getValue(); if (!function_name.starts_with(kDequantizeFunctionName)) return failure(); @@ -127,7 +127,8 @@ class PropagateDequantizeOpIfAllowed auto original_result_type = user_op->getResult(0).getType(); auto new_user_op_type = CloneTypeWithNewElementType( original_result_type, - op_before_dequantize.getType().cast().getElementType()); + mlir::cast(op_before_dequantize.getType()) + .getElementType()); createNewDequantizeOp(rewriter, op, user_op, user_idx, new_user_op_type); } else { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc index 0b3c89c56f60bb..50409709d44854 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc @@ -213,11 +213,11 @@ LogicalResult CreateQuantizationParams(QuantizedType elem_type, Location loc, if (!elem_type) { return failure(); } - if (auto qtype = elem_type.dyn_cast()) { + if (auto qtype = mlir::dyn_cast(elem_type)) { return CreateUniformQuantizedTypeParams(qtype, loc, rewriter, scale, zero_point); - } else if (auto qtype = - elem_type.dyn_cast()) { + } else if (auto qtype = mlir::dyn_cast( + elem_type)) { return CreateUniformQuantizedPerAxisTypeParams(qtype, loc, rewriter, scale, zero_point); } @@ -235,7 +235,7 @@ ShapedType ConvertIntToQint(ShapedType input_type, MLIRContext* ctx) { if (ele_type.isIntOrFloat()) { bit_width = ele_type.getIntOrFloatBitWidth(); is_signed = ele_type.isSignlessIntOrFloat() || ele_type.isSignedInteger(); - } else if (QuantizedType qtype = ele_type.dyn_cast()) { + } else if (QuantizedType qtype = mlir::dyn_cast(ele_type)) { bit_width = qtype.getStorageTypeIntegralWidth(); is_signed = qtype.isSigned(); } else { @@ -275,8 +275,9 @@ class ReplaceQuantizePattern LogicalResult matchAndRewrite(quantfork::QuantizeCastOp q_op, PatternRewriter& rewriter) const override { - auto output_type = q_op.getType().cast(); - auto elem_type = output_type.getElementType().dyn_cast(); + auto output_type = mlir::cast(q_op.getType()); + auto elem_type = + mlir::dyn_cast(output_type.getElementType()); const Location loc = q_op->getLoc(); Value scale, zero_point; @@ -289,7 +290,7 @@ class ReplaceQuantizePattern if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { ShapedType new_output_type = ConvertIntToQint( - output_type.cast(), rewriter.getContext()); + mlir::cast(output_type), rewriter.getContext()); if (!new_output_type) { q_op->emitError( "Failed to convert the type to the corresponding qtype."); @@ -327,8 +328,8 @@ class ReplaceDequantizePattern LogicalResult matchAndRewrite(quantfork::DequantizeCastOp dq_op, PatternRewriter& rewriter) const override { - auto input_type = dq_op.getArg().getType().cast(); - auto elem_type = input_type.getElementType().dyn_cast(); + auto input_type = mlir::cast(dq_op.getArg().getType()); + auto elem_type = mlir::dyn_cast(input_type.getElementType()); const Location loc = dq_op->getLoc(); Value scale, zero_point; @@ -340,13 +341,13 @@ class ReplaceDequantizePattern TensorType output_type = input_type.clone(elem_type.getStorageType()); if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { ShapedType new_output_type = ConvertIntToQint( - output_type.cast(), rewriter.getContext()); + mlir::cast(output_type), rewriter.getContext()); if (!new_output_type) { dq_op->emitError( "Failed to convert the type to the corresponding qtype."); return failure(); } - output_type = new_output_type.cast(); + output_type = mlir::cast(new_output_type); } auto scast_op = rewriter.create(loc, output_type, @@ -376,8 +377,8 @@ bool IsQuantizedCallforDynamicRange(TF::PartitionedCallOp call_op) { return false; } else if (cur_op) { // Check if the QuantizeCastOp has element type of quantized type. - if (!getElementTypeOrSelf(cur_op.getResult().getType()) - .isa()) { + if (!mlir::isa( + getElementTypeOrSelf(cur_op.getResult().getType()))) { return false; } // Satisfies the input condition. @@ -385,8 +386,8 @@ bool IsQuantizedCallforDynamicRange(TF::PartitionedCallOp call_op) { } } for (Value output : call_op.getOutput()) { - if (auto type = output.getType().dyn_cast()) { - if (type.getElementType().isa()) { + if (auto type = mlir::dyn_cast(output.getType())) { + if (mlir::isa(type.getElementType())) { return false; } } @@ -398,15 +399,15 @@ bool IsQuantizedCallforDynamicRange(TF::PartitionedCallOp call_op) { bool IsQuantizedCallforStaticRange(TF::PartitionedCallOp call_op) { bool has_quantized_types = false; for (Value input : call_op.getArgs()) { - if (auto type = input.getType().dyn_cast()) { - if (type.getElementType().isa()) { + if (auto type = mlir::dyn_cast(input.getType())) { + if (mlir::isa(type.getElementType())) { has_quantized_types = true; } } } for (Value output : call_op.getOutput()) { - if (auto type = output.getType().dyn_cast()) { - if (type.getElementType().isa()) { + if (auto type = mlir::dyn_cast(output.getType())) { + if (mlir::isa(type.getElementType())) { has_quantized_types = true; } } @@ -616,7 +617,7 @@ std::string GetQuantizedFunctionName(StringRef func_name, bool ContainsFloatResultType(ArrayRef result_types) { for (auto current_type : result_types) { - if (current_type.dyn_cast().getElementType().isF32()) + if (mlir::dyn_cast(current_type).getElementType().isF32()) return true; } return false; @@ -644,7 +645,7 @@ class QuantizeFunctionPattern LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op, PatternRewriter& rewriter) const override { - const auto f_attr = call_op.getFAttr().dyn_cast(); + const auto f_attr = mlir::dyn_cast(call_op.getFAttr()); // removeAttr will return nullptr if no attribute was removed. if (!call_op->removeAttr(kQuantTraitAttrName) || !f_attr) { return failure(); @@ -671,12 +672,12 @@ class QuantizeFunctionPattern SmallVector args; SmallVector qparam_args; for (Value arg : call_op.getArgs()) { - if (const auto arg_type = arg.getType().dyn_cast()) { + if (const auto arg_type = mlir::dyn_cast(arg.getType())) { QuantizedType qtype = - arg_type.getElementType().dyn_cast(); + mlir::dyn_cast(arg_type.getElementType()); if (!qtype) continue; - if (!qtype.isa()) { + if (!mlir::isa(qtype)) { return failure(); } Value scale, zero_point; @@ -693,12 +694,12 @@ class QuantizeFunctionPattern } for (Value result : call_op->getResults()) { - if (auto result_type = result.getType().dyn_cast()) { + if (auto result_type = mlir::dyn_cast(result.getType())) { QuantizedType qtype = - result_type.getElementType().dyn_cast(); + mlir::dyn_cast(result_type.getElementType()); if (!qtype) continue; - if (!qtype.isa()) { + if (!mlir::isa(qtype)) { return failure(); } Value scale, zero_point; @@ -717,12 +718,13 @@ class QuantizeFunctionPattern rewriter.setInsertionPoint(call_op); for (Value arg : call_op.getArgs()) { - TensorType arg_type = arg.getType().dyn_cast(); + TensorType arg_type = mlir::dyn_cast(arg.getType()); if (!arg_type) { args.push_back(arg); continue; } - QuantizedType qtype = arg_type.getElementType().dyn_cast(); + QuantizedType qtype = + mlir::dyn_cast(arg_type.getElementType()); if (!qtype) { args.push_back(arg); continue; @@ -730,15 +732,15 @@ class QuantizeFunctionPattern quantfork::StorageCastOp scast_op; if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { - ShapedType new_arg_type = ConvertIntToQint(arg_type.cast(), - rewriter.getContext()); + ShapedType new_arg_type = ConvertIntToQint( + mlir::cast(arg_type), rewriter.getContext()); if (!new_arg_type) { call_op->emitError( "Failed to convert the type to the corresponding qtype."); return failure(); } scast_op = rewriter.create( - arg.getLoc(), new_arg_type.cast(), arg); + arg.getLoc(), mlir::cast(new_arg_type), arg); } else { scast_op = rewriter.create( arg.getLoc(), arg_type.clone(qtype.getStorageType()), arg); @@ -761,20 +763,20 @@ class QuantizeFunctionPattern SmallVector result_types; for (Value result : call_op->getResults()) { - TensorType result_type = result.getType().dyn_cast(); + TensorType result_type = mlir::dyn_cast(result.getType()); if (!result_type) { result_types.push_back(result.getType()); continue; } QuantizedType qtype = - result_type.getElementType().dyn_cast(); + mlir::dyn_cast(result_type.getElementType()); if (!qtype) { result_types.push_back(result_type); continue; } if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { ShapedType new_result_type = ConvertIntToQint( - result_type.cast(), rewriter.getContext()); + mlir::cast(result_type), rewriter.getContext()); result_types.push_back(new_result_type); } else { result_types.push_back(result_type.clone(qtype.getStorageType())); @@ -871,13 +873,13 @@ class QuantizeFunctionPattern rewriter.setInsertionPointAfter(call_op); SmallVector result_types; for (Value result : call_op->getResults()) { - TensorType result_type = result.getType().dyn_cast(); + TensorType result_type = mlir::dyn_cast(result.getType()); if (!result_type) { result_types.push_back(result.getType()); continue; } QuantizedType qtype = - result_type.getElementType().dyn_cast(); + mlir::dyn_cast(result_type.getElementType()); if (!qtype) { result_types.push_back(result_type); continue; @@ -890,7 +892,7 @@ class QuantizeFunctionPattern auto module = call_op->getParentOfType(); SymbolTable symbol_table(module); - const auto f_attr = call_op.getFAttr().dyn_cast(); + const auto f_attr = mlir::dyn_cast(call_op.getFAttr()); const auto float_func = dyn_cast(symbol_table.lookup(f_attr.getValue())); rewriter.setInsertionPointAfter(float_func); @@ -973,14 +975,15 @@ class QuantizeConstPattern return failure(); } - ShapedType tensor_qtype = q_op.getResult().getType().cast(); + ShapedType tensor_qtype = + mlir::cast(q_op.getResult().getType()); Attribute tensor_proto_attr = Quantize(attr, tensor_qtype); if (!tensor_proto_attr) { return failure(); } - Type storage_type = - tensor_qtype.getElementType().cast().getStorageType(); + Type storage_type = mlir::cast(tensor_qtype.getElementType()) + .getStorageType(); ShapedType new_type = tensor_qtype.clone(storage_type); Location loc = q_op.getArg().getLoc(); @@ -991,14 +994,14 @@ class QuantizeConstPattern // workaround. tensorflow::TensorProto tensor_proto; if (!mlir::tfg::ConvertToTensorProto( - tensor_proto_attr.cast(), &tensor_proto) + mlir::cast(tensor_proto_attr), &tensor_proto) .ok()) { return failure(); } - const int bit_width = tensor_qtype.getElementType() - .dyn_cast() - .getStorageTypeIntegralWidth(); + const int bit_width = + mlir::dyn_cast(tensor_qtype.getElementType()) + .getStorageTypeIntegralWidth(); tensor_proto.set_dtype((bit_width == 8) ? tensorflow::DT_QINT8 : tensorflow::DT_QINT32); @@ -1033,8 +1036,9 @@ class RestoreWeightShapePattern int weight_operand_idx = 1; Operation* weight_op = op.getOperand(weight_operand_idx).getDefiningOp(); - auto weight_type = weight_op->getResult(0).getType().dyn_cast(); - auto input_type = op.getOperand(0).getType().dyn_cast(); + auto weight_type = + mlir::dyn_cast(weight_op->getResult(0).getType()); + auto input_type = mlir::dyn_cast(op.getOperand(0).getType()); llvm::ArrayRef weight_shape = weight_type.getShape(); llvm::ArrayRef input_shape = input_type.getShape(); @@ -1073,7 +1077,7 @@ class RestoreWeightShapePattern LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op, PatternRewriter& rewriter) const override { - const auto f_attr = call_op.getFAttr().dyn_cast(); + const auto f_attr = mlir::dyn_cast(call_op.getFAttr()); StringRef function_name = f_attr.getValue(); // TODO(b/228928859): Improve the getter function to match attributes rather // than function name. @@ -1106,7 +1110,8 @@ class QuantizationSummary { module_.walk([&](Operation* op) { if (auto call_op = llvm::dyn_cast_or_null(op)) { - const auto f_attr = call_op.getFAttr().dyn_cast(); + const auto f_attr = + mlir::dyn_cast(call_op.getFAttr()); if (!f_attr) return; StringRef func_name = f_attr.getValue(); if (func_name.starts_with(kQuantizedFuncPrefix)) { @@ -1227,7 +1232,7 @@ class QuantizationSummary { } // Use the first op as the representative name. - return quantized_ops.front().cast().getValue(); + return mlir::cast(quantized_ops.front()).getValue(); } bool IsInCompsiteFunction(Operation* op) { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc index 374d687428ee3e..b202798dffe9d0 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -77,8 +78,9 @@ void PrepareXlaConvParams(OpBuilder &builder, Location loc, ArrayAttr strides, SmallVector lhs_dilation_values(num_dims - 2, 1); SmallVector stride_values, rhs_dilation_values; for (int64_t i : llvm::seq(1, num_dims - 1)) { - stride_values.push_back(strides[i].cast().getInt()); - rhs_dilation_values.push_back(dilations[i].cast().getInt()); + stride_values.push_back(mlir::cast(strides[i]).getInt()); + rhs_dilation_values.push_back( + mlir::cast(dilations[i]).getInt()); } window_strides = Create1DConstValue(builder, loc, stride_values); lhs_dilation = Create1DConstValue(builder, loc, lhs_dilation_values); @@ -96,7 +98,7 @@ Value CreateZeroPointPartialOffset(OpBuilder &builder, Location loc, return CreateScalarConstValue(builder, loc, 0); } - auto shape = tensor.getType().template cast(); + auto shape = mlir::cast(tensor.getType()); SmallVector non_output_indices; for (int64_t i : llvm::seq(0, shape.getRank())) { if (absl::c_count(output_dims, i) == 0) { @@ -108,7 +110,7 @@ Value CreateZeroPointPartialOffset(OpBuilder &builder, Location loc, Create1DConstValue(builder, loc, non_output_indices); auto zp = CreateScalarConstValue(builder, loc, other_tensor_zp); - TensorType tensor_type = tensor.getType().dyn_cast(); + TensorType tensor_type = mlir::dyn_cast(tensor.getType()); Value tensor_i32 = builder.create( loc, tensor_type.clone(builder.getIntegerType(32)), tensor); auto reduced = @@ -136,7 +138,7 @@ Value MergeZeroPointOffset(OpBuilder &builder, Location loc, Value weight, int8_t input_zp, int8_t weight_zp, Value zp_input_contribution, Value zp_weight_contribution) { - auto weight_shape = weight.getType().template cast(); + auto weight_shape = mlir::cast(weight.getType()); SmallVector weight_non_output_indices; for (auto i : llvm::seq(0, weight_shape.getRank())) { if (absl::c_count(weight_output_dims, i) == 0) { @@ -498,7 +500,7 @@ Value CreateZeroPointPartialOffsetXlaDotV2( return CreateScalarConstValue(builder, loc, 0); } - auto shape = tensor.getType().template cast(); + auto shape = mlir::cast(tensor.getType()); SmallVector tensor_shape; for (auto v : shape.getShape()) { tensor_shape.push_back(v); @@ -506,7 +508,7 @@ Value CreateZeroPointPartialOffsetXlaDotV2( auto zp = CreateScalarConstValue(builder, loc, other_tensor_zp); - TensorType tensor_type = tensor.getType().dyn_cast(); + TensorType tensor_type = mlir::dyn_cast(tensor.getType()); Value tensor_i32 = builder.create( loc, tensor_type.clone(builder.getIntegerType(32)), tensor); @@ -596,7 +598,7 @@ Value CalculateZeroPointOffsetXLADotV2(OpBuilder &builder, Location loc, Value zp_weight_contribution = CreateZeroPointPartialOffsetXlaDotV2( builder, loc, weight, input_zp, dnums, /*is_lhs=*/false, output_rank); - auto weight_shape = weight.getType().template cast(); + auto weight_shape = mlir::cast(weight.getType()); absl::flat_hash_set rhs_contracting_dims; for (auto dim : dnums.rhs_contracting_dimensions()) { @@ -711,8 +713,8 @@ Value CreateXlaConvOpFromTfConv2dOp(OpBuilder &builder, Location loc, ArrayAttr dilations, StringAttr conv_padding, ArrayAttr explicit_paddings) { - auto input_shape = input.getType().template cast(); - auto filter_shape = filter.getType().template cast(); + auto input_shape = mlir::cast(input.getType()); + auto filter_shape = mlir::cast(filter.getType()); if (!input_shape.hasRank() || input_shape.getRank() != 4 || !filter_shape.hasRank() || filter_shape.getRank() != 4) { emitError(loc, "input and filter are expected to be 4D tensors"); @@ -731,8 +733,8 @@ Value CreateXlaConvOpFromTfDepthwiseConv2dOp( OpBuilder &builder, Location loc, Value input, Value filter, Value input_zp, Value conv_output, ArrayAttr strides, ArrayAttr dilations, StringAttr conv_padding, ArrayAttr explicit_paddings) { - auto input_shape = input.getType().template cast(); - auto filter_shape = filter.getType().template cast(); + auto input_shape = mlir::cast(input.getType()); + auto filter_shape = mlir::cast(filter.getType()); if (!input_shape.hasRank() || input_shape.getRank() != 4 || !filter_shape.hasRank() || filter_shape.getRank() != 4) { emitError(loc, "input and filter are expected to be 4D tensors"); @@ -759,8 +761,8 @@ Value CreateXlaConvOpFromTfConv3dOp(OpBuilder &builder, Location loc, Value conv_output, ArrayAttr strides, ArrayAttr dilations, StringAttr conv_padding) { - auto input_shape = input.getType().template cast(); - auto filter_shape = filter.getType().template cast(); + auto input_shape = mlir::cast(input.getType()); + auto filter_shape = mlir::cast(filter.getType()); if (!input_shape.hasRank() || input_shape.getRank() != 5 || !filter_shape.hasRank() || filter_shape.getRank() != 5) { emitError(loc, "input and filter are expected to be 5D tensors"); @@ -819,7 +821,7 @@ Value CreateXlaDotV2Op(OpBuilder &builder, Location loc, Value input, Value zp_offset = CalculateZeroPointOffsetXLADotV2( builder, loc, input, weight, input_zp_value, weight_zp_value, dnums, - output.getType().template cast().getRank()); + mlir::cast(output.getType()).getRank()); return builder.create(loc, dot_result, zp_offset); } @@ -891,8 +893,8 @@ GetBroadcastShapesForBatchMatmul(ShapedType input_type, // function, except BroadcastTo, are expected to be folded. void BroadcastBatchDimensionsForBatchMatMul(OpBuilder &builder, Location loc, Value &input, Value &weight) { - ShapedType input_type = input.getType().template cast(); - ShapedType weight_type = weight.getType().template cast(); + ShapedType input_type = mlir::cast(input.getType()); + ShapedType weight_type = mlir::cast(weight.getType()); const int32_t input_rank = input_type.getRank(); const int32_t weight_rank = weight_type.getRank(); const int32_t broadcasted_rank = std::max(input_rank, weight_rank); @@ -984,7 +986,7 @@ Value CreateXlaDotV2OpFromTfBatchMatMulOp(OpBuilder &builder, Location loc, BroadcastBatchDimensionsForBatchMatMul(builder, loc, input, weight); // Both input and weight have the same rank after broadcasting. - ShapedType weight_shape = weight.getType().template cast(); + ShapedType weight_shape = mlir::cast(weight.getType()); int num_batch_dim = weight_shape.getRank() - 2; // Transpose and constant-fold the weight if needed. @@ -1016,7 +1018,7 @@ Value CreateXlaDotV2OpFromTfBatchMatMulOp(OpBuilder &builder, Location loc, // Check if the given value is a ranked type with specified integer width. bool IsRankedInt(Value value, const int integer_width) { - ShapedType value_type = value.getType().template cast(); + ShapedType value_type = mlir::cast(value.getType()); if (!value_type.hasRank()) return false; if (!value_type.getElementType().isInteger(integer_width)) return false; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD index 21600fb78083a5..4397b4fc5a3f2d 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD @@ -30,6 +30,7 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -85,6 +86,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_xla//xla:xla_data_proto_cc", ], ) @@ -100,5 +102,6 @@ tf_cc_test( "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/fake_quant_utils.h b/tensorflow/compiler/mlir/quantization/tensorflow/utils/fake_quant_utils.h index 5a1734bf6bf026..702e19506d2fd6 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/fake_quant_utils.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/fake_quant_utils.h @@ -112,7 +112,7 @@ class ConvertFakeQuantOpToQuantOps { Value input = tf_op.getInputs(); int quant_dim = -1; - auto input_type = input.getType().template cast(); + auto input_type = mlir::cast(input.getType()); if (PerAxis) { if (!input_type.hasRank()) { tf_op.emitError("The input should have known rank for per-channel op."); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_quantize_op_utils.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_quantize_op_utils.cc index 264c6c508a60f7..1392bf4de2a92f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_quantize_op_utils.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_quantize_op_utils.cc @@ -16,14 +16,15 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace quant { UnrankedTensorType CreateUnknownShapeFromElementType(Type tensor_type) { - if (!tensor_type.cast()) return UnrankedTensorType(); + if (!mlir::cast(tensor_type)) return UnrankedTensorType(); return UnrankedTensorType::get( - tensor_type.cast().getElementType()); + mlir::cast(tensor_type).getElementType()); } } // namespace quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.cc index 967af993c0bcf7..430d5ff6ba2047 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_uniform_attribute_utils.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/uniform_op_quant_spec.h" @@ -66,9 +67,9 @@ constexpr std::array kSuffixes = {"_min_val", "_max_val"}; Attribute GetWindowStridesValue( PatternRewriter& rewriter, llvm::StringMap& identifier_to_attr) { - ArrayAttr stride = identifier_to_attr["strides"].dyn_cast(); - const int stride_h = stride[1].cast().getInt(); - const int stride_w = stride[2].cast().getInt(); + ArrayAttr stride = mlir::dyn_cast(identifier_to_attr["strides"]); + const int stride_h = mlir::cast(stride[1]).getInt(); + const int stride_w = mlir::cast(stride[2]).getInt(); return rewriter.getI64ArrayAttr({stride_h, stride_w}); } @@ -79,23 +80,24 @@ Attribute GetLhsDilationValue(PatternRewriter& rewriter, Attribute GetRhsDilationValue(PatternRewriter& rewriter, llvm::StringMap& identifier_to_attr) { - ArrayAttr dilations = identifier_to_attr["dilations"].dyn_cast(); - const int dilation_h = dilations[1].cast().getInt(); - const int dilation_w = dilations[2].cast().getInt(); + ArrayAttr dilations = + mlir::dyn_cast(identifier_to_attr["dilations"]); + const int dilation_h = mlir::cast(dilations[1]).getInt(); + const int dilation_w = mlir::cast(dilations[2]).getInt(); return rewriter.getI64ArrayAttr({dilation_h, dilation_w}); } Attribute GetPaddingValue(PatternRewriter& rewriter, llvm::StringMap& identifier_to_attr) { llvm::StringRef padding = - identifier_to_attr["padding"].dyn_cast().getValue(); + mlir::dyn_cast(identifier_to_attr["padding"]).getValue(); return rewriter.getStringAttr(padding); } Attribute GetExplicitPaddingValue( PatternRewriter& rewriter, llvm::StringMap& identifier_to_attr) { ArrayAttr explicit_padding = - identifier_to_attr["explicit_paddings"].dyn_cast(); + mlir::dyn_cast(identifier_to_attr["explicit_paddings"]); return explicit_padding; } @@ -167,7 +169,7 @@ LogicalResult CheckIfAttrIs8Bit(const std::string& attr, Operation* op, element_type = getElementTypeOrSelf(op->getOpResult(0).getType()); } if (element_type) { - is_8_bit = element_type.isa(); + is_8_bit = mlir::isa(element_type); return success(); } return failure(); @@ -295,7 +297,8 @@ LogicalResult FillAttributesForUniformQuantizedConvolutionOp( auto feature_group_cnt_attr = llvm::StringRef("feature_group_count"); int feature_group_cnt = 1; - ShapedType input_shape = op->getOperand(0).getType().dyn_cast(); + ShapedType input_shape = + mlir::dyn_cast(op->getOperand(0).getType()); if (!input_shape) { return op->emitError( "Only input with known shape is supported for Uniform Quantized " @@ -425,7 +428,8 @@ LogicalResult FillAttributesForUniformRequantizeOp( activation_quantization_axis = GetQuantizationAxis(rewriter, op, /*operand_index=*/0); - auto output_scale_type = op->getOperand(3).getType().dyn_cast(); + auto output_scale_type = + mlir::dyn_cast(op->getOperand(3).getType()); if (!output_scale_type) { return failure(); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.cc index f1d7a6ae576c7b..7054f049d1369c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/strings/str_format.h" #include "llvm/ADT/ArrayRef.h" +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.h" #include "xla/xla_data.pb.h" @@ -34,8 +35,7 @@ Value GetDimValue(OpBuilder &builder, Location loc, Value shape_value, return builder.create( loc, RankedTensorType::get( - {}, - shape_value.getType().template cast().getElementType()), + {}, mlir::cast(shape_value.getType()).getElementType()), /*input=*/shape_value, /*begin=*/Create1DConstValue(builder, loc, {dim}), /*end=*/Create1DConstValue(builder, loc, {dim + 1}), @@ -109,14 +109,14 @@ Value PadForDynamicShapedInputSamePadding( CreateConstValue(builder, loc, {rank}, shape)); }; - ShapedType filter_shape = filter.getType().template cast(); + ShapedType filter_shape = mlir::cast(filter.getType()); Value input_shape_value = builder.create( loc, RankedTensorType::get({num_dims}, builder.getI32Type()), input); auto scalar_to_rank1 = [&](Value value) { return reshape_op(value, {1}); }; for (int i : llvm::seq(1, num_dims - 1)) { Value input_size_i = GetDimValue(builder, loc, input_shape_value, i); - const int stride_i = strides[i].cast().getInt(); - const int dilation_i = dilations[i].cast().getInt(); + const int stride_i = mlir::cast(strides[i]).getInt(); + const int dilation_i = mlir::cast(dilations[i]).getInt(); const int filter_i = filter_shape.getDimSize(i - 1); Value pad_i_low, pad_i_high; GetSamePaddingValues(builder, loc, input_size_i, filter_i, dilation_i, @@ -154,7 +154,7 @@ Value CalculatePaddingAndPadIfNeeded(OpBuilder &builder, Location loc, StringAttr conv_padding, ArrayAttr explicit_paddings, Value &padding, int num_dims) { - ShapedType input_shape = input.getType().template cast(); + ShapedType input_shape = mlir::cast(input.getType()); SmallVector spatial_dims(num_dims - 2); absl::c_iota(spatial_dims, 1); bool has_dynamic_spatial_dim = absl::c_any_of( @@ -166,7 +166,7 @@ Value CalculatePaddingAndPadIfNeeded(OpBuilder &builder, Location loc, conv_padding, padding, num_dims); } - ShapedType filter_shape = filter.getType().template cast(); + ShapedType filter_shape = mlir::cast(filter.getType()); SmallVector padding_values(2 * num_dims, 0); if (conv_padding.strref().equals("EXPLICIT")) { if (explicit_paddings.size() != 2 * num_dims) { @@ -178,16 +178,16 @@ Value CalculatePaddingAndPadIfNeeded(OpBuilder &builder, Location loc, } for (int i : spatial_dims) { padding_values[2 * i] = - explicit_paddings[2 * i].cast().getInt(); + mlir::cast(explicit_paddings[2 * i]).getInt(); padding_values[2 * i + 1] = - explicit_paddings[2 * i + 1].cast().getInt(); + mlir::cast(explicit_paddings[2 * i + 1]).getInt(); } } else if (conv_padding.strref().equals("SAME")) { for (int i : spatial_dims) { int input_size = input_shape.getDimSize(i); int filter_size = filter_shape.getDimSize(i - 1); - int stride_i = strides[i].cast().getInt(); - int dilation_i = dilations[i].cast().getInt(); + int stride_i = mlir::cast(strides[i]).getInt(); + int dilation_i = mlir::cast(dilations[i]).getInt(); int out_size = tflite::ComputeOutSize(kTfLitePaddingSame, input_size, filter_size, stride_i, dilation_i); @@ -243,7 +243,7 @@ Value CalculatePaddingAndPadIfNeeded(OpBuilder &builder, Location loc, // // packed_value = bitwise_or(packed_low, packed_high) Value PackOperand(OpBuilder &builder, Location loc, Value value, int pack_dim) { - ShapedType value_type = value.getType().cast(); + ShapedType value_type = mlir::cast(value.getType()); const int rank = value_type.getRank(); SmallVector packed_shape(value_type.getShape().begin(), diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils_test.cc index cc4bbb344026da..cbcda677b87733 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils_test.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" @@ -51,7 +52,8 @@ void PackOperandTestHelper( DenseIntElementsAttr packed_value_attr; ASSERT_TRUE(matchPattern(packed_value, m_Constant(&packed_value_attr))); - ShapedType packed_shape_type = packed_value.getType().dyn_cast(); + ShapedType packed_shape_type = + mlir::dyn_cast(packed_value.getType()); llvm::SmallVector packed_shape(packed_shape_type.getShape().begin(), packed_shape_type.getShape().end()); EXPECT_THAT(packed_shape, testing::ElementsAreArray(expected_packed_shape)); diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 26d5e4d52b41d7..7fda5b22e08ca2 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -644,6 +644,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -778,6 +779,7 @@ cc_library( hdrs = ["utils/location_utils.h"], deps = [ "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -908,6 +910,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:ml_dtypes", "@local_xla//xla:test", ], @@ -938,6 +941,7 @@ cc_library( "//tensorflow/core/util:managed_stack_trace", "@com_google_absl//absl/status", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_xla//xla/mlir/utils:error_util", ], ) @@ -1401,6 +1405,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -1473,6 +1478,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -1505,6 +1511,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -1516,6 +1523,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc index 348316e2648ccb..cab89bb10b5fb9 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc @@ -74,7 +74,7 @@ class BacktrackAnalysisInfo { // the result cannot be backtracked to a region argument, returns // std::nullopt. std::optional GetArg(int result_index) const { - if (auto arg = GetValue(result_index).dyn_cast()) + if (auto arg = mlir::dyn_cast(GetValue(result_index))) if (arg.getParentBlock() == ®ion_->front()) return arg.getArgNumber(); return std::nullopt; } @@ -191,7 +191,7 @@ BacktrackAnalysis::BacktrackAnalysis( // possible. Value BacktrackAnalysis::BacktrackValue(Value value) { while (Operation* op = value.getDefiningOp()) { - int res_index = value.cast().getResultNumber(); + int res_index = mlir::cast(value).getResultNumber(); if (auto graph = dyn_cast(op)) { value = graph.GetFetch().getOperand(res_index); } else if (auto island = dyn_cast(op)) { diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc index 5ceda80490f688..e27d0405d7e8f1 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_dataflow.cc @@ -46,7 +46,7 @@ ResourceConstructingOps ResourceConstructingOps::EntryState( return ResourceConstructingOps(); } ResourceConstructingOps ResourceConstructingOps::EntryState(Value value) { - if (auto barg = value.dyn_cast()) { + if (auto barg = mlir::dyn_cast(value)) { if (func::FuncOp func = dyn_cast(barg.getOwner()->getParentOp())) { SymbolTable symbol_table(func->getParentOfType()); @@ -87,7 +87,7 @@ IsComposite IsComposite::EntryState(MLIRContext *context) { IsComposite IsComposite::EntryState(Value value) { IsComposite result; - if (auto barg = value.dyn_cast()) { + if (auto barg = mlir::dyn_cast(value)) { if (func::FuncOp func = dyn_cast(barg.getOwner()->getParentOp())) { if (func.getArgAttr(barg.getArgNumber(), kCompositeDevice)) { diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.cc index e1a984ea69bc67..abace5111184ff 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.cc @@ -29,8 +29,8 @@ namespace TF { namespace { bool IsResourceType(Type type) { - if (auto tensor_type = type.dyn_cast()) { - return tensor_type.getElementType().isa(); + if (auto tensor_type = mlir::dyn_cast(type)) { + return mlir::isa(tensor_type.getElementType()); } return false; } @@ -44,10 +44,9 @@ func::FuncOp GetSessionInitializerFunc(ModuleOp module) { auto session_init_op = tf_saved_model::GetSessionInitializerOp(module); if (session_init_op && !session_init_op.getInitializers().empty()) { SymbolTable symbol_table(module); - func::FuncOp init_func_op = - symbol_table.lookup(session_init_op.getInitializers()[0] - .cast() - .getValue()); + func::FuncOp init_func_op = symbol_table.lookup( + mlir::cast(session_init_op.getInitializers()[0]) + .getValue()); return init_func_op; } return nullptr; diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index c95dd020497385..df0138a20a0c74 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -73,7 +73,7 @@ const ResourceIdSet& UnknownResourceSet() { const ResourceIdSet& GetResourceUniqueIdsOrUnknown( Value value, const ResourceAliasAnalysis::Info& alias_analysis) { - if (!getElementTypeOrSelf(value.getType()).isa() || + if (!mlir::isa(getElementTypeOrSelf(value.getType())) || alias_analysis.IsUnknownResource(value)) return UnknownResourceSet(); return alias_analysis.GetResourceUniqueIds(value); } @@ -145,7 +145,7 @@ bool MayHaveSideEffect(Operation* op) { bool ShouldUseResourceAliasAnalysis( const MemoryEffects::EffectInstance& effect) { Value value = effect.getValue(); - if (value && getElementTypeOrSelf(value.getType()).isa()) { + if (value && mlir::isa(getElementTypeOrSelf(value.getType()))) { // For value-based effects on resource values we can use resource alias // analysis. return true; diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc index e9a35b1221c2a4..7275aee19e49f4 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -121,7 +121,7 @@ class MlirTensor : public TracingTensorHandle { Value getValue() { return value_; } Type getElementType() { - return value_.getType().cast().getElementType(); + return mlir::cast(value_.getType()).getElementType(); } // For LLVM style RTTI. @@ -340,11 +340,11 @@ Status MlirAbstractOp::SetOpName(const char* const op_name) { Status MlirAbstractOp::AddRef(Type type, Type* output_type) { Type elt_type = getElementTypeOrSelf(type); - if (elt_type.isa()) { + if (mlir::isa(elt_type)) { return InvalidArgument("Requested reference to a reference type"); } elt_type = TensorFlowRefType::get(elt_type); - if (RankedTensorType tensor_type = type.dyn_cast()) { + if (RankedTensorType tensor_type = mlir::dyn_cast(type)) { *output_type = RankedTensorType::get(tensor_type.getShape(), elt_type); } *output_type = UnrankedTensorType::get(elt_type); @@ -373,11 +373,11 @@ Status MlirAbstractOp::Create(ArrayRef operands, return InvalidArgument("Missing attribute '", output_arg.number_attr(), "' required for output list '", output_arg.name(), "'"); - if (!repeats_attr.isa()) + if (!mlir::isa(repeats_attr)) return InvalidArgument("Attribute '", output_arg.number_attr(), "' required for output list '", output_arg.name(), "' isn't an integer"); - int64_t repeats = repeats_attr.cast().getInt(); + int64_t repeats = mlir::cast(repeats_attr).getInt(); if (!output_arg.type_attr().empty()) { // Same type repeated "repeats" times. @@ -386,7 +386,7 @@ Status MlirAbstractOp::Create(ArrayRef operands, return InvalidArgument("Missing attribute '", output_arg.type_attr(), "' required for output '", output_arg.name(), "'"); - TypedAttr type_attr = attr.dyn_cast(); + TypedAttr type_attr = mlir::dyn_cast(attr); if (!type_attr) return InvalidArgument("Attribute '", output_arg.type_attr(), "' required for output '", output_arg.name(), @@ -410,7 +410,7 @@ Status MlirAbstractOp::Create(ArrayRef operands, return InvalidArgument("Missing attribute '", output_arg.type_attr(), "' required for output '", output_arg.name(), "'"); - TypeAttr type_attr = attr.dyn_cast(); + TypeAttr type_attr = mlir::dyn_cast(attr); if (!type_attr) return InvalidArgument("Attribute '", output_arg.type_attr(), "' required for output '", output_arg.name(), @@ -423,13 +423,13 @@ Status MlirAbstractOp::Create(ArrayRef operands, return InvalidArgument( "Missing attribute '", output_arg.type_list_attr(), "' required for output '", output_arg.name(), "'"); - ArrayAttr array_attr = attr.dyn_cast(); + ArrayAttr array_attr = mlir::dyn_cast(attr); if (!array_attr) return InvalidArgument("Attribute '", output_arg.type_list_attr(), "' required for output '", output_arg.name(), "' isn't an array attribute"); for (Attribute attr : array_attr) { - TypeAttr type_attr = attr.dyn_cast(); + TypeAttr type_attr = mlir::dyn_cast(attr); if (!type_attr) return InvalidArgument("Array Attribute '", output_arg.type_list_attr(), diff --git a/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.cc index dba58f17ccb029..9a1db50ff6b732 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -45,11 +46,12 @@ _TfrtGetResourceOp::GetResourceHandleValueAndIdList( for (const auto &iter : llvm::enumerate(getResults())) { auto index = iter.index(); - if (getElementTypeOrSelf(iter.value().getType()).isa()) { + if (mlir::isa( + getElementTypeOrSelf(iter.value().getType()))) { resource_vec.push_back(GetResourceHandleValueAndIdBase( - getContainer()[index].cast().getValue(), - getSharedName()[index].cast().getValue(), device, - getResults()[index], resource_handle_id_map, next_id)); + mlir::cast(getContainer()[index]).getValue(), + mlir::cast(getSharedName()[index]).getValue(), + device, getResults()[index], resource_handle_id_map, next_id)); } } return resource_vec; @@ -100,16 +102,16 @@ mlir::LogicalResult IfrtCallOp::verify() { } for (mlir::Value arg : getArgs()) { - if (mlir::getElementTypeOrSelf(arg.getType()) - .isa()) { + if (mlir::isa( + mlir::getElementTypeOrSelf(arg.getType()))) { return emitOpError() << "does not support passing '!tf.resource' values as arguments"; } } for (mlir::Value result : getResults()) { - if (mlir::getElementTypeOrSelf(result.getType()) - .isa()) { + if (mlir::isa( + mlir::getElementTypeOrSelf(result.getType()))) { return emitOpError() << "does not support returning '!tf.resource' values as results"; } @@ -118,12 +120,13 @@ mlir::LogicalResult IfrtCallOp::verify() { // Verify variable_arg_indices is sorted in ascending order. int64_t prev_index = -1; for (auto arg_index_attr : getVariableArgIndicesAttr()) { - if (!arg_index_attr.isa_and_nonnull()) { + if (!mlir::isa_and_nonnull(arg_index_attr)) { return emitOpError() << "variable_arg_indices must be an integer"; } - int64_t index = - arg_index_attr.dyn_cast().getValue().getSExtValue(); + int64_t index = mlir::dyn_cast(arg_index_attr) + .getValue() + .getSExtValue(); if (index < 0) { return emitOpError() << "variable_arg_indices must be positive"; } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.cc index f5284a0ef3cf96..9a78a1a83ae214 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" namespace mlir { @@ -27,12 +28,12 @@ namespace TF { // Verifies an reduction op's `input` and reduction `dims`. LogicalResult VerifyReductionInputAndDims(Value input, Value dims, Location loc) { - auto dims_type = dims.getType().dyn_cast(); + auto dims_type = mlir::dyn_cast(dims.getType()); if (!dims_type) return success(); if (dims_type.getRank() > 1) return emitError(loc, "dimensions can only be 0D or 1D tensor"); - auto input_type = input.getType().dyn_cast(); + auto input_type = mlir::dyn_cast(input.getType()); if (!input_type) return success(); int64_t rank = input_type.getRank(); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h index aa0f84eb122e2b..64b5d2e141f13d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/TypeRange.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { @@ -60,10 +61,10 @@ template < OpT, AddV2Op, SubOp, MulOp, DivOp, RealDivOp>::value>::type * = nullptr> OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, ArrayRef operands) { - auto lhs_type = arithmetic_op.getX().getType().template cast(); - auto rhs_type = arithmetic_op.getY().getType().template cast(); + auto lhs_type = mlir::cast(arithmetic_op.getX().getType()); + auto rhs_type = mlir::cast(arithmetic_op.getY().getType()); auto result_type = - arithmetic_op.getResult().getType().template cast(); + mlir::cast(arithmetic_op.getResult().getType()); // We can fold arithmetic operation only of we can prove that we will not // accidentally hide a broadcasting error. @@ -86,8 +87,8 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, // Check that we have a constant operand on one side (candidate for identity). const bool is_commutative = (std::is_same::value || std::is_same::value); - auto lhs_attr = operands[0].dyn_cast_or_null(); - auto rhs_attr = operands[1].dyn_cast_or_null(); + auto lhs_attr = mlir::dyn_cast_or_null(operands[0]); + auto rhs_attr = mlir::dyn_cast_or_null(operands[1]); if (!rhs_attr && !(is_commutative && lhs_attr)) return {}; // Mul and Div ops have identity value one while AddV2 and SubOp have identity @@ -100,9 +101,9 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, Type element_ty = lhs_type.getElementType(); Attribute identity_attr; - if (auto ty = element_ty.template dyn_cast()) { + if (auto ty = mlir::dyn_cast(element_ty)) { identity_attr = FloatAttr::get(ty, static_cast(identity)); - } else if (auto ty = element_ty.template dyn_cast()) { + } else if (auto ty = mlir::dyn_cast(element_ty)) { identity_attr = IntegerAttr::get(ty, static_cast(identity)); } else { return {}; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index 5d145c85a68a06..df887ce453b8ea 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -99,7 +99,8 @@ struct TFInlinerInterface : public DialectInlinerInterface { Operation* materializeCallConversion(OpBuilder& builder, Value input, Type result_type, Location conversion_loc) const final { - if (!result_type.isa() || !input.getType().isa()) + if (!mlir::isa(result_type) || + !mlir::isa(input.getType())) return nullptr; return builder.create(conversion_loc, result_type, input, /*truncate=*/builder.getBoolAttr(false)); @@ -307,7 +308,7 @@ ParseResult SetReplicateOpOperands( llvm::ArrayRef region_arg_types, int32_t* n) { for (const auto& attr : state->attributes) if (attr.getName().strref() == "n") - if (auto n_attr = attr.getValue().dyn_cast()) + if (auto n_attr = mlir::dyn_cast(attr.getValue())) *n = n_attr.getInt(); if (*n < 2) @@ -507,13 +508,14 @@ LogicalResult ReplicateOp::verify() { // Check number of devices, if set, matches `n`. if (op.getDevices().has_value()) { for (auto device_attr : op.getDevices().value().getValue()) { - auto device_list = device_attr.getValue().dyn_cast_or_null(); + auto device_list = + mlir::dyn_cast_or_null(device_attr.getValue()); if (!device_list) return op.emitError() << "expects 'devices' to be a map alias and device name list."; bool is_device_string = llvm::all_of(device_list, [](Attribute attr) { - return attr.dyn_cast_or_null(); + return mlir::dyn_cast_or_null(attr); }); if (!is_device_string) return op.emitOpError() << "expects 'devices' to be a consists of " @@ -747,8 +749,8 @@ static LogicalResult EliminatePassThroughResults(ClusterOp op, // Old bridge only removes unsupported TPU types (only string for now) // during outside compilation extraction so this should be enough for // the parity. - bool is_unsupported_type = getElementTypeOrSelf(operand.get().getType()) - .isa(); + bool is_unsupported_type = mlir::isa( + getElementTypeOrSelf(operand.get().getType())); Value result = operand.get(); if (is_unsupported_type && result.getParentBlock() != &body && !is_used_for_resource_write) { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index f7c35420c22b4a..f48e1570933cf9 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/FoldUtils.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project @@ -119,11 +120,11 @@ Type TensorFlowExecutorDialect::parseType(DialectAsmParser &parser) const { void TensorFlowExecutorDialect::printType(Type type, DialectAsmPrinter &os) const { - if (type.isa()) { + if (mlir::isa(type)) { os << "control"; return; } - if (type.isa()) { + if (mlir::isa(type)) { os << "token"; return; } @@ -141,7 +142,7 @@ namespace { LogicalResult VerifyControlOperandsAfterAllData(Operation *op) { bool found_control = false; for (int operand_idx : llvm::seq(0, op->getNumOperands())) { - if (op->getOperand(operand_idx).getType().isa()) { + if (mlir::isa(op->getOperand(operand_idx).getType())) { found_control = true; continue; } @@ -192,7 +193,7 @@ LogicalResult GraphOp::verify() { Value operand = fetch.getOperand(i); // Break out of the loop at the first control operand encountered. const int64_t num_results = graph.getNumResults(); - if (operand.getType().isa()) { + if (mlir::isa(operand.getType())) { if (i != num_results) return fetch.emitOpError() << "operand #" << i @@ -241,7 +242,7 @@ ParseResult GraphOp::parse(OpAsmParser &parser, OperationState &result) { // the fetch operation. result.types.reserve(fetch.getNumOperands()); for (Type type : fetch.getOperandTypes()) { - if (type.isa()) break; + if (mlir::isa(type)) break; result.types.push_back(type); } @@ -403,8 +404,8 @@ ParseResult SwitchOp::parse(OpAsmParser &parser, OperationState &result) { // fully qualified) or a short form with a single type (in which case the data // input and the outputs are all using this type and predicate is tensor // type). - if (types.front().isa()) { - FunctionType type = types.front().cast(); + if (mlir::isa(types.front())) { + FunctionType type = mlir::cast(types.front()); if (type.getNumInputs() < 2) return parser.emitError(parser.getNameLoc()) << " expects a single data type and a predicate"; @@ -439,7 +440,7 @@ void SwitchOp::print(OpAsmPrinter &p) { p << " : "; if (getTrueOutput().getType() != data_operand_ty || getFalseOutput().getType() != data_operand_ty || - getPredicate().getType().isa()) { + mlir::isa(getPredicate().getType())) { p.printFunctionalType(getOperation()); } else { p << getType(0); @@ -465,16 +466,16 @@ LogicalResult SwitchNOp::verify() { // Check that operand can be broadcasted to each output type. auto operand0_type = switchn.getOperand(0).getType(); - TensorType operand0_tensor_type = operand0_type.dyn_cast(); + TensorType operand0_tensor_type = mlir::dyn_cast(operand0_type); if (!operand0_tensor_type) { return switchn.emitOpError() << "expects data operand to have tensor type but got " << operand0_type; } for (Type output_type : switchn.getResultTypes()) { - if (output_type.isa()) break; + if (mlir::isa(output_type)) break; - TensorType output_tensor_type = output_type.dyn_cast(); + TensorType output_tensor_type = mlir::dyn_cast(output_type); if (!output_tensor_type) { return switchn.emitOpError() << "expects outputs to have tensor type but got " << output_type; @@ -483,10 +484,10 @@ LogicalResult SwitchNOp::verify() { // If the output type is a ref type, then the operand type should also be of // the same ref type. However, if the output type is a non-ref type T, then // the operand can be tensor of type T or T_REF. - bool is_output_ref = - output_tensor_type.getElementType().isa(); - if (is_output_ref && !operand0_tensor_type.getElementType() - .isa()) { + bool is_output_ref = mlir::isa( + output_tensor_type.getElementType()); + if (is_output_ref && !mlir::isa( + operand0_tensor_type.getElementType())) { return switchn.emitOpError() << "expects same operand and output element type but got " << operand0_tensor_type << " vs " << output_tensor_type; @@ -573,24 +574,24 @@ LogicalResult MergeOp::verify() { return merge.emitOpError() << "expects at least one operand"; Type data_type = merge.getOperand(0).getType(); - if (data_type.isa()) + if (mlir::isa(data_type)) return merge.emitOpError() << "expects a non-control input"; // Check that each operand can be individually broadcasted to the output type. Type output_type = merge.getOutput().getType(); - TensorType output_tensor_ty = output_type.dyn_cast(); + TensorType output_tensor_ty = mlir::dyn_cast(output_type); if (!output_tensor_ty) { return merge.emitOpError() << "expects output to have tensor type but got " << output_type; } bool is_output_ref = - output_tensor_ty.getElementType().isa(); + mlir::isa(output_tensor_ty.getElementType()); for (Type operand_type : merge.getOperandTypes()) { - if (operand_type.isa()) break; + if (mlir::isa(operand_type)) break; // TODO(hinsu): Update ControlOperandsAfterAllData trait to verify this // constraint. - TensorType operand_tensor_ty = operand_type.dyn_cast(); + TensorType operand_tensor_ty = mlir::dyn_cast(operand_type); if (!operand_tensor_ty) return merge.emitOpError() << "expects data operands to have tensor type but got " @@ -599,8 +600,8 @@ LogicalResult MergeOp::verify() { // If output type is a ref type then all operand types should also be of the // same ref type. However, if the output type is a non-ref type T, operands // can be tensor of type T or T_REF. - if (is_output_ref && - !operand_tensor_ty.getElementType().isa()) { + if (is_output_ref && !mlir::isa( + operand_tensor_ty.getElementType())) { return merge.emitOpError() << "expects same operand and output element type but got " << operand_tensor_ty << " vs " << output_tensor_ty; @@ -624,7 +625,7 @@ void MergeOp::print(OpAsmPrinter &p) { Type output_type = getOutput().getType(); for (Type operand_type : getOperandTypes()) { - if (operand_type.isa()) break; + if (mlir::isa(operand_type)) break; num_data_operands++; if (operand_type != output_type) { @@ -660,7 +661,7 @@ ParseResult MergeOp::parse(OpAsmParser &parser, OperationState &result) { // Support parsing either a functional type (in which case all the types are // fully qualified) or a short form with a single type (in which case the data // inputs and the output are all using this type). - if (FunctionType type = types.front().dyn_cast()) { + if (FunctionType type = mlir::dyn_cast(types.front())) { result.types.assign(type.getResults().begin(), type.getResults().end()); types.assign(type.getInputs().begin(), type.getInputs().end()); } else { @@ -747,7 +748,7 @@ ParseResult EnterOp::parse(OpAsmParser &parser, OperationState &result) { // Support parsing either a functional type (in which case all the types are // fully qualified) or a short form with a single type (in which case the data // input and the outputs are all using this type). - if (FunctionType type = types.front().dyn_cast()) { + if (FunctionType type = mlir::dyn_cast(types.front())) { // One data input, and any number of control inputs. if (type.getNumInputs() >= 1) { result.types.assign(type.getResults().begin(), type.getResults().end()); @@ -876,7 +877,7 @@ ParseResult LoopCondOp::parse(OpAsmParser &parser, OperationState &result) { // fully qualified) or a short form with a single type (in which case the data // input and the outputs are all using this type). Type control_type = ControlType::get(parser.getBuilder().getContext()); - if (FunctionType type = types.front().dyn_cast()) { + if (FunctionType type = mlir::dyn_cast(types.front())) { if (llvm::count_if(type.getInputs(), [=](Type type) { return type != control_type; }) != 1) return parser.emitError(parser.getNameLoc()) @@ -959,14 +960,14 @@ struct HoistInnerOpsSingleIslandGraph : public OpRewritePattern { llvm::SmallVector new_rets; for (Value operand : fetch_op.getFetches()) { // Control results should not be propagated out. - if (operand.getType().isa()) break; + if (mlir::isa(operand.getType())) break; if (operand.getDefiningOp() != island_op) { // Operand is not from island, simply propagate it out. new_rets.push_back(operand); } else { // Lookup yield operand in island for inner op result. - auto result = operand.cast(); + auto result = mlir::cast(operand); new_rets.push_back(yield_op.getOperand(result.getResultNumber())); } } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index d3026b02878741..373586ae837a3f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -191,7 +191,8 @@ struct TFInlinerInterface : public DialectInlinerInterface { Operation *materializeCallConversion(OpBuilder &builder, Value input, Type result_type, Location conversion_loc) const final { - if (!result_type.isa() || !input.getType().isa()) + if (!mlir::isa(result_type) || + !mlir::isa(input.getType())) return nullptr; return builder.create(conversion_loc, result_type, input, /*truncate=*/builder.getBoolAttr(false)); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index 988c749adb8cc6..36fb36a3d451c6 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -160,12 +160,12 @@ OpFoldResult AddNOp::fold(FoldAdaptor adaptor) { int non_zero_index = -1; auto IsKnownZero = [](Attribute attr) { if (!attr) return false; - auto splat = attr.dyn_cast(); + auto splat = mlir::dyn_cast(attr); if (!splat) return false; Type element_ty = splat.getType().getElementType(); - if (element_ty.isa()) + if (mlir::isa(element_ty)) return splat.getSplatValue().isZero(); - if (element_ty.isa()) + if (mlir::isa(element_ty)) return splat.getSplatValue().getSExtValue() == 0; return false; }; @@ -180,13 +180,13 @@ OpFoldResult AddNOp::fold(FoldAdaptor adaptor) { } // Only fold when the result shape is fully static. - auto result_ty = getType().dyn_cast(); + auto result_ty = mlir::dyn_cast(getType()); if (!result_ty || !result_ty.hasStaticShape()) return {}; if (non_zero_index == -1) { return SplatElementsAttr::get( - result_ty, - operands.begin()->cast().getSplatValue()); + result_ty, mlir::cast(*operands.begin()) + .getSplatValue()); } // Check the non-zero operand's shape matches the result shape. @@ -423,7 +423,7 @@ LogicalResult BatchToSpaceOp::verify() { int64_t block_size = op.getBlockSize(); llvm::SmallVector input_shape(4, ShapedType::kDynamic); - auto input_type = op.getInput().getType().cast(); + auto input_type = mlir::cast(op.getInput().getType()); if (input_type.hasRank()) { if (input_type.getRank() != 4) return op.emitOpError() @@ -442,7 +442,7 @@ LogicalResult BatchToSpaceOp::verify() { input_type.getShape().end()); } - auto crops_type = op.getCrops().getType().cast(); + auto crops_type = mlir::cast(op.getCrops().getType()); if (crops_type.hasRank()) { if (crops_type.getRank() != 2) return op.emitOpError() @@ -477,7 +477,7 @@ LogicalResult BatchToSpaceOp::verify() { } } - auto output_type = op.getOutput().getType().cast(); + auto output_type = mlir::cast(op.getOutput().getType()); if (output_type.hasRank()) { if (output_type.getRank() != 4) return op.emitOpError() @@ -567,8 +567,8 @@ void BatchToSpaceOp::getCanonicalizationPatterns(RewritePatternSet& results, LogicalResult BatchToSpaceNDOp::verify() { BatchToSpaceNDOp op = *this; - auto block_shape_ty = op.getBlockShape().getType().cast(); - auto crops_ty = op.getCrops().getType().cast(); + auto block_shape_ty = mlir::cast(op.getBlockShape().getType()); + auto crops_ty = mlir::cast(op.getCrops().getType()); if (block_shape_ty.hasStaticShape() && crops_ty.hasStaticShape()) { const int block_rank = block_shape_ty.getShape().front(); @@ -617,9 +617,9 @@ LogicalResult BiasAddOp::verify() { return op.emitOpError("requires bias operand to have rank exactly one"); RankedTensorType value_ty = - op.getValue().getType().dyn_cast(); + mlir::dyn_cast(op.getValue().getType()); RankedTensorType bias_ty = - op.getBias().getType().dyn_cast(); + mlir::dyn_cast(op.getBias().getType()); if (!bias_ty || !value_ty) return success(); int64_t feature_dim_idx = @@ -716,7 +716,7 @@ OpFoldResult BroadcastToOp::fold(FoldAdaptor) { // Fold broadcast if operand and result types are the same and all dimensions // are statically known (no-op broadcast). - auto result_ty = getType().dyn_cast(); + auto result_ty = mlir::dyn_cast(getType()); if (!result_ty || !result_ty.hasStaticShape()) return {}; if (result_ty == input.getType()) return input; @@ -818,8 +818,8 @@ LogicalResult BroadcastGradientArgsOp::verify() { // Verify that output types are of rank one and matches the computed result // shape. - auto r0_ty = op.getR0().getType().dyn_cast(); - auto r1_ty = op.getR1().getType().dyn_cast(); + auto r0_ty = mlir::dyn_cast(op.getR0().getType()); + auto r1_ty = mlir::dyn_cast(op.getR1().getType()); if (r0_ty && r0_ty.hasStaticShape() && r0_ty.getDimSize(0) != r0.size()) return op.emitOpError() << "requires dimension 0 size of 'r0' to be " << r0.size() << " but got " << r0_ty.getShape()[0]; @@ -852,7 +852,8 @@ LogicalResult BroadcastGradientArgsOp::fold( auto build_out_dense_element = [](SmallVectorImpl& shape, Type input_type) { - Type element_type = input_type.cast().getElementType(); + Type element_type = + mlir::cast(input_type).getElementType(); RankedTensorType type = tensorflow::GetTypeFromTFTensorShape( {static_cast(shape.size())}, element_type); // Input could only be i32 or i64. For i32, downcast to int32_t array. @@ -893,7 +894,7 @@ LogicalResult FoldConstantCaseOp::matchAndRewrite( int index = *branch.getValues().begin(); if (index < 0 || index >= op.num_branches()) index = op.num_branches() - 1; - auto func = op.getBranches()[index].cast(); + auto func = mlir::cast(op.getBranches()[index]); auto empty = rewriter.getStringAttr(""); ReplaceTfOpWithNewOp( rewriter, op, op.getResultTypes(), op.getOperands().drop_front(), func, @@ -932,7 +933,7 @@ static LogicalResult VerifyCaseOrIfOpBranchFunctions( for (const auto& branch : llvm::enumerate(branches)) { auto branch_func = symbol_table.lookupNearestSymbolFrom( - op, branch.value().cast()); + op, mlir::cast(branch.value())); if (!branch_func) return op->emitOpError() << "expects " << branch_name(branch.index()) << " (" @@ -1347,12 +1348,10 @@ LogicalResult HoistCwiseBinaryOutOfConcat::matchAndRewrite( else return failure(); DenseElementsAttr const_attr; - auto scalar_tensor_type = - first_arg_op->getOperand(hoist_params->scalar_operand_idx) - .getType() - .dyn_cast(); + auto scalar_tensor_type = mlir::dyn_cast( + first_arg_op->getOperand(hoist_params->scalar_operand_idx).getType()); Type scalar_dtype = scalar_tensor_type.getElementType(); - if (scalar_dtype.isa()) + if (mlir::isa(scalar_dtype)) const_attr = DenseElementsAttr::get(scalar_tensor_type, static_cast(identity_val)); else @@ -1450,7 +1449,7 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams( } else { operand = arg.getDefiningOp()->getOperand(operand_idx); } - auto ranked = operand.getType().dyn_cast(); + auto ranked = mlir::dyn_cast(operand.getType()); return ranked && ranked.getRank() == (axis + 1) && ranked.getShape()[axis] == 1; }); @@ -1461,13 +1460,13 @@ HoistCwiseBinaryOutOfConcat::GetHoistParams( return llvm::all_of(op.getValues(), [&](Value arg) -> bool { if (exceptions.count(arg)) return true; auto operand = arg.getDefiningOp()->getOperand(operand_idx); - auto ranked = operand.getType().dyn_cast(); + auto ranked = mlir::dyn_cast(operand.getType()); return ranked && ranked.hasRank() && ranked.getRank() == 0; }); }; // Concat result type must be a ranked tensor. - auto ranked = op.getType().dyn_cast(); + auto ranked = mlir::dyn_cast(op.getType()); if (!ranked) return std::nullopt; // TODO(ezhulenev): Add support for more valid concat patterns. @@ -1527,7 +1526,7 @@ static LogicalResult Verify(OpT op) { DenseIntElementsAttr axis_attr; if (matchPattern(op.getAxis(), m_Constant(&axis_attr))) { - auto input_ty = op.getX().getType().template dyn_cast(); + auto input_ty = mlir::dyn_cast(op.getX().getType()); if (input_ty) { int64_t rank = input_ty.getRank(); assert(axis_attr.getNumElements() == 1 && @@ -1561,7 +1560,8 @@ LogicalResult ConcatOffsetOp::verify() { << "requires sizes of shapes and offsets to be the same, got sizes " << op.getShape().size() << " and " << op.getOffset().size(); - auto ranked_dim = op.getConcatDim().getType().dyn_cast(); + auto ranked_dim = + mlir::dyn_cast(op.getConcatDim().getType()); if (ranked_dim && ranked_dim.getRank() != 0) return op.emitOpError() << "requires concat_dim to be a scalar, got tensor of rank " @@ -1578,7 +1578,7 @@ LogicalResult ConcatOffsetOp::verify() { return op.emitOpError() << "requires operand and result " << idx << " to have compatible shapes"; - auto ranked_shape = shape.getType().dyn_cast(); + auto ranked_shape = mlir::dyn_cast(shape.getType()); if (!ranked_shape) continue; if (ranked_shape.getRank() != 1) @@ -1609,14 +1609,15 @@ LogicalResult ConcatOffsetOp::fold(FoldAdaptor adaptor, if (operands.size() < 3) return failure(); // Check concat_dim is a scalar. - auto concat_dim_attr = operands[0].dyn_cast_or_null(); + auto concat_dim_attr = + mlir::dyn_cast_or_null(operands[0]); if (!concat_dim_attr || concat_dim_attr.getType().getRank() != 0) return failure(); llvm::SmallVector shapes; shapes.reserve(operands.size() - 1); for (Attribute shape : llvm::drop_begin(operands, 1)) - if (auto shape_attr = shape.dyn_cast_or_null()) + if (auto shape_attr = mlir::dyn_cast_or_null(shape)) shapes.push_back(shape_attr); else return failure(); @@ -1685,14 +1686,14 @@ OpFoldResult ConstOp::fold(FoldAdaptor adaptor) { void ConstOp::build(OpBuilder& builder, OperationState& result, Attribute value) { ShapedType type; - if (auto elem_attr = value.dyn_cast()) { + if (auto elem_attr = mlir::dyn_cast(value)) { return ConstOp::build(builder, result, elem_attr); - } else if (value.isa()) { + } else if (mlir::isa(value)) { // All TensorFlow types must be tensor types. In the build() method, // we want to provide more flexibility by allowing attributes of scalar // types. But we need to wrap it up with ElementsAttr to construct // valid TensorFlow constants. - auto typed_attr = value.cast(); + auto typed_attr = mlir::cast(value); type = tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, typed_attr.getType()); return ConstOp::build(builder, result, DenseElementsAttr::get(type, value)); @@ -1704,7 +1705,7 @@ void ConstOp::build(OpBuilder& builder, OperationState& result, void ConstOp::build(OpBuilder& builder, OperationState& result, Type type, Attribute value) { // Handle the case where the type and value are already tensors. - if (type.isa() && value.isa()) { + if (mlir::isa(type) && mlir::isa(value)) { result.addTypes(type); result.addAttribute("value", value); return; @@ -1722,7 +1723,7 @@ LogicalResult ConstOp::inferReturnTypes( ConstOpAdaptor adaptor(operands, attributes, properties, regions); auto value = adaptor.getValue(); if (!value) return emitOptionalError(location, "missing attribute 'value'"); - if (auto elem_attr = value.dyn_cast()) { + if (auto elem_attr = mlir::dyn_cast(value)) { inferredReturnTypes.assign({elem_attr.getType()}); return success(); } @@ -1743,7 +1744,7 @@ static LogicalResult VerifyConvOpAttributes( return emitOptionalError( location, "requires strides attribute length to be ", num_dims); auto is_not_positive = [](Attribute val) { - return val.cast().getValue().getSExtValue() <= 0; + return mlir::cast(val).getValue().getSExtValue() <= 0; }; if (llvm::any_of(strides, is_not_positive)) return emitOptionalError(location, "requires positive strides"); @@ -1793,9 +1794,8 @@ static LogicalResult Verify(OpT op) { if (padding == tensorflow::Padding::EXPLICIT) { ArrayRef explicit_padding; - ArrayAttr explicit_pad = - op->getAttr("explicit_paddings") - .template dyn_cast_or_null<::mlir::ArrayAttr>(); + ArrayAttr explicit_pad = mlir::dyn_cast_or_null<::mlir::ArrayAttr>( + op->getAttr("explicit_paddings")); if (!explicit_pad) { explicit_pad = ::mlir::Builder(op->getContext()).getI64ArrayAttr({}); } @@ -1812,7 +1812,7 @@ static LogicalResult Verify(OpT op) { num_dims * 2); } auto is_negative = [](Attribute val) { - return val.cast().getValue().getSExtValue() < 0; + return mlir::cast(val).getValue().getSExtValue() < 0; }; if (llvm::any_of(explicit_padding, is_negative)) return emitOptionalError(op.getLoc(), @@ -1827,7 +1827,7 @@ static LogicalResult Verify(OpT op) { } int64_t input_channels = ShapedType::kDynamic; - if (auto ty = op.getInput().getType().template dyn_cast()) { + if (auto ty = mlir::dyn_cast(op.getInput().getType())) { absl::string_view data_format(op.getDataFormat().data(), op.getDataFormat().size()); tensorflow::TensorFormat format; @@ -1838,8 +1838,7 @@ static LogicalResult Verify(OpT op) { } int64_t filter_channels = ShapedType::kDynamic; - if (auto ty = - op.getFilter().getType().template dyn_cast()) { + if (auto ty = mlir::dyn_cast(op.getFilter().getType())) { int idx = tensorflow::GetFilterTensorInputChannelsDimIndex( num_dims, tensorflow::FORMAT_HWIO); filter_channels = ty.getDimSize(idx); @@ -1891,8 +1890,8 @@ static LogicalResult inferConvReturnTypeComponents( const int64_t num_dims = 2 + num_spatial_dims; const Value input = op.getInput(); const Value filter = op.getFilter(); - const TensorType input_ty = input.getType().template cast(); - const TensorType filter_ty = filter.getType().template cast(); + const TensorType input_ty = mlir::cast(input.getType()); + const TensorType filter_ty = mlir::cast(filter.getType()); ArrayRef strides = op.getStrides().getValue(); StringRef data_format = op.getDataFormat(); @@ -1910,7 +1909,7 @@ static LogicalResult inferConvReturnTypeComponents( (void)padding_is_valid; auto get_int = [](Attribute attr) { - return attr.template cast().getInt(); + return mlir::cast(attr).getInt(); }; // Output always have `num_dims` rank. All dimensions are initialized to @@ -1967,7 +1966,7 @@ LogicalResult Conv2DOp::inferReturnTypeComponents( Conv2DOpAdaptor op(operands.getValues(), attributes, properties, regions); ArrayRef explicit_padding; ArrayAttr explicit_pad = - op.getExplicitPaddings().dyn_cast_or_null<::mlir::ArrayAttr>(); + mlir::dyn_cast_or_null<::mlir::ArrayAttr>(op.getExplicitPaddings()); if (!explicit_pad) { explicit_pad = ::mlir::Builder(context).getI64ArrayAttr({}); } @@ -1984,7 +1983,7 @@ StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices& devices) { return getDataFormat(); // Input must be a tensor. - auto input_ty = getInput().getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(getInput().getType()); if (!input_ty) return getDataFormat(); // For f16 data type on devices with Tensor Cores support NHWC data format @@ -1998,7 +1997,7 @@ StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices& devices) { return getDataFormat(); // Keep current data format if filter rank is unknown or not equal to 4. - auto filter_ty = getFilter().getType().dyn_cast(); + auto filter_ty = mlir::dyn_cast(getFilter().getType()); if (!filter_ty || filter_ty.getRank() != 4) return getDataFormat(); const int64_t d0 = filter_ty.getDimSize(0); @@ -2006,7 +2005,7 @@ StringRef Conv2DOp::GetOptimalLayout(const RuntimeDevices& devices) { auto all_ones = [](ArrayAttr arr) -> bool { return llvm::all_of(arr, [](Attribute attr) -> bool { - return attr.cast().getInt() == 1; + return mlir::cast(attr).getInt() == 1; }); }; @@ -2068,7 +2067,7 @@ StringRef Conv2DBackpropFilterOp::GetOptimalLayout( return getDataFormat(); // Input must be a tensor. - auto input_ty = getInput().getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(getInput().getType()); if (!input_ty) return getDataFormat(); // For f16 data type on devices with Tensor Cores support NHWC data format @@ -2142,7 +2141,7 @@ StringRef Conv2DBackpropInputOp::GetOptimalLayout( return getDataFormat(); // Filter must be a tensor. - auto filter_ty = getFilter().getType().dyn_cast(); + auto filter_ty = mlir::dyn_cast(getFilter().getType()); if (!filter_ty) return getDataFormat(); // For f16 data type on devices with Tensor Cores support NHWC data format @@ -2177,7 +2176,7 @@ LogicalResult Conv3DOp::inferReturnTypeComponents( LogicalResult DataFormatVecPermuteOp::verify() { DataFormatVecPermuteOp op = *this; - auto input_ty = op.getX().getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(op.getX().getType()); if (!input_ty) return success(); int rank = input_ty.getRank(); @@ -2285,12 +2284,12 @@ class DivNoNanOrMulNoNanConstantY : public OpRewritePattern { if (auto yDefOp = dyn_cast_or_null(y.getDefiningOp())) { Type typeOfElementsInY = getElementTypeOrSelf(y.getType()); ElementsAttr attr = yDefOp.getValue(); - bool yHasComplexElements = typeOfElementsInY.isa(); + bool yHasComplexElements = mlir::isa(typeOfElementsInY); // If `y` is a splat constant, then the op will definitely get replaced. // We check for a splat constant first, in order to optimize the // performance of this canonicalization because this check will be O(1). - if (auto splatAttr = attr.dyn_cast()) { + if (auto splatAttr = mlir::dyn_cast(attr)) { bool splatAttrIsZero = false; if (!yHasComplexElements) { if (splatAttr.getSplatValue().isZero()) @@ -2356,7 +2355,8 @@ LogicalResult DynamicStitchOp::verify() { if (op.getN() < 1) return op.emitOpError("requires attribute N with value >= 1"); - if (RankedTensorType out_ty = op.getType().dyn_cast()) { + if (RankedTensorType out_ty = + mlir::dyn_cast(op.getType())) { if (out_ty.getRank() == 0) { return op.emitOpError("requires non scalar output"); } @@ -2383,8 +2383,9 @@ LogicalResult DynamicStitchOp::verify() { } Value data = std::get<1>(it); - RankedTensorType index_ty = index.getType().dyn_cast(); - RankedTensorType data_ty = data.getType().dyn_cast(); + RankedTensorType index_ty = + mlir::dyn_cast(index.getType()); + RankedTensorType data_ty = mlir::dyn_cast(data.getType()); if (!index_ty || !data_ty) continue; int64_t index_rank = index_ty.getRank(); @@ -2429,7 +2430,7 @@ LogicalResult DynamicStitchOp::verify() { expected_shape.append(inferred_item_shape->begin(), inferred_item_shape->end()); - auto out_ty = op.getType().cast(); + auto out_ty = mlir::cast(op.getType()); auto expected_out_ty = tensorflow::GetTypeFromTFTensorShape( expected_shape, out_ty.getElementType()); @@ -2471,25 +2472,25 @@ OpFoldResult EmptyOp::fold(FoldAdaptor adaptor) { Attribute attr = operands.front(); if (!attr) return {}; - auto int_attr = attr.cast(); + auto int_attr = mlir::cast(attr); SmallVector out_shape; for (const auto val : int_attr.getValues()) { out_shape.push_back(val); } - auto type = getResult().getType().cast(); + auto type = mlir::cast(getResult().getType()); auto etype = type.getElementType(); // We can not fold if the result is not static. if (!type.hasStaticShape()) return {}; - if (auto float_type = etype.dyn_cast()) { + if (auto float_type = mlir::dyn_cast(etype)) { auto out_type = tensorflow::GetTypeFromTFTensorShape(out_shape, float_type); return DenseElementsAttr::get(out_type, {APFloat(float_type.getFloatSemantics())}); } - if (auto int_type = etype.dyn_cast()) { + if (auto int_type = mlir::dyn_cast(etype)) { auto out_type = tensorflow::GetTypeFromTFTensorShape(out_shape, etype); APInt val(int_type.getWidth(), 0, int_type.getSignedness()); return DenseElementsAttr::get(out_type, val); @@ -2580,7 +2581,7 @@ EnqueueTPUEmbeddingSparseTensorBatchOp::GetResourceInstanceStr() { //===----------------------------------------------------------------------===// OpFoldResult EnsureShapeOp::fold(FoldAdaptor) { - ShapedType type = getInput().getType().dyn_cast(); + ShapedType type = mlir::dyn_cast(getInput().getType()); if (!type || !type.hasRank()) return {}; // If shape attribute equals input operand's type's shape, fold it to input. std::optional> shape_constraint = getShape(); @@ -2639,15 +2640,15 @@ static LogicalResult flipComatibleShapeError(Ty op, PatternRewriter& rewriter) { // we don't know which one it is. TF shape inference turns unranked outputs // into ranked ones if it can statically evaluate the broadcast, see the shape // function of tf.Equal. - auto ty = op.getType().template dyn_cast(); + auto ty = mlir::dyn_cast(op.getType()); if (!ty) { return rewriter.notifyMatchFailure(op, "requires a ranked output shape"); } // Unless this is a scalar compare, a scalar output indicates that this will // always fail. - auto x_ty = op.getX().getType().template dyn_cast(); - auto y_ty = op.getY().getType().template dyn_cast(); + auto x_ty = mlir::dyn_cast(op.getX().getType()); + auto y_ty = mlir::dyn_cast(op.getY().getType()); if (ty.getRank() == 0 && (!x_ty || x_ty.getRank() != 0 || !y_ty || y_ty.getRank() != 0)) { return rewriter.notifyMatchFailure(op, "output rank must match input rank"); @@ -2675,10 +2676,10 @@ void NotEqualOp::getCanonicalizationPatterns(RewritePatternSet& results, //===----------------------------------------------------------------------===// Type InferExpandDimsOpType(Value input, Value dim) { - Type element_ty = input.getType().cast().getElementType(); + Type element_ty = mlir::cast(input.getType()).getElementType(); auto unranked_ty = UnrankedTensorType::get(element_ty); - auto input_ty = input.getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(input.getType()); if (!input_ty) return unranked_ty; DenseIntElementsAttr dim_attr; @@ -2773,7 +2774,7 @@ LogicalResult FakeQuantWithMinMaxVarsPerChannelOp::verify() { "requires num_bits to be between 2 and 16, inclusive"); } - auto inputs_type = inputs.getType().dyn_cast(); + auto inputs_type = mlir::dyn_cast(inputs.getType()); if (!inputs_type) return success(); int depth = inputs_type.getDimSize(inputs_type.getRank() - 1); if ((min && min.getDimSize(0) != depth) || @@ -2800,7 +2801,7 @@ LogicalResult FillOp::verify() { } static ShapedType InferFillOpType(Value dims, Value value) { - Type etype = value.getType().cast().getElementType(); + Type etype = mlir::cast(value.getType()).getElementType(); DenseIntElementsAttr dims_attr; if (matchPattern(dims, m_Constant(&dims_attr))) { @@ -2813,7 +2814,7 @@ static ShapedType InferFillOpType(Value dims, Value value) { } if (auto shape_op = dims.getDefiningOp()) { - if (auto t = shape_op.getInput().getType().dyn_cast()) { + if (auto t = mlir::dyn_cast(shape_op.getInput().getType())) { return t; } } @@ -2830,20 +2831,20 @@ OpFoldResult FillOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); assert(operands.size() == 2 && "fill op has two operand"); - auto type = getType().cast(); + auto type = mlir::cast(getType()); // DenseElementsAttr that is used in this folder only supports int and float // types. // TODO(hinsu): Handle complex types once there is a attribute kind for // complex. if (!type.getElementType().isIntOrFloat()) return {}; - auto value = operands[1].dyn_cast_or_null(); + auto value = mlir::dyn_cast_or_null(operands[1]); if (!value) return {}; if (type.hasStaticShape()) return DenseElementsAttr::get(type, value.getValues()[0]); - auto dims = operands[0].dyn_cast_or_null(); + auto dims = mlir::dyn_cast_or_null(operands[0]); if (!dims) return {}; llvm::SmallVector shape; @@ -2876,7 +2877,7 @@ StringRef FusedBatchNormGradV3Op::GetOptimalLayout( // For f16 data type on devices with Tensor Cores support NHWC data format // is up to ~2x faster. - auto x_ty = getX().getType().cast(); + auto x_ty = mlir::cast(getX().getType()); const bool is_f16 = x_ty.getElementType().isF16(); if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; @@ -2940,7 +2941,7 @@ static StringRef GetOptimalLayout(const RuntimeDevices& devices, Op* op) { // For f16 data type on devices with Tensor Cores support NHWC data format // is up to ~2x faster. - auto x_ty = op->getX().getType().template cast(); + auto x_ty = mlir::cast(op->getX().getType()); const bool is_f16 = x_ty.getElementType().isF16(); if (is_f16 && CanUseTensorCores(devices)) return "NHWC"; @@ -3045,7 +3046,7 @@ void GeneratorDatasetRegionOp::getSuccessorRegions( LogicalResult GatherV2Op::verify() { GatherV2Op op = *this; int64_t batch_dims = op.getBatchDims(); - if (auto ty = op.getIndices().getType().dyn_cast()) { + if (auto ty = mlir::dyn_cast(op.getIndices().getType())) { int64_t rank = ty.getRank(); if (batch_dims > rank || batch_dims < -rank) return op.emitOpError() @@ -3060,7 +3061,7 @@ LogicalResult GatherV2Op::verify() { DenseIntElementsAttr axis_attr; if (matchPattern(op.getAxis(), m_Constant(&axis_attr))) { int64_t axis = (*axis_attr.begin()).getSExtValue(); - if (auto ty = op.getParams().getType().dyn_cast()) { + if (auto ty = mlir::dyn_cast(op.getParams().getType())) { int64_t rank = ty.getRank(); if (axis >= rank || axis < -rank) return op.emitOpError() << "axis (" << axis << ") must be in range [" @@ -3283,7 +3284,7 @@ void IfRegionOp::getSuccessorRegions( // Verifies that the input is 1D. LogicalResult InvertPermutationOp::verify() { InvertPermutationOp op = *this; - auto x_type = op.getX().getType().cast(); + auto x_type = mlir::cast(op.getX().getType()); if (!x_type.hasRank()) return success(); if (x_type.getShape().size() != 1) return op.emitOpError() << "requires input x to be 1-dimensional"; @@ -3310,10 +3311,12 @@ OpFoldResult LeakyReluOp::fold(FoldAdaptor adaptor) { return FloatAttr::get(arg.getType(), val); }; - if (auto arg = operands[0].dyn_cast_or_null()) { + if (auto arg = mlir::dyn_cast_or_null(operands[0])) { return calculate(arg); - } else if (auto arg = operands[0].dyn_cast_or_null()) { - if (auto elementAttr = arg.getSplatValue().dyn_cast()) + } else if (auto arg = + mlir::dyn_cast_or_null(operands[0])) { + if (auto elementAttr = + mlir::dyn_cast(arg.getSplatValue())) return DenseElementsAttr::get(arg.getType(), calculate(elementAttr)); } return {}; @@ -3378,7 +3381,7 @@ OpFoldResult LogicalAndOp::fold(FoldAdaptor adaptor) { auto result_type = getType(); for (const auto& operand : operands) { - auto splat_attr = operand.dyn_cast_or_null(); + auto splat_attr = mlir::dyn_cast_or_null(operand); if (!splat_attr) continue; if (splat_attr.getType() != result_type) continue; @@ -3540,7 +3543,8 @@ LogicalResult MeanOp::FoldOperandsPermutation(ArrayRef permutation) { dyn_cast_or_null(getReductionIndices().getDefiningOp()); if (!reduction_op) return failure(); - auto reductions_value = reduction_op.getValue().dyn_cast(); + auto reductions_value = + mlir::dyn_cast(reduction_op.getValue()); if (!reductions_value) return failure(); // Prepare new reduction indices according to operand permutation. @@ -3597,8 +3601,8 @@ void HashTableOp::getCanonicalizationPatterns(RewritePatternSet& results, LogicalResult BitcastOp::verify() { BitcastOp op = *this; - auto input_type = op.getInput().getType().cast(); - auto output_type = op.getOutput().getType().cast(); + auto input_type = mlir::cast(op.getInput().getType()); + auto output_type = mlir::cast(op.getOutput().getType()); auto input_element_type = input_type.getElementType(); auto output_element_type = output_type.getElementType(); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.cc index d67c1da227d1c6..b3ce501c1c08d1 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h" +#include "mlir/Support/LLVM.h" // from @llvm-project + namespace mlir { namespace TF { @@ -60,7 +62,7 @@ ArrayAttr ShuffleArrayAttr(ArrayAttr attr, ArrayRef permutation, // Shuffle ranked tensor dimensions according to the permutation. Type ShuffleRankedTensorType(Type type, ArrayRef permutation) { - if (auto ranked_type = type.dyn_cast()) { + if (auto ranked_type = mlir::dyn_cast(type)) { ArrayRef shape = ranked_type.getShape(); assert(permutation.size() == shape.size()); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.cc index 24036d17b588e6..ca8f27a1489c06 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace TF { @@ -33,9 +34,9 @@ class IdentityNOp; RankedTensorType GetRankedTensorTypeForOperand(Value operand) { DenseElementsAttr attr; if (matchPattern(operand, m_Constant(&attr))) { - return attr.getType().dyn_cast(); + return mlir::dyn_cast(attr.getType()); } - return operand.getType().dyn_cast(); + return mlir::dyn_cast(operand.getType()); } // Returns the tf.Equal/tf.NotEqual result type given `x` and `y` and inputs. If @@ -53,7 +54,7 @@ Type DeduceEqualCmpOpType(Builder *builder, Location loc, Value x, Value y, } } - auto ranked_type = result_type.dyn_cast(); + auto ranked_type = mlir::dyn_cast(result_type); if (!ranked_type) return UnrankedTensorType::get(builder->getI1Type()); return RankedTensorType::get(ranked_type.getShape(), builder->getI1Type()); @@ -65,7 +66,7 @@ Type InferReductionOpType(Value input, Value reduction_indices, Type element_ty = getElementTypeOrSelf(input_ty); // Output type is unranked if input type is not ranked. - auto ranked_ty = input_ty.dyn_cast(); + auto ranked_ty = mlir::dyn_cast(input_ty); if (!ranked_ty) return UnrankedTensorType::get(element_ty); int64_t rank = ranked_ty.getRank(); @@ -124,7 +125,7 @@ LogicalResult VerifyTypesCompatibility(Operation::operand_type_range types, // the dimension index on the first mismatch and ignore dimension at that // index in following types. for (Type ty : types) { - RankedTensorType ranked_ty = ty.dyn_cast(); + RankedTensorType ranked_ty = mlir::dyn_cast(ty); if (!ranked_ty) continue; int64_t rank = ranked_ty.getRank(); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h index aaf795afd72917..e77ea7d77deef0 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { @@ -36,7 +37,7 @@ RankedTensorType GetRankedTensorTypeForOperand(Value operand); // given `rank`. inline bool IsOfRankedFloatTensorType(RankedTensorType type, int rank) { return type && type.getRank() == rank && - type.getElementType().isa(); + mlir::isa(type.getElementType()); } // Returns true if the given `value` has the specified rank or has unranked diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index e5ecef28a38377..45717471e373a2 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -46,11 +47,11 @@ namespace tf_saved_model { //===----------------------------------------------------------------------===// static bool IsStrArrayAttr(Attribute attr) { - auto array = attr.dyn_cast(); + auto array = mlir::dyn_cast(attr); if (!array) return false; - return llvm::all_of(array, - [](Attribute attr) { return attr.isa(); }); + return llvm::all_of( + array, [](Attribute attr) { return mlir::isa(attr); }); } //===----------------------------------------------------------------------===// @@ -58,10 +59,11 @@ static bool IsStrArrayAttr(Attribute attr) { //===----------------------------------------------------------------------===// LogicalResult VerifyTensorTypesCompatible(Type t1, Type t2) { - if (!t1.isa() || !t2.isa()) { + if (!mlir::isa(t1) || !mlir::isa(t2)) { return failure(); } - return verifyCompatibleShape(t1.cast(), t2.cast()); + return verifyCompatibleShape(mlir::cast(t1), + mlir::cast(t2)); } LogicalResult GlobalTensorOp::verify() { @@ -75,7 +77,7 @@ LogicalResult GlobalTensorOp::verify() { } } if (!global_tensor.getIsMutable()) { - if (!global_tensor.getType().cast().hasStaticShape()) { + if (!mlir::cast(global_tensor.getType()).hasStaticShape()) { return global_tensor.emitError() << "'type' attribute for immutable 'tf_saved_model.global_tensor' " "should have a static shape"; @@ -91,7 +93,7 @@ LogicalResult SessionInitializerOp::verify() { for (auto sym_ref : session_initializer.getInitializers()) { auto init_func_op = symbol_table.lookup( - sym_ref.cast().getValue()); + mlir::cast(sym_ref).getValue()); if (!init_func_op) return session_initializer.emitOpError() @@ -143,16 +145,16 @@ TensorFlowSavedModelDialect::TensorFlowSavedModelDialect(MLIRContext *context) } static LogicalResult VerifyIndexPath(Operation *op, NamedAttribute named_attr) { - auto attr = named_attr.getValue().dyn_cast(); + auto attr = mlir::dyn_cast(named_attr.getValue()); if (!attr) { return op->emitError() << "'" << kTfSavedModelIndexPathAttr << "' attribute should be an ArrayAttr"; } for (auto element : attr) { - if (element.isa()) { + if (mlir::isa(element)) { continue; } - if (auto integer = element.dyn_cast()) { + if (auto integer = mlir::dyn_cast(element)) { if (integer.getValue().getBitWidth() == 64) { continue; } @@ -165,7 +167,7 @@ static LogicalResult VerifyIndexPath(Operation *op, NamedAttribute named_attr) { Type GetBoundInputArgTypeFor(mlir::Operation *op) { if (auto global_tensor = llvm::dyn_cast(op)) { - auto type = global_tensor.getType().cast(); + auto type = mlir::cast(global_tensor.getType()); return RankedTensorType::get( {}, TF::ResourceType::get({type}, type.getContext())); } @@ -196,12 +198,12 @@ LogicalResult TensorFlowSavedModelDialect::verifyRegionArgAttribute( Operation *op, unsigned region_index, unsigned arg_index, NamedAttribute named_attr) { if (named_attr.getName() == "tf_saved_model.bound_input") { - if (!named_attr.getValue().isa()) { + if (!mlir::isa(named_attr.getValue())) { return op->emitError() << "'tf_saved_model.bound_input' attribute should " "be a FlatSymbolRefAttr"; } auto symbol_name = - named_attr.getValue().cast().getValue(); + mlir::cast(named_attr.getValue()).getValue(); auto module = op->getParentOfType(); mlir::Operation *symbol_op = module.lookupSymbol(symbol_name); if (!symbol_op) { @@ -292,8 +294,8 @@ static LogicalResult VerifySavedModelModule( &op, {exported_names_ident, attr}))) { return failure(); } - for (auto str : attr.cast()) { - auto exported_name = str.cast().getValue(); + for (auto str : mlir::cast(attr)) { + auto exported_name = mlir::cast(str).getValue(); auto p = exported_name_to_op.insert({exported_name, &op}); if (!p.second) { return op.emitError() @@ -341,7 +343,8 @@ static LogicalResult VerifySavedModelModule( auto init_syms = (*session_initializers.begin()).getInitializers(); return std::any_of( init_syms.begin(), init_syms.end(), [&](Attribute sym_ref) { - return sym_ref.cast().getValue() == func.getName(); + return mlir::cast(sym_ref).getValue() == + func.getName(); }); }; @@ -439,7 +442,7 @@ LogicalResult VerifyInitializerTypeAttr(Operation *op, // Validate the attribute value. auto initializer_type_attr_value = - named_attr.getValue().dyn_cast_or_null(); + mlir::dyn_cast_or_null(named_attr.getValue()); if (!initializer_type_attr_value) { return op->emitError() << "Attribute tf_saved_model.initializer_type " << "should be a StringAttr."; @@ -504,7 +507,7 @@ SmallVector GetExportedNames(Operation *op) { op->getAttrOfType(kTfSavedModelExportedNamesAttr); if (exported_names) { for (auto name : exported_names) { - ret.push_back(name.cast().getValue()); + ret.push_back(mlir::cast(name).getValue()); } } return ret; @@ -547,7 +550,7 @@ class OptimizeSessionInitializerPattern SmallVector to_keep; for (auto sym_ref : op.getInitializers()) { auto init_func_op = symbol_table.lookup( - sym_ref.cast().getValue()); + mlir::cast(sym_ref).getValue()); // The init function can only be referenced from the SessionInitializerOp. // And there is at most one SessionInitializerOp in the module. So if both @@ -590,7 +593,7 @@ SmallVector GetSessionInitializerExportedName(ModuleOp op) { SmallVector results; for (auto sym_ref : session_initializer_op.getInitializers()) { auto init_func_op = symbol_table.lookup( - sym_ref.cast().getValue()); + mlir::cast(sym_ref).getValue()); auto exported_names = GetExportedNames(init_func_op); assert(exported_names.size() == 1); results.push_back(exported_names[0]); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h index 62f6192c1f84f0..c6abd7689beddc 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -38,7 +39,7 @@ namespace TF { static inline LogicalResult VerifyRefTypeMatch(mlir::Type type, mlir::Type maybe_ref_type) { if (auto ref_type = - maybe_ref_type.dyn_cast()) + mlir::dyn_cast(maybe_ref_type)) return success(ref_type.RemoveRef().getTypeID() == type.getTypeID()); return failure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD index 00fd3934676c86..825430513060fc 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD @@ -147,6 +147,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/core:framework", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -298,6 +299,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ], ) @@ -1034,6 +1036,7 @@ cc_library( "//tensorflow/core:framework", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -1045,6 +1048,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_xla//xla:shape_util", "@local_xla//xla/mlir_hlo", "@local_xla//xla/stream_executor/tpu:c_api_conversions", @@ -1067,6 +1071,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:path", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc index 996686eb525d03..52765fb5657eba 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -70,14 +71,14 @@ void AnnotateParameterReplicationPass::runOnOperation() { if (mirrored_variable_indices_attr) { for (const auto& mirrored_index : mirrored_variable_indices_attr) { mirrored_replicate_args.insert( - mirrored_index.cast().getInt()); + mlir::cast(mirrored_index).getInt()); } } auto func = llvm::cast(m.lookupSymbol(cluster_func.getFunc())); for (auto entry : llvm::enumerate(cluster_func.getOperands())) { auto operand = SkipIdentityAndReadVariable(entry.value()); - auto block_arg = operand.dyn_cast(); + auto block_arg = mlir::dyn_cast(operand); if (block_arg && block_arg.getOwner() == &replicate.GetBody()) { // Only mirrored args of ReplicateOp can be annotated. if (mirrored_replicate_args.count(block_arg.getArgNumber()) == 0) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc index 14ab17be0fdee5..c6e21cb1e03054 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc @@ -53,8 +53,8 @@ class ConvertTFBatchMatMulToEinsumOp Value input_rhs = op.getY(); // LHS and RHS must be a ranked tensor type - auto lhs_type = input_lhs.getType().dyn_cast(); - auto rhs_type = input_rhs.getType().dyn_cast(); + auto lhs_type = mlir::dyn_cast(input_lhs.getType()); + auto rhs_type = mlir::dyn_cast(input_rhs.getType()); if (!lhs_type || !rhs_type) return failure(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc index 4b409ffe1f614f..3ce5fb5bcb8379 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #define DEBUG_TYPE "cluster-ops-by-policy" @@ -44,7 +45,7 @@ ValueConstraint Merge(ValueConstraint a, ValueConstraint b) { LogicalResult IsStaticallyResolved(Value value, ValueConstraint constraint) { // Resolve constraints inferred from the tensor type. - if (auto tensor = value.getType().dyn_cast()) { + if (auto tensor = mlir::dyn_cast(value.getType())) { if (constraint == ValueConstraint::kRank && tensor.hasRank()) return success(); if (constraint == ValueConstraint::kShape && tensor.hasStaticShape()) @@ -710,7 +711,7 @@ void EmitValueConstraintsRemarks(const ValuesConstraintSet &constraints) { void EmitInputsConstraintsRemarks(func::FuncOp func, const ValuesConstraintSet &constraints) { constraints.Walk([&](Value value, ValueConstraint constraint) { - if (auto arg = value.dyn_cast()) + if (auto arg = mlir::dyn_cast(value)) if (arg.getOwner() == &func.getBody().front()) func.emitRemark(llvm::formatv("input #{0} constrained to: {1}", arg.getArgNumber(), constraint)); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc index 5d9f5f9718446f..3d3e1305993a30 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc @@ -33,6 +33,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/util/device_name_utils.h" @@ -124,7 +125,7 @@ std::optional> GetFunctionMetadatas( // If the value is defined as an argument of the func_op, adds it to // the argument list of the function that uses this op. - if (BlockArgument block_arg = value.dyn_cast()) { + if (BlockArgument block_arg = mlir::dyn_cast(value)) { if (StringAttr attr = func_op.getArgAttrOfType( block_arg.getArgNumber(), kTFDeviceAttr)) { value_device = attr.getValue().str(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc index 59faa220521f0b..5a83e75e9eedf4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc @@ -62,7 +62,7 @@ Value GetR1Const(ArrayRef r1, OpBuilder builder, Location loc, Value GetIndicesForElement(Value index, Value buffer, OpBuilder builder, Location loc) { - auto buffer_type = buffer.getType().cast(); + auto buffer_type = mlir::cast(buffer.getType()); if (buffer_type.getShape().size() == 1) return index; // Create a concat of index and trailing zeros. llvm::SmallVector zeros(buffer_type.getShape().size() - 1, 0); @@ -77,7 +77,7 @@ Value GetIndicesForElement(Value index, Value buffer, OpBuilder builder, Value GetElement(Value index, Value buffer, OpBuilder builder, Location loc, bool keep_slice_shape) { - auto buffer_type = buffer.getType().cast(); + auto buffer_type = mlir::cast(buffer.getType()); // Create a slice then reshape to remove the leading trivial dimension of // size 1. llvm::SmallVector slice_size = @@ -102,7 +102,7 @@ Value GetElement(Value index, Value buffer, OpBuilder builder, Location loc, Value SetElement(Value index, Value buffer, Value element, OpBuilder builder, Location loc) { - auto buffer_type = buffer.getType().cast(); + auto buffer_type = mlir::cast(buffer.getType()); // Reshape the element to add a leading dimension of size 1 if th element does // not have that dimension, then perform a dynamic update slice. auto slice_shape = llvm::to_vector<8>(buffer_type.getShape()); @@ -208,7 +208,7 @@ std::optional GetElementTypeFromAccess( if (type_from_alias.has_value()) return type_from_alias; } else if (auto type = infer_from_op(use.getOwner())) { if (!type) continue; - auto elem_type = type->dyn_cast(); + auto elem_type = mlir::dyn_cast(*type); if (elem_type && elem_type.hasStaticShape()) return elem_type; } } @@ -220,8 +220,8 @@ Value ReadLocalVariable(Value local_var, OpBuilder builder, Location loc) { return builder .create( loc, - ArrayRef{getElementTypeOrSelf(local_var.getType()) - .cast() + ArrayRef{mlir::cast( + getElementTypeOrSelf(local_var.getType())) .getSubtypes()[0]}, ArrayRef{local_var}) .getValue(); @@ -246,7 +246,7 @@ Value AccumulateBuffers(Value a, Value b, OpBuilder builder, Location loc) { namespace { int64_t GetFirstIfIndicesAreContiguous(Value indices) { - auto type = indices.getType().dyn_cast(); + auto type = mlir::dyn_cast(indices.getType()); if (!type) return -1; auto indices_op = indices.getDefiningOp(); if (!indices_op) return -1; @@ -270,9 +270,10 @@ int64_t GetFirstIfIndicesAreContiguous(Value indices) { Value GatherElements(Value indices, Value buffer, OpBuilder builder, Location loc) { - auto buffer_type = buffer.getType().cast(); + auto buffer_type = mlir::cast(buffer.getType()); auto result_shape = llvm::to_vector<8>(buffer_type.getShape()); - result_shape[0] = indices.getType().cast().getDimSize(0); + result_shape[0] = + mlir::cast(indices.getType()).getDimSize(0); int64_t maybe_contiguous_start = GetFirstIfIndicesAreContiguous(indices); if (maybe_contiguous_start >= 0) { llvm::SmallVector slice_starts(result_shape.size(), 0); @@ -293,8 +294,8 @@ Value GatherElements(Value indices, Value buffer, OpBuilder builder, Value ScatterAccumulateElements(Value indices, Value updates, Value buffer, OpBuilder builder, Location loc) { - auto buffer_type = buffer.getType().cast(); - auto updates_type = updates.getType().cast(); + auto buffer_type = mlir::cast(buffer.getType()); + auto updates_type = mlir::cast(updates.getType()); int64_t maybe_contiguous_start = GetFirstIfIndicesAreContiguous(indices); if (maybe_contiguous_start == 0 && buffer_type == updates_type) { return AccumulateBuffers(buffer, updates, builder, loc); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index 880dfa837e881c..ca5eb4bc737b99 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -47,7 +48,7 @@ static bool IsFoldedByDefaultPolicy(Operation* inst) { auto get_size = [&](TypeRange types) { int64_t size = 0; for (auto t : types) { - auto tensor_type = t.cast(); + auto tensor_type = mlir::cast(t); // Ignore types with undefined bit widths. if (!tensor_type.getElementType().isIntOrFloat()) continue; if (!tensor_type.hasStaticShape()) { @@ -93,7 +94,7 @@ LogicalResult ConstantFoldFallbackHook( // propagation. bool has_empty_numerical_results = llvm::all_of(inst->getResultTypes(), [](Type ty) { - ShapedType shaped_ty = ty.cast(); + ShapedType shaped_ty = mlir::cast(ty); Type element_ty = shaped_ty.getElementType(); return shaped_ty.hasStaticShape() && shaped_ty.getNumElements() == 0 && element_ty.isIntOrFloat(); @@ -103,7 +104,7 @@ LogicalResult ConstantFoldFallbackHook( // addressed. inst->isRegistered()) { for (Type ty : inst->getResultTypes()) { - auto shaped_ty = ty.cast(); + auto shaped_ty = mlir::cast(ty); results.push_back( DenseElementsAttr::get(shaped_ty, llvm::ArrayRef())); } @@ -112,14 +113,14 @@ LogicalResult ConstantFoldFallbackHook( // Returns directly if any of the operands is not an elements attributes. if (std::any_of(operands.begin(), operands.end(), [](Attribute attr) { - return !attr || !attr.isa(); + return !attr || !mlir::isa(attr); })) return failure(); SmallVector inputs; inputs.reserve(operands.size()); for (auto input : operands) { - inputs.push_back(input.cast()); + inputs.push_back(mlir::cast(input)); } SmallVector constants; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.cc index a41dbdefa6f520..84c96590910243 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.cc @@ -25,6 +25,7 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" @@ -65,8 +66,8 @@ bool CanBeFolded(Operation* inst) { // This creates opaque variant constants which lose information and would // require "raising" later. for (const Type type : inst->getResultTypes()) { - if (const TensorType tensor_type = type.dyn_cast()) { - if (tensor_type.getElementType().isa()) { + if (const TensorType tensor_type = mlir::dyn_cast(type)) { + if (mlir::isa(tensor_type.getElementType())) { return false; } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc index 4de43317677f63..6262cad26ca6e3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc @@ -114,7 +114,7 @@ SmallVector GetWhileCallers(func::FuncOp func, } bool IsResourceType(Type type) { - return getElementTypeOrSelf(type).isa(); + return mlir::isa(getElementTypeOrSelf(type)); } bool OnlyOperatesOnCompositeDevices( @@ -124,11 +124,11 @@ bool OnlyOperatesOnCompositeDevices( auto& alias_analysis = side_effect_analysis.GetAliasAnalysis(); llvm::SmallSet read_array; for (const Attribute& attr : op.getDeviceVarReadsIndices()) { - read_array.insert(attr.cast().getInt()); + read_array.insert(mlir::cast(attr).getInt()); } llvm::SmallSet update_array; for (const Attribute& attr : op.getDeviceVarUpdatesIndices()) { - update_array.insert(attr.cast().getInt()); + update_array.insert(mlir::cast(attr).getInt()); } for (auto& arg : op->getOpOperands()) { @@ -270,7 +270,7 @@ void CollectChainResources( // // Checks if the value `control` is a NoOp control barrier. bool IsNoOpControlBarrier(Value control) { - if (!control.getType().isa()) return false; + if (!mlir::isa(control.getType())) return false; auto control_island = dyn_cast_or_null(control.getDefiningOp()); if (!control_island) return false; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_reduce_dataset.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_reduce_dataset.cc index 8e89f3988dd8d4..4af1246d5a72b6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_reduce_dataset.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_reduce_dataset.cc @@ -38,6 +38,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" @@ -76,7 +77,7 @@ AnonymousIteratorV3Op CreateIterator(OpBuilder builder, llvm::SmallVector type_attrs; for (Type type : dataset_types) { shape_attrs.push_back( - TF::ShapeAttr::get(builder.getContext(), type.cast())); + TF::ShapeAttr::get(builder.getContext(), mlir::cast(type))); type_attrs.push_back(TypeAttr::get(getElementTypeOrSelf(type))); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc index 51438ac4901b9d..4cdc90376c2317 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -90,7 +90,7 @@ TF::SumOp createSumOp(Value value, Location loc, PatternRewriter* rewriter) { Value redux_op = createI32ConstantOp(redux_axes, loc, rewriter); - auto value_type = value.getType().cast(); + auto value_type = mlir::cast(value.getType()); auto shape = value_type.getShape(); llvm::SmallVector sum_shape; for (int i = 0; i < shape.size(); ++i) { @@ -108,7 +108,7 @@ TF::TransposeOp createTransposeOp(Value value, Location loc, llvm::ArrayRef permutation, PatternRewriter* rewriter) { auto perm_op = createI32ConstantOp(permutation, loc, rewriter); - auto value_type = value.getType().cast(); + auto value_type = mlir::cast(value.getType()); auto shape = value_type.getShape(); SmallVector transposed_shape(shape.begin(), shape.end()); for (int i = 0, end = shape.size(); i < end; ++i) { @@ -529,7 +529,7 @@ LogicalResult rewriteToReduceSumAndTranspose(TF::EinsumOp op, bool needs_transpose = false; for (int64_t i = 0; i < dnums.lhs_out.size(); ++i) { if (std::get<0>(dnums.lhs_out[i]) > - lhs.getType().cast().getRank() - 1) { + mlir::cast(lhs.getType()).getRank() - 1) { continue; } @@ -637,8 +637,8 @@ LogicalResult reshapeForBatchMatmul(const Location& loc, Value* rhs, SmallVectorImpl* out_shape, PatternRewriter* rewriter) { - RankedTensorType lhs_type = lhs->getType().cast(); - RankedTensorType rhs_type = rhs->getType().cast(); + RankedTensorType lhs_type = mlir::cast(lhs->getType()); + RankedTensorType rhs_type = mlir::cast(rhs->getType()); int32_t num_lhs_reshape_segids = 0; int32_t num_rhs_reshape_segids = 0; @@ -776,7 +776,7 @@ LogicalResult rewriteToBatchMatmul(TF::EinsumOp op, EinsumDimensionNumbers original_dnums = dnums; RankedTensorType original_type = - op.getResult().getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(op.getResult().getType()); if (!original_type) return failure(); std::vector out_transpose; @@ -822,7 +822,7 @@ LogicalResult matchAndRewriteUnaryEinsumOp(TF::EinsumOp op, op, "Function only supports unary einsum op"); } RankedTensorType lhs = - op.getOperand(0).getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(op.getOperand(0).getType()); if (!lhs) { return failure(); } @@ -862,9 +862,9 @@ LogicalResult ConvertTFEinsumOp::matchAndRewrite( } RankedTensorType lhs = - op.getOperand(0).getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(op.getOperand(0).getType()); RankedTensorType rhs = - op.getOperand(1).getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(op.getOperand(1).getType()); if (!lhs || !rhs) { return failure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc index 51afea6d84671e..e1611432f36e8c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/core/platform/logging.h" @@ -443,7 +444,7 @@ void InsertDummyIslandForFetch(FetchOp fetch) { control_fetches.reserve(data_fetches.capacity()); for (auto value : fetch.getFetches()) { - if (value.getType().isa()) { + if (mlir::isa(value.getType())) { control_fetches.push_back(value); } else { data_fetches.push_back(value); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc index 9567278d98dc9c..2f456248c381af 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc @@ -624,7 +624,7 @@ void TpuV1BridgeExecutorIslandCoarsening::runOnOperation() { assert(!funcs_for_cluster->second.empty()); if (funcs_for_cluster->second.size() == 1) return false; for (NamedAttribute attr : op->getAttrs()) { - auto symbol_ref = attr.getValue().dyn_cast(); + auto symbol_ref = mlir::dyn_cast(attr.getValue()); if (!symbol_ref) continue; func::FuncOp callee = symbol_table.lookup(symbol_ref.getValue()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc index 0106d149d3d343..19603170b89e20 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc @@ -178,13 +178,14 @@ void TPUBridgeExecutorIslandOutlining::runOnOperation() { for (func::FuncOp func : outlined_module.getOps()) { func.walk([&](Operation *op) { for (NamedAttribute attr : op->getAttrs()) { - if (auto symbol_ref = attr.getValue().dyn_cast()) { + if (auto symbol_ref = + mlir::dyn_cast(attr.getValue())) { MoveFuncOp(symbol_ref, symbol_table, outlined_symbol_table); continue; } - if (auto array_attr = attr.getValue().dyn_cast()) { + if (auto array_attr = mlir::dyn_cast(attr.getValue())) { for (const Attribute &attribute : array_attr) { - auto symbol_ref = attribute.dyn_cast(); + auto symbol_ref = mlir::dyn_cast(attribute); if (!symbol_ref) continue; MoveFuncOp(symbol_ref, symbol_table, outlined_symbol_table); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/extract_tpu_copy_with_dynamic_shape_op.cc b/tensorflow/compiler/mlir/tensorflow/transforms/extract_tpu_copy_with_dynamic_shape_op.cc index dfd20d8dd0e07a..18480fbd772fa9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/extract_tpu_copy_with_dynamic_shape_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/extract_tpu_copy_with_dynamic_shape_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -47,7 +48,7 @@ class ExtractTPUCopyWithDynamicShapeOpPass // Finds op that created a given value. If the value is a BlockArgument, this // returns the owner of the Block. Operation* GetOpOfValue(Value value) { - if (auto block_arg = value.dyn_cast()) + if (auto block_arg = mlir::dyn_cast(value)) return block_arg.getOwner()->getParentOp(); return value.getDefiningOp(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc index d755696c74607b..6547b6f168c3bf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -144,7 +145,7 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( // Check that the result shape is fully defined. auto result_type = - op->getResultTypes().front().dyn_cast_or_null(); + mlir::dyn_cast_or_null(op->getResultTypes().front()); if (!result_type || !result_type.hasStaticShape()) return failure(); bool changed = false; @@ -155,15 +156,13 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( if (!broadcast) continue; // Check that the operand of the broadcast has fully defined shape. - auto broadcast_arg_type = - broadcast.getInput().getType().dyn_cast_or_null(); + auto broadcast_arg_type = mlir::dyn_cast_or_null( + broadcast.getInput().getType()); if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue; // Check that the other argument has fully defined shape. - auto argument_type = op->getOpOperand(1 - i) - .get() - .getType() - .dyn_cast_or_null(); + auto argument_type = mlir::dyn_cast_or_null( + op->getOpOperand(1 - i).get().getType()); if (!argument_type || !argument_type.hasStaticShape()) continue; // Get the unbroadcasted shapes in the operand order. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc index 6a1a4852e68a1c..9f9da90bf76594 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc @@ -86,7 +86,7 @@ void FreezeGlobalTensorsPass::runOnOperation() { DenseMap freezeable; for (auto func : module.getOps()) { for (BlockArgument val : func.getArguments()) { - if (!getElementTypeOrSelf(val.getType()).isa()) + if (!mlir::isa(getElementTypeOrSelf(val.getType()))) continue; // Check that there is only a single global tensor associated with arg. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc index 0cff8946687dcb..11be79869f4fd2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -101,7 +102,7 @@ YieldOp CreateCall(Operation* op, func::FuncOp func, Region& caller_region, // Converts the condition for an IfOp/WhileOp to a boolean value. Value ConvertConditionToBoolean(Operation* op, Value cond) { - if (auto ranked_type = cond.getType().dyn_cast()) + if (auto ranked_type = mlir::dyn_cast(cond.getType())) if (ranked_type.getRank() == 0 && ranked_type.getElementType().isSignlessInteger(1)) return cond; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc index 1c0a125598cdbe..4eb791a909022d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc @@ -332,9 +332,9 @@ class FuseMatMulBiasAdd } // FusedMatMul kernel does not support grad_a/grad_b attrs if ((matmul->hasAttr("grad_a") && - matmul->getAttr("grad_a").cast().getValue()) || + mlir::cast(matmul->getAttr("grad_a")).getValue()) || (matmul->hasAttr("grad_b") && - matmul->getAttr("grad_b").cast().getValue())) { + mlir::cast(matmul->getAttr("grad_b")).getValue())) { (void)rewriter.notifyMatchFailure(matmul, [&](Diagnostic &diag) { diag << "FusedMatMul kernel does not support grad_a/grad_b attrs"; }); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc index 78fb6aad3abdde..91f14794494de7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -135,7 +136,7 @@ void HoistLoopInvariantPass::runOnOperation() { // Skip the pass if the function inputs contain any resource. for (const auto &type : func.getArgumentTypes()) { - if (getElementTypeOrSelf(type).isa()) return; + if (mlir::isa(getElementTypeOrSelf(type))) return; } llvm::DenseSet read_only_vars = GetReadOnlyVariables(func); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD index f8e75d9032f3e5..3d046b4c41c51f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD @@ -252,6 +252,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", ], diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.cc index 767d5cf7f0cf8c..a21c78a9e3ca82 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_metadata_utils.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" @@ -93,7 +94,7 @@ LogicalResult SetMetadataProtoStepMarkerLocation( // Parses a xla::OpSharding from a string attribute. LogicalResult SetOpSharding(Operation* op, Attribute attr, llvm::StringRef name, int index, xla::OpSharding* sharding_ptr) { - auto sharding_attr = attr.dyn_cast(); + auto sharding_attr = mlir::dyn_cast(attr); if (!sharding_attr) return op->emitOpError( llvm::formatv(kBadStringArrayElementMsg, name, index)); @@ -130,7 +131,7 @@ LogicalResult SetMetadataProtoArgs( llvm::SmallSet dynamic_arg_idx_set; if (dynamic_arg_idx) { for (auto idx : dynamic_arg_idx.getValue()) { - dynamic_arg_idx_set.insert(idx.dyn_cast().getInt()); + dynamic_arg_idx_set.insert(mlir::dyn_cast(idx).getInt()); } } @@ -155,7 +156,8 @@ LogicalResult SetMetadataProtoArgs( // Populate argument shapes. *arg->mutable_shape() = tensorflow::TensorShapeProto(); - if (auto ranked_tensor_type = operand_type.dyn_cast()) { + if (auto ranked_tensor_type = + mlir::dyn_cast(operand_type)) { tensorflow::TensorShapeProto shape_proto; ConvertToTensorShapeProto(ranked_tensor_type.getShape(), &shape_proto); *arg->mutable_shape() = std::move(shape_proto); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_rewrite_pass.cc index ed1e0549dfb769..72521680142248 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_rewrite_pass.cc @@ -43,6 +43,7 @@ limitations under the License. #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -139,7 +140,7 @@ LogicalResult EncapsulateFuncAndSerialize(const std::string& module_name, assert(uses && "expected to be able to collect symbol uses"); for (SymbolTable::SymbolUse use : *uses) { func::FuncOp referenced_func = entry_module_table.lookup( - use.getSymbolRef().cast().getValue()); + mlir::cast(use.getSymbolRef()).getValue()); // Skip Symbols that do not map to a function. if (!referenced_func) continue; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_variable_runtime_reformatting.cc index 271e525ef8c7ae..4e87c10b1b7ac6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_variable_runtime_reformatting.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -73,7 +74,7 @@ struct TPUVariableRuntimeReformattingPass // provided, it will be used to store the identity nodes skipped. Value SkipIdentity(Value v, bool allow_other_use, llvm::SmallPtrSet* skipped = nullptr) { - while (auto result = v.dyn_cast()) { + while (auto result = mlir::dyn_cast(v)) { if (!(allow_other_use || v.hasOneUse())) break; auto op = result.getDefiningOp(); if (!llvm::isa(op)) { @@ -108,10 +109,10 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping( for (auto index_and_arg : llvm::enumerate(execute.getArgs())) { auto arg = SkipIdentity(index_and_arg.value(), /*allow_other_use=*/false); if (!arg.hasOneUse() || - !getElementTypeOrSelf(arg.getType()).isa()) { + !mlir::isa(getElementTypeOrSelf(arg.getType()))) { continue; } - auto block_arg = arg.dyn_cast(); + auto block_arg = mlir::dyn_cast(arg); if (!block_arg || block_arg.getOwner() != &replicate.GetBody()) continue; assert(replicate_arg_to_execute_arg.count(block_arg.getArgNumber()) == 0 && "Found duplicate use of a resource in the execute op."); @@ -131,13 +132,13 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping( // variables (arguments of `replicate`), and must be pass-throughs from while // operands. for (const auto& mirrored_index : mirrored_variable_indices_attr) { - int64_t replicate_arg = mirrored_index.cast().getInt(); + int64_t replicate_arg = mlir::cast(mirrored_index).getInt(); // Check if the mirrored variable is an input to `execute`. auto it = replicate_arg_to_execute_arg.find(replicate_arg); if (it == replicate_arg_to_execute_arg.end()) continue; // Get the data type of the resource. - auto subtypes = getElementTypeOrSelf(execute.getOperand(it->second)) - .cast() + auto subtypes = mlir::cast( + getElementTypeOrSelf(execute.getOperand(it->second))) .getSubtypes(); if (subtypes.size() != 1) continue; auto data_type = getElementTypeOrSelf(subtypes[0]); @@ -198,14 +199,14 @@ AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping( llvm::sort(mapping, llvm::less_first()); // Populate the `retval_index_for_sharding` field of the argument metadate. for (auto entry : llvm::enumerate(execute.getDeviceVarReadsIndices())) { - int64_t arg_index = entry.value().cast().getInt(); + int64_t arg_index = mlir::cast(entry.value()).getInt(); auto arg_metadata = metadata.mutable_args(arg_index); if (arg_metadata->enable_xla_sharding() == ::tensorflow::tpu::TPUCompileMetadataProto_Arg::ALLOWED) { - int64_t ret_index = execute.getDeviceVarUpdatesIndices() - .getValue()[entry.index()] - .cast() - .getInt(); + int64_t ret_index = + mlir::cast( + execute.getDeviceVarUpdatesIndices().getValue()[entry.index()]) + .getInt(); arg_metadata->set_retval_index_for_sharding(ret_index); } } @@ -379,12 +380,13 @@ bool HandleReplicateOp(TF::WhileRegionOp while_op, for (auto it : device_map) { auto device_alias = it.getName().strref(); - auto device_list = it.getValue().cast(); + auto device_list = mlir::cast(it.getValue()); llvm::SmallVector device_list_for_alias; device_list_for_alias.reserve(device_list.size()); for (auto device : device_list) - device_list_for_alias.emplace_back(device.cast().getValue()); + device_list_for_alias.emplace_back( + mlir::cast(device).getValue()); devices.insert({device_alias, device_list_for_alias}); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc b/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc index 3b974c395706aa..6e7fe42ef4dfab 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/initialize_variables_in_session_init.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" @@ -99,8 +100,7 @@ func::FuncOp GetOrCreateSessionInitFunc(ModuleOp module) { // tf_saved_model.initializer_type attribute was introduced. SymbolTable symbol_table(module); return symbol_table.lookup( - session_init_op.getInitializers()[0] - .cast() + mlir::cast(session_init_op.getInitializers()[0]) .getValue()); } else { return CreateSessionInitFunc(module); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc index aa1efc6837eee6..015499c6996f38 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -66,7 +67,7 @@ LogicalResult AssignDevicesInRegion(const Dialect* tf_dialect, return WalkResult::advance(); } - if (auto device_str_attr = device_attr.dyn_cast()) { + if (auto device_str_attr = mlir::dyn_cast(device_attr)) { if (device_str_attr.getValue().empty()) { op->setAttr(kDeviceAttr, launch.getDeviceAttr()); return WalkResult::advance(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc index 0fad3c019ea432..e8c1d1997e195e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/layout_optimization.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h" @@ -49,7 +50,7 @@ TransposeOp ReuseExistingTranspose(const OpOperand* operand, auto tranpose_op = *it; for (auto tranpose_operand : tranpose_op.getOperands()) { auto ranked_tranpose_type = - tranpose_operand.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(tranpose_operand.getType()); if (!ranked_tranpose_type) continue; if (ranked_tranpose_type.getRank() == permutation.size() && operand->get().getType() == @@ -201,7 +202,7 @@ void MoveTransposeBefore(Operation* op, SmallVector* work_list) { if (!perm) return; // With the same permutation indices. - auto dense_elem_attr = perm.getValue().dyn_cast(); + auto dense_elem_attr = mlir::dyn_cast(perm.getValue()); if (!dense_elem_attr) return; if (!permutation_op) permutation_op = perm; @@ -217,7 +218,7 @@ void MoveTransposeBefore(Operation* op, SmallVector* work_list) { // Nothing to do here. if (!permutation_op || transpose_ops.empty()) return; SmallVector permutation; - auto perm_attr = permutation_op.getValue().cast(); + auto perm_attr = mlir::cast(permutation_op.getValue()); for (const auto& value : perm_attr.getValues()) permutation.push_back(value.getSExtValue()); @@ -227,10 +228,11 @@ void MoveTransposeBefore(Operation* op, SmallVector* work_list) { if (op->hasTrait()) { auto transpose_op = *transpose_ops.begin(); auto result_type = - transpose_op.getResult().getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(transpose_op.getResult().getType()); auto is_valid_move = llvm::all_of(op->getOperands(), [result_type](Value operand) -> bool { - auto operand_type = operand.getType().dyn_cast_or_null(); + auto operand_type = + mlir::dyn_cast_or_null(operand.getType()); return result_type && operand_type && result_type.hasRank() && operand_type.hasRank() && result_type.getRank() == operand_type.getRank(); @@ -343,7 +345,7 @@ void MoveTransposeAfter(Operation* op, SmallVector* work_list, if (!perm) return; // With the same permutation indices. - auto dense_elem_attr = perm.getValue().dyn_cast(); + auto dense_elem_attr = mlir::dyn_cast(perm.getValue()); if (!dense_elem_attr) return; if (!permutation_op) permutation_op = perm; @@ -365,7 +367,7 @@ void MoveTransposeAfter(Operation* op, SmallVector* work_list, SmallVector permutation; - auto attr = permutation_op.getValue().cast(); + auto attr = mlir::cast(permutation_op.getValue()); for (const auto& value : attr.getValues()) permutation.push_back(value.getSExtValue()); @@ -373,7 +375,7 @@ void MoveTransposeAfter(Operation* op, SmallVector* work_list, if (fold_operands && fold_transpose_in_ops) { SmallVector permutation; - auto attr = permutation_op.getValue().cast(); + auto attr = mlir::cast(permutation_op.getValue()); for (const auto& value : attr.getValues()) permutation.push_back(value.getSExtValue()); @@ -408,7 +410,7 @@ void MoveTransposeAfter(Operation* op, SmallVector* work_list, // update the result type in `FoldOperandsPermutation`. if (layout_agnostic) result.setType(ReversePermuteShapedType( - result.getType().cast(), permutation)); + mlir::cast(result.getType()), permutation)); // Try to push transpose further down. for (Operation* user : result.getUsers()) { @@ -422,7 +424,7 @@ void MoveTransposeAfter(Operation* op, SmallVector* work_list, transpose.getOperation()->moveBefore(op->getNextNode()); transpose.setOperand(0, result); transpose.setOperand(1, permutation_op); - transpose.getResult().setType(original_type[idx].cast()); + transpose.getResult().setType(mlir::cast(original_type[idx])); } else { transpose = builder.create(loc, result, permutation_op); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc index bc0534fdb0bb84..c4ea84d8b0948c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc @@ -189,10 +189,10 @@ LogicalResult LiftVariables(ModuleOp module, Session* session) { func, arg_number, symbol_table); if (!global_tensor) continue; - auto arg_type = arg.getType().cast(); + auto arg_type = mlir::cast(arg.getType()); assert(arg_type.getRank() == 0); llvm::ArrayRef underlying_type = - arg_type.getElementType().cast().getSubtypes(); + mlir::cast(arg_type.getElementType()).getSubtypes(); // If the arg type already matches the global_tensor type, we don't need // to do anything. @@ -206,7 +206,7 @@ LogicalResult LiftVariables(ModuleOp module, Session* session) { auto new_arg_type = mlir::RankedTensorType::get( /*shape=*/{}, mlir::TF::ResourceType::get( - /*subtypes=*/{global_tensor.getType().cast()}, + /*subtypes=*/{mlir::cast(global_tensor.getType())}, module.getContext())); arg.setType(new_arg_type); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_globals_to_ml_program.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_globals_to_ml_program.cc index cdb256ab25f177..8d58b8177b33c8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_globals_to_ml_program.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_globals_to_ml_program.cc @@ -47,7 +47,7 @@ static LogicalResult traceUpwardsToArgument(Value v, llvm::DenseSet seen, } seen.insert(v); - if (auto blockArg = v.dyn_cast()) { + if (auto blockArg = mlir::dyn_cast(v)) { Operation *op = blockArg.getOwner()->getParentOp(); // If we're in the first block, then the argument to that block is the diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index 1c9b1e03a663c6..da565f00b45b99 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeRange.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" @@ -79,7 +80,7 @@ static DenseElementsAttr GetF32Scalar(OpBuilder *builder, float value) { // Preconditions: The given value must have a ShapedType. static Value CreateTFCastOpF32(OpBuilder *builder, Location loc, Value x, BoolAttr truncate) { - auto x_type = x.getType().dyn_cast_or_null(); + auto x_type = mlir::dyn_cast_or_null(x.getType()); if (!x_type) llvm_unreachable("unsupported type"); Type type = x_type.clone(builder->getF32Type()); return builder->create(loc, type, x, truncate); @@ -92,7 +93,7 @@ static Value CreateTFCastOpF32(OpBuilder *builder, Location loc, Value x, // Preconditions: The given value must have a ShapedType. static Value CreateTFCastOpI32(OpBuilder *builder, Location loc, Value x, BoolAttr truncate) { - auto x_type = x.getType().dyn_cast_or_null(); + auto x_type = mlir::dyn_cast_or_null(x.getType()); if (!x_type) llvm_unreachable("unsupported type"); Type type = x_type.clone(builder->getI32Type()); return builder->create(loc, type, x, truncate); @@ -109,7 +110,8 @@ static APFloat ConvertToAPFloat(double val, Type type) { // Performs the operation of `Shape(input)[idx]`. static Value GetDimensionSize(OpBuilder *builder, Location loc, Value input, int32_t idx, BoolAttr use_32bit) { - if (auto ranked_ty = input.getType().dyn_cast_or_null()) { + if (auto ranked_ty = + mlir::dyn_cast_or_null(input.getType())) { // Canonicalize negative index. if (idx < 0) { idx += ranked_ty.getRank(); @@ -154,7 +156,7 @@ bool QuantizedTypeIsUnsigned(Type type) { // to offset the quantized representation before it gets scaled. In the case // of negative quantize types, this offset is half the type's range. static DenseElementsAttr DequantizeHalfRange(OpBuilder *builder, Value input) { - auto input_type = input.getType().dyn_cast_or_null(); + auto input_type = mlir::dyn_cast_or_null(input.getType()); if (!input_type) llvm_unreachable("DequantizeHalfRange: not a ShapedType"); bool is_unsigned = QuantizedTypeIsUnsigned(input_type.getElementType()); float half_range = is_unsigned ? 0 : 128; @@ -183,7 +185,7 @@ DenseIntElementsAttr GetBiasAddGradReductionIndices(int64_t rank, // Infers ExpandDims op output type for the given input type `ty` and dimension // to expand at the given `axis`. Type InferExpandDimsType(Type ty, int64_t axis, Builder *builder) { - auto ranked_ty = ty.dyn_cast(); + auto ranked_ty = mlir::dyn_cast(ty); // Unranked type. if (!ranked_ty) return ty; @@ -258,7 +260,7 @@ class LowerAddNOp : public RewritePattern { // TODO(hinsu): Support variant with TensorList type. tf.AddV2 doesn't // support variant type so variant types require special handling. - if (getElementTypeOrSelf(addn_op.getType()).isa()) + if (mlir::isa(getElementTypeOrSelf(addn_op.getType()))) return failure(); llvm::SmallVector operands(addn_op.getInputs().begin(), addn_op.getInputs().end()); @@ -324,8 +326,7 @@ class LowerDynamicStitchOp : public RewritePattern { // Static output type is used to compute intermediate values. Note that the // output type doesn't have to be static but if input types and indices are // constant, then the output type can be statically determined. - RankedTensorType out_ty = - op.getType().template dyn_cast(); + RankedTensorType out_ty = mlir::dyn_cast(op.getType()); if (!out_ty || !out_ty.hasStaticShape()) return failure(); // Extract out all the constant indices' attributes and verify that data @@ -341,7 +342,7 @@ class LowerDynamicStitchOp : public RewritePattern { indices.push_back(index_attr); RankedTensorType data_ty = - data.getType().template dyn_cast(); + mlir::dyn_cast(data.getType()); if (!data_ty || !data_ty.hasStaticShape()) return failure(); } @@ -367,9 +368,8 @@ class LowerDynamicStitchOp : public RewritePattern { auto reshaped_data = rewriter.create(loc, data, packed_shape_val); - auto num_items = reshaped_data.getType() - .template cast() - .getShape()[0]; + auto num_items = + mlir::cast(reshaped_data.getType()).getShape()[0]; auto items = rewriter.create( loc, SmallVector(num_items, item_ty), reshaped_data, /*axis=*/0); @@ -407,7 +407,7 @@ class ConvertFakeQuantWithMinMaxVarsOp : public RewritePattern { auto op = cast(src_op); auto input = op.getInputs(); - auto input_ty = input.getType().cast(); + auto input_ty = mlir::cast(input.getType()); auto element_ty = input_ty.getElementType(); auto scalar_ty = tensorflow::GetTypeFromTFTensorShape({}, element_ty); @@ -534,7 +534,7 @@ class LowerInvertPermutationOp : public RewritePattern { auto op = cast(src_op); Location loc = op.getLoc(); - auto x_type = op.getX().getType().dyn_cast(); + auto x_type = mlir::dyn_cast(op.getX().getType()); // x input must have static shape. if (!x_type || !x_type.hasStaticShape()) { return failure(); @@ -617,12 +617,13 @@ class LowerLgammaOp : public RewritePattern { Location loc = op.getLoc(); Value input = op.getX(); - TensorType original_tensor_type = op.getX().getType().cast(); + TensorType original_tensor_type = + mlir::cast(op.getX().getType()); // The approximation is not precise enough for float16. Do the computation // in float32 for that case. TensorType tensor_type = original_tensor_type; - FloatType float_type = tensor_type.getElementType().cast(); + FloatType float_type = mlir::cast(tensor_type.getElementType()); bool needs_cast = float_type.getWidth() < 32; if (needs_cast) { MLIRContext *context = rewriter.getContext(); @@ -887,17 +888,18 @@ class LowerSpaceToBatchNDOp : public RewritePattern { auto op = cast(src_op); Location loc = op.getLoc(); - auto input_type = op.getInput().getType().cast(); + auto input_type = mlir::cast(op.getInput().getType()); auto element_type = input_type.getElementType(); if (!input_type.hasStaticShape()) { return failure(); } ArrayRef input_shape = input_type.getShape(); - auto block_shape_type = op.getBlockShape().getType().cast(); + auto block_shape_type = + mlir::cast(op.getBlockShape().getType()); if (!block_shape_type.hasStaticShape()) { return failure(); } - auto paddings_type = op.getPaddings().getType().cast(); + auto paddings_type = mlir::cast(op.getPaddings().getType()); if (!paddings_type.hasRank()) { return failure(); } @@ -1100,7 +1102,7 @@ class LowerBatchToSpaceND : public RewritePattern { PatternRewriter &rewriter) const override { auto op = cast(src_op); auto input = op.getInput(); - auto input_ty = input.getType().cast(); + auto input_ty = mlir::cast(input.getType()); auto element_ty = input_ty.getElementType(); if (!input_ty.hasStaticShape()) { return failure(); @@ -1279,9 +1281,7 @@ class LowerSparseMatMulOp : public RewritePattern { // Result type must be f32 for applying the pattern (currently this is // required by the op anyway but this might change). - if (!op.getProduct() - .getType() - .cast() + if (!mlir::cast(op.getProduct().getType()) .getElementType() .isF32()) { return failure(); @@ -1289,7 +1289,7 @@ class LowerSparseMatMulOp : public RewritePattern { MLIRContext *context = rewriter.getContext(); llvm::SmallVector operands{op.getA(), op.getB()}; for (Value &operand : operands) { - TensorType tensor_type = operand.getType().cast(); + TensorType tensor_type = mlir::cast(operand.getType()); Type element_type = tensor_type.getElementType(); if (element_type.isF32()) continue; // Element type can either be f32 or bf16 for `SparseMatMulOp` so it @@ -1374,13 +1374,13 @@ class LowerResizeNearestNeighbor : public RewritePattern { PatternRewriter &rewriter) const override { auto op = cast(src_op); auto loc = op.getLoc(); - auto result_ty = op.getType().cast(); + auto result_ty = mlir::cast(op.getType()); auto input = op.getImages(); - auto input_ty = input.getType().cast(); + auto input_ty = mlir::cast(input.getType()); auto input_element_ty = input_ty.getElementType(); auto out_size = op.getSize(); - auto out_size_ty = out_size.getType().cast(); + auto out_size_ty = mlir::cast(out_size.getType()); auto out_size_element_ty = out_size_ty.getElementType(); // Input should be rank 4. @@ -1620,7 +1620,7 @@ struct LowerRollOp : public RewritePattern { auto tf_roll_op = cast(op); auto input_ty = - tf_roll_op.getInput().getType().dyn_cast(); + mlir::dyn_cast(tf_roll_op.getInput().getType()); if (!input_ty || !input_ty.hasStaticShape()) { return rewriter.notifyMatchFailure( op, "require the type of input to have static shapes"); @@ -1628,7 +1628,8 @@ struct LowerRollOp : public RewritePattern { DenseIntElementsAttr shift_attr; Value shift = tf_roll_op.getShift(); - auto shift_ranked_attr_type = shift.getType().dyn_cast(); + auto shift_ranked_attr_type = + mlir::dyn_cast(shift.getType()); if (!shift_ranked_attr_type || !matchPattern(shift, m_Constant(&shift_attr))) { return failure(); @@ -1636,7 +1637,8 @@ struct LowerRollOp : public RewritePattern { DenseIntElementsAttr axis_attr; Value axis = tf_roll_op.getAxis(); - auto axis_ranked_attr_type = axis.getType().dyn_cast(); + auto axis_ranked_attr_type = + mlir::dyn_cast(axis.getType()); if (!axis_ranked_attr_type || !matchPattern(axis, m_Constant(&axis_attr))) { return failure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc index cd608bdf269ad7..80e7cd3991c727 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/utils/validators.h" @@ -57,13 +58,14 @@ class SimplifyBroadcastReshape : public OpRewritePattern { auto reshape_op = llvm::dyn_cast_or_null(user); if (!reshape_op) return failure(); - auto reshape_type = reshape_op.getOutput().getType().cast(); + auto reshape_type = + mlir::cast(reshape_op.getOutput().getType()); if (!reshape_type.hasStaticShape()) return failure(); ArrayRef reshape_shape = reshape_type.getShape(); - auto input_type = op.getInput().getType().cast(); - auto output_type = op.getOutput().getType().cast(); + auto input_type = mlir::cast(op.getInput().getType()); + auto output_type = mlir::cast(op.getOutput().getType()); if (!input_type.hasRank() || !output_type.hasRank()) return failure(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc index eaf881c43df95e..bfed05448bd25a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc @@ -94,7 +94,7 @@ GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) { continue; } auto global_tensor = symbol_table.lookup( - sym.cast().getValue()); + mlir::cast(sym).getValue()); if (!global_tensor) { continue; } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc index 661dafe2a2f327..b968923089cb8f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" @@ -79,8 +80,8 @@ class RewriteXlaHostComputeMlir llvm::SmallVector shape_attrs; shape_attrs.reserve(op.getNumResults()); for (Type ty : op.getResultTypes()) { - shape_attrs.push_back( - TF::ShapeAttr::get(rewriter.getContext(), ty.cast())); + shape_attrs.push_back(TF::ShapeAttr::get(rewriter.getContext(), + mlir::cast(ty))); } // Clone the `host_func` in the `host_mlir_module` attribute if it exists diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc index a7226b39ebe380..bc64c48c81a596 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -196,8 +197,8 @@ LogicalResult PromoteResourcesToArguments( auto func_args = function.getArguments().take_front( function.getNumArguments() - var_handle_shared_names.size()); for (BlockArgument& func_arg : func_args) { - auto resource_type = - getElementTypeOrSelf(func_arg.getType()).dyn_cast(); + auto resource_type = mlir::dyn_cast( + getElementTypeOrSelf(func_arg.getType())); if (!resource_type) continue; if (failed(ValidateResourceArgument(function, func_arg, resource_type))) return failure(); @@ -212,8 +213,8 @@ LogicalResult PromoteResourcesToArguments( auto var_handle_args = function.getArguments().take_back(var_handle_shared_names.size()); for (BlockArgument& var_handle_arg : var_handle_args) { - auto resource_type = - getElementTypeOrSelf(var_handle_arg.getType()).cast(); + auto resource_type = mlir::cast( + getElementTypeOrSelf(var_handle_arg.getType())); add_resource_argument(var_handle_arg, resource_type); } @@ -226,7 +227,8 @@ LogicalResult PromoteResourcesToArguments( // live value. for (Operation& op : llvm::make_early_inc_range(block)) { if (auto read_op = llvm::dyn_cast(&op)) { - if (auto func_arg = read_op.getResource().dyn_cast()) { + if (auto func_arg = + mlir::dyn_cast(read_op.getResource())) { if (func_arg.getOwner() != &block) return read_op.emitOpError(kResourceFunctionMsg); @@ -239,7 +241,8 @@ LogicalResult PromoteResourcesToArguments( read_op.erase(); } else if (auto write_op = llvm::dyn_cast(&op)) { - if (auto func_arg = write_op.getResource().dyn_cast()) { + if (auto func_arg = + mlir::dyn_cast(write_op.getResource())) { if (func_arg.getOwner() != &block) return write_op.emitOpError(kResourceFunctionMsg); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc b/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc index 975a1484d6984a..7c488b8992d2cb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/readonly_references_to_resources.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -91,7 +92,7 @@ StringRef GetNodeNameFromClassAttrOrSharedNameAttr(Operation *op) { StringRef result; for (Attribute class_attr : classes_attr) { - StringRef node_name = class_attr.cast().getValue(); + StringRef node_name = mlir::cast(class_attr).getValue(); if (!node_name.starts_with(kLocationPrefix)) { continue; } @@ -150,8 +151,8 @@ void ConvertReadonlyReferenceVariablesToResourceVariablesPass:: for (VariableV2Op variable_v2_op : variable_v2s_to_replace) { builder.setInsertionPoint(variable_v2_op); ShapedType shaped_type = - variable_v2_op.getResult().getType().cast(); - TensorType tensor_type = DropRefType(shaped_type).cast(); + mlir::cast(variable_v2_op.getResult().getType()); + TensorType tensor_type = mlir::cast(DropRefType(shaped_type)); StringAttr device_attr = variable_v2_op->getAttrOfType("device"); if (!device_attr) device_attr = builder.getStringAttr(""); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc index a669276e35a175..b740e667dabe84 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc @@ -508,10 +508,10 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp( // existing function as is. auto while_arg_matcher = [](Value first, Region& first_region, Value second, Region& second_region) { - if (!first.isa() || !second.isa()) + if (!mlir::isa(first) || !mlir::isa(second)) return false; - BlockArgument first_block_arg = first.cast(); - BlockArgument second_block_arg = second.cast(); + BlockArgument first_block_arg = mlir::cast(first); + BlockArgument second_block_arg = mlir::cast(second); // 2 block arguments will match if they are the same argument number, and // are block arguments of the corresponding containing regions. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc b/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc index 6aa3d161c0e121..18f54d6b5826d3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project #include "mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace TF { @@ -137,7 +138,7 @@ void RemoveUnusedArgumentsPass::runOnOperation() { // SymbolUserOpInterface doesn't tell us which attributes contain // the symbols, so we have to scan through all of them. for (auto attr : op->getAttrs()) { - if (auto sym = attr.getValue().dyn_cast()) { + if (auto sym = mlir::dyn_cast(attr.getValue())) { Operation* func = mlir::SymbolTable::lookupNearestSymbolFrom(op, sym); if (func) { do_not_touch.insert(func); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/remove_vars_in_session_initializer.cc b/tensorflow/compiler/mlir/tensorflow/transforms/remove_vars_in_session_initializer.cc index 9e85c5f9ed6fda..3a6377a3bb63e1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/remove_vars_in_session_initializer.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/remove_vars_in_session_initializer.cc @@ -55,7 +55,7 @@ void RecursiveRemove(Operation* op, erase_list.push_back(op); for (auto& use : op->getOpOperands()) { - if (auto op_result = use.get().dyn_cast()) { + if (auto op_result = mlir::dyn_cast(use.get())) { Operation* def = op_result.getDefiningOp(); if (!dead_ops.insert(def).second) continue; RecursiveRemove(def, erase_list, dead_ops); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc index 1c9558eecda702..803f135af624d7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_invariant_op_hoisting.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -90,7 +91,7 @@ void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas, Value input = shape_op.getInput(); // If ShapeOp operand is replicate tensor block argument, replace with the // associated first replica operand. - if (auto block_arg = input.dyn_cast()) { + if (auto block_arg = mlir::dyn_cast(input)) { if (block_arg.getOwner() != replicate_block) return; shape_op.setOperand(replicate_op.GetReplicaOperandForBlockArgument( @@ -112,7 +113,8 @@ void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas, // shape has not changed in replicate prior to read. Currently after both // ResourceOpLiftingPass and TPURewritePass, there should not be any updates // to resources prior to their respective ReadVariableOp. - if (auto block_arg = read_var_op.getResource().dyn_cast()) { + if (auto block_arg = + mlir::dyn_cast(read_var_op.getResource())) { if (block_arg.getOwner() != replicate_block) return; OpBuilder builder(shape_op); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index e03eb9a9228f35..90397e7f8237c9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -74,14 +74,14 @@ struct ResourceOpLiftingPass }; bool IsResource(Value value) { - return getElementTypeOrSelf(value.getType()).isa(); + return mlir::isa(getElementTypeOrSelf(value.getType())); } // Get the type of the data contained in a resource. Returns null if there is // no single type in the resource. Type GetResourceSubtype(Value value) { auto resource_type = - getElementTypeOrSelf(value.getType()).dyn_cast(); + mlir::dyn_cast(getElementTypeOrSelf(value.getType())); auto subtypes = resource_type.getSubtypes(); if (subtypes.size() == 1) return subtypes[0]; return nullptr; @@ -691,7 +691,7 @@ void RemoveUnusedResourceArgumentsAndForwardedRetvals( int64_t skipped_retvals = 0; for (auto entry : llvm::enumerate(old_return_vals)) { auto return_val = entry.value(); - if (auto arg = return_val.dyn_cast()) { + if (auto arg = mlir::dyn_cast(return_val)) { auto it = infos.find(arg.getArgNumber()); if (it != infos.end() && !it->getSecond().used) { return_op->eraseOperand(entry.index() - skipped_retvals++); @@ -747,7 +747,7 @@ LogicalResult LiftArgRetResourcesForFunction( // with type replaced. llvm::SmallVector skipped_args; for (auto& it : hoister.GetResources()) { - BlockArgument arg = it.first.dyn_cast(); + BlockArgument arg = mlir::dyn_cast(it.first); assert(arg && "Expect resources for FuncOp to be its arguments"); auto type_iter = resource_data_types.find(arg.getArgNumber()); if (type_iter == resource_data_types.end()) { @@ -772,7 +772,7 @@ LogicalResult LiftArgRetResourcesForFunction( Value resource = assign_variable_op.getResource(); if (!hoister.Contains(resource)) continue; - auto arg = resource.dyn_cast(); + auto arg = mlir::dyn_cast(resource); handle_updated_arg_value(arg.getArgNumber(), assign_variable_op.getValue()); assign_variable_op.erase(); } @@ -1018,11 +1018,11 @@ LogicalResult HandlePartitionedCallOpCallee( for (auto entry : llvm::enumerate(callee.front().getTerminator()->getOperands())) { auto retval = entry.value(); - if (!getElementTypeOrSelf(retval.getType()).isa()) { + if (!mlir::isa(getElementTypeOrSelf(retval.getType()))) { result->old_to_new_output_indices.push_back(non_resource_results++); continue; } - auto aliasing_arg = retval.dyn_cast(); + auto aliasing_arg = mlir::dyn_cast(retval); if (!aliasing_arg) { return callee.emitOpError("unsupported function call: ") << "resource return value does not alias an input."; @@ -1063,7 +1063,7 @@ LogicalResult HandlePartitionedCallOpCallee( llvm::SmallVector retval_indices_to_preserve; for (auto& val : callee.front().getTerminator()->getOpOperands()) { // Store indices of results that are not resources. - if (!getElementTypeOrSelf(val.get().getType()).isa()) + if (!mlir::isa(getElementTypeOrSelf(val.get().getType()))) retval_indices_to_preserve.push_back(val.getOperandNumber()); } int64_t num_retvals = retval_indices_to_preserve.size(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc index 2f1c675b305516..303e5aa2b6ddeb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -32,7 +33,7 @@ namespace mlir { namespace { bool IsResource(Value value) { - return getElementTypeOrSelf(value.getType()).isa(); + return mlir::isa(getElementTypeOrSelf(value.getType())); } // Checks if a cast op is casting a resource -> resource. @@ -182,7 +183,7 @@ void EliminateUnusedResultsForIfCase(Operation *op, if (cloned == func) continue; // Patch up the op attribute to point to the new function. for (NamedAttribute attr : op->getAttrs()) { - auto symref = attr.getValue().dyn_cast(); + auto symref = mlir::dyn_cast(attr.getValue()); if (!symref) continue; if (symref.getValue() != func.getName()) continue; op->setAttr(attr.getName(), @@ -301,7 +302,8 @@ LogicalResult ForwardCommonArgToOutput(Operation *op, std::optional common_arg_index; for (func::FuncOp func : branches) { auto ret = func.front().getTerminator(); - auto block_arg = ret->getOperand(result_idx).dyn_cast(); + auto block_arg = + mlir::dyn_cast(ret->getOperand(result_idx)); if (!block_arg) { return op->emitOpError("result #") << result_idx << " not tied to function argument for branch @" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h index c86c4383cc602f..4dd6ae7c8e4a7d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h @@ -18,6 +18,7 @@ limitations under the License. #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { namespace TF { @@ -27,13 +28,13 @@ namespace TF { template DenseElementsAttr GetScalarOfType(Type ty, T raw_value) { RankedTensorType scalar_ty = RankedTensorType::get({}, ty); - if (auto float_ty = ty.dyn_cast()) { + if (auto float_ty = mlir::dyn_cast(ty)) { FloatAttr attr = FloatAttr::get(float_ty, raw_value); return DenseElementsAttr::get(scalar_ty, attr); - } else if (auto int_ty = ty.dyn_cast()) { + } else if (auto int_ty = mlir::dyn_cast(ty)) { IntegerAttr attr = IntegerAttr::get(int_ty, raw_value); return DenseElementsAttr::get(scalar_ty, attr); - } else if (auto complex_ty = ty.dyn_cast()) { + } else if (auto complex_ty = mlir::dyn_cast(ty)) { Type complex_element_ty = complex_ty.getElementType(); if (complex_element_ty.isF32()) { return DenseElementsAttr::get( @@ -50,13 +51,13 @@ DenseElementsAttr GetScalarOfType(Type ty, T raw_value) { // to `raw_value`. template bool IsConstantValueOf(Value value, T raw_value) { - auto element_type = value.getType().cast().getElementType(); - if (element_type.isa()) { + auto element_type = mlir::cast(value.getType()).getElementType(); + if (mlir::isa(element_type)) { DenseFPElementsAttr float_attr; if (matchPattern(value, m_Constant(&float_attr)) && float_attr.isSplat() && float_attr.getSplatValue().isExactlyValue(raw_value)) return true; - } else if (element_type.isa()) { + } else if (mlir::isa(element_type)) { DenseIntElementsAttr int_attr; if (matchPattern(value, m_Constant(&int_attr)) && int_attr.isSplat() && int_attr.getSplatValue() == raw_value) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc b/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc index 4948aa68e13039..0eb552208194e1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/set_tpu_infeed_layout.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "xla/layout.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -64,27 +65,27 @@ FailureOr GetTPUInfeedLayout(const ArrayRef types, llvm::SmallVector v; v.reserve(types.size()); for (const mlir::Type &t : types) { - if (t.isa()) continue; + if (mlir::isa(t)) continue; auto layout = GetTPUInfeedLayout({t}, rewriter); if (failed(layout)) return failure(); v.push_back(layout.value()); } ArrayRef shape(v); return rewriter.getArrayAttr(shape); - } else if (types[0].isa()) { - auto tuple_type = types[0].dyn_cast(); + } else if (mlir::isa(types[0])) { + auto tuple_type = mlir::dyn_cast(types[0]); const auto &types = tuple_type.getTypes(); llvm::SmallVector v; v.reserve(types.size()); for (const mlir::Type &t : types) { - if (t.isa()) continue; + if (mlir::isa(t)) continue; auto layout = GetTPUInfeedLayout({t}, rewriter); if (failed(layout)) return failure(); v.push_back(layout.value()); } ArrayRef shape(v); return rewriter.getArrayAttr(shape); - } else if (auto t = types[0].dyn_cast()) { + } else if (auto t = mlir::dyn_cast(types[0])) { if (!t.hasStaticShape()) return failure(); auto layout = GetTPUInfeedLayoutFromAPI(t); std::vector minor_to_major; @@ -129,7 +130,7 @@ bool SetTPUInfeedLayout(mlir::OwningOpRef &mlir_module) { std::vector result_types; for (mlir::Type t : op.getResultTypes()) { - auto ty = t.cast(); + auto ty = mlir::cast(t); if (!ty.hasStaticShape()) return mlir::WalkResult::interrupt(); result_types.push_back(t); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 9d1d8d599e8224..6a9527aea26b3f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -125,9 +125,9 @@ Type TypeMeet(Type lhs, Type rhs) { DCOMMENT("RefineTypeWith : " << lhs << " : " << rhs); if (lhs == rhs) return lhs; - auto rhs_shape_type = rhs.dyn_cast(); + auto rhs_shape_type = mlir::dyn_cast(rhs); if (!rhs_shape_type) return lhs; - auto lhs_shape_type = lhs.cast(); + auto lhs_shape_type = mlir::cast(lhs); if (lhs_shape_type.hasRank() && rhs_shape_type.hasRank() && lhs_shape_type.getRank() != rhs_shape_type.getRank()) { DCOMMENT("Unexpected rank mismatch: " << lhs << " vs " << rhs); @@ -167,7 +167,8 @@ Type TypeMeet(Type lhs, Type rhs) { // returned type. auto lhs_element_type = lhs_shape_type.getElementType(); auto rhs_element_type_with_subtype = - rhs_shape_type.getElementType().dyn_cast(); + mlir::dyn_cast( + rhs_shape_type.getElementType()); // Look for resource or variant element type and ensure we refine the subtype. // We only support a single subtype at the moment, we won't handle something // like: @@ -175,7 +176,7 @@ Type TypeMeet(Type lhs, Type rhs) { if (rhs_element_type_with_subtype && rhs_element_type_with_subtype.GetSubtypes().size() == 1) { auto lhs_element_type_with_subtype = - lhs_element_type.dyn_cast(); + mlir::dyn_cast(lhs_element_type); TensorType subtype; if (!lhs_element_type_with_subtype) { DCOMMENT( @@ -193,10 +194,9 @@ Type TypeMeet(Type lhs, Type rhs) { // and: // tensor>> // we'll try here to refine tensor with tensor<10x8xf32>. - auto refined_subtype = + auto refined_subtype = mlir::cast( TypeMeet(lhs_element_type_with_subtype.GetSubtypes().front(), - rhs_element_type_with_subtype.GetSubtypes().front()) - .cast(); + rhs_element_type_with_subtype.GetSubtypes().front())); if (refined_subtype != lhs_element_type_with_subtype.GetSubtypes().front()) subtype = refined_subtype; @@ -272,7 +272,7 @@ Value GetElementShapeOperand(Operation* op) { // Utility function to create a ranked tensor type after dropping the first // dimension from the input type. RankedTensorType DropFirstDimension(Type type) { - RankedTensorType ranked_type = type.dyn_cast(); + RankedTensorType ranked_type = mlir::dyn_cast(type); if (!ranked_type) return {}; llvm::ArrayRef dims_except_first = ranked_type.getShape().drop_front(); @@ -282,7 +282,7 @@ RankedTensorType DropFirstDimension(Type type) { Operation* InsertCast(OpBuilder& b, Location loc, Type dst_type, Value input) { Type element_type = getElementTypeOrSelf(dst_type); - if (element_type.isa()) + if (mlir::isa(element_type)) return b.create(loc, dst_type, input); if (isa(element_type.getDialect())) return b.create(loc, dst_type, input, @@ -342,7 +342,7 @@ bool CanInferTensorListElementType(Value tensorlist, for (auto& use : tensorlist.getUses()) { if (auto push = llvm::dyn_cast(use.getOwner())) { auto element_type = - push.getTensor().getType().dyn_cast(); + mlir::dyn_cast(push.getTensor().getType()); if (!verify_and_update_potential_element_type(element_type)) return false; add_to_worklist(push.getOutputHandle()); @@ -361,7 +361,7 @@ bool CanInferTensorListElementType(Value tensorlist, } if (auto set_item = llvm::dyn_cast(use.getOwner())) { auto element_type = - set_item.getItem().getType().dyn_cast(); + mlir::dyn_cast(set_item.getItem().getType()); DCOMMENT("\tTensorListSetItemOp " << element_type); if (!verify_and_update_potential_element_type(element_type)) return false; @@ -433,8 +433,8 @@ bool CanInferTensorListElementType(Value tensorlist, // Returns the tensor type created from the `shape_attr` and `type_attr` // attributes. Type GetType(Attribute shape_attr, Attribute type_attr) { - auto shape = shape_attr.cast(); - auto type = type_attr.cast(); + auto shape = mlir::cast(shape_attr); + auto type = mlir::cast(type_attr); if (shape.hasRank()) return tensorflow::GetTypeFromTFTensorShape(shape.getShape(), type.getValue()); @@ -445,7 +445,7 @@ Type GetType(Attribute shape_attr, Attribute type_attr) { // Returns whether type can be further refined. bool CanBeRefined(Type type) { - auto shape_type = type.dyn_cast(); + auto shape_type = mlir::dyn_cast(type); if (!shape_type) return false; // Returns whether type with subtypes can be further refined. @@ -453,8 +453,8 @@ bool CanBeRefined(Type type) { return tws.GetSubtypes().empty() || llvm::any_of(tws.GetSubtypes(), CanBeRefined); }; - auto type_with_subtype = - shape_type.getElementType().dyn_cast(); + auto type_with_subtype = mlir::dyn_cast( + shape_type.getElementType()); if (type_with_subtype && can_refine_subtypes(type_with_subtype)) return true; return !shape_type.hasStaticShape(); @@ -467,7 +467,7 @@ Type GetNewArgType(Type old_arg_type, ArrayRef shape, Type element_type, mlir::MLIRContext* context) { Type new_arg_type = tensorflow::GetTypeFromTFTensorShape(shape, element_type); - if (auto input_ty = old_arg_type.dyn_cast()) { + if (auto input_ty = mlir::dyn_cast(old_arg_type)) { ArrayRef bounds = hlo::encodingToBounds(input_ty.getEncoding()); // The input type has bounded dynamic dimension. if (!bounds.empty()) { @@ -505,12 +505,12 @@ struct ValuePort { // Convert output value to ValuePort. explicit ValuePort(Value v) { - OpResult opr = v.dyn_cast(); + OpResult opr = mlir::dyn_cast(v); if (opr) { producer = opr.getOwner(); port = {opr.getResultNumber()}; } else { - producer = v.cast(); + producer = mlir::cast(v); port = {0}; } } @@ -549,7 +549,7 @@ using ValuePortInputs = SmallVectorImpl; // Maps the specified component in the `port` of the given op's result to one of // the element in the input. ValuePort ComputeInputComponentFor(PackOp op, ArrayRef port) { - auto type = op.getType().cast(); + auto type = mlir::cast(op.getType()); if (!type.hasRank() || type.getRank() != 1) return {}; if (port.size() != 2) return {}; assert(port[0] == 0); @@ -562,7 +562,7 @@ ValuePort ComputeInputComponentFor(ConcatV2Op op, ArrayRef port) { int64_t element_idx = port[1]; for (Value val : op.getValues()) { - auto val_ty = val.getType().cast(); + auto val_ty = mlir::cast(val.getType()); if (!val_ty.hasStaticShape() || val_ty.getRank() != 1) return {}; int64_t dim_size = val_ty.getNumElements(); @@ -583,7 +583,7 @@ ValuePort ComputeInputComponentFor(GatherV2Op op, ArrayRef port) { assert(port[0] == 0); auto params = op.getParams(); - auto params_ty = params.getType().dyn_cast(); + auto params_ty = mlir::dyn_cast(params.getType()); if (!params_ty || !params_ty.hasStaticShape() || params_ty.getRank() != 1 || op.getBatchDims() != 0) { return {}; @@ -687,7 +687,7 @@ Attribute ComputeOutputComponent(const ValuePort& value_port, if (auto shape_op = dyn_cast(op)) { // No shape available in an unranked tensor type. auto operand_ty = - shape_op.getOperand().getType().dyn_cast(); + mlir::dyn_cast(shape_op.getOperand().getType()); if (!operand_ty) return nullptr; // Shape op has a single output so the first element should always be zero @@ -1134,14 +1134,14 @@ bool ShapeInference::InferShapeForCast(Operation* op) { if (!new_type) { // Combine shape information when leaf element types are not the same, not // including shape info in subtypes. - auto ranked_operand_type = operand_type.dyn_cast(); + auto ranked_operand_type = mlir::dyn_cast(operand_type); if (!ranked_operand_type) return false; - auto ranked_res_type = result.getType().dyn_cast(); + auto ranked_res_type = mlir::dyn_cast(result.getType()); if (ranked_res_type && ranked_operand_type.getShape() == ranked_res_type.getShape()) return false; - auto shaped_res_type = result_type.dyn_cast(); + auto shaped_res_type = mlir::dyn_cast(result_type); if (!shaped_res_type) return false; new_type = tensorflow::GetTypeFromTFTensorShape( ranked_operand_type.getShape(), shaped_res_type.getElementType()); @@ -1296,7 +1296,7 @@ bool ShapeInference::InferShapeForXlaCallModule(XlaCallModuleOp op) { int next_op_result = 0; for (auto output_type : main_output_types) { if (tensorflow::IsTokenType(output_type)) continue; - auto output_type_ranked = output_type.dyn_cast(); + auto output_type_ranked = mlir::dyn_cast(output_type); if (output_type_ranked == nullptr) { llvm::errs() << "Unsupported XlaCallModule result type: " << output_type << "\n"; @@ -1422,20 +1422,20 @@ bool ShapeInference::InferShapeForRestore(Operation* op) { if (!assign_op) { continue; } - auto subtypes = getElementTypeOrSelf(assign_op.getResource()) - .cast() + auto subtypes = mlir::cast( + getElementTypeOrSelf(assign_op.getResource())) .getSubtypes(); if (subtypes.empty()) { continue; } - auto subtype = subtypes.front().dyn_cast(); + auto subtype = mlir::dyn_cast(subtypes.front()); if (subtype == nullptr) { continue; } // Preserve the dtype from the restore op even if `AssignVariableOp` uses a // different dtype, which is possible when there's a `CastOp` between them. subtype = subtype.clone( - op->getResult(0).getType().cast().getElementType()); + mlir::cast(op->getResult(0).getType()).getElementType()); // Update the result type of this op with the resource's type. We only use // the resource subtype of the first user since shapes from all the users // should be equal or compatible. @@ -1460,7 +1460,7 @@ DatasetInput GetDatasetInput(Value value) { while ( llvm::isa_and_nonnull(value.getDefiningOp())) { value = value.getDefiningOp()->getOperand( - value.cast().getResultNumber()); + mlir::cast(value).getResultNumber()); } Operation* op = value.getDefiningOp(); @@ -1668,14 +1668,14 @@ bool ShapeInference::InferShapeForTensorListPopBackOp(TensorListPopBackOp op) { DCOMMENT_OP(op, "Inferring shape for TensorListPopBackOp."); auto src_list_handle_t = - op.getOperand(0).getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(op.getOperand(0).getType()); if (!src_list_handle_t) return false; // Copy of operand tensorlist type. TensorType dst_list_handle_t = src_list_handle_t.clone(src_list_handle_t.getElementType()); auto variant_element_t = - dst_list_handle_t.getElementType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(dst_list_handle_t.getElementType()); if (!variant_element_t || variant_element_t.getSubtypes().size() != 1) return false; @@ -1726,7 +1726,7 @@ bool ShapeInference::InferShapeForVarHandleOp(VarHandleOp op) { llvm_unreachable("unexpected operator type"); } - TensorType resource_subtype = value.getType().cast(); + TensorType resource_subtype = mlir::cast(value.getType()); ResourceType resource_type = ResourceType::get({resource_subtype}, op.getContext()); UnrankedTensorType new_resource_type = @@ -1858,7 +1858,7 @@ bool ShapeInference::InferShapeForXlaReduceWindowOp(XlaReduceWindowOp op) { bool changed = false; - auto input_ty = op.getInput().getType().cast(); + auto input_ty = mlir::cast(op.getInput().getType()); DenseElementsAttr window_dimensions, window_strides, base_dilations, window_dilations, padding; if (input_ty.hasStaticShape() && @@ -1905,7 +1905,7 @@ bool ShapeInference::InferShapeForXlaReduceWindowOp(XlaReduceWindowOp op) { } auto output_shape = InferWindowOutputShape( input_ty, window.value(), - op.getInitValue().getType().cast().getElementType()); + mlir::cast(op.getInitValue().getType()).getElementType()); if (!output_shape) { op->emitOpError("failed to infer output shape"); @@ -1922,8 +1922,8 @@ bool ShapeInference::InferShapeForXlaSelectAndScatterOp( XlaSelectAndScatterOp op) { DCOMMENT_OP(op, "Inferring shape for XlaSelectAndScatterOp"); - auto operand_shape = op.getOperand().getType().cast(); - auto source_shape = op.getSource().getType().cast(); + auto operand_shape = mlir::cast(op.getOperand().getType()); + auto source_shape = mlir::cast(op.getSource().getType()); DenseElementsAttr window_dimensions, window_strides, padding; if (operand_shape.hasRank() && source_shape.hasRank() && matchPattern(op.getWindowDimensions(), m_Constant(&window_dimensions)) && @@ -2085,13 +2085,14 @@ LogicalResult PrecheckForXlaConvV2Op(XlaConvV2Op op) { int64_t batch_group_count = op.getBatchGroupCount(); auto input_args_have_static_shape = [&]() -> bool { - return input_tensor.getType().cast().hasStaticShape() && - kernel_tensor.getType().cast().hasStaticShape() && - window_strides.getType().cast().hasStaticShape() && - padding.getType().cast().hasStaticShape() && - lhs_dilation.getType().cast().hasStaticShape() && - rhs_dilation.getType().cast().hasStaticShape() && - feature_group_count.getType().cast().hasStaticShape(); + return mlir::cast(input_tensor.getType()).hasStaticShape() && + mlir::cast(kernel_tensor.getType()).hasStaticShape() && + mlir::cast(window_strides.getType()).hasStaticShape() && + mlir::cast(padding.getType()).hasStaticShape() && + mlir::cast(lhs_dilation.getType()).hasStaticShape() && + mlir::cast(rhs_dilation.getType()).hasStaticShape() && + mlir::cast(feature_group_count.getType()) + .hasStaticShape(); }; // Return failure when one of the input args has not a static shape @@ -2100,9 +2101,9 @@ LogicalResult PrecheckForXlaConvV2Op(XlaConvV2Op op) { } auto input_tensor_shape = - input_tensor.getType().cast().getShape(); + mlir::cast(input_tensor.getType()).getShape(); auto kernel_tensor_shape = - kernel_tensor.getType().cast().getShape(); + mlir::cast(kernel_tensor.getType()).getShape(); if (input_tensor_shape.size() <= 2) { return op.emitOpError() @@ -2229,14 +2230,16 @@ bool ShapeInference::InferShapeForXlaConvV2Op(XlaConvV2Op op) { xla::ConvolutionDimensionNumbers dnums; dnums.ParseFromString(op.getDimensionNumbersAttr().getValue().str()); - auto input_tensor_shape = input_tensor.getType().cast(); + auto input_tensor_shape = + mlir::cast(input_tensor.getType()); for (auto i = 0; i < input_tensor_shape.getShape().size(); ++i) { DCOMMENT("Input Tensor Shape " << i << "th is " << input_tensor_shape.getShape()[i]); input_tensor_dims_vec.push_back(input_tensor_shape.getShape()[i]); } - auto kernel_tensor_shape = kernel_tensor.getType().cast(); + auto kernel_tensor_shape = + mlir::cast(kernel_tensor.getType()); for (auto i = 0; i < kernel_tensor_shape.getShape().size(); ++i) { DCOMMENT("Kernel tensor Shape" << i << "th is " << kernel_tensor_shape.getShape()[i]); @@ -2319,7 +2322,7 @@ bool ShapeInference::RefineWithInferTypeOpInterface( ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result, InferenceContext* ic) { LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially ")); - auto rt = result.getType().dyn_cast(); + auto rt = mlir::dyn_cast(result.getType()); if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {}; int dim_size = rt.getDimSize(0); @@ -2366,7 +2369,7 @@ ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result, // If worklist is empty, then this is the root query op. if (worklist.empty()) { LLVM_DEBUG(llvm::dbgs() << "[root node]\n"); - if (auto dea = ret.dyn_cast()) { + if (auto dea = mlir::dyn_cast(ret)) { if (dea.getNumElements() != 1) { LLVM_DEBUG(llvm::dbgs() << "Unexpected number of elements\n"); return {}; @@ -2404,7 +2407,7 @@ bool ShapeInference::RefineTypeForPassThroughOperands(Operation* op, for (auto entry : llvm::zip(operands, results)) { Type operand_type = std::get<0>(entry).getType(); Value result = std::get<1>(entry); - TensorType result_type = result.getType().cast(); + TensorType result_type = mlir::cast(result.getType()); Type inferred_type = TypeMeet(result_type, operand_type); if (result_type == inferred_type) continue; @@ -2470,10 +2473,10 @@ bool ShapeInference::InferShapeForNonTFDialectOperation(Operation* op) { Type GetElementTypeFromOperand(TensorType operand_type, TensorType result_type) { auto operand_handle_type = - operand_type.getElementType().dyn_cast(); + mlir::dyn_cast(operand_type.getElementType()); if (!operand_handle_type) return result_type.getElementType(); auto result_handle_type = - result_type.getElementType().cast(); + mlir::cast(result_type.getElementType()); if (operand_handle_type.GetSubtypes().empty() || !result_handle_type.GetSubtypes().empty()) return result_type.getElementType(); @@ -2509,9 +2512,8 @@ bool ShapeInference::InferShapeForWhile(WhileOpTy op, for (auto entry : zip(op.getInput().getTypes(), op.getOutput(), body_result_types)) { Value result = std::get<1>(entry); - TensorType body_result_type = - std::get<2>(entry).template cast(); - auto result_type = result.getType().cast(); + TensorType body_result_type = mlir::cast(std::get<2>(entry)); + auto result_type = mlir::cast(result.getType()); Type potential_refined_type; if (CanWhileTypeBeRefinedWith(result_type, body_result_type)) { @@ -2522,7 +2524,7 @@ bool ShapeInference::InferShapeForWhile(WhileOpTy op, : std::optional>(), element_type); } else { - TensorType operand_type = std::get<0>(entry).template cast(); + TensorType operand_type = mlir::cast(std::get<0>(entry)); Type element_type = GetElementTypeFromOperand(operand_type, result_type); potential_refined_type = CreateTensorType( result_type.hasRank() ? result_type.getShape() @@ -2675,7 +2677,8 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op, // Return result element type at `index`. auto result_element_type_fn = [&](int index) { - return op->getResult(index).getType().cast().getElementType(); + return mlir::cast(op->getResult(index).getType()) + .getElementType(); }; llvm::SmallVector inferred_return_shapes; @@ -2702,7 +2705,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op, inferred_type = UnrankedTensorType::get(inferred.getElementType()); } inferred_type = - TypeMeet(op_result.getType(), inferred_type).cast(); + mlir::cast(TypeMeet(op_result.getType(), inferred_type)); if (op_result.getType() == inferred_type) continue; if (!UpdateTypeAndInsertIncompatibleUseCasts(inferred_type, op_result)) continue; @@ -2883,19 +2886,19 @@ llvm::SmallVector GetWhileCompatibleTypes( types.reserve(operand_types.size()); for (auto entry : llvm::zip(operand_types, result_types, region_argument_types)) { - auto operand_type = std::get<0>(entry).cast(); - auto result_type = std::get<1>(entry).cast(); + auto operand_type = mlir::cast(std::get<0>(entry)); + auto result_type = mlir::cast(std::get<1>(entry)); if (operand_type == result_type) { types.push_back(operand_type); } else if (RankedAndSameRank(operand_type, result_type)) { - auto potential_refined_type = - GetCompatibleRankedTensorType(operand_type.cast(), - result_type.cast()); + auto potential_refined_type = GetCompatibleRankedTensorType( + mlir::cast(operand_type), + mlir::cast(result_type)); types.push_back(potential_refined_type); } else { - auto region_argument_type = std::get<2>(entry).cast(); + auto region_argument_type = mlir::cast(std::get<2>(entry)); Type element_type = GetElementTypeFromOperand( - operand_type.cast(), region_argument_type); + mlir::cast(operand_type), region_argument_type); Type potential_refined_type = CreateTensorType( region_argument_type.hasRank() ? region_argument_type.getShape() : std::optional>(), @@ -3068,7 +3071,7 @@ LogicalResult ShapeInference::TryToFold(Operation* op) { } } - if (ElementsAttr eattr = attr.dyn_cast_or_null()) { + if (ElementsAttr eattr = mlir::dyn_cast_or_null(attr)) { if (std::get<0>(result).getType() == eattr.getType()) continue; (void)UpdateTypeAndInsertIncompatibleUseCasts(eattr.getType(), @@ -3268,13 +3271,15 @@ FailureOr InferShapeForFunction(func::FuncOp func, for (size_t i = 0; i < func_type.getNumInputs(); ++i) { ArrayRef shape = arg_shapes[i]; Type element_type; - if (auto input_ty = func_type.getInput(i).dyn_cast()) { + if (auto input_ty = + mlir::dyn_cast(func_type.getInput(i))) { if (input_ty.getRank() != shape.size()) { return failure(); } element_type = input_ty.getElementType(); } else { - auto unranked_input_ty = func_type.getInput(i).dyn_cast(); + auto unranked_input_ty = + mlir::dyn_cast(func_type.getInput(i)); if (!unranked_input_ty) { return failure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc index e565d50660558c..abef8ee04f2212 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc @@ -174,8 +174,8 @@ namespace TFDevice { namespace { bool IsResourceType(Type val_type) { - if (auto tensor_type = val_type.dyn_cast()) { - if (tensor_type.getElementType().isa()) { + if (auto tensor_type = mlir::dyn_cast(val_type)) { + if (mlir::isa(tensor_type.getElementType())) { return true; } } @@ -588,7 +588,7 @@ void GatherOpsForExtraction(mlir::SetVector* operations, if (predecessors) { for (Value operand : op->getOperands()) { // Stop at the block boundary. - if (operand.isa()) continue; + if (mlir::isa(operand)) continue; Operation* predecessor = operand.getDefiningOp(); if (!operations->contains(predecessor) && @@ -1867,7 +1867,7 @@ void EmbeddingPipeliningPass::runOnOperation() { for (int ret_pos = 0; ret_pos < orig_return_op->getNumOperands(); ++ret_pos) { auto operand = orig_return_op->getOperand(ret_pos); auto def_op = operand.getDefiningOp(); - auto result = operand.dyn_cast(); + auto result = mlir::dyn_cast(operand); if (def_op == non_tpu_caller) { loop_arg_update_map_non_tpu[result.getResultNumber()] = ret_pos; } else if (def_op == core_tpu_caller) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_program_key.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_program_key.cc index 3e41762feb16c2..1e7958660fd8c4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_program_key.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_program_key.cc @@ -314,7 +314,7 @@ void CreateReducedLaunchOp(OpBuilder* builder, Block* old_block, // Handle pass through block arguments. for (OpOperand& operand : original_launch_op.GetBody().getTerminator()->getOpOperands()) { - if (operand.get().isa()) { + if (mlir::isa(operand.get())) { original_launch_op.getResult(operand.getOperandNumber()) .replaceAllUsesWith(operand.get()); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_sequencing.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_sequencing.cc index 577b374a43847d..b224b723cda50d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_sequencing.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_sequencing.cc @@ -95,8 +95,8 @@ std::vector GetValueTypes(const InputContainer& input) { } bool IsResourceType(Type val_type) { - if (auto tensor_type = val_type.dyn_cast()) { - if (tensor_type.getElementType().isa()) { + if (auto tensor_type = mlir::dyn_cast(val_type)) { + if (mlir::isa(tensor_type.getElementType())) { return true; } } @@ -139,7 +139,7 @@ void GatherOpsForExtraction(mlir::SetVector* operations, if (predecessors) { for (Value operand : op->getOperands()) { // Stop at the block boundary. - if (operand.isa()) continue; + if (mlir::isa(operand)) continue; Operation* predecessor = operand.getDefiningOp(); if (!operations->contains(predecessor) && diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc index fb9848dbaeac47..476a67b496355f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc @@ -73,7 +73,7 @@ Type GetSizeVarType(OpBuilder builder) { // forwards the argument. Otherwise, returns -1. int64_t FindAliasedInput(func::FuncOp func, int64_t return_index) { Value return_val = func.front().getTerminator()->getOperand(return_index); - auto maybe_arg = return_val.dyn_cast(); + auto maybe_arg = mlir::dyn_cast(return_val); if (!maybe_arg) return -1; return maybe_arg.getArgNumber(); } @@ -180,8 +180,8 @@ LogicalResult HandleWhileOp( while_op.getLoc(), body.getFunctionType().getInputs(), new_while_operands, while_op->getAttrs()); for (int64_t i = 0; i < while_op.getNumResults(); ++i) { - if (!getElementTypeOrSelf(while_op.getOperand(i).getType()) - .isa()) { + if (!mlir::isa( + getElementTypeOrSelf(while_op.getOperand(i).getType()))) { continue; } int64_t aliased_input = FindAliasedInput(body, i); @@ -233,7 +233,7 @@ LogicalResult HandleIfOp( if_op.getLoc(), then_func.getFunctionType().getResults(), new_if_operands, if_op->getAttrs()); for (auto result : if_op.getResults()) { - if (!getElementTypeOrSelf(result.getType()).isa()) { + if (!mlir::isa(getElementTypeOrSelf(result.getType()))) { continue; } int64_t then_aliased_input = @@ -287,8 +287,8 @@ LogicalResult HandlePartitionedCallOp( const_cast(info.decomposed_callee).getName())); for (int64_t i = 0; i < call.getNumResults(); ++i) { auto result = call.getResult(i); - if (!getElementTypeOrSelf(result.getType()) - .template isa()) { + if (!mlir::isa( + getElementTypeOrSelf(result.getType()))) { continue; } int64_t aliased_input = FindAliasedInput(info.decomposed_callee, i); @@ -328,9 +328,9 @@ LogicalResult HandlePartitionedCallOp( } else { info.decomposed_callee = lowered_callee; for (auto& entry : callee_map) { - info.stack_var_arg_to_size_arg - [entry.getFirst().cast().getArgNumber()] = - entry.getSecond().cast().getArgNumber(); + info.stack_var_arg_to_size_arg[mlir::cast(entry.getFirst()) + .getArgNumber()] = + mlir::cast(entry.getSecond()).getArgNumber(); } if (lowered_callee != callee) { // Add the clone with a new name. @@ -372,7 +372,7 @@ LogicalResult HandleStackV2Op( auto size_var_type = GetSizeVarType(builder); auto var_type = RankedTensorType::get( {}, TF::ResourceType::get( - ArrayRef{buffer.getType().cast()}, + ArrayRef{mlir::cast(buffer.getType())}, stack.getContext())); auto local_var = builder.create( stack.getLoc(), ArrayRef{var_type}, ArrayRef{}); @@ -446,7 +446,8 @@ LogicalResult HandleRegionControlFlowOps( llvm::StringMap* decomposed_partitioned_call_callees) { for (OpOperand& operand : op.getOpOperands()) { - if (getElementTypeOrSelf(operand.get().getType()).isa()) { + if (mlir::isa( + getElementTypeOrSelf(operand.get().getType()))) { return op.emitOpError() << "found unexpected type " << operand.get().getType() << " of operand #" << operand.getOperandNumber() @@ -455,7 +456,7 @@ LogicalResult HandleRegionControlFlowOps( } } for (OpResult result : op.getResults()) { - if (getElementTypeOrSelf(result.getType()).isa()) { + if (mlir::isa(getElementTypeOrSelf(result.getType()))) { return op.emitOpError() << "found unexpected type " << result.getType() << " of result #" << result.getResultNumber() diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc index b18a6a3496649a..267f32daa9f6e6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Pass/PassOptions.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" @@ -67,7 +68,7 @@ void TensorDeviceCopyConversionPass::runOnOperation() { (isa(def_op))) { return true; } - if (BlockArgument block_arg = arg.dyn_cast()) { + if (BlockArgument block_arg = mlir::dyn_cast(arg)) { // Skip the folding logic if the block argument is not from the function // arguments. This can happen when the argument is from a while loop. if (block_arg.getParentRegion() != &func_op.getRegion()) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc index 278ba1f7fdf65b..a9ad31a28461f7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -510,10 +511,10 @@ LogicalResult HandlePartitionedCallOp( } else { info.signature_change = true; for (auto& entry : callee_map) { - auto buffer_arg = entry.getFirst().dyn_cast(); + auto buffer_arg = mlir::dyn_cast(entry.getFirst()); if (!buffer_arg) continue; info.buffer_arg_to_size_arg[buffer_arg.getArgNumber()] = - entry.getSecond().size.cast().getArgNumber(); + mlir::cast(entry.getSecond().size).getArgNumber(); } if (lowered_callee != callee) { // Add the clone with a new name. @@ -549,7 +550,8 @@ LogicalResult GetConstShapeValue(Value shape_value, // return error. LogicalResult GetElementShapeFromResultType( Type type, llvm::SmallVector* shape) { - auto variant_type = getElementTypeOrSelf(type).dyn_cast(); + auto variant_type = + mlir::dyn_cast(getElementTypeOrSelf(type)); if (!variant_type || variant_type.getSubtypes().size() != 1) return failure(); TensorType tensor_type = variant_type.getSubtypes().front(); if (!tensor_type.hasStaticShape()) return failure(); @@ -619,7 +621,7 @@ LogicalResult HandleTensorListFromTensorOp( Value buffer = builder.create( list.getLoc(), ArrayRef{list.getTensor().getType()}, ArrayRef{list.getTensor()}); - auto type = buffer.getType().cast(); + auto type = mlir::cast(buffer.getType()); if (!type.hasStaticShape()) { return list.emitOpError("TensorListFromTensorOp input has unknown shape."); } @@ -733,8 +735,8 @@ LogicalResult HandleTensorListLengthOp( OpBuilder builder(length); if (it->getSecond().fixed) { auto dim = cutil::CreateScalarConst( - length.getInputHandle().getType().cast().getDimSize( - 0), + mlir::cast(length.getInputHandle().getType()) + .getDimSize(0), builder, length.getLoc()); length.getLength().replaceAllUsesWith(dim); } else { @@ -760,7 +762,7 @@ LogicalResult HandleTensorListElementShapeOp( } auto buffer = elem_shape.getInputHandle(); auto result = cutil::GetR1Const( - buffer.getType().cast().getShape().drop_front(), + mlir::cast(buffer.getType()).getShape().drop_front(), OpBuilder(elem_shape), elem_shape.getLoc(), elem_shape.getShapeType().getIntOrFloatBitWidth()); elem_shape.getElementShape().replaceAllUsesWith(result); @@ -792,7 +794,8 @@ LogicalResult HandleTensorListScatterIntoExistingListOp( } auto buffer = scatter.getInputHandle(); OpBuilder builder(scatter); - auto indices_type = scatter.getIndices().getType().cast(); + auto indices_type = + mlir::cast(scatter.getIndices().getType()); if (!indices_type) return scatter.emitOpError("unranked indices shape"); auto shape_type = RankedTensorType::get({2}, builder.getIntegerType(32)); auto shape = builder.create( @@ -874,7 +877,8 @@ LogicalResult DecomposeTensorListOpsInternal( } else if (auto addn = llvm::dyn_cast(&op)) { auto it = buffer_to_size->find(addn.getOperand(0)); if (it != buffer_to_size->end()) { - addn.getSum().setType(addn.getOperand(0).getType().cast()); + addn.getSum().setType( + mlir::cast(addn.getOperand(0).getType())); auto size = it->getSecond(); (*buffer_to_size)[addn.getSum()] = size; } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.cc index 0bc7b47377fa3f..40d9032b499ff6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -70,7 +71,7 @@ class AssetSinkingPass : public impl::AssetSinkingPassBase { SymbolTable symbol_table(module); for (auto initializer : init_op.getInitializers()) { auto func = symbol_table.lookup( - initializer.cast().getValue()); + mlir::cast(initializer).getValue()); RewriteFunction(symbol_table, func); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tfg-to-tfe.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tfg-to-tfe.cc index 7f449520030876..8180b4116ef21b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tfg-to-tfe.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tfg-to-tfe.cc @@ -18,12 +18,13 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -58,7 +59,7 @@ static mlir::LogicalResult FilterTfgSpecificArgResultAttributes( llvm::SmallVector &output_attrs) { for (auto it : llvm::zip( types, array_attr.template getAsRange())) { - if (std::get<0>(it).isa()) continue; + if (mlir::isa(std::get<0>(it))) continue; output_types.push_back(std::get<0>(it)); mlir::NamedAttrList list; @@ -80,7 +81,7 @@ static mlir::LogicalResult ReformatOpAttributes( mlir::tfg::TFGraphDialect::getDeviceAttrKey())) { tensorflow::DeviceNameUtils::ParsedName parsed_name; if (!tensorflow::DeviceNameUtils::ParseFullName( - attr.getValue().cast().getValue().str(), + mlir::cast(attr.getValue()).getValue().str(), &parsed_name)) return mlir::failure(); if (!parsed_name.has_type) { @@ -106,7 +107,7 @@ static mlir::LogicalResult ReformatOpAttributes( static void FilterOutBlockArgControlDep( ValueRange operands, llvm::SmallVectorImpl &filtered) { for (Value value : operands) - if (!value.isa()) filtered.push_back(value); + if (!mlir::isa(value)) filtered.push_back(value); } // Split the tfg.NextIteration into tf_executor::NextIterationSourceOp and @@ -218,7 +219,7 @@ class ConvertGraphFuncOp : public OpConversionPattern { Block &block = graph_func.getBody().front(); for (auto iter = block.args_begin(), end_iter = block.args_end(); iter != end_iter; ++iter) { - if (!iter->getType().isa()) + if (!mlir::isa(iter->getType())) iter->replaceAllUsesWith(func.getBody().getArgument(idx++)); } @@ -412,9 +413,9 @@ class ConvertGeneralOp : public ConversionPattern { for (Value value : operands) { // Because of the property of graph region, the control operands may // not have been converted to tf_executor::ControlType. - if (value.getType().isa() || - value.getType().isa()) { - if (!value.isa()) + if (mlir::isa(value.getType()) || + mlir::isa(value.getType())) { + if (!mlir::isa(value)) island_control_operands.push_back(value); } else { inner_op_operands.push_back(value); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_annotate_dynamic_shape_inputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_annotate_dynamic_shape_inputs.cc index d1a244b7f2ec2a..b4a98605a34ac2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_annotate_dynamic_shape_inputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_annotate_dynamic_shape_inputs.cc @@ -53,7 +53,7 @@ class TPUAnnotateDynamicShapeInputsPass // Finds op that created a given value. If the value is a BlockArgument, this // returns the owner of the Block. Operation* GetOpOfValue(Value value) { - if (auto block_arg = value.dyn_cast()) + if (auto block_arg = mlir::dyn_cast(value)) return block_arg.getOwner()->getParentOp(); return value.getDefiningOp(); @@ -98,7 +98,7 @@ void TPUAnnotateDynamicShapeInputsPass::runOnOperation() { // Update the marked argument with dynamic shapes. for (int index : dynamic_shape_arg_index) { BlockArgument arg = func.getArgument(index); - auto inputType = arg.getType().dyn_cast(); + auto inputType = mlir::dyn_cast(arg.getType()); // Only rank 1 tensor is supported for now. if (!inputType || inputType.getRank() != 1) continue; auto shape = llvm::to_vector<4>(inputType.getShape()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc index a6f9d7d4c63f01..e2b9c62ee8e6bc 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_device_propagation.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/UseDefLists.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h" @@ -94,8 +95,8 @@ void PopulateDeviceForOpResults( op_to_update = op_to_update->getParentOp(); for (Value result : op_to_update->getResults()) { - if (result.getType().isa()) continue; - if (result.getType().isa()) break; + if (mlir::isa(result.getType())) continue; + if (mlir::isa(result.getType())) break; value_to_device.insert({result, device}); } @@ -118,8 +119,8 @@ llvm::StringRef FindDeviceFromOperands( llvm::StringRef new_device; const bool is_switch = llvm::isa(op); for (Value operand : op.getOperands()) { - if (operand.getType().isa()) continue; - if (operand.getType().isa()) break; + if (mlir::isa(operand.getType())) continue; + if (mlir::isa(operand.getType())) break; if (is_switch && llvm::isa_and_nonnull(operand.getDefiningOp())) @@ -230,7 +231,7 @@ void PropagateDevicesToResults( mlir::Builder builder(func.getOperation()); for (OpOperand& operand : fetch.getOperation()->getOpOperands()) { - if (operand.get().getType().isa()) break; + if (mlir::isa(operand.get().getType())) break; auto it = value_to_device.find(operand.get()); if (it != value_to_device.end()) { auto device_attr = func.getResultAttrOfType( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc index 04b488a38048fd..2281658efc5ed1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_dynamic_layout_pass.cc @@ -86,7 +86,7 @@ bool IsSupportedInputOp( resource_alias_analysis.GetResourceAliases(resource_iterator); auto is_generator = [](Value val) { - if (val.isa()) return true; + if (mlir::isa(val)) return true; Operation* definition = val.getDefiningOp(); return definition->getNumOperands() == 0 && definition->getNumResults() == 1; @@ -99,7 +99,7 @@ bool IsSupportedInputOp( if (!is_generator(alias)) return true; StringAttr device; - if (auto arg = alias.dyn_cast()) { + if (auto arg = mlir::dyn_cast(alias)) { device = func.getArgAttrOfType(arg.getArgNumber(), kFuncDeviceAttr); } else { @@ -186,10 +186,8 @@ bool HandleReplicatedInputs( BuildCopyWithLayout(execute_launch, compile_launch, get_layout, entry.value().get(), &builder); - auto device_list = replicate.getDevices() - .value() - .get(execute_launch.getDevice()) - .cast(); + auto device_list = mlir::cast( + replicate.getDevices().value().get(execute_launch.getDevice())); copy_with_layout->setAttr(kDeviceAttr, device_list.getValue()[entry.index()]); @@ -225,7 +223,7 @@ void HandleCompileAndExecutes( for (const auto& input_and_idx : llvm::enumerate(execute.getArgs())) { Value input = input_and_idx.value(); const int64_t execute_arg_index = input_and_idx.index(); - if (auto block_arg = input.dyn_cast()) { + if (auto block_arg = mlir::dyn_cast(input)) { // For a block argument, consider transforms only when it is a // replicated input (defining ops will be outside the replicate node). if (maybe_replicate != block_arg.getParentRegion()->getParentOp() || diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_host_computation_expansion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_host_computation_expansion.cc index fdea45957eb7d8..b2a3b81f63a1a9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_host_computation_expansion.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_host_computation_expansion.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -41,7 +42,7 @@ bool HasOutsideCompilationAttribute(Operation* op) { // Finds op that created a given value. If the value is a BlockArgument, this // returns the owner of the Block. Operation* GetOpOfValue(Value value) { - if (auto block_arg = value.dyn_cast()) + if (auto block_arg = mlir::dyn_cast(value)) return block_arg.getOwner()->getParentOp(); return value.getDefiningOp(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_partitioned_op_conversion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_partitioned_op_conversion.cc index a2232f9f33bf2a..08165fb1435ff2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_partitioned_op_conversion.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_partitioned_op_conversion.cc @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -61,12 +62,12 @@ LogicalResult ReplacePartitionedOp(IntegerAttr num_cores_per_replica, T op) { } auto element_type = getElementTypeOrSelf(first_operand_type); - if (element_type.isa()) { + if (mlir::isa(element_type)) { first_operand_type = - element_type.cast().getSubtypes().front(); + mlir::cast(element_type).getSubtypes().front(); } - auto tensor_type = first_operand_type.dyn_cast_or_null(); + auto tensor_type = mlir::dyn_cast_or_null(first_operand_type); if (!(tensor_type && tensor_type.hasRank())) { return op->emitError() << "cannot convert op with unranked or non-tensor input type " diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc index fb2588f50631e8..9b2beeb26dfb30 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc @@ -176,7 +176,7 @@ LogicalResult VerifySharding(mlir::Type type, // verify shardings that actually break a tensor apart. return success(); } - if (RankedTensorType ranked_type = type.dyn_cast()) { + if (RankedTensorType ranked_type = mlir::dyn_cast(type)) { const int64_t tensor_rank = ranked_type.getRank(); int tile_assignment_rank = sharding->tile_assignment_dimensions_size(); @@ -461,13 +461,13 @@ std::optional GetXlaShardingFromRetval( llvm::dyn_cast(call_op.resolveCallable()); if (!func) continue; value_to_visit = func.front().getTerminator()->getOperand( - value_to_visit.cast().getResultNumber()); + mlir::cast(value_to_visit).getResultNumber()); values_to_visit.push_back(value_to_visit); continue; } if (auto while_op = llvm::dyn_cast(def)) { - if (auto op_result = value_to_visit.cast()) { + if (auto op_result = mlir::cast(value_to_visit)) { int result_idx = op_result.getResultNumber(); if (auto yield_op = llvm::dyn_cast( while_op.getBody().front().getTerminator())) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc index c6ce1428bfb3e4..ef16273e9eea45 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_space_to_depth_pass.cc @@ -90,7 +90,7 @@ LogicalResult HandleCast(TF::CastOp cast_op, ArrayRef new_shape) { auto transform_result_type = RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input)); cast_input.setType(transform_result_type); - auto block_arg = cast_input.dyn_cast(); + auto block_arg = mlir::dyn_cast(cast_input); auto cast_op_input = dyn_cast_or_null(cast_input.getDefiningOp()); while (block_arg || cast_op_input) { if (block_arg) { @@ -105,7 +105,7 @@ LogicalResult HandleCast(TF::CastOp cast_op, ArrayRef new_shape) { RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input)); cast_input.setType(transform_result_type); // Update block arg and cast_op_input. - block_arg = cast_input.dyn_cast(); + block_arg = mlir::dyn_cast(cast_input); cast_op_input = dyn_cast_or_null(cast_input.getDefiningOp()); } } @@ -114,7 +114,7 @@ LogicalResult HandleCast(TF::CastOp cast_op, ArrayRef new_shape) { // Handles padding before convolution for space to depth transform. LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) { - auto ranked_type = op.getInput().getType().dyn_cast(); + auto ranked_type = mlir::dyn_cast(op.getInput().getType()); if (!ranked_type) return failure(); auto pad_input_shape = ranked_type.getShape(); Location loc = op.getLoc(); @@ -164,7 +164,7 @@ void HandleConv2DStride(TF::Conv2DOp conv2d) { // Transforms input shape for the first convolution. void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) { auto input = conv2d.getInput(); - auto input_shape = input.getType().cast().getShape(); + auto input_shape = mlir::cast(input.getType()).getShape(); SmallVector transform_shape = { input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size, input_shape[3] * block_size * block_size}; @@ -228,7 +228,7 @@ void HandleConv2DFilter(TF::Conv2DOp conv2d, int64_t block_size) { OpBuilder builder(conv2d); builder.setInsertionPoint(conv2d); // Book keeping filter information. - auto filter_shape = filter.getType().cast().getShape(); + auto filter_shape = mlir::cast(filter.getType()).getShape(); int64_t height = filter_shape[0]; int64_t width = filter_shape[1]; int64_t channel = filter_shape[2]; @@ -422,7 +422,7 @@ bool HandleHostReplicatedInputs(int64_t index, } for (auto entry : llvm::enumerate(inputs)) { Value input = entry.value().get(); - auto ranked_type = input.getType().dyn_cast(); + auto ranked_type = mlir::dyn_cast(input.getType()); if (!ranked_type) return false; auto input_shape = ranked_type.getShape(); auto space_to_depth = @@ -442,7 +442,7 @@ void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size, llvm::SmallVector transform_input_indices; for (const auto& input : llvm::enumerate(cluster_func.getOperands())) { - if (auto block_arg = input.value().dyn_cast()) { + if (auto block_arg = mlir::dyn_cast(input.value())) { if (block_arg.getArgNumber() != arg_num) continue; // For a block argument, consider transforms only when it is a replicated // input (defining ops will be outside the replicate node). @@ -462,7 +462,8 @@ void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size, continue; } if (!IsSupportedHostInputOp(input_op)) continue; - auto ranked_type = input.value().getType().dyn_cast(); + auto ranked_type = + mlir::dyn_cast(input.value().getType()); if (!ranked_type) continue; auto input_shape = ranked_type.getShape(); HandleHostInput(input.value(), input.index(), cluster_func, block_size, @@ -473,7 +474,7 @@ void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size, // Checks if input shape of convolution is good for space to depth transform. bool Conv2DInputShapeCanTransform(Value input) { - auto ranked_type = input.getType().dyn_cast(); + auto ranked_type = mlir::dyn_cast(input.getType()); if (!ranked_type) return false; auto input_shape = ranked_type.getShape(); int32_t batch_size = input_shape[0]; @@ -486,7 +487,7 @@ bool Conv2DInputShapeCanTransform(Value input) { // Get block argument id and number of users for the input arg. std::optional GetBlockArgNum(Value arg) { - if (auto block_arg = arg.dyn_cast()) { + if (auto block_arg = mlir::dyn_cast(arg)) { if (!Conv2DInputShapeCanTransform(arg)) return std::nullopt; unsigned num_users = std::distance(block_arg.getUsers().begin(), block_arg.getUsers().end()); @@ -540,9 +541,9 @@ std::optional GetConv2DInputArgNum(TF::Conv2DOp conv2d) { void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { // Check if input and filter type are RankedTensorType. auto input_tensor_type = - conv2d.getInput().getType().dyn_cast(); + mlir::dyn_cast(conv2d.getInput().getType()); auto filter_tensor_type = - conv2d.getFilter().getType().dyn_cast(); + mlir::dyn_cast(conv2d.getFilter().getType()); if (!input_tensor_type || !filter_tensor_type) return; // Book keeping filter shape for padding and backprop filter rewrite. auto filter_shape = filter_tensor_type.getShape(); @@ -550,7 +551,7 @@ void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { filter_shape.end()); // Handles input. auto conv2d_input = conv2d.getInput(); - if (auto block_arg = conv2d_input.dyn_cast()) { + if (auto block_arg = mlir::dyn_cast(conv2d_input)) { // Change on device function type/shape. HandleFuncOp(block_arg.getOwner()->getParentOp()); } @@ -559,7 +560,7 @@ void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { // Rewrite pad_op before Convolutioin. if (failed(HandlePad(pad_op, filter_shape[0], block_size))) return; auto pad_input = pad_op.getInput(); - if (auto block_arg = pad_input.dyn_cast()) { + if (auto block_arg = mlir::dyn_cast(pad_input)) { // Change on device function type/shape. HandleFuncOp(block_arg.getOwner()->getParentOp()); } @@ -573,7 +574,7 @@ void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { // Book keeping new filter shape for backprop filter rewrite. // Filter shape is defined in HandleConv2DFilter, thus it is RankedTensorType. filter_shape = - conv2d.getFilter().getType().cast().getShape(); + mlir::cast(conv2d.getFilter().getType()).getShape(); SmallVector new_filter_shape(filter_shape.begin(), filter_shape.end()); @@ -593,7 +594,7 @@ void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) { int32_t GetConv2DBlockSize(TF::Conv2DOp conv2d) { SmallVector strides(4, 1); for (int i = 0; i < 3; ++i) { - strides[i] = conv2d.getStrides()[i].cast().getInt(); + strides[i] = mlir::cast(conv2d.getStrides()[i]).getInt(); } // Space to depth only supports striding at spatial dimension. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_validate_inputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_validate_inputs.cc index 4dc9daa6c705ee..bfc5437b140e32 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_validate_inputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_validate_inputs.cc @@ -372,7 +372,8 @@ bool CheckOpsClusterIO(Operation* op, MetadataMap& metadata_map) { bool TypeMustBeNonXLA(const Type& type) { const Type elem = getElementTypeOrSelf(type); - return !elem.isa() && !tensorflow::TypeValidForXLA(type); + return !mlir::isa(elem) && + !tensorflow::TypeValidForXLA(type); } // Check if the op cannot be XLA compiled. If the op does not satisfy this diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc index abdd1a83d516eb..ff8ac1ad7cacd1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.cc @@ -96,7 +96,7 @@ TF::ReshapeOp ConvertTFBatchMatMulOp::createReshapeOp( template std::vector ConvertTFBatchMatMulOp::sliceInput( Value value, int batch_size, Location loc, PatternRewriter& rewriter) { - RankedTensorType tensorType = value.getType().cast(); + RankedTensorType tensorType = mlir::cast(value.getType()); Type element_type = tensorType.getElementType(); int rank = tensorType.getShape().size(); @@ -150,17 +150,17 @@ LogicalResult ConvertTFBatchMatMulOp::matchAndRewrite( Value input_lhs = op.getX(); Value input_rhs = op.getY(); - if (!input_lhs.getType().isa()) { + if (!mlir::isa(input_lhs.getType())) { // LHS must be a ranked tensor type return failure(); } - if (!input_rhs.getType().isa()) { + if (!mlir::isa(input_rhs.getType())) { // RHS must be a ranked tensor type return failure(); } - auto lhs_type = input_lhs.getType().cast(); - auto rhs_type = input_rhs.getType().cast(); + auto lhs_type = mlir::cast(input_lhs.getType()); + auto rhs_type = mlir::cast(input_rhs.getType()); // Skip int8 x int8 => int32. if (lhs_type.getElementType().isInteger(8) && diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc index 20dcdb8b034c97..9237ff8d5b69dd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_deserialization.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo // IWYU pragma: keep #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep @@ -156,8 +157,8 @@ LogicalResult SymbolizeCustomCallCalledIndex( return WalkResult::interrupt(); } - auto called_index_attr = backend_config.get(kCalledIndexAttrName) - .dyn_cast_or_null(); + auto called_index_attr = mlir::dyn_cast_or_null( + backend_config.get(kCalledIndexAttrName)); if (!called_index_attr) { op->emitOpError() << "is missing attribute '" << kCalledIndexAttrName << "'"; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_serialization.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_serialization.cc index a75bf4c75d8033..6ab5da6bdb2e3c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_serialization.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_call_module_serialization.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "stablehlo/api/PortableApi.h" // from @stablehlo #include "stablehlo/dialect/Serialization.h" // from @stablehlo #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep @@ -66,8 +67,8 @@ FailureOr DesymbolizeCustomCallCalledIndex(ModuleOp module) { << "'"; return WalkResult::interrupt(); } - auto called_func = backend_config.get(kCalledFuncAttrName) - .dyn_cast_or_null(); + auto called_func = mlir::dyn_cast_or_null( + backend_config.get(kCalledFuncAttrName)); if (!called_func) { op->emitOpError() << "is missing attribute '" << kCalledFuncAttrName << "'"; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc b/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc index 1992f43a951184..8ce264b47b57d4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/xla_rewrite.cc @@ -18,6 +18,7 @@ limitations under the License. #include +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -44,8 +45,7 @@ void MoveResourceArgsToEnd(func::FuncOp callee) { // Copy the resource-type parameters to the end. for (unsigned i = 0; i < num_params; ++i) { BlockArgument param = callee.getArgument(i); - if (getElementTypeOrSelf(param.getType()) - .template isa()) { + if (mlir::isa(getElementTypeOrSelf(param.getType()))) { removed_params.set(i); callee.getBody().addArgument(param.getType(), param.getLoc()); param.replaceAllUsesWith(callee.getArguments().back()); @@ -65,7 +65,7 @@ void RewriteCall(tf_device::ClusterFuncOp cluster_func_op, SymbolTable &symtab, llvm::SmallVector non_resource_args, resource_args; bool has_resources = false, in_order = true; for (const Value &arg : cluster_func_op.getOperands()) { - if (!getElementTypeOrSelf(arg.getType()).template isa()) { + if (!mlir::isa(getElementTypeOrSelf(arg.getType()))) { non_resource_args.push_back(arg); if (has_resources) in_order = false; } else { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/BUILD b/tensorflow/compiler/mlir/tensorflow/translate/BUILD index 59d7cfd7081106..6e3a8377b3d81e 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/translate/BUILD @@ -107,6 +107,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_xla//xla:status_macros", ], ) @@ -209,6 +210,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 523048cd7cd582..893b5b3ada6b94 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -42,6 +42,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -205,9 +206,9 @@ StatusOr> Exporter::GetArgumentNode( node_def->set_op(FunctionLibraryDefinition::kArgOp); - mlir::TensorType arg_type = arg.getType().cast(); + mlir::TensorType arg_type = mlir::cast(arg.getType()); if (auto resource_type = - arg_type.getElementType().dyn_cast()) { + mlir::dyn_cast(arg_type.getElementType())) { llvm::ArrayRef subtypes = resource_type.getSubtypes(); if (!subtypes.empty()) { AttrValue handle_dtypes_attr; @@ -266,7 +267,8 @@ StatusOr> Exporter::GetReturnNode( node_def->set_op(FunctionLibraryDefinition::kRetOp); DataType dtype; TF_RETURN_IF_ERROR(ConvertToDataType( - operand.getType().cast().getElementType(), &dtype)); + mlir::cast(operand.getType()).getElementType(), + &dtype)); AttrValue type_attr; type_attr.set_type(dtype); (*node_def->mutable_attr())["T"] = type_attr; @@ -290,7 +292,7 @@ StatusOr> Exporter::GetReturnNode( Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node, unsigned dst_index) { - if (auto input_result = src.dyn_cast()) { + if (auto input_result = mlir::dyn_cast(src)) { auto* input_inst = GetIslandInnerOpOrSelf(input_result.getOwner()); // Replaces the input node with NextIteration sink if it is a NextIteration // source. @@ -302,7 +304,7 @@ Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node, auto node_it = nodes_.find(input_inst); TF_RET_CHECK(node_it != nodes_.end()) << "Use of OpResult encountered before def!"; - if (input_result.getType().isa()) { + if (mlir::isa(input_result.getType())) { graph_->AddControlEdge(node_it->second, dst_node, /*allow_duplicates=*/true); } else { @@ -312,7 +314,7 @@ Status Exporter::AddEdgeBetweenNodes(Value src, Node* dst_node, return OkStatus(); } - auto input_arg = src.cast(); + auto input_arg = mlir::cast(src); auto input_node_it = args_.find(input_arg); TF_RET_CHECK(input_node_it != args_.end()) << "Use of BlockArgument encounted before def!"; @@ -327,7 +329,7 @@ Status Exporter::AddEdge(Operation* inst) { if (auto fetch = llvm::dyn_cast(inst)) { for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) { Value operand = operand_and_idx.value(); - if (operand.getType().isa()) break; + if (mlir::isa(operand.getType())) break; auto* dst_node = returns_[fetch][operand_and_idx.index()]; TF_RETURN_IF_ERROR(AddEdgeBetweenNodes(operand, dst_node, 0)); @@ -447,7 +449,8 @@ Status Exporter::AddFetchNode(FuncOp function, mlir::tf_executor::FetchOp fetch, llvm::ArrayRef names) { auto& return_nodes = returns_[fetch]; for (auto operand_and_idx : llvm::enumerate(fetch.getOperands())) { - if (operand_and_idx.value().getType().isa()) + if (mlir::isa( + operand_and_idx.value().getType())) break; TF_ASSIGN_OR_RETURN( @@ -467,7 +470,7 @@ Status Exporter::GetControlRetNodes( mlir::tf_executor::FetchOp fetch, absl::flat_hash_set* control_ret_nodes) { for (Value fetch_operand : fetch.getOperands()) { - if (fetch_operand.getType().isa()) { + if (mlir::isa(fetch_operand.getType())) { Operation* defining_op = GetIslandInnerOpOrSelf(fetch_operand.getDefiningOp()); auto node_it = nodes_.find(defining_op); @@ -509,14 +512,16 @@ StatusOr> Exporter::Convert( auto dict_attr = function->getAttrOfType(kEntryFuncAttr); if (dict_attr) { - TF_RET_CHECK(dict_attr.get("inputs").isa()) + TF_RET_CHECK(mlir::isa(dict_attr.get("inputs"))) << "inputs missing in entry function attribute"; - TF_RET_CHECK(dict_attr.get("outputs").isa()) + TF_RET_CHECK(mlir::isa(dict_attr.get("outputs"))) << "outputs missing in entry function attribute"; - dict_attr.get("inputs").cast().getValue().split( - input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); - dict_attr.get("outputs").cast().getValue().split( - output_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); + mlir::cast(dict_attr.get("inputs")) + .getValue() + .split(input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); + mlir::cast(dict_attr.get("outputs")) + .getValue() + .split(output_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); } auto graph = std::make_unique(OpRegistry::Global()); @@ -582,7 +587,7 @@ StatusOr> Exporter::Convert( int index = it.index(); auto arg = it.value(); mlir::Type type = arg.getType(); - if (!type.isa()) { + if (!mlir::isa(type)) { return errors::InvalidArgument( "FuncOps arguments must have tensor types. Found ", mlir::debugString(type), " in function ", function.getName().str()); @@ -607,8 +612,8 @@ StatusOr> Exporter::Convert( // Adds nodes for operations. for (Operation& inst : graph_op.GetBody()) { for (auto type : inst.getResultTypes()) - if (!type.isa()) + if (!mlir::isa(type)) return errors::InvalidArgument( "Values must be of tensor type, TensorFlow control type, or " "TensorFlow token type. Found ", diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc index 6ce83519a0fe6d..1734755ee1fa90 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.cc @@ -25,6 +25,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/export_utils.h" @@ -183,7 +184,7 @@ Status PopulateDerivedAttributes(mlir::Operation* inst, llvm::StringRef name, auto values = inst->getResults(); auto begin = values.begin(); auto end = values.begin(); - while (end != values.end() && (*end).getType().isa()) + while (end != values.end() && mlir::isa((*end).getType())) end++; if (begin != end) { mlir::TF::ResultShapeRange output_shapes = { diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 42b059cbd0a527..7809c2bd69fab6 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -68,6 +68,7 @@ limitations under the License. #include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/cc/saved_model/loader_util.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" @@ -1228,7 +1229,7 @@ absl::StatusOr ImporterBase::InferOutputType( TF_ASSIGN_OR_RETURN( auto etype, ConvertToMlirTensorType(shape_proto, dtype, &builder)); return mlir::UnrankedTensorType::get(mlir::TF::ResourceType::get( - {etype.cast()}, builder.getContext())); + {mlir::cast(etype)}, builder.getContext())); } else { return mlir::UnrankedTensorType::get( mlir::TF::ResourceType::get(builder.getContext())); @@ -2000,7 +2001,7 @@ mlir::Operation* ImporterBase::CreateOperation( record_resource = [&](mlir::Type type) { type.walk([&](mlir::Type t) { if (resource) return mlir::WalkResult::interrupt(); - if (type.isa()) { + if (mlir::isa(type)) { resource = true; return mlir::WalkResult::interrupt(); } @@ -3187,10 +3188,10 @@ void StructuredValueLinearizer::RecursivelyFindLeaves( << " at index path: "; for (auto path_element : current_index_path_) { os << "."; - if (auto integer = path_element.dyn_cast()) { + if (auto integer = mlir::dyn_cast(path_element)) { os << integer.getValue(); } else { - auto str = path_element.cast(); + auto str = mlir::cast(path_element); os << str.getValue(); } } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index b0759da88e4ced..ed604a9290a4a4 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/cc/saved_model/reader.h" #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" @@ -263,7 +264,7 @@ GraphdefToSplattedMlirTranslateFunction( if (auto attr = inst.getAttrOfType(attr_id)) { mlir::Attribute rand_val; mlir::Type element_type = attr.getShapedType().getElementType(); - if (element_type.isa()) { + if (mlir::isa(element_type)) { rand_val = mlir::IntegerAttr::get(element_type, std::rand()); } else if (element_type.isF16() || element_type.isF32() || element_type.isF64()) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h index 5a99806d4295f3..0771b529465a94 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/tf2xla/tf2xla_defs.h" namespace mlir { @@ -167,7 +168,7 @@ class IdentityNOp; // as an attribute. template bool GetValueAsConstant(Value val, AttrT &attr) { - while (auto result = val.dyn_cast()) { + while (auto result = mlir::dyn_cast(val)) { Operation *op = result.getOwner(); if (!isa(op) && !isa(op)) break; val = op->getOperand(result.getResultNumber()); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.cc index fd3c00a3873e5c..030b8ae7575a40 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/call_graph_util.cc @@ -19,6 +19,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" namespace mlir { @@ -54,7 +55,7 @@ llvm::SmallVector GetEntryFunctions(ModuleOp module) { LogicalResult GetCallees(SymbolUserOpInterface op, SymbolTable &symtab, llvm::SmallVector &callees) { for (auto attr : op->getAttrs()) { - auto sym = attr.getValue().dyn_cast(); + auto sym = mlir::dyn_cast(attr.getValue()); if (!sym) continue; auto callee = symtab.lookup(sym.getRootReference()); if (!callee) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index 10e882192cfdf3..d1f68c1cd9ddd6 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" @@ -242,12 +243,12 @@ void ConvertToTensorShapeProto(ArrayRef shape, } PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type) { - if (type.isa()) { + if (mlir::isa(type)) { // An empty PartialTensorShape indicates an unranked tensor. return PartialTensorShape(); } - if (auto tensor_type = type.dyn_cast()) { + if (auto tensor_type = mlir::dyn_cast(type)) { TensorShapeProto tensor_shape_proto; ConvertToTensorShapeProto(tensor_type.getShape(), &tensor_shape_proto); return PartialTensorShape(tensor_shape_proto); @@ -259,11 +260,11 @@ PartialTensorShape ConvertTypeToTensorShape(const mlir::Type& type) { } mlir::TF::ShapeAttr ConvertTypeToTensorShapeAttr(const mlir::Type& type) { - if (type.isa()) { + if (mlir::isa(type)) { return mlir::TF::ShapeAttr::get(type.getContext(), std::nullopt); } - if (auto tensor_type = type.dyn_cast()) { + if (auto tensor_type = mlir::dyn_cast(type)) { return mlir::TF::ShapeAttr::get(type.getContext(), tensor_type.getShape()); } @@ -427,10 +428,10 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { output->set_dtype(output_dtype); ConvertToTensorShapeProto(shape, output->mutable_tensor_shape()); - if (auto tensor_attr = attr.dyn_cast()) + if (auto tensor_attr = mlir::dyn_cast(attr)) return ConvertTensorProtoAttr(tensor_attr, output); - auto dense_attr = attr.dyn_cast(); + auto dense_attr = mlir::dyn_cast(attr); if (!dense_attr) return errors::InvalidArgument("Unsupported elements attr"); switch (output_dtype) { @@ -496,7 +497,7 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { output->mutable_tensor_content()); break; case DT_STRING: - ConvertStringElementsAttr(dense_attr.cast(), + ConvertStringElementsAttr(mlir::cast(dense_attr), output->mutable_string_val()); break; case DT_UINT8: diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc index f3c51f88fc7630..3feed8904fab0e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "xla/test.h" @@ -97,8 +98,8 @@ TEST(ConvertTypeToTensorTypeTest, ConvertStringTensor) { ASSERT_TRUE(value_or_status.ok()); auto attr = value_or_status.value(); - EXPECT_TRUE(attr.isa()); - auto string_attr = attr.cast(); + EXPECT_TRUE(mlir::isa(attr)); + auto string_attr = mlir::cast(attr); auto string_values = string_attr.getRawStringData(); ASSERT_EQ(string_values.size(), 4); EXPECT_EQ(string_values[0], mlir::StringRef("one")); @@ -191,7 +192,7 @@ TEST_F(ConvertTensorTest, Simple) { } bool IsSplat(mlir::ElementsAttr attr) { - return attr.cast().isSplat(); + return mlir::cast(attr).isSplat(); } TEST(ConvertTensorProtoTest, SplatTensor) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc index 880501c3e89554..cc7179ffb3968c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/core/framework/types.h" @@ -124,7 +125,7 @@ Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { } else if (type.isFloat8E5M2()) { *dtype = DT_FLOAT8_E5M2; return OkStatus(); - } else if (auto itype = type.dyn_cast()) { + } else if (auto itype = mlir::dyn_cast(type)) { switch (itype.getWidth()) { case 1: *dtype = DT_BOOL; @@ -148,7 +149,7 @@ Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { return errors::Unimplemented( absl::StrCat("Converting ", debugString(type), " to DataType")); } - } else if (auto complex_type = type.dyn_cast()) { + } else if (auto complex_type = mlir::dyn_cast(type)) { auto etype = complex_type.getElementType(); if (etype.isF32()) { *dtype = DT_COMPLEX64; @@ -174,7 +175,7 @@ Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { } Status ConvertToDataType(Type type, DataType* dtype) { - if (auto stype = type.dyn_cast()) { + if (auto stype = mlir::dyn_cast(type)) { TF_RETURN_IF_ERROR( ConvertScalarTypeToDataType(stype.getElementType(), dtype)); } else { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc index 51db1be0820761..d9249d472b334c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_set.h" @@ -67,7 +68,7 @@ mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op, for (const auto& kv : llvm::enumerate(array_attr)) { const int idx = kv.index(); - auto string_attr = kv.value().dyn_cast(); + auto string_attr = mlir::dyn_cast(kv.value()); if (!string_attr) return op->emitOpError(llvm::formatv( "bad '{0}' attribute at index {1}, not a string", kDevicesAttr, idx)); @@ -100,7 +101,7 @@ mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op, llvm::formatv("bad '{0}' attribute, '{1}', not a valid device", kDevicesAttr, name.strref())); - if (auto gpu_metadata = attr.dyn_cast()) { + if (auto gpu_metadata = mlir::dyn_cast(attr)) { devices->AddGpuDevice(device, gpu_metadata); } else { devices->AddDevice(device); @@ -144,10 +145,11 @@ mlir::LogicalResult GetDevicesFromOp(mlir::Operation* op, auto devices_attr = op->getAttr(kDevicesAttr); if (!devices_attr) return mlir::success(); - if (auto array_attr = devices_attr.dyn_cast()) { + if (auto array_attr = mlir::dyn_cast(devices_attr)) { return GetDevicesFromOp(op, array_attr, devices); - } else if (auto dict_attr = devices_attr.dyn_cast()) { + } else if (auto dict_attr = + mlir::dyn_cast(devices_attr)) { return GetDevicesFromOp(op, dict_attr, devices); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc index 326dbbb4781602..f089ec111991e7 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_set.h" @@ -87,18 +88,18 @@ TEST(DeviceUtilTest, AddDeviceToOp) { ASSERT_EQ(devices_attr.size(), 3); // CPU device added with an empty metadata. - auto device_meta_0 = devices_attr.get(cpu0).dyn_cast(); + auto device_meta_0 = mlir::dyn_cast(devices_attr.get(cpu0)); ASSERT_NE(device_meta_0, nullptr); // GPU device successfully parsed compute capability from description. auto device_meta_1 = - devices_attr.get(gpu0).dyn_cast(); + mlir::dyn_cast(devices_attr.get(gpu0)); ASSERT_NE(device_meta_1, nullptr); ASSERT_EQ(device_meta_1.getCcMajor(), 7); ASSERT_EQ(device_meta_1.getCcMinor(), 0); // If description is empty GPU devices added with an empty metadata. - auto device_meta_2 = devices_attr.get(gpu1).dyn_cast(); + auto device_meta_2 = mlir::dyn_cast(devices_attr.get(gpu1)); ASSERT_NE(device_meta_2, nullptr); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc index 6a66067920fdcb..f0dd8f1c748a25 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/util/managed_stack_trace.h" @@ -33,7 +34,7 @@ StatusScopedDiagnosticHandler::StatusScopedDiagnosticHandler( this->shouldShowLocFn = [](Location loc) -> bool { // For a Location to be surfaced in the stack, it must evaluate to true. // For any Location that is a FileLineColLoc: - if (FileLineColLoc fileLoc = loc.dyn_cast()) { + if (FileLineColLoc fileLoc = mlir::dyn_cast(loc)) { return !tensorflow::IsInternalFrameForFilename( fileLoc.getFilename().str()); } else { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index f01a3f0e09d19b..cd6b2e0f8fa8b3 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -39,6 +39,7 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -81,9 +82,9 @@ Status ConvertLocation(mlir::Location inst_loc, llvm::StringRef node_name, NodeDef::ExperimentalDebugInfo* debug_info) { mlir::Location unwrapped_inst_loc = GetLocationWithoutOpType(inst_loc); - if (auto call_site = unwrapped_inst_loc.dyn_cast()) { - if (auto name_loc = GetLocationWithoutOpType(call_site.getCallee()) - .dyn_cast()) { + if (auto call_site = mlir::dyn_cast(unwrapped_inst_loc)) { + if (auto name_loc = mlir::dyn_cast( + GetLocationWithoutOpType(call_site.getCallee()))) { llvm::StringRef original_node_name, original_func_name; std::tie(original_node_name, original_func_name) = name_loc.getName().strref().split('@'); @@ -96,7 +97,7 @@ Status ConvertLocation(mlir::Location inst_loc, llvm::StringRef node_name, debug_info->add_original_func_names(original_func_name.str()); } } - } else if (auto fused = unwrapped_inst_loc.dyn_cast()) { + } else if (auto fused = mlir::dyn_cast(unwrapped_inst_loc)) { auto locations = fused.getLocations(); if (locations.size() <= 1) return errors::InvalidArgument("expected experimental debuf info."); @@ -145,8 +146,8 @@ Status ConvertAttribute(const mlir::FlatSymbolRefAttr& attr, AttrValue* value) { Status ConvertAttribute(const mlir::TF::FuncAttr& attr, bool remove_ref_type, AttrValue* value) { - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.getName().cast(), value)); + TF_RETURN_IF_ERROR(ConvertAttribute( + mlir::cast(attr.getName()), value)); TF_RETURN_IF_ERROR(ConvertAttributes(attr.getAttrs().getValue(), /*attrs_to_ignore=*/{}, remove_ref_type, value->mutable_func()->mutable_attr())); @@ -199,13 +200,13 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, bool remove_ref_type, AttrValue* value) { auto* list = value->mutable_list(); for (mlir::Attribute a : attr.getValue()) { - if (auto attr = a.dyn_cast()) { + if (auto attr = mlir::dyn_cast(a)) { list->add_b(attr.getValue()); - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { list->add_i(attr.getInt()); - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { list->add_f(attr.getValueAsDouble()); - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { AttrValue nested_value; TF_RETURN_IF_ERROR(ConvertAttribute(attr, &nested_value)); switch (nested_value.value_case()) { @@ -221,32 +222,32 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, bool remove_ref_type, default: return errors::Unimplemented("Unhandled nested attribute!"); } - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { TensorProto tensor; TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor)); *list->add_tensor() = tensor; - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { AttrValue attr_val; TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val)); *list->add_func() = attr_val.func(); - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { AttrValue attr_val; // For type attributes, we only propagate the element type. mlir::Type elt_type = attr.getValue(); - if (auto shaped_type = elt_type.dyn_cast()) { + if (auto shaped_type = mlir::dyn_cast(elt_type)) { elt_type = shaped_type.getElementType(); } TF_RETURN_IF_ERROR( ConvertAttribute(elt_type, remove_ref_type, &attr_val)); list->add_type(attr_val.type()); - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { AttrValue attr_val; TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val)); *list->add_shape() = attr_val.shape(); - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { std::vector vals; for (mlir::Attribute a : attr.getValue()) { - auto i = a.dyn_cast(); + auto i = mlir::dyn_cast(a); if (!i) return errors::Unimplemented( "Expected 64-bit integer array attributes!"); @@ -274,21 +275,21 @@ Status ConvertAttribute(const mlir::ArrayAttr& attr, bool remove_ref_type, static bool IsRefTypeControlOp(mlir::Operation* op) { if (auto next_iter_sink = llvm::dyn_cast(op)) - return mlir::getElementTypeOrSelf(next_iter_sink.getInput().getType()) - .isa(); + return mlir::isa( + mlir::getElementTypeOrSelf(next_iter_sink.getInput().getType())); auto op_name_or_status = GetTensorFlowOpName(op->getName().getStringRef()); if (!op_name_or_status.ok()) return false; auto op_name = std::move(op_name_or_status).value(); if (op_name.equals("NextIteration")) - return mlir::getElementTypeOrSelf(op->getOperand(0).getType()) - .isa(); + return mlir::isa( + mlir::getElementTypeOrSelf(op->getOperand(0).getType())); if (op_name.equals("Enter") || op_name.equals("Exit") || op_name.equals("Switch") || op_name.equals("Merge")) { - return getElementTypeOrSelf(op->getResult(0).getType()) - .isa(); + return mlir::isa( + getElementTypeOrSelf(op->getResult(0).getType())); } return false; } @@ -393,18 +394,18 @@ Status ConvertAttributes( name = mangling_util::DemangleAttributeName(name); } AttrValue value; - if (auto symbol_ref = attr.dyn_cast()) { - TF_RETURN_IF_ERROR( - ConvertAttribute(symbol_ref.cast(), &value)); + if (auto symbol_ref = mlir::dyn_cast(attr)) { + TF_RETURN_IF_ERROR(ConvertAttribute( + mlir::cast(symbol_ref), &value)); func_call_attrs[string(name)] = std::move(value); continue; } - if (auto func_attr = attr.dyn_cast()) { + if (auto func_attr = mlir::dyn_cast(attr)) { TF_RETURN_IF_ERROR(ConvertAttribute(func_attr, remove_ref_type, &value)); func_call_attrs[string(name)] = std::move(value); continue; } - if (attr.isa()) { + if (mlir::isa(attr)) { // AffineMapAttr is not implemented. return errors::Unimplemented("AffineMap attribute (needed for '", name_strref, "') unimplemented"); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/location_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/location_utils.cc index 2a6ff2921a4ad5..afaa78640af3d9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/location_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/location_utils.cc @@ -17,15 +17,16 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace tensorflow { mlir::Location GetLocationWithoutOpType(mlir::Location loc) { - if (auto fused_loc = loc.dyn_cast()) { + if (auto fused_loc = mlir::dyn_cast(loc)) { auto locations = fused_loc.getLocations(); if (!locations.empty()) { // Skip locations for propagating op_type metadata. - if (auto name_loc = locations[0].dyn_cast()) { + if (auto name_loc = mlir::dyn_cast(locations[0])) { if (name_loc.getName().strref().ends_with(":")) { if (locations.size() == 2) return locations[1]; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc index 2895ebdc9c6424..9e8db314f51b0d 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/session_utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/utils/string_container_utils.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/framework/device.h" @@ -32,7 +33,7 @@ std::string GetVariableName(TF::VarHandleOp var_handle_op) { // In some cases the shared_name attribute doesn't have the same // tensor name in the model, so we first try to use the location // then fallback to shared_name attribute. - if (auto loc = var_handle_op->getLoc().dyn_cast()) + if (auto loc = mlir::dyn_cast(var_handle_op->getLoc())) return loc.getName().str(); return var_handle_op.getSharedName().str(); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.cc b/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.cc index 549b665f044314..6ab4aa64a89070 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/stablehlo_custom_call.cc @@ -17,6 +17,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project namespace mlir { @@ -63,7 +64,7 @@ FailureOr GetTfFuncCustomCallFuncName( return failure(); } - if (auto attr = f.dyn_cast()) { + if (auto attr = mlir::dyn_cast(f)) { return attr; } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc index 97f1093fe3d56b..5a29bae67afe01 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo @@ -396,11 +397,11 @@ SerializedMlirStringAttrToMlirModuleTranslate(llvm::StringRef input, // an output parameter is provided for returning the number of chars read. size_t numRead; mlir::Attribute attr = mlir::parseAttribute(input, context, {}, &numRead); - if (!attr || !attr.isa()) { + if (!attr || !mlir::isa(attr)) { LOG(ERROR) << "Input is not parsable as a MLIR StringAttr."; return nullptr; } - auto str_attr = attr.cast(); + auto str_attr = mlir::cast(attr); mlir::DialectRegistry registry; RegisterMlirInputDialects(registry); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc index c6ff5f5c93c6ef..690263319cf51e 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.cc @@ -36,6 +36,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -609,7 +610,7 @@ StatusOr> GetDeviceCoordinates( for (auto device_coordinate_and_idx : llvm::enumerate(device_assignment_attr)) { auto device_coordinate = - device_coordinate_and_idx.value().dyn_cast(); + mlir::dyn_cast(device_coordinate_and_idx.value()); if (!device_coordinate) return absl::InvalidArgumentError( llvm::formatv(kBadIntArrayElementMsg, kDeviceAssignmentAttr, @@ -733,8 +734,8 @@ bool IsTPUReplicatedCore(llvm::StringRef device) { bool TypeValidForXLA(const mlir::Type& type) { const mlir::Type elem = getElementTypeOrSelf(type); - return !elem.isa() && - !elem.isa(); + return !mlir::isa(elem) && + !mlir::isa(elem); } mlir::LogicalResult GetDeviceToHostMap( diff --git a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc index 988950389edf8b..74dd7803a48a20 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/translate_utils.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" #include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/platform/errors.h" namespace tensorflow { @@ -44,21 +45,21 @@ mlir::LogicalResult ExtractTfVersions(mlir::ModuleOp module, if (!version_attr) return mlir::failure(); auto producer = - version_attr.get("producer").dyn_cast_or_null(); + mlir::dyn_cast_or_null(version_attr.get("producer")); if (!producer) return mlir::failure(); versions->set_producer(producer.getInt()); - auto min_consumer = - version_attr.get("min_consumer").dyn_cast_or_null(); + auto min_consumer = mlir::dyn_cast_or_null( + version_attr.get("min_consumer")); if (min_consumer) versions->set_min_consumer(min_consumer.getInt()); - auto bad_consumers = - version_attr.get("bad_consumers").dyn_cast_or_null(); + auto bad_consumers = mlir::dyn_cast_or_null( + version_attr.get("bad_consumers")); if (!bad_consumers) return mlir::success(); for (auto bad_consumer : bad_consumers) { auto bad_consumer_int_attr = - bad_consumer.dyn_cast_or_null(); + mlir::dyn_cast_or_null(bad_consumer); if (!bad_consumer_int_attr) return mlir::failure(); versions->mutable_bad_consumers()->Add(bad_consumer_int_attr.getInt()); @@ -72,7 +73,7 @@ ::tsl::StatusOr GetTfGraphProducerVersion(mlir::ModuleOp module) { return errors::Internal( "Missing 'tf.versions' attribute on the module, abort.\n"); } - auto producer = versions.get("producer").dyn_cast(); + auto producer = mlir::dyn_cast(versions.get("producer")); if (!producer) { return errors::Internal( "Missing 'producer' attribute on the module, abort.\n"); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index ea76adb284b7e2..334cca591cf569 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -70,7 +70,7 @@ mlir::LogicalResult CreateSplitOp(const int num_split, // Correctly set output shapes of split op output if input shape is statically // known. mlir::Type output_type; - auto input_type = src_input.getType().cast(); + auto input_type = mlir::cast(src_input.getType()); if (input_type.hasRank()) { if (input_type.getShape()[split_dimension] == mlir::ShapedType::kDynamic) { @@ -122,7 +122,7 @@ mlir::TF::ConcatOp CreateConcatOp(const int concat_dimension, // across logical devices, we refer to the shape of 0th logical device // computation output. mlir::Type output_type; - auto input_type = inputs[0].getType().cast(); + auto input_type = mlir::cast(inputs[0].getType()); if (input_type.hasRank()) { if (input_type.getShape()[concat_dimension] == mlir::ShapedType::kDynamic) { @@ -294,9 +294,9 @@ mlir::LogicalResult DecodeShardingAttribute(const std::string& shard_str, mlir::LogicalResult DecodeShardingAttribute(mlir::Attribute shard_attr, xla::OpSharding& sharding, bool report_error) { - if (!shard_attr.isa()) return mlir::failure(); + if (!mlir::isa(shard_attr)) return mlir::failure(); - auto shard_str = shard_attr.cast().getValue().str(); + auto shard_str = mlir::cast(shard_attr).getValue().str(); return DecodeShardingAttribute(shard_str, sharding, report_error); } @@ -350,7 +350,8 @@ mlir::LogicalResult ExtractInputsForLogicalDevices( xla::OpSharding sharding; if (DecodeShardingAttribute( - sharding_attr.cast().getValue().str(), sharding) + mlir::cast(sharding_attr).getValue().str(), + sharding) .failed()) { return cluster_func.emitError("incorrect sharding format for inputs"); } @@ -443,13 +444,14 @@ mlir::LogicalResult ParseAndValidateOutputSharding( llvm::enumerate(output_sharding_attrs)) { const auto& output_sharding = output_sharding_and_index.value(); const int sharding_index = output_sharding_and_index.index(); - if (!output_sharding.isa()) + if (!mlir::isa(output_sharding)) return cluster_func.emitError(llvm::formatv( "non-string output sharding at index {0}", sharding_index)); xla::OpSharding sharding; if (DecodeShardingAttribute( - output_sharding.cast().getValue().str(), sharding) + mlir::cast(output_sharding).getValue().str(), + sharding) .failed()) { return cluster_func.emitError("incorrect sharding format for outputs"); } @@ -661,7 +663,7 @@ mlir::LogicalResult GetOutputTypesForLogicalDeviceComputation( const auto output_index = result_and_index.index(); const auto& output_sharding = output_sharding_config[output_index]; const auto cluster_func_output_type = - result_and_index.value().getType().cast(); + mlir::cast(result_and_index.value().getType()); // If output shape of cluster func is statically known and output is tiled // sharded, then the corresponding output shape of cluster func must be diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc index 322862828e63b3..7358b97971e0fe 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc @@ -175,13 +175,13 @@ Status GetXlaInputShapes( // bounded type by using the bounds as dimension sizes. Returns null if is // neither. mlir::RankedTensorType GetBufferType(mlir::Type ty) { - auto ranked_ty = ty.dyn_cast_or_null(); + auto ranked_ty = mlir::dyn_cast_or_null(ty); if (!ranked_ty) return {}; int64_t rank = ranked_ty.getRank(); llvm::SmallVector dims = llvm::to_vector<4>(ranked_ty.getShape()); - auto encoding = ranked_ty.getEncoding() - .dyn_cast_or_null(); + auto encoding = mlir::dyn_cast_or_null( + ranked_ty.getEncoding()); if (encoding && !encoding.getBounds().empty()) { for (int64_t dim = 0; dim < rank; ++dim) { if (dims[dim] == mlir::ShapedType::kDynamic) { @@ -234,7 +234,7 @@ Status GetOutputInfo( auto return_op = main_func.begin()->getTerminator(); for (const auto& type_and_idx : llvm::enumerate(func_type.getResults())) { size_t idx = type_and_idx.index(); - auto result_ty = type_and_idx.value().cast(); + auto result_ty = mlir::cast(type_and_idx.value()); // If the result type isn't static, then the owner of the result may be a // cast op from a more specific bounded type to an unbounded dynamic type. @@ -275,7 +275,8 @@ Status GetOutputInfo( TF_RETURN_IF_ERROR(MaybeRewriteLayoutWithShardedShape( sharding, shape_determination_fns, &shape)); - auto tensor_type = type_and_idx.value().dyn_cast(); + auto tensor_type = + mlir::dyn_cast(type_and_idx.value()); shapes.push_back(shape); auto it = output_to_input_alias.find(type_and_idx.index()); @@ -872,7 +873,7 @@ static absl::StatusOr> RewriteWithArgs( auto resource_type = mlir::TF::ResourceType::get({resource_subtype}, builder.getContext()); - auto tensor_type = mlir_arg.getType().cast(); + auto tensor_type = mlir::cast(mlir_arg.getType()); if (tensor_type.hasRank()) { mlir_arg.setType( GetTypeFromTFTensorShape(tensor_type.getShape(), resource_type)); diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_head_tail_outside_compilation.cc index ad85310291c146..e0dc7bda1f9c86 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_head_tail_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_head_tail_outside_compilation.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" @@ -78,7 +79,7 @@ bool HasOutsideCompilationAttribute(Operation* op) { // Finds op that created a given value. If the value is a BlockArgument, this // returns the owner of the Block. Operation* GetOpOfValue(Value value) { - if (auto block_arg = value.dyn_cast()) + if (auto block_arg = mlir::dyn_cast(value)) return block_arg.getOwner()->getParentOp(); return value.getDefiningOp(); diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_outside_compilation.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_outside_compilation.cc index 6bc3468a2729e3..10fbd371ca81b2 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/extract_outside_compilation.cc @@ -386,7 +386,7 @@ llvm::SmallSetVector GetStaticExternalOperands( } continue; } - auto block_arg = v.cast(); + auto block_arg = mlir::cast(v); if (block_arg.getParentRegion() == op->getParentRegion()) external_values.insert(v); } @@ -475,7 +475,7 @@ void GetExternalOutputs(const llvm::SmallSetVector& cluster_ops, LogicalResult GetShardShapedType(Operation* context_op, int num_cores_per_replica, Type full_type, Type& shard_type) { - RankedTensorType ranked_type = full_type.dyn_cast(); + RankedTensorType ranked_type = mlir::dyn_cast(full_type); if (!ranked_type) return context_op->emitOpError() << "A map_outside_compilation op's input and output types must be " @@ -587,7 +587,8 @@ LogicalResult CreateHostComputeMap( // Convert MANUAL sharded outputs to split sharded outputs. for (auto [full_type, out] : llvm::zip(full_output_types, host_compute.getResults())) { - RankedTensorType full_type_ranked = full_type.dyn_cast(); + RankedTensorType full_type_ranked = + mlir::dyn_cast(full_type); if (!full_type_ranked) return original_op->emitOpError() << "map_outside_compilation must have ranked outputs"; diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc index 732bae8c67b018..f16df445439084 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/hoist_broadcast_read.cc @@ -72,7 +72,7 @@ Operation* GetAncestorBelow(Operation* descendant, Operation* ancestor) { // `is_cpu_read` is set to `true` iff `read` is on a resource with device type // CPU. LogicalResult IsCpuRead(FuncOp func, ReadVariableOp read, bool& is_cpu_read) { - if (auto arg = read->getOperand(0).dyn_cast()) { + if (auto arg = mlir::dyn_cast(read->getOperand(0))) { if (arg.getOwner() != &(func.front())) { is_cpu_read = false; return success(); diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/mark_ops_for_outside_compilation.cc index 617185360a9936..7308669b6359cb 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/mark_ops_for_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/mark_ops_for_outside_compilation.cc @@ -38,6 +38,7 @@ limitations under the License. #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Rewrite/PatternApplicator.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -238,13 +239,13 @@ void AddRewrittenCompositeOps(MLIRContext* context, } bool IsStringType(Type type) { - if (type.isa()) return true; + if (mlir::isa(type)) return true; - auto sub_type = type.dyn_cast(); + auto sub_type = mlir::dyn_cast(type); if (!sub_type) return false; bool has_string = llvm::any_of(sub_type.GetSubtypes(), [](TensorType type) { - return type.getElementType().isa(); + return mlir::isa(type.getElementType()); }); return has_string; } @@ -290,7 +291,8 @@ bool IsSupportedOp(Operation& op, } bool IsVariant(Value value) { - return getElementTypeOrSelf(value.getType()).isa(); + return mlir::isa( + getElementTypeOrSelf(value.getType())); } bool HasOutsideCompiledAncestor(Operation* op) { diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc index b600c865661d58..5f752cdedc5c82 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc @@ -51,6 +51,7 @@ limitations under the License. #include "mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" @@ -142,7 +143,7 @@ LogicalResult CollectMetadata(Block* block, MetadataMap* metadata_map) { return metadata_op.emitError() << kBadReplicateInfoAttrMsg; auto replication_info_attr_str = - replication_info_attr.dyn_cast(); + mlir::dyn_cast(replication_info_attr); if (!replication_info_attr_str || replication_info_attr_str.getValue().empty()) return metadata_op.emitError() << kBadReplicateInfoAttrMsg; @@ -991,17 +992,16 @@ LogicalResult FormClustersInBlock( // Determine `num_replicas`. auto num_replicas_attr = cluster_metadata->getSecond().get(kNumReplicasAttr); - if (!num_replicas_attr || !num_replicas_attr.isa()) + if (!num_replicas_attr || !mlir::isa(num_replicas_attr)) return cluster.emitError() << "requires '" << kNumReplicasAttr << "' int attribute"; - int num_replicas = num_replicas_attr.cast().getInt(); + int num_replicas = + mlir::cast(num_replicas_attr).getInt(); // Determine `num_cores_per_replica`. int num_cores_per_replica = 1; - auto num_cores_per_replica_attr = - cluster_metadata->getSecond() - .get(kNumCoresPerReplicaAttr) - .dyn_cast_or_null(); + auto num_cores_per_replica_attr = mlir::dyn_cast_or_null( + cluster_metadata->getSecond().get(kNumCoresPerReplicaAttr)); if (num_cores_per_replica_attr) num_cores_per_replica = num_cores_per_replica_attr.getInt(); if (failed(ReplicateCluster(cluster, num_replicas, num_cores_per_replica))) diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index edf0b96b569fea..c63cae83079a02 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -51,6 +51,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -89,10 +90,10 @@ static size_t GetFeatureDimension(tensorflow::TensorFormat format, // Gets all integer values from the given attribute and push them to `values`. void GetI64ArrayAttrValues(Attribute attr, SmallVectorImpl *values) { - auto array_attr = attr.cast(); + auto array_attr = mlir::cast(attr); values->reserve(array_attr.getValue().size()); for (Attribute val : array_attr.getValue()) - values->push_back(val.cast().getValue().getSExtValue()); + values->push_back(mlir::cast(val).getValue().getSExtValue()); } // Returns 1D 32-bit dense elements attribute with the given values. @@ -142,8 +143,8 @@ Type GetSumAccumulationType(Type input_type) { // format supports negative indexing unlike HLO. static IntegerAttr GetHLOAxisFromTFAxis(Attribute attr, int64_t rank, Builder *b) { - IntegerAttr intAttr = attr.dyn_cast_or_null(); - if (auto elementAttr = attr.dyn_cast_or_null()) { + IntegerAttr intAttr = mlir::dyn_cast_or_null(attr); + if (auto elementAttr = mlir::dyn_cast_or_null(attr)) { SmallVector index(elementAttr.getShapedType().getRank(), 0); intAttr = elementAttr.getValues()[index]; } @@ -198,7 +199,7 @@ static ConvertOp CastValueToI64(Location loc, Value value, // must be a ranked tensor. static TF::UnpackOp UnpackTensorAlongZeroDim(Location loc, Value value, PatternRewriter *rewriter) { - auto indices_type = value.getType().cast(); + auto indices_type = mlir::cast(value.getType()); int num_outputs = indices_type.getShape().front(); SmallVector unpacked_indices_type( num_outputs, @@ -214,7 +215,7 @@ static TF::UnpackOp UnpackTensorAlongZeroDim(Location loc, Value value, // // Aborts if the type is ranked but doesn't have the dimension. int64_t GetDimSize(Type ty, int64_t index) { - RankedTensorType ranked_ty = ty.dyn_cast(); + RankedTensorType ranked_ty = mlir::dyn_cast(ty); if (!ranked_ty) return -1; return ranked_ty.getDimSize(index); @@ -298,8 +299,8 @@ template static Value StaticBinaryBroadcast(Location loc, Value x, Value y, DenseIntElementsAttr broadcast_dims, OpBuilder &builder) { - auto x_type = x.getType().cast(); - auto y_type = y.getType().cast(); + auto x_type = mlir::cast(x.getType()); + auto y_type = mlir::cast(y.getType()); auto result_type = GetStaticBroadcastType(x_type, y_type, broadcast_dims); if (!result_type) { emitError(loc) << "could not binary broadcast " << x_type << ", " << y_type @@ -353,7 +354,7 @@ static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to, Value broadcast_from, int64_t feature_dim, OpBuilder &builder) { auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &builder); - auto to_type = broadcast_to.getType().cast(); + auto to_type = mlir::cast(broadcast_to.getType()); auto result_shape = builder.create(loc, broadcast_to); auto result_extents_type = GetExtentsTensorTypeFor(to_type); auto result_extents = builder.create( @@ -372,11 +373,11 @@ static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to, static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to, OpBuilder &builder) { auto result_shape = builder.create(loc, broadcast_to); - auto to_type = broadcast_to.getType().cast(); + auto to_type = mlir::cast(broadcast_to.getType()); auto result_extents_type = GetExtentsTensorTypeFor(to_type); auto result_extents = builder.create( loc, result_extents_type, result_shape); - int64_t rank = input.getType().cast().getRank(); + int64_t rank = mlir::cast(input.getType()).getRank(); auto broadcast_dims = GetI64ElementsAttrForSeq(0, rank, &builder); return builder.create( loc, to_type, input, result_extents, broadcast_dims); @@ -520,8 +521,8 @@ static void CreateWhile32(Location loc, int num_iterations, static IntegerAttr getFeatureDimensionAttr(Builder &b, tensorflow::TensorFormat format, Value input) { - return b.getI64IntegerAttr( - GetFeatureDimension(format, input.getType().cast())); + return b.getI64IntegerAttr(GetFeatureDimension( + format, mlir::cast(input.getType()))); } //===----------------------------------------------------------------------===// @@ -567,7 +568,7 @@ static DenseIntElementsAttr Get2DTransposePerm(BoolAttr transpose, Builder *b) { // attribute. static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D( ElementsAttr input, int column) { - auto int_attr = input.cast(); + auto int_attr = mlir::cast(input); auto shaped_type = int_attr.getType(); auto shape = shaped_type.getShape(); @@ -605,8 +606,8 @@ static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) { // must be broadcasted with a size 1 tensor or another dynamic dimension. // Returns false on rankless. static bool AreBroadcastCompatible(Value x, Value y) { - auto x_rankless = x.getType().dyn_cast(); - auto y_rankless = y.getType().dyn_cast(); + auto x_rankless = mlir::dyn_cast(x.getType()); + auto y_rankless = mlir::dyn_cast(y.getType()); if (!x_rankless || !y_rankless) { return false; } @@ -634,7 +635,7 @@ static bool AreBroadcastCompatible(Value x, Value y) { // updated element type. static Type ChangeTensorElementType(Builder *b, Type tensor_type, Type element_type) { - RankedTensorType ranked_type = tensor_type.dyn_cast(); + RankedTensorType ranked_type = mlir::dyn_cast(tensor_type); if (ranked_type) { return tensorflow::GetTypeFromTFTensorShape(ranked_type.getShape(), element_type); @@ -659,7 +660,7 @@ static Type GetAccumulationType(Type ty) { //===----------------------------------------------------------------------===// static DenseElementsAttr GetEpsilonValue(Type ty) { - auto element_ty = ty.cast().getElementType(); + auto element_ty = mlir::cast(ty).getElementType(); auto scalar_ty = tensorflow::GetTypeFromTFTensorShape({}, element_ty); if (element_ty.isF16()) { uint16_t raw_epsilon = Eigen::numext::bit_cast( @@ -750,9 +751,10 @@ static bool ArgTypesMatchCallee(mlir::Operation *op, OperandRange args, static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices, DenseIntElementsAttr slice_sizes) { - auto input_ty = input.getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(input.getType()); if (!input_ty) return false; - auto start_indices_ty = start_indices.getType().dyn_cast(); + auto start_indices_ty = + mlir::dyn_cast(start_indices.getType()); if (!start_indices_ty) return false; int64_t input_rank = input_ty.getRank(); @@ -780,11 +782,11 @@ static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes( Builder *builder) { DenseIntElementsAttr constant_start_indices; if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) { - return hlo::convertElementsAttr(slice_sizes, builder->getIntegerType(64)) - .cast(); + return mlir::cast( + hlo::convertElementsAttr(slice_sizes, builder->getIntegerType(64))); } - auto input_ty = input.getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(input.getType()); int64_t input_rank = input_ty.getRank(); ArrayRef input_shape = input_ty.getShape(); SmallVector normalized_sizes; @@ -906,7 +908,7 @@ class ConvertBiasAddOp : public OpRewritePattern { if (!FormatFromString(op.getDataFormat().str(), &data_format)) return op.emitOpError("invalid data format"); - auto value_type = op.getValue().getType().dyn_cast(); + auto value_type = mlir::dyn_cast(op.getValue().getType()); if (!value_type) return failure(); auto feature_dim = GetFeatureDimension(data_format, value_type); auto bias_broadcast = Broadcast1DToFeatureDim( @@ -1008,11 +1010,9 @@ class ConvertConvDynamic : public OpRewritePattern { if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) return failure(); - auto input_ty = - op.getInput().getType().template dyn_cast(); - auto filter_ty = - op.getFilter().getType().template dyn_cast(); - auto result_ty = op.getType().template dyn_cast(); + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + auto filter_ty = mlir::dyn_cast(op.getFilter().getType()); + auto result_ty = mlir::dyn_cast(op.getType()); if (!input_ty || !filter_ty || !result_ty) return failure(); // TODO(disc): Remove this constraint once fold and canonicalization // implemented. @@ -1035,7 +1035,7 @@ class ConvertConvDynamic : public OpRewritePattern { SmallVector paddings; auto get_int = [](Attribute attr) { - return attr.template cast().getInt(); + return mlir::cast(attr).getInt(); }; constexpr int num_dims = num_spatial_dims + 2; @@ -1177,10 +1177,8 @@ class ConvertConvOp : public OpRewritePattern { if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) return failure(); - auto input_ty = - op.getInput().getType().template dyn_cast(); - auto filter_ty = - op.getFilter().getType().template dyn_cast(); + auto input_ty = mlir::dyn_cast(op.getInput().getType()); + auto filter_ty = mlir::dyn_cast(op.getFilter().getType()); // With the exception of input's batch dimension, input and filter need to // have static shape for calculation of HLO paddings and feature group count @@ -1205,7 +1203,7 @@ class ConvertConvOp : public OpRewritePattern { SmallVector paddings; auto get_int = [](Attribute attr) { - return attr.template cast().getInt(); + return mlir::cast(attr).getInt(); }; constexpr int num_dims = num_spatial_dims + 2; @@ -1318,8 +1316,8 @@ class ConvertPadOpDynamic : public OpRewritePattern { auto input = op.getInput(); auto paddings = op.getPaddings(); auto constant_values = op.getConstantValues(); - auto input_type = input.getType().dyn_cast(); - auto paddings_type = paddings.getType().dyn_cast(); + auto input_type = mlir::dyn_cast(input.getType()); + auto paddings_type = mlir::dyn_cast(paddings.getType()); if (!input_type || !paddings_type || !paddings_type.hasStaticShape()) return failure(); @@ -1385,9 +1383,9 @@ class ConvertGatherNdOpDynamic : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); auto params = op.getParams(); - auto params_ty = params.getType().dyn_cast(); + auto params_ty = mlir::dyn_cast(params.getType()); auto indices = op.getIndices(); - auto indices_ty = indices.getType().dyn_cast(); + auto indices_ty = mlir::dyn_cast(indices.getType()); auto params_rank = params_ty.getRank(); auto indices_rank = indices_ty.getRank(); int64_t num_index_dims = indices_ty.getDimSize(indices_rank - 1); @@ -1485,8 +1483,8 @@ class ConvertBF16FloorDivOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::FloorDivOp op, PatternRewriter &rewriter) const override { - auto l = op.getX().dyn_cast>(); - auto r = op.getY().dyn_cast>(); + auto l = mlir::dyn_cast>(op.getX()); + auto r = mlir::dyn_cast>(op.getY()); if (!l || !r) return failure(); auto element_type = getElementTypeOrSelf(l.getType()); @@ -1515,14 +1513,14 @@ class ConvertBroadcastToOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::BroadcastToOp op, PatternRewriter &rewriter) const override { - auto input_type = op.getInput().getType().dyn_cast(); + auto input_type = mlir::dyn_cast(op.getInput().getType()); auto output_type = op.getOutput().getType(); if (!input_type) { return rewriter.notifyMatchFailure(op, "requires ranked input shape"); } llvm::SmallVector broadcast_dimensions; if (input_type.getRank() > 0) { - auto ranked_output_type = output_type.dyn_cast(); + auto ranked_output_type = mlir::dyn_cast(output_type); if (!ranked_output_type) { return rewriter.notifyMatchFailure(op, "requires ranked output shape"); } @@ -1546,7 +1544,7 @@ class ConvertRollOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TF::RollOp op, PatternRewriter &rewriter) const override { - auto shift_ty = op.getShift().getType().dyn_cast(); + auto shift_ty = mlir::dyn_cast(op.getShift().getType()); if (!shift_ty || shift_ty.getRank() != 0) { return rewriter.notifyMatchFailure( op, "require the type of shift to be 0D tensor"); @@ -1558,7 +1556,7 @@ class ConvertRollOp : public OpRewritePattern { } int axis = val.getSExtValue(); - auto input_ty = op.getInput().getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(op.getInput().getType()); if (!input_ty || !input_ty.hasStaticShape()) { return rewriter.notifyMatchFailure( op, "require the type of input to have static shapes"); @@ -1674,7 +1672,7 @@ class ConvertDiagPartOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::DiagPartOp op, PatternRewriter &rewriter) const override { - auto input_type = op.getInput().getType().dyn_cast(); + auto input_type = mlir::dyn_cast(op.getInput().getType()); if (!input_type || !input_type.hasStaticShape()) return failure(); int64_t num_dims = input_type.getRank(); if (num_dims < 2 || num_dims % 2 != 0) return failure(); @@ -1771,7 +1769,7 @@ class ConvertMatrixDiagPartV3Op LogicalResult matchAndRewrite(TF::MatrixDiagPartV3Op op, PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - ShapedType input_type = op.getInput().getType().dyn_cast(); + ShapedType input_type = mlir::dyn_cast(op.getInput().getType()); // Align is a string specifying how superdiagonals and subdiagonals should // be aligned/padded for diagonals that are shorter than max_diag_len. The @@ -2035,7 +2033,7 @@ class ConvertFFTOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - auto input_ty = op.getInput().getType().template cast(); + auto input_ty = mlir::cast(op.getInput().getType()); if (!input_ty.hasRank()) { return failure(); } @@ -2131,14 +2129,12 @@ class ConvertFusedBatchNormGradBase // TODO(b/141785544): Update this to not require static shapes. // activation shape needs to be static to convert negative indices in // TensorFlow to absolute indices required by HLO. - RankedTensorType act_type = - act.getType().template dyn_cast(); + RankedTensorType act_type = mlir::dyn_cast(act.getType()); if (!act_type) return failure(); Type act_ele_type = act_type.getElementType(); // To support mixed precision, the statistics type, which maybe more // precise than the input types, are used for this op. - Type kernel_type = - scale.getType().template cast().getElementType(); + Type kernel_type = mlir::cast(scale.getType()).getElementType(); grad = rewriter.create(loc, grad, kernel_type); act = rewriter.create(loc, act, kernel_type); @@ -2260,14 +2256,13 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { auto feature_dim = getFeatureDimensionAttr(rewriter, data_format, op.getX()); - auto input_type_tensor = op.getX().getType().template cast(); + auto input_type_tensor = mlir::cast(op.getX().getType()); auto input_element_type = input_type_tensor.getElementType(); - auto scale_type_tensor = - op.getScale().getType().template cast(); + auto scale_type_tensor = mlir::cast(op.getScale().getType()); auto scale_element_type = scale_type_tensor.getElementType(); - auto mean_type_tensor = op.getMean().getType().template cast(); + auto mean_type_tensor = mlir::cast(op.getMean().getType()); auto mean_element_type = mean_type_tensor.getElementType(); // In the training case, dimensions of input tensors must be static. if (op.getIsTraining() && (!input_type_tensor.hasStaticShape() || @@ -2281,7 +2276,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { Value bn_train_input = rewriter.create( op.getLoc(), op.getX(), scale_element_type); TensorType bn_train_input_type_tensor = - bn_train_input.getType().template cast(); + mlir::cast(bn_train_input.getType()); if (op.getIsTraining()) { // Training case. @@ -2372,7 +2367,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { // For FusedBatchNormV3Op, also create a constant tensor to forward to // last reserve_space_3 output. auto reserve_space_3_type = - op.getResult(5).getType().template cast(); + mlir::cast(op.getResult(5).getType()); int num_elements = reserve_space_3_type.hasStaticShape() ? reserve_space_3_type.getNumElements() : 0; @@ -2416,7 +2411,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { // For FusedBatchNormV3Op, also create a constant tensor to forward to // last reserve_space_3 output. auto reserve_space_3_type = - op.getResult(5).getType().template cast(); + mlir::cast(op.getResult(5).getType()); int num_elements = reserve_space_3_type.hasStaticShape() ? reserve_space_3_type.getNumElements() : 0; @@ -2465,9 +2460,9 @@ static PaddingArray GetReduceWindowPaddingAsArray( for (const auto &dim : input_dims) input_shape.push_back(dim); for (Attribute attr : window_dims) - window_shape.push_back(attr.cast().getInt()); + window_shape.push_back(mlir::cast(attr).getInt()); for (Attribute attr : window_strides) - strides.push_back(attr.cast().getInt()); + strides.push_back(mlir::cast(attr).getInt()); PaddingArray paddings = ::xla::MakePadding(input_shape, window_shape, strides, ::xla::Padding::kSame); @@ -2509,8 +2504,7 @@ Operation *AvgPoolDivideByCount( const SmallVector &strides, OpTy op, Value zero, PatternRewriter &rewriter) { Location loc = op.getLoc(); - RankedTensorType pooled_type = - pooled.getType().template cast(); + RankedTensorType pooled_type = mlir::cast(pooled.getType()); Type element_type = pooled_type.getElementType(); Operation *result = nullptr; RankedTensorType orig_input_type = @@ -2577,8 +2571,7 @@ class ConvertAvgPoolOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Value input_value = GetAvgPoolInput(op); - auto input_type = - input_value.getType().template dyn_cast(); + auto input_type = mlir::dyn_cast(input_value.getType()); if (!input_type) return failure(); // We will do accumulation first; use a larger bitwidth if suitable. @@ -2587,7 +2580,7 @@ class ConvertAvgPoolOp : public OpRewritePattern { Type result_type; // The result type for reduction and division with the proper element type. - if (auto ranked_type = op.getType().template dyn_cast()) + if (auto ranked_type = mlir::dyn_cast(op.getType())) result_type = tensorflow::GetTypeFromTFTensorShape(ranked_type.getShape(), sum_element_type); else @@ -2695,8 +2688,7 @@ class ConvertAvgPoolGradOp : public OpRewritePattern { // `out_grad` is the gradient that was propagated via backpropagation from // the output layer. Value out_grad = op.getGrad(); - auto out_grad_type = - out_grad.getType().template dyn_cast(); + auto out_grad_type = mlir::dyn_cast(out_grad.getType()); if (!out_grad_type) { return failure(); } @@ -2833,7 +2825,7 @@ class ConvertMaxPoolOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { Type element_type = - op.getInput().getType().template cast().getElementType(); + mlir::cast(op.getInput().getType()).getElementType(); if (!element_type.isSignlessIntOrFloat()) return failure(); tensorflow::Padding padding; if (!GetPaddingFromString(op.getPadding().str(), &padding).ok()) @@ -2845,8 +2837,7 @@ class ConvertMaxPoolOp : public OpRewritePattern { ConstantOp init = GetScalarLimitConstOfType( element_type, loc, hlo::kInfinityLowest, &rewriter); - auto input_ty = - op.getInput().getType().template dyn_cast(); + auto input_ty = mlir::dyn_cast(op.getInput().getType()); if (!input_ty) return failure(); DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( input_ty.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), @@ -2875,9 +2866,12 @@ class ConvertSelectOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::SelectOp op, PatternRewriter &rewriter) const override { // This lowering only works on ranked types. - auto cond_type = op.getCondition().getType().dyn_cast(); - auto then_type = op.getThenValue().getType().dyn_cast(); - auto else_type = op.getElseValue().getType().dyn_cast(); + auto cond_type = + mlir::dyn_cast(op.getCondition().getType()); + auto then_type = + mlir::dyn_cast(op.getThenValue().getType()); + auto else_type = + mlir::dyn_cast(op.getElseValue().getType()); if (!cond_type || !then_type || !else_type) { return failure(); } @@ -2913,7 +2907,7 @@ class ConvertSelectOp : public OpRewritePattern { assumption = b.createOrFold( witness, ValueRange{assumption, eq_cstr}); } - auto result_type = op.getResult().getType().cast(); + auto result_type = mlir::cast(op.getResult().getType()); auto assuming_op = b.create(ArrayRef{result_type}, assumption); @@ -2978,7 +2972,7 @@ class ConvertSigmoidOp : public RewritePattern { // Create constant half with shape and element type same as the operand. Value operand = op.getOperand(); - auto operand_ty = operand.getType().cast(); + auto operand_ty = mlir::cast(operand.getType()); auto scalar_ty = tensorflow::GetTypeFromTFTensorShape({}, operand_ty.getElementType()); ElementsAttr attr = mlir::hlo::getSplat(&rewriter, scalar_ty, 0.5); @@ -3009,9 +3003,9 @@ class ConvertSliceOpDynamic : public OpRewritePattern { Value begin_indices = op.getBegin(); Value sizes = op.getSize(); - auto input_ty = input.getType().dyn_cast(); - auto begin_type = begin_indices.getType().dyn_cast(); - auto size_type = sizes.getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(input.getType()); + auto begin_type = mlir::dyn_cast(begin_indices.getType()); + auto size_type = mlir::dyn_cast(sizes.getType()); if (!input_ty || !begin_type || !size_type || !begin_type.hasStaticShape() || !size_type.hasStaticShape() || @@ -3112,8 +3106,8 @@ static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc, loc, TypeRange{shape_type, shape_type}, lhs_shape, const_neg2); auto rhs_splitted = rewriter->create( loc, TypeRange{shape_type, shape_type}, rhs_shape, const_neg2); - auto lhs_type = lhs.getType().cast(); - auto rhs_type = rhs.getType().cast(); + auto lhs_type = mlir::cast(lhs.getType()); + auto rhs_type = mlir::cast(rhs.getType()); // The last two dimensions are the matrix row/col dimensions. Don't broadcast // them. SmallVector result_batch_shape_compile_time_extents; @@ -3166,21 +3160,21 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { PatternRewriter &rewriter) const override { Value lhs = op.getX(); Value rhs = op.getY(); - auto lhs_type = lhs.getType().dyn_cast(); - auto rhs_type = rhs.getType().dyn_cast(); + auto lhs_type = mlir::dyn_cast(lhs.getType()); + auto rhs_type = mlir::dyn_cast(rhs.getType()); if (!lhs_type || !rhs_type) return failure(); - if (lhs_type.getElementType().isa() && op.getAdjX()) { + if (mlir::isa(lhs_type.getElementType()) && op.getAdjX()) { lhs = rewriter.create(op.getLoc(), lhs_type, lhs); } - if (rhs_type.getElementType().isa() && op.getAdjY()) { + if (mlir::isa(rhs_type.getElementType()) && op.getAdjY()) { rhs = rewriter.create(op.getLoc(), rhs_type, rhs); } // Broadcast both operands. BroadcastBatchMatMulV2Operands(lhs, rhs, op.getLoc(), &lhs, &rhs, &rewriter); - lhs_type = lhs.getType().cast(); - rhs_type = rhs.getType().cast(); + lhs_type = mlir::cast(lhs.getType()); + rhs_type = mlir::cast(rhs.getType()); assert(lhs_type.getRank() == rhs_type.getRank()); int64_t rank = lhs_type.getRank(); auto batch_dimensions = llvm::to_vector<4>(llvm::seq(0, rank - 2)); @@ -3243,7 +3237,7 @@ class ConvertSplitOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::SplitOp op, PatternRewriter &rewriter) const override { // We can only split inputs that have fully static shape. - auto input_type = op.getValue().getType().dyn_cast(); + auto input_type = mlir::dyn_cast(op.getValue().getType()); if (!input_type || !input_type.hasStaticShape()) return failure(); // We can only match when the split dimension is a constant scalar. @@ -3304,7 +3298,7 @@ class ConvertSplitOpDynamic : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value input = op.getValue(); - auto input_type = input.getType().dyn_cast(); + auto input_type = mlir::dyn_cast(input.getType()); if (!input_type) return failure(); // TODO(disc): remove static shape check once folding/canonicalization func @@ -3419,7 +3413,7 @@ class ConvertSplitVOp : public OpRewritePattern { PatternRewriter &rewriter) const override { // We can only split inputs that have fully static shape. // TODO(b/145731001): enhance to support dynamic-shaped inputs. - auto input_type = op.getValue().getType().dyn_cast(); + auto input_type = mlir::dyn_cast(op.getValue().getType()); if (!input_type || !input_type.hasStaticShape()) return failure(); // We can only match when the split dimension is a constant scalar. @@ -3438,7 +3432,7 @@ class ConvertSplitVOp : public OpRewritePattern { int64_t total_dim_size = 0; // Total dimension size assigned to splits std::optional dynamic_dim_index; split_sizes.reserve( - split_sizes_attr.getType().cast().getNumElements()); + mlir::cast(split_sizes_attr.getType()).getNumElements()); for (const auto &dim : llvm::enumerate(split_sizes_attr)) { int64_t dim_val = dim.value().getSExtValue(); split_sizes.push_back(dim_val); @@ -3620,7 +3614,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { // Begin must be a ranked, 1-dimensional tensor: This is checked by the // verifier. int64_t slicing_dim_size = - op.getBegin().getType().cast().getDimSize(0); + mlir::cast(op.getBegin().getType()).getDimSize(0); uint64_t begin_mask = op.getBeginMask(); uint64_t end_mask = op.getEndMask(); const int input_rank = input_shape.size(); @@ -3642,7 +3636,7 @@ class ConvertStridedSliceOp : public OpRewritePattern { // For the dimensions that are to be sliced, all have slice sizes of 1. SmallVector slice_sizes; auto begin_element_ty = - op.getBegin().getType().cast().getElementType(); + mlir::cast(op.getBegin().getType()).getElementType(); // Scalar tensor type. TensorType type = tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, begin_element_ty); @@ -3696,14 +3690,14 @@ class ConvertStridedSliceOp : public OpRewritePattern { // // TODO(hinsu): Relax this constraint for ops without negative indices and // strides. - auto input_ty = op.getInput().getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(op.getInput().getType()); if (!input_ty || !input_ty.hasStaticShape()) return failure(); // Output shape needs to be static to apply 'new_axis_mask' or // 'shrink_axis_mask' by reshaping tensor after slice. // // TODO(hinsu): Relax this constraint for ops without the above masks. - auto result_ty = op.getType().dyn_cast(); + auto result_ty = mlir::dyn_cast(op.getType()); if (!result_ty || !result_ty.hasStaticShape()) return failure(); DenseIntElementsAttr sparse_begin_attr, sparse_end_attr; @@ -3750,7 +3744,7 @@ class ConvertStridedSliceGradOp return failure(); Value grad = op.getDy(); - Type element_type = grad.getType().cast().getElementType(); + Type element_type = mlir::cast(grad.getType()).getElementType(); // Perform reshape to undo any new/shrink axes done by strided slice. grad = rewriter.create( @@ -3830,7 +3824,7 @@ class ConvertRangeOp : public OpRewritePattern { PatternRewriter &rewriter) const override { auto result = op.getResult(); auto result_type = result.getType(); - if (!result_type.cast().hasStaticShape()) { + if (!mlir::cast(result_type).hasStaticShape()) { return failure(); } @@ -3863,7 +3857,7 @@ class ConvertDynamicRangeOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::RangeOp op, PatternRewriter &rewriter) const override { auto result = op.getResult(); - auto result_type = result.getType().cast(); + auto result_type = mlir::cast(result.getType()); if (result_type.hasStaticShape()) { return failure(); } @@ -3875,11 +3869,12 @@ class ConvertDynamicRangeOp : public OpRewritePattern { // To compute the length we need to use floating point calculations so that // ceil can be computed for the number of steps. auto compute_element_type = - getElementTypeOrSelf(start.getType()).isa() + mlir::isa(getElementTypeOrSelf(start.getType())) ? getElementTypeOrSelf(start.getType()) : rewriter.getF64Type(); auto compute_type = tensorflow::GetTypeFromTFTensorShape( - limit.getType().cast().getShape(), compute_element_type); + mlir::cast(limit.getType()).getShape(), + compute_element_type); // Compute the length of the sequence we are going to need. This includes // some conversion to float for the operations. @@ -3930,8 +3925,8 @@ class ConvertDynamicRangeOp : public OpRewritePattern { }; ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) { - auto int_attr = attr.cast(); - auto type = val.getType().cast(); + auto int_attr = mlir::cast(attr); + auto type = mlir::cast(val.getType()); SmallVector axis; axis.reserve(int_attr.getNumElements()); @@ -3954,7 +3949,7 @@ class ConvertLinSpaceOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::LinSpaceOp op, PatternRewriter &rewriter) const override { auto result = op.getResult(); - auto result_type = result.getType().dyn_cast(); + auto result_type = mlir::dyn_cast(result.getType()); if (!result_type || !result_type.hasStaticShape()) { return failure(); } @@ -4023,8 +4018,7 @@ class GenericConvertReductionOp : public OpRewritePattern { // TODO(b/141785544): Update this to not require ranked shapes. // Input shape needs to be ranked to convert negative indices in TensorFlow // to absolute indices required by HLO. - auto input_ty = - op.getInput().getType().template dyn_cast(); + auto input_ty = mlir::dyn_cast(op.getInput().getType()); if (!input_ty) return failure(); ArrayRef input_shape = input_ty.getShape(); @@ -4049,8 +4043,9 @@ class GenericConvertReductionOp : public OpRewritePattern { Type element_type = input_ty.getElementType(); // Only float, int, and complex types are currently supported. - if (!element_type.isa() && !element_type.isa() && - !element_type.isa()) { + if (!mlir::isa(element_type) && + !mlir::isa(element_type) && + !mlir::isa(element_type)) { return rewriter.notifyMatchFailure( op, "element type must be float, int, or complex type"); } @@ -4252,7 +4247,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { RankedTensorType input_type = - op.getInput().getType().template dyn_cast(); + mlir::dyn_cast(op.getInput().getType()); if (!input_type) { return failure(); } @@ -4267,7 +4262,7 @@ class ConvertArgMinMaxOp : public OpRewritePattern { Derived::GetInitialValue(input_element_type, loc, rewriter); RankedTensorType output_type = - op.getOutput().getType().template dyn_cast(); + mlir::dyn_cast(op.getOutput().getType()); if (!output_type) { return rewriter.notifyMatchFailure(op, "requires known rank"); } @@ -4364,12 +4359,11 @@ class ConvertTensorScatterOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - auto tensor_ty = - op.getTensor().getType().template dyn_cast(); + auto tensor_ty = mlir::dyn_cast(op.getTensor().getType()); auto indices_ty = - op.getIndices().getType().template dyn_cast(); + mlir::dyn_cast(op.getIndices().getType()); auto updates_ty = - op.getUpdates().getType().template dyn_cast(); + mlir::dyn_cast(op.getUpdates().getType()); if (!tensor_ty || !indices_ty || !updates_ty) return failure(); // Last dimension of the indices needs to known at compile time for @@ -4421,13 +4415,13 @@ class ConvertTensorScatterOp : public OpRewritePattern { updates = rewriter.create( op->getLoc(), broadcast_to_type, op.getUpdates(), const_op); - updates_ty = updates.getType().template dyn_cast(); + updates_ty = mlir::dyn_cast(updates.getType()); } int64_t tensor_rank = tensor_ty.getRank(); int64_t indices_rank = indices_ty.getRank(); int64_t updates_rank = - updates.getType().template dyn_cast().getRank(); + mlir::dyn_cast(updates.getType()).getRank(); int64_t window_dims = tensor_rank - num_index_dims; auto dims_attr = ScatterDimensionNumbersAttr::get( @@ -4558,7 +4552,7 @@ class ConvertTileOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::TileOp op, PatternRewriter &rewriter) const override { - auto input_ty = op.getInput().getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(op.getInput().getType()); if (!input_ty || !input_ty.hasStaticShape()) return failure(); ArrayRef input_shape = input_ty.getShape(); Type element_type = input_ty.getElementType(); @@ -4639,7 +4633,7 @@ class ConvertTileOpDynamic : public OpRewritePattern { Location loc = op.getLoc(); Value input = op.getInput(); Value multiples = op.getMultiples(); - auto input_ty = input.getType().dyn_cast(); + auto input_ty = mlir::dyn_cast(input.getType()); if (!input_ty) return failure(); // TODO(disc): Remove this constraint once fold and canonicalization // implemented. @@ -4659,7 +4653,7 @@ class ConvertTileOpDynamic : public OpRewritePattern { } } - auto multiples_ty = multiples.getType().dyn_cast(); + auto multiples_ty = mlir::dyn_cast(multiples.getType()); int64_t multiples_rank = multiples_ty.getRank(); // rank of multiples input of tf.TileOp must be 1 if (multiples_rank != 1) return failure(); @@ -4728,16 +4722,14 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Type element_type = op.getOrigInput() - .getType() - .template cast() - .getElementType(); + Type element_type = + mlir::cast(op.getOrigInput().getType()).getElementType(); // Compute paddings using the original input and kernel shape and strides. // Here, ReduceWindow op as used as the MaxPool op is lowered to the // ReduceWindow op. auto input_ty = - op.getOrigInput().getType().template dyn_cast(); + mlir::dyn_cast(op.getOrigInput().getType()); if (!input_ty) return failure(); DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( input_ty.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), @@ -4798,9 +4790,8 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { return failure(); auto out_backprop_ty = - op.getOutBackprop().getType().template dyn_cast(); - auto filter_ty = - op.getFilter().getType().template dyn_cast(); + mlir::dyn_cast(op.getOutBackprop().getType()); + auto filter_ty = mlir::dyn_cast(op.getFilter().getType()); // With the exception of out_backprop's batch dimension, out_backprop and // filter need to have static shape. Filter is validated here, out_backprop @@ -4824,7 +4815,7 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { } else { auto pack = op.getInputSizes().template getDefiningOp(); if (!pack || pack.getAxis() != 0) return failure(); - auto pack_ty = pack.getType().template dyn_cast(); + auto pack_ty = mlir::dyn_cast(pack.getType()); if (!pack_ty || pack_ty.getRank() != 1) return failure(); for (auto i = 0; i < pack_ty.getDimSize(0); ++i) { if (i == batch_dim) { @@ -4862,7 +4853,7 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { explicit_paddings.reserve(explicit_paddings_attr.size()); for (Attribute explicit_padding : explicit_paddings_attr) explicit_paddings.push_back( - explicit_padding.cast().getInt()); + mlir::cast(explicit_padding).getInt()); } ArrayRef filter_shape = filter_ty.getShape(); @@ -5029,9 +5020,8 @@ class ConvertConvBackpropFilterOp : public OpRewritePattern { return failure(); auto out_backprop_ty = - op.getOutBackprop().getType().template dyn_cast(); - auto input_ty = - op.getInput().getType().template dyn_cast(); + mlir::dyn_cast(op.getOutBackprop().getType()); + auto input_ty = mlir::dyn_cast(op.getInput().getType()); for (RankedTensorType ty : {out_backprop_ty, input_ty}) if (!ty || !ty.hasStaticShape()) return failure(); @@ -5063,7 +5053,7 @@ class ConvertConvBackpropFilterOp : public OpRewritePattern { explicit_paddings.reserve(explicit_paddings_attr.size()); for (Attribute explicit_padding : explicit_paddings_attr) explicit_paddings.push_back( - explicit_padding.cast().getInt()); + mlir::cast(explicit_padding).getInt()); } constexpr int num_dims = num_spatial_dims + 2; @@ -5223,7 +5213,8 @@ class ConvertOneHotOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::OneHotOp op, PatternRewriter &rewriter) const override { - auto indices_ty = op.getIndices().getType().dyn_cast(); + auto indices_ty = + mlir::dyn_cast(op.getIndices().getType()); if (!indices_ty || !indices_ty.hasStaticShape()) return failure(); ArrayRef indices_shape = indices_ty.getShape(); Type element_type = indices_ty.getElementType(); @@ -5307,7 +5298,7 @@ class ConvertInfeedDequeueTupleOp result_types.reserve(op.getOutputs().size() + 1); for (const auto &output : op.getOutputs()) { Type ty = output.getType(); - if (auto tensor_ty = ty.dyn_cast()) { + if (auto tensor_ty = mlir::dyn_cast(ty)) { if (!tensor_ty.hasStaticShape()) return failure(); } result_types.push_back(ty); @@ -5412,7 +5403,7 @@ class ConvertTopKV2Op : public OpRewritePattern { if (!matchPattern(op.getK(), m_Constant(&k_attr))) return failure(); int64_t k = (*k_attr.begin()).getSExtValue(); - TensorType input_type = op.getInput().getType().cast(); + TensorType input_type = mlir::cast(op.getInput().getType()); if (!input_type.hasRank()) return failure(); int64_t input_rank = input_type.getRank(); int64_t last_dim_index = input_rank - 1; @@ -5436,7 +5427,7 @@ class ConvertUnpackOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::UnpackOp op, PatternRewriter &rewriter) const override { - auto value_type = op.getValue().getType().dyn_cast(); + auto value_type = mlir::dyn_cast(op.getValue().getType()); if (!value_type) return failure(); int64_t value_rank = value_type.getRank(); @@ -5482,7 +5473,7 @@ class ConvertUnpackOpDynamic : public OpRewritePattern { LogicalResult matchAndRewrite(TF::UnpackOp op, PatternRewriter &rewriter) const override { - auto value_type = op.getValue().getType().dyn_cast(); + auto value_type = mlir::dyn_cast(op.getValue().getType()); if (!value_type) return failure(); // TODO(disc): Remove this constraint once fold and canonicalization // implemented. @@ -5585,8 +5576,8 @@ class ConvertSigmoidGradOpDynamic : public OpRewritePattern { Location loc = op.getLoc(); Value y = op.getY(); Value dy = op.getDy(); - auto tp_y = y.getType().dyn_cast(); - auto tp_dy = dy.getType().dyn_cast(); + auto tp_y = mlir::dyn_cast(y.getType()); + auto tp_dy = mlir::dyn_cast(dy.getType()); if (!tp_y || !tp_dy) return failure(); // TODO(disc): Remove this constraint once fold and canonicalization @@ -5598,7 +5589,7 @@ class ConvertSigmoidGradOpDynamic : public OpRewritePattern { if (elem_tp.isSignlessInteger()) { attr = rewriter.getIntegerAttr(elem_tp, 1); } else { - assert(elem_tp.isa()); + assert(mlir::isa(elem_tp)); attr = rewriter.getFloatAttr(elem_tp, 1); } Value one = rewriter.create( @@ -5640,13 +5631,12 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { - auto data_type = - op.getData().getType().template dyn_cast(); + auto data_type = mlir::dyn_cast(op.getData().getType()); if (!data_type) return failure(); int64_t data_rank = data_type.getRank(); auto segment_ids_type = - op.getSegmentIds().getType().template dyn_cast(); + mlir::dyn_cast(op.getSegmentIds().getType()); if (!segment_ids_type) return failure(); int64_t segment_ids_rank = segment_ids_type.getRank(); @@ -5766,7 +5756,7 @@ class ConvertRandomShuffleOp : public OpRewritePattern { return success(); }; - auto input_type = op.getValue().getType().dyn_cast(); + auto input_type = mlir::dyn_cast(op.getValue().getType()); if (!input_type) return failure(); if (input_type.hasStaticShape() && input_type.getNumElements() <= 1) // No shuffling is required, so copy input directly to output. @@ -5966,16 +5956,16 @@ class ConvertInplaceUpdateOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::InplaceUpdateOp op, PatternRewriter &rewriter) const override { - auto input = op.getX().dyn_cast>(); + auto input = mlir::dyn_cast>(op.getX()); if (!input) return failure(); auto indices = op.getI(); auto updates = op.getV(); // Slice each row of `i` and `v` to perform a separate dynamic-update-slice // on the contents of `x`. - auto input_type = input.getType().cast(); - auto updates_type = updates.getType().cast(); - auto indices_type = indices.getType().cast(); + auto input_type = mlir::cast(input.getType()); + auto updates_type = mlir::cast(updates.getType()); + auto indices_type = mlir::cast(indices.getType()); if (!input_type.hasRank()) return failure(); if (!updates_type.hasRank() || updates_type.isDynamicDim(0)) return failure(); @@ -6033,7 +6023,8 @@ class ConvertXlaDynamicUpdateSliceOp LogicalResult matchAndRewrite(TF::XlaDynamicUpdateSliceOp op, PatternRewriter &rewriter) const override { - auto indices_type = op.getIndices().getType().dyn_cast(); + auto indices_type = + mlir::dyn_cast(op.getIndices().getType()); if (!indices_type || !indices_type.hasStaticShape() || indices_type.getShape().size() != 1) return failure(); @@ -6062,8 +6053,8 @@ class ConvertXlaReduceScatterOp if (!matchPattern(op.getGroupAssignment(), m_Constant(&group_assignment))) return failure(); auto replica_groups = - hlo::convertElementsAttr(group_assignment, rewriter.getIntegerType(64)) - .cast(); + mlir::cast(hlo::convertElementsAttr( + group_assignment, rewriter.getIntegerType(64))); if (replica_groups.getType().getRank() != 2) return failure(); APInt scatter_dimension; @@ -6141,16 +6132,16 @@ class ConvertXlaReduceWindowOp // Create the mhlo.SelectAndScatter op. auto reduce_window_op = rewriter.create( loc, result_types, op.getInput(), op.getInitValue(), - hlo::convertElementsAttr(window_dimensions, rewriter.getIntegerType(64)) - .cast(), - hlo::convertElementsAttr(window_strides, rewriter.getIntegerType(64)) - .cast(), - hlo::convertElementsAttr(base_dilations, rewriter.getIntegerType(64)) - .cast(), - hlo::convertElementsAttr(window_dilations, rewriter.getIntegerType(64)) - .cast(), - hlo::convertElementsAttr(padding, rewriter.getIntegerType(64)) - .cast()); + mlir::cast(hlo::convertElementsAttr( + window_dimensions, rewriter.getIntegerType(64))), + mlir::cast(hlo::convertElementsAttr( + window_strides, rewriter.getIntegerType(64))), + mlir::cast(hlo::convertElementsAttr( + base_dilations, rewriter.getIntegerType(64))), + mlir::cast(hlo::convertElementsAttr( + window_dilations, rewriter.getIntegerType(64))), + mlir::cast( + hlo::convertElementsAttr(padding, rewriter.getIntegerType(64)))); // Insert a call to the reducer in the region of the mhlo op. mlir::SymbolRefAttr func = op.getComputation(); auto func_op = cast(SymbolTable::lookupSymbolIn( @@ -6177,9 +6168,9 @@ class ConvertClipByValueOp : public OpRewritePattern { Value min = op.getClipValueMin(); Value max = op.getClipValueMax(); - auto input_ty = input.getType().cast(); - auto min_ty = min.getType().cast(); - auto max_ty = max.getType().cast(); + auto input_ty = mlir::cast(input.getType()); + auto min_ty = mlir::cast(min.getType()); + auto max_ty = mlir::cast(max.getType()); if (!input_ty.hasRank() || !min_ty.hasRank() || !max_ty.hasRank()) { return failure(); @@ -6215,8 +6206,9 @@ class ConvertConstOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::ConstOp op, PatternRewriter &rewriter) const override { // Convert only for valid HLO tensors. - auto ty = op.getType().dyn_cast(); - if (!ty || !ty.getElementType().isa()) + auto ty = mlir::dyn_cast(op.getType()); + if (!ty || + !mlir::isa(ty.getElementType())) return failure(); Location loc = op.getLoc(); @@ -6239,9 +6231,9 @@ class ConvertCumOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const override { - auto input = op.getX().template dyn_cast>(); + auto input = mlir::dyn_cast>(op.getX()); if (!input) return failure(); - auto input_type = input.getType().template dyn_cast(); + auto input_type = mlir::dyn_cast(input.getType()); if (!input_type || !input_type.hasStaticShape()) { return failure(); } @@ -6352,7 +6344,7 @@ class ConvertShapeOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Value input = op.getInput(); - auto result_ty = op.getResult().getType().dyn_cast(); + auto result_ty = mlir::dyn_cast(op.getResult().getType()); if (!result_ty) { return failure(); } @@ -6373,8 +6365,8 @@ class ConvertDynamicExpandDimsOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::ExpandDimsOp op, PatternRewriter &rewriter) const override { auto input = op.getInput(); - auto input_ty = input.getType().cast(); - auto result_ty = op.getType().cast(); + auto input_ty = mlir::cast(input.getType()); + auto result_ty = mlir::cast(op.getType()); if (!result_ty.hasRank() || !input_ty.hasRank() || result_ty.hasStaticShape()) { return failure(); @@ -6431,8 +6423,8 @@ class ConvertDynamicSqueezeOp : public OpRewritePattern { LogicalResult matchAndRewrite(TF::SqueezeOp op, PatternRewriter &rewriter) const override { auto input = op.getInput(); - auto input_ty = input.getType().cast(); - auto result_ty = op.getType().cast(); + auto input_ty = mlir::cast(input.getType()); + auto result_ty = mlir::cast(op.getType()); if (!result_ty.hasRank() || !input_ty.hasRank() || result_ty.hasStaticShape()) { return failure(); @@ -6492,24 +6484,23 @@ class ConvertXlaConvV2Op : public OpRewritePattern { return failure(); auto window_strides_named_attr = rewriter.getNamedAttr( - "window_strides", hlo::convertElementsAttr(window_strides_attr, - rewriter.getIntegerType(64)) - .cast()); + "window_strides", + mlir::cast(hlo::convertElementsAttr( + window_strides_attr, rewriter.getIntegerType(64)))); auto padding_named_attr = rewriter.getNamedAttr( - "padding", - hlo::convertElementsAttr(padding_attr, rewriter.getIntegerType(64)) - .cast()); + "padding", mlir::cast(hlo::convertElementsAttr( + padding_attr, rewriter.getIntegerType(64)))); auto lhs_dilation_named_attr = rewriter.getNamedAttr( "lhs_dilation", - hlo::convertElementsAttr(lhs_dilation_attr, rewriter.getIntegerType(64)) - .cast()); + mlir::cast(hlo::convertElementsAttr( + lhs_dilation_attr, rewriter.getIntegerType(64)))); auto rhs_dilation_named_attr = rewriter.getNamedAttr( "rhs_dilation", - hlo::convertElementsAttr(rhs_dilation_attr, rewriter.getIntegerType(64)) - .cast()); + mlir::cast(hlo::convertElementsAttr( + rhs_dilation_attr, rewriter.getIntegerType(64)))); int64_t feature_group_count_val = feature_group_count_attr.getValues()[0].getInt(); @@ -6566,12 +6557,12 @@ class ConvertXlaSelectAndScatterOp // Create the mhlo.SelectAndScatter op. auto select_and_scatter_op = rewriter.create( loc, result_types, op.getOperand(), op.getSource(), op.getInitValue(), - hlo::convertElementsAttr(window_dimensions, rewriter.getIntegerType(64)) - .cast(), - hlo::convertElementsAttr(window_strides, rewriter.getIntegerType(64)) - .cast(), - hlo::convertElementsAttr(padding, rewriter.getIntegerType(64)) - .cast()); + mlir::cast(hlo::convertElementsAttr( + window_dimensions, rewriter.getIntegerType(64))), + mlir::cast(hlo::convertElementsAttr( + window_strides, rewriter.getIntegerType(64))), + mlir::cast( + hlo::convertElementsAttr(padding, rewriter.getIntegerType(64)))); auto insert_call_to = [&](const mlir::SymbolRefAttr &func, Region *region) { auto func_op = cast(SymbolTable::lookupSymbolIn( @@ -6671,7 +6662,7 @@ class ConvertXlaVariadicReduceV2Op auto func_ty = func_op.getFunctionType(); SmallVector elementTypes{llvm::map_range( func_ty.getResults(), - [](Type ty) { return ty.cast().getElementType(); })}; + [](Type ty) { return mlir::cast(ty).getElementType(); })}; // Create the mhlo.reduce op. auto reduce_op = rewriter.create( @@ -6754,7 +6745,7 @@ class LowerYieldOp : public OpConversionPattern { // Returns a new tensor type from the given type with element type updated to // the given type. TensorType UpdateElementTypeTo(Type ty, Type element_ty) { - auto ranked_ty = ty.dyn_cast(); + auto ranked_ty = mlir::dyn_cast(ty); if (!ranked_ty) { return UnrankedTensorType::get(element_ty); } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc index 54bd5812644488..34df8fc9759a5c 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_collective.cc @@ -113,9 +113,8 @@ LogicalResult ConvertReplicaGroups(OpBuilder& builder, if (!matchPattern(group_assignment_value, m_Constant(&group_assignment))) { return op->emitOpError() << "expects constant group_assignment"; } - replica_groups = - hlo::convertElementsAttr(group_assignment, builder.getIntegerType(64)) - .cast(); + replica_groups = mlir::cast( + hlo::convertElementsAttr(group_assignment, builder.getIntegerType(64))); if (replica_groups.getType().getRank() != 2) { return op->emitOpError() << "group_assignment should have rank 2, got " << replica_groups.getType().getRank(); diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc index 3e8dd5b58ed2f1..68c412f79ff393 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_communication.cc @@ -458,7 +458,7 @@ SmallVector GetValueWithToken( return new_result; }; - auto tuple_type = value.getType().dyn_cast(); + auto tuple_type = mlir::dyn_cast(value.getType()); // `value` is not a tuple, create a new tuple. if (!tuple_type) return {create_tuple({value, token})}; @@ -499,7 +499,7 @@ SmallVector GetTypeWithToken(OpBuilder& builder, ArrayRef types, } auto type = types[0]; - if (auto tuple_type = type.dyn_cast()) { + if (auto tuple_type = mlir::dyn_cast(type)) { auto result_types = llvm::to_vector(tuple_type.getTypes()); result_types.push_back(token_type); return {builder.getTupleType(result_types)}; @@ -536,7 +536,7 @@ void ReplaceWithTupleResult(OpBuilder& builder, ValueRange values, auto value = values[0]; auto replacement = replacements[0]; - auto tuple_type = value.getType().dyn_cast(); + auto tuple_type = mlir::dyn_cast(value.getType()); if (!tuple_type) { if (!value.use_empty()) { auto new_element = builder.create(replacement.getLoc(), diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc index d5560f2481b00f..ce8b46708d2f52 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -75,13 +76,13 @@ namespace { // Returns true if the given type is a ranked tensor type with static or bounded // dimensions. bool IsBounded(Type ty) { - auto ranked_ty = ty.dyn_cast(); + auto ranked_ty = mlir::dyn_cast(ty); if (!ranked_ty) return false; if (ranked_ty.hasStaticShape()) return true; auto encoding = - ranked_ty.getEncoding().dyn_cast_or_null(); + mlir::dyn_cast_or_null(ranked_ty.getEncoding()); if (!encoding) return false; for (int i = 0; i < ranked_ty.getRank(); ++i) { @@ -96,10 +97,11 @@ bool IsBounded(Type ty) { bool HasSymbolRefAttr(Operation* op) { for (const auto& attr : op->getAttrs()) { Attribute attr_value = attr.getValue(); - if (attr_value.isa()) { + if (mlir::isa(attr_value)) { return true; - } else if (auto array_attr = attr_value.dyn_cast()) { - if (!array_attr.empty() && array_attr.begin()->isa()) { + } else if (auto array_attr = mlir::dyn_cast(attr_value)) { + if (!array_attr.empty() && + mlir::isa(*array_attr.begin())) { return true; } } @@ -146,8 +148,8 @@ class Tf2XlaRewritePattern : public ConversionPattern { }; bool ShouldRefineTypeTo(Type original_ty, Type updated_ty) { - auto updated = updated_ty.dyn_cast(); - auto original = original_ty.dyn_cast(); + auto updated = mlir::dyn_cast(updated_ty); + auto original = mlir::dyn_cast(original_ty); // Both types must be shaped types. if (!original || !updated) return false; diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc index b17d474f85a652..927e59e0195bb2 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc @@ -232,13 +232,13 @@ LogicalResult Tf2XlaRewriter::PrepareParams() { // Returns true if the given type is a ranked tensor type with static or // bounded dimensions. bool IsBounded(Type ty) { - auto ranked_ty = ty.dyn_cast(); + auto ranked_ty = mlir::dyn_cast(ty); if (!ranked_ty) return false; if (ranked_ty.hasStaticShape()) return true; auto encoding = - ranked_ty.getEncoding().dyn_cast_or_null(); + mlir::dyn_cast_or_null(ranked_ty.getEncoding()); if (!encoding) return false; for (int i = 0; i < ranked_ty.getRank(); ++i) { @@ -253,10 +253,11 @@ bool IsBounded(Type ty) { bool HasSymbolRefAttr(Operation* op) { for (const auto& attr : op->getAttrs()) { Attribute attr_value = attr.getValue(); - if (attr_value.isa()) { + if (mlir::isa(attr_value)) { return true; - } else if (auto array_attr = attr_value.dyn_cast()) { - if (!array_attr.empty() && array_attr.begin()->isa()) { + } else if (auto array_attr = mlir::dyn_cast(attr_value)) { + if (!array_attr.empty() && + mlir::isa(*array_attr.begin())) { return true; } } @@ -305,7 +306,7 @@ LogicalResult Tf2XlaRewriter::PrepareKernelInputs( LogicalResult Tf2XlaRewriter::LegalizeOp() { for (Type ty : op_->getOperandTypes()) { - auto ranked_ty = ty.dyn_cast(); + auto ranked_ty = mlir::dyn_cast(ty); // Only bounded operands are supported in the XLA builders. if (!IsBounded(ranked_ty)) { return op_->emitRemark() diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc index 7938fc4684ce2b..a6435081820880 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -89,18 +90,18 @@ static void IncrementCounterFor(tensorflow::monitoring::Counter<1>* counter, } bool HasBounds(RankedTensorType type) { - auto encoding = - type.getEncoding().dyn_cast_or_null(); + auto encoding = mlir::dyn_cast_or_null( + type.getEncoding()); return (encoding && !encoding.getBounds().empty()); } bool HasStaticShapeOrBounded(Value val) { auto type = val.getType(); - if (type.isa()) { + if (mlir::isa(type)) { return false; } - if (type.isa()) { - auto ranked_tensor = type.dyn_cast(); + if (mlir::isa(type)) { + auto ranked_tensor = mlir::dyn_cast(type); if (ranked_tensor.hasStaticShape()) { return true; } diff --git a/tensorflow/compiler/mlir/tfr/passes/decompose.cc b/tensorflow/compiler/mlir/tfr/passes/decompose.cc index 5d59d958d3e7c9..988dc9e612b9c3 100644 --- a/tensorflow/compiler/mlir/tfr/passes/decompose.cc +++ b/tensorflow/compiler/mlir/tfr/passes/decompose.cc @@ -84,8 +84,8 @@ namespace { // Quantize the float value based on given scale and zero point attributes. IntegerAttr Quantize(float value, Attribute scale_attr, Attribute zp_attr, OpBuilder builder) { - double scale = scale_attr.cast().getValueAsDouble(); - int64_t zp = zp_attr.cast().getInt(); + double scale = mlir::cast(scale_attr).getValueAsDouble(); + int64_t zp = mlir::cast(zp_attr).getInt(); int quantized = static_cast(std::round(value / scale) + zp); quantized = @@ -187,11 +187,12 @@ LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() { // default value in the argument attribute. llvm::SmallVector new_operands; for (auto arg : llvm::enumerate(compose_func_type.getInputs())) { - if (auto tensor_type = arg.value().dyn_cast()) { + if (auto tensor_type = mlir::dyn_cast(arg.value())) { auto casted = builder.create(op->getLoc(), tensor_type, op->getOperand(arg.index())); new_operands.push_back(casted); - } else if (auto list_type = arg.value().dyn_cast()) { + } else if (auto list_type = + mlir::dyn_cast(arg.value())) { llvm::SmallVector variadic_operands; for (int i = arg.index(); i < op->getNumOperands(); i++) { auto casted = builder.create( @@ -211,8 +212,8 @@ LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() { } if (!attribute && attr_name.getValue() == "out_type") { auto type = op->getResult(0).getType(); - if (type.isa()) { - type = type.cast().getElementType(); + if (mlir::isa(type)) { + type = mlir::cast(type).getElementType(); } attribute = TypeAttr::get(type); } @@ -220,8 +221,9 @@ LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() { // Wrap these special attributes as a special TFR constant, so the SSA // value has a valid type to be used as TFR function argument. These // attributes are not expected to be manipulated by the lowering passes. - if (attribute.isa() || attribute.isa() || - attribute.isa() || attribute.isa()) { + if (mlir::isa(attribute) || mlir::isa(attribute) || + mlir::isa(attribute) || + mlir::isa(attribute)) { TFRAttrType output_type = TFRAttrType::get(builder.getContext()); attr_cst = builder.create(op->getLoc(), output_type, attribute); @@ -245,9 +247,10 @@ LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() { // op result. llvm::SmallVector new_results; for (auto res : llvm::enumerate(compose_func_type.getResults())) { - if (res.value().dyn_cast()) { + if (mlir::dyn_cast(res.value())) { new_results.push_back(new_op.getResult(res.index())); - } else if (auto list_type = res.value().dyn_cast()) { + } else if (auto list_type = + mlir::dyn_cast(res.value())) { for (int i = res.index(), j = 0; i < op->getNumResults(); i++, j++) { auto index = builder.create( op->getLoc(), builder.getIndexAttr(j)); diff --git a/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc b/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc index dd85565cfed88e..61aa404847ee07 100644 --- a/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc +++ b/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc @@ -136,7 +136,7 @@ class RewriteTFRCallOp : public OpRewritePattern { // by the frontend correctly. Value CastToNonDerivedType(PatternRewriter& rewriter, Location loc, CastOp cast_op, Type input_tfr_type) const { - auto tensor_type = input_tfr_type.dyn_cast(); + auto tensor_type = mlir::dyn_cast(input_tfr_type); if (!tensor_type) return cast_op.getArg(); auto attr_names = tensor_type.getAttrKeys(); @@ -150,7 +150,7 @@ class RewriteTFRCallOp : public OpRewritePattern { } Type original_input_type = - cast_op.getInputElementType().cast().getValue(); + mlir::cast(cast_op.getInputElementType()).getValue(); if (result_elt_type != original_input_type) { UnrankedTensorType result_type = UnrankedTensorType::get(result_elt_type); return rewriter.create(loc, result_type, cast_op.getArg()); @@ -166,10 +166,10 @@ class RewriteTFRCallOp : public OpRewritePattern { llvm::SmallVectorImpl& input_values) const { if (input_types.size() <= 1) return; - Type target_input_type = input_types[0].cast().getValue(); + Type target_input_type = mlir::cast(input_types[0]).getValue(); auto result_type = UnrankedTensorType::get(target_input_type); for (auto i = 1; i < input_types.size(); ++i) { - Type current_input_type = input_types[i].cast().getValue(); + Type current_input_type = mlir::cast(input_types[i]).getValue(); if (current_input_type != target_input_type) { input_values[i] = rewriter.create(loc, result_type, input_values[i]); @@ -189,7 +189,7 @@ LogicalResult RewriteTFRCallOp::AddDerivedAttrs( llvm::StringMap* derived_attrs) const { // If there is an attribute associated to the input in the signature, we // store it as an derived attribute. - if (auto tensor_type = input_tfr_type.dyn_cast()) { + if (auto tensor_type = mlir::dyn_cast(input_tfr_type)) { auto attr_names = tensor_type.getAttrKeys(); if (attr_names.empty()) return success(); @@ -201,7 +201,7 @@ LogicalResult RewriteTFRCallOp::AddDerivedAttrs( // If there is an attribute associated to the input in the signature, // we store it as an derived attribute. - if (auto list_type = input_tfr_type.dyn_cast()) { + if (auto list_type = mlir::dyn_cast(input_tfr_type)) { auto attr_names = list_type.getAttrKeys(); if (attr_names.empty()) return success(); @@ -314,7 +314,7 @@ Attribute RewriteTFRCallOp::ProcessAttributeValue(Attribute attr, if (!attr_type) return attr; if (attr_type.getValue() == "tensor") { - if (auto f = attr.dyn_cast()) { + if (auto f = mlir::dyn_cast(attr)) { RankedTensorType type = RankedTensorType::get({}, f.getType()); return DenseFPElementsAttr::get(type, attr); } @@ -332,13 +332,13 @@ LogicalResult RewriteTFRCallOp::DeriveOutputTypes( const llvm::StringMap& attrs, SmallVectorImpl* output_types) const { for (auto res : llvm::enumerate(signature.getResults())) { - if (auto tensor_type = res.value().dyn_cast()) { + if (auto tensor_type = mlir::dyn_cast(res.value())) { // tfr.tensor should only have one attribute attached. auto attr_key = tensor_type.getAttrKeys().front(); Builder builder(signature.getContext()); if (auto attr = attrs.lookup(attr_key.getValue())) { output_types->push_back( - UnrankedTensorType::get(attr.cast().getValue())); + UnrankedTensorType::get(mlir::cast(attr).getValue())); } else if (Type element_type = GetFixedElementType(attr_key.getValue(), builder)) { output_types->push_back(UnrankedTensorType::get(element_type)); @@ -350,16 +350,18 @@ LogicalResult RewriteTFRCallOp::DeriveOutputTypes( continue; } - if (auto list_type = res.value().dyn_cast()) { + if (auto list_type = mlir::dyn_cast(res.value())) { // There are two cases: N*T or list(dtype) auto attr_keys = list_type.getAttrKeys(); // N*T case if (attr_keys.size() == 2) { // The first one is N, and the second one is T int list_size = - attrs.lookup(attr_keys[0].getValue()).cast().getInt(); + mlir::cast(attrs.lookup(attr_keys[0].getValue())) + .getInt(); Type list_type = - attrs.lookup(attr_keys[1].getValue()).cast().getValue(); + mlir::cast(attrs.lookup(attr_keys[1].getValue())) + .getValue(); for (int i = 0; i < list_size; ++i) { output_types->push_back(UnrankedTensorType::get(list_type)); } @@ -398,11 +400,12 @@ LogicalResult RewriteTFRCallOp::CreateAndReplaceOp( SmallVector new_results; for (auto res : llvm::enumerate(call_op.getResultTypes())) { Type res_type = res.value(); - if (res_type.dyn_cast()) { + if (mlir::dyn_cast(res_type)) { Value new_res = new_op->getResult(res.index()); auto casted = rewriter.create(loc, res_type, new_res); new_results.push_back(casted.getOut()); - } else if (auto list_type = res.value().dyn_cast()) { + } else if (auto list_type = + mlir::dyn_cast(res.value())) { SmallVector tensor_list; for (int i = res.index(); i < new_op->getNumResults(); i++) { Value new_res = new_op->getResult(i); diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index 2343fd84c07d32..3642cf67ce74a0 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -145,6 +145,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@tf_runtime//:basic_kernels_opdefs", @@ -166,6 +167,7 @@ cc_library( "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_async_opdefs", "//tensorflow/compiler/mlir/tfrt/ir:tfrt_fallback_opdefs", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@tf_runtime//:basic_kernels_opdefs", @@ -331,6 +333,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:errors", "@tf_runtime//:bef", "@tf_runtime//:core_runtime", @@ -426,6 +429,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@tf_runtime//:compiler_tfrt_op_interfaces", ], ) @@ -621,6 +625,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@tf_runtime//:core_runtime_opdefs", ], ) @@ -641,6 +646,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@tf_runtime//:basic_kernels_opdefs", "@tf_runtime//:core_runtime_opdefs", ], diff --git a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc index 5573e7c2d46866..28f582723c8b2f 100644 --- a/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc +++ b/tensorflow/compiler/mlir/tfrt/analysis/cost_analysis.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfrt/constants.h" #include "tensorflow/core/tfrt/fallback/cost_recorder.h" @@ -59,14 +60,14 @@ int64_t InferLookupTableFindV2Cost(const CostContext& context, constexpr int64_t kLookupTableFindCostScale = 8; constexpr int64_t kLookupTableFindStringKeyCostScale = 16; - auto value_type = op.getValues().getType().cast(); - auto key_type = op.getKeys().getType().cast(); + auto value_type = mlir::cast(op.getValues().getType()); + auto key_type = mlir::cast(op.getKeys().getType()); int64_t output_size = InferTensorSize(context, value_type); int64_t cost = kLookupTableFindCostScale * output_size; - if (key_type.getElementType().isa()) + if (mlir::isa(key_type.getElementType())) cost *= kLookupTableFindStringKeyCostScale; return cost; @@ -74,15 +75,15 @@ int64_t InferLookupTableFindV2Cost(const CostContext& context, // The cost function for tf.GatherV2. int64_t InferGatherV2Cost(const CostContext& context, mlir::TF::GatherV2Op op) { - return InferTensorSize(context, - op.getOutput().getType().cast()); + return InferTensorSize( + context, mlir::cast(op.getOutput().getType())); } // The cost function for tf.SparseSegmentSumOp. template int64_t InferSparseSegmentOpCost(const CostContext& context, OpType op) { return InferTensorSize( - context, op.getOutput().getType().template cast()); + context, mlir::cast(op.getOutput().getType())); } // CostFunctionRegistry is a map from op names to their cost functions. @@ -145,8 +146,8 @@ void CostAnalysis::AnalyzeArguments(mlir::func::FuncOp func_op) { // Use the max size among function inputs as the default size of dynamic // shaped tensors in the function. for (auto arg : func_op.getArguments()) { - if (!arg.getType().isa()) continue; - auto type = arg.getType().cast(); + if (!mlir::isa(arg.getType())) continue; + auto type = mlir::cast(arg.getType()); if (type.hasRank()) { max_arg_size_ = std::max(max_arg_size_, GetRankedTensorSize(type)); } @@ -204,7 +205,7 @@ void CostAnalysis::EvaluateCost(mlir::Operation* op) { // For other ops, use the sum of input sizes as its cost. int64_t cost = kDefaultCheapCost; for (auto operand : op->getOperands()) { - auto type = operand.getType().cast(); + auto type = mlir::cast(operand.getType()); if (type.hasRank()) { cost += GetRankedTensorSize(type); } else { diff --git a/tensorflow/compiler/mlir/tfrt/ir/BUILD b/tensorflow/compiler/mlir/tfrt/ir/BUILD index 68e9624e118453..cbdcc602ec3397 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/BUILD @@ -29,6 +29,7 @@ cc_library( ":tfrt_fallback_opdefs_inc_gen", "@llvm-project//mlir:IR", "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", ], ) @@ -98,6 +99,7 @@ cc_library( deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@tf_runtime//:basic_kernels_opdefs", ], ) diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD index bfc93b9252ccbf..5bc69ce7a70758 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD @@ -59,6 +59,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InliningUtils", + "@llvm-project//mlir:Support", ], ) @@ -167,6 +168,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:InliningUtils", "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", "@tf_runtime//:compiler_tfrt_op_interfaces", "@tf_runtime//:compiler_tfrt_traits", diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.cc b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.cc index 50d4cb1214250b..b4e337f328b27b 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.cc +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/Region.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_ops.h" @@ -73,17 +74,17 @@ mlir::Type MlrtDialect::parseType(mlir::DialectAsmParser &parser) const { // Print a type registered to this dialect. void MlrtDialect::printType(mlir::Type type, mlir::DialectAsmPrinter &os) const { - if (type.isa()) { + if (mlir::isa(type)) { os << "future"; return; } - if (type.isa()) { + if (mlir::isa(type)) { os << "promise"; return; } - if (type.isa()) { + if (mlir::isa(type)) { os << "async_handle"; return; } diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.cc b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.cc index fc4cb6a93a28ea..d6ddc8f96fd901 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.cc +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/DialectImplementation.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.h" @@ -74,7 +75,7 @@ mlir::Type TensorflowMlrtDialect::parseType( // Print a type registered to this dialect. void TensorflowMlrtDialect::printType(mlir::Type type, mlir::DialectAsmPrinter &os) const { - if (type.isa()) { + if (mlir::isa(type)) { os << "tensor"; return; } diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.cc b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.cc index 4bc8a6842bffe1..dd47e81ee1a6df 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.cc +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.cc @@ -18,6 +18,7 @@ limitations under the License. #include "mlir/IR/DialectImplementation.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace tfrt { namespace fallback { @@ -47,12 +48,12 @@ Type FallbackDialect::parseType(DialectAsmParser &parser) const { /// Print a type registered to this dialect. void FallbackDialect::printType(Type type, DialectAsmPrinter &os) const { - if (type.isa()) { + if (mlir::isa(type)) { os << "tf_tensor"; return; } - if (type.isa()) { + if (mlir::isa(type)) { os << "tf_allocator"; return; } diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.cc b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.cc index 3a835b3796962d..30f6aa234a2d59 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.cc +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.cc @@ -18,6 +18,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace tfrt { namespace fallback_common { @@ -31,8 +32,8 @@ void GetExecuteOpAttrsCommon( mlir::Builder builder(context); for (auto iter : op_attr_array) { - auto key_value = iter.cast().getValue(); - llvm::StringRef key = key_value[0].cast().getValue(); + auto key_value = mlir::cast(iter).getValue(); + llvm::StringRef key = mlir::cast(key_value[0]).getValue(); mlir::Attribute value = key_value[1]; op_attrs->push_back({key, value}); } diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.h b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.h index e78d247c038c64..0cddb1017a33d8 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.h +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_common.h @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tfrt/basic_kernels/opdefs/types.h" // from @tf_runtime namespace tfrt { @@ -30,9 +31,9 @@ template mlir::LogicalResult VerifyExecuteOpCommon(OpTy op) { auto op_attr_array = op.getOpAttrs().getValue(); for (auto op_attr : op_attr_array) { - auto key_value = op_attr.template dyn_cast(); + auto key_value = mlir::dyn_cast(op_attr); if (!key_value || key_value.getValue().size() != 2 || - !key_value.getValue()[0].template isa()) + !mlir::isa(key_value.getValue()[0])) return op.emitOpError() << "each op_attr should be a key-value pair, " "where the key is a string"; } @@ -47,10 +48,10 @@ mlir::LogicalResult VerifyFallbackExecuteOp(OpTy op) { // Verify function attributes. auto op_func_attr_array = op.getOpFuncAttrs().getValue(); for (auto op_attr : op_func_attr_array) { - auto key_value = op_attr.template dyn_cast(); + auto key_value = mlir::dyn_cast(op_attr); if (!key_value || key_value.getValue().size() != 2 || - !key_value.getValue()[0].template isa() || - !key_value.getValue()[1].template isa()) + !mlir::isa(key_value.getValue()[0]) || + !mlir::isa(key_value.getValue()[1])) return op.emitOpError() << "each op_func_attr should be a key-value " "pair, where both the key and the value are " "strings"; @@ -63,11 +64,11 @@ void PrintExecuteOpFuncAttribute(mlir::OpAsmPrinter &p, OpTy op) { auto op_func_attrs = op.getOpFuncAttrs(); if (!op_func_attrs.empty()) { auto print_key_value = [&](mlir::Attribute attr) { - auto key_value = attr.cast().getValue(); + auto key_value = mlir::cast(attr).getValue(); auto key = key_value[0]; auto value = key_value[1]; - p << key.cast().getValue(); + p << mlir::cast(key).getValue(); p << " = "; p << value; }; @@ -84,11 +85,11 @@ void PrintExecuteOpCommon(mlir::OpAsmPrinter &p, OpTy op) { auto op_attrs = op.getOpAttrs(); if (!op_attrs.empty()) { auto print_key_value = [&](mlir::Attribute attr) { - auto key_value = attr.cast().getValue(); + auto key_value = mlir::cast(attr).getValue(); auto key = key_value[0]; auto value = key_value[1]; - p << key.cast().getValue(); + p << mlir::cast(key).getValue(); p << " = "; p << value; }; diff --git a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc index e3acecf75e3073..93d50a012a6fed 100644 --- a/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc +++ b/tensorflow/compiler/mlir/tfrt/saved_model/saved_model.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" @@ -42,9 +43,9 @@ namespace { using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; llvm::StringRef ProcessIndexPath(mlir::ArrayAttr index_path) { - if (index_path.size() == 1 && index_path[0].isa()) { + if (index_path.size() == 1 && mlir::isa(index_path[0])) { // TODO(chky): Support cases where index_path is not a single string. - return index_path[0].cast().getValue(); + return mlir::cast(index_path[0]).getValue(); } return ""; } @@ -92,8 +93,8 @@ Status MapFunctionSignaturesFromTFSavedModelMLIR( if (auto input_index_path = func.getArgAttrOfType( i, kTfSavedModelIndexPathAttr)) { input_names.push_back(ProcessIndexPath(input_index_path)); - auto statusor_spec = - ProcessTensorSpec(func_type.getInput(i).cast()); + auto statusor_spec = ProcessTensorSpec( + mlir::cast(func_type.getInput(i))); if (!statusor_spec.ok()) { status = std::move(statusor_spec).status(); return mlir::WalkResult::interrupt(); @@ -120,8 +121,8 @@ Status MapFunctionSignaturesFromTFSavedModelMLIR( if (auto output_index_path = func.getResultAttrOfType( i, kTfSavedModelIndexPathAttr)) { output_names.push_back(ProcessIndexPath(output_index_path)); - auto statusor_spec = - ProcessTensorSpec(func_type.getResult(i).cast()); + auto statusor_spec = ProcessTensorSpec( + mlir::cast(func_type.getResult(i))); if (!statusor_spec.ok()) { status = std::move(statusor_spec).status(); return mlir::WalkResult::interrupt(); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/attr_lowering_utils.cc b/tensorflow/compiler/mlir/tfrt/transforms/attr_lowering_utils.cc index efedf36452dc12..a0ae8cb06e45df 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/attr_lowering_utils.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/attr_lowering_utils.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Types.h" +#include "mlir/Support/LLVM.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" @@ -37,39 +38,39 @@ mlir::TypeAttr ConvertTypeAttribute(mlir::TypeAttr type_attr, if (IsSupportedTfrtNumericDType(type)) return type_attr; // For TF custom types, we convert it to custom corert types. - if (type.isa()) + if (mlir::isa(type)) return mlir::TypeAttr::get( tfrt::corert::StringType::get(builder.getContext())); - if (type.isa()) + if (mlir::isa(type)) return mlir::TypeAttr::get( tfrt::corert::ResourceType::get(builder.getContext())); - if (type.isa()) + if (mlir::isa(type)) return mlir::TypeAttr::get( tfrt::corert::VariantType::get(builder.getContext())); - if (type.isa()) { + if (mlir::isa(type)) { return mlir::TypeAttr::get( tfrt::corert::Quint8Type::get(builder.getContext())); } - if (type.isa()) { + if (mlir::isa(type)) { return mlir::TypeAttr::get( tfrt::corert::Quint16Type::get(builder.getContext())); } - if (type.isa()) { + if (mlir::isa(type)) { return mlir::TypeAttr::get( tfrt::corert::Qint8Type::get(builder.getContext())); } - if (type.isa()) { + if (mlir::isa(type)) { return mlir::TypeAttr::get( tfrt::corert::Qint16Type::get(builder.getContext())); } - if (type.isa()) { + if (mlir::isa(type)) { return mlir::TypeAttr::get( tfrt::corert::Qint32Type::get(builder.getContext())); } @@ -86,14 +87,15 @@ mlir::Attribute ConvertAttribute(mlir::Attribute attr, mlir::Builder& builder) { // attributes are not supported yet. // Return directly if the attribute is already supported. - if (attr.isa()) + if (mlir::isa(attr)) return attr; // For type attributes, we convert non-standard MLIR types to corresponding // corert types. - if (auto type_attr = attr.dyn_cast()) { - if (auto shape_type = type_attr.getValue().dyn_cast()) { + if (auto type_attr = mlir::dyn_cast(attr)) { + if (auto shape_type = + mlir::dyn_cast(type_attr.getValue())) { if (!shape_type.hasRank()) return tfrt::corert::ShapeAttr::get(builder.getContext()); @@ -106,7 +108,7 @@ mlir::Attribute ConvertAttribute(mlir::Attribute attr, mlir::Builder& builder) { // Convert the attribute to the corresponding format in TFRT dialect if // needed. - if (auto shape_attr = attr.dyn_cast()) { + if (auto shape_attr = mlir::dyn_cast(attr)) { if (!shape_attr.hasRank()) return tfrt::corert::ShapeAttr::get(builder.getContext()); return tfrt::corert::ShapeAttr::get(builder.getContext(), @@ -114,7 +116,7 @@ mlir::Attribute ConvertAttribute(mlir::Attribute attr, mlir::Builder& builder) { } // For arrays, we recursively convert the elements. - if (auto array_attr = attr.dyn_cast()) { + if (auto array_attr = mlir::dyn_cast(attr)) { llvm::SmallVector attrs; attrs.reserve(array_attr.size()); for (auto attr : array_attr) { @@ -140,7 +142,7 @@ bool IsSupportedTfrtNumericDType(mlir::Type type) { type.isUnsignedInteger(64)) return true; - if (auto complex_type = type.dyn_cast()) { + if (auto complex_type = mlir::dyn_cast(type)) { auto element_type = complex_type.getElementType(); if (element_type.isF32() || element_type.isF64()) return true; } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.cc b/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.cc index 48d9f755c16c7b..910f7a83a9f7af 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/corert_converter.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Types.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -46,7 +47,7 @@ CoreRTConverter::CoreRTConverter( addConversion([](tfrt::corert::TensorHandleType type) { return type; }); addConversion([=](mlir::TensorType type) -> std::optional { // Ref types are not supported in both compiler and runtime. - if (type.getElementType().isa()) + if (mlir::isa(type.getElementType())) return std::nullopt; return tensor_handle_type(); }); @@ -74,8 +75,8 @@ mlir::ArrayAttr CoreRTConverter::CreateOpFuncAttrs( auto attr_key = key_and_value.getName(); auto attr_value = key_and_value.getValue(); if (!IsUnusedTfrtAttribute(attr_key) && - attr_value.isa()) { - auto func_attr = attr_value.dyn_cast(); + mlir::isa(attr_value)) { + auto func_attr = mlir::dyn_cast(attr_value); auto converted = CanonicalizeTensorflowFunctionName( symbol_table, func_attr.getValue().str(), use_mlir_func_name); if (!converted) return {}; @@ -126,7 +127,7 @@ std::optional CoreRTConverter::ParseDeviceName( } auto parsed_device_name = - ParseDeviceName(device_attr.cast().getValue()); + ParseDeviceName(mlir::cast(device_attr).getValue()); if (!parsed_device_name) op->emitWarning("failed to parse device name."); return parsed_device_name; } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/cross_device_transfer.cc b/tensorflow/compiler/mlir/tfrt/transforms/cross_device_transfer.cc index 2b1e29c5347096..5f539b8c520e65 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/cross_device_transfer.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/cross_device_transfer.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/core/util/device_name_utils.h" #include "tfrt/basic_kernels/opdefs/basic_kernels.h" // from @tf_runtime @@ -81,8 +82,8 @@ static std::string GetDevice(Operation *op) { SmallVector, 4> attrs; execute_op.getOpAttrs(&attrs); for (std::pair entry : attrs) { - if (entry.first == kDeviceAttr && entry.second.isa()) { - device = entry.second.cast().getValue().str(); + if (entry.first == kDeviceAttr && mlir::isa(entry.second)) { + device = mlir::cast(entry.second).getValue().str(); break; } } @@ -94,7 +95,7 @@ static std::string GetDevice(Operation *op) { // Return the device of the given value. static std::string GetDevice(mlir::Value value, func::FuncOp parent_func_op) { std::string device = ""; - if (BlockArgument block_arg = value.dyn_cast()) { + if (BlockArgument block_arg = mlir::dyn_cast(value)) { if (StringAttr device_attr = parent_func_op.getArgAttrOfType( block_arg.getArgNumber(), kTFRTDeviceAttr)) { device = device_attr.getValue().str(); @@ -140,10 +141,10 @@ void CrossDeviceTransferPass::runOnOperation() { for (mlir::Value arg : op->getOperands()) { // Do not transfer non-TensorHandle values. - if (!arg.getType().isa()) continue; + if (!mlir::isa(arg.getType())) continue; // Do not transfer the result of corert.transfer op. - if (OpResult op_result = arg.dyn_cast()) { + if (OpResult op_result = mlir::dyn_cast(arg)) { Operation *defining_op = arg.getDefiningOp(); if (llvm::isa(defining_op)) continue; } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.cc b/tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.cc index ef8c2ec38ce64b..77759a631f177f 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/fallback_converter.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h" #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h" @@ -31,7 +32,7 @@ FallbackConverter::FallbackConverter(mlir::MLIRContext *context) addConversion([](tfrt::fallback::TFTensorType type) { return type; }); addConversion([=](mlir::TensorType type) -> std::optional { // Ref types are not supported in both compiler and runtime. - if (type.getElementType().isa()) { + if (mlir::isa(type.getElementType())) { return std::nullopt; } @@ -46,9 +47,9 @@ FallbackConverter::FallbackConverter(mlir::MLIRContext *context) mlir::Value ConvertCoreRTTensorHandleToFallbackTensor( mlir::Location loc, llvm::StringRef device, mlir::Value value, mlir::ConversionPatternRewriter &rewriter) { - if (value.getType().isa()) return value; + if (mlir::isa(value.getType())) return value; - if (!value.getType().isa()) return {}; + if (!mlir::isa(value.getType())) return {}; mlir::OpBuilder::InsertionGuard guard(rewriter); @@ -82,9 +83,9 @@ mlir::Value ConvertCoreRTTensorHandleToFallbackTensor( mlir::Value ConvertFallbackTensorToCoreRTTensorHandle( mlir::Location loc, mlir::Value value, mlir::ConversionPatternRewriter &rewriter) { - if (value.getType().isa()) return value; + if (mlir::isa(value.getType())) return value; - if (!value.getType().isa()) return {}; + if (!mlir::isa(value.getType())) return {}; // Use CPU device by default if no device is specified. llvm::StringRef device = GetDefaultCpuDeviceName(); @@ -134,7 +135,7 @@ mlir::LogicalResult ConvertFallbackOperands( llvm::SmallVectorImpl *new_operands, mlir::ConversionPatternRewriter &rewriter) { for (auto operand : operands) { - if (!operand.getType().isa()) { + if (!mlir::isa(operand.getType())) { auto new_operand = ConvertCoreRTTensorHandleToFallbackTensor( op->getLoc(), device, operand, rewriter); if (!new_operand) diff --git a/tensorflow/compiler/mlir/tfrt/transforms/insert_tensor_copy.cc b/tensorflow/compiler/mlir/tfrt/transforms/insert_tensor_copy.cc index 2b2cbbcf318d15..d6c87abeedd54a 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/insert_tensor_copy.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/insert_tensor_copy.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.h" #include "tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback_async.h" @@ -69,7 +70,7 @@ class InsertFallbackTensorCopy // Process function arguments first. for (auto arg : func_op.getArguments()) { - if (!arg.getType().isa()) continue; + if (!mlir::isa(arg.getType())) continue; InsertFallbackTensorCopyForValue(arg, func_op->getLoc(), builder, stream_analysis); } @@ -91,7 +92,7 @@ class InsertFallbackTensorCopy // Process each result value. for (auto result : op->getResults()) { - if (!result.getType().isa()) continue; + if (!mlir::isa(result.getType())) continue; InsertFallbackTensorCopyForValue(result, op->getLoc(), builder, stream_analysis); } @@ -147,7 +148,7 @@ class InsertFallbackTensorCopy // For each stream, we will create one new value that replaces the uses in // that stream. - assert(value.getType().isa()); + assert(mlir::isa(value.getType())); // The number of results is the number candidate streams. llvm::SmallVector result_types(copies.size(), diff --git a/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc b/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc index 17e3d8be95204d..01ae5811b46b9a 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/lower_saved_model.cc @@ -44,6 +44,7 @@ limitations under the License. #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" @@ -231,7 +232,7 @@ void FindCalleesRecursiveForOp(const mlir::SymbolTable &symbol_table, llvm::StringSet<> &callees) { for (const auto &named_attr : op->getAttrs()) { if (auto symbol_attr = - named_attr.getValue().dyn_cast()) { + mlir::dyn_cast(named_attr.getValue())) { auto symbol = symbol_attr.getValue(); if (!callees.contains(symbol)) { callees.insert(symbol); @@ -337,7 +338,8 @@ class LowerTFSavedModelPass func_op->removeAttr(kTfSavedModelExportedNamesAttr); for (auto exported_name : exported_names) { auto exported_func_op = func_op.clone(); - exported_func_op.setName(exported_name.cast()); + exported_func_op.setName( + mlir::cast(exported_name)); // If it is a session initializer, we want to maximize parallelism // and do not perform any stream merge, to minimize latency. @@ -631,8 +633,8 @@ class ConvertReferenceVariableToResourceVariablePass mlir::LogicalResult ConvertReferenceVariableToResourceVariable( mlir::TF::VariableV2Op var_op) { - auto tensor_type = - mlir::TF::DropRefType(var_op.getRef().getType()).cast(); + auto tensor_type = mlir::cast( + mlir::TF::DropRefType(var_op.getRef().getType())); llvm::SmallVector identity_ops; llvm::SmallVector assign_ops; diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc index 5942fe2ddb816f..dc43f3bc213f6d 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/host_runtime/tfrt_ops.h.inc" @@ -380,11 +381,11 @@ class IfrtRestoreVariableOpConversion }; std::optional DecodeLongName(mlir::Location loc) { - if (auto name_loc = loc.dyn_cast()) { + if (auto name_loc = mlir::dyn_cast(loc)) { return name_loc.getName().str(); } - if (auto fused_loc = loc.dyn_cast()) { + if (auto fused_loc = mlir::dyn_cast(loc)) { std::string fused_name; for (auto l : fused_loc.getLocations()) { if (auto n = DecodeLongName(l)) { @@ -1027,7 +1028,7 @@ class TfToMlrtConversionPass type_converter_.addConversion( [=](mlir::TensorType type) -> std::optional { // Ref types are not supported in both compiler and runtime. - if (type.getElementType().isa()) + if (mlir::isa(type.getElementType())) return std::nullopt; return tf_mlrt::TFTensorType::get(context); }); @@ -1037,8 +1038,8 @@ class TfToMlrtConversionPass mlir::ValueRange inputs, mlir::Location loc) -> mlir::Value { if (inputs.size() != 1) return mlir::Value(); - if (inputs[0].getType().isa()) { - if (desired_type.isa()) { + if (mlir::isa(inputs[0].getType())) { + if (mlir::isa(desired_type)) { return builder.create(loc, desired_type, inputs[0]); } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc b/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc index e13e8f36b1a436..0e47fad312c7cc 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/optimize.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h" @@ -41,7 +42,8 @@ class FoldDeviceIndex : public mlir::OpRewritePattern { int32_t i = 0; mlir::ArrayAttr device_names = op.getDeviceNames(); for (; i < device_names.size(); ++i) { - auto device_name = device_names[i].cast().getValue(); + auto device_name = + mlir::cast(device_names[i]).getValue(); if (device_name == parsed_name.type) break; } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc b/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc index 848498c68ba71c..8bdb39c913bf75 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/sink_in_invariant_ops.cc @@ -39,6 +39,7 @@ limitations under the License. #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -82,7 +83,7 @@ llvm::SmallVector FindValueInCallees( llvm::SmallDenseSet callees; for (const auto &named_attr : caller->getAttrs()) { if (auto symbol_attr = - named_attr.getValue().dyn_cast()) { + mlir::dyn_cast(named_attr.getValue())) { auto symbol = symbol_attr.getValue(); auto callee = symbol_table.lookup(symbol); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc index f893ec7f2b3f33..48aee8f98f314d 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/tf_to_tfrt.cc @@ -41,6 +41,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -181,7 +182,7 @@ class GpuCompileAndExecuteOpConversion if (!xla_function) { return op->emitWarning("failed to find 'function' attribute"); } - auto func_attr = xla_function.dyn_cast(); + auto func_attr = mlir::dyn_cast(xla_function); if (!func_attr || func_attr.getValue().empty()) { return op->emitWarning("failed to find a non-empty 'function' attribute"); } @@ -512,7 +513,7 @@ class FallbackConstOpConversion mlir::ConversionPatternRewriter &rewriter) const override { // Some data types are handled separately using a fast path. if (IsSupportedTfrtNumericDType(op.getDtype()) || - op.getDtype().isa()) + mlir::isa(op.getDtype())) return failure(); // For other data types that do not have a fast path (eg. quantized types), @@ -757,7 +758,7 @@ class CoreRTConstDenseTensorOpConversion auto new_op = rewriter.create( op.getLoc(), corert_converter_.tensor_handle_type(), - op.getValue().cast()); + mlir::cast(op.getValue())); rewriter.replaceOp(op, new_op->getResult(0)); return success(); } @@ -870,10 +871,10 @@ class CoreRTConstStringTensorOpConversion LogicalResult matchAndRewrite( mlir::TF::ConstOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // NOLINT - if (!op.getDtype().isa()) return failure(); + if (!mlir::isa(op.getDtype())) return failure(); DenseStringElementsAttr attr = - op.getValue().cast(); + mlir::cast(op.getValue()); llvm::SmallVector values; values.reserve(attr.getNumElements()); @@ -1067,7 +1068,7 @@ class TFRTCaseOpConversion : public mlir::OpConversionPattern { mlir::Value index_operand = adaptor.getOperands()[0]; // TODO(b/182233401): Support TF tensor; remove the conversion op here. - if (index_operand.getType().isa()) { + if (mlir::isa(index_operand.getType())) { // TODO(b/182232457): Support other devices. index_operand = rewriter @@ -1079,7 +1080,7 @@ class TFRTCaseOpConversion : public mlir::OpConversionPattern { tfrt_compiler::GetDefaultCpuDeviceName()) .getResult(0); } - if (!index_operand.getType().isa()) + if (!mlir::isa(index_operand.getType())) return op.emitError( "branch index operand is expected to be a TensorHandle."); mlir::Value index_value = @@ -1101,7 +1102,7 @@ class TFRTCaseOpConversion : public mlir::OpConversionPattern { static mlir::Value GetPredicate(mlir::Operation *op, mlir::Value cond_operand, mlir::ConversionPatternRewriter &rewriter) { - if (!cond_operand.getType().isa()) { + if (!mlir::isa(cond_operand.getType())) { cond_operand = tfrt_compiler::ConvertCoreRTTensorHandleToFallbackTensor( op->getLoc(), tfrt_compiler::GetDefaultCpuDeviceName(), cond_operand, rewriter); @@ -1721,7 +1722,7 @@ class TfToTfrtConversionPass auto return_op = llvm::cast(block.getTerminator()); auto chain = return_op->getOperand(0); - assert(chain.getType().isa()); + assert(mlir::isa(chain.getType())); dangling_values.push_back(chain); mlir::OpBuilder builder(return_op); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/utils.cc b/tensorflow/compiler/mlir/tfrt/transforms/utils.cc index 9b602babeafe22..711438f21d13f9 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/utils.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/utils.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tfrt/basic_kernels/opdefs/tfrt_base.h" // from @tf_runtime @@ -34,7 +35,7 @@ limitations under the License. namespace tensorflow { bool IsResourceArgument(mlir::Value value) { - auto arg = value.dyn_cast(); + auto arg = mlir::dyn_cast(value); if (!arg) return false; auto func = llvm::cast(arg.getOwner()->getParentOp()); @@ -44,7 +45,7 @@ bool IsResourceArgument(mlir::Value value) { bool IsResultVariable(const mlir::Value &original_operand, const mlir::Value &operand) { - if (original_operand.isa()) { + if (mlir::isa(original_operand)) { auto defining_op = original_operand.getDefiningOp(); // TODO(b/174753886): When device assignment is properly done, we @@ -99,7 +100,8 @@ bool IsSessionInitializer(mlir::func::FuncOp op) { if (!session_initializer_op) return false; for (auto sym_ref : session_initializer_op.getInitializers()) { - if (op.getSymName() == sym_ref.cast().getValue()) + if (op.getSymName() == + mlir::cast(sym_ref).getValue()) return true; } diff --git a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc index 7917cf15391e39..f68f0a2bcc80cc 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/import_model.cc @@ -123,7 +123,7 @@ absl::StatusOr> ExportXlaFunctions( func_op->walk([&](mlir::Operation* op) { for (const mlir::NamedAttribute& attr : op->getAttrs()) { if (const auto sym = - attr.getValue().dyn_cast()) { + mlir::dyn_cast(attr.getValue())) { mlir::Operation* func = mlir::SymbolTable::lookupNearestSymbolFrom(op, sym); if (func) { diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD index bc887cdfc966f9..79caaccb155101 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/BUILD @@ -29,6 +29,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -43,6 +44,7 @@ tf_cc_test( "@com_google_googletest//:gtest_main", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:resource_loader", "@local_tsl//tsl/platform:status_matchers", "@local_tsl//tsl/platform:test", diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc index 98cb26acdba8fa..06606c6fff345e 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.cc @@ -29,6 +29,7 @@ limitations under the License. #include "llvm/ADT/TypeSwitch.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/tfrt/mlrt/bytecode/executable.h" namespace mlrt { @@ -37,8 +38,8 @@ namespace { // LINT.IfChange(mlrt_attributes) bool CanBeInlined(mlir::Attribute attr, absl::string_view data) { // FlatSymbolRefAttr is a special case as we are emitting it as integer. - return attr.isa() && + return mlir::isa( + attr) && data.size() <= sizeof(uint32_t); } // LINT.ThenChange(../../../../../core/tfrt/mlrt/interpreter/attribute_span.h:mlrt_attributes) @@ -64,7 +65,7 @@ std::optional EncodeListOfInteger(mlir::ArrayAttr array) { mlir::Type type; for (int i = 0; i < array.size(); ++i) { - if (auto integer_attr = array[i].dyn_cast()) { + if (auto integer_attr = mlir::dyn_cast(array[i])) { if (type && integer_attr.getType() != type) return std::nullopt; type = integer_attr.getType(); llvm::APInt value = integer_attr.getValue(); @@ -85,7 +86,7 @@ std::optional EncodeListOfSymbolRef( auto ctor = bc::New>(&allocator, array.size()); for (int i = 0; i < array.size(); ++i) { - if (auto symbol_ref = array[i].dyn_cast()) { + if (auto symbol_ref = mlir::dyn_cast(array[i])) { ctor.ConstructAt(i, module_context.GetFunctionId(symbol_ref.getValue())); } else { return std::nullopt; @@ -117,7 +118,7 @@ std::optional EncodeListOfString(mlir::ArrayAttr array) { auto ctor = bc::New>(&allocator, array.size()); for (int i = 0; i < array.size(); ++i) { - if (auto string_attr = array[i].dyn_cast()) { + if (auto string_attr = mlir::dyn_cast(array[i])) { ctor.ConstructAt(i, string_attr.getValue().str()); } else { return std::nullopt; diff --git a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc index 8214d1d6deb3b3..07f1fbfdb0c0c1 100644 --- a/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc +++ b/tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/tfrt/mlrt/bytecode/executable.h" #include "tensorflow/core/tfrt/mlrt/interpreter/attribute_span.h" #include "tsl/platform/resource_loader.h" @@ -299,13 +300,13 @@ class CustomDense { absl::StatusOr EncodeCustomDense(const ModuleEmitterContext&, mlir::Attribute attr) { - auto dense_int_attr = attr.dyn_cast(); + auto dense_int_attr = mlir::dyn_cast(attr); if (!dense_int_attr) return absl::InvalidArgumentError( "The element of the custom dense attribute must be an integer."); - if (dense_int_attr.getElementType().cast().getWidth() != - 32) { + if (mlir::cast(dense_int_attr.getElementType()) + .getWidth() != 32) { return absl::InvalidArgumentError( "The element of the custom dense attribute must be an i32 integer."); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD index 42d679c35d0173..2862b79475a930 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD @@ -92,6 +92,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", "@local_tsl//tsl/protobuf:error_codes_proto_impl_cc", ], ) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc index a3b8c07cc1bb66..ee295c19335ff5 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/DialectImplementation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_status.cc.inc" // Generated dialect definitions. @@ -61,11 +62,11 @@ Type TFFrameworkDialect::parseType(DialectAsmParser &parser) const { /// Print a type registered to this dialect. void TFFrameworkDialect::printType(Type type, DialectAsmPrinter &os) const { - if (type.isa()) { + if (mlir::isa(type)) { os << "op_kernel_context"; return; } - if (type.isa()) { + if (mlir::isa(type)) { os << "jit_callable"; return; } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc index 277511fed098e0..e863d056b2fb04 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc @@ -52,6 +52,7 @@ limitations under the License. #include "mlir/Interfaces/DataLayoutInterfaces.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" // from @llvm-project #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" // from @llvm-project @@ -91,7 +92,7 @@ bool IsSmallAlloc(Value alloc) { constexpr unsigned kMaximumSizeInBytes = 64; constexpr unsigned kMaxRankOfAllocatedMemRef = 1; - auto type = alloc.getType().dyn_cast(); + auto type = mlir::dyn_cast(alloc.getType()); if (!type || !alloc.getDefiningOp()) return false; if (!type.hasStaticShape()) { // Check if the dynamic shape dimension of the alloc is produced by RankOp diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index d1c3af0b9a6191..489e13d172c059 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -39,6 +39,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMCommonConversion", "@llvm-project//mlir:LLVMDialect", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], @@ -75,6 +76,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc index 37999960cd69e7..45dbbf993bb6be 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/buffer_reuse_pass.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h" @@ -115,7 +116,7 @@ class BufferReuseAnalysis { // Find reuse candidates for the regarded allocation. SmallVector local_reuse_candidates; for (BlockArgument old_buffer : arguments) { - if (!old_buffer.getType().isa()) continue; + if (!mlir::isa(old_buffer.getType())) continue; // Lifetime criterion: Only reuse buffers that are no longer used on // first reuse, i.e. they are no longer alive. @@ -177,15 +178,16 @@ class BufferReuseAnalysis { std::vector get_buffer_arguments(func::FuncOp &f) { std::vector buffer_arguments; for (BlockArgument arg : f.getArguments()) { - if (arg.getType().isa()) buffer_arguments.push_back(arg); + if (mlir::isa(arg.getType())) + buffer_arguments.push_back(arg); } return buffer_arguments; } bool can_reuse_locally(Operation *op, Value old_buffer, Value new_buffer) { // For now, we support only memrefs with the same memory layout. - auto old_buffer_ty = old_buffer.getType().dyn_cast(); - auto new_buffer_ty = old_buffer.getType().dyn_cast(); + auto old_buffer_ty = mlir::dyn_cast(old_buffer.getType()); + auto new_buffer_ty = mlir::dyn_cast(old_buffer.getType()); if (!old_buffer_ty || !new_buffer_ty || old_buffer_ty.getLayout() != new_buffer_ty.getLayout()) return false; @@ -205,7 +207,7 @@ class BufferReuseAnalysis { // Allow dropping dimensions but no permutations. int64_t i = -1; for (AffineExpr expr : map.getResults()) { - auto dim_expr = expr.dyn_cast(); + auto dim_expr = mlir::dyn_cast(expr); if (!dim_expr || dim_expr.getPosition() <= i) return false; i = dim_expr.getPosition(); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc index 32faed506e52b4..9f41b399e2fd7f 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/copy_cleanup_pass.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" namespace mlir { @@ -135,7 +136,7 @@ void RemoveCopyIfTargetIsFunctionArg(func::FuncOp func) { Block &body = func.getBody().front(); for (auto &op : llvm::reverse(body.without_terminator())) { if (auto copy = dyn_cast(op)) { - auto block_arg = copy.getTarget().dyn_cast(); + auto block_arg = mlir::dyn_cast(copy.getTarget()); if (!block_arg) break; if (!isa(block_arg.getOwner()->getParentOp()) || !block_arg.hasOneUse()) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc index b7ad2d4d28b129..a6f23f1ad43aa8 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h" @@ -64,7 +65,7 @@ std::optional FindOpKernelContext(Operation *op) { return std::nullopt; } Value ctx = func.getArgument(0); - if (!ctx.getType().isa()) { + if (!mlir::isa(ctx.getType())) { return std::nullopt; } return ctx; @@ -114,7 +115,8 @@ struct DeallocOpConverter : public OpConversionPattern { if (!ctx) return failure(); // Operand with no layout is expected. - auto operand_memref_type = dealloc.getMemref().getType().cast(); + auto operand_memref_type = + mlir::cast(dealloc.getMemref().getType()); if (!operand_memref_type.getLayout().isIdentity()) { return failure(); } diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc index ed1138849e5a06..b5b22008dcb951 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/embed_tf_framework_pass.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h" @@ -71,7 +72,7 @@ class EmbedTFFrameworkPass } FunctionType func_type = op.getFunctionType(); return func_type.getNumInputs() > 0 && - func_type.getInput(0).isa(); + mlir::isa(func_type.getInput(0)); }); target.addDynamicallyLegalOp(IsNotInsideTfEntryFunction); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc index a6b24b1a3afcc3..fa6ba2491d5906 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/same_shape_propagation.cc @@ -217,7 +217,7 @@ class ShapeEqualityKnowledge { } if (auto alloc = dyn_cast(op)) { SmallVector shape; - ShapedType type = alloc.getResult().getType().cast(); + ShapedType type = mlir::cast(alloc.getResult().getType()); fillShapeFromAllocLike(alloc.getDynamicSizes(), type, shape); registerAssociation(ShapeValue{shape}, alloc.getResult()); return; @@ -225,7 +225,7 @@ class ShapeEqualityKnowledge { if (auto alloc = dyn_cast(op)) { // Construct a symbol representing the allocated shape. SmallVector shape; - ShapedType type = alloc.getResult().getType().cast(); + ShapedType type = mlir::cast(alloc.getResult().getType()); fillShapeFromAllocLike(alloc.getDynSizes(), type, shape); registerAssociation(ShapeValue{shape}, alloc.getResult()); return; @@ -331,7 +331,7 @@ struct PropagateShapeKnowledgeToKernels // Position of the kernel argument we are currently at. int kernel_p = 0; for (auto operand : launch.getKernelOperands()) { - auto memref = operand.getType().dyn_cast(); + auto memref = mlir::dyn_cast(operand.getType()); if (!memref) { // Scalar argument, advance kernel position by one. kernel_p++; @@ -341,7 +341,7 @@ struct PropagateShapeKnowledgeToKernels if (!knowledge.haveSameShape(operand, previous.first)) { continue; } - auto previous_type = previous.first.getType().cast(); + auto previous_type = mlir::cast(previous.first.getType()); // We use the first equality found and replace uses of corresponding // size and (potentially) stride information here. auto args_to_replace = memref.getRank(); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tensorflow_abi_knowledge_propagation.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tensorflow_abi_knowledge_propagation.cc index 89ecd6da13be74..a7d26813239571 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tensorflow_abi_knowledge_propagation.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tensorflow_abi_knowledge_propagation.cc @@ -56,7 +56,7 @@ struct PropagateTfAbiKnowledgeToKernelsPass // the inner stride is one. // TODO(herhut): Insert asserts in debug mode to check this. for (auto argument : function.getArguments()) { - if (argument.getType().isa()) { + if (mlir::isa(argument.getType())) { worklist.push_back(argument); allocated_by_tf_runtime.insert(argument); offset_is_zero.insert(argument); @@ -95,7 +95,7 @@ struct PropagateTfAbiKnowledgeToKernelsPass llvm::SmallDenseMap constants; auto loc = kernel.getLoc(); for (auto operand : launch.getKernelOperands()) { - auto memref = operand.getType().dyn_cast(); + auto memref = mlir::dyn_cast(operand.getType()); if (!memref) { // Scalar argument, advance kernel position by one. kernel_p++; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc index cffa5e7b44691e..8748b188f35dfa 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_framework_legalize_to_llvm.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/rewriters.h" @@ -96,14 +97,14 @@ class ConvertToLLVMCallOpPattern : public ConvertOpToLLVMPattern { Location loc, Type size_ty, Type element_ty, std::optional attr, ConversionPatternRewriter *rewriter) const { - assert(size_ty.isa() && "expect integer size type"); - assert(element_ty.isa() && "expect integer element type"); + assert(mlir::isa(size_ty) && "expect integer size type"); + assert(mlir::isa(element_ty) && "expect integer element type"); return ConvertArrayAttrToStackAllocatedArray( loc, size_ty, element_ty, attr, rewriter, [&](Attribute attr) { return rewriter->create( loc, element_ty, rewriter->getIntegerAttr(element_ty, - attr.cast().getInt())); + mlir::cast(attr).getInt())); }); } }; @@ -227,7 +228,7 @@ class TFDeallocOpConverter : public ConvertToLLVMCallOpPattern { TFDeallocOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // TODO(herhut) Support unranked memrefs. - if (!op.getMemref().getType().isa()) return failure(); + if (!mlir::isa(op.getMemref().getType())) return failure(); MemRefDescriptor memref(adaptor.getMemref()); Value allocated_bytes_ptr = memref.allocatedPtr(rewriter, op.getLoc()); @@ -429,7 +430,7 @@ class ReportErrorOpConverter std::string err_str; llvm::raw_string_ostream err_stream(err_str); err_stream << message; - if (!loc.isa()) { + if (!mlir::isa(loc)) { err_stream << " at "; loc.print(err_stream); } @@ -465,16 +466,18 @@ class NullMemRefOpConverter : public ConvertOpToLLVMPattern { MLIRContext *ctx = null_memref_op.getContext(); mlir::Operation *op = null_memref_op.getOperation(); - auto shaped_result_type = null_memref_op.getType().cast(); - auto mem_space = - shaped_result_type.getMemorySpace().dyn_cast_or_null(); + auto shaped_result_type = + mlir::cast(null_memref_op.getType()); + auto mem_space = mlir::dyn_cast_or_null( + shaped_result_type.getMemorySpace()); unsigned address_space = static_cast(mem_space ? mem_space.getInt() : 0); LLVM::LLVMPointerType llvm_ptr_type = LLVM::LLVMPointerType::get(ctx, address_space); Value zero = createIndexAttrConstant(rewriter, loc, getIndexType(), 0); - if (auto result_type = null_memref_op.getType().dyn_cast()) { + if (auto result_type = + mlir::dyn_cast(null_memref_op.getType())) { // Set all dynamic sizes to 1 and compute fake strides. SmallVector dyn_sizes( result_type.getNumDynamicDims(), @@ -497,7 +500,7 @@ class NullMemRefOpConverter : public ConvertOpToLLVMPattern { return success(); } - auto result_type = null_memref_op.getType().cast(); + auto result_type = mlir::cast(null_memref_op.getType()); Type llvm_result_type = type_converter.convertType(result_type); auto desc = @@ -506,7 +509,7 @@ class NullMemRefOpConverter : public ConvertOpToLLVMPattern { // Extract address space and element type. auto targetType = - null_memref_op.getResult().getType().cast(); + mlir::cast(null_memref_op.getResult().getType()); unsigned addressSpace = *getTypeConverter()->getMemRefAddressSpace(targetType); @@ -549,7 +552,7 @@ class IsValidMemRefOpConverter MemRefDescriptor desc(adaptor.getArg()); // Compare every size in the descriptor to 0 to check num_elements == 0. - int64_t rank = op.getArg().getType().cast().getRank(); + int64_t rank = mlir::cast(op.getArg().getType()).getRank(); Value is_empty_shape = rewriter.create( loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); Value zero = createIndexAttrConstant(rewriter, loc, getIndexType(), 0); diff --git a/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc index 350d9e47545fb0..6523824611a603 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc @@ -39,6 +39,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -90,7 +91,7 @@ struct ConvertUint8QConstOp : public RewritePattern { } mlir::DenseElementsAttr src_dense_attr = - tfl_qconst_op.getValue().cast(); + mlir::cast(tfl_qconst_op.getValue()); double type_range_min = static_cast(output_element_type.getStorageTypeMin() - diff --git a/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc b/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc index b64e4eda6d5e37..ba194e3e81c964 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/dequantize_tfl_softmax.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" @@ -52,8 +53,8 @@ LogicalResult TosaDequantizeTFLSoftmaxPattern::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { TFL::SoftmaxOp tfl_softmax_op = cast(op); RankedTensorType input_type = - tfl_softmax_op.getInput().getType().cast(); - if (!input_type.getElementType().isa()) { + mlir::cast(tfl_softmax_op.getInput().getType()); + if (!mlir::isa(input_type.getElementType())) { return failure(); } Location loc = tfl_softmax_op.getLoc(); diff --git a/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc index efee9aa9e9b9c2..ff07b9d6f91039 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/fuse_bias_tf.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" @@ -80,7 +81,7 @@ LogicalResult ConvertTFBiasAddOp::matchAndRewrite( auto value = tf_biasadd_op.getValue(); auto bias = tf_biasadd_op.getBias(); - auto bias_shape = bias.getType().cast().getShape(); + auto bias_shape = mlir::cast(bias.getType()).getShape(); if (bias_shape.size() != 1) { return rewriter.notifyMatchFailure(op, "bias tensor must be rank 1"); } @@ -89,7 +90,8 @@ LogicalResult ConvertTFBiasAddOp::matchAndRewrite( llvm::dyn_cast_if_present(value.getDefiningOp())) { // Sanity check to confirm rhs() has the expected shape of bias auto filter_shape = - tf_conv2d_op.getFilter().getType().cast().getShape(); + mlir::cast(tf_conv2d_op.getFilter().getType()) + .getShape(); // Assume the filter shape is [H, W, I, O] if (filter_shape.back() != bias_shape.back()) { @@ -114,7 +116,8 @@ LogicalResult ConvertTFBiasAddOp::matchAndRewrite( llvm::dyn_cast_if_present(value.getDefiningOp())) { // Sanity check to confirm rhs() has the expected shape of bias auto filter_shape = - tf_conv3d_op.getFilter().getType().cast().getShape(); + mlir::cast(tf_conv3d_op.getFilter().getType()) + .getShape(); // Assume the filter shape is [D, H, W, I, O] if (filter_shape.back() != bias_shape.back()) { diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index 3b461b8b36ae42..7b631ff418056e 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -56,6 +56,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h" @@ -571,7 +572,7 @@ std::optional convertZerosLikeOp(PatternRewriter& rewriter, Attribute zero_attr = rewriter.getZeroAttr(zero_type); return CreateOpAndInfer(rewriter, op->getLoc(), zero_type, - zero_attr.cast()) + mlir::cast(zero_attr)) .getResult(); } @@ -586,12 +587,12 @@ std::optional convertMultiplyOp(PatternRewriter& rewriter, Operation* op, // Not a shaped tensor output if (!input_lhs_type || !input_rhs_type || !output_type) return std::nullopt; - bool input_lhs_is_qtype = - input_lhs_type.getElementType().isa(); - bool input_rhs_is_qtype = - input_rhs_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + bool input_lhs_is_qtype = mlir::isa( + input_lhs_type.getElementType()); + bool input_rhs_is_qtype = mlir::isa( + input_rhs_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_lhs_is_qtype != output_is_qtype || input_rhs_is_qtype != output_is_qtype) { @@ -603,12 +604,12 @@ std::optional convertMultiplyOp(PatternRewriter& rewriter, Operation* op, if (output_is_qtype) { ShapedType rescale_type = output_type.clone(rewriter.getI32Type()); - auto input_lhs_qtype = input_lhs_type.getElementType() - .cast(); - auto input_rhs_qtype = input_rhs_type.getElementType() - .cast(); - auto output_qtype = - output_type.getElementType().cast(); + auto input_lhs_qtype = mlir::cast( + input_lhs_type.getElementType()); + auto input_rhs_qtype = mlir::cast( + input_rhs_type.getElementType()); + auto output_qtype = mlir::cast( + output_type.getElementType()); // MLIR store scale as double, but TFLite store scale as float // Downcasting from double to float to match TFLite behavior @@ -661,11 +662,11 @@ std::optional convertSquaredDifferenceOp(PatternRewriter& rewriter, } bool x_is_qtype = - x_type.getElementType().isa(); + mlir::isa(x_type.getElementType()); bool y_is_qtype = - y_type.getElementType().isa(); - bool result_is_qtype = - result_type.getElementType().isa(); + mlir::isa(y_type.getElementType()); + bool result_is_qtype = mlir::isa( + result_type.getElementType()); if (x_is_qtype != result_is_qtype || y_is_qtype != result_is_qtype) { (void)rewriter.notifyMatchFailure( @@ -678,11 +679,11 @@ std::optional convertSquaredDifferenceOp(PatternRewriter& rewriter, // Then scale back to I8 if (result_is_qtype) { auto x_qtype = - x_type.getElementType().cast(); + mlir::cast(x_type.getElementType()); auto y_qtype = - y_type.getElementType().cast(); - auto result_qtype = - result_type.getElementType().cast(); + mlir::cast(y_type.getElementType()); + auto result_qtype = mlir::cast( + result_type.getElementType()); uint32_t result_bits = result_qtype.getStorageTypeIntegralWidth(); @@ -779,16 +780,16 @@ std::optional convertConcatV2Op(PatternRewriter& rewriter, Operation* op, } mlir::quant::UniformQuantizedType result_quant_type = - result_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + result_type.getElementType()); SmallVector values_rescaled; for (auto v : values) { RankedTensorType operand_type = dyn_cast(v.getType()); mlir::quant::UniformQuantizedType operand_quant_type = - operand_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + operand_type.getElementType()); // tfl.concat currently allows different scales for each input tensor, which // TFlite team will fix in: @@ -818,7 +819,8 @@ std::optional convertConcatV2Op(PatternRewriter& rewriter, Operation* op, } } - int32_t tensor_rank = values[0].getType().cast().getRank(); + int32_t tensor_rank = + mlir::cast(values[0].getType()).getRank(); if (axis < 0) axis += tensor_rank; if ((axis < 0) || (axis > tensor_rank)) { @@ -1046,7 +1048,8 @@ std::optional convertSpaceToBatchNDOp(PatternRewriter& rewriter, // [padded_shape[M] / block_shape[M-1]] + // remaining_shape int32_t a2_reshape_a1_rank = - a2_reshape_a1_op.getResult().getType().cast().getRank(); + mlir::cast(a2_reshape_a1_op.getResult().getType()) + .getRank(); SmallVector a3_perm(a2_reshape_a1_rank); SmallVector a3_transpose_shape(a2_reshape_a1_rank); @@ -1579,17 +1582,19 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, int32_t input_rank = input_type.getShape().size(); ArrayRef logits_shape = output_type.getShape(); - if (input_type.getElementType().isa() && - output_type.getElementType().isa()) { + if (mlir::isa(input_type.getElementType()) && + mlir::isa(output_type.getElementType())) { SmallVector rsum_shape_v(input_type.getShape().begin(), input_type.getShape().end() - 1); rsum_shape_v.push_back(1); ArrayRef rsum_shape(rsum_shape_v); // The if condition already checks if these are UQTs mlir::quant::UniformQuantizedType in_quant_type = - input_type.getElementType().cast(); + mlir::cast( + input_type.getElementType()); mlir::quant::UniformQuantizedType out_quant_type = - output_type.getElementType().cast(); + mlir::cast( + output_type.getElementType()); auto int16_element_qtype = mlir::quant::UniformQuantizedType::get( true, rewriter.getIntegerType(16), rewriter.getF32Type(), 1.0f, 0, @@ -2005,11 +2010,11 @@ std::optional convertLogSoftmaxOp(PatternRewriter& rewriter, } mlir::quant::UniformQuantizedType in_quant_type = - input_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + input_type.getElementType()); mlir::quant::UniformQuantizedType out_quant_type = - output_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + output_type.getElementType()); if (in_quant_type || out_quant_type) { (void)rewriter.notifyMatchFailure( op, "quantized log_softmax lowering not implemented yet"); @@ -2271,7 +2276,8 @@ std::optional> convertSplitOp( tensorflow::ConvertMlirShapeToTF(new_shape))); } - RankedTensorType slice_type = slice_value.getType().cast(); + RankedTensorType slice_type = + mlir::cast(slice_value.getType()); assert((slice_type.getDimSize(axis) % num_split) == 0); // Each slice has a different beginning point. @@ -2442,7 +2448,7 @@ std::optional convertStridedSliceOp( // Limitations: // * This implementation only supports ellipsis_mask=0 for now auto input_type = dyn_cast(input_value.getType()); - ShapedType result_type = result_value.getType().cast(); + ShapedType result_type = mlir::cast(result_value.getType()); if (ellipsis_mask != 0) { (void)rewriter.notifyMatchFailure(op, "ellipses mask not supported yet"); @@ -2586,7 +2592,7 @@ std::optional convertStridedSliceOp( if (all_strides_one) { auto reversed = reverseNegativeStride(rewriter, op, a1_slice_op.getResult(), strides); - auto shape = reversed.getType().cast().getShape(); + auto shape = mlir::cast(reversed.getType()).getShape(); SmallVector new_shape; for (int i = 0; i < input_rank; ++i) { @@ -2684,7 +2690,7 @@ std::optional convertFloorDivOp(PatternRewriter& rewriter, Operation* op, Type element_type = output_type.getElementType(); - if (element_type.isa()) { + if (mlir::isa(element_type)) { return CreateOpAndInfer(rewriter, op->getLoc(), output_type, lhs_value, rhs_value) .getResult(); @@ -2738,14 +2744,14 @@ std::optional convertFusedActivation(PatternRewriter& rewriter, if (!input_type) return std::nullopt; bool input_is_qtype = - input_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); if (input_is_qtype) { // We can always make output/input tensor's scale/zp always be the same // when legalizing fused_activation_function, as it's generated during // legalization. - auto input_qtype = - input_type.getElementType().cast(); + auto input_qtype = mlir::cast( + input_type.getElementType()); if (fused_activation_fn.getValue() == "NONE") { return input_value; @@ -3079,9 +3085,9 @@ std::optional convertReduceProdOp(PatternRewriter& rewriter, if (!input_type) return std::nullopt; bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype || output_is_qtype) { (void)rewriter.notifyMatchFailure( @@ -3105,9 +3111,9 @@ std::optional convertReduceSumOp(PatternRewriter& rewriter, if (!input_type) return std::nullopt; bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { (void)rewriter.notifyMatchFailure( @@ -3123,10 +3129,10 @@ std::optional convertReduceSumOp(PatternRewriter& rewriter, Type reduce_element_type = input_type.getElementType(); if (input_is_qtype) { - auto input_qtype = - input_type.getElementType().cast(); - auto output_qtype = - output_type.getElementType().cast(); + auto input_qtype = mlir::cast( + input_type.getElementType()); + auto output_qtype = mlir::cast( + output_type.getElementType()); int32_t input_shift = 20; @@ -3164,9 +3170,9 @@ std::optional convertReduceMeanOp(PatternRewriter& rewriter, if (!input_type) return std::nullopt; bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { (void)rewriter.notifyMatchFailure( @@ -3176,7 +3182,8 @@ std::optional convertReduceMeanOp(PatternRewriter& rewriter, } // Only supports float type mean() if it's non-quantized - if (!input_is_qtype && !output_type.getElementType().isa()) { + if (!input_is_qtype && + !mlir::isa(output_type.getElementType())) { op->emitWarning("input unquantized type but output element not FloatType"); return std::nullopt; } @@ -3206,10 +3213,10 @@ std::optional convertReduceMeanOp(PatternRewriter& rewriter, int32_t output_scale_shift = 0; if (input_is_qtype) { - auto input_qtype = - input_type.getElementType().cast(); - auto output_qtype = - output_type.getElementType().cast(); + auto input_qtype = mlir::cast( + input_type.getElementType()); + auto output_qtype = mlir::cast( + output_type.getElementType()); const int32_t scale_width = 32; computeMultiplierAndShift(1.0f, input_scale_multiplier, input_scale_shift, @@ -3275,9 +3282,9 @@ std::optional convertResizeOp(PatternRewriter& rewriter, Operation* op, } bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { (void)rewriter.notifyMatchFailure( @@ -3287,7 +3294,7 @@ std::optional convertResizeOp(PatternRewriter& rewriter, Operation* op, } if (!input_is_qtype) { - if (!input_type.getElementType().isa()) { + if (!mlir::isa(input_type.getElementType())) { (void)rewriter.notifyMatchFailure( op, "only quantized or float types supported"); return std::nullopt; @@ -3406,8 +3413,8 @@ std::optional convertResizeOp(PatternRewriter& rewriter, Operation* op, // If quantized bilinear mode, need to lower to RESIZE + RESCALE pair. if (is_bilinear) { RankedTensorType output_acc_type; - auto input_element_qtype = - input_type.getElementType().cast(); + auto input_element_qtype = mlir::cast( + input_type.getElementType()); bool is_scale32; @@ -3505,7 +3512,7 @@ std::optional convertQuantizeOp(PatternRewriter& rewriter, Operation* op, auto output_element_type = output_type.getElementType(); // output element type could only be quantized integer - if (!output_element_type.isa()) { + if (!mlir::isa(output_element_type)) { (void)rewriter.notifyMatchFailure( op, "lowering quantizeOp but output element type not quantized"); return std::nullopt; @@ -3546,7 +3553,7 @@ std::optional convertDequantizeOp(PatternRewriter& rewriter, if (!input_type) return std::nullopt; // input element type could only be quantized integer - if (!input_type.getElementType().isa()) + if (!mlir::isa(input_type.getElementType())) return std::nullopt; std::optional zp_val; @@ -3839,8 +3846,8 @@ std::optional convertTFConv2DCommon( stride = rewriter.getDenseI64ArrayAttr({1, 1}); } else { // Note: hardcoded to NHWC for now - int64_t stride_h = strides_attr[1].cast().getInt(); - int64_t stride_w = strides_attr[2].cast().getInt(); + int64_t stride_h = mlir::cast(strides_attr[1]).getInt(); + int64_t stride_w = mlir::cast(strides_attr[2]).getInt(); stride = rewriter.getDenseI64ArrayAttr({stride_h, stride_w}); } } @@ -3849,8 +3856,8 @@ std::optional convertTFConv2DCommon( dilation = rewriter.getDenseI64ArrayAttr({1, 1}); } else { // Note: hardcoded to NHWC for now - int64_t dilation_h = dilations_attr[1].cast().getInt(); - int64_t dilation_w = dilations_attr[2].cast().getInt(); + int64_t dilation_h = mlir::cast(dilations_attr[1]).getInt(); + int64_t dilation_w = mlir::cast(dilations_attr[2]).getInt(); dilation = rewriter.getDenseI64ArrayAttr({dilation_h, dilation_w}); } } @@ -3915,8 +3922,8 @@ std::optional convertConv3DCommon(PatternRewriter& rewriter, DenseI64ArrayAttr strides_attr = rewriter.getDenseI64ArrayAttr(strides); DenseI64ArrayAttr dilations_attr = rewriter.getDenseI64ArrayAttr(dilations); - RankedTensorType input_type = input.getType().cast(); - RankedTensorType filter_type = filter.getType().cast(); + RankedTensorType input_type = mlir::cast(input.getType()); + RankedTensorType filter_type = mlir::cast(filter.getType()); DenseI64ArrayAttr pads_attr; if (!getPaddingValuesFromPadType(tf_pad, data_format_tf, 0, input_type, @@ -3963,9 +3970,9 @@ std::optional convertTFConv3DCommon( // Defaults to [1, 1, 1]. strides = {1, 1, 1}; } else { - int64_t stride_d = strides_attr[1].cast().getInt(); - int64_t stride_h = strides_attr[2].cast().getInt(); - int64_t stride_w = strides_attr[3].cast().getInt(); + int64_t stride_d = mlir::cast(strides_attr[1]).getInt(); + int64_t stride_h = mlir::cast(strides_attr[2]).getInt(); + int64_t stride_w = mlir::cast(strides_attr[3]).getInt(); strides = {stride_d, stride_h, stride_w}; } @@ -3974,9 +3981,9 @@ std::optional convertTFConv3DCommon( // Defaults to [1, 1, 1]. dilations = {1, 1, 1}; } else { - int64_t dilation_d = dilations_attr[1].cast().getInt(); - int64_t dilation_h = dilations_attr[2].cast().getInt(); - int64_t dilation_w = dilations_attr[3].cast().getInt(); + int64_t dilation_d = mlir::cast(dilations_attr[1]).getInt(); + int64_t dilation_h = mlir::cast(dilations_attr[2]).getInt(); + int64_t dilation_w = mlir::cast(dilations_attr[3]).getInt(); dilations = {dilation_d, dilation_h, dilation_w}; } @@ -4686,7 +4693,7 @@ std::optional convertSinOp(PatternRewriter& rewriter, Operation* op, std::optional convertSignOp(PatternRewriter& rewriter, Operation* op, Value input, RankedTensorType output_type) { auto output_elem_type = output_type.getElementType(); - if (output_elem_type.isa()) { + if (mlir::isa(output_elem_type)) { (void)rewriter.notifyMatchFailure(op, "tfl quantization not yet supported"); return std::nullopt; } @@ -4695,7 +4702,7 @@ std::optional convertSignOp(PatternRewriter& rewriter, Operation* op, // one element. Value pos_one, neg_one, zero; ImplicitLocOpBuilder builder(op->getLoc(), rewriter); - if (output_elem_type.isa()) { + if (mlir::isa(output_elem_type)) { pos_one = getTosaConstTensorSingleF32(rewriter, op, 1.0f); neg_one = getTosaConstTensorSingleF32(rewriter, op, -1.0f); zero = getTosaConstTensorSingleF32(rewriter, op, 0.0f); @@ -4733,7 +4740,7 @@ std::optional convertBroadcastToOp(PatternRewriter& rewriter, } Type element_type = input_type.getElementType(); - if (element_type.isa()) { + if (mlir::isa(element_type)) { (void)rewriter.notifyMatchFailure(op, "input element type is complex"); return std::nullopt; } @@ -4816,7 +4823,7 @@ std::optional convertBroadcastToOp(PatternRewriter& rewriter, RankedTensorType output_type = tensorflow::GetTypeFromTFTensorShape(new_shape, element_type); - if (element_type.isa()) { + if (mlir::isa(element_type)) { // F32: legalize to broadcastable Add with (-0.f), instead of 0.f. // This is to preserve original values: // for corner case where x = -0.f diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc index 0ab48cc417fc98..10be9422c60762 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc @@ -257,7 +257,7 @@ LogicalResult ConvertTFSignOp::matchAndRewrite( auto tf_sign_op = cast(op); RankedTensorType output_type = - tf_sign_op.getResult().getType().cast(); + mlir::cast(tf_sign_op.getResult().getType()); std::optional result = convertSignOp(rewriter, op, tf_sign_op.getX(), output_type); @@ -270,7 +270,8 @@ LogicalResult ConvertTFSignOp::matchAndRewrite( LogicalResult ConvertTFSinOp::matchAndRewrite(Operation* op, PatternRewriter& rewriter) const { auto tf_sin_op = cast(op); - ShapedType output_type = tf_sin_op.getResult().getType().cast(); + ShapedType output_type = + mlir::cast(tf_sin_op.getResult().getType()); std::optional result = convertSinOp(rewriter, op, tf_sin_op.getX(), output_type); @@ -289,8 +290,8 @@ LogicalResult ConvertTFCosOp::matchAndRewrite(Operation* op, if (!input_ty || !output_ty) return failure(); - bool input_is_fp = input_ty.getElementType().isa(); - bool output_is_fp = output_ty.getElementType().isa(); + bool input_is_fp = mlir::isa(input_ty.getElementType()); + bool output_is_fp = mlir::isa(output_ty.getElementType()); if (!input_is_fp || !output_is_fp) { return rewriter.notifyMatchFailure( @@ -427,7 +428,7 @@ LogicalResult ConvertTFRoundOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "input not tensor type"); } - if (input_type.getElementType().isa()) { + if (mlir::isa(input_type.getElementType())) { std::optional result = convertRoundOp( rewriter, op, tf_round_op.getResult(), tf_round_op.getX()); @@ -519,7 +520,7 @@ LogicalResult ConvertTFRealDivOp::matchAndRewrite( Type element_type = output_type.getElementType(); - if (element_type.isa()) { + if (mlir::isa(element_type)) { CreateReplaceOpAndInfer(rewriter, op, output_type, tf_div_op.getX(), tf_div_op.getY()); return success(); @@ -717,7 +718,8 @@ LogicalResult ConvertTFMaxPoolOp::matchAndRewrite( LogicalResult ConvertTFConcatV2Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_concatv2_op = cast(op); - auto result_type = tf_concatv2_op.getResult().getType().cast(); + auto result_type = + mlir::cast(tf_concatv2_op.getResult().getType()); SmallVector values(tf_concatv2_op.getValues()); ElementsAttr axis_elems; @@ -877,7 +879,7 @@ LogicalResult ConvertTFFillOp::matchAndRewrite( DenseArrayAttr fill_attr; // Convert to a compatible zero type - if (value_elem.getShapedType().getElementType().isa()) { + if (mlir::isa(value_elem.getShapedType().getElementType())) { SmallVector fill_arr( total_size, value_elem.getValues()[0].getValue().convertToFloat()); @@ -891,7 +893,7 @@ LogicalResult ConvertTFFillOp::matchAndRewrite( DenseI32ArrayAttr::get(rewriter.getContext(), llvm::ArrayRef(fill_arr)); } auto fill_const_op = CreateOpAndInfer( - rewriter, op->getLoc(), fill_type, fill_attr.cast()); + rewriter, op->getLoc(), fill_type, mlir::cast(fill_attr)); rewriter.replaceOp(op, {fill_const_op.getResult()}); return success(); @@ -911,8 +913,8 @@ LogicalResult ConvertTFConv2DOp::matchAndRewrite( RankedTensorType bias_type = tensorflow::GetTypeFromTFTensorShape( {bias_dim}, filter_type.getElementType()); auto bias_attr = rewriter.getZeroAttr(bias_type); - auto bias = CreateOpAndInfer(rewriter, op->getLoc(), bias_type, - bias_attr.cast()); + auto bias = CreateOpAndInfer( + rewriter, op->getLoc(), bias_type, mlir::cast(bias_attr)); std::optional result = convertTFConv2DCommon( rewriter, op, output_type, tf_conv2d_op.getInput(), @@ -946,8 +948,8 @@ LogicalResult ConvertTFConv3DOp::matchAndRewrite( RankedTensorType bias_type = RankedTensorType::get({bias_dim}, filter_type.getElementType()); auto bias_attr = rewriter.getZeroAttr(bias_type); - auto bias = CreateOpAndInfer(rewriter, op->getLoc(), bias_type, - bias_attr.cast()); + auto bias = CreateOpAndInfer( + rewriter, op->getLoc(), bias_type, mlir::cast(bias_attr)); std::optional result = convertTFConv3DCommon( rewriter, op, output_type, tf_conv3d_op.getInput(), @@ -1036,8 +1038,8 @@ LogicalResult ConvertTFDepthwiseConv2dNativeOp::matchAndRewrite( RankedTensorType bias_type = tensorflow::GetTypeFromTFTensorShape( {bias_dim}, filter_type.getElementType()); auto bias_attr = rewriter.getZeroAttr(bias_type); - auto bias = CreateOpAndInfer(rewriter, op->getLoc(), bias_type, - bias_attr.cast()); + auto bias = CreateOpAndInfer( + rewriter, op->getLoc(), bias_type, mlir::cast(bias_attr)); CreateReplaceOpAndInfer( rewriter, op, output_type, tf_dwconv2d_op.getInput(), diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index 024fdd88af43c1..39755b1add0f00 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -321,9 +321,9 @@ LogicalResult ConvertTFLReluOp::matchAndRewrite( if (!input_type || !output_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { return rewriter.notifyMatchFailure( @@ -373,9 +373,9 @@ LogicalResult ConvertTFLRelu1Op::matchAndRewrite( if (!input_type || !output_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { return rewriter.notifyMatchFailure( @@ -423,14 +423,15 @@ LogicalResult ConvertTFLRelu0To1Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_relu0to1_op = cast(op); - ShapedType input_type = tfl_relu0to1_op.getX().getType().cast(); + ShapedType input_type = + mlir::cast(tfl_relu0to1_op.getX().getType()); ShapedType output_type = - tfl_relu0to1_op.getResult().getType().cast(); + mlir::cast(tfl_relu0to1_op.getResult().getType()); bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { return rewriter.notifyMatchFailure( @@ -444,9 +445,11 @@ LogicalResult ConvertTFLRelu0To1Op::matchAndRewrite( if (output_is_qtype && input_is_qtype) { UniformQuantizedType input_qtype = - input_type.getElementType().cast(); + mlir::cast( + input_type.getElementType()); UniformQuantizedType output_qtype = - output_type.getElementType().cast(); + mlir::cast( + output_type.getElementType()); clamp_min = output_qtype.getZeroPoint(); @@ -482,9 +485,9 @@ LogicalResult ConvertTFLRelu6Op::matchAndRewrite( if (!input_type || !output_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { return rewriter.notifyMatchFailure( @@ -539,12 +542,12 @@ static LogicalResult prepareMatchAndRewriteComparison( // Not a shaped tensor output if (!input_x_type || !input_y_type || !output_type) return failure(); - bool input_x_is_qtype = - input_x_type.getElementType().isa(); - bool input_y_is_qtype = - input_y_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + bool input_x_is_qtype = mlir::isa( + input_x_type.getElementType()); + bool input_y_is_qtype = mlir::isa( + input_y_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_x_is_qtype != input_y_is_qtype || input_y_is_qtype != output_is_qtype) { @@ -671,20 +674,20 @@ static LogicalResult matchAndRewriteAddSub(Operation* op, auto tfl_add_op = cast(op); ShapedType input_lhs_type = - tfl_add_op.getLhs().getType().template dyn_cast(); + mlir::dyn_cast(tfl_add_op.getLhs().getType()); ShapedType input_rhs_type = - tfl_add_op.getRhs().getType().template dyn_cast(); + mlir::dyn_cast(tfl_add_op.getRhs().getType()); ShapedType output_type = - tfl_add_op.getResult().getType().template dyn_cast(); + mlir::dyn_cast(tfl_add_op.getResult().getType()); // Not a ranked tensor output if (!input_lhs_type || !input_rhs_type || !output_type) return failure(); - bool input_lhs_is_qtype = - input_lhs_type.getElementType().isa(); - bool input_rhs_is_qtype = - input_rhs_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + bool input_lhs_is_qtype = mlir::isa( + input_lhs_type.getElementType()); + bool input_rhs_is_qtype = mlir::isa( + input_rhs_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_lhs_is_qtype != output_is_qtype || input_rhs_is_qtype != output_is_qtype) { @@ -847,7 +850,7 @@ LogicalResult ConvertTFLSignOp::matchAndRewrite( auto tfl_sign_op = cast(op); RankedTensorType output_type = - tfl_sign_op.getResult().getType().cast(); + mlir::cast(tfl_sign_op.getResult().getType()); std::optional result = convertSignOp(rewriter, op, tfl_sign_op.getX(), output_type); @@ -932,7 +935,7 @@ LogicalResult ConvertTFLRoundOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "input not shaped tensor type"); } - if (input_type.getElementType().isa()) { + if (mlir::isa(input_type.getElementType())) { std::optional result = convertRoundOp( rewriter, op, tfl_round_op.getResult(), tfl_round_op.getX()); @@ -962,7 +965,7 @@ LogicalResult ConvertTFLDivOp::matchAndRewrite( Type element_type = output_type.getElementType(); Value div_op; - if (element_type.isa()) { + if (mlir::isa(element_type)) { div_op = CreateOpAndInfer(rewriter, op->getLoc(), output_type, tfl_div_op.getLhs(), tfl_div_op.getRhs()) @@ -1006,12 +1009,12 @@ LogicalResult ConvertTFLMaximumOp::matchAndRewrite( // Not a shaped tensor output if (!input_lhs_type || !input_rhs_type || !output_type) return failure(); - bool input_lhs_is_qtype = - input_lhs_type.getElementType().isa(); - bool input_rhs_is_qtype = - input_rhs_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + bool input_lhs_is_qtype = mlir::isa( + input_lhs_type.getElementType()); + bool input_rhs_is_qtype = mlir::isa( + input_rhs_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_lhs_is_qtype != output_is_qtype || input_rhs_is_qtype != output_is_qtype) { @@ -1062,12 +1065,12 @@ LogicalResult ConvertTFLMinimumOp::matchAndRewrite( // Not a shaped tensor output if (!input_lhs_type || !input_rhs_type || !output_type) return failure(); - bool input_lhs_is_qtype = - input_lhs_type.getElementType().isa(); - bool input_rhs_is_qtype = - input_rhs_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + bool input_lhs_is_qtype = mlir::isa( + input_lhs_type.getElementType()); + bool input_rhs_is_qtype = mlir::isa( + input_rhs_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_lhs_is_qtype != output_is_qtype || input_rhs_is_qtype != output_is_qtype) { @@ -1215,12 +1218,12 @@ LogicalResult ConvertTFLAveragePool2DOp::matchAndRewrite( // Tosa supports FP16 and FP32 accumulator type for FP16 input. When the time // FP16 is supported, the accumulator type can be selected based on trade-off // between performance and accuracy. Set to FP32 by default. - TypeAttr acc_attr = average_etype.isa() + TypeAttr acc_attr = mlir::isa(average_etype) ? mlir::TypeAttr::get(rewriter.getF32Type()) : mlir::TypeAttr::get(rewriter.getIntegerType(32)); Value result; - if (average_etype.isa()) { + if (mlir::isa(average_etype)) { // TensorFlow Lite doesn't use the zero point when calculating // quantized average pool, while TOSA does. Force the TOSA // zero_points to zero to ensure that the calculations match @@ -1445,11 +1448,11 @@ LogicalResult ConvertTFLConv2DOp::matchAndRewrite( if (!filter_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); bool filter_is_qtype = - filter_type.getElementType().isa(); + mlir::isa(filter_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(output_type.getElementType()); if ((input_is_qtype != filter_is_qtype) || (input_is_qtype != output_is_qtype)) { @@ -1499,7 +1502,7 @@ LogicalResult ConvertTFLConv2DOp::matchAndRewrite( output_is_qtype ? rewriter.getI32Type() : output_type.getElementType(); if (unquantized_bias) { Type new_bias_ety = getElementTypeOrSelf(unquantized_bias.getType()); - if (auto qtype = new_bias_ety.dyn_cast()) { + if (auto qtype = mlir::dyn_cast(new_bias_ety)) { new_bias_ety = qtype.getStorageType(); } if (new_bias_ety.getIntOrFloatBitWidth() > @@ -1555,11 +1558,11 @@ LogicalResult ConvertTFLConv3DOp::matchAndRewrite( } bool input_is_qtype = - input_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); bool filter_is_qtype = - filter_type.getElementType().isa(); + mlir::isa(filter_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(output_type.getElementType()); if ((input_is_qtype != filter_is_qtype) || (input_is_qtype != output_is_qtype)) { @@ -1578,7 +1581,7 @@ LogicalResult ConvertTFLConv3DOp::matchAndRewrite( RankedTensorType::get({bias_dim}, filter_type.getElementType()); auto bias_attr = rewriter.getZeroAttr(bias_type); unquantized_bias = CreateOpAndInfer( - rewriter, op->getLoc(), bias_type, bias_attr.cast()); + rewriter, op->getLoc(), bias_type, mlir::cast(bias_attr)); } SmallVector strides({tfl_conv3d_op.getStrideD(), @@ -1588,7 +1591,7 @@ LogicalResult ConvertTFLConv3DOp::matchAndRewrite( tfl_conv3d_op.getDilationHFactor(), tfl_conv3d_op.getDilationWFactor()}); Type bias_ety = - unquantized_bias.getType().cast().getElementType(); + mlir::cast(unquantized_bias.getType()).getElementType(); std::optional a1_conv3d_op = convertConv3DCommon( rewriter, op, output_type.clone(bias_ety), tfl_conv3d_op.getInput(), tfl_conv3d_op.getFilter(), unquantized_bias, strides, dilations, @@ -1634,11 +1637,11 @@ LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite( if (!filter_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); bool filter_is_qtype = - filter_type.getElementType().isa(); + mlir::isa(filter_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(output_type.getElementType()); if ((input_is_qtype != filter_is_qtype) || (input_is_qtype != output_is_qtype)) { @@ -1721,7 +1724,7 @@ LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite( } if (!zero_bias) return failure(); - Type bias_ety = zero_bias->getType().cast().getElementType(); + Type bias_ety = mlir::cast(zero_bias->getType()).getElementType(); auto a1_conv2d_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type.clone(bias_ety), @@ -1770,11 +1773,11 @@ LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite( if (!filter_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); bool filter_is_qtype = - filter_type.getElementType().isa(); + mlir::isa(filter_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(output_type.getElementType()); if ((input_is_qtype != filter_is_qtype) || (input_is_qtype != output_is_qtype)) { @@ -1863,7 +1866,7 @@ LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite( Value unquantized_bias = tfl_conv2d_op.getBias(); if (unquantized_bias) { Type new_bias_ety = getElementTypeOrSelf(unquantized_bias.getType()); - if (auto qtype = new_bias_ety.dyn_cast()) { + if (auto qtype = mlir::dyn_cast(new_bias_ety)) { new_bias_ety = qtype.getStorageType(); } if (new_bias_ety.getIntOrFloatBitWidth() > @@ -1906,7 +1909,7 @@ LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite( LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_mm_op = cast(op); - auto result_ty = tfl_mm_op.getType().cast(); + auto result_ty = mlir::cast(tfl_mm_op.getType()); Value lhs = tfl_mm_op.getX(); Value rhs = tfl_mm_op.getY(); RankedTensorType lhs_ty = dyn_cast(lhs.getType()); @@ -1916,10 +1919,12 @@ LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite( if (!lhs_ty || !rhs_ty) return failure(); - bool lhs_is_qtype = lhs_ty.getElementType().isa(); - bool rhs_is_qtype = rhs_ty.getElementType().isa(); + bool lhs_is_qtype = + mlir::isa(lhs_ty.getElementType()); + bool rhs_is_qtype = + mlir::isa(rhs_ty.getElementType()); bool result_is_qtype = - result_ty.getElementType().isa(); + mlir::isa(result_ty.getElementType()); if ((lhs_is_qtype != rhs_is_qtype) || (lhs_is_qtype != result_is_qtype)) { return rewriter.notifyMatchFailure( @@ -1951,8 +1956,8 @@ LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite( rewriter, op->getLoc(), UnrankedTensorType::get(rhs_ty.getElementType()), rhs, rewriter.getDenseI64ArrayAttr(new_rhs_shape)); - lhs_ty = lhs.getType().cast(); - rhs_ty = rhs.getType().cast(); + lhs_ty = mlir::cast(lhs.getType()); + rhs_ty = mlir::cast(rhs.getType()); } if (transpose_lhs) { @@ -1977,12 +1982,12 @@ LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite( Type output_ety; if (result_is_qtype) { - auto lhs_qty_width = lhs_ty.getElementType() - .cast() - .getStorageTypeIntegralWidth(); - auto rhs_qty_width = rhs_ty.getElementType() - .cast() - .getStorageTypeIntegralWidth(); + auto lhs_qty_width = + mlir::cast(lhs_ty.getElementType()) + .getStorageTypeIntegralWidth(); + auto rhs_qty_width = + mlir::cast(rhs_ty.getElementType()) + .getStorageTypeIntegralWidth(); if (lhs_qty_width != rhs_qty_width) { return rewriter.notifyMatchFailure( @@ -2007,7 +2012,7 @@ LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite( .getResult(); // Conditionally reshape rank back to expected rank. - auto matmul_ty = matmul.getType().cast(); + auto matmul_ty = mlir::cast(matmul.getType()); if (batch_dims.size() != 1) { llvm::SmallVector new_shape{}; for (auto d : batch_dims) { @@ -2052,11 +2057,11 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( if (!input_type || !filter_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); bool filter_is_qtype = - filter_type.getElementType().isa(); + mlir::isa(filter_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(output_type.getElementType()); if ((input_is_qtype != filter_is_qtype) || (input_is_qtype != output_is_qtype)) { @@ -2099,7 +2104,7 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( RankedTensorType new_bias_type; DenseElementsAttr bias_attr; - if (input_type.getElementType().isa()) { + if (mlir::isa(input_type.getElementType())) { SmallVector bias_arr(bias_shape[0]); for (int i = 0; i < bias_shape[0]; i++) { @@ -2120,7 +2125,7 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( op, "input must be quantized type if it's not float type"); } auto input_qtype = - input_type.getElementType().cast(); + mlir::cast(input_type.getElementType()); Type new_bias_ety = input_qtype.getStorageTypeIntegralWidth() == 16 ? rewriter.getIntegerType(48) : rewriter.getI32Type(); @@ -2136,7 +2141,7 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( bias_val = tfl_fc_op.getBias(); } - Type bias_ety = bias_val.getType().cast().getElementType(); + Type bias_ety = mlir::cast(bias_val.getType()).getElementType(); auto fc_op = CreateOpAndInfer( rewriter, op->getLoc(), UnrankedTensorType::get(bias_ety), input_val, @@ -2152,7 +2157,7 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( } // If we know the output rank, we need to ensure the output shape is correct. - ShapedType fc_type = fc_output.getType().cast(); + ShapedType fc_type = mlir::cast(fc_output.getType()); if (output_type.hasRank()) { llvm::SmallVector output_shape; @@ -2270,7 +2275,7 @@ LogicalResult ConvertTFLRankOp::matchAndRewrite( RankedTensorType::get({1}, rewriter.getIntegerType(32)); auto rank_attr = DenseI32ArrayAttr::get(rewriter.getContext(), {rank}); auto rank_const = CreateOpAndInfer( - rewriter, op->getLoc(), rank_type, rank_attr.cast()); + rewriter, op->getLoc(), rank_type, mlir::cast(rank_attr)); rewriter.replaceOp(op, {rank_const.getResult()}); @@ -2303,7 +2308,7 @@ LogicalResult ConvertTFLShapeOp::matchAndRewrite( auto shape_attr = DenseI32ArrayAttr::get(rewriter.getContext(), llvm::ArrayRef(shape_arr)); auto shape_const = CreateOpAndInfer( - rewriter, op->getLoc(), shape_type, shape_attr.cast()); + rewriter, op->getLoc(), shape_type, mlir::cast(shape_attr)); rewriter.replaceOp(op, {shape_const.getResult()}); @@ -2376,7 +2381,7 @@ LogicalResult ConvertTFLFillOp::matchAndRewrite( DenseArrayAttr fill_attr; // Convert to a compatible zero type. - if (value_elem.getShapedType().getElementType().isa()) { + if (mlir::isa(value_elem.getShapedType().getElementType())) { SmallVector fill_arr( total_size, value_elem.getValues()[0].convertToFloat()); fill_attr = @@ -2388,7 +2393,7 @@ LogicalResult ConvertTFLFillOp::matchAndRewrite( DenseI32ArrayAttr::get(rewriter.getContext(), llvm::ArrayRef(fill_arr)); } auto fill_const_op = CreateOpAndInfer( - rewriter, op->getLoc(), fill_type, fill_attr.cast()); + rewriter, op->getLoc(), fill_type, mlir::cast(fill_attr)); rewriter.replaceOp(op, {fill_const_op.getResult()}); return success(); @@ -2589,11 +2594,11 @@ LogicalResult ConvertTFLRsqrtOp::matchAndRewrite( dyn_cast(tfl_rsqrt_op.getX().getType()); mlir::quant::UniformQuantizedType input_qtype = - input_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + input_type.getElementType()); mlir::quant::UniformQuantizedType output_qtype = - output_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + output_type.getElementType()); // Quantization case if (input_qtype && output_qtype) { @@ -2636,7 +2641,7 @@ LogicalResult ConvertTFLL2NormalizationOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_l2norm_op = cast(op); auto input = tfl_l2norm_op.getInput(); - auto input_ty = input.getType().cast(); + auto input_ty = mlir::cast(input.getType()); auto loc = op->getLoc(); if (!input_ty.hasRank()) return failure(); @@ -3200,15 +3205,15 @@ LogicalResult ConvertTFLHardSwishOp::matchAndRewrite( // TFL hardswish: f(x) -> (x * relu6(x+3))/6 - if (input_type.getElementType().isa() && - output_type.getElementType().isa()) { + if (mlir::isa(input_type.getElementType()) && + mlir::isa(output_type.getElementType())) { // Should match TFLite reference numerical behavior mlir::quant::UniformQuantizedType input_qtype = - input_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + input_type.getElementType()); mlir::quant::UniformQuantizedType output_qtype = - output_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + output_type.getElementType()); auto hardswish_func = [](double v) -> double { double w = v + 3.0; @@ -3286,8 +3291,8 @@ LogicalResult ConvertTFLCosOp::matchAndRewrite( if (!input_ty || !output_ty) return failure(); - bool input_is_fp = input_ty.getElementType().isa(); - bool output_is_fp = output_ty.getElementType().isa(); + bool input_is_fp = mlir::isa(input_ty.getElementType()); + bool output_is_fp = mlir::isa(output_ty.getElementType()); if (!input_is_fp || !output_is_fp) { return rewriter.notifyMatchFailure(op, "input/result must be fp"); @@ -3440,9 +3445,9 @@ LogicalResult ConvertTFLLogisticOp::matchAndRewrite( if (!input_type || !output_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { return rewriter.notifyMatchFailure( @@ -3453,11 +3458,11 @@ LogicalResult ConvertTFLLogisticOp::matchAndRewrite( if (input_is_qtype) { ShapedType int32_type = output_type.clone(rewriter.getIntegerType(32)); mlir::quant::UniformQuantizedType input_qtype = - input_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + input_type.getElementType()); mlir::quant::UniformQuantizedType output_qtype = - output_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + output_type.getElementType()); auto sigmoid_func = [](double x) -> double { return 1.0 / (1.0 + std::exp(-x)); @@ -3511,9 +3516,9 @@ LogicalResult ConvertTFLTanhOp::matchAndRewrite( if (!input_type || !output_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { return rewriter.notifyMatchFailure( @@ -3524,11 +3529,11 @@ LogicalResult ConvertTFLTanhOp::matchAndRewrite( if (input_is_qtype) { ShapedType int32_type = output_type.clone(rewriter.getIntegerType(32)); mlir::quant::UniformQuantizedType input_qtype = - input_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + input_type.getElementType()); mlir::quant::UniformQuantizedType output_qtype = - output_type.getElementType() - .dyn_cast_or_null(); + mlir::dyn_cast_or_null( + output_type.getElementType()); auto tanh_func = [](double x) -> double { x = std::exp(-2.0 * x); @@ -3644,9 +3649,9 @@ static LogicalResult LegalizeQuantizedPrelu(Operation* op, // Perform an element-wise multiplication on rescaled alpha and input for // PReLU. Value alpha = tfl_prelu_op.getAlpha(); - ShapedType alpha_type = alpha.getType().cast(); + ShapedType alpha_type = mlir::cast(alpha.getType()); UniformQuantizedType alpha_qtype = - alpha_type.getElementType().cast(); + mlir::cast(alpha_type.getElementType()); Value op_rescale_alpha = removeZeroPointAndCastToInt32( rewriter, op, alpha, alpha_qtype.getZeroPoint()); @@ -3698,7 +3703,7 @@ static LogicalResult LegalizeQuantizedLeakyRelu(Operation* op, PatternRewriter& rewriter, Value input, double alpha, ShapedType output_type) { - ShapedType input_type = input.getType().cast(); + ShapedType input_type = mlir::cast(input.getType()); ShapedType rescale_type = input_type.clone(rewriter.getI32Type()); UniformQuantizedType input_qtype = @@ -3784,9 +3789,9 @@ LogicalResult ConvertTFLLeakyReluOp::matchAndRewrite( "input or output is not a ShapedType"); bool input_is_qtype = - input_type.getElementType().isa(); - bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); + bool output_is_qtype = mlir::isa( + output_type.getElementType()); if (input_is_qtype != output_is_qtype) { return rewriter.notifyMatchFailure( @@ -3846,8 +3851,7 @@ LogicalResult ConvertTFLCustomOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, op->getResultTypes(), tfl_custom_op.getCustomCode(), rewriter.getStringAttr("TFL"), - tfl_custom_op.getCustomOption() - .cast() + mlir::cast(tfl_custom_op.getCustomOption()) .getValue() .str(), op->getOperands()); @@ -3966,7 +3970,7 @@ LogicalResult ConvertTFLDequantizeOp::matchAndRewrite( if (!qtype) return failure(); Type element_type = qtype.getElementType(); - if (element_type.isa()) { + if (mlir::isa(element_type)) { CreateReplaceOpAndInfer(rewriter, op, output_type, tfl_dequantize_op.getInput()); return success(); @@ -4023,7 +4027,7 @@ LogicalResult ConvertTFLConstOp::matchAndRewrite( ElementsAttr elements = tfl_const_op.getValue(); Type element_type = elements.getShapedType().getElementType(); - if (output_type.getElementType().isa()) { + if (mlir::isa(output_type.getElementType())) { output_type = RankedTensorType::get(output_type.getShape(), element_type); } @@ -4031,7 +4035,8 @@ LogicalResult ConvertTFLConstOp::matchAndRewrite( // attribute shape. This occurs as some TFLite folders create constants with // unranked shapes. if (!output_type.hasRank()) { - output_type = elements.getType().cast().clone(element_type); + output_type = + mlir::cast(elements.getType()).clone(element_type); } rewriter.replaceOpWithNewOp(op, output_type, elements); @@ -4053,8 +4058,8 @@ LogicalResult ConvertTFLQConstOp::matchAndRewrite( // attribute shape. This occurs as some TFLite folders create constants with // unranked shapes. if (!output_type.hasRank()) { - output_type = elements.getType().cast().clone( - output_type.getElementType()); + output_type = mlir::cast(elements.getType()) + .clone(output_type.getElementType()); } rewriter.replaceOpWithNewOp(op, output_type, elements); @@ -4079,7 +4084,7 @@ LogicalResult ConvertConstantOp::matchAndRewrite( // For data type like 64 bits, we need to truncate them into 48 bits. if (e_type.isInteger(64)) { e_type = rewriter.getIntegerType(48); - attr = attr.cast().mapValues( + attr = mlir::cast(attr).mapValues( e_type, [](const APInt& x) -> APInt { return x.trunc(48); }); } @@ -4136,11 +4141,11 @@ LogicalResult ConvertTFLSparseToDenseOp::matchAndRewrite( auto indices = tfl_sparse_to_dense_op.getSparseIndices(); auto values = tfl_sparse_to_dense_op.getSparseValues(); auto default_value = tfl_sparse_to_dense_op.getDefaultValue(); - auto indices_ty = indices.getType().cast(); + auto indices_ty = mlir::cast(indices.getType()); auto indices_ety = indices_ty.getElementType(); - auto values_ty = values.getType().cast(); + auto values_ty = mlir::cast(values.getType()); auto result_ty = - tfl_sparse_to_dense_op.getResult().getType().cast(); + mlir::cast(tfl_sparse_to_dense_op.getResult().getType()); auto result_ety = result_ty.getElementType(); auto loc = op->getLoc(); @@ -4262,7 +4267,7 @@ LogicalResult ConvertTFLArgMinOp::matchAndRewrite( auto arg_max_op = cast(op); auto loc = arg_max_op.getLoc(); auto input = arg_max_op.getInput(); - auto input_ty = input.getType().cast(); + auto input_ty = mlir::cast(input.getType()); Type input_ety = input_ty.getElementType(); if (auto quantized_ty = dyn_cast(input_ety)) { @@ -4281,9 +4286,9 @@ LogicalResult ConvertTFLArgMinOp::matchAndRewrite( int32_t dim = dim_elems.getValues()[0].getSExtValue(); if (dim < 0) dim += input_ty.getRank(); - if (input_ety.isa()) { + if (mlir::isa(input_ety)) { input = CreateOpAndInfer(rewriter, loc, input_ty, input); - } else if (input_ety.isa()) { + } else if (mlir::isa(input_ety)) { auto reverse_ty = RankedTensorType::get({}, input_ety); Value reverse_val = rewriter.create( loc, reverse_ty, @@ -4370,12 +4375,12 @@ LogicalResult ConvertTFLRealOp::matchAndRewrite( Type input_ety = input_ty.getElementType(); // For non-complex inputs, return the original tensor. - if (!input_ety.isa()) { + if (!mlir::isa(input_ety)) { CreateReplaceOpAndInfer(rewriter, op, input_ty, input); return success(); } - if (!input_ety.cast().getElementType().isF32()) { + if (!mlir::cast(input_ety).getElementType().isF32()) { return rewriter.notifyMatchFailure( op, "complex input must be of type complex64"); } @@ -4425,13 +4430,13 @@ LogicalResult ConvertTFLImagOp::matchAndRewrite( Type input_ety = input_ty.getElementType(); // For non-complex inputs return all zero's. - if (!input_ety.isa()) { + if (!mlir::isa(input_ety)) { CreateReplaceOpAndInfer( rewriter, op, input_ty, DenseElementsAttr::get(input_ty, {0.0f})); return success(); } - if (!input_ety.cast().getElementType().isF32()) { + if (!mlir::cast(input_ety).getElementType().isF32()) { return rewriter.notifyMatchFailure( op, "complex input must be of type complex64"); } diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc index de8e777a7d558e..8571995d719484 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" @@ -78,7 +79,7 @@ std::optional buildReshapeWithDynamicDims(PatternRewriter& rewriter, Value input_value, ShapedType output_type, llvm::ArrayRef dims) { - auto e_ty = input_value.getType().cast().getElementType(); + auto e_ty = mlir::cast(input_value.getType()).getElementType(); llvm::SmallVector static_dims; if (output_type.hasRank()) { @@ -92,7 +93,7 @@ std::optional buildReshapeWithDynamicDims(PatternRewriter& rewriter, auto dim = dims[i]; SplatElementsAttr dim_attr; if (matchPattern(dim, m_Constant(&dim_attr))) { - if (dim_attr.getType().cast().getRank() != 0) { + if (mlir::cast(dim_attr.getType()).getRank() != 0) { (void)rewriter.notifyMatchFailure( op, "dim for building tosa::ReshapeOp should be rank-0"); return std::nullopt; @@ -643,8 +644,8 @@ DenseI64ArrayAttr getPaddingValuesFromExplicitPadAttr( for (int i = 0; i < 2; i++) { // Two spatial dimensions X&Y int64_t dim = GetTensorSpatialDimIndex(4, data_format_tf, i); // 4D tensor, NHWC/NCHW format - pad_before = explicit_pad[dim * 2].template cast().getInt(); - pad_after = explicit_pad[dim * 2 + 1].template cast().getInt(); + pad_before = mlir::cast(explicit_pad[dim * 2]).getInt(); + pad_after = mlir::cast(explicit_pad[dim * 2 + 1]).getInt(); computed_paddings.push_back(pad_before); computed_paddings.push_back(pad_after); } @@ -801,11 +802,11 @@ LogicalResult ApplyPatternsWithShapeResolution( // This should be investigate for whether it is still necessary due to quant // type stripping changing. func.walk([&](tosa::ConstOp op) { - if (op.getType().getElementType().isa()) { + if (mlir::isa(op.getType().getElementType())) { return; } auto ety = op.getValue().getShapedType().getElementType(); - auto new_ty = op.getType().cast().clone(ety); + auto new_ty = mlir::cast(op.getType()).clone(ety); op.getResult().setType(new_ty); }); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h index d2e04ac869ae48..acb9dff2a4a8ff 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h @@ -202,7 +202,7 @@ TosaOp CreateOpAndInfer(ImplicitLocOpBuilder& builder, Type result_ty, // Compute the knowledge based on the inferred type. auto inferredKnowledge = ValueKnowledge::getPessimisticValueState(); - inferredKnowledge.dtype = result_ty.cast().getElementType(); + inferredKnowledge.dtype = mlir::cast(result_ty).getElementType(); inferredKnowledge.hasRank = predictedShape.hasRank(); if (predictedShape.hasRank()) { for (auto dim : predictedShape.getDims()) { diff --git a/tensorflow/compiler/mlir/tosa/transforms/lower_complex_types.cc b/tensorflow/compiler/mlir/tosa/transforms/lower_complex_types.cc index 987ac5deb7479f..765cf33aa08812 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/lower_complex_types.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/lower_complex_types.cc @@ -42,6 +42,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project @@ -115,7 +116,7 @@ class GenericTypeConvert : public ConversionPattern { static bool isIllegalType(Type type) { if (auto shapedType = dyn_cast(type)) { - return shapedType.getElementType().isa(); + return mlir::isa(shapedType.getElementType()); } return false; } diff --git a/tensorflow/compiler/mlir/tosa/transforms/strip_quant_types.cc b/tensorflow/compiler/mlir/tosa/transforms/strip_quant_types.cc index 85df18855769fc..11857f3b1c3404 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/strip_quant_types.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/strip_quant_types.cc @@ -38,6 +38,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" @@ -121,7 +122,7 @@ class GenericTypeConvert : public ConversionPattern { }; static bool isIllegalType(Type type) { - if (type.isa()) return true; + if (mlir::isa(type)) return true; if (auto shapedType = dyn_cast(type)) { return isIllegalType(shapedType.getElementType()); } diff --git a/tensorflow/compiler/mlir/utils/BUILD b/tensorflow/compiler/mlir/utils/BUILD index e34e9cf7be7cca..2256c421b45717 100644 --- a/tensorflow/compiler/mlir/utils/BUILD +++ b/tensorflow/compiler/mlir/utils/BUILD @@ -16,6 +16,7 @@ cc_library( deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/utils/name_utils.cc b/tensorflow/compiler/mlir/utils/name_utils.cc index 6ca366fc9d64d5..7ce1c46861c2bb 100644 --- a/tensorflow/compiler/mlir/utils/name_utils.cc +++ b/tensorflow/compiler/mlir/utils/name_utils.cc @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project namespace mlir { @@ -63,7 +64,7 @@ std::string GetNameFromLoc(Location loc) { while (!locs.empty()) { Location curr_loc = locs.pop_back_val(); - if (auto name_loc = curr_loc.dyn_cast()) { + if (auto name_loc = mlir::dyn_cast(curr_loc)) { // Add name in NameLoc. For NameLoc we also account for names due to ops // in functions where the op's name is first. auto name = name_loc.getName().strref().split('@').first; @@ -73,11 +74,11 @@ std::string GetNameFromLoc(Location loc) { if (!name.empty()) names_is_nonempty = true; } continue; - } else if (auto call_loc = curr_loc.dyn_cast()) { + } else if (auto call_loc = mlir::dyn_cast(curr_loc)) { // Use location of the Callee to generate the name. locs.push_back(call_loc.getCallee()); continue; - } else if (auto fused_loc = curr_loc.dyn_cast()) { + } else if (auto fused_loc = mlir::dyn_cast(curr_loc)) { // Push all locations in FusedLoc in reverse order, so locations are // visited based on order in FusedLoc. auto reversed_fused_locs = llvm::reverse(fused_loc.getLocations()); diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index 5c19b9fe1014d3..39d4b086788ffe 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -116,8 +116,8 @@ constexpr llvm::StringRef kCustomCallShimTarget = } // namespace bool IsTokenType(mlir::Type type) { - return type.isa() || - type.isa(); + return mlir::isa(type) || + mlir::isa(type); } absl::StatusOr> @@ -174,7 +174,7 @@ absl::Status XlaCallModuleLoader::SetPlatformIndex( op_builder.setInsertionPointToStart(&main_body); mlir::BlockArgument platform_index_arg = main_body.getArgument(0); mlir::RankedTensorType arg_ranked_type = - platform_index_arg.getType().dyn_cast(); + mlir::dyn_cast(platform_index_arg.getType()); if (!arg_ranked_type || arg_ranked_type.getRank() != 0 || !(arg_ranked_type.getElementType().isSignlessInteger(32) || arg_ranked_type.getElementType().isSignlessInteger(64))) { @@ -301,7 +301,7 @@ absl::Status XlaCallModuleLoader::RefineDynamicShapes( << mlir::debugString(type) << " for argument type " << mlir::debugString(arg_type); mlir::TensorType arg_type = - main_body.getArgument(i).getType().dyn_cast(); + mlir::dyn_cast(main_body.getArgument(i).getType()); if (arg_type == nullptr) { return absl::InvalidArgumentError(absl::StrCat( "Argument ", i, " passed to XlaCallModule is not a tensor, ", @@ -316,7 +316,8 @@ absl::Status XlaCallModuleLoader::RefineDynamicShapes( mlir::debugString(arg_type), ", got ", mlir::debugString(type))); } - if (auto ranked_arg_type = arg_type.dyn_cast()) { + if (auto ranked_arg_type = + mlir::dyn_cast(arg_type)) { if (mlir::failed(mlir::verifyCompatibleShape(ranked_arg_type.getShape(), type.getShape()))) { return absl::InvalidArgumentError(absl::StrCat( @@ -380,9 +381,10 @@ absl::Status XlaCallModuleLoader::RefineDynamicShapes( if (IsTokenType(arg_type) || is_input_refined) { continue; } - auto ranked_arg_type = arg_type.dyn_cast(); + auto ranked_arg_type = mlir::dyn_cast(arg_type); if (!ranked_arg_type || !ranked_arg_type.hasStaticShape()) { - auto type = static_array_input_types[i].cast(); + auto type = + mlir::cast(static_array_input_types[i]); auto custom_call = MakeShapeRefinementOperandWrapper(op_builder, arg, type.getShape()); auto call_result = custom_call.getResult(0); @@ -409,8 +411,8 @@ absl::Status XlaCallModuleLoader::RefineDynamicShapes( // Clean up custom_call shims. for (auto call : llvm::make_early_inc_range( main_body.getOps())) { - if (call->getAttr("call_target_name").cast().strref() == - kCustomCallShimTarget) { + if (mlir::cast(call->getAttr("call_target_name")) + .strref() == kCustomCallShimTarget) { auto operand = call->getOperand(0); auto result = call->getResult(0); if (operand.getType() != result.getType()) { diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc index dbba8e7d8afd6a..529f27a0f7b25d 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h" @@ -403,7 +404,7 @@ class XlaCallModuleOp : public XlaOpKernel { mlir::TypeRange input_types(custom_call->getOperandTypes()); if (custom_call_has_token_input_output) { if (input_types.empty() || - !input_types.front().isa()) { + !mlir::isa(input_types.front())) { return absl::InvalidArgumentError(absl::StrCat( "stablehlo.custom_call with has_token_input_output = true is " "expected to take !stablehlo.token as the first argument, but " @@ -422,7 +423,7 @@ class XlaCallModuleOp : public XlaOpKernel { mlir::TypeRange result_types(custom_call->getResultTypes()); if (custom_call_has_token_input_output) { if (result_types.empty() || - !result_types.front().isa()) { + !mlir::isa(result_types.front())) { return absl::InvalidArgumentError(absl::StrCat( "stablehlo.custom_call with has_token_input_output = true is " "expected to return !stablehlo.token as the first result, but " diff --git a/tensorflow/core/ir/BUILD b/tensorflow/core/ir/BUILD index 5fc887ce213d76..9369b912467852 100644 --- a/tensorflow/core/ir/BUILD +++ b/tensorflow/core/ir/BUILD @@ -192,6 +192,7 @@ tf_cc_test( "//tensorflow/core:test_main", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", ], ) @@ -206,6 +207,7 @@ tf_cc_test( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/core/ir/importexport/BUILD b/tensorflow/core/ir/importexport/BUILD index 5694f228b32b7b..007fd3de65aa3b 100644 --- a/tensorflow/core/ir/importexport/BUILD +++ b/tensorflow/core/ir/importexport/BUILD @@ -109,6 +109,7 @@ cc_library( "//tensorflow/core/platform:statusor", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:ml_dtypes", ], ) diff --git a/tensorflow/core/ir/importexport/convert_attributes.cc b/tensorflow/core/ir/importexport/convert_attributes.cc index a591d77afe4977..6e1fd89c355bdd 100644 --- a/tensorflow/core/ir/importexport/convert_attributes.cc +++ b/tensorflow/core/ir/importexport/convert_attributes.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/status_macros.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/full_type.pb.h" @@ -55,11 +56,11 @@ namespace { // Converts a location to the debug information for the node def. Status ConvertLocation(Location inst_loc, NodeDef::ExperimentalDebugInfo* debug_info) { - if (auto call_site = inst_loc.dyn_cast()) { - if (auto name_loc = call_site.getCallee().dyn_cast()) { + if (auto call_site = mlir::dyn_cast(inst_loc)) { + if (auto name_loc = mlir::dyn_cast(call_site.getCallee())) { debug_info->add_original_node_names(name_loc.getName().data()); } - } else if (auto fused = inst_loc.dyn_cast()) { + } else if (auto fused = mlir::dyn_cast(inst_loc)) { auto locations = fused.getLocations(); if (locations.size() <= 1) return InvalidArgument("Expected experimental debug info."); @@ -107,7 +108,7 @@ Status ConvertAttribute(FlatSymbolRefAttr attr, AttrValue* value) { Status ConvertAttribute(FuncAttr attr, bool remove_ref_type, AttrValue* value) { TF_RETURN_IF_ERROR( - ConvertAttribute(attr.getName().cast(), value)); + ConvertAttribute(mlir::cast(attr.getName()), value)); TF_RETURN_IF_ERROR(ConvertAttributes(attr.getAttrs().getValue(), /*attrs_to_ignore=*/{}, remove_ref_type, value->mutable_func()->mutable_attr())); @@ -141,13 +142,13 @@ Status ConvertAttribute(const ArrayAttr& attr, bool remove_ref_type, AttrValue* value) { auto* list = value->mutable_list(); for (Attribute a : attr.getValue()) { - if (auto attr = a.dyn_cast()) { + if (auto attr = mlir::dyn_cast(a)) { list->add_b(attr.getValue()); - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { list->add_i(attr.getInt()); - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { list->add_f(attr.getValueAsDouble()); - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { AttrValue nested_value; TF_RETURN_IF_ERROR(ConvertAttribute(attr, &nested_value)); switch (nested_value.value_case()) { @@ -163,29 +164,29 @@ Status ConvertAttribute(const ArrayAttr& attr, bool remove_ref_type, default: return Unimplemented("Unhandled nested attribute!"); } - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { TensorProto tensor; TF_RETURN_IF_ERROR(ConvertToTensorProto(attr, &tensor)); *list->add_tensor() = tensor; - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { AttrValue attr_val; TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val)); *list->add_func() = attr_val.func(); - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { AttrValue attr_val; TF_RETURN_IF_ERROR(ConvertAttribute(attr, remove_ref_type, &attr_val)); *list->add_func() = attr_val.func(); - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { AttrValue attr_val; // For type attributes, we only propagate the element type. Type elt_type = attr.getValue(); - if (auto shaped_type = elt_type.dyn_cast()) { + if (auto shaped_type = mlir::dyn_cast(elt_type)) { elt_type = shaped_type.getElementType(); } TF_RETURN_IF_ERROR( ConvertAttribute(elt_type, remove_ref_type, &attr_val)); list->add_type(attr_val.type()); - } else if (auto attr = a.dyn_cast()) { + } else if (auto attr = mlir::dyn_cast(a)) { AttrValue attr_val; TF_RETURN_IF_ERROR(ConvertAttribute(attr, &attr_val)); *list->add_shape() = attr_val.shape(); @@ -200,17 +201,17 @@ Status ConvertAttribute(const ArrayAttr& attr, bool remove_ref_type, absl::StatusOr ConvertAttribute(Attribute attr) { AttrValue value; - if (auto symbol_ref = attr.dyn_cast()) { + if (auto symbol_ref = mlir::dyn_cast(attr)) { TF_RETURN_IF_ERROR( - ConvertAttribute(symbol_ref.cast(), &value)); + ConvertAttribute(mlir::cast(symbol_ref), &value)); return value; } - if (auto func_attr = attr.dyn_cast()) { + if (auto func_attr = mlir::dyn_cast(attr)) { TF_RETURN_IF_ERROR( ConvertAttribute(func_attr, /*remove_ref_type=*/false, &value)); return value; } - if (attr.isa()) + if (mlir::isa(attr)) return Unimplemented("AffineMap attribute unimplemented"); TF_RETURN_IF_ERROR( llvm::TypeSwitch(attr) @@ -251,11 +252,11 @@ Status ConvertAttributes(ArrayRef attrs, name = mangling_util::DemangleAttributeName(name); } TF_ASSIGN_OR_RETURN(AttrValue value, ConvertAttribute(attr)); - if (attr.isa()) { + if (mlir::isa(attr)) { func_call_attrs[std::string(name)] = value; continue; } - if (attr.isa()) { + if (mlir::isa(attr)) { func_call_attrs[std::string(name)] = value; continue; } @@ -479,7 +480,8 @@ Status ConvertHandleData(ArrayAttr handle_data_arr, tensorflow::OpDef::ArgDef* arg) { if (!handle_data_arr) return {}; for (auto handle_data_attr : handle_data_arr.getAsRange()) { - TensorType handle_type = handle_data_attr.getValue().dyn_cast(); + TensorType handle_type = + mlir::dyn_cast(handle_data_attr.getValue()); if (!handle_type) { return InvalidArgument("Expected an array of tensor types, but got ", debugString(handle_data_arr)); diff --git a/tensorflow/core/ir/importexport/convert_tensor.cc b/tensorflow/core/ir/importexport/convert_tensor.cc index e3b801dc52860c..752bb8b44ba5a7 100644 --- a/tensorflow/core/ir/importexport/convert_tensor.cc +++ b/tensorflow/core/ir/importexport/convert_tensor.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" @@ -248,12 +249,12 @@ void ConvertToTensorShapeProto(ArrayRef shape, } PartialTensorShape ConvertTypeToTensorShape(const Type& type) { - if (type.isa()) { + if (mlir::isa(type)) { // An empty PartialTensorShape indicates an unranked tensor. return PartialTensorShape(); } - if (auto tensor_type = type.dyn_cast()) { + if (auto tensor_type = mlir::dyn_cast(type)) { TensorShapeProto tensor_shape_proto; ConvertToTensorShapeProto(ConvertMlirShapeToTF(tensor_type.getShape()), &tensor_shape_proto); @@ -266,11 +267,11 @@ PartialTensorShape ConvertTypeToTensorShape(const Type& type) { } ShapeAttr ConvertTypeToTensorShapeAttr(const Type& type) { - if (type.isa()) { + if (mlir::isa(type)) { return ShapeAttr::get(type.getContext(), std::nullopt); } - if (auto tensor_type = type.dyn_cast()) { + if (auto tensor_type = mlir::dyn_cast(type)) { return ShapeAttr::get( type.getContext(), llvm::ArrayRef(ConvertMlirShapeToTF(tensor_type.getShape()))); @@ -439,10 +440,10 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { output->set_dtype(output_dtype); ConvertToTensorShapeProto(shape, output->mutable_tensor_shape()); - if (auto tensor_attr = attr.dyn_cast()) + if (auto tensor_attr = mlir::dyn_cast(attr)) return ConvertTensorProtoAttr(tensor_attr, output); - auto dense_attr = attr.dyn_cast(); + auto dense_attr = mlir::dyn_cast(attr); if (!dense_attr) return InvalidArgument("Unsupported elements attr"); switch (output_dtype) { @@ -508,7 +509,7 @@ Status ConvertToTensorProto(const ElementsAttr attr, TensorProto* output) { output->mutable_tensor_content()); break; case tensorflow::DT_STRING: - ConvertStringElementsAttr(dense_attr.cast(), + ConvertStringElementsAttr(mlir::cast(dense_attr), output->mutable_string_val()); break; case tensorflow::DT_UINT8: diff --git a/tensorflow/core/ir/importexport/convert_types.cc b/tensorflow/core/ir/importexport/convert_types.cc index 26e21205460512..f733f1025ae3e1 100644 --- a/tensorflow/core/ir/importexport/convert_types.cc +++ b/tensorflow/core/ir/importexport/convert_types.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/ir/dialect.h" @@ -127,7 +128,7 @@ Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { } else if (type.isFloat8E5M2()) { *dtype = ::tensorflow::DT_FLOAT8_E5M2; return ::tensorflow::OkStatus(); - } else if (auto itype = type.dyn_cast()) { + } else if (auto itype = mlir::dyn_cast(type)) { switch (itype.getWidth()) { case 1: *dtype = tensorflow::DT_BOOL; @@ -156,7 +157,7 @@ Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { return Unimplemented( absl::StrCat("Converting ", debugString(type), " to DataType")); } - } else if (auto complex_type = type.dyn_cast()) { + } else if (auto complex_type = mlir::dyn_cast(type)) { auto etype = complex_type.getElementType(); if (etype.isF32()) { *dtype = tensorflow::DT_COMPLEX64; @@ -182,7 +183,7 @@ Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { } Status ConvertToDataType(Type type, DataType* dtype) { - if (auto stype = type.dyn_cast()) { + if (auto stype = mlir::dyn_cast(type)) { TF_RETURN_IF_ERROR( ConvertScalarTypeToDataType(stype.getElementType(), dtype)); } else { diff --git a/tensorflow/core/ir/importexport/functiondef_export.cc b/tensorflow/core/ir/importexport/functiondef_export.cc index a274aac1b8ab27..6aa9f815e52885 100644 --- a/tensorflow/core/ir/importexport/functiondef_export.cc +++ b/tensorflow/core/ir/importexport/functiondef_export.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_def.pb.h" @@ -51,9 +52,9 @@ namespace tfg { static absl::StatusOr GetValueName(Value operand, Type control_ty) { bool is_control = (operand.getType() == control_ty); - OpResult op_result = operand.dyn_cast(); + OpResult op_result = mlir::dyn_cast(operand); if (!op_result) { - BlockArgument block_operand = operand.dyn_cast(); + BlockArgument block_operand = mlir::dyn_cast(operand); int arg_num = block_operand.getArgNumber(); // Function arguments are coming as pair: the even are the actual tensors @@ -174,7 +175,8 @@ absl::StatusOr ConvertGenericFunctionToFunctionDef( for (NamedAttribute attr : attrs) { OpDef_AttrDef *func_attr = signature->add_attr(); func_attr->set_name(attr.getName().str()); - DictionaryAttr dict_attr = attr.getValue().dyn_cast(); + DictionaryAttr dict_attr = + mlir::dyn_cast(attr.getValue()); if (!dict_attr) return InvalidArgument("Expects dict attribute"); if (StringAttr type = dict_attr.getAs("function_type")) func_attr->set_type(type.getValue().str()); @@ -198,7 +200,7 @@ absl::StatusOr ConvertGenericFunctionToFunctionDef( if (auto control_outputs = func_op->getAttrOfType("control_output")) { for (Attribute attr : control_outputs) { - StringAttr output = attr.dyn_cast(); + StringAttr output = mlir::dyn_cast(attr); if (!output) return InvalidArgument( "Can't export function with non-string \"control_output\" " @@ -216,7 +218,7 @@ absl::StatusOr ConvertGenericFunctionToFunctionDef( if (arg_num >= args_attr.size()) return InvalidArgument("Can't export function ", func_op.getName().str(), " because missing attributes for arg #", arg_num); - DictionaryAttr arg_attrs = args_attr[arg_num].cast(); + DictionaryAttr arg_attrs = mlir::cast(args_attr[arg_num]); FunctionDef::ArgAttrs func_def_arg_attrs; TF_RETURN_WITH_CONTEXT_IF_ERROR( ExportArgDef(arg, arg_attrs, &func_def_arg_attrs), @@ -242,7 +244,7 @@ absl::StatusOr ConvertGenericFunctionToFunctionDef( return InvalidArgument("Can't export function ", func_op.getName().str(), " because missing attributes for result #", res_num); - auto res_attrs = results_attr[res_num].cast(); + auto res_attrs = mlir::cast(results_attr[res_num]); auto name = res_attrs.getAs("tfg.name"); if (!name) return InvalidArgument( diff --git a/tensorflow/core/ir/importexport/functiondef_import.cc b/tensorflow/core/ir/importexport/functiondef_import.cc index 74bc67fcf077d5..7246cd35378eb3 100644 --- a/tensorflow/core/ir/importexport/functiondef_import.cc +++ b/tensorflow/core/ir/importexport/functiondef_import.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_def.pb.h" @@ -504,7 +505,7 @@ Status ImportGenericFunction( } TF_ASSIGN_OR_RETURN(Value result, value_manager.GetValueOrCreatePlaceholder( (Twine("^") + ret_val.second).str())); - if (!result.getType().isa()) + if (!mlir::isa(result.getType())) return InvalidArgument("failed to map returned value ", ret_val.second, ", isn't a control output"); ret_vals[func.ret_size() + position->second] = result; diff --git a/tensorflow/core/ir/importexport/graphdef_export.cc b/tensorflow/core/ir/importexport/graphdef_export.cc index bc0df8d0160628..54da202ddfdf2f 100644 --- a/tensorflow/core/ir/importexport/graphdef_export.cc +++ b/tensorflow/core/ir/importexport/graphdef_export.cc @@ -290,12 +290,12 @@ absl::StatusOr> GraphDefExporter::ExportFunction( // Convert the arguments. for (int i = 0, e = func.getNumArguments(); i < e; i += 2) { - auto attrs = func.getArgAttrs().value()[i].cast(); + auto attrs = mlir::cast(func.getArgAttrs().value()[i]); TF_ASSIGN_OR_RETURN(OpDef::ArgDef &arg = *signature->add_input_arg(), ConvertArgumentAttributes(attrs)); DataType dtype; TF_RETURN_IF_ERROR(ConvertToDataType( - func.getArgument(i).getType().cast().getElementType(), + mlir::cast(func.getArgument(i).getType()).getElementType(), &dtype)); arg.set_type(dtype); // Convert the attributes. @@ -317,7 +317,7 @@ absl::StatusOr> GraphDefExporter::ExportFunction( ConvertArgumentAttributes(std::get<1>(it))); DataType dtype; TF_RETURN_IF_ERROR(ConvertToDataType( - std::get<0>(it).cast().getElementType(), &dtype)); + mlir::cast(std::get<0>(it)).getElementType(), &dtype)); arg.set_type(dtype); // Map the result. TF_ASSIGN_OR_RETURN((*def->mutable_ret())[arg.name()], @@ -333,7 +333,7 @@ absl::StatusOr> GraphDefExporter::ExportFunction( if (attrs.empty()) return InvalidArgument("Control result is missing 'tfg.name'"); assert(attrs.begin()->getName() == dialect_->getTfgNameAttrIdentifier()); - std::string name = attrs.begin()->getValue().cast().str(); + std::string name = mlir::cast(attrs.begin()->getValue()).str(); signature->add_control_output(name); // Map the control result. TF_ASSIGN_OR_RETURN(std::string value_name, @@ -383,12 +383,13 @@ static void ExtractExperimentalDebugInfoFromLocation( debug_info->add_original_node_names(node.str()); if (!func.empty()) debug_info->add_original_func_names(func.str()); }; - if (auto fused = inst_loc.dyn_cast()) { + if (auto fused = mlir::dyn_cast(inst_loc)) { for (Location loc : fused.getLocations()) - if (auto name_loc = loc.dyn_cast()) add_name_loc(name_loc); + if (auto name_loc = mlir::dyn_cast(loc)) + add_name_loc(name_loc); return; } - if (auto name_loc = inst_loc.dyn_cast()) + if (auto name_loc = mlir::dyn_cast(inst_loc)) add_name_loc(name_loc); } @@ -437,7 +438,7 @@ Status ConvertToNodeDef( } // Export the location as debug info. - if (!op->getLoc().isa()) { + if (!mlir::isa(op->getLoc())) { ExtractExperimentalDebugInfoFromLocation( op->getLoc(), node->mutable_experimental_debug_info()); if (node->experimental_debug_info().original_node_names().empty()) @@ -464,15 +465,14 @@ static absl::StatusOr GetValueName( std::string name; bool is_control = value.getType() == dialect->getControlType(); - if (auto arg = value.dyn_cast()) { + if (auto arg = mlir::dyn_cast(value)) { auto func = dyn_cast(arg.getOwner()->getParentOp()); if (!func) return InvalidArgument("Expected block argument owner to be tfg.func"); // If the block argument is a control token, use the attributes of the // associated data argument (which preceeds it). - auto attrs = func.getArgAttrs() - .value()[arg.getArgNumber() - is_control] - .cast(); + auto attrs = mlir::cast( + func.getArgAttrs().value()[arg.getArgNumber() - is_control]); auto name_attr = attrs.getAs(dialect->getTfgNameAttrIdentifier()); if (!name_attr) { @@ -486,7 +486,7 @@ static absl::StatusOr GetValueName( return name; } - auto result = value.cast(); + auto result = mlir::cast(value); auto name_attr = result.getOwner()->getAttrOfType( dialect->getNameAttrIdentifier()); if (!name_attr) @@ -535,12 +535,12 @@ absl::StatusOr GraphDefExporter::GetEdgeName(Value value, static absl::StatusOr GetOutputSegmentSize( Operation *op, const OpDef::ArgDef &arg) { if (!arg.type_list_attr().empty()) { - if (auto v = op->getAttr(arg.type_list_attr()).dyn_cast()) + if (auto v = mlir::dyn_cast(op->getAttr(arg.type_list_attr()))) return v.size(); return InvalidArgument("Type attr not found: ", arg.type_list_attr()); } if (arg.number_attr().empty()) return 1; - if (auto v = op->getAttr(arg.number_attr()).dyn_cast()) + if (auto v = mlir::dyn_cast(op->getAttr(arg.number_attr()))) return v.getValue().getZExtValue(); return InvalidArgument("Type attr not found: ", arg.number_attr()); } diff --git a/tensorflow/core/ir/importexport/graphdef_import.cc b/tensorflow/core/ir/importexport/graphdef_import.cc index da2ebcf2d9f0c3..9e8f244cd9c9f3 100644 --- a/tensorflow/core/ir/importexport/graphdef_import.cc +++ b/tensorflow/core/ir/importexport/graphdef_import.cc @@ -710,10 +710,10 @@ absl::StatusOr GraphDefImporter::ArgNumType( SmallVectorImpl &types) { // Check whether a type list attribute is specified. if (!arg_def.type_list_attr().empty()) { - if (auto v = - attrs.get(arg_def.type_list_attr()).dyn_cast_or_null()) { + if (auto v = mlir::dyn_cast_or_null( + attrs.get(arg_def.type_list_attr()))) { for (Attribute attr : v) { - if (auto dtype = attr.dyn_cast()) { + if (auto dtype = mlir::dyn_cast(attr)) { types.push_back(UnrankedTensorType::get(dtype.getValue())); } else { return InvalidArgument("Expected '", arg_def.type_list_attr(), @@ -728,8 +728,8 @@ absl::StatusOr GraphDefImporter::ArgNumType( unsigned num = 1; // Check whether a number attribute is specified. if (!arg_def.number_attr().empty()) { - if (auto v = - attrs.get(arg_def.number_attr()).dyn_cast_or_null()) { + if (auto v = mlir::dyn_cast_or_null( + attrs.get(arg_def.number_attr()))) { num = v.getValue().getZExtValue(); } else { return NotFound("Type attr not found: ", arg_def.number_attr()); @@ -744,7 +744,8 @@ absl::StatusOr GraphDefImporter::ArgNumType( return InvalidArgument("Arg '", arg_def.name(), "' has invalid type and no type attribute"); } else { - if (auto v = attrs.get(arg_def.type_attr()).dyn_cast_or_null()) { + if (auto v = + mlir::dyn_cast_or_null(attrs.get(arg_def.type_attr()))) { dtype = v.getValue(); } else { return NotFound("Type attr not found: ", arg_def.type_attr()); diff --git a/tensorflow/core/ir/interfaces.cc b/tensorflow/core/ir/interfaces.cc index a8b84af5e1eafb..a9f69eba6bbff0 100644 --- a/tensorflow/core/ir/interfaces.cc +++ b/tensorflow/core/ir/interfaces.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/IR/Region.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/ir/ops.h" #include "tensorflow/core/ir/types/dialect.h" @@ -30,7 +31,7 @@ LogicalResult ControlArgumentInterface::verifyRegion(Operation *op, Region ®ion) { unsigned num_ctl = 0, num_data = 0; for (BlockArgument arg : region.getArguments()) { - bool is_ctl = arg.getType().isa(); + bool is_ctl = mlir::isa(arg.getType()); num_ctl += is_ctl; num_data += !is_ctl; } diff --git a/tensorflow/core/ir/ops.cc b/tensorflow/core/ir/ops.cc index 0814e1e4af4b35..830fb59e8e3fdd 100644 --- a/tensorflow/core/ir/ops.cc +++ b/tensorflow/core/ir/ops.cc @@ -69,10 +69,11 @@ static void GenericGetAsmResultNames(Operation* op, // We only name the results when there are results to name, an op like `print` // which does not have results will just use the `ctl` name for the control // output. - if (op->getNumResults() > 1 && !op->getResult(0).getType().isa()) + if (op->getNumResults() > 1 && + !mlir::isa(op->getResult(0).getType())) set_name_fn(op->getResult(0), op->getName().stripDialect()); for (Value result : op->getResults()) { - if (result.getType().isa()) { + if (mlir::isa(result.getType())) { set_name_fn(op->getResult(op->getNumResults() - 1), "ctl"); break; } @@ -310,8 +311,8 @@ static ParseResult ParseCustomTfOp(OpAsmParser& parser, llvm::SMLoc loc = parser.getCurrentLocation(); Type control_type = ControlType::get(context); if (failed(parser.parseOptionalColonTypeList(arg_types))) return failure(); - if (arg_types.size() == 1 && arg_types.front().isa()) { - auto funcType = arg_types.front().cast(); + if (arg_types.size() == 1 && mlir::isa(arg_types.front())) { + auto funcType = mlir::cast(arg_types.front()); if (funcType.getNumInputs() != numNonControlOperands) return parser.emitError(loc) << "got " << numNonControlOperands @@ -398,8 +399,9 @@ bool GraphFuncOp::isMarkedForCompilation() { auto is_enabled = [this](StringRef attr_name) -> bool { Attribute attr = (*this)->getAttr(attr_name); if (!attr) return false; - if (auto bool_attr = attr.dyn_cast()) return bool_attr.getValue(); - if (auto str_attr = attr.dyn_cast()) + if (auto bool_attr = mlir::dyn_cast(attr)) + return bool_attr.getValue(); + if (auto str_attr = mlir::dyn_cast(attr)) return !str_attr.getValue().empty(); return false; }; @@ -673,7 +675,7 @@ void GraphFuncOp::print(OpAsmPrinter& p) { p.printOperand(getArgument(i)); p << ": "; p.printType(arg_types[i]); - if (auto arg_attrs = args_attr[i].dyn_cast()) + if (auto arg_attrs = mlir::dyn_cast(args_attr[i])) p.printOptionalAttrDict(arg_attrs.getValue()); if (i != e - 2) { p << ", "; @@ -691,7 +693,7 @@ void GraphFuncOp::print(OpAsmPrinter& p) { ArrayAttr results_attr = getAllResultAttrs(); for (int i = 0, e = result_types.size(); i < e; ++i) { p.printType(result_types[i]); - if (auto result_attrs = results_attr[i].dyn_cast()) + if (auto result_attrs = mlir::dyn_cast(results_attr[i])) p.printOptionalAttrDict(result_attrs.getValue()); if (i != e - 1) { p << ", "; @@ -761,7 +763,8 @@ void GraphFuncOp::getAsmBlockArgumentNames(Region& region, ArrayAttr args_attr = getAllArgAttrs(); if (!args_attr || args_attr.size() != args.size()) return; for (int arg_num = 0, e = args.size(); arg_num < e; arg_num += 2) { - DictionaryAttr arg_attrs = args_attr[arg_num].dyn_cast(); + DictionaryAttr arg_attrs = + mlir::dyn_cast(args_attr[arg_num]); if (!arg_attrs) continue; if (auto strAttr = arg_attrs.getAs("tfg.name")) { set_name_fn(args[arg_num], strAttr.getValue()); @@ -1053,7 +1056,7 @@ static LogicalResult VerifyCaseLikeOp(CaseLikeOp op, TypeRange func_args = ins->drop_front(); for (const auto& it : llvm::enumerate(op.getBranches())) { - SymbolRefAttr func_name = it.value().template cast().getName(); + SymbolRefAttr func_name = mlir::cast(it.value()).getName(); auto func = symbol_table.lookupNearestSymbolFrom(op, func_name); if (func && failed(VerifySignature(func, op, func_args, *outs, @@ -1126,7 +1129,7 @@ static LogicalResult VerifyPreservedAttrs(Operation* op, assert(op->getNumRegions() == preserved_attrs.size()); for (auto it : llvm::zip(preserved_attrs, op->getRegions())) { // Preserved attributes for a particular region may not exist. - auto attrs = std::get<0>(it).dyn_cast_or_null(); + auto attrs = mlir::dyn_cast_or_null(std::get<0>(it)); if (!attrs) continue; Region& region = std::get<1>(it); @@ -1195,7 +1198,7 @@ static LogicalResult VerifyIfLikeRegionOp(IfLikeRegionOp op) { // TODO(jeffniu): Incorporate the other cases of `tf.ToBool`. static std::optional GetStaticallyKnownBranch(Attribute cond_attr) { // Only handle the case of a scalar tensor of i1. - auto cond = cond_attr.dyn_cast_or_null(); + auto cond = mlir::dyn_cast_or_null(cond_attr); if (cond && cond.getNumElements() == 1 && cond.getElementType().isSignlessInteger(1)) return cond.getSplatValue(); @@ -1275,7 +1278,7 @@ static LogicalResult VerifyCaseLikeRegionOp(CaseLikeRegionOp op) { // try to narrow it to a statically known branch index. static std::optional GetStaticallyKnownCaseBranch( Attribute branch_attr) { - auto branch = branch_attr.dyn_cast_or_null(); + auto branch = mlir::dyn_cast_or_null(branch_attr); if (branch && branch.getNumElements() == 1 && branch.getElementType().isSignlessInteger(32)) return branch.getSplatValue(); @@ -1344,7 +1347,7 @@ static LogicalResult VerifyLoopRegionArgs(Operation* op, Region& region) { // the first half of the arguments are not control tokens, then we know for // sure that the second half is only control tokens. for (BlockArgument data : GetLoopRegionDataArgs(region)) - if (data.getType().isa()) + if (mlir::isa(data.getType())) return arg_error(data) << "should not be a control token"; return success(); } @@ -1412,7 +1415,7 @@ LogicalResult ForRegionOp::verify() { "expected the body block to have at least have the loop index as an " "argument"); } - auto index = args.front().getType().dyn_cast(); + auto index = mlir::dyn_cast(args.front().getType()); if (!index || !index.getElementType().isSignlessInteger(32)) { return emitOpError( "expected first body block argument to be an i32 tensor"); @@ -1467,8 +1470,9 @@ bool FunctionTable::MayBeCall(Operation* op) const { if (IsLegacyCall(op)) return true; // The operation might be a call if it references a symbol. bool references_symbol = false; - op->getAttrDictionary().walk( - [&](Attribute attr) { references_symbol |= attr.isa(); }); + op->getAttrDictionary().walk([&](Attribute attr) { + references_symbol |= mlir::isa(attr); + }); return references_symbol; } diff --git a/tensorflow/core/ir/tf_op_names.cc b/tensorflow/core/ir/tf_op_names.cc index 0aadaf9288863a..bd43c99d84b397 100644 --- a/tensorflow/core/ir/tf_op_names.cc +++ b/tensorflow/core/ir/tf_op_names.cc @@ -15,6 +15,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/ir/dialect.h" #include "tensorflow/core/ir/tf_op_wrapper.h" @@ -26,7 +27,7 @@ bool TFGraphDialect::IsAdd(TFOp op) const { if (op_name == add_v2_) return true; if (op_name == add_) - return !op->getAttrOfType("T").getValue().isa(); + return !mlir::isa(op->getAttrOfType("T").getValue()); return false; } diff --git a/tensorflow/core/ir/tf_op_wrapper.h b/tensorflow/core/ir/tf_op_wrapper.h index 1bf02d90877c66..36a3964d45ea7d 100644 --- a/tensorflow/core/ir/tf_op_wrapper.h +++ b/tensorflow/core/ir/tf_op_wrapper.h @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/ir/dialect.h" #include "tensorflow/core/ir/types/dialect.h" #include "tensorflow/core/ir/utility.h" @@ -38,7 +39,7 @@ class ControlRetIterator final ValueIteratorT, Value>::mapped_iterator_base; Value mapElement(Value value) const { - return value.getType().isa() + return mlir::isa(value.getType()) ? value : tfg::LookupControlDependency(value); } diff --git a/tensorflow/core/ir/tf_op_wrapper_test.cc b/tensorflow/core/ir/tf_op_wrapper_test.cc index 1eaf6731e670fc..b8bc7bfd55c86d 100644 --- a/tensorflow/core/ir/tf_op_wrapper_test.cc +++ b/tensorflow/core/ir/tf_op_wrapper_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/ir/dialect.h" #include "tensorflow/core/ir/ops.h" #include "tensorflow/core/platform/test.h" @@ -84,7 +85,7 @@ TEST(TFOpWrapper, ControlOperands) { EXPECT_EQ(ctls.size(), 2u); OperandRange::iterator ctl_it = llvm::find_if(operands, [](Value operand) { - return operand.getType().isa(); + return mlir::isa(operand.getType()); }); EXPECT_NE(ctl_it, operands.end()); EXPECT_EQ(data.end(), ctl_it); @@ -184,7 +185,7 @@ TEST(TFOpWrapper, ValueControlRet) { // Value with ControlType will be the same. EXPECT_EQ(ret_range[2], const_op.controlRet()); - for (Value v : ret_range) EXPECT_TRUE(v.getType().isa()); + for (Value v : ret_range) EXPECT_TRUE(mlir::isa(v.getType())); } } // namespace diff --git a/tensorflow/core/ir/types/BUILD b/tensorflow/core/ir/types/BUILD index 2181219addb51f..18bb1541d76662 100644 --- a/tensorflow/core/ir/types/BUILD +++ b/tensorflow/core/ir/types/BUILD @@ -123,5 +123,6 @@ tf_cc_test( "//tensorflow/core:test_main", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/core/ir/types/dialect.cc b/tensorflow/core/ir/types/dialect.cc index 80fd703bdf1c29..bd693656cf6fbe 100644 --- a/tensorflow/core/ir/types/dialect.cc +++ b/tensorflow/core/ir/types/dialect.cc @@ -38,6 +38,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #define GET_ATTRDEF_CLASSES @@ -274,7 +275,7 @@ static void RawFullTypeAttrPrint(FullTypeAttr tfattr, AsmPrinter& printer) { if (!tfattr.getArgs().empty()) { printer << "<"; llvm::interleaveComma(tfattr.getArgs(), printer, [&](Attribute arg) { - if (auto t = arg.dyn_cast()) + if (auto t = mlir::dyn_cast(arg)) RawFullTypeAttrPrint(t, printer); else printer << "<>"; @@ -320,7 +321,7 @@ Attribute FuncAttr::parse(AsmParser& parser, Type type) { parser.emitError(loc) << "expected symbol while parsing tf.func attribute"; return {}; } - if (auto func_name_str = name.dyn_cast()) { + if (auto func_name_str = mlir::dyn_cast(name)) { if (!func_name_str.getValue().empty()) { parser.emitError(loc) << "expected empty string or symbol while parsing tf.func " @@ -329,20 +330,20 @@ Attribute FuncAttr::parse(AsmParser& parser, Type type) { } name = SymbolRefAttr::get(parser.getContext(), ""); } - if (!name.isa()) { + if (!mlir::isa(name)) { parser.emitError(loc) << "expected symbol while parsing tf.func attribute"; return {}; } if (failed(parser.parseComma())) return {}; loc = parser.getCurrentLocation(); - if (failed(parser.parseAttribute(dict)) || !dict.isa()) { + if (failed(parser.parseAttribute(dict)) || !mlir::isa(dict)) { parser.emitError(loc) << "expected Dictionary attribute while parsing tf.func attribute"; return {}; } if (failed(parser.parseGreater())) return {}; - return FuncAttr::get(parser.getContext(), name.cast(), - dict.cast()); + return FuncAttr::get(parser.getContext(), mlir::cast(name), + mlir::cast(dict)); } void PlaceholderAttr::print(AsmPrinter& os) const { @@ -455,7 +456,7 @@ namespace { // Returns the shape of the given value if it's ranked; returns std::nullopt // otherwise. std::optional> GetShape(Value value) { - auto shaped_type = value.getType().cast(); + auto shaped_type = mlir::cast(value.getType()); if (shaped_type.hasRank()) return shaped_type.getShape(); return std::nullopt; } @@ -516,13 +517,13 @@ bool TensorFlowType::classof(Type type) { return llvm::isa(type.getDialect()); } bool TensorFlowRefType::classof(Type type) { - return type.isa< + return mlir::isa< #define HANDLE_TF_TYPE(tftype, enumerant, name) #define HANDLE_TF_REF_TYPE(tftype, enumerant, name) tftype##Type, #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type // NOLINTNEXTLINE #include "tensorflow/core/ir/types/types.def" - >(); + >(type); } TensorFlowType TensorFlowRefType::get(Type type) { @@ -540,7 +541,7 @@ TensorFlowType TensorFlowRefType::get(Type type) { return Float8E4M3FNRefType::get(ctx); } else if (type.isFloat8E5M2()) { return Float8E5M2RefType::get(ctx); - } else if (auto complex_type = type.dyn_cast()) { + } else if (auto complex_type = mlir::dyn_cast(type)) { Type etype = complex_type.getElementType(); if (etype.isF32()) { return Complex64RefType::get(ctx); @@ -548,7 +549,7 @@ TensorFlowType TensorFlowRefType::get(Type type) { return Complex128RefType::get(ctx); } llvm_unreachable("unexpected complex type"); - } else if (auto itype = type.dyn_cast()) { + } else if (auto itype = mlir::dyn_cast(type)) { switch (itype.getWidth()) { case 1: return BoolRefType::get(ctx); @@ -583,30 +584,34 @@ TensorFlowType TensorFlowRefType::get(Type type) { Type TensorFlowRefType::RemoveRef() { MLIRContext* ctx = getContext(); - if (isa()) return FloatType::getF16(ctx); - if (isa()) return FloatType::getF32(ctx); - if (isa()) return FloatType::getF64(ctx); - if (isa()) return FloatType::getBF16(ctx); - if (isa()) return FloatType::getFloat8E4M3FN(ctx); - if (isa()) return FloatType::getFloat8E5M2(ctx); - if (isa()) return IntegerType::get(ctx, 1); - if (isa()) return IntegerType::get(ctx, 4, IntegerType::Signed); - if (isa()) return IntegerType::get(ctx, 8); - if (isa()) return IntegerType::get(ctx, 16); - if (isa()) return IntegerType::get(ctx, 32); - if (isa()) return IntegerType::get(ctx, 64); - if (isa()) + if (mlir::isa(*this)) return FloatType::getF16(ctx); + if (mlir::isa(*this)) return FloatType::getF32(ctx); + if (mlir::isa(*this)) return FloatType::getF64(ctx); + if (mlir::isa(*this)) return FloatType::getBF16(ctx); + if (mlir::isa(*this)) + return FloatType::getFloat8E4M3FN(ctx); + if (mlir::isa(*this)) return FloatType::getFloat8E5M2(ctx); + if (mlir::isa(*this)) return IntegerType::get(ctx, 1); + if (mlir::isa(*this)) + return IntegerType::get(ctx, 4, IntegerType::Signed); + if (mlir::isa(*this)) return IntegerType::get(ctx, 8); + if (mlir::isa(*this)) return IntegerType::get(ctx, 16); + if (mlir::isa(*this)) return IntegerType::get(ctx, 32); + if (mlir::isa(*this)) return IntegerType::get(ctx, 64); + if (mlir::isa(*this)) return IntegerType::get(ctx, 4, IntegerType::Unsigned); - if (isa()) + if (mlir::isa(*this)) return IntegerType::get(ctx, 8, IntegerType::Unsigned); - if (isa()) + if (mlir::isa(*this)) return IntegerType::get(ctx, 16, IntegerType::Unsigned); - if (isa()) + if (mlir::isa(*this)) return IntegerType::get(ctx, 32, IntegerType::Unsigned); - if (isa()) + if (mlir::isa(*this)) return IntegerType::get(ctx, 64, IntegerType::Unsigned); - if (isa()) return ComplexType::get(FloatType::getF32(ctx)); - if (isa()) return ComplexType::get(FloatType::getF64(ctx)); + if (mlir::isa(*this)) + return ComplexType::get(FloatType::getF32(ctx)); + if (mlir::isa(*this)) + return ComplexType::get(FloatType::getF64(ctx)); #define HANDLE_TF_TYPE(tftype, enumerant, name) \ if (isa()) return tftype##Type::get(ctx); @@ -617,32 +622,32 @@ Type TensorFlowRefType::RemoveRef() { } bool TensorFlowTypeWithSubtype::classof(Type type) { - return type.isa(); + return mlir::isa(type); } Type TensorFlowTypeWithSubtype::RemoveSubtypes() { MLIRContext* ctx = getContext(); - if (isa()) return VariantType::get(ctx); - if (isa()) return ResourceType::get(ctx); + if (mlir::isa(*this)) return VariantType::get(ctx); + if (mlir::isa(*this)) return ResourceType::get(ctx); llvm_unreachable("unexpected tensorflow type with subtypes kind"); } TensorFlowTypeWithSubtype TensorFlowTypeWithSubtype::clone( ArrayRef new_subtypes) { MLIRContext* ctx = getContext(); - if (isa()) - return VariantType::get(new_subtypes, ctx) - .cast(); - if (isa()) - return ResourceType::get(new_subtypes, ctx) - .cast(); + if (mlir::isa(*this)) + return mlir::cast( + VariantType::get(new_subtypes, ctx)); + if (mlir::isa(*this)) + return mlir::cast( + ResourceType::get(new_subtypes, ctx)); llvm_unreachable("unexpected tensorflow type with subtypes kind"); } ArrayRef TensorFlowTypeWithSubtype::GetSubtypes() { - if (auto variant_type = dyn_cast()) + if (auto variant_type = mlir::dyn_cast(*this)) return variant_type.getSubtypes(); - if (auto resource_type = dyn_cast()) + if (auto resource_type = mlir::dyn_cast(*this)) return resource_type.getSubtypes(); llvm_unreachable("unexpected tensorflow type with subtypes kind"); } @@ -659,8 +664,8 @@ bool BroadcastCompatible(TypeRange lhs, TypeRange rhs) { auto rhs_type = DropRefType(std::get<1>(types)); // This should be true for all TF ops: - auto lhs_tt = lhs_type.dyn_cast(); - auto rhs_tt = rhs_type.dyn_cast(); + auto lhs_tt = mlir::dyn_cast(lhs_type); + auto rhs_tt = mlir::dyn_cast(rhs_type); if (!lhs_tt || !rhs_tt) { if (lhs_type != rhs_type) return false; continue; @@ -673,8 +678,8 @@ bool BroadcastCompatible(TypeRange lhs, TypeRange rhs) { auto rhs_et = rhs_tt.getElementType(); if (lhs_et != rhs_et) { // If either does not have subtypes, then the element types don't match. - auto lhs_wst = lhs_et.dyn_cast(); - auto rhs_wst = rhs_et.dyn_cast(); + auto lhs_wst = mlir::dyn_cast(lhs_et); + auto rhs_wst = mlir::dyn_cast(rhs_et); if (!lhs_wst || !rhs_wst) return false; // Consider the subtype of variant types. @@ -689,8 +694,8 @@ bool BroadcastCompatible(TypeRange lhs, TypeRange rhs) { } } - auto lhs_rt = lhs_type.dyn_cast(); - auto rhs_rt = rhs_type.dyn_cast(); + auto lhs_rt = mlir::dyn_cast(lhs_type); + auto rhs_rt = mlir::dyn_cast(rhs_type); if (!lhs_rt || !rhs_rt) return true; SmallVector shape; return OpTrait::util::getBroadcastedShape(lhs_rt.getShape(), @@ -721,8 +726,8 @@ Type GetCastCompatibleType(Type a, Type b, bool may_ignore_ref_type_a) { // Fast path if everything is equal. if (a == b) return b; - auto a_tt = a.dyn_cast(); - auto b_tt = b.dyn_cast(); + auto a_tt = mlir::dyn_cast(a); + auto b_tt = mlir::dyn_cast(b); // If only one of a or b is a tensor type, they are incompatible. if (static_cast(a_tt) ^ static_cast(b_tt)) return nullptr; @@ -732,7 +737,7 @@ Type GetCastCompatibleType(Type a, Type b, bool may_ignore_ref_type_a) { if (!a_tt && !b_tt) { // Remove ref types. if (may_ignore_ref_type_a) { - if (auto ref_type = a.dyn_cast()) { + if (auto ref_type = mlir::dyn_cast(a)) { a = ref_type.RemoveRef(); if (a == b) return a; } @@ -741,8 +746,8 @@ Type GetCastCompatibleType(Type a, Type b, bool may_ignore_ref_type_a) { // If either is not a type that contain subtypes then the types are not cast // compatible. - auto a_wst = a.dyn_cast(); - auto b_wst = b.dyn_cast(); + auto a_wst = mlir::dyn_cast(a); + auto b_wst = mlir::dyn_cast(b); if (!a_wst || !b_wst) return nullptr; // For Variant types we are more permissive right now and accept all pairs @@ -752,8 +757,8 @@ Type GetCastCompatibleType(Type a, Type b, bool may_ignore_ref_type_a) { // one, so we should only assign it one when we know the subtype. Then we // can be more constrained and check subtypes for cast compatibility as // well. - if (a.isa()) return a; - if (b.isa()) return b; + if (mlir::isa(a)) return a; + if (mlir::isa(b)) return b; // For Resource types, we recursively check the subtypes for cast // compatibility, if possible. Otherwise treat them as compatible. @@ -768,7 +773,7 @@ Type GetCastCompatibleType(Type a, Type b, bool may_ignore_ref_type_a) { GetCastCompatibleType(std::get<0>(subtypes), std::get<1>(subtypes), /*may_ignore_ref_type_a=*/false); if (!refined_st) return nullptr; - refined_subtypes.push_back(refined_st.cast()); + refined_subtypes.push_back(mlir::cast(refined_st)); } return ResourceType::get(refined_subtypes, a.getContext()); @@ -833,13 +838,13 @@ static Type GetDefaultTypeOf(TensorFlowRefType type) { template Type DropTypeHelper(Type ty) { Type element_ty = getElementTypeOrSelf(ty); - auto composed_type = element_ty.dyn_cast(); + auto composed_type = mlir::dyn_cast(element_ty); if (!composed_type) return ty; Type default_ty = GetDefaultTypeOf(composed_type); - if (auto ranked_ty = ty.dyn_cast()) { + if (auto ranked_ty = mlir::dyn_cast(ty)) { return RankedTensorType::get(ranked_ty.getShape(), default_ty); - } else if (ty.dyn_cast()) { + } else if (mlir::dyn_cast(ty)) { return UnrankedTensorType::get(default_ty); } else { return default_ty; @@ -867,7 +872,7 @@ Attribute TensorProtoAttr::parse(AsmParser& parser, Type type) { parser.emitError(parser.getNameLoc(), "Hex string doesn't start with `0x`"); return nullptr; } - auto shapedType = type.dyn_cast(); + auto shapedType = mlir::dyn_cast(type); if (!shapedType) return nullptr; std::string bytes_data = absl::HexStringToBytes(data.substr(2)); diff --git a/tensorflow/core/ir/types/dialect.h b/tensorflow/core/ir/types/dialect.h index e2dc8bef70a5d1..7c1a1cda1bec94 100644 --- a/tensorflow/core/ir/types/dialect.h +++ b/tensorflow/core/ir/types/dialect.h @@ -28,6 +28,7 @@ limitations under the License. // Include the dialect class generated from dialect.td. // The constructor and the printing/parsing of dialect types are manually // implemented (see ops.cpp). +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/ir/types/dialect.h.inc" // Include the Type classes declaration generated from types.td @@ -52,15 +53,15 @@ class TensorFlowType : public Type { // Returns true if the specified type is a valid TensorFlow element type. inline bool IsValidTFElementType(Type type) { - return type.isa(); + return mlir::isa(type); } // Returns true if this is a valid TensorFlow tensor type. inline bool IsValidTFTensorType(Type type) { // TensorFlow types should be tensors of one of the valid TensorFlow element // types. - if (auto tensor_ty = type.dyn_cast()) + if (auto tensor_ty = mlir::dyn_cast(type)) return IsValidTFElementType(tensor_ty.getElementType()); return false; } @@ -329,7 +330,7 @@ using ResultShapeRange = iterator_range; template auto filter_resources(RangeT&& range) { return llvm::make_filter_range(std::forward(range), [](Value val) { - return getElementTypeOrSelf(val.getType()).isa(); + return mlir::isa(getElementTypeOrSelf(val.getType())); }); } @@ -338,7 +339,7 @@ auto filter_resources(RangeT&& range) { // standard type if necessary. inline Type GetElementTypeOrSelfResolveRef(Type type) { Type element_type = getElementTypeOrSelf(type); - if (auto ref_type = element_type.dyn_cast()) { + if (auto ref_type = mlir::dyn_cast(element_type)) { element_type = ref_type.RemoveRef(); } return element_type; diff --git a/tensorflow/core/ir/types/dialect_test.cc b/tensorflow/core/ir/types/dialect_test.cc index db5d1cfed6d502..1fb6537b4684f7 100644 --- a/tensorflow/core/ir/types/dialect_test.cc +++ b/tensorflow/core/ir/types/dialect_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/platform/test.h" namespace mlir { @@ -46,7 +47,7 @@ TEST(TFTypesDialect, TestFuncAttrSubElement) { ASSERT_TRUE(succeeded(SymbolTable::replaceAllSymbolUses( b.getStringAttr("foo"), baz, test_op.getParentRegion()))); - auto func_attr = test_op.getAttr("func").dyn_cast(); + auto func_attr = mlir::dyn_cast(test_op.getAttr("func")); ASSERT_TRUE(func_attr); auto sym_ref = FlatSymbolRefAttr::get(baz); EXPECT_TRUE(func_attr.getName() == sym_ref); diff --git a/tensorflow/core/ir/utility.cc b/tensorflow/core/ir/utility.cc index e15eb5799f8f3d..168a71e76bc310 100644 --- a/tensorflow/core/ir/utility.cc +++ b/tensorflow/core/ir/utility.cc @@ -20,6 +20,7 @@ limitations under the License. #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/ir/dialect.h" #include "tensorflow/core/ir/interfaces.h" #include "tensorflow/core/ir/types/dialect.h" @@ -47,35 +48,36 @@ BlockArgument GetLoopRegionDataOf(BlockArgument ctl) { } Value LookupControlDependency(Value data) { - assert(!data.getType().isa() && "expected a data type"); + assert(!mlir::isa(data.getType()) && "expected a data type"); // If the value is defined by an op, then the last result is the control // dependency. Value control_dep; - if (auto result = data.dyn_cast()) { + if (auto result = mlir::dyn_cast(data)) { control_dep = *std::prev(result.getOwner()->result_end()); } else { - auto arg = data.cast(); + auto arg = mlir::cast(data); control_dep = cast(arg.getOwner()->getParentOp()) .getControlTokenOf(arg); } - assert(control_dep.getType().isa() && "expected a control type"); + assert(mlir::isa(control_dep.getType()) && + "expected a control type"); return control_dep; } std::optional LookupDataValue(Value ctl) { - assert(ctl.getType().isa() && "expected a control type"); + assert(mlir::isa(ctl.getType()) && "expected a control type"); // If the value is defined by an op, then return the first result. Value data; - if (auto result = ctl.dyn_cast()) { + if (auto result = mlir::dyn_cast(ctl)) { // If the op only has a control result, then there is no data value. if (result.getOwner()->getNumResults() == 1) return {}; data = *result.getOwner()->result_begin(); } else { - auto arg = ctl.cast(); + auto arg = mlir::cast(ctl); data = cast(arg.getOwner()->getParentOp()) .getDataValueOf(arg); } - assert(!data.getType().isa() && "expected a data type"); + assert(!mlir::isa(data.getType()) && "expected a data type"); return data; } diff --git a/tensorflow/core/ir/utility_test.cc b/tensorflow/core/ir/utility_test.cc index 97c3184f228535..200d1103c7d7fb 100644 --- a/tensorflow/core/ir/utility_test.cc +++ b/tensorflow/core/ir/utility_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/ir/dialect.h" #include "tensorflow/core/ir/ops.h" #include "tensorflow/core/platform/test.h" @@ -49,7 +50,7 @@ TEST(DialectUtilityTest, TestLookupControlDependency) { Value copy = ret_op.getOperand(0); Value ctl = LookupControlDependency(copy); ASSERT_TRUE(ctl); - OpResult ctl_result = ctl.dyn_cast(); + OpResult ctl_result = mlir::dyn_cast(ctl); ASSERT_TRUE(ctl_result); EXPECT_EQ(ctl_result.getResultNumber(), 1); EXPECT_EQ(copy, ctl_result.getOwner()->getResult(0)); @@ -58,7 +59,7 @@ TEST(DialectUtilityTest, TestLookupControlDependency) { Value arg = ctl_result.getOwner()->getOperand(0); Value arg_ctl = LookupControlDependency(arg); ASSERT_TRUE(arg_ctl); - BlockArgument ctl_arg = arg_ctl.dyn_cast(); + BlockArgument ctl_arg = mlir::dyn_cast(arg_ctl); ASSERT_TRUE(ctl_arg); EXPECT_EQ(ctl_arg.getArgNumber(), 1); EXPECT_EQ(arg, ctl_arg.getOwner()->getArgument(0)); @@ -84,7 +85,7 @@ TEST(DialectUtilityTest, TestLookupDataValue) { Value ctl = ret_op.getOperand(1); std::optional produce = LookupDataValue(ctl); ASSERT_TRUE(produce); - OpResult produce_result = produce->dyn_cast(); + OpResult produce_result = mlir::dyn_cast(*produce); ASSERT_TRUE(produce_result); ASSERT_EQ(produce_result.getResultNumber(), 0); ASSERT_EQ(produce_result.getOwner()->getName().getStringRef(), "tfg.Produce"); @@ -93,7 +94,7 @@ TEST(DialectUtilityTest, TestLookupDataValue) { Value arg_ctl = produce_result.getOwner()->getOperand(0); std::optional arg = LookupDataValue(arg_ctl); ASSERT_TRUE(arg); - BlockArgument arg_arg = arg->dyn_cast(); + BlockArgument arg_arg = mlir::dyn_cast(*arg); ASSERT_TRUE(arg_arg); ASSERT_EQ(arg_arg.getArgNumber(), 0); ASSERT_EQ(arg_arg.getOwner()->getArgument(1), arg_ctl); diff --git a/tensorflow/core/ir/utils/shape_inference_utils_test.cc b/tensorflow/core/ir/utils/shape_inference_utils_test.cc index 96399a90f88d3f..c5359b19048099 100644 --- a/tensorflow/core/ir/utils/shape_inference_utils_test.cc +++ b/tensorflow/core/ir/utils/shape_inference_utils_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/ir/dialect.h" @@ -82,7 +83,7 @@ class ShapeInferenceTest : public ::testing::Test { EXPECT_EQ(op.getNumResults() - 1, info.size()); for (int i = 0; i < op.getNumResults() - 1; ++i) { - ShapedType shape = op.getResultTypes()[i].cast(); + ShapedType shape = mlir::cast(op.getResultTypes()[i]); EXPECT_EQ(shape.hasRank(), info[i].hasRank()); if (shape.hasRank()) EXPECT_EQ(shape.getShape(), info[i].getDims()); if (check_type) @@ -114,7 +115,7 @@ TEST_F(ShapeInferenceTest, TestShapeAndTypeInference) { // `value` attr contains the tensor information and it's a DenseElementAttr. auto op_result_as_shape_fn = [](InferenceContext &ic, OpResult op_result) -> ShapeHandle { - auto rt = op_result.getType().dyn_cast(); + auto rt = mlir::dyn_cast(op_result.getType()); if (!rt || rt.getRank() != 1 || !rt.hasStaticShape()) return {}; std::vector dims(rt.getDimSize(0), ic.UnknownDim()); @@ -136,7 +137,8 @@ TEST_F(ShapeInferenceTest, TestShapeAndTypeInference) { // `InferReturnTypeComponentsForTFOp`uses this callback to get the type // information. auto result_element_type_fn = [&](int idx) -> Type { - return op.getResult(idx).getType().cast().getElementType(); + return mlir::cast(op.getResult(idx).getType()) + .getElementType(); }; // We use TFG operation so that we don't need to provide @@ -178,7 +180,8 @@ TEST_F(ShapeInferenceTest, TestShapeAndTypeInference) { all_results.clear(); for (Operation &op : block.without_terminator()) { auto result_element_type_fn = [&](int idx) -> Type { - return op.getResult(idx).getType().cast().getElementType(); + return mlir::cast(op.getResult(idx).getType()) + .getElementType(); }; SmallVector results; diff --git a/tensorflow/core/tfrt/mlrt/attribute/BUILD b/tensorflow/core/tfrt/mlrt/attribute/BUILD index 5bebee19f89926..9a9d33042fcb75 100644 --- a/tensorflow/core/tfrt/mlrt/attribute/BUILD +++ b/tensorflow/core/tfrt/mlrt/attribute/BUILD @@ -28,6 +28,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:errors", ], ) diff --git a/tensorflow/core/tfrt/mlrt/attribute/attribute.cc b/tensorflow/core/tfrt/mlrt/attribute/attribute.cc index daa92fcaca5953..a9749e92871b14 100644 --- a/tensorflow/core/tfrt/mlrt/attribute/attribute.cc +++ b/tensorflow/core/tfrt/mlrt/attribute/attribute.cc @@ -26,6 +26,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tfrt/translate/mlrt/mlir_to_bytecode.h" @@ -42,7 +43,7 @@ absl::StatusOr EncodeTensorflowAttribute( return std::move(*result); } - if (auto dense_attr = attr.dyn_cast()) { + if (auto dense_attr = mlir::dyn_cast(attr)) { auto element_type = dense_attr.getElementType(); tensorflow::DataType dtype; @@ -95,7 +96,7 @@ absl::StatusOr EncodeTensorflowAttribute( } // Handle dtype attrs - if (auto type_attr = attr.dyn_cast()) { + if (auto type_attr = mlir::dyn_cast(attr)) { tensorflow::DataType dtype; TF_RETURN_IF_ERROR( tensorflow::ConvertToDataType(type_attr.getValue(), &dtype)); @@ -105,7 +106,7 @@ absl::StatusOr EncodeTensorflowAttribute( } // Handle shape attrs - if (auto shape_attr = attr.dyn_cast()) { + if (auto shape_attr = mlir::dyn_cast(attr)) { llvm::ArrayRef shape; if (!shape_attr.getUnranked()) { auto shape_or = shape_attr.getValue(); @@ -131,7 +132,7 @@ absl::StatusOr EncodeTensorflowAttribute( } // Handle attribute arrays. - if (auto array_attr = attr.dyn_cast()) { + if (auto array_attr = mlir::dyn_cast(attr)) { mlrt::bc::Buffer buffer; mlrt::bc::Allocator allocator(&buffer); auto ctor = mlrt::bc::New>( @@ -139,7 +140,7 @@ absl::StatusOr EncodeTensorflowAttribute( int i; for (i = 0; i < array_attr.size(); ++i) { - if (auto type_attr = array_attr[i].dyn_cast()) { + if (auto type_attr = mlir::dyn_cast(array_attr[i])) { tensorflow::DataType dtype; TF_RETURN_IF_ERROR( tensorflow::ConvertToDataType(type_attr.getValue(), &dtype)); diff --git a/tensorflow/core/transforms/BUILD b/tensorflow/core/transforms/BUILD index 54850bc90c587b..1968db2a555270 100644 --- a/tensorflow/core/transforms/BUILD +++ b/tensorflow/core/transforms/BUILD @@ -105,6 +105,7 @@ cc_library( "//tensorflow/core/ir:Dialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/core/transforms/consolidate_attrs/pass.cc b/tensorflow/core/transforms/consolidate_attrs/pass.cc index 2f8b629b2d2988..3c58d7b8ab9402 100644 --- a/tensorflow/core/transforms/consolidate_attrs/pass.cc +++ b/tensorflow/core/transforms/consolidate_attrs/pass.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/core/ir/dialect.h" @@ -47,13 +48,13 @@ static const char *kRegenerateOutputShapes = "tfg.regenerate_output_shapes"; // Returns true if an attribute is an array of shapes; static bool IsArrayOfShapes(ArrayAttr array) { - return llvm::all_of(array, - [](Attribute attr) { return attr.isa(); }); + return llvm::all_of( + array, [](Attribute attr) { return mlir::isa(attr); }); } // Given a tensor type and shape information, try to refine the type. static Type GetReifiedType(Type orig, ShapeAttr shape) { - Type element_type = orig.cast().getElementType(); + Type element_type = mlir::cast(orig).getElementType(); TensorType inferred; if (shape.hasRank()) { // Replace dimensions less than -1 with ? @@ -137,11 +138,11 @@ Type ConsolidateAttributesPassImpl::refineTypeWithOutputShapes( // Get the output shapes attribute. If the attribute is not an array of // exactly one shape, ignore it. if (auto output_shapes = - attrs.get(output_shapes_id_).dyn_cast_or_null()) { + mlir::dyn_cast_or_null(attrs.get(output_shapes_id_))) { if (output_shapes.size() == 1 && IsArrayOfShapes(output_shapes)) { attrs.erase(output_shapes_id_); attrs.set(regenerate_output_shapes_id_, UnitAttr::get(&getContext())); - return GetReifiedType(type, output_shapes[0].cast()); + return GetReifiedType(type, mlir::cast(output_shapes[0])); } } return type; @@ -153,8 +154,9 @@ Type ConsolidateAttributesPassImpl::refineTypeWithHandleData( SmallVector subtypes; // Because `tfg.handle_data` is a TFG internal attribute, it will be // well-formed. - for (Type type : handle_data.cast().getAsValueRange()) - subtypes.push_back(type.cast()); + for (Type type : + mlir::cast(handle_data).getAsValueRange()) + subtypes.push_back(mlir::cast(type)); auto resource = UnrankedTensorType::get(ResourceType::get(subtypes, &getContext())); Type reified = tf_type::GetCastCompatibleType(resource, type); @@ -167,7 +169,7 @@ ArrayAttr ConsolidateAttributesPassImpl::reifyAndDropFunctionArgumentAttributes( // we will ignore it. If it isn't an array of shapes or has an inconsistent // number of shapes, ignore it. ArrayAttr input_shapes = - func->getAttr(input_shapes_id_).dyn_cast_or_null(); + mlir::dyn_cast_or_null(func->getAttr(input_shapes_id_)); unsigned num_args = func.getNumArguments() / 2; if (input_shapes) { if (input_shapes.size() != num_args || !IsArrayOfShapes(input_shapes)) { @@ -188,7 +190,8 @@ ArrayAttr ConsolidateAttributesPassImpl::reifyAndDropFunctionArgumentAttributes( arg_type = refineTypeWithOutputShapes(arg_type, attrs); arg_type = refineTypeWithHandleData(arg_type, attrs.erase(handle_data_id_)); if (input_shapes) - arg_type = GetReifiedType(arg_type, input_shapes[i].cast()); + arg_type = + GetReifiedType(arg_type, mlir::cast(input_shapes[i])); arg.setType(arg_type); attrs.erase(dtype_id_); attrs.erase(is_ref_id_); @@ -242,7 +245,7 @@ class ReifyOperationOutputShapes : public RewritePattern { // attribute, if it has an inconsistent number of shapes, or if it is not // an array of shapes. ArrayAttr output_shapes = - op->getAttr(output_shapes_id_).dyn_cast_or_null(); + mlir::dyn_cast_or_null(op->getAttr(output_shapes_id_)); if (!output_shapes || results.size() != output_shapes.size() || !IsArrayOfShapes(output_shapes)) return failure(); @@ -422,7 +425,7 @@ void PrepareAttributesForExportPassImpl::prepareFunctionAttributes( continue; } arg_attrs.push_back(prepareAttributesFor(type, attrs)); - if (auto ranked = type.dyn_cast()) { + if (auto ranked = mlir::dyn_cast(type)) { input_shapes.push_back(ShapeAttr::get(&getContext(), ranked.getShape())); } else { input_shapes.push_back(ShapeAttr::get(&getContext(), std::nullopt)); @@ -450,14 +453,14 @@ DictionaryAttr PrepareAttributesForExportPassImpl::prepareAttributesFor( NamedAttrList attrs(attr_dict); // Add shape data if requested. if (attrs.erase(regenerate_output_shapes_id_)) { - auto shape = ShapeAttr::get(&getContext(), - type.isa() - ? type.cast().getShape() - : std::optional>()); + auto shape = ShapeAttr::get( + &getContext(), mlir::isa(type) + ? mlir::cast(type).getShape() + : std::optional>()); attrs.set(output_shapes_id_, ArrayAttr::get(&getContext(), {shape})); } - auto element_type = type.cast().getElementType(); - if (auto resource = element_type.dyn_cast()) { + auto element_type = mlir::cast(type).getElementType(); + if (auto resource = mlir::dyn_cast(element_type)) { SmallVector handle_data; for (TensorType subtype : resource.getSubtypes()) handle_data.push_back(TypeAttr::get(subtype)); @@ -465,7 +468,7 @@ DictionaryAttr PrepareAttributesForExportPassImpl::prepareAttributesFor( if (!handle_data.empty()) attrs.set(handle_data_id_, ArrayAttr::get(&getContext(), handle_data)); } - if (element_type.isa()) + if (mlir::isa(element_type)) attrs.set(is_ref_id_, UnitAttr::get(&getContext())); return attrs.getDictionary(&getContext()); } @@ -475,8 +478,8 @@ static ArrayAttr GetElementTypesAttr(PatternRewriter &rewriter, ValueRange values) { SmallVector types; for (Value value : values) { - types.push_back( - TypeAttr::get(value.getType().cast().getElementType())); + types.push_back(TypeAttr::get( + mlir::cast(value.getType()).getElementType())); } return rewriter.getArrayAttr(types); } @@ -515,11 +518,10 @@ struct MaterializeIfAttrs : public MaterializeAttrsPattern { PatternRewriter &rewriter) const override { if (op.getTcond() && op.getTin() && op.getTout()) return failure(); NamedAttrList attrs(op->getAttrDictionary()); - attrs.set(op.getTcondAttrName(), - TypeAttr::get(op.getCond() - .getType() - .template cast() - .getElementType())); + attrs.set( + op.getTcondAttrName(), + TypeAttr::get( + mlir::cast(op.getCond().getType()).getElementType())); attrs.set(op.getTinAttrName(), this->getArgumentElementTypesAttr(rewriter, op)); attrs.set(op.getToutAttrName(), @@ -583,7 +585,7 @@ class MaterializeOutputShapesBase : public RewritePattern { SmallVector shapes; for (Value result : results) { - if (auto ranked = result.getType().dyn_cast()) { + if (auto ranked = mlir::dyn_cast(result.getType())) { shapes.push_back(ShapeAttr::get(op->getContext(), ranked.getShape())); } else { shapes.push_back(ShapeAttr::get(op->getContext(), std::nullopt)); diff --git a/tensorflow/core/transforms/const_dedupe_hoist/BUILD b/tensorflow/core/transforms/const_dedupe_hoist/BUILD index 381b666a80a711..54d58e8c7ee09f 100644 --- a/tensorflow/core/transforms/const_dedupe_hoist/BUILD +++ b/tensorflow/core/transforms/const_dedupe_hoist/BUILD @@ -21,6 +21,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/core/transforms/const_dedupe_hoist/pass.cc b/tensorflow/core/transforms/const_dedupe_hoist/pass.cc index 8358bd338d374c..6a793cedea3b13 100644 --- a/tensorflow/core/transforms/const_dedupe_hoist/pass.cc +++ b/tensorflow/core/transforms/const_dedupe_hoist/pass.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/ir/dialect.h" #include "tensorflow/core/ir/ops.h" #include "tensorflow/core/ir/utility.h" @@ -189,7 +190,7 @@ void DedupeAndHoistConstantPass::RunOnGraphOrFuncOp(Operation* op) { op->walk([&](Operation* inner_op) { if (inner_op->getName().getIdentifier() != tfg_const) return; - ElementsAttr val = inner_op->getAttr(value_id).cast(); + ElementsAttr val = mlir::cast(inner_op->getAttr(value_id)); if (val.getNumElements() > max_size_) return; constant_ops[inner_op].push_back(inner_op); }); diff --git a/tensorflow/core/transforms/constant_folding/pass.cc b/tensorflow/core/transforms/constant_folding/pass.cc index 07e831317891c4..a168283bbb7190 100644 --- a/tensorflow/core/transforms/constant_folding/pass.cc +++ b/tensorflow/core/transforms/constant_folding/pass.cc @@ -107,7 +107,7 @@ static FailureOr CreateConstantTensorOp( OpBuilder &builder, Location loc, StringRef name_prefix, Type type, ValueRange control_operands, TypedAttr tensor_value, ArrayRef other_attrs = std::nullopt) { - if (type.isa()) return failure(); + if (mlir::isa(type)) return failure(); // TODO(chiahungduan): Reuse ConstOp Like // OperationFolder::tryGetOrCreateConstant. OperationState state(loc, "tfg.Const"); @@ -116,8 +116,9 @@ static FailureOr CreateConstantTensorOp( state.attributes = other_attrs; util::EraseRegularNodeAttributes(state.attributes); state.attributes.set( - "dtype", TypeAttr::get( - tensor_value.getType().cast().getElementType())); + "dtype", + TypeAttr::get( + mlir::cast(tensor_value.getType()).getElementType())); state.attributes.set("value", tensor_value); if (!name_prefix.empty()) { state.attributes.set( @@ -170,7 +171,7 @@ static TFOp GetControlAnchorForSwitchResult( if (StringAttr device_attr = switch_op.deviceAttr()) identity_op.setRequestedDevice(device_attr); identity_op.setName(Twine(switch_op.name(), "/ControlDependencyCtrl_") + - Twine(value.cast().getResultNumber())); + Twine(mlir::cast(value).getResultNumber())); return identity_op; } @@ -179,12 +180,12 @@ static TFOp GetControlAnchorForSwitchResult( // the output does not necessarily activate when the switch op activates. We // add a "control anchor" in the form of an identity op instead. static Value GetControlDependency(OpBuilder &builder, Value value) { - if (value.getType().isa()) return value; + if (mlir::isa(value.getType())) return value; TFGraphDialect *dialect = builder.getContext()->getLoadedDialect(); assert(dialect); - if (OpResult result = value.dyn_cast(); + if (OpResult result = mlir::dyn_cast(value); result && dialect->IsSwitch(result.getOwner())) { return GetControlAnchorForSwitchResult(builder, result, dialect) .controlRet(); @@ -196,7 +197,7 @@ static Value GetControlDependency(OpBuilder &builder, Value value) { // Add control operand to `op` if it doesn't exist. static void AddControlOperand(Operation *op, Value control, PatternRewriter &rewriter) { - assert(control.getType().isa()); + assert(mlir::isa(control.getType())); if (llvm::is_contained(op->getOperands(), control)) return; rewriter.startOpModification(op); op->insertOperands(op->getNumOperands(), control); @@ -271,7 +272,7 @@ static FailureOr ReplaceOpWithNoOp(OpBuilder &builder, TFOp op) { static FailureOr ReplaceOpWithConstant(OpBuilder &builder, Operation *op, double constant_value) { - auto res = (*op->result_type_begin()).cast(); + auto res = mlir::cast((*op->result_type_begin())); Type dtype = GetDataTypeFromOp(builder, op); Attribute value_attr; if (dtype.isIntOrIndex()) @@ -315,7 +316,7 @@ static FailureOr ReplaceOpWithSnapshot(OpBuilder &builder, TFOp op, static FailureOr ReplaceOpWithBroadcastTo(OpBuilder &builder, TFOp op, int idx_to_replace) { - ShapedType tensor_type = (*op->result_type_begin()).cast(); + ShapedType tensor_type = mlir::cast((*op->result_type_begin())); if (!tensor_type.hasStaticShape()) return failure(); ElementsAttr const_attr = ConvertShapeToAttr(tensor_type); @@ -551,7 +552,8 @@ bool OpPropertyHelper::IsFoldableUncached(TFOp op) { TFOp operand_op = operand.getDefiningOp(); if (operand_op && dialect_->IsConstant(operand_op)) { auto dtype = operand_op->getAttrOfType("dtype"); - if (!dtype || dtype.getValue().isa()) return false; + if (!dtype || mlir::isa(dtype.getValue())) + return false; // Special case: If a Merge node has at least one constant input that // does not depend on a control input, we can fold it. @@ -572,7 +574,7 @@ bool OpPropertyHelper::IsFoldableUncached(TFOp op) { // to materialize. int64_t input_size_bytes = 0; for (Value operand : operands) { - auto shape = operand.getType().dyn_cast(); + auto shape = mlir::dyn_cast(operand.getType()); if (!shape || !shape.hasStaticShape()) continue; auto element_type = shape.getElementType(); @@ -581,7 +583,7 @@ bool OpPropertyHelper::IsFoldableUncached(TFOp op) { input_size_bytes += shape.getNumElements() * DataTypeSize(dtype); } for (Value res : op->getResults().drop_back()) { - auto shape = res.getType().dyn_cast(); + auto shape = mlir::dyn_cast(res.getType()); if (!shape || !shape.hasStaticShape()) continue; auto element_type = shape.getElementType(); @@ -742,7 +744,7 @@ class EvaluateConstant : public FolderPatternBase { // TODO(tlongeri): Is CreateConstantTensorNode check correct? Shouldn't it // always be a ShapedType? for (TypedAttr r : result) - if (r && r.getType().isa()) return failure(); + if (r && mlir::isa(r.getType())) return failure(); StringAttr name_attr = static_cast(op->getDialect()) ->getNameAttrIdentifier(); @@ -824,7 +826,7 @@ class MaterializeShapeOp : public FolderPatternBase { PatternRewriter &rewriter) const override { Value input = op->getOperand(0); - auto input_shape = input.getType().cast(); + auto input_shape = mlir::cast(input.getType()); if (!input_shape.hasStaticShape()) return failure(); // TODO(rmlarsen): Remove this workaround for b/150861569 @@ -834,7 +836,7 @@ class MaterializeShapeOp : public FolderPatternBase { return failure(); Type output_dtype = - op->getResult(0).getType().cast().getElementType(); + mlir::cast(op->getResult(0).getType()).getElementType(); ElementsAttr const_attr = CreateElementsAttrOfTypeValues( output_dtype, {input_shape.getRank()}, input_shape.getShape()); @@ -863,10 +865,10 @@ class MaterializeSizeOp : public FolderPatternBase { PatternRewriter &rewriter) const override { Value input = op->getOperand(0); - auto input_shape = input.getType().cast(); + auto input_shape = mlir::cast(input.getType()); if (!input_shape.hasStaticShape()) return failure(); - ShapedType result_type = (*op->result_type_begin()).cast(); + ShapedType result_type = mlir::cast((*op->result_type_begin())); if (!result_type.getElementType().isIntOrIndexOrFloat()) return failure(); ElementsAttr const_attr = CreateElementsAttrOfTypeValues( @@ -898,10 +900,10 @@ class MaterializeRankOp : public FolderPatternBase { PatternRewriter &rewriter) const override { Value input = op->getOperand(0); - auto input_shape = input.getType().cast(); + auto input_shape = mlir::cast(input.getType()); if (!input_shape.hasRank()) return failure(); - ShapedType result_type = (*op->result_type_begin()).cast(); + ShapedType result_type = mlir::cast((*op->result_type_begin())); if (!result_type.getElementType().isIntOrIndexOrFloat()) return failure(); ElementsAttr const_attr = CreateElementsAttrOfTypeValues( @@ -976,7 +978,7 @@ class MaterializeShapeNOp : public FolderPatternBase { for (const auto &it : llvm::enumerate(TFOp(op).getNonControlOperands())) { Value operand = op->getOperand(it.index()); - auto operand_shape = operand.getType().cast(); + auto operand_shape = mlir::cast(operand.getType()); if (!operand_shape.hasStaticShape()) continue; if (op->getResults()[it.index()].use_empty()) continue; @@ -1033,7 +1035,7 @@ class MaterializeBroadcastGradientArgsOp auto get_shape = [this](Operation *op, SmallVector &shape) -> bool { if (dialect_->IsShape(op)) { - auto type = op->getOperand(0).getType().cast(); + auto type = mlir::cast(op->getOperand(0).getType()); if (!type.hasRank()) return false; llvm::append_range(shape, type.getShape()); @@ -1139,18 +1141,19 @@ class MaterializeReductionIndices // The reduction indices are already constant, there's nothing to do. if (!indices || dialect_->IsConstant(indices)) return failure(); - auto indices_shape = indices->getResult(0).getType().cast(); + auto indices_shape = + mlir::cast(indices->getResult(0).getType()); if (!indices_shape.hasRank()) return failure(); if (!indices_shape.getElementType().isInteger(32) && !indices_shape.getElementType().isInteger(64)) { return failure(); } - auto input_shape = op->getOperand(0).getType().cast(); + auto input_shape = mlir::cast(op->getOperand(0).getType()); // Unexpected graph, don't try to change it. if (!input_shape.hasRank() || input_shape.getRank() < 1) return failure(); - auto output_shape = op->getResult(0).getType().cast(); + auto output_shape = mlir::cast(op->getResult(0).getType()); const int output_rank = output_shape.hasRank() ? output_shape.getRank() : -1; @@ -1167,7 +1170,7 @@ class MaterializeReductionIndices full_reduction = false; if (!dialect_->IsReshape(user)) return failure(); - auto shape = user->getResult(0).getType().cast(); + auto shape = mlir::cast(user->getResult(0).getType()); if (!shape.hasStaticShape() || shape.getNumElements() != 1) return failure(); else @@ -1214,7 +1217,7 @@ class MaterializeFillNode : public FolderPatternBase { // Only handles single result op. Note that another result is control ret. if (op->getNumResults() != 2) return failure(); - auto output_type = op->getResult(0).getType().cast(); + auto output_type = mlir::cast(op->getResult(0).getType()); if (!output_type.hasStaticShape()) return failure(); if (!output_type.isIntOrIndexOrFloat()) return failure(); @@ -1262,7 +1265,7 @@ class MaterializeConstantValuedNode // TODO(chiahungduan): If op->getOperand(0) has static shape, can we use // that to materialize? - auto output_type = op->getResult(0).getType().cast(); + auto output_type = mlir::cast(op->getResult(0).getType()); if (!output_type.hasStaticShape()) return failure(); int value = is_zeros_like ? 0 : 1; @@ -1277,8 +1280,9 @@ class MaterializeConstantValuedNode } else { const_attr = SplatElementsAttr::get( output_type, - APFloat(output_element_type.cast().getFloatSemantics(), - value)); + APFloat( + mlir::cast(output_element_type).getFloatSemantics(), + value)); } FailureOr const_op = @@ -1457,7 +1461,7 @@ class RemoveShuffleOp : public FolderPatternBase { ElementsAttr perm_tensor = perm_op->getAttrOfType("value"); if (!perm_tensor) return failure(); - ShapedType x_shape = op->getOperand(0).getType().cast(); + ShapedType x_shape = mlir::cast(op->getOperand(0).getType()); if (!x_shape.hasRank()) return failure(); if (perm_tensor.getNumElements() != x_shape.getRank()) return failure(); @@ -1489,7 +1493,7 @@ class RemoveTransposeOp : public FolderPatternBase { ElementsAttr perm_tensor = perm_op->getAttrOfType("value"); if (!perm_tensor) return failure(); - ShapedType x_shape = op->getOperand(0).getType().cast(); + ShapedType x_shape = mlir::cast(op->getOperand(0).getType()); if (!x_shape.hasRank()) return failure(); if (perm_tensor.getNumElements() != x_shape.getRank()) return failure(); @@ -1516,7 +1520,7 @@ class RemoveRandomShuffleOp : public FolderPatternBase { : FolderPatternBase("tfg.RandomShuffle", helper) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - auto shape = op->getOperand(0).getType().cast(); + auto shape = mlir::cast(op->getOperand(0).getType()); if (!shape.hasRank()) return failure(); if (shape.getRank() != 0 && shape.getShape()[0] != 1) return failure(); @@ -1536,7 +1540,8 @@ class RemoveReverse : public FolderPatternBase { : FolderPatternBase("tfg.ReverseV2", helper) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - ShapedType tensor_type = op->getOperand(0).getType().cast(); + ShapedType tensor_type = + mlir::cast(op->getOperand(0).getType()); if (!tensor_type.hasRank()) return failure(); Operation *dim_op = op->getOperand(1).getDefiningOp(); @@ -1588,7 +1593,7 @@ class SimplifySliceOp : public FolderPatternBase { auto begin_attr = begin_op->getAttrOfType("value"); auto size_attr = size_op->getAttrOfType("value"); - ShapedType input_type = op->getOperand(0).getType().cast(); + ShapedType input_type = mlir::cast(op->getOperand(0).getType()); if (!input_type.hasRank()) return failure(); for (unsigned i = 0; i < input_type.getRank(); ++i) { @@ -1643,7 +1648,7 @@ class SimplifyStridedSlice : public FolderPatternBase { if (!begin_mask_attr || !end_mask_attr || !ellipsis_mask_attr) return failure(); - ShapedType input_type = op->getOperand(0).getType().cast(); + ShapedType input_type = mlir::cast(op->getOperand(0).getType()); if (!input_type.hasStaticShape()) return failure(); Operation *begin_op = op->getOperand(1).getDefiningOp(); @@ -1805,7 +1810,7 @@ class SimplifySqueezeOp : public FolderPatternBase { : FolderPatternBase("tfg.Squeeze", helper) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - auto shape_type = op->getOperand(0).getType().cast(); + auto shape_type = mlir::cast(op->getOperand(0).getType()); if (!shape_type.hasRank()) return failure(); if (llvm::any_of(shape_type.getShape(), [](int64_t s) { return s <= 1; })) return failure(); @@ -1836,13 +1841,13 @@ class SimplifyPackOp : public FolderPatternBase { // protos, e.g. there is DT_RESOURCE). // TODO(tlongeri): is there a reason ExpandDims does not support DT_VARIANT? if (ShapedType values_type = - non_control_operands[0].getType().dyn_cast(); - !values_type || values_type.getElementType().isa()) + mlir::dyn_cast(non_control_operands[0].getType()); + !values_type || mlir::isa(values_type.getElementType())) return failure(); // It's unsafe to add a control dependency on the feed node, because it // might have been never executed otherwiwise. - if (non_control_operands[0].isa()) return failure(); + if (mlir::isa(non_control_operands[0])) return failure(); IntegerAttr axis = op->getAttrOfType("axis"); ElementsAttr const_attr = CreateElementsAttrOfTypeValues( @@ -2033,8 +2038,8 @@ class SimplifyReductionOp : public FolderPatternBase { } // Check `IsReductionCandidateForSimplification` - auto input_type = op->getOperand(0).getType().cast(); - auto op_type = (*op->result_type_begin()).cast(); + auto input_type = mlir::cast(op->getOperand(0).getType()); + auto op_type = mlir::cast((*op->result_type_begin())); if (!input_type.hasStaticShape() || !op_type.hasStaticShape()) return failure(); @@ -2096,7 +2101,7 @@ class SimplifyReductionOp : public FolderPatternBase { Operation *ReplaceReductionWithReshape(OpBuilder &builder, Operation *op, Operation *reduction_indices) const { const int new_num_dimensions = - (*op->result_type_begin()).cast().getRank(); + mlir::cast((*op->result_type_begin())).getRank(); SmallVector elements(new_num_dimensions); std::iota(elements.begin(), elements.end(), 1); ElementsAttr const_attr = CreateElementsAttrOfTypeValues( @@ -2164,7 +2169,7 @@ class SimplifyReshapeOp : public FolderPatternBase { PatternRewriter &rewriter) const override { if (!dialect_->IsReshape(op) || !op->hasAttr("T")) return failure(); - auto input_shape = op->getOperand(0).getType().cast(); + auto input_shape = mlir::cast(op->getOperand(0).getType()); if (!input_shape.hasStaticShape()) return failure(); Operation *shape_op = op->getOperand(1).getDefiningOp(); @@ -2227,9 +2232,9 @@ class SimplifyArithmeticOp Operation *y = op->getOperand(1).getDefiningOp(); if (!x || !y) return failure(); - ShapedType op_type = (*op->result_type_begin()).cast(); - ShapedType x_type = (*x->result_type_begin()).cast(); - ShapedType y_type = (*y->result_type_begin()).cast(); + ShapedType op_type = mlir::cast((*op->result_type_begin())); + ShapedType x_type = mlir::cast((*x->result_type_begin())); + ShapedType y_type = mlir::cast((*y->result_type_begin())); const bool y_matches_output_shape = op_type.hasStaticShape() && y_type.hasStaticShape() && @@ -2277,8 +2282,8 @@ class SimplifyArithmeticOp TypeAttr type_attr = op->getAttrOfType("T"); if (!type_attr) return failure(); - if (type_attr.getValue().isa() || - type_attr.getValue().isa()) { + if (mlir::isa(type_attr.getValue()) || + mlir::isa(type_attr.getValue())) { OperationState state(op->getLoc(), "tfg.Reciprocal"); state.addOperands({op->getOperand(1), GetControlDependency(rewriter, op->getOperand(0))}); @@ -2401,8 +2406,9 @@ class ReduceDivToReciprocalMul if (!type_attr) return failure(); // Skip integer division. - if (dialect_->IsDiv(op) && !(type_attr.getValue().isa() || - type_attr.getValue().isa())) { + if (dialect_->IsDiv(op) && + !(mlir::isa(type_attr.getValue()) || + mlir::isa(type_attr.getValue()))) { return failure(); } @@ -2572,8 +2578,8 @@ class ConstantPushDown : public ConstantPushDownBase { // Dimensions of X must be smaller than or equal than those of C. // This also avoids having to increase the size of the child op's result // to match the broadcast with a bigger operand. - auto c_shape = const_op->getResult(0).getType().cast(); - auto x_shape = x_value.getType().cast(); + auto c_shape = mlir::cast(const_op->getResult(0).getType()); + auto x_shape = mlir::cast(x_value.getType()); if (c_shape.hasStaticShape() && x_shape.hasStaticShape() && c_shape.getNumElements() > x_shape.getNumElements()) { @@ -2677,7 +2683,7 @@ class PartialConstPropThroughIdentityN SmallVector control_operands; for (OpOperand &operand : op->getOpOperands()) { Value v = operand.get(); - if (v.getType().isa()) break; + if (mlir::isa(v.getType())) break; Operation *v_op = v.getDefiningOp(); if (!v_op || !dialect_->IsIdentityN(v_op) || @@ -2685,7 +2691,7 @@ class PartialConstPropThroughIdentityN continue; } - int res_index = v.cast().getResultNumber(); + int res_index = mlir::cast(v).getResultNumber(); Value value_to_forward = v_op->getOperand(res_index); if (!value_to_forward.getDefiningOp() || !dialect_->IsConstant(value_to_forward.getDefiningOp())) { @@ -2965,21 +2971,22 @@ class MulConvPushDown : public ConstantPatternBaseresult_type_begin()).cast(); + ShapedType mul_shape = mlir::cast((*op->result_type_begin())); ShapedType conv_shape = - (*conv_node->result_type_begin()).cast(); + mlir::cast((*conv_node->result_type_begin())); // TODO(chiahungduan): Symbolic shape equivalence is acceptable. if (!mul_shape.hasStaticShape() || !conv_shape.hasStaticShape() || mul_shape != conv_shape) { return failure(); } - auto filter_shape = conv_node->getOperand(1).getType().cast(); + auto filter_shape = + mlir::cast(conv_node->getOperand(1).getType()); Operation *const_node = left_child_is_constant ? mul_left_child : mul_right_child; auto const_node_shape = - (*const_node->result_type_begin()).cast(); + mlir::cast((*const_node->result_type_begin())); if (!IsValidConstShapeForMulConvPushDown( conv_node->getAttrOfType("data_format"), filter_shape, const_node_shape)) { @@ -3235,7 +3242,7 @@ class ConstantPushDownBiasAdd if (!IsOperandsSafeToMove(add_child, const_child)) return failure(); auto hasRank = [&](Value value) { - return value.getType().cast().hasRank(); + return mlir::cast(value.getType()).hasRank(); }; if (!hasRank(op->getOperand(0)) || !hasRank(op->getOperand(1)) || @@ -3246,19 +3253,19 @@ class ConstantPushDownBiasAdd // Now get the ranks and types of the 3 leaf nodes. const int left_leaf_rank = - add_child->getOperand(0).getType().cast().getRank(); + mlir::cast(add_child->getOperand(0).getType()).getRank(); const int right_leaf_rank = - add_child->getOperand(1).getType().cast().getRank(); + mlir::cast(add_child->getOperand(1).getType()).getRank(); // At least one leaf must be a vector. if (left_leaf_rank != 1 && right_leaf_rank != 1) return failure(); const int vector_idx = left_leaf_rank == 1 ? 0 : 1; auto vector_type = - add_child->getOperand(vector_idx).getType().cast(); + mlir::cast(add_child->getOperand(vector_idx).getType()); Type vector_d_type = vector_type.getElementType(); - auto const_type = const_child->getResultTypes()[0].cast(); + auto const_type = mlir::cast(const_child->getResultTypes()[0]); const int const_rank = const_type.getRank(); Type const_d_type = const_type.getElementType(); @@ -3336,7 +3343,7 @@ class ConstantPushDownAdd : public ConstantPushDownBase { if (!child_is_bias_add && !dialect_->IsAdd(add_child)) return failure(); auto hasRank = [&](Value value) { - return value.getType().cast().hasRank(); + return mlir::cast(value.getType()).hasRank(); }; if (!hasRank(op->getOperand(0)) || !hasRank(op->getOperand(1)) || @@ -3347,9 +3354,9 @@ class ConstantPushDownAdd : public ConstantPushDownBase { // Now get the ranks and types of the 3 leaf nodes. const int left_leaf_rank = - add_child->getOperand(0).getType().cast().getRank(); + mlir::cast(add_child->getOperand(0).getType()).getRank(); const int right_leaf_rank = - add_child->getOperand(1).getType().cast().getRank(); + mlir::cast(add_child->getOperand(1).getType()).getRank(); // At least one leaf must be a vector. if (left_leaf_rank != 1 && right_leaf_rank != 1) return failure(); @@ -3357,18 +3364,18 @@ class ConstantPushDownAdd : public ConstantPushDownBase { const int matrix_idx = 1 - vector_idx; ShapedType vector_type = - add_child->getOperand(vector_idx).getType().cast(); + mlir::cast(add_child->getOperand(vector_idx).getType()); Type vector_d_type = vector_type.getElementType(); ShapedType matrix_type = - add_child->getOperand(matrix_idx).getType().cast(); + mlir::cast(add_child->getOperand(matrix_idx).getType()); const int matrix_rank = matrix_type.getRank(); Type matrix_d_type = matrix_type.getElementType(); const int const_index = op->getOperand(0).getDefiningOp() == const_child ? 0 : 1; ShapedType const_type = - const_child->getResult(0).getType().cast(); + mlir::cast(const_child->getResult(0).getType()); const int const_rank = const_type.getRank(); Type const_d_type = const_type.getElementType(); @@ -3518,9 +3525,9 @@ class SimplifySelectOpBase : public FolderPatternBase { bool is_all_false = this->helper_.IsZeros(condition_op); if (!is_all_true && !is_all_false) return failure(); - auto condition_type = op->getOperand(0).getType().cast(); - auto t_type = op->getOperand(1).getType().cast(); - auto e_type = op->getOperand(2).getType().cast(); + auto condition_type = mlir::cast(op->getOperand(0).getType()); + auto t_type = mlir::cast(op->getOperand(1).getType()); + auto e_type = mlir::cast(op->getOperand(2).getType()); if (!condition_type.hasStaticShape() || !t_type.hasStaticShape() || !e_type.hasStaticShape()) { return failure(); diff --git a/tensorflow/core/transforms/func_to_graph/BUILD b/tensorflow/core/transforms/func_to_graph/BUILD index 0c62a5a2f90894..8463483b01d80f 100644 --- a/tensorflow/core/transforms/func_to_graph/BUILD +++ b/tensorflow/core/transforms/func_to_graph/BUILD @@ -19,6 +19,7 @@ cc_library( "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/core/transforms/func_to_graph/func_to_graph.cc b/tensorflow/core/transforms/func_to_graph/func_to_graph.cc index 15bd5bc00252a6..dcedaed70f4c1a 100644 --- a/tensorflow/core/transforms/func_to_graph/func_to_graph.cc +++ b/tensorflow/core/transforms/func_to_graph/func_to_graph.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/transforms/func_to_graph/func_to_graph.h" #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/ir/ops.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/status.h" @@ -47,8 +48,9 @@ tensorflow::Status FuncToGraph(GraphFuncOp func) { // Init the entry with nullptr and it'll be updated with associated op // later. - referred_ops.insert({lifted_value_attr[0].cast().getValue(), - /*Operation=*/nullptr}); + referred_ops.insert( + {mlir::cast(lifted_value_attr[0]).getValue(), + /*Operation=*/nullptr}); } } @@ -59,7 +61,7 @@ tensorflow::Status FuncToGraph(GraphFuncOp func) { } for (const auto &it : llvm::enumerate(func.getArguments())) { - if (it.value().getType().isa()) continue; + if (mlir::isa(it.value().getType())) continue; auto lifted_value_attr = func.getArgAttrOfType(it.index(), lifted_value_attr_name); @@ -70,7 +72,7 @@ tensorflow::Status FuncToGraph(GraphFuncOp func) { } StringRef value_defining_op_name = - lifted_value_attr[0].cast().getValue(); + mlir::cast(lifted_value_attr[0]).getValue(); Operation *op = referred_ops[value_defining_op_name]; if (!op) { return tensorflow::errors::InvalidArgument( @@ -79,7 +81,7 @@ tensorflow::Status FuncToGraph(GraphFuncOp func) { } uint64_t result_index = - lifted_value_attr[1].cast().getValue().getZExtValue(); + mlir::cast(lifted_value_attr[1]).getValue().getZExtValue(); if (result_index >= op->getNumResults()) { return tensorflow::errors::InvalidArgument( "result index out of bound: seeing index ", result_index, diff --git a/tensorflow/core/transforms/functional_to_region/impl.cc b/tensorflow/core/transforms/functional_to_region/impl.cc index ffdffd14d494e5..046f64493eb4a7 100644 --- a/tensorflow/core/transforms/functional_to_region/impl.cc +++ b/tensorflow/core/transforms/functional_to_region/impl.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/core/ir/dialect.h" #include "tensorflow/core/ir/ops.h" @@ -383,7 +384,7 @@ LogicalResult ConvertCaseLikeOp::matchAndRewrite( } ArrayAttr region_attrs = nullptr; if (!llvm::all_of(preserved_attrs, [](Attribute attr) { - return AreRegionAttrsEmpty(attr.cast()); + return AreRegionAttrsEmpty(mlir::cast(attr)); })) region_attrs = rewriter.getArrayAttr(preserved_attrs); diff --git a/tensorflow/core/transforms/graph_compactor/pass.cc b/tensorflow/core/transforms/graph_compactor/pass.cc index bf5eb785c8f181..883bc36891e9d7 100644 --- a/tensorflow/core/transforms/graph_compactor/pass.cc +++ b/tensorflow/core/transforms/graph_compactor/pass.cc @@ -110,7 +110,8 @@ class NameCompressPass : public impl::NameCompressBase { arg_attrs.reserve(func.getNumArguments()); // Iterate over the function arguments, skipping the control tokens. for (int i = 0, e = func.getNumArguments(); i != e; i += 2) { - NamedAttrList attrs = func.getArgAttrsAttr()[i].cast(); + NamedAttrList attrs = + mlir::cast(func.getArgAttrsAttr()[i]); attrs.set(dialect_->getTfgNameAttrIdentifier(), encode_new_name()); arg_attrs.append({attrs.getDictionary(&getContext()), empty_dict_}); } diff --git a/tensorflow/core/transforms/graph_to_func/graph_to_func.cc b/tensorflow/core/transforms/graph_to_func/graph_to_func.cc index 83c2230afa40e0..4db1cc15ff310c 100644 --- a/tensorflow/core/transforms/graph_to_func/graph_to_func.cc +++ b/tensorflow/core/transforms/graph_to_func/graph_to_func.cc @@ -86,11 +86,11 @@ tensorflow::Status GraphToFunc(GraphOp graph, ArrayRef feeds, feed.replaceAllUsesWith(body->addArgument(feed.getType(), loc)); body->addArgument(control_ty, loc); llvm::SmallVector arg_attrs; - std::string slot = OpResultToSlotName(feed.cast()); + std::string slot = OpResultToSlotName(mlir::cast(feed)); arg_attrs.push_back(NamedAttribute(tfg_name, builder.getStringAttr(slot))); - arg_attrs.push_back( - NamedAttribute(lifted_value_name, - createLiftedValueAttr(builder, feed.cast()))); + arg_attrs.push_back(NamedAttribute( + lifted_value_name, + createLiftedValueAttr(builder, mlir::cast(feed)))); args_rets_attrs.push_back(builder.getDictionaryAttr(arg_attrs)); args_rets_attrs.push_back(Attribute{}); } @@ -99,7 +99,7 @@ tensorflow::Status GraphToFunc(GraphOp graph, ArrayRef feeds, args_rets_attrs.clear(); for (Value fetch : fetches) { llvm::SmallVector arg_attrs; - std::string slot = OpResultToSlotName(fetch.cast()); + std::string slot = OpResultToSlotName(mlir::cast(fetch)); arg_attrs.push_back(NamedAttribute(tfg_name, builder.getStringAttr(slot))); args_rets_attrs.push_back(builder.getDictionaryAttr(arg_attrs)); } diff --git a/tensorflow/core/transforms/region_to_functional/impl.cc b/tensorflow/core/transforms/region_to_functional/impl.cc index a4d8b07f2594e2..439f622240b5aa 100644 --- a/tensorflow/core/transforms/region_to_functional/impl.cc +++ b/tensorflow/core/transforms/region_to_functional/impl.cc @@ -518,7 +518,7 @@ NamedAttrList BasePattern::BuildAttributes(RegionAttr preserved, // For each argument and result, lookup a name and regenerate output shapes. const auto build_attrs = [&](ArrayAttr attr, auto &it, std::optional args) { - NamedAttrList attrs(attr ? attr[it.index()].template cast() + NamedAttrList attrs(attr ? mlir::cast(attr[it.index()]) : DictionaryAttr()); // If no name was preserved, try to find one. if (!attrs.get(ids_.tfg_name)) { @@ -548,7 +548,7 @@ NamedAttrList BasePattern::BuildAttributes(RegionAttr preserved, StringAttr BasePattern::TryFindName(Value value, std::optional args) const { // If this is an op result, return the op's name. - if (auto result = value.dyn_cast()) { + if (auto result = mlir::dyn_cast(value)) { Operation *op = result.getOwner(); if (auto name = op->getAttrOfType(dialect_.getNameAttrIdentifier())) { @@ -558,7 +558,7 @@ StringAttr BasePattern::TryFindName(Value value, return {}; } - auto arg = value.cast(); + auto arg = mlir::cast(value); Operation *parent = arg.getOwner()->getParentOp(); auto iface = dyn_cast(parent); if (!iface) return {}; @@ -904,12 +904,11 @@ LogicalResult ConvertCaseLikeOp::matchAndRewrite( // Get the preserved attributes, if there are any. RegionAttr preserved = op.getRegionAttrs() - ? op.getRegionAttrsAttr()[idx].template cast() + ? mlir::cast(op.getRegionAttrsAttr()[idx]) : nullptr; DictionaryAttr attrs = - branch_func_attrs - ? branch_func_attrs[idx].template cast() - : nullptr; + branch_func_attrs ? mlir::cast(branch_func_attrs[idx]) + : nullptr; branch_regions.push_back(BasePattern::RegionFunction{ it.value(), preserved, attrs, ("case_function_" + Twine(idx)).str()}); } diff --git a/tensorflow/core/transforms/remapper/BUILD b/tensorflow/core/transforms/remapper/BUILD index 38b09b8bf601cb..e75b654b69a59f 100644 --- a/tensorflow/core/transforms/remapper/BUILD +++ b/tensorflow/core/transforms/remapper/BUILD @@ -48,6 +48,7 @@ cc_library( "@llvm-project//mlir:PDLInterpDialect", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", ], diff --git a/tensorflow/core/transforms/remapper/pass.cc b/tensorflow/core/transforms/remapper/pass.cc index 7adc8d6d84a069..f881fce665da32 100644 --- a/tensorflow/core/transforms/remapper/pass.cc +++ b/tensorflow/core/transforms/remapper/pass.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/core/framework/types.h" #include "tensorflow/core/ir/dialect.h" @@ -55,8 +56,8 @@ class MatchMulSigmoid : public RewritePattern { LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { TypeAttr dtype_attr = op->getAttrOfType("T"); - if (!dtype_attr.getValue().isa() && - !dtype_attr.getValue().isa()) { + if (!mlir::isa(dtype_attr.getValue()) && + !mlir::isa(dtype_attr.getValue())) { return failure(); } @@ -217,7 +218,7 @@ class MatchStringToHashBucket : public RemapperPatternBase { TypeAttr dtype_attr = as_string_op->getAttrOfType("T"); if (!dtype_attr) return failure(); Type dtype = dtype_attr.getValue(); - if (!dtype.isa()) return failure(); + if (!mlir::isa(dtype)) return failure(); // width/fill attributes must be default values auto width_attr = as_string_op->getAttrOfType("width"); @@ -270,7 +271,7 @@ class MatchSoftplusTanhMul : public RemapperPatternBase { auto attr = op->getAttrOfType("T"); if (!attr) return failure(); Type dtype = attr.getValue(); - if (!dtype.isa()) return failure(); + if (!mlir::isa(dtype)) return failure(); TFOp mul_wrapper(op); @@ -511,12 +512,12 @@ class FusedBatchNormExRewriter : public RemapperPatternBase { // GPU supports float and half. // Put this condition before check `isOneDNNEnabled()` because this node // should be processed when it's on GPU and oneDNN CPU is enabled. - if (!dtype_T.isa()) return false; + if (!mlir::isa(dtype_T)) return false; } else { // Bfloat16 is available only with oneDNN. // Half is not available with oneDNN. if (this->helper_.isOneDNNEnabled() && - !dtype_T.isa()) { + !mlir::isa(dtype_T)) { return false; } } @@ -543,11 +544,11 @@ class FusedBatchNormExRewriter : public RemapperPatternBase { if (data_format != "NHWC") return false; // Data type must be Float16. - if (!dtype_T.isa()) return false; + if (!mlir::isa(dtype_T)) return false; // Channel dimension must be a multiple of 4. auto fbn_input0_shape = - fused_batch_norm_op->getOperand(0).getType().cast(); + mlir::cast(fused_batch_norm_op->getOperand(0).getType()); auto fbn_input0_shape_dims = fbn_input0_shape.getShape(); const bool valid_channel_dim = (fbn_input0_shape.getRank() == 4) && @@ -562,7 +563,7 @@ class FusedBatchNormExRewriter : public RemapperPatternBase { // FusedBatchNormV2 and V3 have an extra type parameter. if (fused_batch_norm_op->getName().getStringRef() != "tfg.FusedBatchNorm") { auto attr = fused_batch_norm_op->getAttrOfType("U"); - if (attr && !attr.getValue().isa()) { + if (attr && !mlir::isa(attr.getValue())) { return false; } } @@ -618,9 +619,9 @@ class FusedBatchNormExRewriter : public RemapperPatternBase { auto add_input1_op = activation_input_op->getOperand(1).getDefiningOp(); if (add_input0_op == nullptr || add_input1_op == nullptr) return false; auto add_input0_shape = - activation_input_op->getOperand(0).getType().cast(); + mlir::cast(activation_input_op->getOperand(0).getType()); auto add_input1_shape = - activation_input_op->getOperand(1).getType().cast(); + mlir::cast(activation_input_op->getOperand(1).getType()); if (add_input0_shape.getShape() != add_input1_shape.getShape()) { return false; } diff --git a/tensorflow/core/transforms/remapper/remapping_helper.h b/tensorflow/core/transforms/remapper/remapping_helper.h index 0e751f3d550629..1d8db8fc459a8e 100644 --- a/tensorflow/core/transforms/remapper/remapping_helper.h +++ b/tensorflow/core/transforms/remapper/remapping_helper.h @@ -17,6 +17,7 @@ limitations under the License. #include +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/framework/types.h" #include "tensorflow/core/transforms/utils/op_cat_helper.h" #include "tensorflow/core/transforms/utils/utils.h" @@ -111,9 +112,9 @@ class OpPropertyHelper : public OpCatHelper { if (!attr) return false; Type dtype = attr.getValue(); if (dialect_->IsConv2D(contraction_op)) { - return dtype.isa(); + return mlir::isa(dtype); } else if (dialect_->IsMatMul(contraction_op)) { - return dtype.isa(); + return mlir::isa(dtype); } else { return false; } @@ -131,13 +132,13 @@ class OpPropertyHelper : public OpCatHelper { // fusions are handled differently than contraction ops. bool is_supported = IsContraction(contraction_op) || dialect_->IsAnyBatchMatMul(contraction_op); - return is_supported && dtype.isa(); + return is_supported && mlir::isa(dtype); } if (dialect_->IsConv2D(contraction_op)) { - return dtype.isa(); + return mlir::isa(dtype); } else if (dialect_->IsMatMul(contraction_op)) { - return dtype.isa(); + return mlir::isa(dtype); } else { return false; } diff --git a/tensorflow/core/transforms/shape_inference/pass.cc b/tensorflow/core/transforms/shape_inference/pass.cc index 80a28e0b53dbfb..3b95924476adfb 100644 --- a/tensorflow/core/transforms/shape_inference/pass.cc +++ b/tensorflow/core/transforms/shape_inference/pass.cc @@ -52,7 +52,7 @@ using tensorflow::shape_inference::ShapeHandle; // Only non-static shape or type with subtype can be refined. static bool CanBeRefined(Type type) { - auto shape_type = type.dyn_cast(); + auto shape_type = mlir::dyn_cast(type); if (!shape_type) return false; // Returns whether type with subtypes can be further refined. @@ -60,8 +60,8 @@ static bool CanBeRefined(Type type) { return tws.GetSubtypes().empty() || llvm::any_of(tws.GetSubtypes(), CanBeRefined); }; - auto type_with_subtype = shape_type.getElementType() - .dyn_cast(); + auto type_with_subtype = mlir::dyn_cast( + shape_type.getElementType()); if (type_with_subtype && can_refine_subtypes(type_with_subtype)) return true; return !shape_type.hasStaticShape(); @@ -85,7 +85,7 @@ class ShapeInference : public impl::ShapeInferenceBase { // Get the tensor value if possible, return nullptr otherwise. DenseElementsAttr GetTensorValue(Value result) { - OpResult op_result = result.dyn_cast(); + OpResult op_result = mlir::dyn_cast(result); if (op_result) { auto it = cached_tensor_values_.find(op_result); if (it != cached_tensor_values_.end()) return it->second; @@ -99,7 +99,7 @@ class ShapeInference : public impl::ShapeInferenceBase { void ShapeInference::TryToCacheResultsTensorValue(Operation *op) { // Only op with static shape is able to construct the tensor value. if (llvm::all_of(op->getResults().drop_back(), [this](Value value) { - auto shape = value.getType().cast(); + auto shape = mlir::cast(value.getType()); /// NOMUTANTS -- shape.hasStaticShape is a cheaper operation than /// GetTensorValue return (!shape.hasStaticShape() || GetTensorValue(value) != nullptr); @@ -117,9 +117,10 @@ void ShapeInference::TryToCacheResultsTensorValue(Operation *op) { if (!operand_tensor_value) return; cached_tensor_values_[op->getResult(0)] = operand_tensor_value; } else if (op_name == "Rank") { - ShapedType operand_shape = op->getOperand(0).getType().cast(); + ShapedType operand_shape = + mlir::cast(op->getOperand(0).getType()); if (!operand_shape.hasRank()) return; - ShapedType return_shape = op->getResultTypes()[0].cast(); + ShapedType return_shape = mlir::cast(op->getResultTypes()[0]); DenseElementsAttr tensor_value; if (return_shape.getElementType().isInteger(32)) { tensor_value = DenseElementsAttr::get( @@ -130,9 +131,10 @@ void ShapeInference::TryToCacheResultsTensorValue(Operation *op) { } cached_tensor_values_[op->getResult(0)] = tensor_value; } else if (op_name == "Size") { - ShapedType operand_shape = op->getOperand(0).getType().cast(); + ShapedType operand_shape = + mlir::cast(op->getOperand(0).getType()); if (!operand_shape.hasStaticShape()) return; - ShapedType return_shape = op->getResultTypes()[0].cast(); + ShapedType return_shape = mlir::cast(op->getResultTypes()[0]); DenseElementsAttr tensor_value; if (return_shape.getElementType().isInteger(32)) { tensor_value = @@ -147,13 +149,14 @@ void ShapeInference::TryToCacheResultsTensorValue(Operation *op) { } else if (op_name == "Shape" || op_name == "ShapeN") { for (OpOperand &operand : op->getOpOperands()) { Type operand_type = operand.get().getType(); - if (operand_type.isa()) break; + if (mlir::isa(operand_type)) break; - auto operand_shape = operand_type.cast(); + auto operand_shape = mlir::cast(operand_type); if (!operand_shape.hasStaticShape()) continue; int idx = operand.getOperandNumber(); - ShapedType return_shape = op->getResultTypes()[idx].cast(); + ShapedType return_shape = + mlir::cast(op->getResultTypes()[idx]); DenseElementsAttr tensor_value; if (return_shape.getElementType().isInteger(32)) { tensor_value = DenseElementsAttr::get( @@ -182,7 +185,7 @@ void ShapeInference::runOnOperation() { auto op_result_as_shape_fn = [this](InferenceContext &ic, OpResult op_result) -> ShapeHandle { - auto rt = op_result.getType().dyn_cast(); + auto rt = mlir::dyn_cast(op_result.getType()); // NOMUTANTS -- TODO(chiahungduan): Review this condition to see if shape // with known rank but unknown dimension is acceptable. if (!rt || rt.getRank() != 1 || !rt.hasStaticShape()) return {}; @@ -199,7 +202,8 @@ void ShapeInference::runOnOperation() { auto infer_and_update_shapes = [&](Operation *op) -> bool { auto result_element_type_fn = [&](int idx) -> Type { - return op->getResult(idx).getType().cast().getElementType(); + return mlir::cast(op->getResult(idx).getType()) + .getElementType(); }; SmallVector results; @@ -223,7 +227,7 @@ void ShapeInference::runOnOperation() { } Type refined_type = tf_type::GetCastCompatibleType( - op_result.getType().cast(), inferred_type); + mlir::cast(op_result.getType()), inferred_type); // Certain attributes like _output_shapes may have incorrect shape // information. When it's incompatible, use the result of shape inference diff --git a/tensorflow/core/transforms/utils/eval_utils_test.cc b/tensorflow/core/transforms/utils/eval_utils_test.cc index 78b71a9683f876..072e88f2db7465 100644 --- a/tensorflow/core/transforms/utils/eval_utils_test.cc +++ b/tensorflow/core/transforms/utils/eval_utils_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -103,8 +104,8 @@ TEST(EvalUtilsTest, EvaluateOperation) { {const_0->getAttrOfType("value")}, result))); ASSERT_EQ(result.size(), 1); - ASSERT_TRUE(result[0].isa()); - EXPECT_EQ(result[0].cast().getValues()[0], 1); + ASSERT_TRUE(mlir::isa(result[0])); + EXPECT_EQ(mlir::cast(result[0]).getValues()[0], 1); result.clear(); @@ -113,8 +114,8 @@ TEST(EvalUtilsTest, EvaluateOperation) { {const_1->getAttrOfType("value")}, result))); ASSERT_EQ(result.size(), 1); - ASSERT_TRUE(result[0].isa()); - EXPECT_EQ(result[0].cast().getValues()[0], 2); + ASSERT_TRUE(mlir::isa(result[0])); + EXPECT_EQ(mlir::cast(result[0]).getValues()[0], 2); result.clear(); @@ -125,8 +126,8 @@ TEST(EvalUtilsTest, EvaluateOperation) { result))); ASSERT_EQ(result.size(), 1); - ASSERT_TRUE(result[0].isa()); - EXPECT_EQ(result[0].cast().getValues()[0], 3); + ASSERT_TRUE(mlir::isa(result[0])); + EXPECT_EQ(mlir::cast(result[0]).getValues()[0], 3); } TEST(EvalUtilsTest, OutputInvalidation) { @@ -170,7 +171,7 @@ TEST(EvalUtilsTest, OutputInvalidation) { ASSERT_EQ(result.size(), 2); // Output 0 is invalidated. EXPECT_EQ(result[0], nullptr); - EXPECT_EQ(result[1].cast().getValues()[0], 2); + EXPECT_EQ(mlir::cast(result[1]).getValues()[0], 2); } } // namespace tfg diff --git a/tensorflow/core/transforms/utils/op_cat_helper.cc b/tensorflow/core/transforms/utils/op_cat_helper.cc index 5107c2e7b7c94e..1347072cd87676 100644 --- a/tensorflow/core/transforms/utils/op_cat_helper.cc +++ b/tensorflow/core/transforms/utils/op_cat_helper.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/transforms/utils/op_cat_helper.h" +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/ir/dialect.h" @@ -55,22 +56,22 @@ bool SplatElementsAttrHasValue(SplatElementsAttr attr, float v) { IF_SPLAT_VALUE_IS(tensorflow::DT_DOUBLE, v); } else if (type.isBF16()) { IF_SPLAT_VALUE_IS(tensorflow::DT_BFLOAT16, v); - } else if (type.isa()) { - ComplexType complex_type = type.cast(); + } else if (mlir::isa(type)) { + ComplexType complex_type = mlir::cast(type); if (complex_type.getElementType().isF32()) { IF_SPLAT_VALUE_IS(tensorflow::DT_COMPLEX64, v); } else if (complex_type.getElementType().isF64()) { IF_SPLAT_VALUE_IS(tensorflow::DT_COMPLEX128, v); } - } else if (type.isa()) { + } else if (mlir::isa(type)) { IF_SPLAT_VALUE_IS(tensorflow::DT_QINT8, v); - } else if (type.isa()) { + } else if (mlir::isa(type)) { IF_SPLAT_VALUE_IS(tensorflow::DT_QINT16, v); - } else if (type.isa()) { + } else if (mlir::isa(type)) { IF_SPLAT_VALUE_IS(tensorflow::DT_QINT32, v); - } else if (type.isa()) { + } else if (mlir::isa(type)) { IF_SPLAT_VALUE_IS(tensorflow::DT_QUINT8, v); - } else if (type.isa()) { + } else if (mlir::isa(type)) { IF_SPLAT_VALUE_IS(tensorflow::DT_QUINT16, v); } #undef IF_SPLAT_VALUE_IS @@ -82,7 +83,7 @@ bool SplatElementsAttrHasValue(SplatElementsAttr attr, float v) { bool OpCatHelper::IsAggregate(TFOp op) { if (dialect_->IsAdd(op)) { auto attr = op->getAttrOfType("T"); - return !attr || !attr.getValue().isa(); + return !attr || !mlir::isa(attr.getValue()); } const tensorflow::OpDef *op_def = nullptr; tensorflow::Status status = tensorflow::OpRegistry::Global()->LookUpOpDef( @@ -93,7 +94,7 @@ bool OpCatHelper::IsAggregate(TFOp op) { bool OpCatHelper::IsCommutative(TFOp op) { if (dialect_->IsAdd(op)) { auto attr = op->getAttrOfType("T"); - return !attr || !attr.getValue().isa(); + return !attr || !mlir::isa(attr.getValue()); } const tensorflow::OpDef *op_def = nullptr; tensorflow::Status status = tensorflow::OpRegistry::Global()->LookUpOpDef( diff --git a/tensorflow/dtensor/cc/BUILD b/tensorflow/dtensor/cc/BUILD index 3d28e474d680ea..a1ba57eb3f104c 100644 --- a/tensorflow/dtensor/cc/BUILD +++ b/tensorflow/dtensor/cc/BUILD @@ -204,6 +204,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/dtensor/cc/save_restore_util.cc b/tensorflow/dtensor/cc/save_restore_util.cc index b6f040a5f24fc6..dcaf41baf5f1e6 100644 --- a/tensorflow/dtensor/cc/save_restore_util.cc +++ b/tensorflow/dtensor/cc/save_restore_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/dtensor/mlir/value_utils.h" namespace tensorflow { @@ -158,7 +159,8 @@ SaveOpSpecs BuildPerDeviceSave( builder .create( prefix.getLoc(), - prefix.getType().dyn_cast(), prefix, + mlir::dyn_cast(prefix.getType()), + prefix, StringScalarConst(builder, prefix.getLoc(), DeviceSuffix(device_id, total_devices))) .getZ(); diff --git a/tensorflow/dtensor/mlir/BUILD b/tensorflow/dtensor/mlir/BUILD index 80531f9c04036a..6d531565f03f88 100644 --- a/tensorflow/dtensor/mlir/BUILD +++ b/tensorflow/dtensor/mlir/BUILD @@ -108,6 +108,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -131,6 +132,7 @@ cc_library( "//tensorflow/dtensor/cc:tensor_layout", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], alwayslink = True, ) @@ -326,6 +328,7 @@ cc_library( "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], alwayslink = 1, ) diff --git a/tensorflow/dtensor/mlir/annotate_global_shape.cc b/tensorflow/dtensor/mlir/annotate_global_shape.cc index e251254e1d38cf..462d9b3a3f8e96 100644 --- a/tensorflow/dtensor/mlir/annotate_global_shape.cc +++ b/tensorflow/dtensor/mlir/annotate_global_shape.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" @@ -51,8 +52,8 @@ void AnnotateFunctionArgRetvalGlobalShapes(mlir::func::FuncOp function, const auto& argument_type = argument_type_and_index.value(); // Extract TensorType from element of resource type to allow setting proper // global shape of resource types. - if (auto resource_type = mlir::getElementTypeOrSelf(argument_type) - .dyn_cast()) { + if (auto resource_type = mlir::dyn_cast( + mlir::getElementTypeOrSelf(argument_type))) { auto subtype = resource_type.getSubtypes(); if (subtype.size() == 1) { // subtype returns a Array of TensorType -- if it contains more than one diff --git a/tensorflow/dtensor/mlir/cluster_function_conversion.cc b/tensorflow/dtensor/mlir/cluster_function_conversion.cc index a1a40fc0801165..11aa6121c4339b 100644 --- a/tensorflow/dtensor/mlir/cluster_function_conversion.cc +++ b/tensorflow/dtensor/mlir/cluster_function_conversion.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -80,7 +81,7 @@ mlir::LogicalResult AttachRetvalLayouts( // operations. In that case, query the input layouts for function callsite // operations for layout information. if (!result_layout) { - if (auto block_arg = operand.dyn_cast()) { + if (auto block_arg = mlir::dyn_cast(operand)) { auto layout_or_status = ExtractLayoutFromOperand( sp_call_op.getOperand(block_arg.getArgNumber())); if (!layout_or_status.ok()) diff --git a/tensorflow/dtensor/mlir/collectives.cc b/tensorflow/dtensor/mlir/collectives.cc index 696fd52923b36e..8a766d63912b89 100644 --- a/tensorflow/dtensor/mlir/collectives.cc +++ b/tensorflow/dtensor/mlir/collectives.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" #include "tensorflow/core/platform/errors.h" @@ -79,7 +80,7 @@ StatusOr EmitAllGather( // For convenience, operate on explicit input shapes. This isn't necessary, // as we could instead generate operations on top of the dynamic shape. const mlir::TensorType input_type = - input.getType().dyn_cast(); + mlir::dyn_cast(input.getType()); if (!input_type) { return errors::Internal( llvm::formatv( @@ -133,7 +134,7 @@ StatusOr EmitAllScatter( } const mlir::TensorType input_type = - original_value.getType().dyn_cast(); + mlir::dyn_cast(original_value.getType()); if (!input_type) return errors::InvalidArgument( "input to EmitAllScatter does not have a TensorType"); @@ -199,7 +200,7 @@ StatusOr EmitAllToAll( // For convenience, operate on explicit input shapes. This isn't necessary, // as we could instead generate operations on top of the dynamic shape. const mlir::TensorType input_type = - input.getType().dyn_cast(); + mlir::dyn_cast(input.getType()); if (!input_type) { return errors::Internal( llvm::formatv( @@ -260,7 +261,7 @@ StatusOr EmitDenseToSparseToDense( mlir::Value zero_scalar, CreateZeroScalarConst( builder, input.getLoc(), - input.getType().cast().getElementType())); + mlir::cast(input.getType()).getElementType())); auto dense = builder.create( input.getLoc(), input.getType(), @@ -311,7 +312,7 @@ StatusOr EmitRelayout( // Save whether the input is from a SparseToDenseOp. If it is, then we will // emit a DenseToSparse and a SparseToDense op. bool is_sparse = IsSparseValue(input); - if (!input.getType().isa()) + if (!mlir::isa(input.getType())) return errors::Internal( "attempting to relayout a tensor that does not " "have a rank"); @@ -389,7 +390,7 @@ StatusOr EmitRelayout( mlir::Operation* EmitTransposeOp(mlir::OpBuilder& builder, const mlir::Location& loc, mlir::Value input, std::vector& perm_arr) { - auto tr_input_type = input.getType().cast(); + auto tr_input_type = mlir::cast(input.getType()); auto shape = tr_input_type.getShape(); auto perm_type = mlir::RankedTensorType::get( @@ -591,7 +592,8 @@ StatusOr EmitHaloExchange(mlir::OpBuilder& builder, int halo_size, if (!mesh.is_tpu_mesh()) return errors::InvalidArgument("Halo exchange is only supported on TPU."); - auto input_tensor_type = tensor.getType().dyn_cast(); + auto input_tensor_type = + mlir::dyn_cast(tensor.getType()); if (!input_tensor_type || !input_tensor_type.hasStaticShape()) return errors::InvalidArgument( "Static shape of input tensor must be known for halo exchange."); diff --git a/tensorflow/dtensor/mlir/device_utils.cc b/tensorflow/dtensor/mlir/device_utils.cc index f32ae0f0beff91..16d0ff7871613f 100644 --- a/tensorflow/dtensor/mlir/device_utils.cc +++ b/tensorflow/dtensor/mlir/device_utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/platform/errors.h" namespace tensorflow { @@ -38,9 +39,10 @@ StatusOr DeviceId(mlir::Operation* op) { "enclosing function must contain device id as argument"); auto device_id = function.getArgument(0); - auto device_id_type = device_id.getType().dyn_cast(); + auto device_id_type = + mlir::dyn_cast(device_id.getType()); if (!device_id_type || - !device_id_type.getElementType().isa()) + !mlir::isa(device_id_type.getElementType())) return errors::InvalidArgument( "0-th argument of the enclosing function should be integer device id."); @@ -48,12 +50,12 @@ StatusOr DeviceId(mlir::Operation* op) { } StatusOr DeviceId(mlir::Value val) { - if (auto block_arg = val.dyn_cast()) { + if (auto block_arg = mlir::dyn_cast(val)) { auto device_id = block_arg.getOwner()->getArgument(0); auto device_id_type = - device_id.getType().dyn_cast(); + mlir::dyn_cast(device_id.getType()); if (!device_id_type || - !device_id_type.getElementType().isa()) + !mlir::isa(device_id_type.getElementType())) return errors::InvalidArgument( "0-th argument of the enclosing block should be integer device id."); return device_id; diff --git a/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc b/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc index 8d9e34b6aa77a5..a4dec16409fdac 100644 --- a/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc +++ b/tensorflow/dtensor/mlir/dtensor_allreduce_combine_optimization.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/TopologicalSortUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h" @@ -112,7 +113,7 @@ mlir::LogicalResult MergeAllReduceGroup( all_reduce_shapes.reserve(num_all_reduces); for (mlir::TF::DTensorAllReduceOp& all_reduce : all_reduce_group) { auto all_reduce_ranked_type = - all_reduce.getType().dyn_cast(); + mlir::dyn_cast(all_reduce.getType()); if (!all_reduce_ranked_type || !all_reduce_ranked_type.hasStaticShape()) { return all_reduce.emitOpError(llvm::formatv( "requires static shape for DTensorAllReduceOp, but got : {0}", @@ -152,7 +153,7 @@ mlir::LogicalResult MergeAllReduceGroup( mlir::TF::DTensorAllReduceOp& all_reduce = all_reduce_group[i]; mlir::Location loc = all_reduce.getLoc(); auto all_reduce_ranked_type = - all_reduce.getType().dyn_cast(); + mlir::dyn_cast(all_reduce.getType()); if (!all_reduce_ranked_type || !all_reduce_ranked_type.hasStaticShape()) { return all_reduce.emitOpError(llvm::formatv( "requires static shape for DTensorAllReduceOp, but got : {0}", @@ -201,7 +202,7 @@ mlir::LogicalResult MergeAllReduceGroup( mlir::TF::DTensorAllReduceOp& all_reduce = all_reduce_group[i]; mlir::Location loc = all_reduce.getLoc(); auto all_reduce_ranked_type = - all_reduce.getType().dyn_cast(); + mlir::dyn_cast(all_reduce.getType()); if (!all_reduce_ranked_type || !all_reduce_ranked_type.hasStaticShape()) { return all_reduce.emitOpError(llvm::formatv( "requires static shape for DTensorAllReduceOp, but got : {0}", @@ -676,7 +677,7 @@ struct DTensorAllReduceCombineOptimization if (!all_reduce.getDeviceType().contains("TPU")) { // Only combine all reduces for GPU and CPU mlir::RankedTensorType all_reduce_ranked_type = - all_reduce.getType().dyn_cast(); + mlir::dyn_cast(all_reduce.getType()); if (all_reduce_ranked_type && all_reduce_ranked_type.hasStaticShape()) { diff --git a/tensorflow/dtensor/mlir/dtensor_allreduce_sum_optimization.cc b/tensorflow/dtensor/mlir/dtensor_allreduce_sum_optimization.cc index 78a67ea1b9f404..c7b21f2508fc3d 100644 --- a/tensorflow/dtensor/mlir/dtensor_allreduce_sum_optimization.cc +++ b/tensorflow/dtensor/mlir/dtensor_allreduce_sum_optimization.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/UseDefLists.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -56,7 +57,7 @@ bool IsZeroConstant(mlir::Value val) { GetIdentitySkippedInputs(val).getDefiningOp()); if (!const_input) return false; mlir::DenseFPElementsAttr attr = - const_input.getValue().dyn_cast(); + mlir::dyn_cast(const_input.getValue()); // This uses the fact that constant Attrs becomes splats, so we only need to // check one value. if (!attr || !attr.isSplat()) return false; @@ -255,9 +256,9 @@ void OptimizeIdentityLikeOps(mlir::Operation* op, bool* changed) { mlir::Value op_output = op->getResult(0); dtensor_all_reduce.setOperand(0, op_output); dtensor_all_reduce.getInput().setType( - op_output.getType().cast()); + mlir::cast(op_output.getType())); dtensor_all_reduce.getOutput().setType( - op_output.getType().cast()); + mlir::cast(op_output.getType())); llvm::SmallPtrSet exceptions{dtensor_all_reduce}; op_output.replaceAllUsesExcept(dtensor_all_reduce.getOutput(), exceptions); @@ -296,12 +297,12 @@ bool CheckWhileLoopOptimizationCriteria( llvm::dyn_cast_or_null( first_operand.get().getDefiningOp()); if (all_reduce) { - block_arg = second_operand.get().dyn_cast(); + block_arg = mlir::dyn_cast(second_operand.get()); *add_input = &second_operand; } else { all_reduce = llvm::dyn_cast_or_null( second_operand.get().getDefiningOp()); - block_arg = first_operand.get().dyn_cast(); + block_arg = mlir::dyn_cast(first_operand.get()); *add_input = &first_operand; } if (!block_arg || !all_reduce) return false; diff --git a/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc b/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc index 9c3171cbf0131d..2634c993b62014 100644 --- a/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc +++ b/tensorflow/dtensor/mlir/dtensor_collective_type_lowering.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/types.h" @@ -99,9 +100,9 @@ mlir::LogicalResult ConvertShortIntReduce(ReduceOpType reduce_op) { // Handle bools by first casting to int32 and swapping All/Any for Min/Max. const mlir::TensorType& tensor_input_type = - input_type.dyn_cast(); + mlir::dyn_cast(input_type); const mlir::TensorType& tensor_output_type = - output_type.dyn_cast(); + mlir::dyn_cast(output_type); if (!tensor_input_type) return mlir::success(); if (!tensor_output_type) return mlir::success(); @@ -166,12 +167,12 @@ mlir::LogicalResult ConvertComplexReduce(ReduceOpType reduce_op) { const mlir::Value tensor_input = reduce_op.getInput(); const mlir::Value tensor_result = reduce_op.getResult(); const mlir::TensorType complex_input_tensor_type = - tensor_input.getType().dyn_cast(); + mlir::dyn_cast(tensor_input.getType()); if (!complex_input_tensor_type) { return mlir::success(); } const mlir::TensorType complex_result_tensor_type = - tensor_result.getType().dyn_cast(); + mlir::dyn_cast(tensor_result.getType()); if (!complex_result_tensor_type) { return mlir::success(); } @@ -222,12 +223,12 @@ mlir::LogicalResult ConvertComplexCollectives(CollectiveType op) { const mlir::Value tensor_input = op.getInput(); const mlir::Value tensor_result = op.getResult(); const mlir::TensorType complex_input_tensor_type = - tensor_input.getType().dyn_cast(); + mlir::dyn_cast(tensor_input.getType()); if (!complex_input_tensor_type) { return mlir::success(); } const mlir::TensorType& complex_result_tensor_type = - tensor_result.getType().dyn_cast(); + mlir::dyn_cast(tensor_result.getType()); if (!complex_result_tensor_type) { return mlir::success(); } diff --git a/tensorflow/dtensor/mlir/dtensor_dialect/BUILD b/tensorflow/dtensor/mlir/dtensor_dialect/BUILD index d484322d42fea4..534638eed7b6a6 100644 --- a/tensorflow/dtensor/mlir/dtensor_dialect/BUILD +++ b/tensorflow/dtensor/mlir/dtensor_dialect/BUILD @@ -71,6 +71,7 @@ cc_library( "@com_google_absl//absl/strings", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/dtensor/mlir/dtensor_dialect/ir/ops.cc b/tensorflow/dtensor/mlir/dtensor_dialect/ir/ops.cc index 7fa35ca75f11d7..80a9a91f61b645 100644 --- a/tensorflow/dtensor/mlir/dtensor_dialect/ir/ops.cc +++ b/tensorflow/dtensor/mlir/dtensor_dialect/ir/ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/DialectImplementation.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/utils/string_container_utils.h" #include "tensorflow/dtensor/cc/tensor_layout.h" #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h" @@ -144,9 +145,10 @@ static void printLayoutAttr(LayoutAttr attr, DialectAsmPrinter &os) { void DTensorDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const { // Cast into correct attribute and print - if (auto mesh_attr = attr.dyn_cast()) printMeshAttr(mesh_attr, os); + if (auto mesh_attr = mlir::dyn_cast(attr)) + printMeshAttr(mesh_attr, os); - if (auto layout_attr = attr.dyn_cast()) + if (auto layout_attr = mlir::dyn_cast(attr)) printLayoutAttr(layout_attr, os); } } // namespace dtensor diff --git a/tensorflow/dtensor/mlir/dtensor_layout_to_xla_sharding_op.cc b/tensorflow/dtensor/mlir/dtensor_layout_to_xla_sharding_op.cc index 0f633c3707a459..b9e9b0648a7220 100644 --- a/tensorflow/dtensor/mlir/dtensor_layout_to_xla_sharding_op.cc +++ b/tensorflow/dtensor/mlir/dtensor_layout_to_xla_sharding_op.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -65,7 +66,7 @@ class DTensorLayoutToXlaShardingOpPass mlir::LogicalResult RemoveDTensorLayoutAfterConstOrBlockArgPattern::match( DTensorLayout layout_op) const { auto input = layout_op.getInput(); - if (input.isa()) { + if (mlir::isa(input)) { return mlir::success(); } mlir::Operation* input_op = input.getDefiningOp(); diff --git a/tensorflow/dtensor/mlir/dtensor_location.cc b/tensorflow/dtensor/mlir/dtensor_location.cc index 23be51a4c96695..a129889c5e51c0 100644 --- a/tensorflow/dtensor/mlir/dtensor_location.cc +++ b/tensorflow/dtensor/mlir/dtensor_location.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/utils/name_utils.h" namespace tensorflow { @@ -66,12 +67,12 @@ std::string DTensorLocationToString(mlir::Location loc) { while (!queue.empty()) { mlir::Location& front = queue.front(); - if (auto name_loc = front.dyn_cast()) { + if (auto name_loc = mlir::dyn_cast(front)) { queue.push(name_loc.getChildLoc()); - } else if (auto callsite_loc = front.dyn_cast()) { + } else if (auto callsite_loc = mlir::dyn_cast(front)) { queue.push(callsite_loc.getCallee()); queue.push(callsite_loc.getCaller()); - } else if (auto line_loc = front.dyn_cast()) { + } else if (auto line_loc = mlir::dyn_cast(front)) { stack.push_back(CreateLocalLocationString(line_loc)); } queue.pop(); diff --git a/tensorflow/dtensor/mlir/dtensor_location_test.cc b/tensorflow/dtensor/mlir/dtensor_location_test.cc index 3989c3623b3901..773f6af8db4d82 100644 --- a/tensorflow/dtensor/mlir/dtensor_location_test.cc +++ b/tensorflow/dtensor/mlir/dtensor_location_test.cc @@ -25,8 +25,8 @@ namespace { void CheckFileLineColLocation(mlir::Location loc, unsigned line, unsigned column) { - ASSERT_TRUE(loc.isa()); - auto file_line_col_loc = loc.cast(); + ASSERT_TRUE(mlir::isa(loc)); + auto file_line_col_loc = mlir::cast(loc); EXPECT_EQ(file_line_col_loc.getFilename(), "test.cc"); EXPECT_EQ(file_line_col_loc.getLine(), line); EXPECT_EQ(file_line_col_loc.getColumn(), column); @@ -37,8 +37,8 @@ TEST(DTensorLocationTest, HandlesEmptyLocation) { mlir::Location loc = mlir::FileLineColLoc::get(&ctx, "test.cc", 10, 20); loc = tensorflow::dtensor::DTensorLocation(loc, "test.cc", 21); - ASSERT_TRUE(loc.isa()); - auto callsite_loc = loc.cast(); + ASSERT_TRUE(mlir::isa(loc)); + auto callsite_loc = mlir::cast(loc); CheckFileLineColLocation(callsite_loc.getCallee(), 21, 0); CheckFileLineColLocation(callsite_loc.getCaller(), 10, 20); @@ -57,8 +57,8 @@ TEST(DTensorLocationTest, HandlesMultipleCalls) { auto verify_loc = test_loc; for (int i = 0; i < 4; ++i) { - ASSERT_TRUE(verify_loc.isa()); - auto callsite_loc = verify_loc.cast(); + ASSERT_TRUE(mlir::isa(verify_loc)); + auto callsite_loc = mlir::cast(verify_loc); auto callee_loc = callsite_loc.getCallee(); CheckFileLineColLocation(callee_loc, 24 - i, 0); verify_loc = callsite_loc.getCaller(); @@ -80,17 +80,18 @@ TEST(DTensorLocationTest, HandlesNameLoc) { mlir::FileLineColLoc::get(&ctx, "test.cc", 10, 20)); test_loc = tensorflow::dtensor::DTensorLocation(test_loc, "test.cc", 21); ASSERT_EQ(mlir::GetNameFromLoc(test_loc), "op"); - ASSERT_TRUE(test_loc.isa()); - auto callsite_loc = test_loc.cast(); - mlir::Location caller_loc = test_loc.cast().getCaller(); - ASSERT_TRUE(caller_loc.isa()); - CheckFileLineColLocation(caller_loc.cast().getChildLoc(), 10, - 20); + ASSERT_TRUE(mlir::isa(test_loc)); + auto callsite_loc = mlir::cast(test_loc); + mlir::Location caller_loc = + mlir::cast(test_loc).getCaller(); + ASSERT_TRUE(mlir::isa(caller_loc)); + CheckFileLineColLocation(mlir::cast(caller_loc).getChildLoc(), + 10, 20); mlir::Location callee_loc = callsite_loc.getCallee(); - ASSERT_TRUE(callee_loc.isa()); - CheckFileLineColLocation(callee_loc.cast().getChildLoc(), 21, - 0); + ASSERT_TRUE(mlir::isa(callee_loc)); + CheckFileLineColLocation(mlir::cast(callee_loc).getChildLoc(), + 21, 0); constexpr char stack[] = R"stack(>> test.cc:10:20 >> test.cc:21:0)stack"; diff --git a/tensorflow/dtensor/mlir/dtensor_mixed_precision_reduce.cc b/tensorflow/dtensor/mlir/dtensor_mixed_precision_reduce.cc index 5ecb38b3e97590..c5fbe20ac067f3 100644 --- a/tensorflow/dtensor/mlir/dtensor_mixed_precision_reduce.cc +++ b/tensorflow/dtensor/mlir/dtensor_mixed_precision_reduce.cc @@ -69,9 +69,7 @@ template mlir::LogicalResult MaybeUpcastForReduction(ReduceOpType reduce_op, bool* changed) { const mlir::RankedTensorType& input_type = - reduce_op.getInput() - .getType() - .template dyn_cast(); + mlir::dyn_cast(reduce_op.getInput().getType()); if (!input_type.getElementType().isBF16()) { // Upcast only applies for bfloat16 input. return mlir::success(); @@ -96,9 +94,7 @@ mlir::LogicalResult MaybeUpcastForReduction(ReduceOpType reduce_op, // The original output tensor type that would have been used by all users of // the reduce op. const mlir::RankedTensorType& output_type = - reduce_op.getOutput() - .getType() - .template dyn_cast(); + mlir::dyn_cast(reduce_op.getOutput().getType()); mlir::TF::CastOp upcast = builder.create( loc, diff --git a/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc b/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc index 1f47934230e3bc..15a334fccb9d61 100644 --- a/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc +++ b/tensorflow/dtensor/mlir/dtensor_multi_device_expansion.cc @@ -160,7 +160,7 @@ StatusOr>> GetResourceLayouts( std::vector layouts; layouts.reserve(attrs.size()); for (mlir::Attribute attr : attrs) { - auto string_attr = attr.cast(); + auto string_attr = mlir::cast(attr); auto layout = Layout::FromString(string_attr.str()); if (layout.ok()) { layouts.emplace_back(std::move(layout.value())); @@ -175,7 +175,8 @@ StatusOr>> GetResourceLayouts( } bool IsResource(mlir::Value value) { - return getElementTypeOrSelf(value.getType()).isa(); + return mlir::isa( + getElementTypeOrSelf(value.getType())); } StatusOr> FindResourceLayout(mlir::BlockArgument arg) { @@ -417,7 +418,7 @@ mlir::LogicalResult ExpandTPUOperation( llvm::SmallVector operands; for (const mlir::Value& operand : op->getOperands()) { - if (const auto arg = operand.dyn_cast_or_null()) { + if (const auto arg = mlir::dyn_cast_or_null(operand)) { const StatusOr> new_args = GetExpandedArguments( builder, target_func, expanded_arguments, arg, &target_mesh); if (!new_args.ok()) { @@ -481,7 +482,8 @@ mlir::LogicalResult ExpandOperation( for (size_t i = 0; i < num_devices; ++i) { llvm::SmallVector operands; for (const mlir::Value& operand : op->getOperands()) { - if (const auto arg = operand.dyn_cast_or_null()) { + if (const auto arg = + mlir::dyn_cast_or_null(operand)) { const StatusOr> new_args = GetExpandedArguments( builder, target_func, expanded_arguments, arg, &target_mesh); if (!new_args.ok()) { @@ -609,7 +611,7 @@ StatusOr> GetExpandedArguments( } } else { mlir::TensorType tensor_type = - arg.getType().dyn_cast_or_null(); + mlir::dyn_cast_or_null(arg.getType()); if (!tensor_type) { return errors::InvalidArgument("Could not determine tensor type."); } diff --git a/tensorflow/dtensor/mlir/dtensor_send_recv.cc b/tensorflow/dtensor/mlir/dtensor_send_recv.cc index ed4b636081fc93..fe5885012ef471 100644 --- a/tensorflow/dtensor/mlir/dtensor_send_recv.cc +++ b/tensorflow/dtensor/mlir/dtensor_send_recv.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -48,14 +49,14 @@ namespace dtensor { namespace { bool IsStringType(mlir::Type type) { - if (type.isa()) return true; + if (mlir::isa(type)) return true; - auto sub_type = type.dyn_cast(); + auto sub_type = mlir::dyn_cast(type); if (!sub_type) return false; bool has_string = llvm::any_of(sub_type.GetSubtypes(), [](mlir::TensorType type) { - return type.getElementType().isa(); + return mlir::isa(type.getElementType()); }); return has_string; } @@ -421,7 +422,7 @@ StatusOr LowerOneToOneDTensorSendToTFHostSend( op_builder.getStringAttr(send_layout.ToString())); mlir::Value val = arg; if (i32_copy) { - auto val_type = val.getType().cast(); + auto val_type = mlir::cast(val.getType()); val = op_builder .create( loc, @@ -673,7 +674,7 @@ StatusOr LowerDTensorSend(mlir::Operation* send_op, dtensor_send)); } else { mlir::TensorType send_type = - send_input.getType().cast(); + mlir::cast(send_input.getType()); if (!recv_mesh.is_cpu_mesh() && send_type.getElementType().isInteger(32)) { builder.setInsertionPointAfter(send_input.getDefiningOp()); @@ -745,7 +746,7 @@ StatusOr LowerDTensorRecv(mlir::Operation* send_op, TF_ASSIGN_OR_RETURN( mlir::TensorType local_output_type, LocalTypeFromGlobalType( - recv_layout, dtensor_recv.getType().cast())); + recv_layout, mlir::cast(dtensor_recv.getType()))); TF_ASSIGN_OR_RETURN( lowered_recv, LowerDTensorRecvToXlaOp(dtensor_recv, local_output_type)); dtensor_recv->replaceAllUsesWith(lowered_recv); diff --git a/tensorflow/dtensor/mlir/expansions/control_flow_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/control_flow_spmd_expander.cc index 6c618def560dfe..1f9711220824d2 100644 --- a/tensorflow/dtensor/mlir/expansions/control_flow_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/control_flow_spmd_expander.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/dtensor/mlir/expansions/control_flow_spmd_expander.h" +#include "mlir/Support/LLVM.h" // from @llvm-project + namespace tensorflow { namespace dtensor { @@ -79,7 +81,7 @@ StatusOr IfRegionSPMDExpander::ExpandOp(mlir::Operation* op) { result.setType(mlir::RankedTensorType::get( layout.LocalShapeFromGlobalShape(*global_shape), - result.getType().cast().getElementType())); + mlir::cast(result.getType()).getElementType())); } } return op; diff --git a/tensorflow/dtensor/mlir/expansions/conv_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/conv_spmd_expander.cc index 2602a16eb9f824..fa6dee1683936c 100644 --- a/tensorflow/dtensor/mlir/expansions/conv_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/conv_spmd_expander.cc @@ -22,6 +22,7 @@ limitations under the License. #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/errors.h" @@ -68,7 +69,7 @@ Status VerifyConvLayout(const Layout& input_layout, const Layout& filter_layout, const int num_non_default_dilations = llvm::count_if(conv_op.getDilations(), [](mlir::Attribute dilation) { - return dilation.cast().getInt() != 1; + return mlir::cast(dilation).getInt() != 1; }); if (num_non_default_dilations > 0) return errors::InvalidArgument( @@ -78,21 +79,21 @@ Status VerifyConvLayout(const Layout& input_layout, const Layout& filter_layout, // TODO(b/208700444): support convolution with strides greater than 1. const int num_non_default_strides = llvm::count_if(conv_op.getStrides(), [](mlir::Attribute stride) { - return stride.cast().getInt() != 1; + return mlir::cast(stride).getInt() != 1; }); if (num_non_default_strides > 0) return errors::InvalidArgument( "Only stride 1 is supported for convolution with spatial partitions."); mlir::Value input = conv_op.getInput(); - auto input_type = input.getType().dyn_cast(); + auto input_type = mlir::dyn_cast(input.getType()); if (!input_type || !input_type.hasStaticShape()) return errors::InvalidArgument( "Input must have static shapes for convolution with spatial " "partitions."); mlir::Value filter = conv_op.getFilter(); - auto filter_type = filter.getType().dyn_cast(); + auto filter_type = mlir::dyn_cast(filter.getType()); if (!filter_type || !filter_type.hasStaticShape()) return errors::InvalidArgument( "Filter must have static shapes for convolution with spatial " @@ -114,7 +115,7 @@ mlir::Value PadInputOnUnshardedDim(mlir::OpBuilder& builder, mlir::Value input_tensor, int curr_input_dim, int64_t curr_filter_dim_size) { auto input_tensor_type = - input_tensor.getType().dyn_cast(); + mlir::dyn_cast(input_tensor.getType()); auto input_tensor_shape = input_tensor_type.getShape(); const size_t paddings_flat_length = input_tensor_type.getRank() * 2; @@ -171,7 +172,7 @@ StatusOr HandleConv(ConvOp conv_op) { const auto output_num_shards = output_layout.num_shards(); auto filter_type = - conv_op.getFilter().getType().template dyn_cast(); + mlir::dyn_cast(conv_op.getFilter().getType()); auto filter_shape = filter_type.getShape(); int begin_input_dim = -1, end_input_dim = -1; @@ -194,9 +195,8 @@ StatusOr HandleConv(ConvOp conv_op) { ++curr_input_dim) { int curr_filter_dim = curr_input_dim - begin_input_dim; - auto input_type = conv_op.getInput() - .getType() - .template dyn_cast(); + auto input_type = + mlir::dyn_cast(conv_op.getInput().getType()); auto input_shape = input_type.getShape(); if (input_sharding_spec[curr_input_dim] == Layout::kUnshardedDim) { @@ -410,7 +410,7 @@ StatusOr HandleConvBackpropInputTensor( // HandleConvBackpropInput which expects there to be a layout here. mlir::TF::ShapeAttr global_input_shape_shape = mlir::TF::ShapeAttr::get( builder.getContext(), - global_input_shape.getType().cast()); + mlir::cast(global_input_shape.getType())); mlir::TF::DTensorLayout global_input_shape_with_layout = builder.create( conv_op->getLoc(), global_input_shape, @@ -536,7 +536,7 @@ StatusOr HandleConvBackpropFilterTensor( // HandleConvBackpropInput which expects there to be a layout here. mlir::TF::ShapeAttr global_filter_shape_shape = mlir::TF::ShapeAttr::get( builder.getContext(), - global_filter_shape_const.getType().cast()); + mlir::cast(global_filter_shape_const.getType())); mlir::TF::DTensorLayout global_filter_shape_with_layout = builder.create( conv_op->getLoc(), global_filter_shape_const, diff --git a/tensorflow/dtensor/mlir/expansions/fft_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/fft_spmd_expander.cc index d8de0e07b3be25..2893128e1f01c1 100644 --- a/tensorflow/dtensor/mlir/expansions/fft_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/fft_spmd_expander.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/cc/framework/ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/errors.h" @@ -74,7 +75,7 @@ bool IsDistributedFFTN(int num_transform_axes, const Layout& layout) { bool IsComplexFFT(mlir::Value input) { auto data_type = mlir::dyn_cast(input.getType()).getElementType(); - return data_type.isa(); + return mlir::isa(data_type); } Status IsProperFFTLength(mlir::Operation* op, diff --git a/tensorflow/dtensor/mlir/expansions/fill_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/fill_spmd_expander.cc index 8aeb53f5e8314e..75dde58edee01c 100644 --- a/tensorflow/dtensor/mlir/expansions/fill_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/fill_spmd_expander.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "xla/mlir_hlo/utils/convert_op_folder.h" @@ -68,11 +69,9 @@ StatusOr FillSPMDExpander::ExpandOp(mlir::Operation* op) { auto int_type = mlir::RankedTensorType::get( static_cast(shard_values.size()), builder.getIntegerType(32)); auto int_attr = mlir::DenseIntElementsAttr::get(int_type, shard_values); - auto target_type_attr = - mlir::hlo::convertElementsAttr(int_attr, original_fill.getDims() - .getType() - .cast() - .getElementType()); + auto target_type_attr = mlir::hlo::convertElementsAttr( + int_attr, mlir::cast(original_fill.getDims().getType()) + .getElementType()); auto location = DT_LOC(op); auto num_shards = @@ -82,7 +81,7 @@ StatusOr FillSPMDExpander::ExpandOp(mlir::Operation* op) { num_shards.getResult()); // Copy over static shape information if available auto global_output_type = - original_fill.getResult().getType().cast(); + mlir::cast(original_fill.getResult().getType()); TF_ASSIGN_OR_RETURN( auto local_type, LocalTypeFromGlobalType(output_layout.value(), global_output_type)); diff --git a/tensorflow/dtensor/mlir/expansions/io_op_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/io_op_spmd_expander.cc index 6cd2f10c4c97aa..64c55363d3451d 100644 --- a/tensorflow/dtensor/mlir/expansions/io_op_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/io_op_spmd_expander.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "llvm/Support/FormatVariadic.h" +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/dtensor/cc/tensor_layout.h" #include "tensorflow/dtensor/mlir/collectives.h" @@ -99,7 +100,7 @@ StatusOr Expand(mlir::Operation* op) { mlir::Value zero_scalar, CreateZeroScalarConst( builder, location, - device_id.getType().cast().getElementType())); + mlir::cast(device_id.getType()).getElementType())); mlir::TF::NotEqualOp not_equal = builder.create( location, device_id, zero_scalar, diff --git a/tensorflow/dtensor/mlir/expansions/iterator_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/iterator_spmd_expander.cc index cbe59bdb8557a0..b0c84ce679733b 100644 --- a/tensorflow/dtensor/mlir/expansions/iterator_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/iterator_spmd_expander.cc @@ -50,7 +50,7 @@ StatusOr IteratorGetNextSPMDExpander::ExpandOp( for (int i = 0; i < original_op->getNumResults(); ++i) { mlir::TensorType global_output_type = - original_op.getResult(i).getType().cast(); + mlir::cast(original_op.getResult(i).getType()); std::vector local_shape = output_layouts[i].LocalShapeFromGlobalShape( global_output_type.getShape()); @@ -111,10 +111,9 @@ StatusOr IteratorGetNextAsOptionalSPMDExpander::ExpandOp( for (int i = 0; i < array_attr.size(); ++i) { std::vector local_shape = output_layouts[i].LocalShapeFromGlobalShape( - array_attr[i].cast().getShape()); - output_shape_attrs[i] = - mlir::TF::ShapeAttr::get(op->getContext(), {local_shape}) - .cast(); + mlir::cast(array_attr[i]).getShape()); + output_shape_attrs[i] = mlir::cast( + mlir::TF::ShapeAttr::get(op->getContext(), {local_shape})); } // Update the `output_shapes` attribute on the op to match the local shape diff --git a/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc index cc82fa42901fa7..0ef04e88833e00 100644 --- a/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/meta_spmd_expander.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" @@ -244,8 +245,9 @@ namespace { Status VerifyPaddedDimensionNotSharded(const Layout& layout, mlir::Value pad_input, mlir::Value pad_output) { - auto input_type = pad_input.getType().dyn_cast(); - auto output_type = pad_output.getType().dyn_cast(); + auto input_type = mlir::dyn_cast(pad_input.getType()); + auto output_type = + mlir::dyn_cast(pad_output.getType()); if (!input_type || !output_type) return errors::InvalidArgument( "pad op input/output should have statically known shape for SPMD."); @@ -435,7 +437,7 @@ StatusOr TileSPMDExpander::ExpandOp(mlir::Operation* op) { auto multiples_const = IntConst(builder, location, local_tile_multiples); auto global_output_type = - tile_op.getResult().getType().cast(); + mlir::cast(tile_op.getResult().getType()); TF_ASSIGN_OR_RETURN( auto local_type, LocalTypeFromGlobalType(output_layout.value(), global_output_type)); @@ -458,7 +460,7 @@ StatusOr> TileSPMDExpander::ComputeLayoutForward( auto tile_op = llvm::cast(op); auto output_ranked_type = - tile_op.getOutput().getType().dyn_cast(); + mlir::dyn_cast(tile_op.getOutput().getType()); if (!output_ranked_type || !output_ranked_type.hasStaticShape()) { return errors::InvalidArgument( llvm::formatv( @@ -503,7 +505,7 @@ StatusOr> TileSPMDExpander::ComputeLayoutBackward( // Retrieve operand/output shapes of tile op. auto input_ranked_type = - tile_op.getInput().getType().dyn_cast(); + mlir::dyn_cast(tile_op.getInput().getType()); if (!input_ranked_type || !input_ranked_type.hasStaticShape()) { return errors::InvalidArgument( llvm::formatv( @@ -516,11 +518,9 @@ StatusOr> TileSPMDExpander::ComputeLayoutBackward( llvm::DenseMap input_layouts(op->getNumOperands()); // `multiples` operand is always set to have replicated layout. - input_layouts[1] = - Layout::ReplicatedOnMesh(mesh, tile_op.getMultiples() - .getType() - .cast() - .getRank()); + input_layouts[1] = Layout::ReplicatedOnMesh( + mesh, mlir::cast(tile_op.getMultiples().getType()) + .getRank()); llvm::SmallVector static_multiple; auto status = @@ -1043,9 +1043,9 @@ StatusOr OneHotSPMDExpander::ExpandOp(mlir::Operation* op) { mlir::TF::SliceOp selected_sharding_at_dimension = builder.create< mlir::TF::SliceOp>( one_hot_op.getLoc(), - mlir::RankedTensorType::get({1, 1}, mesh_coordinates.getType() - .cast() - .getElementType()), + mlir::RankedTensorType::get( + {1, 1}, mlir::cast(mesh_coordinates.getType()) + .getElementType()), /*input=*/mesh_coordinates, /*begin=*/IntConst(builder, one_hot_op.getLoc(), {0, mesh_dim_index}), /*size=*/IntConst(builder, one_hot_op.getLoc(), {1, 1})); diff --git a/tensorflow/dtensor/mlir/expansions/nullary_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/nullary_spmd_expander.cc index 20ca961d55c8f4..b6c0f573382ea2 100644 --- a/tensorflow/dtensor/mlir/expansions/nullary_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/nullary_spmd_expander.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/utils/array_container_utils.h" @@ -53,7 +54,8 @@ StatusOr NullarySPMDExpander::ExpandOp(mlir::Operation* op) { if (all_operands_fully_replicated) return op; if (auto const_op = mlir::dyn_cast(op)) { - if (auto dense = const_op.getValue().dyn_cast()) { + if (auto dense = + mlir::dyn_cast(const_op.getValue())) { if (dense.isSplat()) { // A 'splat' value for a DenseElementsAttr, has a single value for // all its elements. For these inputs, we don't need to slice. We just @@ -120,7 +122,7 @@ StatusOr> NullarySPMDExpander::ComputeLayoutForward( // Nullary ops always output replicated layout for output values. for (auto i = 0; i < op->getNumResults(); ++i) { auto output_ranked_type = - op->getResult(i).getType().dyn_cast(); + mlir::dyn_cast(op->getResult(i).getType()); if (!output_ranked_type) { return errors::InvalidArgument( llvm::formatv("requires output type to have statically known rank, " diff --git a/tensorflow/dtensor/mlir/expansions/optional_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/optional_spmd_expander.cc index d0f5efac876b1f..b2c082d3bb71b0 100644 --- a/tensorflow/dtensor/mlir/expansions/optional_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/optional_spmd_expander.cc @@ -17,6 +17,7 @@ limitations under the License. #include "llvm/ADT/SmallVector.h" #include "llvm/Support/FormatVariadic.h" +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/dtensor/cc/constants.h" #include "tensorflow/dtensor/cc/dstatus.h" #include "tensorflow/dtensor/mlir/dtensor_location.h" @@ -39,7 +40,7 @@ StatusOr OptionalGetValueSPMDExpander::ExpandOp( for (int i = 0; i < original_op->getNumResults(); ++i) { mlir::TensorType global_output_type = - original_op.getResult(i).getType().cast(); + mlir::cast(original_op.getResult(i).getType()); TF_ASSIGN_OR_RETURN( mlir::TensorType local_type, LocalTypeFromGlobalType(output_layouts[i], global_output_type)); diff --git a/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc index 5156e09e417c1f..fb4e5302c9dfae 100644 --- a/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/random_op_spmd_expander.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/IntegerSet.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" #include "tensorflow/dtensor/cc/constants.h" #include "tensorflow/dtensor/cc/tensor_layout.h" @@ -192,7 +193,7 @@ StatusOr ComputeNewSeed(mlir::OpBuilder& builder, mlir::Value op_seed) { TF_ASSIGN_OR_RETURN(auto device_id_seed, GetDeviceSeed(layout, op)); mlir::Type seed_type = - op_seed.getType().cast().getElementType(); + mlir::cast(op_seed.getType()).getElementType(); device_id_seed = builder.create( location, mlir::RankedTensorType::get({}, seed_type), device_id_seed); @@ -222,8 +223,8 @@ StatusOr CreatedShardedLocalRandomOpV1(const Layout& layout, // StatelessRandom op is used to make random op SPMD expansion // deterministic. mlir::Type new_random_type = mlir::RankedTensorType::get( - new_random_shape, - op->getResult(0).getType().cast().getElementType()); + new_random_shape, mlir::cast(op->getResult(0).getType()) + .getElementType()); auto new_shape_value = Int64Const(builder, location, new_random_shape); // TODO(zhonglinhan) : check different input for StatelessRandomUniformInt @@ -254,8 +255,8 @@ StatusOr CreatedShardedLocalRandomOpV2(const Layout& layout, // StatelessRandom op is used to make random op SPMD expansion // deterministic. mlir::Type new_random_type = mlir::RankedTensorType::get( - new_random_shape, - op->getResult(0).getType().cast().getElementType()); + new_random_shape, mlir::cast(op->getResult(0).getType()) + .getElementType()); auto new_shape_value = Int64Const(builder, location, new_random_shape); @@ -287,8 +288,8 @@ StatusOr CreatedShardedLocalRandomOpV2Range( // StatelessRandom op is used to make random op SPMD expansion // deterministic. mlir::Type new_random_type = mlir::RankedTensorType::get( - new_random_shape, - op->getResult(0).getType().cast().getElementType()); + new_random_shape, mlir::cast(op->getResult(0).getType()) + .getElementType()); auto new_shape_value = Int64Const(builder, location, new_random_shape); diff --git a/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc index 9d40da3173cee4..e3bb17d1829efe 100644 --- a/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/reduce_spmd_expander.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" @@ -218,7 +219,7 @@ StatusOr ReduceSPMDExpander::ExpandOp(mlir::Operation* op) { // Generate an error message for TPU int64. if (input_layout->mesh().is_tpu_mesh()) { if (auto tensor_type = - op->getOperand(0).getType().dyn_cast()) { + mlir::dyn_cast(op->getOperand(0).getType())) { if (tensor_type.getElementType().isInteger(64)) { return errors::InvalidArgument( "ReduceOp on TPU does not support int64 as dtype."); diff --git a/tensorflow/dtensor/mlir/expansions/resource_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/resource_spmd_expander.cc index c6f5042dfb8034..63d1bf4fd7e198 100644 --- a/tensorflow/dtensor/mlir/expansions/resource_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/resource_spmd_expander.cc @@ -28,6 +28,7 @@ limitations under the License. #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" @@ -238,7 +239,7 @@ StatusOr ResourceSPMDExpander::ExpandOp(mlir::Operation* op) { TF_RETURN_WITH_CONTEXT(errors::Internal( "if both resource and value layout are set they must be equal")); - auto block_arg = input_resource_value.dyn_cast(); + auto block_arg = mlir::dyn_cast(input_resource_value); auto enclosing_device_cluster = op->getParentOfType(); diff --git a/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc index 4df082defd1817..ea25519f13b7c3 100644 --- a/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/save_restore_spmd_expander.cc @@ -85,7 +85,7 @@ StatusOr GetAllCandidateCheckpointPrefixes( builder .create( prefix.getLoc(), - prefix.getType().dyn_cast(), prefix, + mlir::dyn_cast(prefix.getType()), prefix, StringConst(builder, prefix.getLoc(), llvm::SmallVector( {DeviceSuffix(0, mesh.num_devices())}))) @@ -96,7 +96,8 @@ StatusOr GetAllCandidateCheckpointPrefixes( builder .create( prefix.getLoc(), - prefix.getType().dyn_cast(), prefix, + mlir::dyn_cast(prefix.getType()), + prefix, StringConst(builder, prefix.getLoc(), llvm::SmallVector( {DeviceSuffix(device_id, mesh.num_devices())}))) @@ -529,7 +530,7 @@ StatusOr ExpandMergeV2Op(mlir::Operation* op) { mlir::Value zero_scalar, CreateZeroScalarConst( builder, location, - device_id.getType().cast().getElementType())); + mlir::cast(device_id.getType()).getElementType())); mlir::TF::NotEqualOp not_equal = builder.create( location, device_id, zero_scalar, @@ -697,7 +698,7 @@ StatusOr ExpandDTensorRestoreV2Op(mlir::Operation* op) { std::vector> input_shapes; input_shapes.reserve(input_shapes_attr.size()); for (const auto& shape : input_shapes_attr) { - mlir::TF::ShapeAttr shape_attr = shape.cast(); + mlir::TF::ShapeAttr shape_attr = mlir::cast(shape); if (!shape_attr.hasStaticShape()) { return absl::InvalidArgumentError( llvm::formatv("DTensorRestoreV2Op requires statically known input " @@ -718,7 +719,8 @@ StatusOr ExpandDTensorRestoreV2Op(mlir::Operation* op) { input_layouts.reserve(input_layouts_attr.size()); for (const auto& layout : input_layouts_attr.getValue().vec()) { input_layouts.push_back( - Layout::FromString(layout.cast().getValue().str()) + Layout::FromString( + mlir::cast(layout).getValue().str()) .value()); } @@ -778,7 +780,7 @@ StatusOr ExpandRestoreV2Op(mlir::Operation* op) { Layout& layout = std::get<2>(it); new_types.push_back(mlir::RankedTensorType::get( layout.LocalShapeFromGlobalShape(shape), - type.dyn_cast().getElementType())); + mlir::dyn_cast(type).getElementType())); } return ExpandRestoreV2OpHelper( @@ -910,7 +912,7 @@ SaveRestoreSPMDExpander::ComputeLayoutForward( TF_ASSIGN_OR_RETURN( Layout layout, Layout::FromString( - it.value().cast().getValue().str())); + mlir::cast(it.value()).getValue().str())); output_layouts[it.index()] = layout; } return output_layouts; diff --git a/tensorflow/dtensor/mlir/expansions/slice_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/slice_spmd_expander.cc index 13768bc95419b5..275a1bbe1af07e 100644 --- a/tensorflow/dtensor/mlir/expansions/slice_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/slice_spmd_expander.cc @@ -24,6 +24,7 @@ limitations under the License. #include "llvm/Support/FormatVariadic.h" #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/dtensor/cc/dstatus.h" @@ -116,7 +117,7 @@ StatusOr SliceSPMDExpander::ExpandOp(mlir::Operation* op) { // The dyn_cast will never be nullptr as it is checked in // GetLayoutFromOperands. auto input_type = - slice_op.getInput().getType().dyn_cast(); + mlir::dyn_cast(slice_op.getInput().getType()); if (!input_type) return errors::InvalidArgument( "rank of input tensor must be statically known for slice op."); @@ -172,10 +173,10 @@ StatusOr SliceSPMDExpander::ExpandOp(mlir::Operation* op) { auto loc = op->getLoc(); // Both begin and size need to be the same type, so we must match the new // size input with the type of begin. - if (!slice_op.getBegin().getType().isa()) + if (!mlir::isa(slice_op.getBegin().getType())) return errors::Internal("type of begin is not a ShapedType"); mlir::ShapedType type = - slice_op.getBegin().getType().cast(); + mlir::cast(slice_op.getBegin().getType()); if (type.getElementType().isInteger(32)) new_size = IntConst( builder, loc, llvm::SmallVector(sizes.begin(), sizes.end())); diff --git a/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc index 8ccdf0ee8168ae..2a2b03273f28dc 100644 --- a/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/softmax_spmd_expander.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/IRMapping.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/dtensor/cc/tensor_layout.h" #include "tensorflow/dtensor/mlir/collectives.h" @@ -79,10 +80,8 @@ StatusOr ComputeGlobalReduce( local_reduce, reduce_op)); if (!keep_dims) { - mlir::RankedTensorType output_type = - global_reduce->getResult(0) - .getType() - .dyn_cast(); + mlir::RankedTensorType output_type = mlir::dyn_cast( + global_reduce->getResult(0).getType()); if (!output_type) return errors::Internal( "output of EmitAllReduce is not a RankedTensorType"); @@ -209,7 +208,8 @@ StatusOr GetBroadcastedLayout(llvm::ArrayRef global_shape, // value. Assumes builder's insertion point is after input. StatusOr GetFPConstOfType(mlir::OpBuilder& builder, const mlir::Value& input, float value) { - if (mlir::TensorType type = input.getType().dyn_cast()) { + if (mlir::TensorType type = + mlir::dyn_cast(input.getType())) { return builder .create( input.getLoc(), @@ -239,7 +239,7 @@ StatusOr ComputeOneHot(mlir::OpBuilder& builder, // Get the number of classes for this onehot. The number of classes is the // global size of the last dimension of features. mlir::RankedTensorType features_type = - features.getType().dyn_cast(); + mlir::dyn_cast(features.getType()); if (!features_type) return errors::InvalidArgument( "feature input shape must be statically known"); @@ -297,7 +297,8 @@ StatusOr ComputeOneHot(mlir::OpBuilder& builder, // Note that the type of id_offset (int32) may not match the type of input. // So we insert a cast in this case. - mlir::TensorType input_type = input.getType().dyn_cast(); + mlir::TensorType input_type = + mlir::dyn_cast(input.getType()); if (!input_type) return errors::InvalidArgument("input is not a TensorType"); if (!input_type.getElementType().isInteger(32)) id_offset = diff --git a/tensorflow/dtensor/mlir/expansions/sparse_to_dense_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/sparse_to_dense_spmd_expander.cc index 1e9b003a14d46f..ee7ee35bbee5aa 100644 --- a/tensorflow/dtensor/mlir/expansions/sparse_to_dense_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/sparse_to_dense_spmd_expander.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/dtensor/mlir/layout_parsing.h" #include "tensorflow/dtensor/mlir/shape_utils.h" #include "tensorflow/dtensor/mlir/value_utils.h" @@ -35,7 +36,7 @@ StatusOr SparseToDenseSPMDExpander::ExpandOp( auto op_result = op->getResult(0); const auto element_type = - op_result.getType().cast().getElementType(); + mlir::cast(op_result.getType()).getElementType(); op_result.setType(mlir::RankedTensorType::get(local_shape, element_type)); // No-op return op; diff --git a/tensorflow/dtensor/mlir/expansions/squeeze_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/squeeze_spmd_expander.cc index c40e08814c36a1..5df9f648ffabfd 100644 --- a/tensorflow/dtensor/mlir/expansions/squeeze_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/squeeze_spmd_expander.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/core/platform/errors.h" #include "tensorflow/dtensor/cc/dstatus.h" #include "tensorflow/dtensor/mlir/layout_parsing.h" @@ -37,7 +38,8 @@ std::set GetSqueezeDims(mlir::Operation* op, int64_t rank) { if (array_attribute) { auto attr_list = array_attribute.getValue().vec(); for (const auto& attr : attr_list) { - int64_t dim = attr.cast().getValue().getSExtValue(); + int64_t dim = + mlir::cast(attr).getValue().getSExtValue(); // Offset the negative indices to positive range. squeeze_dims.insert((dim + rank) % rank); } diff --git a/tensorflow/dtensor/mlir/expansions/tensorlist_reserve_spmd_expander.cc b/tensorflow/dtensor/mlir/expansions/tensorlist_reserve_spmd_expander.cc index 0ef5ca777e8446..ce82bae63e930b 100644 --- a/tensorflow/dtensor/mlir/expansions/tensorlist_reserve_spmd_expander.cc +++ b/tensorflow/dtensor/mlir/expansions/tensorlist_reserve_spmd_expander.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/dtensor/mlir/expansions/tensorlist_reserve_spmd_expander.h" #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/dtensor/mlir/dtensor_location.h" @@ -38,9 +39,9 @@ StatusOr TensorListReserveSPMDExpander::ExpandOp( llvm::dyn_cast(op); mlir::OpBuilder builder(op); - mlir::Type element_type = GetSubtypeOrSelf(op->getOpResult(0)) - .cast() - .getElementType(); + mlir::Type element_type = + mlir::cast(GetSubtypeOrSelf(op->getOpResult(0))) + .getElementType(); mlir::RankedTensorType new_output_type = mlir::RankedTensorType::get( {}, mlir::TF::VariantType::get( diff --git a/tensorflow/dtensor/mlir/handle_cross_cluster_dependencies.cc b/tensorflow/dtensor/mlir/handle_cross_cluster_dependencies.cc index 422a4d4a8cafd1..6877f77dd25f37 100644 --- a/tensorflow/dtensor/mlir/handle_cross_cluster_dependencies.cc +++ b/tensorflow/dtensor/mlir/handle_cross_cluster_dependencies.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/UseDefLists.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project @@ -109,7 +110,7 @@ mlir::LogicalResult CloneOpToCluster(mlir::Operation* const_op, mlir::LogicalResult GetInputProducingValue(mlir::OpOperand& operand, mlir::Value* val_output) { - auto input_value = operand.get().dyn_cast(); + auto input_value = mlir::dyn_cast(operand.get()); if (!input_value) return mlir::success(); auto input_cluster = @@ -216,7 +217,7 @@ mlir::LogicalResult LowerToSendRecv(mlir::TF::CopyToMeshOp copy_to_mesh, mlir::MLIRContext* context, int* send_recv_counter) { const mlir::OpResult copied_value = - copy_to_mesh.getInput().cast(); + mlir::cast(copy_to_mesh.getInput()); const int result_index = copied_value.getResultNumber(); auto src_cluster = llvm::cast(copied_value.getDefiningOp()); @@ -243,7 +244,7 @@ mlir::LogicalResult LowerToSendRecv(mlir::TF::CopyToMeshOp copy_to_mesh, mlir::dtensor::MeshAttr::get(context, target_mesh)); // Create recv op that recvs data from send op. - auto tensor_type = value_to_send.getType().dyn_cast(); + auto tensor_type = mlir::dyn_cast(value_to_send.getType()); if (!tensor_type) return copy_to_mesh.emitOpError( "found CopyToMesh sending value with unknown shape. Inputs to " diff --git a/tensorflow/dtensor/mlir/handle_sparsetensors.cc b/tensorflow/dtensor/mlir/handle_sparsetensors.cc index 30c18accca8ff0..678d8a41841114 100644 --- a/tensorflow/dtensor/mlir/handle_sparsetensors.cc +++ b/tensorflow/dtensor/mlir/handle_sparsetensors.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -95,11 +96,12 @@ mlir::LogicalResult UpdateFunctionInputAttributes( auto dict_attr = main_func->getAttrOfType(kEntryFuncAttr); if (dict_attr) { - if (!dict_attr.get("inputs").isa()) + if (!mlir::isa(dict_attr.get("inputs"))) return main_func.emitOpError("Missing attribute inputs in main FuncOp."); - dict_attr.get("inputs").cast().getValue().split( - input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); + mlir::cast(dict_attr.get("inputs")) + .getValue() + .split(input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); llvm::SmallVector new_input_names; @@ -148,10 +150,10 @@ void CreateComponentTensorsFromSparseTensors( {mlir::ShapedType::kDynamic, ValueRank(block_arg)}, builder.getI64Type()), /*values=*/ - mlir::RankedTensorType::get({mlir::ShapedType::kDynamic}, - block_arg.getType() - .dyn_cast() - .getElementType()), + mlir::RankedTensorType::get( + {mlir::ShapedType::kDynamic}, + mlir::dyn_cast(block_arg.getType()) + .getElementType()), /*dense_shapes=*/ mlir::RankedTensorType::get({ValueRank(block_arg)}, builder.getI64Type()), @@ -214,11 +216,10 @@ struct DTensorSparseTensorToDenseTensor // Emit a SparseToDenseOp and replace the SparseTensor with the result of // this new op. - StatusOr zero_scalar = - CreateZeroScalarConst(builder, front_op->getLoc(), - sparse_tensor_value.getType() - .cast() - .getElementType()); + StatusOr zero_scalar = CreateZeroScalarConst( + builder, front_op->getLoc(), + mlir::cast(sparse_tensor_value.getType()) + .getElementType()); if (!zero_scalar.ok()) return signalPassFailure(); mlir::TF::SparseToDenseOp sparse_to_dense_op = builder.create( diff --git a/tensorflow/dtensor/mlir/ir/tf_dtensor.cc b/tensorflow/dtensor/mlir/ir/tf_dtensor.cc index 14e66c98b6cb2c..d119f034eecc38 100644 --- a/tensorflow/dtensor/mlir/ir/tf_dtensor.cc +++ b/tensorflow/dtensor/mlir/ir/tf_dtensor.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -38,13 +39,13 @@ namespace { RankedTensorType GetRankedTensorType(mlir::Value val) { mlir::Type type = val.getType(); if (auto type_with_subtype = - mlir::getElementTypeOrSelf(val) - .dyn_cast()) { + mlir::dyn_cast( + mlir::getElementTypeOrSelf(val))) { if (type_with_subtype.GetSubtypes().size() == 1) { type = type_with_subtype.GetSubtypes().front(); } } - return type.dyn_cast_or_null(); + return mlir::dyn_cast_or_null(type); } } // namespace @@ -110,7 +111,7 @@ mlir::LogicalResult DTensorAllGatherOp::verify() { } RankedTensorType input_type = - op.getInput().getType().dyn_cast(); + mlir::dyn_cast(op.getInput().getType()); if (!input_type) return mlir::success(); if (input_type.getRank() != input_layout.rank()) @@ -119,7 +120,7 @@ mlir::LogicalResult DTensorAllGatherOp::verify() { << " is not equal to input rank " << input_type.getRank(); RankedTensorType output_type = - op.getOutput().getType().dyn_cast(); + mlir::dyn_cast(op.getOutput().getType()); if (!output_type) return mlir::success(); if (output_type.getRank() != output_layout.rank()) @@ -166,7 +167,7 @@ mlir::LogicalResult DTensorAllScatterOp::verify() { } RankedTensorType input_type = - op.getInput().getType().dyn_cast(); + mlir::dyn_cast(op.getInput().getType()); if (!input_type) return mlir::success(); if (input_type.getRank() != input_layout.rank()) @@ -175,7 +176,7 @@ mlir::LogicalResult DTensorAllScatterOp::verify() { << " is not equal to input rank " << input_type.getRank(); RankedTensorType output_type = - op.getOutput().getType().dyn_cast(); + mlir::dyn_cast(op.getOutput().getType()); if (!output_type) return mlir::success(); if (output_type.getRank() != output_layout.rank()) @@ -237,7 +238,7 @@ mlir::LogicalResult DTensorAllToAllOp::verify() { } RankedTensorType input_type = - op.getInput().getType().dyn_cast(); + mlir::dyn_cast(op.getInput().getType()); if (!input_type) return mlir::success(); if (input_type.getRank() != input_layout.rank()) @@ -246,7 +247,7 @@ mlir::LogicalResult DTensorAllToAllOp::verify() { << " is not equal to input rank " << input_type.getRank(); RankedTensorType output_type = - op.getOutput().getType().dyn_cast(); + mlir::dyn_cast(op.getOutput().getType()); if (!output_type) return mlir::success(); if (output_type.getRank() != output_layout.rank()) diff --git a/tensorflow/dtensor/mlir/layout_parsing.cc b/tensorflow/dtensor/mlir/layout_parsing.cc index 441ab2b86ede67..4b9011f42b0af0 100644 --- a/tensorflow/dtensor/mlir/layout_parsing.cc +++ b/tensorflow/dtensor/mlir/layout_parsing.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/mutex.h" @@ -105,7 +106,7 @@ StatusOr>> ExtractLayoutFromOp( if (!serialized_layouts) return outs; for (auto const& attr : serialized_layouts) { - auto attr_str = attr.cast().getValue().str(); + auto attr_str = mlir::cast(attr).getValue().str(); if (!attr_str.empty()) { TF_ASSIGN_OR_RETURN(auto layout, Layout::FromString(attr_str)); outs.emplace_back(std::move(layout)); @@ -162,7 +163,7 @@ StatusOr> ExtractDeviceMeshFromOp(mlir::Operation* op) { } StatusOr> ExtractLayoutFromOperand(mlir::Value operand) { - if (auto op_result = operand.dyn_cast()) { + if (auto op_result = mlir::dyn_cast(operand)) { mlir::Operation* op = op_result.getDefiningOp(); std::optional out; if (auto layout_op = llvm::dyn_cast(op)) { @@ -185,7 +186,7 @@ StatusOr> ExtractLayoutFromOperand(mlir::Value operand) { return out; } - auto block_arg = operand.dyn_cast(); + auto block_arg = mlir::dyn_cast(operand); if (!block_arg) return errors::Internal( "Operand is not either a OpResult or a BlockArgument. This should not " @@ -293,7 +294,7 @@ StatusOr> ExtractElementLayoutsFromOperand( operand_index, op->getName()) .str()); - auto block_arg = input_value.get().dyn_cast(); + auto block_arg = mlir::dyn_cast(input_value.get()); auto array_attr = enclosing_function.getArgAttrOfType( block_arg.getArgNumber(), kIteratorElementLayouts); if (!array_attr) @@ -305,9 +306,10 @@ StatusOr> ExtractElementLayoutsFromOperand( llvm::SmallVector layouts(array_attr.size()); for (int i = 0; i < array_attr.size(); ++i) { - layouts[i] = Layout::FromString( - array_attr[i].cast().getValue().str()) - .value(); + layouts[i] = + Layout::FromString( + mlir::cast(array_attr[i]).getValue().str()) + .value(); } return layouts; diff --git a/tensorflow/dtensor/mlir/layout_propagation_v2.cc b/tensorflow/dtensor/mlir/layout_propagation_v2.cc index a7f089eb1726ec..54b798881e5b09 100644 --- a/tensorflow/dtensor/mlir/layout_propagation_v2.cc +++ b/tensorflow/dtensor/mlir/layout_propagation_v2.cc @@ -113,7 +113,7 @@ void UpdateLayoutForSkippedOps( llvm::SmallVector skipped_values; TraceUseToNextTFOp(&operand, func_to_caller, &skipped_values); for (const mlir::Value& skipped_value : skipped_values) - if ((!skipped_value.isa() || + if ((!mlir::isa(skipped_value) || !mlir::isa( skipped_value.getDefiningOp())) && layouts.find(skipped_value) == layouts.end()) @@ -724,7 +724,7 @@ mlir::LogicalResult InsertDTensorLayoutOps( // resource type elements. mlir::Type value_type = GetSubtypeOrSelf(merged_layout.first); - if (auto type = value_type.dyn_cast()) { + if (auto type = mlir::dyn_cast(value_type)) { auto layout_op = builder.create( merged_layout.first.getLoc(), merged_layout.first, layout_attr, mlir::TF::ShapeAttr::get(builder.getContext(), type)); @@ -810,7 +810,7 @@ void GetOperationsNeedingUpdate( if (!mlir::isa(use->getOwner())) operations.insert(use->getOwner()); // If this is an OpResult, also add the op that produces it. - if (value.isa() && + if (mlir::isa(value) && !mlir::isa(value.getDefiningOp())) operations.insert(value.getDefiningOp()); } @@ -933,7 +933,7 @@ class LayoutPrinter : public mlir::OpAsmPrinter { // Print an operand, this could be both the OpResult or a BlockArgument. // We also print the layout if it exists and the type. void printOperand(mlir::Value value, llvm::raw_ostream& os) override { - if (auto result = value.dyn_cast()) { + if (auto result = mlir::dyn_cast(value)) { // If DTensorLayout ops are already in the module, we need to skip them // since we aren't printing them out. if (mlir::isa(result.getDefiningOp())) { @@ -946,7 +946,7 @@ class LayoutPrinter : public mlir::OpAsmPrinter { os << "%" << location_[result.getDefiningOp()]; if (result.getDefiningOp()->getNumResults() > 1) os << ":" << result.getResultNumber(); - } else if (auto argument = value.dyn_cast()) { + } else if (auto argument = mlir::dyn_cast(value)) { if (arguments_.find(argument) == arguments_.end()) arguments_[argument] = next_argument_++; os << "%arg" << arguments_[argument]; @@ -1203,7 +1203,7 @@ mlir::LogicalResult InsertRelayoutForWhileLoops( mlir::cast(input_layout_op).getLayout(); // Inputs to Yield should also be a DTensorLayout op. - if (!yield_op->getOperand(i).isa() || + if (!mlir::isa(yield_op->getOperand(i)) || !mlir::isa( yield_op->getOperand(i).getDefiningOp())) return yield_op->emitOpError() @@ -1220,12 +1220,12 @@ mlir::LogicalResult InsertRelayoutForWhileLoops( // Insert the first Relayout op (in the loop body). builder.setInsertionPointAfter(output_layout_op); - if (!yield_op->getOperand(i).getType().isa()) + if (!mlir::isa(yield_op->getOperand(i).getType())) return yield_op->emitOpError() << "operand " << i << " does not have TensorType"; mlir::TF::ShapeAttr global_shape = mlir::TF::ShapeAttr::get( builder.getContext(), - yield_op->getOperand(i).getType().cast()); + mlir::cast(yield_op->getOperand(i).getType())); mlir::TF::RelayoutOp first_relayout = builder.create( op.getLoc(), yield_op->getOperand(i).getType(), diff --git a/tensorflow/dtensor/mlir/merge_clusters.cc b/tensorflow/dtensor/mlir/merge_clusters.cc index 3aa323ec1d4cc1..575a4e643b331d 100644 --- a/tensorflow/dtensor/mlir/merge_clusters.cc +++ b/tensorflow/dtensor/mlir/merge_clusters.cc @@ -39,6 +39,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project @@ -278,7 +279,7 @@ void CloneEmptyIfWithPredicate(mlir::TF::IfRegionOp if_region, const Mesh& mesh, // DTensorSend op sends the predicate to `mesh` cluster with replicated // layout. mlir::TensorType predicate_tensor_type = - if_region.getCond().getType().cast(); + mlir::cast(if_region.getCond().getType()); const std::string send_recv_key = absl::StrCat(kSendRecvKeyPrefix, *num_send_recvs); *num_send_recvs += 1; @@ -341,7 +342,7 @@ mlir::LogicalResult VerifyClusterInputOutput( mlir::LogicalResult result = mlir::success(); mlir::visitUsedValuesDefinedAbove( cluster.getBody(), cluster.getBody(), [&](mlir::OpOperand* input) { - if (!input->get().isa()) { + if (!mlir::isa(input->get())) { result = cluster.emitOpError( "found nested tf_device.Cluster op with inputs. Nested cluster " "must use send/recv instead."); diff --git a/tensorflow/dtensor/mlir/mesh_propagation.cc b/tensorflow/dtensor/mlir/mesh_propagation.cc index 6d57b068bd5062..bb5369a530c2cb 100644 --- a/tensorflow/dtensor/mlir/mesh_propagation.cc +++ b/tensorflow/dtensor/mlir/mesh_propagation.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -101,7 +102,7 @@ mlir::LogicalResult ExtractMeshFromBlockArgumentWhile( } return mlir::success(); } else if (auto func_block_arg = - while_op_operand.dyn_cast()) { + mlir::dyn_cast(while_op_operand)) { // The while op operand is a block argument of the function, then follow the // same routine of getting mesh from function argument. auto function_op = mlir::dyn_cast_or_null( @@ -183,7 +184,7 @@ mlir::LogicalResult ExtractMeshFromOperand( // If `operand` is a block argument then extract mesh from `tf._mesh` // attribute of the corresponding function argument. - if (auto block_arg = operand_value.dyn_cast()) { + if (auto block_arg = mlir::dyn_cast(operand_value)) { if (mlir::failed(ExtractMeshFromBlockArgument(block_arg, out))) return mlir::failure(); @@ -193,7 +194,7 @@ mlir::LogicalResult ExtractMeshFromOperand( auto producer_values = it->getSecond(); std::optional operand_mesh; for (mlir::Value producer_value : producer_values) { - if (auto arg = producer_value.dyn_cast()) { + if (auto arg = mlir::dyn_cast(producer_value)) { std::optional mesh; if (mlir::failed(ExtractMeshFromBlockArgument(arg, &mesh))) return mlir::failure(); @@ -206,7 +207,7 @@ mlir::LogicalResult ExtractMeshFromOperand( producer_value.getDefiningOp() ->getParentOfType(); auto output_from_producing_op = input_cluster.getResult( - producer_value.cast().getResultNumber()); + mlir::cast(producer_value).getResultNumber()); std::optional mesh; if (mlir::failed( diff --git a/tensorflow/dtensor/mlir/move_compilation_to_host.cc b/tensorflow/dtensor/mlir/move_compilation_to_host.cc index d10a65f7d6b465..2fd20105388bf4 100644 --- a/tensorflow/dtensor/mlir/move_compilation_to_host.cc +++ b/tensorflow/dtensor/mlir/move_compilation_to_host.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/dtensor/cc/tensor_layout.h" @@ -141,9 +142,9 @@ mlir::LogicalResult CreateSendRecvOpsToTransferProgramKey( mlir::OpBuilder fn_builder = mlir::OpBuilder::atBlockEnd(fn_block); auto recv = fn_builder.create( compile_op->getLoc(), - compilation_key.getType().cast(), device_key_map[i], - compile_op_launch.getDevice(), /*send_device_incarnation=*/0, - local_devices[i]); + mlir::cast(compilation_key.getType()), + device_key_map[i], compile_op_launch.getDevice(), + /*send_device_incarnation=*/0, local_devices[i]); recv->setAttr("device", builder.getStringAttr(local_devices[i])); fn_builder.create(recv_select_fn.getLoc(), diff --git a/tensorflow/dtensor/mlir/propagate_default_layout.cc b/tensorflow/dtensor/mlir/propagate_default_layout.cc index 39aba3a8716b9b..030e8af63aa1d8 100644 --- a/tensorflow/dtensor/mlir/propagate_default_layout.cc +++ b/tensorflow/dtensor/mlir/propagate_default_layout.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/dtensor/cc/constants.h" @@ -83,7 +84,7 @@ mlir::LogicalResult PropagateDTensorLayoutForRelayout( mlir::OpBuilder builder(relayout->getBlock(), ++mlir::Block::iterator(relayout)); - mlir::TensorType type = relayout.getType().dyn_cast(); + mlir::TensorType type = mlir::dyn_cast(relayout.getType()); if (!type) return relayout.emitOpError("type required for Relayout op"); CreateDTensorLayoutOp(layout, relayout.getOutput(), type, relayout.getLoc(), @@ -110,7 +111,7 @@ mlir::LogicalResult PropagateFunctionArgAttrToLayoutOp( mlir::OpBuilder builder(function.getBody()); auto arg = function.getArgument(arg_index); mlir::Type tensor_type = GetSubtypeOrSelf(arg); - if (auto type = tensor_type.dyn_cast()) { + if (auto type = mlir::dyn_cast(tensor_type)) { CreateDTensorLayoutOp(layout_or_status.value(), arg, type, function.getLoc(), builder.getI64IntegerAttr(arg_index), &builder, &c); @@ -149,7 +150,7 @@ mlir::LogicalResult PropagateFunctionDefaultLayoutAttrToLayoutOp( mlir::OpBuilder builder(function_terminator); auto return_value = function_terminator->getOperand(ret_index); - if (auto type = return_value.getType().dyn_cast()) + if (auto type = mlir::dyn_cast(return_value.getType())) CreateDTensorLayoutOp(result_layout_or_status.value(), return_value, type, function.getLoc(), nullptr, &builder, &c); else @@ -187,7 +188,8 @@ mlir::LogicalResult PropagateOpAttrToLayoutOp(mlir::MLIRContext& context, if (!layout || layout->IsEmpty()) continue; auto op_output = op->getResult(index); - if (auto type = op_output.getType().dyn_cast()) { + if (auto type = + mlir::dyn_cast(op_output.getType())) { CreateDTensorLayoutOp(*layout, op_output, type, function.getLoc(), arg_index, &builder, &context); } else { diff --git a/tensorflow/dtensor/mlir/propagate_device_id_to_function_args.cc b/tensorflow/dtensor/mlir/propagate_device_id_to_function_args.cc index 404be2b38e61fc..34648f9f407f2c 100644 --- a/tensorflow/dtensor/mlir/propagate_device_id_to_function_args.cc +++ b/tensorflow/dtensor/mlir/propagate_device_id_to_function_args.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -69,7 +70,7 @@ llvm::SmallVector FindFunctionsToRewrite( symbol = call_op.getF(); } else { auto symbol_ref = llvm::dyn_cast(op).getF(); - if (!symbol_ref.isa()) return; + if (!mlir::isa(symbol_ref)) return; symbol = symbol_ref.getRootReference().getValue(); } diff --git a/tensorflow/dtensor/mlir/restore_shape_inference.cc b/tensorflow/dtensor/mlir/restore_shape_inference.cc index abbda8cbdd80ae..4a400ecb9f341e 100644 --- a/tensorflow/dtensor/mlir/restore_shape_inference.cc +++ b/tensorflow/dtensor/mlir/restore_shape_inference.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/dtensor/mlir/dtensor_send_recv.h" @@ -81,9 +82,7 @@ mlir::LogicalResult BackwardShapeInferenceToRestoreOp(mlir::ModuleOp module, // the type to the operand element type. mlir::RankedTensorType new_type = mlir::RankedTensorType::get( GetShapeOfValue(new_cast_op.getResult()).value(), - new_cast_op.getOperand() - .getType() - .cast() + mlir::cast(new_cast_op.getOperand().getType()) .getElementType()); // Recursively shape inference to the input of the cast op with the @@ -120,7 +119,7 @@ mlir::LogicalResult BackwardShapeInferenceToRestoreOp(mlir::ModuleOp module, auto new_recv_op = builder->create( recv_op.getLoc(), type, builder->getStringAttr(recv_op.getKey()), mlir::TF::ShapeAttr::get(builder->getContext(), - type.dyn_cast()), + mlir::dyn_cast(type)), mlir::dtensor::MeshAttr::get(builder->getContext(), recv_op.getMesh())); recv_op.replaceAllUsesWith(new_recv_op.getOutput()); diff --git a/tensorflow/dtensor/mlir/sparse_expansions/dynamic_enqueue_sparse_expander.cc b/tensorflow/dtensor/mlir/sparse_expansions/dynamic_enqueue_sparse_expander.cc index 3cc8d9e3ec513c..f609c1576f72fd 100644 --- a/tensorflow/dtensor/mlir/sparse_expansions/dynamic_enqueue_sparse_expander.cc +++ b/tensorflow/dtensor/mlir/sparse_expansions/dynamic_enqueue_sparse_expander.cc @@ -39,7 +39,7 @@ namespace { StatusOr ExpandIndices(mlir::OpBuilder& builder, mlir::Value indices) { int64_t num_dim = - indices.getType().dyn_cast().getDimSize(1); + mlir::dyn_cast(indices.getType()).getDimSize(1); if (num_dim != 2) return errors::Unimplemented( "Sparse tensors with dense rank not equal to 2 is not yet supported in " @@ -47,7 +47,8 @@ StatusOr ExpandIndices(mlir::OpBuilder& builder, mlir::Location loc = indices.getLoc(); auto indices_padded_type = mlir::RankedTensorType::get( {mlir::ShapedType::kDynamic, 3}, - indices.getType().dyn_cast().getElementType()); + mlir::dyn_cast(indices.getType()) + .getElementType()); // Little trick to make a rank-2 tensor of [[0,0], [0,1]] using rank 1 // constants. mlir::Value indices_padding = builder.create( diff --git a/tensorflow/dtensor/mlir/spmd_expander.cc b/tensorflow/dtensor/mlir/spmd_expander.cc index 4e63f87970777f..5538afe18c1a94 100644 --- a/tensorflow/dtensor/mlir/spmd_expander.cc +++ b/tensorflow/dtensor/mlir/spmd_expander.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/errors.h" @@ -177,7 +178,7 @@ Status SPMDExpanderBase::ExpandOpAndSetLayout(mlir::Operation* op, global_output_shapes.reserve(op->getNumResults()); for (auto output_value : op->getResults()) { auto maybe_ranked = - output_value.getType().dyn_cast(); + mlir::dyn_cast(output_value.getType()); // Do not extract global shape if the shape isn't statically known. // // This is a bit subtle and relies on the check of static shape of output diff --git a/tensorflow/dtensor/mlir/spmd_expander_common.cc b/tensorflow/dtensor/mlir/spmd_expander_common.cc index b3f823ae7e9fc4..1cee400f1a283f 100644 --- a/tensorflow/dtensor/mlir/spmd_expander_common.cc +++ b/tensorflow/dtensor/mlir/spmd_expander_common.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/errors.h" @@ -113,7 +114,7 @@ Status CreateSplitOp(const int num_split, const int split_dimension, // Correctly set output shapes of split op output if input shape is statically // known. mlir::Type output_type; - auto input_type = src_input.getType().cast(); + auto input_type = mlir::cast(src_input.getType()); if (input_type.hasRank()) { if (input_type.getShape()[split_dimension] == mlir::ShapedType::kDynamic) { @@ -680,7 +681,7 @@ mlir::StringAttr GetUniqueControlflowFnName(const std::string& prefix, Status SetBuilderInsertionAfterValue(mlir::Value value, mlir::OpBuilder& builder) { - if (value.isa()) { + if (mlir::isa(value)) { builder.setInsertionPointAfterValue(value); return absl::OkStatus(); } @@ -719,7 +720,7 @@ Status PrintTensor(mlir::Value value, const std::string& format_string = "%s") { Status ExtractConstStringVectorFromValue( mlir::Value value, llvm::SmallVectorImpl& out_vector) { value = GetForwardedDTensorLayoutInput(value); - if (value.isa()) + if (mlir::isa(value)) return errors::Internal("Unable get constant value from block argument."); mlir::DenseStringElementsAttr attr; if (!matchPattern(value, m_Constant(&attr))) { @@ -736,7 +737,7 @@ Status ExtractConstStringVectorFromValue( StatusOr ExtractConstScalarStringFromValue(mlir::Value value) { value = GetForwardedDTensorLayoutInput(value); - if (value.isa()) + if (mlir::isa(value)) return errors::Internal("Unable get constant value from block argument."); mlir::DenseStringElementsAttr attr; if (!matchPattern(value, m_Constant(&attr))) { diff --git a/tensorflow/dtensor/mlir/tpu_add_resource_device_attribute.cc b/tensorflow/dtensor/mlir/tpu_add_resource_device_attribute.cc index 7fcf1d3b3e2c20..7094adde2ea977 100644 --- a/tensorflow/dtensor/mlir/tpu_add_resource_device_attribute.cc +++ b/tensorflow/dtensor/mlir/tpu_add_resource_device_attribute.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -92,10 +93,10 @@ struct DTensorTpuAddResourceDeviceAttribute mlir::WalkResult walk_result = module.walk([](mlir::TF::TPUExecuteOp tpu_execute) { for (mlir::Value tpu_input : tpu_execute.getOperands()) { - if (tpu_input.isa() && + if (mlir::isa(tpu_input) && IsResourceType(tpu_input)) AddPlaceholderDeviceAttributeToResource( - tpu_input.cast(), tpu_execute); + mlir::cast(tpu_input), tpu_execute); mlir::Operation* input_op = tpu_input.getDefiningOp(); auto read_variable_op = @@ -103,7 +104,7 @@ struct DTensorTpuAddResourceDeviceAttribute if (!read_variable_op) continue; AddPlaceholderDeviceAttributeToResource( - read_variable_op.getResource().cast(), + mlir::cast(read_variable_op.getResource()), tpu_execute); } @@ -113,9 +114,9 @@ struct DTensorTpuAddResourceDeviceAttribute if (assign_variable == nullptr) continue; AddPlaceholderDeviceAttributeToResource( - llvm::cast(assign_variable) - .getResource() - .cast(), + mlir::cast( + llvm::cast(assign_variable) + .getResource()), tpu_execute); } diff --git a/tensorflow/dtensor/mlir/utils/collective_lowering.cc b/tensorflow/dtensor/mlir/utils/collective_lowering.cc index 5a12d4de95dcc0..fb71b092d3c695 100644 --- a/tensorflow/dtensor/mlir/utils/collective_lowering.cc +++ b/tensorflow/dtensor/mlir/utils/collective_lowering.cc @@ -42,6 +42,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" #include "xla/tsl/util/env_var.h" #include "tensorflow/core/platform/errors.h" @@ -269,7 +270,8 @@ mlir::Operation* EmitCollectiveReduceScatter( const mlir::DenseIntElementsAttr& group_assignment, int32 scatter_dimension, int32 key_base, mlir::Value device_id, int32 host_group_size, const mlir::StringRef device_type) { - mlir::TensorType input_type = input.getType().dyn_cast(); + mlir::TensorType input_type = + mlir::dyn_cast(input.getType()); const bool need_transpose = scatter_dimension != 0; std::vector perm_for_transpose; @@ -282,9 +284,10 @@ mlir::Operation* EmitCollectiveReduceScatter( auto pre_transpose_op = EmitTransposeOp(builder, loc, input, perm_for_transpose); input = pre_transpose_op->getResult(0); - input_type = input.getType().dyn_cast(); + input_type = mlir::dyn_cast(input.getType()); // Compute transposed output type for CollectiveReduceScatter - auto output_shape = output_type.dyn_cast().getShape(); + auto output_shape = + mlir::dyn_cast(output_type).getShape(); std::vector transposed_shape(output_shape.begin(), output_shape.end()); for (int i = 0; i < output_shape.size(); i++) { @@ -338,8 +341,8 @@ mlir::Operation* EmitCollectiveAllToAll( // data correctly. An example relayout that requires this is [y, unsharded, x] // -> [y, x, unsharded]. const mlir::TensorType input_type = - input.getType().dyn_cast(); - auto input_shape = input_type.dyn_cast().getShape(); + mlir::dyn_cast(input.getType()); + auto input_shape = mlir::dyn_cast(input_type).getShape(); // TODO(trevor-m): One of the transpose pairs created when requires_transpose // is true can be combined with the transpose in permute_data() that lies on @@ -478,7 +481,7 @@ mlir::Operation* EmitCollectiveGather( auto shape = group_assignment.getType().getShape(); const int32 group_size = shape[1]; const mlir::TensorType input_type = - input.getType().dyn_cast(); + mlir::dyn_cast(input.getType()); auto input_shape = input_type.getShape(); auto dim_0_shape = input_shape[0]; std::vector output_shape = {input_shape.begin(), input_shape.end()}; @@ -758,9 +761,9 @@ mlir::LogicalResult LowerAllGatherOpToCollective( const std::string device_type = device_type_or_status.value(); const mlir::RankedTensorType input_type = - all_gather.getInput().getType().dyn_cast(); + mlir::dyn_cast(all_gather.getInput().getType()); const mlir::RankedTensorType output_type = - all_gather.getOutput().getType().dyn_cast(); + mlir::dyn_cast(all_gather.getOutput().getType()); if (!input_type) return all_gather.emitOpError() << "input type is not a RankedTensorType"; @@ -901,9 +904,9 @@ mlir::LogicalResult LowerAllGatherOp(mlir::TF::DTensorAllGatherOp all_gather) { } const mlir::RankedTensorType input_type = - all_gather.getInput().getType().dyn_cast(); + mlir::dyn_cast(all_gather.getInput().getType()); const mlir::RankedTensorType output_type = - all_gather.getOutput().getType().dyn_cast(); + mlir::dyn_cast(all_gather.getOutput().getType()); if (!input_type) return all_gather.emitOpError() << "input type is not a RankedTensorType"; @@ -1048,7 +1051,7 @@ mlir::LogicalResult LowerAllGatherOp(mlir::TF::DTensorAllGatherOp all_gather) { // position in the tensor, only one task in the reduction group can have a 1. // This is sufficient. const mlir::TensorType type = - update_result.getType().dyn_cast(); + mlir::dyn_cast(update_result.getType()); absl::string_view reduce_type = kReduceOpAdd; if (type && type.getElementType().isInteger(1)) reduce_type = kReduceOpAny; mlir::TF::DTensorAllReduceOp all_reduce = @@ -1090,7 +1093,7 @@ mlir::LogicalResult LowerAllScatterOp( // sharding_spec[j]=i and this is a dimension with split and 0 otherwise. mlir::RankedTensorType output_type = - all_scatter.getOutput().getType().dyn_cast(); + mlir::dyn_cast(all_scatter.getOutput().getType()); if (!output_type) return all_scatter.emitOpError() << "input must have static rank"; diff --git a/third_party/xla/xla/mlir/backends/cpu/transforms/BUILD b/third_party/xla/xla/mlir/backends/cpu/transforms/BUILD index bae1e833302a91..dd3602bd4f7725 100644 --- a/third_party/xla/xla/mlir/backends/cpu/transforms/BUILD +++ b/third_party/xla/xla/mlir/backends/cpu/transforms/BUILD @@ -58,6 +58,7 @@ cc_library( "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:VectorDialect", diff --git a/third_party/xla/xla/mlir/backends/cpu/transforms/legalize_i1_vector_transfers.cc b/third_party/xla/xla/mlir/backends/cpu/transforms/legalize_i1_vector_transfers.cc index 09d4293af32dc3..359efa920d8451 100644 --- a/third_party/xla/xla/mlir/backends/cpu/transforms/legalize_i1_vector_transfers.cc +++ b/third_party/xla/xla/mlir/backends/cpu/transforms/legalize_i1_vector_transfers.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "xla/mlir/backends/cpu/transforms/passes.h" #include "xla/mlir/xla_cpu/ir/xla_cpu.h" @@ -51,11 +52,11 @@ Value CastToI8(Value in, ImplicitLocOpBuilder& b, bool optional = false) { return {}; } - if (auto vec_ty = ty.dyn_cast()) { + if (auto vec_ty = mlir::dyn_cast(ty)) { return b.create( vec_ty.cloneWith(std::nullopt, b.getI8Type()), in); } - if (auto memref_ty = ty.dyn_cast()) { + if (auto memref_ty = mlir::dyn_cast(ty)) { auto cast_ty = memref_ty.cloneWith(std::nullopt, b.getI8Type()); return b.create(cast_ty, in); } diff --git a/third_party/xla/xla/mlir/backends/cpu/transforms/legalize_library_ops.cc b/third_party/xla/xla/mlir/backends/cpu/transforms/legalize_library_ops.cc index c3231a92625385..e85b7b82849c68 100644 --- a/third_party/xla/xla/mlir/backends/cpu/transforms/legalize_library_ops.cc +++ b/third_party/xla/xla/mlir/backends/cpu/transforms/legalize_library_ops.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "xla/mlir/backends/cpu/transforms/passes.h" #include "xla/mlir/xla_cpu/ir/xla_cpu.h" @@ -80,7 +81,8 @@ std::optional MatchReductionComputation( return xla_cpu::ReductionKind::ALL_REDUCE_MAX; } - auto type = computation->getOperandTypes().front().dyn_cast(); + auto type = + mlir::dyn_cast(computation->getOperandTypes().front()); if (!type || !type.getElementType().isInteger(1)) { return std::nullopt; } @@ -97,7 +99,7 @@ std::optional MatchReductionComputation( // Returns a `tensor.empty` with the same shape as `tensor`. Value CreateEmptyLike(OpBuilder& b, Location loc, Value tensor) { - auto ty = tensor.getType().cast(); + auto ty = mlir::cast(tensor.getType()); auto sizes = tensor::getMixedSizes(b, loc, tensor); return b.create(loc, sizes, ty.getElementType()); } @@ -197,7 +199,7 @@ class AllToAllLowering : public OpRewritePattern { dsts.push_back(rewriter.create( op.getLoc(), getAsOpFoldResult(sizes), - op->getResultTypes()[0].cast().getElementType())); + mlir::cast(op->getResultTypes()[0]).getElementType())); } rewriter.replaceOpWithNewOp( @@ -243,7 +245,7 @@ class InfeedLowering : public OpRewritePattern { llvm::SmallVector dsts; for (const auto& type : op.getResultTypes()) { - if (auto ranked_type = type.dyn_cast()) { + if (auto ranked_type = mlir::dyn_cast(type)) { dsts.push_back(b.create( op.getLoc(), ranked_type.getShape(), ranked_type.getElementType())); } else { @@ -266,8 +268,8 @@ class OutfeedLowering : public OpRewritePattern { PatternRewriter& rewriter) const override { SmallVector result_types; for (auto operand : op.getInputs()) { - result_types.push_back( - TypeAttr::get(operand.getType().cast().getElementType())); + result_types.push_back(TypeAttr::get( + mlir::cast(operand.getType()).getElementType())); } rewriter.create( op.getLoc(), std::nullopt, op.getInputs(), op.getOutfeedConfigAttr(), @@ -327,9 +329,9 @@ class ConvolutionLowering : public OpRewritePattern { PatternRewriter& rewriter) const override { ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto input_shape = op.getLhs().getType().dyn_cast(); - auto kernel_shape = op.getRhs().getType().dyn_cast(); - auto output_shape = op.getResult().getType().dyn_cast(); + auto input_shape = mlir::dyn_cast(op.getLhs().getType()); + auto kernel_shape = mlir::dyn_cast(op.getRhs().getType()); + auto output_shape = mlir::dyn_cast(op.getResult().getType()); auto dnums = op.getDimensionNumbers(); auto reversals = op.getWindowReversal(); diff --git a/third_party/xla/xla/mlir/backends/cpu/transforms/xla_abi_legalization.cc b/third_party/xla/xla/mlir/backends/cpu/transforms/xla_abi_legalization.cc index 03298f338e7980..145f7b41200e83 100644 --- a/third_party/xla/xla/mlir/backends/cpu/transforms/xla_abi_legalization.cc +++ b/third_party/xla/xla/mlir/backends/cpu/transforms/xla_abi_legalization.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "xla/mlir/backends/cpu/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -78,7 +79,7 @@ Value NormalizeTensor(ImplicitLocOpBuilder& b, TypedValue tensor, void NormalizeInputInPlace(ImplicitLocOpBuilder& b, Value tensor, ArrayRef layout) { - auto typedTensor = tensor.dyn_cast>(); + auto typedTensor = mlir::dyn_cast>(tensor); if (!typedTensor || IsDefaultLayout(layout)) { return; } @@ -92,12 +93,13 @@ SmallVector> FlattenLayoutAttribute(Attribute attr) { SmallVector> layouts; auto visit_attr = [&](mlir::Attribute attr) { - if (attr.isa()) { - layouts.emplace_back(attr.cast().getValues()); + if (mlir::isa(attr)) { + layouts.emplace_back( + mlir::cast(attr).getValues()); } }; - if (auto array = attr.dyn_cast()) { + if (auto array = mlir::dyn_cast(attr)) { for (int64_t i = 0; i < array.size(); ++i) { visit_attr(array[i]); } @@ -164,7 +166,7 @@ struct RewriteReturnArgs : OpRewritePattern { results.push_back( IsDefaultLayout(layout) ? result - : NormalizeTensor(b, result.cast>(), + : NormalizeTensor(b, mlir::cast>(result), layout, /*isInput=*/false)); } @@ -176,8 +178,8 @@ struct RewriteReturnArgs : OpRewritePattern { }; bool IsI1Tensor(Type ty) { - return ty.isa() && - ty.cast().getElementType().isInteger(1); + return mlir::isa(ty) && + mlir::cast(ty).getElementType().isInteger(1); } struct RewriteI1Results : OpRewritePattern { @@ -239,7 +241,8 @@ struct RewriteCustomCalls : OpRewritePattern { const auto& layout = operand_layouts[index]; if (!IsDefaultLayout(layout)) { Value normalized = NormalizeTensor( - b, op.getOperand(index).cast>(), layout, + b, mlir::cast>(op.getOperand(index)), + layout, /*isInput=*/false); op.setOperand(index, normalized); } diff --git a/third_party/xla/xla/mlir/backends/cpu/transforms/xla_cpu_memref_element_cast_to_llvm.cc b/third_party/xla/xla/mlir/backends/cpu/transforms/xla_cpu_memref_element_cast_to_llvm.cc index 8bedb9be524aef..0b65d4ae477c1e 100644 --- a/third_party/xla/xla/mlir/backends/cpu/transforms/xla_cpu_memref_element_cast_to_llvm.cc +++ b/third_party/xla/xla/mlir/backends/cpu/transforms/xla_cpu_memref_element_cast_to_llvm.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/mlir/backends/cpu/transforms/passes.h" #include "xla/mlir/xla_cpu/ir/xla_cpu.h" @@ -47,11 +48,11 @@ struct MemRefElementCastOpLowering LogicalResult matchAndRewrite( xla_cpu::MemRefElementCastOp cast_op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto target_memref_ty = cast_op.getDst().getType().cast(); + auto target_memref_ty = mlir::cast(cast_op.getDst().getType()); const LLVMTypeConverter &type_converter = *getTypeConverter(); - auto target_desc_ty = type_converter.convertType(target_memref_ty) - .dyn_cast_or_null(); + auto target_desc_ty = mlir::dyn_cast_or_null( + type_converter.convertType(target_memref_ty)); if (!target_desc_ty) { return failure(); } @@ -62,7 +63,7 @@ struct MemRefElementCastOpLowering SmallVector desc_fields; MemRefDescriptor::unpack(rewriter, loc, adaptor.getSrc(), - src_type.cast(), desc_fields); + mlir::cast(src_type), desc_fields); // Create descriptor. auto dst_desc = MemRefDescriptor::pack(rewriter, loc, type_converter, diff --git a/third_party/xla/xla/mlir/backends/cpu/transforms/xla_cpu_to_cpu_runtime.cc b/third_party/xla/xla/mlir/backends/cpu/transforms/xla_cpu_to_cpu_runtime.cc index fb3bb71548c6f9..a4582aa29ec9f3 100644 --- a/third_party/xla/xla/mlir/backends/cpu/transforms/xla_cpu_to_cpu_runtime.cc +++ b/third_party/xla/xla/mlir/backends/cpu/transforms/xla_cpu_to_cpu_runtime.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "xla/mlir/backends/cpu/transforms/passes.h" #include "xla/mlir/runtime/transforms/type_converter.h" @@ -71,7 +72,7 @@ SmallVector EnsureFlatMemrefs(ValueRange values, ImplicitLocOpBuilder& b) { SmallVector out; for (Value value : values) { - auto ty = value.getType().dyn_cast(); + auto ty = mlir::dyn_cast(value.getType()); if (!ty || ty.getLayout().isIdentity()) { out.push_back(value); } else { @@ -148,7 +149,7 @@ class AllReduceLowering : public OpRewritePattern { LogicalResult matchAndRewrite(xla_cpu::AllReduceOp op, PatternRewriter& rewriter) const override { - if (!op.getOperandTypes().front().isa()) { + if (!mlir::isa(op.getOperandTypes().front())) { return failure(); } @@ -207,7 +208,7 @@ class CollectivePermuteLowering LogicalResult matchAndRewrite(xla_cpu::CollectivePermuteOp op, PatternRewriter& rewriter) const override { - if (!op.getOperandTypes().front().isa()) { + if (!mlir::isa(op.getOperandTypes().front())) { return failure(); } @@ -274,7 +275,7 @@ class RngBitGeneratorLowering LogicalResult matchAndRewrite(xla_cpu::RngBitGeneratorOp op, PatternRewriter& rewriter) const override { auto algorithm = - op.getRngAlgorithmAttr().cast().getValue(); + mlir::cast(op.getRngAlgorithmAttr()).getValue(); op->removeAttr("rng_algorithm"); CreateCallForDpsCollectiveOp(op.getOperation(), custom_calls_, @@ -308,7 +309,7 @@ class InfeedLowering : public OpRewritePattern { // For infeed with empty tuples, bufferizer does not run, thus the token is // left as the only operand. Remove it. - if (operands.back().getType().isa()) { + if (mlir::isa(operands.back().getType())) { assert(operands.size() == 1 && "Expect token only with empty tuples"); operands.pop_back(); } diff --git a/third_party/xla/xla/mlir/framework/transforms/outline_with_xla_framework.cc b/third_party/xla/xla/mlir/framework/transforms/outline_with_xla_framework.cc index 8efd1726ed05cf..b13299d69a6149 100644 --- a/third_party/xla/xla/mlir/framework/transforms/outline_with_xla_framework.cc +++ b/third_party/xla/xla/mlir/framework/transforms/outline_with_xla_framework.cc @@ -79,7 +79,7 @@ struct OutlineXLAFunc : public RewritePattern { if (!func) return failure(); if (func.getSymName() != "main") return failure(); if (llvm::any_of(op->getOperandTypes(), - [](Type t) { return !t.isa(); }) || + [](Type t) { return !mlir::isa(t); }) || op->getNumResults() != 0) return failure(); if (func->hasAttr("outlined")) return failure(); diff --git a/third_party/xla/xla/mlir/framework/transforms/xla_framework_to_llvm_pass.cc b/third_party/xla/xla/mlir/framework/transforms/xla_framework_to_llvm_pass.cc index dc609381152f67..62ae86bce849ed 100644 --- a/third_party/xla/xla/mlir/framework/transforms/xla_framework_to_llvm_pass.cc +++ b/third_party/xla/xla/mlir/framework/transforms/xla_framework_to_llvm_pass.cc @@ -138,11 +138,11 @@ struct BarePtrFuncOpConversion : public ConvertOpToLLVMPattern { Value inner_index = rewriter.create( loc, typeConverter->convertType(rewriter.getIntegerType(32)), rewriter.getI32IntegerAttr(static_cast( - funcOp - ->getAttrOfType( - "xla_framework.result_inner_mapping") - .getValue()[current_index] - .cast() + mlir::cast( + funcOp + ->getAttrOfType( + "xla_framework.result_inner_mapping") + .getValue()[current_index]) .getValue() .getSExtValue()))); @@ -227,10 +227,11 @@ class LegalizeXLAFrameworkToLLVMPass target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); target.addIllegalDialect(); target.addDynamicallyLegalOp([](func::FuncOp op) { - if (llvm::any_of( - llvm::concat(op.getArgumentTypes(), - op.getResultTypes()), - [](Type type) { return type.isa(); })) + if (llvm::any_of(llvm::concat(op.getArgumentTypes(), + op.getResultTypes()), + [](Type type) { + return mlir::isa(type); + })) return false; return true; }); diff --git a/third_party/xla/xla/mlir/runtime/ir/BUILD b/third_party/xla/xla/mlir/runtime/ir/BUILD index f044bf6f4b5d94..bf05a86d524deb 100644 --- a/third_party/xla/xla/mlir/runtime/ir/BUILD +++ b/third_party/xla/xla/mlir/runtime/ir/BUILD @@ -107,5 +107,6 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", ], ) diff --git a/third_party/xla/xla/mlir/runtime/ir/rt_dialect.cc b/third_party/xla/xla/mlir/runtime/ir/rt_dialect.cc index 45b52b3f7f0b18..85af7d69cf46e7 100644 --- a/third_party/xla/xla/mlir/runtime/ir/rt_dialect.cc +++ b/third_party/xla/xla/mlir/runtime/ir/rt_dialect.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/DialectImplementation.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/mlir/runtime/ir/rt_interfaces.h" #include "xla/mlir/runtime/ir/rt_ops.h" #include "xla/runtime/constraints.h" @@ -35,7 +36,7 @@ namespace runtime { static bool IsRtConstraintAttr(mlir::Attribute attr) { // If attribute is not defined it means that there is no constraint if (!attr) return true; - auto str = attr.dyn_cast_or_null(); + auto str = mlir::dyn_cast_or_null(attr); absl::StatusOr constraint = ParseArgumentConstraint(str.getValue()); return constraint.ok(); @@ -83,7 +84,7 @@ mlir::LogicalResult RuntimeDialect::verifyOperationAttribute( // Custom call attribute can be defined only on a function declaration. if (attribute.getName() == "rt.custom_call") { - if (!(attribute.getValue().isa())) { + if (!(mlir::isa(attribute.getValue()))) { return op->emitOpError() << "requires " << attribute.getName() << " to only accept string value"; } @@ -111,7 +112,7 @@ mlir::LogicalResult RuntimeDialect::verifyOperationAttribute( // Trace annotation should implement an attribute interface. if (attribute.getName() == "rt.trace") { - if (!attribute.getValue().isa()) { + if (!mlir::isa(attribute.getValue())) { return op->emitOpError() << " requires " << attribute.getName() << " to be a trace annotation attribute"; } diff --git a/third_party/xla/xla/mlir/runtime/transforms/BUILD b/third_party/xla/xla/mlir/runtime/transforms/BUILD index 10fd62ad03d721..7be003f0f4ccec 100644 --- a/third_party/xla/xla/mlir/runtime/transforms/BUILD +++ b/third_party/xla/xla/mlir/runtime/transforms/BUILD @@ -96,6 +96,7 @@ xla_cc_test( "//xla/mlir/runtime/ir:rt", "@llvm-project//mlir:IR", "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@local_tsl//tsl/platform:test", "@local_tsl//tsl/platform:test_main", diff --git a/third_party/xla/xla/mlir/runtime/transforms/calling_convention_test.cc b/third_party/xla/xla/mlir/runtime/transforms/calling_convention_test.cc index 6d326264d3be15..9c586951c312ae 100644 --- a/third_party/xla/xla/mlir/runtime/transforms/calling_convention_test.cc +++ b/third_party/xla/xla/mlir/runtime/transforms/calling_convention_test.cc @@ -20,6 +20,7 @@ #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "xla/mlir/runtime/ir/rt_dialect.h" #include "tsl/platform/test.h" @@ -47,8 +48,8 @@ TEST(CallingConventionTest, DefaultCallingConvention) { auto converted = calling_convention(signature); EXPECT_EQ(converted.getNumInputs(), 2); - EXPECT_TRUE(converted.getInput(0).isa()); - EXPECT_TRUE(converted.getInput(1).isa()); + EXPECT_TRUE(mlir::isa(converted.getInput(0))); + EXPECT_TRUE(mlir::isa(converted.getInput(1))); } TEST(CallingConventionTest, DefaultCallingConventionWithTypeConverter) { @@ -68,8 +69,8 @@ TEST(CallingConventionTest, DefaultCallingConventionWithTypeConverter) { auto converted = calling_convention(signature); EXPECT_EQ(converted.getNumInputs(), 2); - EXPECT_TRUE(converted.getInput(0).isa()); - EXPECT_TRUE(converted.getInput(1).isa()); + EXPECT_TRUE(mlir::isa(converted.getInput(0))); + EXPECT_TRUE(mlir::isa(converted.getInput(1))); } TEST(CallingConventionTest, ResultsToOutsCallingConvention) { @@ -89,8 +90,8 @@ TEST(CallingConventionTest, ResultsToOutsCallingConvention) { auto converted = calling_convention(signature); EXPECT_EQ(converted.getNumInputs(), 2); - EXPECT_TRUE(converted.getInput(0).isa()); - EXPECT_TRUE(converted.getInput(1).isa()); + EXPECT_TRUE(mlir::isa(converted.getInput(0))); + EXPECT_TRUE(mlir::isa(converted.getInput(1))); EXPECT_EQ(converted.getNumResults(), 0); } diff --git a/third_party/xla/xla/mlir/runtime/transforms/custom_call_encoding.cc b/third_party/xla/xla/mlir/runtime/transforms/custom_call_encoding.cc index 7a97554822b193..2bba0cceb1e8dc 100644 --- a/third_party/xla/xla/mlir/runtime/transforms/custom_call_encoding.cc +++ b/third_party/xla/xla/mlir/runtime/transforms/custom_call_encoding.cc @@ -193,7 +193,7 @@ static LLVM::GlobalOp EncodeDenseElementsAttribute( Globals &g, ImplicitLocOpBuilder &b, Attribute value, std::string_view symbol_base) { MLIRContext *ctx = b.getContext(); - DenseIntOrFPElementsAttr dense = value.cast(); + DenseIntOrFPElementsAttr dense = mlir::cast(value); Type ptr = LLVM::LLVMPointerType::get(ctx); @@ -362,7 +362,7 @@ static LLVM::GlobalOp EncodeDenseArrayAttribute(Globals &g, std::string_view symbol_base) { MLIRContext *ctx = b.getContext(); - DenseArrayAttr base_array = value.cast(); + DenseArrayAttr base_array = mlir::cast(value); int64_t size = base_array.size(); Type ptr = LLVM::LLVMPointerType::get(ctx); @@ -593,20 +593,20 @@ static bool IsAnyOf(unsigned width, ArrayRef supported) { } static bool IsSupportedScalarType(Type type) { - if (auto idx = type.dyn_cast()) return true; + if (auto idx = mlir::dyn_cast(type)) return true; - if (auto i = type.dyn_cast()) + if (auto i = mlir::dyn_cast(type)) return i.isUnsigned() ? IsAnyOf(i.getWidth(), {8, 16, 32, 64}) : IsAnyOf(i.getWidth(), {1, 8, 16, 32, 64}); - if (auto fp = type.dyn_cast()) + if (auto fp = mlir::dyn_cast(type)) return IsAnyOf(fp.getWidth(), {16, 32, 64}); return false; } static bool IsSupportedScalarAttribute(Attribute attr) { - if (auto typed = attr.dyn_cast()) + if (auto typed = mlir::dyn_cast(attr)) return IsSupportedScalarType(typed.getType()); return false; } @@ -637,7 +637,7 @@ static TypeID ScalarRuntimeTypeId(Type type) { static PrimitiveType ScalarPrimitiveType(Type type) { // Integer types. if (type.isInteger(1)) return PrimitiveType::PRED; - if (auto int_type = type.dyn_cast()) { + if (auto int_type = mlir::dyn_cast(type)) { unsigned int width = int_type.getWidth(); if (auto primitive_type = int_type.isUnsigned() @@ -662,7 +662,7 @@ static PrimitiveType ScalarPrimitiveType(Type type) { if (type.isBF16()) return PrimitiveType::BF16; // Complex types. - if (auto complex = type.dyn_cast()) { + if (auto complex = mlir::dyn_cast(type)) { return primitive_util::ComplexType( ScalarPrimitiveType(complex.getElementType())); } @@ -703,7 +703,7 @@ static TypeID AsyncValueRuntimeTypeId(Type elem_type) { return TypeID::get>>(); if (elem_type.isF64()) return TypeID::get>>(); - if (elem_type.isa()) + if (mlir::isa(elem_type)) return TypeID::get>>(); assert(false && "unsupported type id"); @@ -731,7 +731,7 @@ static TypeID DenseElementsRuntimeTypeId(Type elem_type) { LogicalResult StringAttrEncoding::Match(mlir::SymbolTable &, std::string_view name, Attribute attr) const { - return success(attr.isa()); + return success(mlir::isa(attr)); } FailureOr StringAttrEncoding::Encode(mlir::SymbolTable &, @@ -739,7 +739,7 @@ FailureOr StringAttrEncoding::Encode(mlir::SymbolTable &, ImplicitLocOpBuilder &b, std::string_view name, Attribute attr) const { - auto str = attr.cast(); + auto str = mlir::cast(attr); Encoded encoded; encoded.name = EncodeString(g, b, name, kAttrName); @@ -761,7 +761,7 @@ FailureOr ScalarAttrEncoding::Encode(mlir::SymbolTable &, ImplicitLocOpBuilder &b, std::string_view name, Attribute attr) const { - Type type = attr.cast().getType(); + Type type = mlir::cast(attr).getType(); Encoded encoded; encoded.name = EncodeString(g, b, name, kAttrName); @@ -776,7 +776,7 @@ FailureOr ScalarAttrEncoding::Encode(mlir::SymbolTable &, LogicalResult DenseElementsAttrEncoding::Match(mlir::SymbolTable &, std::string_view name, Attribute attr) const { - if (auto dense = attr.dyn_cast()) + if (auto dense = mlir::dyn_cast(attr)) return success(IsSupportedScalarType(dense.getElementType())); return failure(); } @@ -784,7 +784,7 @@ LogicalResult DenseElementsAttrEncoding::Match(mlir::SymbolTable &, FailureOr DenseElementsAttrEncoding::Encode( mlir::SymbolTable &, Globals &g, ImplicitLocOpBuilder &b, std::string_view name, Attribute attr) const { - auto dense = attr.cast(); + auto dense = mlir::cast(attr); Type elem_type = dense.getType().getElementType(); Encoded encoded; @@ -800,8 +800,8 @@ FailureOr DenseElementsAttrEncoding::Encode( LogicalResult ArrayAttrEncoding::Match(mlir::SymbolTable &, std::string_view name, Attribute attr) const { - if (auto array = attr.dyn_cast(); - array && !array.empty() && array[0].isa()) { + if (auto array = mlir::dyn_cast(attr); + array && !array.empty() && mlir::isa(array[0])) { return success(IsSupportedScalarAttribute(array[0])); } return failure(); @@ -812,12 +812,12 @@ FailureOr ArrayAttrEncoding::Encode(mlir::SymbolTable &, ImplicitLocOpBuilder &b, std::string_view name, Attribute attr) const { - ArrayAttr array = attr.dyn_cast(); - Type elem_type = array[0].cast().getType(); + ArrayAttr array = mlir::dyn_cast(attr); + Type elem_type = mlir::cast(array[0]).getType(); // We only support array attributes with elements of same type. bool all_of_same_type = llvm::all_of(array, [&](Attribute attr) { - auto typed = attr.dyn_cast(); + auto typed = mlir::dyn_cast(attr); return typed && typed.getType() == elem_type; }); if (!all_of_same_type) return failure(); @@ -835,7 +835,7 @@ FailureOr ArrayAttrEncoding::Encode(mlir::SymbolTable &, LogicalResult DenseArrayAttrEncoding::Match(mlir::SymbolTable &, std::string_view name, Attribute attr) const { - if (auto array = attr.dyn_cast()) { + if (auto array = mlir::dyn_cast(attr)) { return success(); } return failure(); @@ -846,7 +846,7 @@ FailureOr DenseArrayAttrEncoding::Encode(mlir::SymbolTable &, ImplicitLocOpBuilder &b, std::string_view name, Attribute attr) const { - Type elem_type = attr.cast().getElementType(); + Type elem_type = mlir::cast(attr).getElementType(); Encoded encoded; encoded.name = EncodeString(g, b, name, kAttrName); @@ -861,7 +861,7 @@ FailureOr DenseArrayAttrEncoding::Encode(mlir::SymbolTable &, LogicalResult EmptyArrayAttrEncoding::Match(mlir::SymbolTable &, std::string_view name, Attribute attr) const { - if (auto array = attr.dyn_cast(); array && array.empty()) { + if (auto array = mlir::dyn_cast(attr); array && array.empty()) { return success(); } return failure(); @@ -885,7 +885,7 @@ FailureOr EmptyArrayAttrEncoding::Encode(mlir::SymbolTable &, LogicalResult SymbolRefAttrEncoding::Match(mlir::SymbolTable &sym_table, std::string_view name, Attribute attr) const { - if (auto ref = attr.dyn_cast()) { + if (auto ref = mlir::dyn_cast(attr)) { auto exported = sym_table.lookup(ref.getValue()); return success(exported && exported->hasAttr(kExportedAttrName)); } @@ -896,7 +896,7 @@ FailureOr SymbolRefAttrEncoding::Encode( mlir::SymbolTable &sym_table, Globals &g, ImplicitLocOpBuilder &b, std::string_view name, Attribute attr) const { // Get the exported function ordinal. - auto ref = attr.cast(); + auto ref = mlir::cast(attr); auto func = sym_table.lookup(ref.getValue()); auto ordinal = func->getAttrOfType(kExportedAttrName); assert(ordinal.getType().isSignlessInteger(32)); @@ -917,7 +917,7 @@ FailureOr SymbolRefAttrEncoding::Encode( LogicalResult UnitAttrEncoding::Match(mlir::SymbolTable &, std::string_view, Attribute attr) const { - return success(attr.isa()); + return success(mlir::isa(attr)); } FailureOr UnitAttrEncoding::Encode(mlir::SymbolTable &, Globals &g, @@ -937,7 +937,7 @@ FailureOr UnitAttrEncoding::Encode(mlir::SymbolTable &, Globals &g, LogicalResult DictionaryAttrEncoding::Match(mlir::SymbolTable &, std::string_view, Attribute attr) const { - return success(attr.isa()); + return success(mlir::isa(attr)); } FailureOr DictionaryAttrEncoding::Encode( @@ -1068,7 +1068,7 @@ FailureOr ScalarArgEncoding::Encode(Globals &g, Allocas &a, //===----------------------------------------------------------------------===// static bool IsOpaqueValue(Value value) { - return value.getType().isa(); + return mlir::isa(value.getType()); } OpaqueArgEncoding::OpaqueArgEncoding() @@ -1079,7 +1079,7 @@ OpaqueArgEncoding::OpaqueArgEncoding(std::function match, : match_(std::move(match)), type_id_(type_id) {} LogicalResult OpaqueArgEncoding::Match(Value value, Value converted) const { - if (auto ptr = converted.getType().dyn_cast()) + if (auto ptr = mlir::dyn_cast(converted.getType())) return success(match_(value)); return failure(); } @@ -1190,14 +1190,14 @@ static Value EncodeMemRef(ImplicitLocOpBuilder &b, MemRefType memref_ty, } LogicalResult MemrefArgEncoding::Match(Value value, Value converted) const { - return success(value.getType().isa()); + return success(mlir::isa(value.getType())); } FailureOr MemrefArgEncoding::Encode(Globals &g, Allocas &a, ImplicitLocOpBuilder &b, Value value, Value converted) const { - auto memref_type = value.getType().cast(); + auto memref_type = mlir::cast(value.getType()); // If memref has non-identity layout we use `StridedMemrefView` to // distinguish it from the default row-major memref. @@ -1242,7 +1242,7 @@ FailureOr ScalarRetEncoding::Decode(ImplicitLocOpBuilder &b, Type type, //===----------------------------------------------------------------------===// -static bool IsOpaqueType(Type type) { return type.isa(); } +static bool IsOpaqueType(Type type) { return mlir::isa(type); } OpaqueRetEncoding::OpaqueRetEncoding() : OpaqueRetEncoding(IsOpaqueType, TypeID::get>()) {} @@ -1252,7 +1252,7 @@ OpaqueRetEncoding::OpaqueRetEncoding(std::function match, : match_(std::move(match)), type_id_(type_id) {} LogicalResult OpaqueRetEncoding::Match(Type type, Type converted) const { - if (auto ptr = converted.dyn_cast()) + if (auto ptr = mlir::dyn_cast(converted)) return success(match_(type)); return failure(); } @@ -1280,15 +1280,15 @@ FailureOr OpaqueRetEncoding::Decode(ImplicitLocOpBuilder &b, Type type, //===----------------------------------------------------------------------===// LogicalResult MemrefRetEncoding::Match(Type type, Type converted) const { - return success(type.isa() && - converted.isa()); + return success(mlir::isa(type) && + mlir::isa(converted)); } FailureOr MemrefRetEncoding::Encode(Globals &g, Allocas &a, ImplicitLocOpBuilder &b, Type type, Type converted) const { - auto memref_ty = type.cast(); + auto memref_ty = mlir::cast(type); // We assume custom calls can only return row-major memrefs, may need to add // PermutedMemref support in the future. @@ -1353,9 +1353,9 @@ FailureOr MemrefRetEncoding::Decode(ImplicitLocOpBuilder &b, Type type, //===----------------------------------------------------------------------===// LogicalResult AsyncValueRetEncoding::Match(Type type, Type converted) const { - return success( - (type.isa() || type.isa()) && - converted.isa()); + return success((mlir::isa(type) || + mlir::isa(type)) && + mlir::isa(converted)); } FailureOr AsyncValueRetEncoding::Encode(Globals &g, Allocas &a, @@ -1365,9 +1365,9 @@ FailureOr AsyncValueRetEncoding::Encode(Globals &g, Allocas &a, Type ptr = LLVM::LLVMPointerType::get(b.getContext()); Value one = b.create(b.getI32IntegerAttr(1)); - auto type_id = type.isa() + auto type_id = mlir::isa(type) ? AsyncValueRuntimeTypeId( - type.cast().getValueType()) + mlir::cast(type).getValueType()) : TypeID::get>>(); Encoded encoded; @@ -1375,8 +1375,8 @@ FailureOr AsyncValueRetEncoding::Encode(Globals &g, Allocas &a, // for !async.value encoding its dtype, rank and dims with // EncodedMemRef struct; we use its data field to store async value ptr. - if (auto value_ty = type.dyn_cast()) { - if (auto memref_ty = value_ty.getValueType().dyn_cast()) { + if (auto value_ty = mlir::dyn_cast(type)) { + if (auto memref_ty = mlir::dyn_cast(value_ty.getValueType())) { encoded.value = PackValue(b, a, EncodeMemRef(b, memref_ty, /*descriptor=*/nullptr)); return encoded; @@ -1391,8 +1391,8 @@ FailureOr AsyncValueRetEncoding::Encode(Globals &g, Allocas &a, FailureOr AsyncValueRetEncoding::Decode(ImplicitLocOpBuilder &b, Type type, Type converted, LLVM::AllocaOp alloca) const { - if (auto value_ty = type.dyn_cast()) { - if (auto memref_ty = value_ty.getValueType().dyn_cast()) { + if (auto value_ty = mlir::dyn_cast(type)) { + if (auto memref_ty = mlir::dyn_cast(value_ty.getValueType())) { // TODO(ezhulenev): Add support for returning dynamically shaped memref. if (!memref_ty.hasStaticShape()) return failure(); diff --git a/third_party/xla/xla/mlir/runtime/transforms/custom_call_encoding.h b/third_party/xla/xla/mlir/runtime/transforms/custom_call_encoding.h index ea9fac26865ed5..7f1674d7542cf3 100644 --- a/third_party/xla/xla/mlir/runtime/transforms/custom_call_encoding.h +++ b/third_party/xla/xla/mlir/runtime/transforms/custom_call_encoding.h @@ -41,6 +41,7 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "xla/runtime/custom_call.h" @@ -619,7 +620,7 @@ struct AggregateAttrEncoding : public CustomCallAttrEncoding { mlir::LogicalResult Match(mlir::SymbolTable &, std::string_view, mlir::Attribute attr) const final { - return mlir::success(attr.isa()); + return mlir::success(mlir::isa(attr)); } mlir::FailureOr Encode(mlir::SymbolTable &sym_table, Globals &g, @@ -629,7 +630,7 @@ struct AggregateAttrEncoding : public CustomCallAttrEncoding { // Extract aggregate attributes from the user-defined attributes. llvm::SmallVector attrs; for (auto &bind : attrdef.bindings) - attrs.emplace_back(bind(attr.cast(), b)); + attrs.emplace_back(bind(mlir::cast(attr), b)); // Encode extracted attributes as an aggregate. auto type_id = TypeID::get>(); diff --git a/third_party/xla/xla/mlir/runtime/transforms/rt_to_llvm.cc b/third_party/xla/xla/mlir/runtime/transforms/rt_to_llvm.cc index 8af1cfda92803d..2f285e97698630 100644 --- a/third_party/xla/xla/mlir/runtime/transforms/rt_to_llvm.cc +++ b/third_party/xla/xla/mlir/runtime/transforms/rt_to_llvm.cc @@ -706,7 +706,7 @@ void ConvertRuntimeToLLVMPass::runOnOperation() { // Convert all async types to opaque pointers. llvm_converter.addConversion([&](Type type) -> std::optional { - if (type.isa()) + if (mlir::isa(type)) return LLVM::LLVMPointerType::get(ctx); return std::nullopt; }); diff --git a/third_party/xla/xla/mlir/runtime/transforms/specialization.cc b/third_party/xla/xla/mlir/runtime/transforms/specialization.cc index 3109917e6fb2e1..fc96a2c1584a4e 100644 --- a/third_party/xla/xla/mlir/runtime/transforms/specialization.cc +++ b/third_party/xla/xla/mlir/runtime/transforms/specialization.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/mlir/runtime/transforms/type_converter.h" #include "xla/mlir/runtime/utils/constraints.h" #include "xla/runtime/arguments.h" @@ -102,19 +103,19 @@ static StatusOr SpecializeOperandType( // Replace all symbolic dimensions with dynamic dimension. auto shape = SymbolicShapesResolver::Normalize(symbolic_shape); - if (auto memref = type.dyn_cast()) { + if (auto memref = mlir::dyn_cast(type)) { if (auto st = VerifyMemrefOperand(index, memref, *memref_arg); !st.ok()) return st; return mlir::MemRefType::get(shape, memref.getElementType()); } - if (auto tensor = type.dyn_cast()) { + if (auto tensor = mlir::dyn_cast(type)) { if (auto st = VerifyMemrefOperand(index, tensor, *memref_arg); !st.ok()) return st; return mlir::RankedTensorType::get(shape, tensor.getElementType()); } - if (auto tensor = type.dyn_cast()) { + if (auto tensor = mlir::dyn_cast(type)) { if (auto st = VerifyMemrefOperand(index, tensor, *memref_arg); !st.ok()) return st; return mlir::RankedTensorType::get(shape, tensor.getElementType()); @@ -236,7 +237,7 @@ Status SpecializeFunction(mlir::FunctionOpInterface func, // We only support sinking of Tensor arguments into the function body. mlir::Type input = llvm::cast(func.getFunctionType()).getInput(i); - mlir::TensorType tensor = input.dyn_cast(); + mlir::TensorType tensor = mlir::dyn_cast(input); if (!tensor || !SupportsValueSpecialization(tensor)) { return InvalidArgumentError(StrCat( "non-sinkable operand was marked for sinking: ", debugString(input))); diff --git a/third_party/xla/xla/mlir/runtime/transforms/type_converter.cc b/third_party/xla/xla/mlir/runtime/transforms/type_converter.cc index 0c0303bdf864c2..1d30f300f33397 100644 --- a/third_party/xla/xla/mlir/runtime/transforms/type_converter.cc +++ b/third_party/xla/xla/mlir/runtime/transforms/type_converter.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/mlir/runtime/ir/rt_dialect.h" #include "xla/primitive_util.h" #include "xla/runtime/types.h" @@ -47,60 +48,60 @@ using absl::StrFormat; static std::unique_ptr ConvertCanonicalType( mlir::Type type, const TypeConverter& convert) { // ExecutionContextType -> ExecutionContextOperandType (both in xla::runtime). - if (auto ctx = type.dyn_cast()) + if (auto ctx = mlir::dyn_cast(type)) return std::make_unique(); // OpaqueType -> OpaqueOperandType (both in xla::runtime). - if (auto ctx = type.dyn_cast()) + if (auto ctx = mlir::dyn_cast(type)) return std::make_unique(); // mlir::async::TokenType -> xla::runtime::AsyncTokenType - if (type.isa()) + if (mlir::isa(type)) return std::make_unique(); // mlir::async::ValueType -> xla::runtime::AsyncValueType - if (auto value = type.dyn_cast()) { + if (auto value = mlir::dyn_cast(type)) { if (auto value_type = convert.Convert(value.getValueType()); value_type.ok()) return std::make_unique(std::move(*value_type)); } // mlir::{IndexType, IntegerType, FloatType} -> xla::runtime::ScalarType - if (type.isa()) { + if (mlir::isa(type)) { if (auto dtype = TypeConverter::ConvertElementType(type); dtype.ok()) return std::make_unique(*dtype); } // mlir::RankedTensorType -> xla::runtime::RankedTensorType - if (auto tensor = type.dyn_cast()) { + if (auto tensor = mlir::dyn_cast(type)) { if (auto dtype = TypeConverter::ConvertElementType(tensor.getElementType()); dtype.ok()) return std::make_unique(tensor.getShape(), *dtype); } // mlir::UnrankedTensorType -> xla::runtime::UnrankedTensorType - if (auto tensor = type.dyn_cast()) { + if (auto tensor = mlir::dyn_cast(type)) { if (auto dtype = TypeConverter::ConvertElementType(tensor.getElementType()); dtype.ok()) return std::make_unique(*dtype); } // mlir::MemrefType -> xla::runtime::MemrefType - if (auto memref = type.dyn_cast()) { + if (auto memref = mlir::dyn_cast(type)) { if (auto dtype = TypeConverter::ConvertElementType(memref.getElementType()); dtype.ok()) return std::make_unique(memref.getShape(), *dtype); } // mlir::UnrankedMemrefType -> xla::runtime::UnrankedMemrefType - if (auto memref = type.dyn_cast()) { + if (auto memref = mlir::dyn_cast(type)) { if (auto dtype = TypeConverter::ConvertElementType(memref.getElementType()); dtype.ok()) return std::make_unique(*dtype); } // mlir::TupleType -> xla::runtime::TupleType - if (auto tuple = type.dyn_cast()) { + if (auto tuple = mlir::dyn_cast(type)) { llvm::SmallVector> conv_elems; llvm::transform(tuple, std::back_inserter(conv_elems), [&convert](mlir::Type type) { @@ -126,7 +127,7 @@ static std::unique_ptr ConvertCanonicalType( if (type.isF32()) return PrimitiveType::F32; if (type.isF64()) return PrimitiveType::F64; if (type.isInteger(1)) return PrimitiveType::PRED; - if (auto int_type = type.dyn_cast()) { + if (auto int_type = mlir::dyn_cast(type)) { unsigned int width = int_type.getWidth(); if (auto primitive_type = int_type.isUnsigned() @@ -136,7 +137,7 @@ static std::unique_ptr ConvertCanonicalType( return primitive_type; } } - if (auto complex_type = type.dyn_cast()) { + if (auto complex_type = mlir::dyn_cast(type)) { auto element_type = complex_type.getElementType(); TF_ASSIGN_OR_RETURN(auto element_primitive_type, ConvertElementType(element_type)); diff --git a/third_party/xla/xla/mlir/runtime/utils/constraints.cc b/third_party/xla/xla/mlir/runtime/utils/constraints.cc index da0023ea62b4de..907225b813adaf 100644 --- a/third_party/xla/xla/mlir/runtime/utils/constraints.cc +++ b/third_party/xla/xla/mlir/runtime/utils/constraints.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/runtime/constraints.h" namespace xla { @@ -52,7 +53,7 @@ StatusOr> GetArgumentsConstraints( if (!attr) return ArgumentConstraint::kResolved; // Otherwise try to parse constraint from the string attribute. - auto str = attr.dyn_cast_or_null(); + auto str = mlir::dyn_cast_or_null(attr); if (!str) return InvalidArgumentError( StrCat("unexpected ", kArgumentConstraintAttrName, " attribute")); @@ -81,7 +82,7 @@ StatusOr ResolveArgumentConstraint( if (constraint == ArgumentConstraint::kResolved) return constraint; // Operand must be a shaped type: memref or tensor. - auto shaped = type.dyn_cast(); + auto shaped = mlir::dyn_cast(type); if (!shaped) return InvalidArgumentError( StrCat("unsupported operand type: ", debugString(type))); diff --git a/third_party/xla/xla/mlir/runtime/utils/constraints.h b/third_party/xla/xla/mlir/runtime/utils/constraints.h index f88c3f267e0d38..29ea88c5145791 100644 --- a/third_party/xla/xla/mlir/runtime/utils/constraints.h +++ b/third_party/xla/xla/mlir/runtime/utils/constraints.h @@ -21,6 +21,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/runtime/constraints.h" namespace xla { @@ -39,7 +40,7 @@ absl::StatusOr ResolveArgumentConstraint( inline bool SupportsValueSpecialization(mlir::Type type) { // TODO(ezhulenev): Add support for sinking `memref` values once the value // specialization will support it. - mlir::TensorType tensor = type.dyn_cast(); + mlir::TensorType tensor = mlir::dyn_cast(type); return tensor && (tensor.getRank() == 0 || tensor.getRank() == 1) && (tensor.getElementType().isInteger(32) || tensor.getElementType().isInteger(64)); diff --git a/third_party/xla/xla/mlir/utils/BUILD b/third_party/xla/xla/mlir/utils/BUILD index 9652b2fbf3dbb0..4267e525a8aaf1 100644 --- a/third_party/xla/xla/mlir/utils/BUILD +++ b/third_party/xla/xla/mlir/utils/BUILD @@ -51,6 +51,7 @@ cc_library( "//xla:xla_data_proto_cc", "@com_google_absl//absl/status:statusor", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_tsl//tsl/platform:statusor", ], ) diff --git a/third_party/xla/xla/mlir/utils/type_util.cc b/third_party/xla/xla/mlir/utils/type_util.cc index 873d9609083270..29072d09f79174 100644 --- a/third_party/xla/xla/mlir/utils/type_util.cc +++ b/third_party/xla/xla/mlir/utils/type_util.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/primitive_util.h" #include "xla/util.h" #include "xla/xla_data.pb.h" @@ -91,11 +92,11 @@ xla::PrimitiveType ConvertMlirTypeToPrimitiveType(mlir::Type type) { return xla::PrimitiveType::F32; } else if (type.isF64()) { return xla::PrimitiveType::F64; - } else if (auto complex_type = type.dyn_cast()) { + } else if (auto complex_type = mlir::dyn_cast(type)) { mlir::Type element_ty = complex_type.getElementType(); return xla::primitive_util::ComplexType( ConvertMlirTypeToPrimitiveType(element_ty)); - } else if (auto integer_type = type.dyn_cast()) { + } else if (auto integer_type = mlir::dyn_cast(type)) { bool is_unsigned = integer_type.isUnsigned(); if (integer_type.getWidth() == 1) { return xla::PrimitiveType::PRED; diff --git a/third_party/xla/xla/mlir/xla_cpu/ir/BUILD b/third_party/xla/xla/mlir/xla_cpu/ir/BUILD index b5933c23b96da4..776416c54ed095 100644 --- a/third_party/xla/xla/mlir/xla_cpu/ir/BUILD +++ b/third_party/xla/xla/mlir/xla_cpu/ir/BUILD @@ -107,5 +107,6 @@ cc_library( "@llvm-project//mlir:BufferizationInterfaces", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Support", ], ) diff --git a/third_party/xla/xla/mlir/xla_cpu/ir/xla_cpu.cc b/third_party/xla/xla/mlir/xla_cpu/ir/xla_cpu.cc index 3e977dc7f2e7cc..49e13054573b81 100644 --- a/third_party/xla/xla/mlir/xla_cpu/ir/xla_cpu.cc +++ b/third_party/xla/xla/mlir/xla_cpu/ir/xla_cpu.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "xla/mlir/xla_cpu/ir/xla_cpu_dialect.cc.inc" #include "xla/mlir/xla_cpu/ir/xla_cpu_enums.cc.inc" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" @@ -48,13 +49,13 @@ template LogicalResult BufferizeOp(Op op, RewriterBase &rewriter, const bufferization::BufferizationOptions &options, int64_t num_inputs) { - if (op.getOperands().front().getType().template isa()) { + if (mlir::isa(op.getOperands().front().getType())) { return success(); } SmallVector new_operands; std::optional token = std::nullopt; for (auto operand : op.getOperands()) { - if (operand.getType().template isa()) { + if (mlir::isa(operand.getType())) { assert(operand == op.getOperands().back() && "Expect token type only for last operand"); assert(!token && "Expect at most only one token-typed operand"); @@ -156,8 +157,8 @@ LogicalResult AddDependencyOp::bufferize( } LogicalResult MemRefElementCastOp::verify() { - auto src_memref_ty = getSrc().getType().cast(); - auto dst_memref_ty = getDst().getType().cast(); + auto src_memref_ty = mlir::cast(getSrc().getType()); + auto dst_memref_ty = mlir::cast(getDst().getType()); if (src_memref_ty.getShape() != dst_memref_ty.getShape()) { return emitOpError() << "expects matching shapes"; } diff --git a/third_party/xla/xla/mlir_hlo/BUILD b/third_party/xla/xla/mlir_hlo/BUILD index 761446ff70f3e1..14dd88206dd29b 100644 --- a/third_party/xla/xla/mlir_hlo/BUILD +++ b/third_party/xla/xla/mlir_hlo/BUILD @@ -262,6 +262,7 @@ cc_library( deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -404,6 +405,7 @@ cc_library( deps = [ "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", ], ) @@ -492,6 +494,7 @@ cc_library( "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:SideEffectInterfaces", + "@llvm-project//mlir:Support", "@llvm-project//mlir:ViewLikeInterface", "@stablehlo//:stablehlo_type_inference", ], @@ -641,6 +644,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FuncTransforms", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", @@ -670,6 +674,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:MathDialect", "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:Support", ], ) @@ -764,6 +769,7 @@ cc_library( "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:MemRefDialect", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", ], ) @@ -829,6 +835,7 @@ cc_library( deps = [ ":mlir_hlo", "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", ], ) @@ -883,6 +890,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", @@ -1209,6 +1217,7 @@ cc_library( ":all_passes", ":mlir_hlo", "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:Support", ], ) @@ -1230,6 +1239,7 @@ cc_library( ":all_passes", ":mlir_hlo", "@llvm-project//mlir:CAPIIRObjects", + "@llvm-project//mlir:Support", ], alwayslink = True, ) diff --git a/third_party/xla/xla/mlir_hlo/bindings/c/Attributes.cc b/third_party/xla/xla/mlir_hlo/bindings/c/Attributes.cc index d5183742094a3e..e5c3f4500efbb4 100644 --- a/third_party/xla/xla/mlir_hlo/bindings/c/Attributes.cc +++ b/third_party/xla/xla/mlir_hlo/bindings/c/Attributes.cc @@ -17,6 +17,7 @@ limitations under the License. #include "mhlo/IR/hlo_ops.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" +#include "mlir/Support/LLVM.h" // // ScatterDimensionNumbersAttr. @@ -35,57 +36,50 @@ MlirAttribute mlirMhloScatterDimensionNumbersGet( } bool mlirMhloAttributeIsAScatterDimensionNumbers(MlirAttribute attr) { - return unwrap(attr).isa(); + return mlir::isa(unwrap(attr)); } intptr_t mlirMhloScatterDimensionNumbersGetUpdateWindowDimsSize( MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getUpdateWindowDims() .size(); } int64_t mlirMhloScatterDimensionNumbersGetUpdateWindowDimsElem( MlirAttribute attr, intptr_t pos) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getUpdateWindowDims()[pos]; } intptr_t mlirMhloScatterDimensionNumbersGetInsertedWindowDimsSize( MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getInsertedWindowDims() .size(); } int64_t mlirMhloScatterDimensionNumbersGetInsertedWindowDimsElem( MlirAttribute attr, intptr_t pos) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getInsertedWindowDims()[pos]; } intptr_t mlirMhloScatterDimensionNumbersGetScatteredDimsToOperandDimsSize( MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getScatterDimsToOperandDims() .size(); } int64_t mlirMhloScatterDimensionNumbersGetScatteredDimsToOperandDimsElem( MlirAttribute attr, intptr_t pos) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getScatterDimsToOperandDims()[pos]; } int64_t mlirMhloDimensionNumbersGetIndexVectorDim(MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getIndexVectorDim(); } @@ -105,56 +99,49 @@ MlirAttribute mlirMhloGatherDimensionNumbersGet( } bool mlirMhloAttributeIsAGatherDimensionNumbers(MlirAttribute attr) { - return unwrap(attr).isa(); + return mlir::isa(unwrap(attr)); } intptr_t mlirMhloGatherDimensionNumbersGetOffsetDimsSize(MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getOffsetDims() .size(); } int64_t mlirMhloGatherDimensionNumbersGetOffsetDimsElem(MlirAttribute attr, intptr_t pos) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getOffsetDims()[pos]; } intptr_t mlirMhloGatherDimensionNumbersGetCollapsedSliceDimsSize( MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getCollapsedSliceDims() .size(); } int64_t mlirMhloGatherDimensionNumbersGetCollapsedSliceDimsElem( MlirAttribute attr, intptr_t pos) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getCollapsedSliceDims()[pos]; } intptr_t mlirMhloGatherDimensionNumbersGetStartIndexMapSize( MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getStartIndexMap() .size(); } int64_t mlirMhloGatherDimensionNumbersGetStartIndexMapElem(MlirAttribute attr, intptr_t pos) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getStartIndexMap()[pos]; } int64_t mlirMhloGatherDimensionNumbersGetIndexVectorDim(MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getIndexVectorDim(); } @@ -177,66 +164,58 @@ MlirAttribute mlirMhloDotDimensionNumbersGet( } bool mlirMhloAttributeIsADotDimensionNumbers(MlirAttribute attr) { - return unwrap(attr).isa(); + return mlir::isa(unwrap(attr)); } intptr_t mlirMhloDotDimensionNumbersGetLhsBatchingDimensionsSize( MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getLhsBatchingDimensions() .size(); } int64_t mlirMhloDotDimensionNumbersGetLhsBatchingDimensionsElem( MlirAttribute attr, intptr_t pos) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getLhsBatchingDimensions()[pos]; } intptr_t mlirMhloDotDimensionNumbersGetRhsBatchingDimensionsSize( MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getRhsBatchingDimensions() .size(); } int64_t mlirMhloDotDimensionNumbersGetRhsBatchingDimensionsElem( MlirAttribute attr, intptr_t pos) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getRhsBatchingDimensions()[pos]; } intptr_t mlirMhloDotDimensionNumbersGetLhsContractingDimensionsSize( MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getLhsContractingDimensions() .size(); } int64_t mlirMhloDotDimensionNumbersGetLhsContractingDimensionsElem( MlirAttribute attr, intptr_t pos) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getLhsContractingDimensions()[pos]; } intptr_t mlirMhloDotDimensionNumbersGetRhsContractingDimensionsSize( MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getRhsContractingDimensions() .size(); } int64_t mlirMhloDotDimensionNumbersGetRhsContractingDimensionsElem( MlirAttribute attr, intptr_t pos) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getRhsContractingDimensions()[pos]; } @@ -261,92 +240,80 @@ MlirAttribute mlirMhloConvDimensionNumbersGet( } bool mlirMhloAttributeIsAConvDimensionNumbers(MlirAttribute attr) { - return unwrap(attr).isa(); + return mlir::isa(unwrap(attr)); } int64_t mlirMhloConvDimensionNumbersGetInputBatchDimension(MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getInputBatchDimension(); } int64_t mlirMhloConvDimensionNumbersGetInputFeatureDimension( MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getInputFeatureDimension(); } intptr_t mlirMhloConvDimensionNumbersGetInputSpatialDimensionsSize( MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getInputSpatialDimensions() .size(); } int64_t mlirMhloConvDimensionNumbersGetInputSpatialDimensionsElem( MlirAttribute attr, intptr_t pos) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getInputSpatialDimensions()[pos]; } int64_t mlirMhloConvDimensionNumbersGetKernelInputFeatureDimension( MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getKernelInputFeatureDimension(); } int64_t mlirMhloConvDimensionNumbersGetKernelOutputFeatureDimension( MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getKernelOutputFeatureDimension(); } intptr_t mlirMhloConvDimensionNumbersGetKernelSpatialDimensionsSize( MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getKernelSpatialDimensions() .size(); } int64_t mlirMhloConvDimensionNumbersGetKernelSpatialDimensionsElem( MlirAttribute attr, intptr_t pos) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getKernelSpatialDimensions()[pos]; } int64_t mlirMhloConvDimensionNumbersGetOutputBatchDimension( MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getOutputBatchDimension(); } int64_t mlirMhloConvDimensionNumbersGetOutputFeatureDimension( MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getOutputFeatureDimension(); } intptr_t mlirMhloConvDimensionNumbersGetOutputSpatialDimensionsSize( MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getOutputSpatialDimensions() .size(); } int64_t mlirMhloConvDimensionNumbersGetOutputSpatialDimensionsElem( MlirAttribute attr, intptr_t pos) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getOutputSpatialDimensions()[pos]; } @@ -364,42 +331,37 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirMhloOutputOperandAliasGet( } bool mlirMhloAttributeIsAOutputOperandAlias(MlirAttribute attr) { - return unwrap(attr).isa(); + return mlir::isa(unwrap(attr)); } intptr_t mlirMhloOutputOperandAliasGetOutputTupleIndicesSize( MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getOutputTupleIndices() .size(); } int64_t mlirMhloOutputOperandAliasGetOutputTupleIndicesElem(MlirAttribute attr, intptr_t pos) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getOutputTupleIndices()[pos]; } int64_t mlirMhloOutputOperandAliasGetOperandIndex(MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getOperandIndex(); } intptr_t mlirMhloOutputOperandAliasGetOperandTupleIndicesSize( MlirAttribute attr) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getOperandTupleIndices() .size(); } int64_t mlirMhloOutputOperandAliasGetOperandTupleIndicesElem(MlirAttribute attr, intptr_t pos) { - return unwrap(attr) - .cast() + return mlir::cast(unwrap(attr)) .getOperandTupleIndices()[pos]; } @@ -416,12 +378,13 @@ MlirAttribute mlirMhloComparisonDirectionAttrGet(MlirContext ctx, } bool mlirMhloAttributeIsAComparisonDirectionAttr(MlirAttribute attr) { - return unwrap(attr).isa(); + return mlir::isa(unwrap(attr)); } MlirStringRef mlirMhloComparisonDirectionAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyComparisonDirection( - unwrap(attr).cast().getValue())); + mlir::cast(unwrap(attr)) + .getValue())); } // @@ -438,12 +401,12 @@ MlirAttribute mlirMhloComparisonTypeAttrGet(MlirContext ctx, } bool mlirMhloAttributeIsAComparisonTypeAttr(MlirAttribute attr) { - return unwrap(attr).isa(); + return mlir::isa(unwrap(attr)); } MlirStringRef mlirMhloComparisonTypeAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyComparisonType( - unwrap(attr).cast().getValue())); + mlir::cast(unwrap(attr)).getValue())); } // @@ -458,12 +421,12 @@ MlirAttribute mlirMhloDomainKindAttrGet(MlirContext ctx, MlirStringRef value) { } bool mlirMhloAttributeIsADomainKindAttr(MlirAttribute attr) { - return unwrap(attr).isa(); + return mlir::isa(unwrap(attr)); } MlirStringRef mlirMhloDomainKindAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyDomainKind( - unwrap(attr).cast().getValue())); + mlir::cast(unwrap(attr)).getValue())); } // @@ -478,12 +441,12 @@ MlirAttribute mlirMhloPrecisionAttrGet(MlirContext ctx, MlirStringRef value) { } bool mlirMhloAttributeIsAPrecisionAttr(MlirAttribute attr) { - return unwrap(attr).isa(); + return mlir::isa(unwrap(attr)); } MlirStringRef mlirMhloPrecisionAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyPrecision( - unwrap(attr).cast().getValue())); + mlir::cast(unwrap(attr)).getValue())); } // @@ -498,12 +461,12 @@ MlirAttribute mlirMhloFftTypeAttrGet(MlirContext ctx, MlirStringRef value) { } bool mlirMhloAttributeIsAFftTypeAttr(MlirAttribute attr) { - return unwrap(attr).isa(); + return mlir::isa(unwrap(attr)); } MlirStringRef mlirMhloFftTypeAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyFftType( - unwrap(attr).cast().getValue())); + mlir::cast(unwrap(attr)).getValue())); } // @@ -520,12 +483,12 @@ MlirAttribute mlirMhloDequantizeModeAttrGet(MlirContext ctx, } bool mlirMhloAttributeIsADequantizeModeAttr(MlirAttribute attr) { - return unwrap(attr).isa(); + return mlir::isa(unwrap(attr)); } MlirStringRef mlirMhloDequantizeModeAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyDequantizeMode( - unwrap(attr).cast().getValue())); + mlir::cast(unwrap(attr)).getValue())); } // @@ -540,12 +503,12 @@ MlirAttribute mlirMhloTransposeAttrGet(MlirContext ctx, MlirStringRef value) { } bool mlirMhloAttributeIsATransposeAttr(MlirAttribute attr) { - return unwrap(attr).isa(); + return mlir::isa(unwrap(attr)); } MlirStringRef mlirMhloTransposeAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyTranspose( - unwrap(attr).cast().getValue())); + mlir::cast(unwrap(attr)).getValue())); } // @@ -560,12 +523,12 @@ MlirAttribute mlirMhloFusionKindAttrGet(MlirContext ctx, MlirStringRef value) { } bool mlirMhloAttributeIsAFusionKindAttr(MlirAttribute attr) { - return unwrap(attr).isa(); + return mlir::isa(unwrap(attr)); } MlirStringRef mlirMhloFusionKindAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyFusionKind( - unwrap(attr).cast().getValue())); + mlir::cast(unwrap(attr)).getValue())); } // @@ -582,12 +545,12 @@ MlirAttribute mlirMhloRngDistributionAttrGet(MlirContext ctx, } bool mlirMhloAttributeIsARngDistributionAttr(MlirAttribute attr) { - return unwrap(attr).isa(); + return mlir::isa(unwrap(attr)); } MlirStringRef mlirMhloRngDistributionAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyRngDistribution( - unwrap(attr).cast().getValue())); + mlir::cast(unwrap(attr)).getValue())); } // @@ -604,12 +567,12 @@ MlirAttribute mlirMhloRngAlgorithmAttrGet(MlirContext ctx, } bool mlirMhloAttributeIsARngAlgorithmAttr(MlirAttribute attr) { - return unwrap(attr).isa(); + return mlir::isa(unwrap(attr)); } MlirStringRef mlirMhloRngAlgorithmAttrGetValue(MlirAttribute attr) { return wrap(mlir::mhlo::stringifyRngAlgorithm( - unwrap(attr).cast().getValue())); + mlir::cast(unwrap(attr)).getValue())); } // @@ -622,15 +585,15 @@ MlirAttribute mlirMhloChannelHandleGet(MlirContext ctx, int64_t handle, } bool mlirMhloAttributeIsChannelHandle(MlirAttribute attr) { - return unwrap(attr).isa(); + return mlir::isa(unwrap(attr)); } int64_t mlirMhloChannelHandleGetHandle(MlirAttribute attr) { - return unwrap(attr).cast().getHandle(); + return mlir::cast(unwrap(attr)).getHandle(); } int64_t mlirMhloChannelHandleGetType(MlirAttribute attr) { - return unwrap(attr).cast().getType(); + return mlir::cast(unwrap(attr)).getType(); } // @@ -644,15 +607,18 @@ MlirAttribute mlirMhloTypeExtensionsGet(MlirContext ctx, intptr_t nBounds, } bool mlirMhloAttributeIsTypeExtensions(MlirAttribute attr) { - return unwrap(attr).isa(); + return mlir::isa(unwrap(attr)); } intptr_t mlirMhloTypeExtensionsGetBoundsSize(MlirAttribute attr) { - return unwrap(attr).cast().getBounds().size(); + return mlir::cast(unwrap(attr)) + .getBounds() + .size(); } int64_t mlirMhloTypeExtensionsGetBoundsElem(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getBounds()[pos]; + return mlir::cast(unwrap(attr)) + .getBounds()[pos]; } // @@ -666,17 +632,18 @@ MlirAttribute mlirMhloSparsityDescriptorGet(MlirContext ctx, int64_t dimension, } bool mlirMhloAttributeIsASparsityDescriptor(MlirAttribute attr) { - return unwrap(attr).isa(); + return mlir::isa(unwrap(attr)); } int64_t mlirMhloSparsityDescriptorGetDimension(MlirAttribute attr) { - return unwrap(attr).cast().getDimension(); + return mlir::cast(unwrap(attr)) + .getDimension(); } int64_t mlirMhloSparsityDescriptorGetN(MlirAttribute attr) { - return unwrap(attr).cast().getN(); + return mlir::cast(unwrap(attr)).getN(); } int64_t mlirMhloSparsityDescriptorGetM(MlirAttribute attr) { - return unwrap(attr).cast().getM(); + return mlir::cast(unwrap(attr)).getM(); } diff --git a/third_party/xla/xla/mlir_hlo/bindings/c/Types.cc b/third_party/xla/xla/mlir_hlo/bindings/c/Types.cc index 0be0e34c02069e..4ac9bdd75c1825 100644 --- a/third_party/xla/xla/mlir_hlo/bindings/c/Types.cc +++ b/third_party/xla/xla/mlir_hlo/bindings/c/Types.cc @@ -14,11 +14,12 @@ limitations under the License. #include "mhlo/IR/hlo_ops.h" #include "mlir/CAPI/IR.h" +#include "mlir/Support/LLVM.h" MlirType mlirMhloTokenTypeGet(MlirContext ctx) { return wrap(mlir::mhlo::TokenType::get(unwrap(ctx))); } bool mlirMhloTypeIsAToken(MlirType type) { - return unwrap(type).isa(); + return mlir::isa(unwrap(type)); } diff --git a/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc b/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc index 951dc240e9b53a..4b8930edc961bf 100644 --- a/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc +++ b/third_party/xla/xla/mlir_hlo/deallocation/transforms/buffer_reuse.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -372,7 +373,7 @@ void promoteToStack(memref::DeallocOp dealloc) { auto alloc = dealloc.getMemref().getDefiningOp(); OpBuilder b(alloc); auto alloca = b.create( - alloc->getLoc(), alloc->getResultTypes()[0].cast(), + alloc->getLoc(), mlir::cast(alloc->getResultTypes()[0]), alloc.getAlignmentAttr()); alloc->replaceAllUsesWith(ValueRange{alloca.getResult()}); alloc->erase(); diff --git a/third_party/xla/xla/mlir_hlo/deallocation/utils/util.h b/third_party/xla/xla/mlir_hlo/deallocation/utils/util.h index ce6be44f99d962..c18c8f9dcd2485 100644 --- a/third_party/xla/xla/mlir_hlo/deallocation/utils/util.h +++ b/third_party/xla/xla/mlir_hlo/deallocation/utils/util.h @@ -81,8 +81,8 @@ struct ValueComparator { if (lhs == rhs) return false; // Block arguments are less than results. - bool lhsIsBBArg = lhs.isa(); - if (lhsIsBBArg != rhs.isa()) { + bool lhsIsBBArg = isa(lhs); + if (lhsIsBBArg != isa(rhs)) { return lhsIsBBArg; } diff --git a/third_party/xla/xla/mlir_hlo/lhlo/IR/lhlo_ops.cc b/third_party/xla/xla/mlir_hlo/lhlo/IR/lhlo_ops.cc index 952e3c94751d3b..6adc93683fedf9 100644 --- a/third_party/xla/xla/mlir_hlo/lhlo/IR/lhlo_ops.cc +++ b/third_party/xla/xla/mlir_hlo/lhlo/IR/lhlo_ops.cc @@ -54,6 +54,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" +#include "mlir/Support/LLVM.h" #define GET_ATTRDEF_CLASSES #include "lhlo/IR/lhlo_ops_structs.cc.inc" @@ -104,7 +105,7 @@ LogicalResult AbsOp::verify() { AbsOp op = *this; auto operandType = getElementTypeOrSelf(op.getInput().getType()); auto outputType = getElementTypeOrSelf(op.getOutput().getType()); - if (auto complexType = operandType.dyn_cast()) { + if (auto complexType = mlir::dyn_cast(operandType)) { if (complexType.getElementType() != outputType) { return op.emitOpError( "requires output type to be the same as the element type of the " @@ -225,8 +226,8 @@ LogicalResult CustomCallOp::verify() { // configurations are applied to the operand. LogicalResult PadOp::verify() { PadOp op = *this; - auto operandType = op.getOperand().getType().dyn_cast(); - auto outputType = op.getOutput().getType().dyn_cast(); + auto operandType = mlir::dyn_cast(op.getOperand().getType()); + auto outputType = mlir::dyn_cast(op.getOutput().getType()); if (!(operandType && outputType && operandType.hasRank() && outputType.hasRank())) { return success(); diff --git a/third_party/xla/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_affine/lhlo_legalize_to_affine.cc b/third_party/xla/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_affine/lhlo_legalize_to_affine.cc index 45e44cc1161827..1a5c5099924a00 100644 --- a/third_party/xla/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_affine/lhlo_legalize_to_affine.cc +++ b/third_party/xla/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_affine/lhlo_legalize_to_affine.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Location.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -55,8 +56,8 @@ struct DotOpConverter : public OpRewritePattern { PatternRewriter& rewriter) const override { Value lhs = op.getLhs(); Value rhs = op.getRhs(); - MemRefType lhsType = lhs.getType().cast(); - MemRefType rhsType = rhs.getType().cast(); + MemRefType lhsType = mlir::cast(lhs.getType()); + MemRefType rhsType = mlir::cast(rhs.getType()); Type elementType = lhsType.getElementType(); ArrayRef shapeLhs = lhsType.getShape(); ArrayRef shapeRhs = rhsType.getShape(); @@ -130,7 +131,7 @@ struct ConcatOpConverter : public OpRewritePattern { PatternRewriter& rewriter) const override { Location loc = op.getLoc(); Value output = op.getOutput(); - MemRefType outputType = output.getType().cast(); + MemRefType outputType = mlir::cast(output.getType()); unsigned outputRank = outputType.getRank(); ArrayRef outputShape = outputType.getShape(); @@ -139,7 +140,7 @@ struct ConcatOpConverter : public OpRewritePattern { int64_t prevBound = 0; for (Value operand : operands) { - MemRefType operandType = operand.getType().cast(); + MemRefType operandType = mlir::cast(operand.getType()); ArrayRef operandShape = operandType.getShape(); // TODO(pashu123): Extend support for dynamic dimensions. @@ -189,11 +190,11 @@ struct ConcatOpConverter : public OpRewritePattern { static Value getZeroValue(Type type, Location loc, PatternRewriter& rewriter) { assert(type.isIntOrFloat() && "Expected int or float"); - if (IntegerType intType = type.dyn_cast()) + if (IntegerType intType = mlir::dyn_cast(type)) return rewriter.create(loc, 0, intType.getWidth()); - FloatType floatType = type.cast(); + FloatType floatType = mlir::cast(type); return rewriter.create( loc, APFloat::getZero(floatType.getFloatSemantics()), floatType); } @@ -202,7 +203,7 @@ static Value getZeroValue(Type type, Location loc, PatternRewriter& rewriter) { static void fillBuffer(Location loc, Value buffer, Value fillValue, PatternRewriter& builder) { OpBuilder::InsertionGuard guard(builder); - MemRefType bufType = buffer.getType().cast(); + MemRefType bufType = mlir::cast(buffer.getType()); unsigned rank = bufType.getRank(); SmallVector dimSizes; dimSizes.reserve(rank); @@ -220,7 +221,7 @@ static void fillBuffer(Location loc, Value buffer, Value fillValue, ivs[i] = forOp.getInductionVar(); } Type fillValueType = fillValue.getType(); - auto fillMemRefType = fillValueType.dyn_cast(); + auto fillMemRefType = mlir::dyn_cast(fillValueType); assert(((fillMemRefType && fillMemRefType.getRank() == 0) || fillValueType.isIntOrFloat()) && "init value has to be a 0-d memref or int or fp"); @@ -252,19 +253,20 @@ class GatherOpConverter : public OpRewritePattern { // Operand array. Value operand = op.getOperand(); - MemRefType operandType = operand.getType().cast(); + MemRefType operandType = mlir::cast(operand.getType()); unsigned operandRank = operandType.getRank(); ArrayRef operandShape = operandType.getShape(); // Start_indices array. Value startIndices = op.getStartIndices(); - MemRefType startIndicesType = startIndices.getType().cast(); + MemRefType startIndicesType = + mlir::cast(startIndices.getType()); unsigned startIndicesRank = startIndicesType.getRank(); ArrayRef startIndicesShape = startIndicesType.getShape(); // Output array. Value output = op.getOutput(); - MemRefType outputType = output.getType().cast(); + MemRefType outputType = mlir::cast(output.getType()); ArrayRef outputShape = outputType.getShape(); if (!operandType.hasStaticShape() || !startIndicesType.hasStaticShape() || @@ -450,7 +452,7 @@ class GatherOpConverter : public OpRewritePattern { rewriter.create(loc, output, outputInductionVars); // The selected value is added to the previous value stored in output array. - if (elementType.isa()) + if (mlir::isa(elementType)) outputValue = rewriter.create(loc, elementType, selectLoad, outputValue); else @@ -486,8 +488,8 @@ struct PadOpConverter : public OpRewritePattern { Value paddingValue = op.getPaddingValue(); Value output = op.getOutput(); - auto operandType = operand.getType().dyn_cast(); - auto outputType = output.getType().dyn_cast(); + auto operandType = mlir::dyn_cast(operand.getType()); + auto outputType = mlir::dyn_cast(output.getType()); // We allow lowering for only ranked input/output. if (!(operandType && outputType && operandType.hasRank() && outputType.hasRank())) @@ -574,8 +576,8 @@ struct BinaryOpConverter : public OpRewritePattern { PatternRewriter& rewriter) const override { Value lhs = op.getLhs(); Value rhs = op.getRhs(); - const auto& lhsType = lhs.getType().template cast(); - const auto& rhsType = rhs.getType().template cast(); + const auto& lhsType = mlir::cast(lhs.getType()); + const auto& rhsType = mlir::cast(rhs.getType()); const auto& elementType = lhsType.getElementType(); if (lhsType.getShape() != rhsType.getShape()) { @@ -611,7 +613,7 @@ struct UnaryOpConverter : public OpRewritePattern { LogicalResult matchAndRewrite(LhloOpTy op, PatternRewriter& rewriter) const override { Value input = op.getInput(); - auto inputType = input.getType().cast(); + auto inputType = mlir::cast(input.getType()); auto elementType = inputType.getElementType(); ArrayRef shape = inputType.getShape(); diff --git a/third_party/xla/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_gpu/lhlo_legalize_to_gpu.cc b/third_party/xla/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_gpu/lhlo_legalize_to_gpu.cc index 7572655977e21d..492f75011e5da0 100644 --- a/third_party/xla/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_gpu/lhlo_legalize_to_gpu.cc +++ b/third_party/xla/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_gpu/lhlo_legalize_to_gpu.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -64,7 +65,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { // Only support 1d reductions for now. int64_t size = 0; for (auto result : reduceOp.getOut()) { - auto shapedType = result.getType().dyn_cast(); + auto shapedType = mlir::dyn_cast(result.getType()); if (!shapedType || shapedType.getRank() != 1) { return failure(); } @@ -80,7 +81,7 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { // Require all inputs to have the same shape. int64_t reduceDimSize = 0; for (auto input : reduceOp.getInputs()) { - auto shapedType = input.getType().dyn_cast(); + auto shapedType = mlir::dyn_cast(input.getType()); if (!shapedType || !shapedType.hasStaticShape()) { return failure(); } @@ -139,7 +140,8 @@ class LhloReduceToGPULaunchConverter : public OpConversionPattern { loc, resType, output, offset, size, stride); llvm::SmallVector indexings; Value inputBuffer = reduceOp.getInputs().front(); - auto inputTypeRank = inputBuffer.getType().cast().getRank(); + auto inputTypeRank = + mlir::cast(inputBuffer.getType()).getRank(); Value input = *reduceOp.operand_begin(); SmallVector offsets = llvm::to_vector<4>(llvm::map_range( diff --git a/third_party/xla/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_parallel_loops/lhlo_legalize_to_parallel_loops.cc b/third_party/xla/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_parallel_loops/lhlo_legalize_to_parallel_loops.cc index f3c6b3940aced5..3b4568c16d8206 100644 --- a/third_party/xla/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_parallel_loops/lhlo_legalize_to_parallel_loops.cc +++ b/third_party/xla/xla/mlir_hlo/lhlo/transforms/lhlo_legalize_to_parallel_loops/lhlo_legalize_to_parallel_loops.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -55,7 +56,7 @@ Value applySingleResultLhloCode(Location loc, ValueRange operands, SmallVector argBufs; for (auto argType : lhloBlock->getArgumentTypes()) { argBufs.push_back( - b->create(loc, argType.cast())); + b->create(loc, mlir::cast(argType))); } for (const auto& operand : llvm::enumerate(operands)) { b->create(loc, operand.value(), argBufs[operand.index()]); @@ -116,7 +117,7 @@ MappedIvs mapWindowIvsToInput(OpTy op, Value operand, ValueRange ivs, auto padding = op.getPadding().value(); auto loc = op.getLoc(); - auto operandShape = operand.getType().template cast().getShape(); + auto operandShape = mlir::cast(operand.getType()).getShape(); // `in_bounds` is false when the mapped indices are in the padding area. mappedIvs.inBounds = b->create( @@ -154,7 +155,8 @@ scf::ParallelOp makeLoopOverShape(Location loc, Value shapedValue, Value zero = b->create(loc, 0); Value one = b->create(loc, 1); - ArrayRef shape = shapedValue.getType().cast().getShape(); + ArrayRef shape = + mlir::cast(shapedValue.getType()).getShape(); SmallVector lower, upper, step; for (const auto& dim : llvm::enumerate(shape)) { upper.push_back( @@ -249,7 +251,7 @@ class ReduceOpConverter : public OpConversionPattern { Value out = reduceOp.getOut().front(); SmallVector parallelLower, parallelUpper, parallelStep; SmallVector reduceLower, reduceUpper, reduceStep; - auto operandShape = operand.getType().cast().getShape(); + auto operandShape = mlir::cast(operand.getType()).getShape(); for (const auto& dim : llvm::enumerate(operandShape)) { const bool isReducingDim = reducingDims.count(dim.index()); @@ -442,7 +444,7 @@ class ReduceWindowOpConverter } Value input = reduceWindowOp.getInputs()[0]; - auto inputType = input.getType().cast(); + auto inputType = mlir::cast(input.getType()); // Compute ivs in 'arg' buffer and whether these ivs are in pad area or not. MappedIvs mappedIvs = mapWindowIvsToInput( @@ -555,7 +557,7 @@ class SelectAndScatterOpConverter Value one = b->create(loc, 1); auto elementType = - sAndSOp.getOut().getType().cast().getElementType(); + mlir::cast(sAndSOp.getOut().getType()).getElementType(); auto rank = loopOverSrc.getNumLoops(); // `iter_args` = [iv_1, ..., iv_N, selected_value, is_initialized] diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc index a6a38886841a76..e42f21fc0dfb62 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops.cc @@ -3350,8 +3350,8 @@ Operation* ReduceWindowOp::getReductionOp(int resultIndex) { auto returnOp = cast(getBody().front().getTerminator()); Operation* computeOp = returnOp.getResults()[resultIndex].getDefiningOp(); if (computeOp->getNumOperands() != 2) return nullptr; - auto arg0 = computeOp->getOperand(0).dyn_cast(); - auto arg1 = computeOp->getOperand(1).dyn_cast(); + auto arg0 = dyn_cast(computeOp->getOperand(0)); + auto arg1 = dyn_cast(computeOp->getOperand(1)); if (!arg0 || !arg1) return nullptr; int64_t arg0Num = arg0.getArgNumber(); int64_t arg1Num = arg1.getArgNumber(); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_common.cc b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_common.cc index ff69d84d08879f..073707b75c093b 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_common.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/IR/hlo_ops_common.cc @@ -25,13 +25,14 @@ limitations under the License. #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/Support/LLVM.h" namespace mlir { namespace hlo { // Verifies the source target pairs attached to collective permute. LogicalResult verifyCollectivePermuteSourceTargetPairs( Operation *op, DenseIntElementsAttr attr) { - auto type = attr.getType().cast(); + auto type = mlir::cast(attr.getType()); if (type.getRank() != 2) return op->emitError() << "expect source_target_pairs attribute to be of " "rank 2, but got rank " @@ -73,8 +74,8 @@ LogicalResult verifyReduceScatter(Operation *op, TypeRange operandTypes, } for (auto it : llvm::zip(operandTypes, resultTypes)) { - auto operandType = std::get<0>(it).cast(); - auto resultType = std::get<1>(it).cast(); + auto operandType = mlir::cast(std::get<0>(it)); + auto resultType = mlir::cast(std::get<1>(it)); if (!operandType.hasRank() || !resultType.hasRank()) continue; if (operandType.getRank() != resultType.getRank()) return op->emitOpError() << "operand and result should have same rank"; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc b/third_party/xla/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc index 4d21d8aecf6b0d..94eee2a46d644d 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/analysis/shape_component_analysis.cc @@ -83,7 +83,8 @@ struct ShapeVisitor { // Skip irrelevant cases early. Value value = transitivelyRequestedInfo.value(); Type ty = value.getType(); - if (!ty.isIntOrIndexOrFloat() && !ty.isa()) continue; + if (!ty.isIntOrIndexOrFloat() && !mlir::isa(ty)) + continue; // Handle shapes. if (transitivelyRequestedInfo.isShapeInfo()) { @@ -101,7 +102,7 @@ struct ShapeVisitor { backwardTransposeShape(transpose); } else if (auto select = value.getDefiningOp()) { backwardSelectShape(select); - } else if (auto arg = value.dyn_cast()) { + } else if (auto arg = mlir::dyn_cast(value)) { backwardBlockArgumentShape(arg); } else if (value.getDefiningOp() && value.getDefiningOp() @@ -114,7 +115,7 @@ struct ShapeVisitor { } // Skip irrelevant cases early. - auto rankedTy = ty.dyn_cast(); + auto rankedTy = mlir::dyn_cast(ty); bool isPossiblyInterestingScalar = ty.isIntOrIndex(); bool isPossiblyInterestingTensor = rankedTy && rankedTy.getRank() <= 1 && rankedTy.hasStaticShape(); @@ -245,7 +246,7 @@ struct ShapeVisitor { void backwardAssumingShape(Value op) { auto assumingOp = op.getDefiningOp(); - auto number = op.cast().getResultNumber(); + auto number = mlir::cast(op).getResultNumber(); forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op)); backwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf( cast( @@ -254,7 +255,7 @@ struct ShapeVisitor { } void forwardAssumingShape(Value op) { auto assumingOp = op.getDefiningOp(); - auto number = op.cast().getResultNumber(); + auto number = mlir::cast(op).getResultNumber(); auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op)); dims = lookup(ShapeOrValueInfo::getShapeInfoOf( cast( @@ -338,7 +339,7 @@ struct ShapeVisitor { ShapeOrValueInfo::getValueInfoOf(op.getOutputShape())); } void forwardDynamicReshapeShape(mhlo::DynamicReshapeOp op) { - auto rankedTy = op.getResult().getType().cast(); + auto rankedTy = mlir::cast(op.getResult().getType()); auto shapeDims = lookup(ShapeOrValueInfo::getValueInfoOf(op.getOutputShape())); auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op)); @@ -370,7 +371,7 @@ struct ShapeVisitor { void forwardTransposeShape(mhlo::TransposeOp op) { auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(op)); auto in = lookup(ShapeOrValueInfo::getShapeInfoOf(op.getOperand())); - auto elem = op.getPermutation().cast(); + auto elem = mlir::cast(op.getPermutation()); for (const auto &val : elem) dims.push_back(in[val.getZExtValue()]); } void backwardSelectShape(mhlo::SelectOp op) { @@ -440,7 +441,7 @@ struct ShapeVisitor { forwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(v)); } void forwardUnknownShape(Value v) { - auto rankedTy = v.getType().dyn_cast(); + auto rankedTy = mlir::dyn_cast(v.getType()); if (!rankedTy) return; auto id = getAffineSymbolExpr(0, v.getContext()); auto &dims = insert(ShapeOrValueInfo::getShapeInfoOf(v)); @@ -465,7 +466,7 @@ struct ShapeVisitor { backwardsWorklist.push_back(ShapeOrValueInfo::getShapeInfoOf(op.getArg())); } void forwardShapeOf(shape::ShapeOfOp op) { - auto rankedTy = op.getArg().getType().cast(); + auto rankedTy = mlir::cast(op.getArg().getType()); auto arg = lookup(ShapeOrValueInfo::getShapeInfoOf(op.getArg())); auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op)); return dimsFromStaticShape(rankedTy, arg, &dims); @@ -521,7 +522,7 @@ struct ShapeVisitor { void forwardDim(tensor::DimOp op) { auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op)); if (auto index = op.getIndex().getDefiningOp()) { - int64_t i = index.getValue().cast().getInt(); + int64_t i = mlir::cast(index.getValue()).getInt(); auto in = lookup(ShapeOrValueInfo::getShapeInfoOf(op.getSource())); if (i >= static_cast(in.size()) || i < 0) llvm::report_fatal_error("tensor dim out of bounds"); @@ -591,7 +592,7 @@ struct ShapeVisitor { assert(op.getIndices().size() == 1); if (auto index = op.getIndices().front().getDefiningOp()) { - int64_t i = index.getValue().cast().getInt(); + int64_t i = mlir::cast(index.getValue()).getInt(); // We asssume this is in bounds. auto in = lookup(ShapeOrValueInfo::getValueInfoOf(op.getTensor())); dims.push_back({in[i].symbols, in[i].expr}); @@ -661,7 +662,7 @@ struct ShapeVisitor { } auto &dims = insert(ShapeOrValueInfo::getValueInfoOf(op)); auto in = lookup(ShapeOrValueInfo::getValueInfoOf(op.getOperand())); - auto elem = op.getStartIndices().cast(); + auto elem = mlir::cast(op.getStartIndices()); auto i = (*elem.begin()).getZExtValue(); if (i >= in.size()) { // Bounds check. return forwardUnknown(op); @@ -711,7 +712,7 @@ struct ShapeVisitor { // Return the size of the first dimension. Returns 1 for scalars. static int64_t dim0size(Type type) { - if (auto rankedType = type.dyn_cast()) + if (auto rankedType = mlir::dyn_cast(type)) return rankedType.getRank() == 0 ? 1 : rankedType.getDimSize(0); return 1; } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc index 56ea0a42e79c19..da27173913f81e 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/broadcast_propagation/broadcast_propagation.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -152,7 +153,7 @@ void findBroadcastIntents( // Derive the broadcast intent associated with the root broadcast operation. // Add it to the worklist to seed the analysis. - rootBcastIntent = {root.getResult().getType().cast(), + rootBcastIntent = {mlir::cast(root.getResult().getType()), root.getOperand(), root.getOutputDimensions(), root.getBroadcastDimensions()}; addToWorklistIfNew(rootBcastIntent); @@ -177,7 +178,7 @@ void findBroadcastIntents( llvm::dyn_cast(producerOp)) { DenseIntElementsAttr composedBcastDims = composeBroadcastDimensionsAttr( builder, producerBcastOp.getBroadcastDimensions(), - it.broadcastDimensions.cast()); + mlir::cast(it.broadcastDimensions)); BroadcastIntent bcastedOperandIntent = { it.resultType, producerBcastOp.getOperand(), it.outputDimensions, composedBcastDims}; @@ -194,7 +195,7 @@ void findBroadcastIntents( assert(allowsForElementwiseBroadcastPropagation(producerOp)); bcastIntentDependencies[it] = {}; for (auto operand : producerOp->getOperands()) { - auto operandTy = operand.getType().cast(); + auto operandTy = mlir::cast(operand.getType()); auto operandBcastDims = operandTy.getRank() == 0 ? builder.getI64TensorAttr({}) : it.broadcastDimensions; @@ -272,7 +273,7 @@ DenseMap realizeBroadcastIntents( realizations[it] = rewriter.create( it.targetValue.getLoc(), it.resultType, it.targetValue, it.outputDimensions, - it.broadcastDimensions.cast()); + mlir::cast(it.broadcastDimensions)); continue; } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc index 7f1e7de597b535..60fcd198853911 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/collapse_elementwise_map/collapse_elementwise_map.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -73,7 +74,7 @@ struct ConvertMapOfElementwiseOps : public OpRewritePattern { operands.push_back(blockAndValueMap.lookup(value)); auto *newOp = rewriter.create( op.getLoc(), op.getName().getIdentifier(), operands, - op.getResultTypes()[0].cast().clone(shape)); + mlir::cast(op.getResultTypes()[0]).clone(shape)); // Maps the result. blockAndValueMap.map(op.getResult(0), newOp->getResult(0)); } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/constraint_fusion/constraint_fusion_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/constraint_fusion/constraint_fusion_pass.cc index 7dc6080c07ccb4..f0d9b2442619dc 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/constraint_fusion/constraint_fusion_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/constraint_fusion/constraint_fusion_pass.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" namespace mlir { namespace mhlo { @@ -286,7 +287,7 @@ LogicalResult analyzeBroadcastableConstraint( // For shapes without a definition, expect them to be an argument of the // regarded block. if (def == nullptr) { - auto barg = shape.dyn_cast(); + auto barg = mlir::dyn_cast(shape); if (!barg || barg.getParentBlock() != theBlock) return failure(); transitiveBcastableCstrOperands.push_back( CstrBroadcastableOperand::valueOf(barg)); @@ -299,7 +300,7 @@ LogicalResult analyzeBroadcastableConstraint( if (auto sof = llvm::dyn_cast(def)) { if (!isWithinBlock(sof, theBlock)) return failure(); tryFlagForErase(theBlock, def, toBeErased); - auto barg = sof.getArg().dyn_cast(); + auto barg = mlir::dyn_cast(sof.getArg()); if (!barg) return failure(); transitiveBcastableCstrOperands.push_back( CstrBroadcastableOperand::shapeOf(barg)); @@ -351,7 +352,7 @@ LogicalResult analyzeBlockGlobalConstraints( // For witnesses without a definition, expect it to be an argument of the // regarded block. if (def == nullptr) { - auto barg = cstr.dyn_cast(); + auto barg = mlir::dyn_cast(cstr); if (!barg || barg.getParentBlock() != theBlock) return failure(); argumentCstrs.push_back(barg); continue; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/convert_to_signless/convert_to_signless_pass.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/convert_to_signless/convert_to_signless_pass.cc index c95ea747ea1362..dffab945389ee1 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/convert_to_signless/convert_to_signless_pass.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/convert_to_signless/convert_to_signless_pass.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" @@ -86,12 +87,13 @@ class ConvertConstantToSignless arith::ConstantOp constantOp, arith::ConstantOpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { // We only care about unsigned integers - if (!adaptor.getValue().isa()) return failure(); + if (!mlir::isa(adaptor.getValue())) return failure(); - auto values = llvm::to_vector( - adaptor.getValue().cast().getValues()); + auto values = + llvm::to_vector(mlir::cast(adaptor.getValue()) + .getValues()); Type type = typeConverter->convertType(constantOp.getType()); - auto shapedType = type.dyn_cast(); + auto shapedType = mlir::dyn_cast(type); auto newValues = DenseIntElementsAttr::get( shapedType, values); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/expand_hlo_tuples/expand_hlo_tuples.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/expand_hlo_tuples/expand_hlo_tuples.cc index 6b514e720fc53b..7b4426d94704d2 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/expand_hlo_tuples/expand_hlo_tuples.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/expand_hlo_tuples/expand_hlo_tuples.cc @@ -32,6 +32,7 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassRegistry.h" +#include "mlir/Support/LLVM.h" namespace mlir { namespace mhlo { @@ -69,7 +70,7 @@ class ExpandHloTuplesPass func.getArguments().end()); for (auto argument : funcArguments) { auto type = argument.getType(); - auto tupleType = type.dyn_cast_or_null(); + auto tupleType = mlir::dyn_cast_or_null(type); if (!tupleType) { expandedInputTypes.push_back(type); } else { @@ -109,7 +110,7 @@ class ExpandHloTuplesPass SmallVector expandedReturnOperands; SmallVector expandedResultTypes; for (auto value : returnOp.getOperands()) { - if (auto tupleTy = value.getType().dyn_cast()) { + if (auto tupleTy = mlir::dyn_cast(value.getType())) { llvm::copy(tupleTy.getTypes(), std::back_inserter(expandedResultTypes)); for (auto [index, ty] : llvm::enumerate(tupleTy.getTypes())) { expandedReturnOperands.push_back( @@ -145,7 +146,7 @@ class ExpandHloTuplesPass while ( llvm::any_of(llvm::concat(entryFunction.getArgumentTypes(), entryFunction.getResultTypes()), - [](Type type) { return type.isa(); })) { + [](Type type) { return mlir::isa(type); })) { expandTupledTensorInReturnOp(entryFunction); } } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/expand_ops_simplifier/expand_ops_simplifier.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/expand_ops_simplifier/expand_ops_simplifier.cc index 5113736c0b3bba..b75f13e5c8417b 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/expand_ops_simplifier/expand_ops_simplifier.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/expand_ops_simplifier/expand_ops_simplifier.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -162,7 +163,7 @@ struct SelectAndScatterExpanderPattern llvm::SmallVector concatenatedIotasDims; concatenatedIotasDims.reserve( - iotaIndices.front().getType().cast().getRank()); + mlir::cast(iotaIndices.front().getType()).getRank()); concatenatedIotasDims.insert(concatenatedIotasDims.end(), broadcastedIotaDims.begin(), broadcastedIotaDims.end()); @@ -189,8 +190,8 @@ struct SelectAndScatterExpanderPattern llvm::SmallVector scatterIns; llvm::SmallVector scatterLocs; scatterIns.push_back(RankedTensorType::get( - {}, - broadcastedInitValue.getType().cast().getElementType())); + {}, mlir::cast(broadcastedInitValue.getType()) + .getElementType())); scatterIns.push_back( RankedTensorType::get({}, source.getType().getElementType())); scatterLocs.push_back(broadcastedInitValue.getLoc()); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc index d2058c0e23254f..7bbc78cd0dc598 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/group_reduction_dimensions/group_reduction_dimensions.cc @@ -31,6 +31,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Types.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -226,8 +227,9 @@ LogicalResult tryLowerTo1DOr2DReduction( auto reductionDimAttr = rewriter.getI64VectorAttr({reductionDim}); Value initVal = op.getInitValues().front(); SmallVector elementTypes{llvm::map_range( - op.getBody().front().getTerminator()->getOperands(), - [](Value v) { return v.getType().cast().getElementType(); })}; + op.getBody().front().getTerminator()->getOperands(), [](Value v) { + return mlir::cast(v.getType()).getElementType(); + })}; auto reductionOp = rewriter.create(loc, intermResult, initVal, reductionDimAttr, elementTypes); rewriter.inlineRegionBefore(op.getBody(), reductionOp.getBody(), @@ -235,7 +237,7 @@ LogicalResult tryLowerTo1DOr2DReduction( intermResult = reductionOp->getResults().front(); // Restore the expected shape by dynamic reshape, if required. - auto resultTy = op->getResultTypes().front().cast(); + auto resultTy = mlir::cast(op->getResultTypes().front()); if (requiresDynamicReshape) { assert(resultShape && "expect to have reified the result shape"); intermResult = rewriter.create( @@ -245,7 +247,7 @@ LogicalResult tryLowerTo1DOr2DReduction( // Othwerise, restore the expected shape by shape expansion, if required. int64_t resultRank = resultTy.getRank(); int64_t intermResultRank = - intermResult.getType().cast().getRank(); + mlir::cast(intermResult.getType()).getRank(); bool requiresExpand = !requiresDynamicReshape && resultRank != intermResultRank; if (requiresExpand) { @@ -276,11 +278,11 @@ struct GroupReductionDimensionsPattern : public OpRewritePattern { return failure(); Value arg = op.getInputs().front(); // Only apply to non-sparse tensors. - if (auto rtp = arg.getType().cast(); + if (auto rtp = mlir::cast(arg.getType()); rtp.getEncoding() != nullptr) return failure(); - auto argTy = arg.getType().cast(); + auto argTy = mlir::cast(arg.getType()); // Sort reduction dimensions, which is not an invariant of the op. SmallVector orderedReductionDims = diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_shape_ops_to_standard/hlo_legalize_shape_ops_to_standard.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_shape_ops_to_standard/hlo_legalize_shape_ops_to_standard.cc index 702deddaa0a51b..31a5d7592ce9b7 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_shape_ops_to_standard/hlo_legalize_shape_ops_to_standard.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_shape_ops_to_standard/hlo_legalize_shape_ops_to_standard.cc @@ -38,6 +38,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" @@ -61,7 +62,7 @@ struct ComputeReshapeShapeConversion auto indexType = rewriter.getIndexType(); auto numElements = adaptor.getOperands()[0]; auto targetShapeType = - adaptor.getOperands()[1].getType().cast(); + mlir::cast(adaptor.getOperands()[1].getType()); auto extentType = shape::getExtentTensorType(ctx, targetShapeType.getDimSize(0)); @@ -128,7 +129,7 @@ struct CstrReshapableConversion Value one = rewriter.create(loc, 1); auto numElements = adaptor.getOperands()[0]; auto targetShapeType = - adaptor.getOperands()[1].getType().cast(); + mlir::cast(adaptor.getOperands()[1].getType()); auto extentType = shape::getExtentTensorType(ctx, targetShapeType.getDimSize(0)); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc index fead7be62bf1c2..a3a9f6664f9544 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_arithmetic/hlo_legalize_to_arithmetic.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -125,7 +126,7 @@ struct ScalarHloToArithmeticPattern : public OpConversionPattern { if (filterFn && !filterFn(op)) return failure(); auto isScalar = [&](Value v) { - return v.getType().cast().getRank() == 0; + return mlir::cast(v.getType()).getRank() == 0; }; if (!llvm::all_of(adaptor.getOperands(), isScalar)) @@ -134,8 +135,8 @@ struct ScalarHloToArithmeticPattern : public OpConversionPattern { auto loc = op.getLoc(); std::optional resultTy; - resultTy = this->typeConverter->convertType(op->getResultTypes().front()) - .template dyn_cast(); + resultTy = mlir::dyn_cast( + this->typeConverter->convertType(op->getResultTypes().front())); SmallVector operands; for (auto operand : adaptor.getOperands()) { diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_memref/hlo_legalize_to_memref.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_memref/hlo_legalize_to_memref.cc index 1df0035b53e6fe..a27a4d683625bf 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_memref/hlo_legalize_to_memref.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_memref/hlo_legalize_to_memref.cc @@ -33,6 +33,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" namespace mlir { namespace mhlo { @@ -76,14 +77,14 @@ struct CustomCallOpInterface SmallVector bufferArgs; for (OpOperand &operand : customCallOp->getOpOperands()) { auto &newBuffer = bufferArgs.emplace_back(); - if (operand.get().getType().isa()) { + if (mlir::isa(operand.get().getType())) { // Remember the token for later. We need it for the return value but // it's not getting passed to LMHLO. if (tokenArgument) return failure(); tokenArgument = operand.get(); continue; } - if (!operand.get().getType().isa()) return failure(); + if (!mlir::isa(operand.get().getType())) return failure(); FailureOr operandBuffer = getBuffer(rewriter, operand.get(), options); if (failed(operandBuffer)) return failure(); @@ -93,10 +94,10 @@ struct CustomCallOpInterface // Allocate outputs. for (OpResult result : customCallOp->getOpResults()) { auto &newBuffer = bufferArgs.emplace_back(); - if (result.getType().isa()) { + if (mlir::isa(result.getType())) { continue; } - auto tensorType = result.getType().dyn_cast(); + auto tensorType = mlir::dyn_cast(result.getType()); if (!tensorType) return failure(); // TODO(springerm): Create alloc_tensor ops during TensorCopyInsertion. AnalysisState analysisState(options); @@ -185,7 +186,7 @@ struct ReshapeOpInterface const BufferizationOptions &options) const { auto reshapeOp = cast(op); auto unrankedOperandType = - reshapeOp.getOperand().getType().dyn_cast(); + mlir::dyn_cast(reshapeOp.getOperand().getType()); if (unrankedOperandType == nullptr) return success(); // The buffer still has the old (pre-reshape) type. @@ -193,7 +194,7 @@ struct ReshapeOpInterface getBuffer(rewriter, reshapeOp.getOperand(), options); if (failed(operandBuffer)) return failure(); - auto resultType = reshapeOp.getType().cast(); + auto resultType = mlir::cast(reshapeOp.getType()); auto destType = MemRefType::get(resultType.getShape(), resultType.getElementType()); replaceOpWithNewBufferizedOp(rewriter, op, destType, @@ -233,16 +234,16 @@ struct DynamicReshapeOpInterface ShapedType resultType; TensorType opResultType = reshapeOp.getType(); - if (auto rankedType = opResultType.dyn_cast()) { + if (auto rankedType = mlir::dyn_cast(opResultType)) { resultType = MemRefType::get(rankedType.getShape(), rankedType.getElementType()); } else if (auto unrankedType = - opResultType.dyn_cast()) { + mlir::dyn_cast(opResultType)) { resultType = UnrankedMemRefType::get(unrankedType.getElementType(), 0); } auto operand = *operandBuffer; // If the operand has a non-identity affine map, we will have to add a copy. - auto bufferType = operandBuffer->getType().dyn_cast(); + auto bufferType = mlir::dyn_cast(operandBuffer->getType()); if (bufferType && !bufferType.getLayout().isIdentity()) { // TODO(springerm): Create alloc_tensor ops during TensorCopyInsertion. AnalysisState analysisState(options); @@ -268,11 +269,11 @@ FailureOr insertDynamicMemrefCastOp( mhlo::DynamicBroadcastInDimOp op, Value operand, RewriterBase &rewriter, const BufferizationOptions &options) { auto loc = op.getLoc(); - auto operandType = operand.getType().cast(); + auto operandType = mlir::cast(operand.getType()); auto operandShape = operandType.getShape(); auto operandRank = operandType.getRank(); - auto resultType = op.getType().cast(); + auto resultType = mlir::cast(op.getType()); auto resultRank = resultType.getRank(); Value zero = rewriter.create(loc, 0); @@ -380,7 +381,8 @@ struct DynamicBroadcastInDimOpInterface LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto broadcastInDimOp = cast(op); - auto resultType = broadcastInDimOp.getType().dyn_cast(); + auto resultType = + mlir::dyn_cast(broadcastInDimOp.getType()); if (!resultType) return success(); // The buffer still has the old (pre-reshape) type. diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc index a7ef5e93d1585b..5dc68ffc7762a8 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Types.h" #include "mlir/Support/DebugStringHelper.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/RegionUtils.h" @@ -81,7 +82,7 @@ bool hasPrivateFeaturesNotInStablehlo(HloOpTy hloOp) { bool hasPackedNibble(std::optional precisionConfigAttr) { if (!precisionConfigAttr) return false; return llvm::any_of(*precisionConfigAttr, [&](Attribute attr) { - auto precisionAttr = attr.cast(); + auto precisionAttr = mlir::cast(attr); return precisionAttr.getValue() == mhlo::Precision::PACKED_NIBBLE; }); } @@ -214,7 +215,7 @@ bool isDenseI64Array(mlir::StringAttr hloName) { template Attribute convertDenseArray(mlir::StringAttr hloName, Attribute hloAttr) { - auto denseInts = hloAttr.dyn_cast(); + auto denseInts = mlir::dyn_cast(hloAttr); if (!denseInts) return {}; if ((std::is_same::value || @@ -243,17 +244,17 @@ Attribute convertAttr(Attribute hloAttr) { // Handle MHLO attributes. // The logic that handles attributes from other dialects (e.g. builtin // attributes) lives below. - if (auto attr = hloAttr.dyn_cast()) { + if (auto attr = mlir::dyn_cast(hloAttr)) { return stablehlo::ChannelHandleAttr::get(attr.getContext(), attr.getHandle(), attr.getType()); } - if (auto attr = hloAttr.dyn_cast()) { + if (auto attr = mlir::dyn_cast(hloAttr)) { RETURN_CONVERTED_ENUM_ATTR(ComparisonDirection); } - if (auto attr = hloAttr.dyn_cast()) { + if (auto attr = mlir::dyn_cast(hloAttr)) { RETURN_CONVERTED_ENUM_ATTR(ComparisonType); } - if (auto attr = hloAttr.dyn_cast()) { + if (auto attr = mlir::dyn_cast(hloAttr)) { return stablehlo::ConvDimensionNumbersAttr::get( attr.getContext(), attr.getInputBatchDimension(), attr.getInputFeatureDimension(), attr.getInputSpatialDimensions(), @@ -264,44 +265,44 @@ Attribute convertAttr(Attribute hloAttr) { } // NOTE: We cannot process CustomCallApiVersionAttr here because // `dyn_cast()` succeeds for IntegerAttr too. - if (auto attr = hloAttr.dyn_cast()) { + if (auto attr = mlir::dyn_cast(hloAttr)) { return stablehlo::DotDimensionNumbersAttr::get( attr.getContext(), attr.getLhsBatchingDimensions(), attr.getRhsBatchingDimensions(), attr.getLhsContractingDimensions(), attr.getRhsContractingDimensions()); } - if (auto attr = hloAttr.dyn_cast()) { + if (auto attr = mlir::dyn_cast(hloAttr)) { RETURN_CONVERTED_ENUM_ATTR(FftType); } - if (auto attr = hloAttr.dyn_cast()) { + if (auto attr = mlir::dyn_cast(hloAttr)) { return stablehlo::GatherDimensionNumbersAttr::get( attr.getContext(), attr.getOffsetDims(), attr.getCollapsedSliceDims(), attr.getStartIndexMap(), attr.getIndexVectorDim()); } - if (auto attr = hloAttr.dyn_cast()) { + if (auto attr = mlir::dyn_cast(hloAttr)) { return stablehlo::OutputOperandAliasAttr::get( attr.getContext(), attr.getOutputTupleIndices(), attr.getOperandIndex(), attr.getOperandTupleIndices()); } - if (auto attr = hloAttr.dyn_cast()) { + if (auto attr = mlir::dyn_cast(hloAttr)) { // StableHLO Precision doesn't support PACKED_NIBBLE yet. // Proposal: https://github.com/openxla/stablehlo/issues/742. if (attr.getValue() == mhlo::Precision::PACKED_NIBBLE) return {}; RETURN_CONVERTED_ENUM_ATTR(Precision); } - if (auto attr = hloAttr.dyn_cast()) { + if (auto attr = mlir::dyn_cast(hloAttr)) { RETURN_CONVERTED_ENUM_ATTR(RngAlgorithm); } - if (auto attr = hloAttr.dyn_cast()) { + if (auto attr = mlir::dyn_cast(hloAttr)) { RETURN_CONVERTED_ENUM_ATTR(RngDistribution); } - if (auto attr = hloAttr.dyn_cast()) { + if (auto attr = mlir::dyn_cast(hloAttr)) { return stablehlo::ScatterDimensionNumbersAttr::get( attr.getContext(), attr.getUpdateWindowDims(), attr.getInsertedWindowDims(), attr.getScatterDimsToOperandDims(), attr.getIndexVectorDim()); } - if (auto attr = hloAttr.dyn_cast()) { + if (auto attr = mlir::dyn_cast(hloAttr)) { RETURN_CONVERTED_ENUM_ATTR(Transpose); } if (hloAttr.getDialect().getNamespace() == @@ -316,7 +317,7 @@ Attribute convertAttr(Attribute hloAttr) { // Handle non-MHLO attributes. // If an attribute is not defined in MHLO, then it is unchanged, // with the exception of ArrayAttr which is converted recursively. - if (auto hloAttrs = hloAttr.dyn_cast()) { + if (auto hloAttrs = mlir::dyn_cast(hloAttr)) { SmallVector stablehloAttrs; for (auto hloAttr : hloAttrs) { auto stablehloAttr = convertAttr(hloAttr); @@ -338,11 +339,11 @@ Attribute convertAttr(Attribute hloAttr) { // we can fork and modify the code of `stringifyPrecision` as needed for // compatibility. Attribute encodePrecisionConfig(Attribute hloAttrs) { - auto hloArrayAttr = hloAttrs.dyn_cast(); + auto hloArrayAttr = mlir::dyn_cast(hloAttrs); if (!hloArrayAttr) return {}; SmallVector stablehloAttrs; for (auto hloAttr : hloArrayAttr) { - auto precisionAttr = hloAttr.dyn_cast(); + auto precisionAttr = mlir::dyn_cast(hloAttr); if (!precisionAttr) return {}; StringRef precisionStr = mhlo::stringifyPrecision(precisionAttr.getValue()); if (precisionStr.empty()) return {}; diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_control_flow/legalize_control_flow.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_control_flow/legalize_control_flow.cc index d56a1da7591094..9d473b9a058936 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_control_flow/legalize_control_flow.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_control_flow/legalize_control_flow.cc @@ -64,8 +64,8 @@ void inlineMhloRegionIntoSCFRegion(PatternRewriter& rewriter, Region& mhlo, // scalar value when necessary. Value extractTensorValue(OpBuilder& b, Value tensor) { auto loc = tensor.getLoc(); - if (tensor.getType().cast().hasRank() && - tensor.getType().cast().getRank() != 0) { + if (mlir::cast(tensor.getType()).hasRank() && + mlir::cast(tensor.getType()).getRank() != 0) { tensor = b.create( loc, tensor, SmallVector()); } @@ -85,9 +85,9 @@ std::optional extractForBounds(mhlo::WhileOp op) { if (cond.getOperations().size() != 2) return std::nullopt; auto matchBbArg = [](Value v, Block& block) -> std::optional { - if (!v.isa() || v.getParentBlock() != &block) + if (!mlir::isa(v) || v.getParentBlock() != &block) return std::nullopt; - return v.cast().getArgNumber(); + return mlir::cast(v).getArgNumber(); }; auto compare = llvm::dyn_cast(cond.front()); @@ -207,10 +207,10 @@ struct CaseOpPattern : public OpConversionPattern { // Determine if the current index matches the case index. auto scalarType = idxValue.getType(); - auto shapedType = scalarType.cast(); + auto shapedType = mlir::cast(scalarType); auto constAttr = DenseElementsAttr::get( - shapedType, - {outerBuilder.getI32IntegerAttr(currentIdx).cast()}); + shapedType, {mlir::cast( + outerBuilder.getI32IntegerAttr(currentIdx))}); Value currentIdxVal = outerBuilder.create( loc, idxValue.getType(), constAttr); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_cross_replica_sum_to_all_reduce/legalize_cross_replica_sum_to_all_reduce.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_cross_replica_sum_to_all_reduce/legalize_cross_replica_sum_to_all_reduce.cc index dfc541c591a4da..0ddd409bf3f583 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_cross_replica_sum_to_all_reduce/legalize_cross_replica_sum_to_all_reduce.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_cross_replica_sum_to_all_reduce/legalize_cross_replica_sum_to_all_reduce.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -49,11 +50,10 @@ struct CrossReplicaSumToAllReducePattern /*use_global_device_ids=*/false); auto *block = rewriter.createBlock(&allReduceOp.getComputation()); - auto elementType = RankedTensorType::get({}, allReduceOp.getResults() - .front() - .getType() - .dyn_cast() - .getElementType()); + auto elementType = RankedTensorType::get( + {}, + mlir::dyn_cast(allReduceOp.getResults().front().getType()) + .getElementType()); auto location = allReduceOp.getComputation().getLoc(); block->addArguments({elementType, elementType}, {location, location}); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_dot_general_to_dot/legalize_dot_general_to_dot.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_dot_general_to_dot/legalize_dot_general_to_dot.cc index ea37c6104e62cd..6c2c26e3d28bf5 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_dot_general_to_dot/legalize_dot_general_to_dot.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_dot_general_to_dot/legalize_dot_general_to_dot.cc @@ -24,6 +24,7 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -44,8 +45,8 @@ struct DotGeneralToDot : public OpRewritePattern { PatternRewriter& rewriter) const override { auto lhs = dot.getLhs(); auto rhs = dot.getRhs(); - auto lhsTy = lhs.getType().cast(); - auto rhsTy = rhs.getType().cast(); + auto lhsTy = mlir::cast(lhs.getType()); + auto rhsTy = mlir::cast(rhs.getType()); int64_t lhsRank = lhsTy.getRank(); int64_t rhsRank = rhsTy.getRank(); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc index 0dc4495cdbbc6c..f8c0f9eafd7c83 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_einsum_to_dot_general/legalize_einsum_to_dot_general.cc @@ -25,6 +25,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -68,8 +69,8 @@ struct EinsumToDotGeneralPattern : public OpRewritePattern { index++; } - auto lhsType = einsum.getLhs().getType().cast(); - auto rhsType = einsum.getRhs().getType().cast(); + auto lhsType = mlir::cast(einsum.getLhs().getType()); + auto rhsType = mlir::cast(einsum.getRhs().getType()); assert(static_cast(lhsTokens.size()) == lhsType.getRank()); assert(static_cast(rhsTokens.size()) == rhsType.getRank()); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_gather_to_torch_index_select/legalize_gather_to_torch_index_select.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_gather_to_torch_index_select/legalize_gather_to_torch_index_select.cc index aba642b45cdf09..34837951b24eef 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_gather_to_torch_index_select/legalize_gather_to_torch_index_select.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_gather_to_torch_index_select/legalize_gather_to_torch_index_select.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -39,13 +40,13 @@ struct GatherIsTorchIndexSelect : public OpRewritePattern { LogicalResult matchAndRewrite(GatherOp gather, PatternRewriter &rewriter) const override { auto startIndices = gather.getStartIndices(); - auto startIndicesTy = startIndices.getType().cast(); + auto startIndicesTy = mlir::cast(startIndices.getType()); if (!startIndicesTy.hasRank()) { return rewriter.notifyMatchFailure(gather, "unranked start_indices"); } auto operand = gather.getOperand(); - auto operandTy = operand.getType().cast(); + auto operandTy = mlir::cast(operand.getType()); if (!operandTy.hasRank()) { return rewriter.notifyMatchFailure(gather, "unranked operand"); } @@ -73,7 +74,8 @@ struct GatherIsTorchIndexSelect : public OpRewritePattern { return rewriter.notifyMatchFailure(gather, "start_index_map != [0]"); } - auto resultTy = gather.getResult().getType().dyn_cast(); + auto resultTy = + mlir::dyn_cast(gather.getResult().getType()); if (!resultTy) { return rewriter.notifyMatchFailure(gather, "unranked result"); } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_shape_computations/legalize_shape_computations.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_shape_computations/legalize_shape_computations.cc index 7a8feb69caf9cb..935465d2dabaa6 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_shape_computations/legalize_shape_computations.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_shape_computations/legalize_shape_computations.cc @@ -46,6 +46,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -62,7 +63,7 @@ namespace { bool opIsShapeComputation(Operation *op) { bool foundFromElements = false; for (auto operand : op->getOperands()) { - auto shapedTy = operand.getType().template cast(); + auto shapedTy = mlir::cast(operand.getType()); if (!shapedTy.hasRank() || shapedTy.getRank() > 1) return false; if (auto fromElements = operand.template getDefiningOp()) { @@ -82,14 +83,14 @@ class MhloElementwiseConverter : public OpRewritePattern { PatternRewriter &rewriter) const final { if (!opIsShapeComputation(op)) return failure(); - auto resultTy = op.getType().template cast(); + auto resultTy = mlir::cast(op.getType()); Location loc = op.getLoc(); SmallVector operands; for (int i = 0, s = resultTy.getNumElements(); i < s; i++) { SmallVector extracts; for (auto operand : op->getOperands()) { - ShapedType operandTy = operand.getType().template cast(); + ShapedType operandTy = mlir::cast(operand.getType()); if (operandTy.getRank() == 0) { Value extract = rewriter.create(loc, operand, ValueRange({})); @@ -121,12 +122,12 @@ class ConcatenateConverter : public OpRewritePattern { if (!opIsShapeComputation(op)) return failure(); Location loc = op.getLoc(); - auto resultTy = op.getType().cast(); + auto resultTy = mlir::cast(op.getType()); llvm::SmallVector elements; elements.reserve(resultTy.getNumElements()); for (auto operand : op->getOperands()) { - ShapedType operandTy = operand.getType().template cast(); + ShapedType operandTy = mlir::cast(operand.getType()); if (operandTy.getRank() == 0) { Value extract = rewriter.create(loc, operand, ValueRange({})); @@ -174,10 +175,10 @@ class ReshapeConverter : public OpRewritePattern { LogicalResult matchAndRewrite(mhlo::ReshapeOp op, PatternRewriter &rewriter) const final { auto operand = op.getOperand(); - auto shapedTy = operand.getType().template cast(); + auto shapedTy = mlir::cast(operand.getType()); if (!shapedTy.hasRank() || shapedTy.getRank() > 1) return failure(); - auto resultTy = op.getType().cast(); + auto resultTy = mlir::cast(op.getType()); auto fromElements = op.getOperand().getDefiningOp(); if (!fromElements) return failure(); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_sort/legalize_sort.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_sort/legalize_sort.cc index 6446995a7e91fb..8ba9de9a1a10f4 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_sort/legalize_sort.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_sort/legalize_sort.cc @@ -139,7 +139,7 @@ SmallVector loadTensorElements(ImplicitLocOpBuilder& b, SmallVector loadMemrefElements(ImplicitLocOpBuilder& b, ValueRange memrefs, Value index) { return llvm::to_vector(llvm::map_range(memrefs, [&](Value memref) -> Value { - Type type = memref.getType().cast().getElementType(); + Type type = mlir::cast(memref.getType()).getElementType(); return b.create(type, memref, index); })); } @@ -414,25 +414,24 @@ struct Slicer { } MemRefType toSlicedType(MemRefType sourceType) { - return memref::SubViewOp::inferRankReducedResultType( - {ShapedType::kDynamic} /*1D output*/, sourceType, offsets, sizes, - strides) - .cast(); + return mlir::cast(memref::SubViewOp::inferRankReducedResultType( + {ShapedType::kDynamic} /*1D output*/, sourceType, offsets, sizes, + strides)); } template Value slice(ImplicitLocOpBuilder& b, Value input) { - Ty ty = input.getType().cast(); + Ty ty = mlir::cast(input.getType()); return b.create(toSlicedType(ty), input, offsets, sizes, strides) .getResult(); } Value apply(ImplicitLocOpBuilder& b, Value input) { Type inTy = input.getType(); - if (inTy.isa()) { + if (mlir::isa(inTy)) { return slice(b, input); } - assert(inTy.isa()); + assert(mlir::isa(inTy)); return slice(b, input); } @@ -470,7 +469,7 @@ struct SortOpPattern : public OpRewritePattern { SmallVector scratchMemrefs; Value firstOperand = op.getOperands().front(); - auto firstOperandType = firstOperand.getType().cast(); + auto firstOperandType = mlir::cast(firstOperand.getType()); int64_t inputRank = firstOperandType.getRank(); Value sortDimSize = b.createOrFold( @@ -489,7 +488,7 @@ struct SortOpPattern : public OpRewritePattern { // statically known to be <= kInsertionSortSize, `scratchMemrefs` are unused // and will be cleaned up later. for (auto input : op.getOperands()) { - auto inputType = input.getType().cast(); + auto inputType = mlir::cast(input.getType()); auto memRefType = MemRefType::get(inputType.getShape(), inputType.getElementType()); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc index 2c470ccf66d48d..7c373cbdb2fbb7 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_to_linalg/legalize_to_linalg.cc @@ -73,22 +73,22 @@ namespace { Value getResultValue(Operation* op) { return op->getResult(0); } ShapedType getHloOpResultType(Operation* op) { - return getResultValue(op).getType().cast(); + return mlir::cast(getResultValue(op).getType()); } bool verifyHloOpBufferOrTensorSemantics(Operation* op) { auto verifyType = [&](Value val) -> bool { - return val.getType().isa(); + return mlir::isa(val.getType()); }; if (!llvm::all_of(op->getOperands(), verifyType)) return false; return llvm::all_of(op->getResults(), verifyType); } Value fillTensorWithZeros(OpBuilder& builder, Location loc, Value tensor) { - auto type = tensor.getType().cast(); + auto type = mlir::cast(tensor.getType()); Value zero; // Complex numbers are a special case. - if (auto complexType = type.getElementType().dyn_cast()) { + if (auto complexType = mlir::dyn_cast(type.getElementType())) { auto zeroElement = builder.getZeroAttr(complexType.getElementType()); auto zeroAttr = builder.getArrayAttr({zeroElement, zeroElement}); zero = builder.create(loc, complexType, zeroAttr); @@ -233,15 +233,15 @@ struct RngUniformConversion : public OpConversionPattern { return failure(); } // TODO(raikonenfnu): Handle other element types as well. - auto minTy = adaptor.getOperands()[0].getType().dyn_cast(); - auto maxTy = adaptor.getOperands()[0].getType().dyn_cast(); - if (!minTy.getElementType().dyn_cast() || - !maxTy.getElementType().dyn_cast()) { + auto minTy = mlir::dyn_cast(adaptor.getOperands()[0].getType()); + auto maxTy = mlir::dyn_cast(adaptor.getOperands()[0].getType()); + if (!mlir::dyn_cast(minTy.getElementType()) || + !mlir::dyn_cast(maxTy.getElementType())) { return rewriter.notifyMatchFailure( op, "expected min/max for rng op to be FloatType"); } - auto targetTy = this->typeConverter->convertType(op.getResult().getType()) - .cast(); + auto targetTy = mlir::cast( + this->typeConverter->convertType(op.getResult().getType())); if (!targetTy) { return rewriter.notifyMatchFailure( op, "expected target shape of rng op to be ShapedType"); @@ -339,14 +339,16 @@ SmallVector extractDynamicEinsumSizes( if (dimIndIt != lhsLoopVec.end()) { // Query from lhs vars. auto dimIndPos = dimIndIt - lhsLoopVec.begin(); - auto lhsShape = lhs.getType().dyn_cast().getShape(); + auto lhsShape = + mlir::dyn_cast(lhs.getType()).getShape(); if (lhsShape[dimIndPos] != ShapedType::kDynamic) continue; dimSize = b.create(loc, lhs, dimIndPos); } else { // query from rhs vars. dimIndIt = std::find(rhsLoopVec.begin(), rhsLoopVec.end(), dimInd); auto dimIndPos = dimIndIt - rhsLoopVec.begin(); - auto rhsShape = rhs.getType().dyn_cast().getShape(); + auto rhsShape = + mlir::dyn_cast(rhs.getType()).getShape(); if (rhsShape[dimIndPos] != ShapedType::kDynamic) continue; dimSize = b.create(loc, rhs, dimIndPos); } @@ -407,7 +409,7 @@ class EinsumToLinalgConverter : public OpConversionPattern { mhlo::EinsumOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { auto getRank = [](Value v) { - return v.getType().cast().getRank(); + return mlir::cast(v.getType()).getRank(); }; auto einsumConfig = op.getEinsumConfig(); @@ -433,8 +435,8 @@ class EinsumToLinalgConverter : public OpConversionPattern { } // Find result type, if on tensors. - auto resultTy = this->typeConverter->convertType(getHloOpResultType(op)) - .dyn_cast(); + auto resultTy = mlir::dyn_cast( + this->typeConverter->convertType(getHloOpResultType(op))); // Check result type compatibility. if (!resultTy || !(resultTy.getElementType().isSignlessIntOrFloat())) { @@ -632,8 +634,8 @@ class DataMovementOpConverter : public OpConversionPattern { ConversionPatternRewriter& rewriter) const final { if (!verifyHloOpBufferOrTensorSemantics(op)) return failure(); auto resultType = getHloOpResultType(op); - resultType = this->typeConverter->convertType(resultType) - .template cast(); + resultType = + mlir::cast(this->typeConverter->convertType(resultType)); SmallVector indexingMaps = Derived::getIndexingMaps(op, &rewriter); @@ -671,7 +673,7 @@ class BroadcastConverter static SmallVector getIndexingMaps(OpTy broadcastOp, Builder* b) { ShapedType inputType = - broadcastOp.getOperand().getType().template cast(); + mlir::cast(broadcastOp.getOperand().getType()); unsigned inputRank = inputType.getRank(); unsigned nloops = getHloOpResultType(broadcastOp).getRank(); @@ -704,7 +706,8 @@ class BroadcastOpToBroadcastConverter LogicalResult matchAndRewrite( mhlo::BroadcastOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - auto resultTy = typeConverter->convertType(op.getType()).cast(); + auto resultTy = + mlir::cast(typeConverter->convertType(op.getType())); int64_t numPrependedDims = op.getBroadcastSizes().size(); SmallVector dimensions = @@ -733,7 +736,7 @@ class HloBroadcastInDimConverter mhlo::BroadcastInDimOp broadcastOp, Builder* b) { auto resultType = getHloOpResultType(broadcastOp); auto operandType = - broadcastOp.getOperand().getType().template cast(); + mlir::cast(broadcastOp.getOperand().getType()); unsigned nloops = resultType.getRank(); // The input is a scalar, i.e. this is a scalar broadcast op. @@ -765,7 +768,7 @@ class HloBroadcastInDimConverter Value collapseExpandingDims(PatternRewriter& rewriter, Location loc, Value operand, SmallVector& dimensions, llvm::function_ref isExpandingDim) { - auto operandTy = operand.getType().cast(); + auto operandTy = mlir::cast(operand.getType()); SmallVector reassociationMap; ReassociationIndices currentIndices; @@ -816,7 +819,7 @@ Value transposeBroadcastOperand(PatternRewriter& rewriter, Location loc, return dimensions[lhs] < dimensions[rhs]; }); - auto operandTy = operand.getType().cast(); + auto operandTy = mlir::cast(operand.getType()); ArrayRef operandShape = operandTy.getShape(); SmallVector transposedOperandShape, transposedDimensions; @@ -846,8 +849,9 @@ class BroadcastInDimOpToBroadcastConverter llvm::to_vector(op.getBroadcastDimensions().getValues()); Value operand = adaptor.getOperand(); - auto operandTy = operand.getType().cast(); - auto resultTy = typeConverter->convertType(op.getType()).cast(); + auto operandTy = mlir::cast(operand.getType()); + auto resultTy = + mlir::cast(typeConverter->convertType(op.getType())); ArrayRef operandShape = operandTy.getShape(); ArrayRef resultShape = resultTy.getShape(); @@ -895,10 +899,10 @@ class HloDynamicBroadcastInDimConverter mhlo::DynamicBroadcastInDimOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { Value operand = adaptor.getOperand(); - auto operandType = operand.getType().dyn_cast(); + auto operandType = mlir::dyn_cast(operand.getType()); if (!operandType) return failure(); - auto resultType = - typeConverter->convertType(op.getType()).dyn_cast(); + auto resultType = mlir::dyn_cast( + typeConverter->convertType(op.getType())); if (!resultType) return failure(); // Determine dimension expressions based on whether the dimension is @@ -971,10 +975,10 @@ class DynamicBroadcastInDimOpToBroadcastConverter Location loc = op.getLoc(); Value operand = adaptor.getOperand(); - auto operandTy = operand.getType().dyn_cast(); + auto operandTy = mlir::dyn_cast(operand.getType()); if (!operandTy) return failure(); - auto resultTy = - typeConverter->convertType(op.getType()).dyn_cast(); + auto resultTy = mlir::dyn_cast( + typeConverter->convertType(op.getType())); if (!resultTy) return failure(); SmallVector broadcastDimensions = @@ -1049,7 +1053,7 @@ class DynamicBroadcastInDimOpToBroadcastConverter static Value getBroadcastOperand( PatternRewriter& rewriter, Location loc, Value operand, llvm::function_ref isExpandingDim) { - auto operandTy = operand.getType().dyn_cast(); + auto operandTy = mlir::dyn_cast(operand.getType()); SmallVector updatedOperandShape = llvm::to_vector(operandTy.getShape()); @@ -1070,7 +1074,8 @@ class DynamicBroadcastInDimOpToBroadcastConverter static ShapedType getBroadcastResultType( Value operand, RankedTensorType resultTy, ArrayRef dimensions, llvm::function_ref isExpandingDim) { - auto operandShape = operand.getType().cast().getShape(); + auto operandShape = + mlir::cast(operand.getType()).getShape(); auto broadcastResultShape = llvm::to_vector(resultTy.getShape()); for (const auto& [operandIndex, resultIndex] : @@ -1091,7 +1096,7 @@ class TransposeConverter using DataMovementOpConverter, OpTy>::DataMovementOpConverter; static SmallVector getIndexingMaps(OpTy op, Builder* b) { - auto resultType = getHloOpResultType(op).template cast(); + auto resultType = mlir::cast(getHloOpResultType(op)); auto nloops = resultType.getRank(); SmallVector inputExprs; inputExprs.resize(resultType.getRank()); @@ -1112,7 +1117,8 @@ class TransposeOpToTransposeConverter LogicalResult matchAndRewrite( mhlo::TransposeOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - auto resultTy = typeConverter->convertType(op.getType()).cast(); + auto resultTy = + mlir::cast(typeConverter->convertType(op.getType())); auto loc = op.getLoc(); Value emptyTensor = @@ -1137,9 +1143,10 @@ class BitcastConvertConverter ConversionPatternRewriter& rewriter) const final { if (!verifyHloOpBufferOrTensorSemantics(op)) return failure(); - auto inputType = adaptor.getOperand().getType().cast(); + auto inputType = + mlir::cast(adaptor.getOperand().getType()); auto outputType = - typeConverter->convertType(op.getType()).cast(); + mlir::cast(typeConverter->convertType(op.getType())); auto loc = op.getLoc(); // Fallback to pointwise conversion if the tensor dimensions are not @@ -1251,7 +1258,7 @@ class RealDynamicSliceConverter mhlo::RealDynamicSliceOp realDynamicSliceOp, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { Location loc = realDynamicSliceOp.getLoc(); - auto argType = adaptor.getOperand().getType().dyn_cast(); + auto argType = mlir::dyn_cast(adaptor.getOperand().getType()); if (!argType || !argType.hasRank()) { return rewriter.notifyMatchFailure(realDynamicSliceOp, "require known-rank args"); @@ -1268,9 +1275,8 @@ class RealDynamicSliceConverter dimElementType.isIndex() ? rewriter.getI64Type() : dimElementType; Type indexType = rewriter.getIndexType(); - auto resultType = - this->typeConverter->convertType(realDynamicSliceOp.getType()) - .cast(); + auto resultType = mlir::cast( + this->typeConverter->convertType(realDynamicSliceOp.getType())); Value zero = rewriter.create(loc, 0); SmallVector offsets, sizes, strides; SmallVector clampType(3, arithType); @@ -1340,9 +1346,9 @@ class ReshapeOpConverter : public OpConversionPattern { ConversionPatternRewriter& rewriter) const final { if (!verifyHloOpBufferOrTensorSemantics(reshapeOp)) return failure(); auto operand = adaptor.getOperand(); - auto operandType = operand.getType().cast(); + auto operandType = mlir::cast(operand.getType()); auto elemType = operandType.getElementType(); - auto resultType = reshapeOp.getType().cast(); + auto resultType = mlir::cast(reshapeOp.getType()); if (!resultType.hasStaticShape()) return failure(); @@ -1354,7 +1360,7 @@ class ReshapeOpConverter : public OpConversionPattern { return success(); } - resultType = typeConverter->convertType(resultType).cast(); + resultType = mlir::cast(typeConverter->convertType(resultType)); // Special case where the result is a scalar. if (resultType.getRank() == 0 && !operandType.hasStaticShape()) { @@ -1472,8 +1478,8 @@ class IotaConverter : public OpConversionPattern { ShapedType resultShapedType = getHloOpResultType(iotaOp); if (!resultShapedType) return failure(); Type targetElementType = resultShapedType.getElementType(); - resultShapedType = this->typeConverter->convertType(resultShapedType) - .template dyn_cast(); + resultShapedType = mlir::dyn_cast( + this->typeConverter->convertType(resultShapedType)); Type resultElementType = resultShapedType.getElementType(); @@ -1497,7 +1503,7 @@ class IotaConverter : public OpConversionPattern { nestedLoc, iotaOp.getIotaDimension()); Type unwrappedResultElementType = resultElementType; if (auto complexType = - unwrappedResultElementType.dyn_cast()) + mlir::dyn_cast(unwrappedResultElementType)) unwrappedResultElementType = complexType.getElementType(); Value castOp = nestedBuilder.create( nestedLoc, @@ -1526,8 +1532,8 @@ class IotaToMapConverter : public OpConversionPattern { ShapedType resultTy = getHloOpResultType(iotaOp); if (!resultTy) return failure(); Type targetElementType = resultTy.getElementType(); - resultTy = this->typeConverter->convertType(resultTy) - .template dyn_cast(); + resultTy = + mlir::dyn_cast(this->typeConverter->convertType(resultTy)); Location loc = iotaOp.getLoc(); Value empty = getEmptyTensorFor(rewriter, loc, resultTy, iotaOp, @@ -1564,8 +1570,8 @@ struct ConcatenateConverter : public OpConversionPattern { return success(); } - auto resultType = this->typeConverter->convertType(op.getResult().getType()) - .dyn_cast(); + auto resultType = mlir::dyn_cast( + this->typeConverter->convertType(op.getResult().getType())); if (!resultType) return failure(); uint64_t dim = op.getDimension(); @@ -1647,9 +1653,9 @@ class ConstConverterTensor : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::ConstantOp constOp, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const final { - auto valueAttr = constOp.getValue().cast(); + auto valueAttr = mlir::cast(constOp.getValue()); auto type = - typeConverter->convertType(constOp.getType()).cast(); + mlir::cast(typeConverter->convertType(constOp.getType())); if (type != constOp.getType()) { // Signedness conversion. valueAttr = valueAttr.mapValues(type.getElementType(), @@ -1668,7 +1674,7 @@ class ReverseConverter mhlo::ReverseOp>::DataMovementOpConverter; static SmallVector getIndexingMaps(mhlo::ReverseOp op, Builder* b) { - auto resultType = getHloOpResultType(op).cast(); + auto resultType = mlir::cast(getHloOpResultType(op)); auto nloops = resultType.getRank(); SmallVector inputExprs; inputExprs.reserve(nloops); @@ -1693,7 +1699,8 @@ class SliceConverter : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::SliceOp sliceOp, typename mhlo::SliceOp::Adaptor adaptor, ConversionPatternRewriter& rewriter) const final { - auto argType = adaptor.getOperands()[0].getType().dyn_cast(); + auto argType = + mlir::dyn_cast(adaptor.getOperands()[0].getType()); if (!argType || !argType.hasRank()) { return rewriter.notifyMatchFailure(sliceOp, "expects known-rank args"); } @@ -1726,7 +1733,7 @@ class DynamicSliceConverter : public OpConversionPattern { mhlo::DynamicSliceOp dynamicSliceOp, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const final { auto loc = dynamicSliceOp.getLoc(); - auto argType = adaptor.getOperand().getType().dyn_cast(); + auto argType = mlir::dyn_cast(adaptor.getOperand().getType()); if (!argType || !argType.hasRank()) { return rewriter.notifyMatchFailure(dynamicSliceOp, "require known-rank args"); @@ -1764,8 +1771,8 @@ class DynamicSliceConverter : public OpConversionPattern { int64_t rank = argType.getRank(); SmallVector strides(rank, rewriter.getI64IntegerAttr(1)); - auto resultType = this->typeConverter->convertType(dynamicSliceOp.getType()) - .cast(); + auto resultType = mlir::cast( + this->typeConverter->convertType(dynamicSliceOp.getType())); rewriter.replaceOpWithNewOp( dynamicSliceOp, resultType, adaptor.getOperand(), startIndices, sizes, @@ -1784,14 +1791,14 @@ class DynamicUpdateSliceConverter ConversionPatternRewriter& rewriter) const final { auto loc = op.getLoc(); auto operandType = - adaptor.getOperand().getType().dyn_cast(); + mlir::dyn_cast(adaptor.getOperand().getType()); if (!operandType || !operandType.hasStaticShape()) { return rewriter.notifyMatchFailure( op, "require static ranked type for operand"); } auto updateType = - adaptor.getUpdate().getType().dyn_cast(); + mlir::dyn_cast(adaptor.getUpdate().getType()); if (!updateType || !updateType.hasStaticShape()) { return rewriter.notifyMatchFailure( op, "require static ranked type for operand"); @@ -1842,9 +1849,9 @@ enum class DotOperationType { DotOperationType getDotOperationType(mhlo::DotOp dotOp) { ArrayRef lhsShape = - dotOp.getLhs().getType().cast().getShape(); + mlir::cast(dotOp.getLhs().getType()).getShape(); ArrayRef rhsShape = - dotOp.getRhs().getType().cast().getShape(); + mlir::cast(dotOp.getRhs().getType()).getShape(); auto shapeMatches = [](int64_t a, int64_t b) { return a == ShapedType::kDynamic || b == ShapedType::kDynamic || a == b; }; @@ -1873,19 +1880,19 @@ SmallVector getDotOpEmptyTensorDynSizes(OpBuilder& b, Location loc, SmallVector dynShape; switch (type) { case DotOperationType::kMatrixMatrix: { - if (lhs.getType().cast().isDynamicDim(0)) + if (mlir::cast(lhs.getType()).isDynamicDim(0)) dynShape.push_back(b.create(loc, lhs, 0)); - if (rhs.getType().cast().isDynamicDim(1)) + if (mlir::cast(rhs.getType()).isDynamicDim(1)) dynShape.push_back(b.create(loc, rhs, 1)); break; } case DotOperationType::kMatrixVector: { - if (lhs.getType().cast().isDynamicDim(0)) + if (mlir::cast(lhs.getType()).isDynamicDim(0)) dynShape.push_back(b.create(loc, lhs, 0)); break; } case DotOperationType::kVectorMatrix: { - if (rhs.getType().cast().isDynamicDim(1)) + if (mlir::cast(rhs.getType()).isDynamicDim(1)) dynShape.push_back(b.create(loc, rhs, 1)); break; } @@ -1912,7 +1919,7 @@ class DotOpConversion : public OpConversionPattern { // Convert unsigned to signed. This works because signed and unsigned // integer matmul is the same operation in two's complement. auto outputType = - typeConverter->convertType(op.getType()).cast(); + mlir::cast(typeConverter->convertType(op.getType())); SmallVector dynShape = getDotOpEmptyTensorDynSizes( rewriter, loc, adaptor.getLhs(), adaptor.getRhs(), op_type); auto emptyTensor = @@ -1938,7 +1945,7 @@ class DotGeneralBatchMatMulOpConversion if (!verifyHloOpBufferOrTensorSemantics(op)) { return failure(); } - if (op.getType().cast().getRank() != 3) { + if (mlir::cast(op.getType()).getRank() != 3) { return rewriter.notifyMatchFailure(op, "expected a batch matmul"); } @@ -1968,7 +1975,7 @@ class DotGeneralBatchMatMulOpConversion // Convert unsigned to signed. This works because signed and unsigned // integer matmul is the same operation in two's complement. auto outputType = - typeConverter->convertType(op.getType()).cast(); + mlir::cast(typeConverter->convertType(op.getType())); auto emptyTensor = getEmptyTensorFor(rewriter, loc, outputType, op, adaptor.getOperands()); Value zeroTensor = fillTensorWithZeros(rewriter, loc, emptyTensor); @@ -1992,7 +1999,7 @@ class MapOpToGenericConverter : public OpConversionPattern { if (!verifyHloOpBufferOrTensorSemantics(op)) return failure(); auto resultType = - typeConverter->convertType(op.getType()).cast(); + mlir::cast(typeConverter->convertType(op.getType())); assert(op.getDimensions().size() == resultType.getRank() && "Expected a pointwise map"); @@ -2019,7 +2026,7 @@ class MapOpToGenericConverter : public OpConversionPattern { signatureConverter.addInputs( it.index(), typeConverter->convertType( - it.value().getType().cast().getElementType())); + mlir::cast(it.value().getType()).getElementType())); } signatureConverter.addInputs(resultType.getElementType()); @@ -2039,7 +2046,7 @@ class MapOpToMapConverter : public OpConversionPattern { if (!verifyHloOpBufferOrTensorSemantics(op)) return failure(); auto resultType = - typeConverter->convertType(op.getType()).cast(); + mlir::cast(typeConverter->convertType(op.getType())); assert(op.getDimensions().size() == resultType.getRank() && "Expected a pointwise map"); @@ -2069,7 +2076,7 @@ class MapOpToMapConverter : public OpConversionPattern { signatureConverter.addInputs( it.index(), typeConverter->convertType( - it.value().getType().cast().getElementType())); + mlir::cast(it.value().getType()).getElementType())); } rewriter.applySignatureConversion(®ion, signatureConverter, @@ -2089,7 +2096,7 @@ SmallVector getReduceOpEmptyTensorDynSizes( SmallVector parallelDims; SmallVector dynShape; - int rank = arg.getType().cast().getRank(); + int rank = mlir::cast(arg.getType()).getRank(); for (int i = 0, j = 0; i < rank; ++i) { if (s.count(i)) continue; if (!resultType.isDynamicDim(j++)) continue; @@ -2111,7 +2118,7 @@ class ReduceRegionReturnOpConversion } SmallVector operands(adaptor.getOperands()); for (size_t i = 0; i < operands.size(); ++i) { - if (operands[i].getType().isa()) { + if (mlir::isa(operands[i].getType())) { auto loc = operands[i].getLoc(); operands[i] = rewriter.create(loc, operands[i]); } @@ -2132,12 +2139,12 @@ class ReduceOpToGenericConverter : public OpConversionPattern { int numOperands = static_cast(adaptor.getInputs().size()); if (llvm::any_of(adaptor.getInputs(), [](Value v) { - return !v.getType().isa(); + return !mlir::isa(v.getType()); })) { return rewriter.notifyMatchFailure(op, "expects known-rank args"); } auto srcRank = - adaptor.getInputs()[0].getType().cast().getRank(); + mlir::cast(adaptor.getInputs()[0].getType()).getRank(); SmallVector reductionDims = extract1DVector(op.getDimensions()); @@ -2205,14 +2212,14 @@ class ReduceOpToGenericConverter : public OpConversionPattern { /*origInputNo=*/idx + numOperands, // type for the new operand number 'idx'. typeConverter->convertType( - val.getType().cast().getElementType())); + mlir::cast(val.getType()).getElementType())); } for (const auto& [idx, val] : llvm::enumerate(op.getInitValues())) { signatureConverter.addInputs( /*origInputNo=*/idx, // type for the new operand number 'idx' + 'numOperands'. typeConverter->convertType( - val.getType().cast().getElementType())); + mlir::cast(val.getType()).getElementType())); } rewriter.applySignatureConversion(®ion, signatureConverter, @@ -2234,7 +2241,7 @@ struct ReduceOpToReduceConverter : public OpConversionPattern { llvm::sort(reductionDims); auto toRankedTensor = [](Value v) -> RankedTensorType { - return v.getType().dyn_cast(); + return mlir::dyn_cast(v.getType()); }; SmallVector outputs; @@ -2256,14 +2263,14 @@ struct ReduceOpToReduceConverter : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "expects known-rank operands"); operandTypes.push_back(operandType); initValue = rewriter.createOrFold(loc, initValue); - auto tensorResultType = resultType.cast(); + auto tensorResultType = mlir::cast(resultType); // For linalg.reduce, the result type's dimensions must match the input's // dimensions, whereas MHLO allows replacing static dimensions with // dynamic ones. SmallVector resultShape; SmallVector dynShape; - for (const auto& [index, dim] : - llvm::enumerate(operand.getType().cast().getShape())) { + for (const auto& [index, dim] : llvm::enumerate( + mlir::cast(operand.getType()).getShape())) { if (!llvm::is_contained(reductionDims, index)) { resultShape.push_back(dim); if (ShapedType::isDynamic(dim)) { @@ -2334,9 +2341,8 @@ class RngBitGeneratorConverter ConversionPatternRewriter& rewriter) const final { Location loc = op.getLoc(); Value state = adaptor.getInitialState(); - ShapedType resultTy = - this->typeConverter->convertType(op.getResult(1).getType()) - .cast(); + ShapedType resultTy = mlir::cast( + this->typeConverter->convertType(op.getResult(1).getType())); if (op.getRngAlgorithm() == mhlo::RngAlgorithm::THREE_FRY) { Value random; @@ -2376,10 +2382,10 @@ struct SelectAndScatterNoOverlapConverter Value operand = op.getOperand(); Value init = op.getInitValue(); - auto sourceTy = source.getType().dyn_cast(); - auto operandTy = operand.getType().dyn_cast(); - auto initTy = init.getType().dyn_cast(); - auto resultTy = op.getResult().getType().dyn_cast(); + auto sourceTy = mlir::dyn_cast(source.getType()); + auto operandTy = mlir::dyn_cast(operand.getType()); + auto initTy = mlir::dyn_cast(init.getType()); + auto resultTy = mlir::dyn_cast(op.getResult().getType()); if (!sourceTy || !operandTy || !initTy || !resultTy) return rewriter.notifyMatchFailure(op, "inputs/outputs must be ranked"); @@ -2563,7 +2569,7 @@ struct SelectAndScatterNoOverlapConverter b.setInsertionPoint(op); Value reduceIndex = reduceGeneric.getResult(1); - ShapedType reduceIndexTy = reduceIndex.getType().cast(); + ShapedType reduceIndexTy = mlir::cast(reduceIndex.getType()); // For the second generic we restricted to only cases where there are // no window overlaps. This guarantees that each source value is scattered @@ -2674,7 +2680,7 @@ struct SelectAndScatterNoOverlapConverter Value collapse = b.create( scatterGeneric.getResult(0), reassociationMap); - auto collapseTy = collapse.getType().cast(); + auto collapseTy = mlir::cast(collapse.getType()); // After collapsing it it possible that the target may need to be padded. auto zero = b.createOrFold(0); @@ -2817,7 +2823,7 @@ struct PadOpConversion : public OpConversionPattern { rewriter.create(loc, paddingVal, emptyTensor).result(); // Get sizes of the original operand. - auto operandType = adaptor.getOperand().getType().cast(); + auto operandType = mlir::cast(adaptor.getOperand().getType()); auto sizes = llvm::to_vector<4>(llvm::map_range( llvm::seq(0, operandType.getRank()), [&](int64_t dim) -> OpFoldResult { @@ -2850,7 +2856,7 @@ Value applyConvolutionPadding(Location loc, Value input, (!lhsDilation || isSplatValue(lhsDilation, 1))) return input; - auto inputType = input.getType().cast(); + auto inputType = mlir::cast(input.getType()); auto rank = inputType.getRank(); // Translate window padding into low/high padding. @@ -2931,8 +2937,8 @@ struct NormalConvolutionOpConversion Value input = adaptor.getLhs(); Value filter = adaptor.getRhs(); filter = applyConvolutionReversal(loc, rewriter, op, filter); - auto resultType = - typeConverter->convertType(op.getResult().getType()).cast(); + auto resultType = mlir::cast( + typeConverter->convertType(op.getResult().getType())); int64_t rank = resultType.getRank(); // Immediately emit an EmptyOp for output tensors with zero dimension. @@ -3039,8 +3045,8 @@ struct ConvolutionOpGeneralConversion auto loc = op.getLoc(); auto* ctx = op.getContext(); - auto resultType = - typeConverter->convertType(op.getResult().getType()).cast(); + auto resultType = mlir::cast( + typeConverter->convertType(op.getResult().getType())); auto reshapedResultShape = resultType.getShape().vec(); if (!resultType.hasStaticShape()) return failure(); @@ -3089,8 +3095,8 @@ struct ConvolutionOpGeneralConversion // Non-one values for feature or batch group counts will result in reshaped // inputs and outputs. These mappings are used to keep track of the the new // index after reshaping has possibly inserted new dimensions. - auto paddedLhsType = modifiedLhs.getType().cast(); - auto paddedRhsType = modifiedRhs.getType().cast(); + auto paddedLhsType = mlir::cast(modifiedLhs.getType()); + auto paddedRhsType = mlir::cast(modifiedRhs.getType()); SmallVector lhsIndexMapping(paddedLhsType.getRank()); std::iota(lhsIndexMapping.begin(), lhsIndexMapping.end(), 0); SmallVector rhsIndexMapping(paddedRhsType.getRank()); @@ -3318,8 +3324,8 @@ struct DepthwiseConvolutionOpConversion // Make sure that this is depthwise convolution. int64_t inputFeatureDim = dimensionNumbers.getInputFeatureDimension(); - int64_t inputFeatureCount = - op.getLhs().getType().cast().getDimSize(inputFeatureDim); + int64_t inputFeatureCount = mlir::cast(op.getLhs().getType()) + .getDimSize(inputFeatureDim); if (static_cast(op.getFeatureGroupCount()) != inputFeatureCount) { return rewriter.notifyMatchFailure(op, "not depth-wise convolution"); } @@ -3350,8 +3356,8 @@ struct DepthwiseConvolutionOpConversion Location loc = op.getLoc(); Value input = adaptor.getLhs(); Value filter = adaptor.getRhs(); - auto resultType = typeConverter->convertType(op.getResult().getType()) - .cast(); + auto resultType = mlir::cast( + typeConverter->convertType(op.getResult().getType())); if (!resultType.hasStaticShape()) { return rewriter.notifyMatchFailure(op, "expected output has static shapes"); @@ -3371,12 +3377,12 @@ struct DepthwiseConvolutionOpConversion op.getLhsDilationAttr(), spatialDimMapping, rewriter); - auto filterDims = - llvm::to_vector<4>(op.getRhs().getType().cast().getShape()); + auto filterDims = llvm::to_vector<4>( + mlir::cast(op.getRhs().getType()).getShape()); auto getReassociationIndicesToCollapseLastTwoDims = [](Value v) { SmallVector reassociations; - int64_t rank = v.getType().cast().getRank(); + int64_t rank = mlir::cast(v.getType()).getRank(); for (int64_t i = 0; i < rank - 1; ++i) reassociations.emplace_back(1, i); reassociations.back().push_back(rank - 1); return reassociations; @@ -3406,7 +3412,8 @@ struct DepthwiseConvolutionOpConversion op.getFeatureGroupCount(); auto reshapedFilterType = RankedTensorType::get( reshapedFilterDims, - op.getRhs().getType().cast().getElementType()); + mlir::cast(op.getRhs().getType()) + .getElementType()); reshapedFilter = rewriter.create(loc, reshapedFilterType, filter); @@ -3591,7 +3598,7 @@ struct ReduceWindowOpOnTensorsGenericConversion llvm::SmallVector broadcastValues; for (uint64_t i = 0, s = initValues.size(); i < s; i++) { Value initValue = initValues[i]; - auto resultTy = resultTypes[i].cast(); + auto resultTy = mlir::cast(resultTypes[i]); if (!resultTy.hasStaticShape()) return failure(); auto broadcastSizes = rewriter.getI64TensorAttr(resultTy.getShape()); @@ -3658,17 +3665,17 @@ struct ReduceWindowOpOnTensorsGenericConversion // args will correlate with the LHS and the inputs correlate with the RHS. for (const auto& [i, type] : llvm::enumerate(resultTypes)) { auto idx = inputs.size() + i - 1; - signatureConverter.addInputs(idx, - type.cast().getElementType()); + signatureConverter.addInputs( + idx, mlir::cast(type).getElementType()); } signatureConverter.addInputs( - inputs.back().getType().cast().getElementType()); + mlir::cast(inputs.back().getType()).getElementType()); for (const auto& [i, input] : llvm::enumerate(ArrayRef(inputs).drop_back())) { signatureConverter.addInputs( - i, input.getType().cast().getElementType()); + i, mlir::cast(input.getType()).getElementType()); } rewriter.applySignatureConversion(®ion, signatureConverter, @@ -3697,8 +3704,8 @@ struct ReduceWindowOpConversion static PoolingType getPoolingType(mhlo::ReduceWindowOp reduceOp, int resultIndex) { - auto rank = - reduceOp.getResultTypes()[resultIndex].cast().getRank(); + auto rank = mlir::cast(reduceOp.getResultTypes()[resultIndex]) + .getRank(); if (Operation* op = reduceOp.getReductionOp(resultIndex)) { if (isa(*op) && rank == 4) return PoolingType::k2DMin; if (isa(*op) && rank == 5) return PoolingType::k3DMin; @@ -3714,7 +3721,7 @@ struct ReduceWindowOpConversion mhlo::ReduceWindowOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto loc = op.getLoc(); - int rank = op.getResultTypes()[0].cast().getRank(); + int rank = mlir::cast(op.getResultTypes()[0]).getRank(); if (rank != 4 && rank != 5) { return rewriter.notifyMatchFailure( op, "expected NHWC/NDHWC pooling-based op"); @@ -3778,8 +3785,8 @@ struct ReduceWindowOpConversion OpResult result = std::get<0>(it); Value input = std::get<1>(it); Value initValue = std::get<2>(it); - auto resultType = result.getType().cast(); - if (!input.getType().cast().getElementType().isF32()) { + auto resultType = mlir::cast(result.getType()); + if (!mlir::cast(input.getType()).getElementType().isF32()) { return rewriter.notifyMatchFailure(op, "expected element type to be f32"); } @@ -3799,9 +3806,9 @@ struct ReduceWindowOpConversion } else { auto i = en.index() - 1; auto stride = - strides.cast().getValues()[i]; - auto dilation = - dilations.cast().getValues()[i]; + mlir::cast(strides).getValues()[i]; + auto dilation = mlir::cast(dilations) + .getValues()[i]; // let j = i * stride // output[i] = reduce( input[j, j + window_size * dilation) ) Value offset = rewriter.create( @@ -3883,15 +3890,16 @@ struct TorchIndexSelectOpConversion ConversionPatternRewriter& rewriter) const final { int axis = static_cast(op.getDim()); int batch = static_cast(op.getBatchDims()); - auto indexShapedType = adaptor.getIndex().getType().cast(); + auto indexShapedType = mlir::cast(adaptor.getIndex().getType()); int numIndices = static_cast(indexShapedType.getRank()); - auto operandShapedType = adaptor.getOperand().getType().cast(); + auto operandShapedType = + mlir::cast(adaptor.getOperand().getType()); if (axis < 0) axis += static_cast(operandShapedType.getRank()); if (batch < 0) batch += numIndices; Location loc = op.getLoc(); - auto resultType = this->typeConverter->convertType(op.getResult().getType()) - .cast(); + auto resultType = mlir::cast( + this->typeConverter->convertType(op.getResult().getType())); int rank = static_cast(resultType.getRank()); // The output shape is @@ -3968,7 +3976,7 @@ struct TorchIndexSelectOpConversion auto* block = rewriter.createBlock(region, region->end()); for (auto blockArgs : linalgOpArgs) { bodyArgTypes.push_back( - blockArgs.getType().cast().getElementType()); + mlir::cast(blockArgs.getType()).getElementType()); } block->addArguments(bodyArgTypes, SmallVector(bodyArgTypes.size(), loc)); @@ -4015,10 +4023,10 @@ struct GatherConversion : public OpConversionPattern { Value startIndices = adaptor.getStartIndices(); Value operand = adaptor.getOperand(); - auto resultType = typeConverter->convertType(gatherOp.getType()) - .dyn_cast(); + auto resultType = mlir::dyn_cast( + typeConverter->convertType(gatherOp.getType())); RankedTensorType startIndicesType = - startIndices.getType().dyn_cast(); + mlir::dyn_cast(startIndices.getType()); // We could actually deal with an unranked result by inferring the result // rank, but the current reifyReturnTypes doesn't support unranked either. if (!resultType || !startIndicesType) @@ -4168,13 +4176,13 @@ struct GatherConversion : public OpConversionPattern { indexFromOffset[i])); Value extractOperand; - if (operand.getType().isa()) { + if (mlir::isa(operand.getType())) { extractOperand = operand; } else { // Cannot extract from unranked tensors, cast to ranked first. SmallVector dims(operandRank, ShapedType::kDynamic); auto type = RankedTensorType::get( - dims, operand.getType().cast().getElementType()); + dims, mlir::cast(operand.getType()).getElementType()); extractOperand = rewriter.create(loc, type, operand); } Value element = @@ -4211,14 +4219,14 @@ class DotGeneralOpConversion : public OpConversionPattern { // Convert unsigned to signed. This works because signed and unsigned // integer matmul is the same operation in two's complement. auto outputType = - typeConverter->convertType(op.getType()).cast(); + mlir::cast(typeConverter->convertType(op.getType())); auto targetRank = outputType.getRank(); auto totalLoopCount = numContracting + targetRank; - auto lhsRank = adaptor.getLhs().getType().cast().getRank(); + auto lhsRank = mlir::cast(adaptor.getLhs().getType()).getRank(); auto lhsExtraDims = lhsRank - lhsBatchingDims.size() - lhsContractingDims.size(); - auto rhsRank = adaptor.getRhs().getType().cast().getRank(); + auto rhsRank = mlir::cast(adaptor.getRhs().getType()).getRank(); Location loc = op.getLoc(); auto emptyTensor = @@ -4302,13 +4310,13 @@ class PointwiseToLinalgMapConverter : public OpConversionPattern { // Find result type, if on tensors. std::optional resultTy; - resultTy = this->typeConverter->convertType(op->getResultTypes().front()) - .template dyn_cast(); + resultTy = mlir::dyn_cast( + this->typeConverter->convertType(op->getResultTypes().front())); // Check result type compatibility. if (!resultTy || !resultTy->hasRank() || resultTy->getRank() != maxRank || !(resultTy->getElementType().isSignlessIntOrFloat() || - resultTy->getElementType().isa())) { + mlir::isa(resultTy->getElementType()))) { return rewriter.notifyMatchFailure( op, "mismatched operand/result types or iterator count"); } @@ -4352,7 +4360,7 @@ class PointwiseToLinalgMapConverter : public OpConversionPattern { protected: int64_t getRank(Value v) const { - return v.getType().cast().getRank(); + return mlir::cast(v.getType()).getRank(); } int64_t getMaxRank(typename OpTy::Adaptor adaptor) const { @@ -4392,7 +4400,8 @@ class SetDimensionSizeConverter // regular dynamic shape. Note that the bounds annotation is still around // but may be no longer valid depending on choices made by bufferization. Location loc = setDimensionSizeOp.getLoc(); - auto resultType = setDimensionSizeOp.getType().cast(); + auto resultType = + mlir::cast(setDimensionSizeOp.getType()); SmallVector offsets(resultType.getRank(), rewriter.getIndexAttr(0)); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_to_standard/legalize_to_standard.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_to_standard/legalize_to_standard.cc index 2b8b4a051ffa53..be752397f72fcb 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_to_standard/legalize_to_standard.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_to_standard/legalize_to_standard.cc @@ -27,6 +27,7 @@ limitations under the License. #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -48,8 +49,8 @@ class CompareIConvert : public OpRewritePattern { PatternRewriter &rewriter) const override { auto lhs = op.getLhs(); auto rhs = op.getRhs(); - auto lhsType = lhs.getType().cast(); - auto rhsType = rhs.getType().cast(); + auto lhsType = mlir::cast(lhs.getType()); + auto rhsType = mlir::cast(rhs.getType()); // Broadcasting not supported by this rewrite. if (lhsType.getShape() != rhsType.getShape()) return failure(); @@ -96,14 +97,14 @@ class CompareFConvert : public OpRewritePattern { PatternRewriter &rewriter) const override { auto lhs = op.getLhs(); auto rhs = op.getRhs(); - auto lhsType = lhs.getType().cast(); - auto rhsType = rhs.getType().cast(); + auto lhsType = mlir::cast(lhs.getType()); + auto rhsType = mlir::cast(rhs.getType()); // Broadcasting not supported by this rewrite. if (lhsType.getShape() != rhsType.getShape()) return failure(); - if (!lhsType.getElementType().isa() || - !rhsType.getElementType().isa()) + if (!mlir::isa(lhsType.getElementType()) || + !mlir::isa(rhsType.getElementType())) return failure(); std::optional comparePredicate = std::nullopt; @@ -146,7 +147,7 @@ class ConvertIotaOp : public OpRewritePattern { LogicalResult matchAndRewrite(mhlo::IotaOp op, PatternRewriter &rewriter) const override { - auto outputType = op.getType().cast(); + auto outputType = mlir::cast(op.getType()); auto outputSize = outputType.getNumElements(); auto dimension = op.getIotaDimension(); auto maxDimSize = outputType.getDimSize(dimension); @@ -154,7 +155,7 @@ class ConvertIotaOp : public OpRewritePattern { auto elementType = outputType.getElementType(); int bitwidth; - auto complexTy = elementType.dyn_cast(); + auto complexTy = mlir::dyn_cast(elementType); Type intOrFloatTy = elementType; if (complexTy) intOrFloatTy = complexTy.getElementType(); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc index 432c7733544772..3fa7e61dd7096b 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/legalize_torch_index_select_to_gather/legalize_torch_index_select_to_gather.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -68,7 +69,7 @@ struct TorchIndexSelectIsGather : public OpRewritePattern { int64_t indexVectorDim = index.getType().getRank(); auto indexTy = index.getType(); - auto indexElementTy = indexTy.getElementType().dyn_cast(); + auto indexElementTy = mlir::dyn_cast(indexTy.getElementType()); if (!indexElementTy) { return rewriter.notifyMatchFailure( op, "index must have integer element type"); diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/lower_general_dot/lower_general_dot.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/lower_general_dot/lower_general_dot.cc index ee672671a68592..f4141c5dbcbaa1 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/lower_general_dot/lower_general_dot.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/lower_general_dot/lower_general_dot.cc @@ -35,6 +35,7 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { @@ -76,9 +77,8 @@ Value transposeReshape(Value arg, Location loc, rewriter.getIntegerType(64)); auto transposePermutationAttr = - DenseIntElementsAttr::get(transposePermutationType, - llvm::ArrayRef(transposePermutation)) - .cast(); + mlir::cast(DenseIntElementsAttr::get( + transposePermutationType, llvm::ArrayRef(transposePermutation))); // Compute the resulting shape. llvm::SmallVector transposedShape; @@ -144,7 +144,7 @@ Value transposeReshape(Value arg, Location loc, Value processDotArg(Value arg, Location loc, ArrayRef contractDimsAttr, bool outerDimsFirst, PatternRewriter &rewriter) { - auto shape = arg.getType().cast().getShape(); + auto shape = mlir::cast(arg.getType()).getShape(); llvm::SmallVector isOuterDim; isOuterDim.resize(shape.size(), true); @@ -197,7 +197,7 @@ struct GeneralDotConvert : public OpRewritePattern { auto opPrecisionConfig = op.getPrecisionConfig(); if (opPrecisionConfig.has_value()) precisionConfig = *opPrecisionConfig; - auto resultTy = op.getType().cast(); + auto resultTy = mlir::cast(op.getType()); auto lhsContractingDims = dotNumbers.getLhsContractingDimensions(); auto rhsContractingDims = dotNumbers.getRhsContractingDimensions(); @@ -205,8 +205,8 @@ struct GeneralDotConvert : public OpRewritePattern { auto lhs = op.getLhs(); auto rhs = op.getRhs(); - RankedTensorType lhsTy = lhs.getType().dyn_cast(); - RankedTensorType rhsTy = rhs.getType().dyn_cast(); + RankedTensorType lhsTy = mlir::dyn_cast(lhs.getType()); + RankedTensorType rhsTy = mlir::dyn_cast(rhs.getType()); if (!lhsTy || !rhsTy) return failure(); // The MHLO dot operator directly supports a vector dot product @@ -264,8 +264,8 @@ struct GeneralDotConvert : public OpRewritePattern { rhs, loc, rhsContractingDims, /*outerDimsFirst=*/false, rewriter)); // Accept only static shaped types. - auto lhsShapeType = lhs.getType().dyn_cast_or_null(); - auto rhsShapeType = rhs.getType().dyn_cast_or_null(); + auto lhsShapeType = mlir::dyn_cast_or_null(lhs.getType()); + auto rhsShapeType = mlir::dyn_cast_or_null(rhs.getType()); if (!lhsShapeType || !rhsShapeType) return failure(); // Generate new dot operator on expanded types. @@ -293,7 +293,7 @@ struct GeneralDotConvert : public OpRewritePattern { auto getDynamicDims = [&](Value arg, llvm::ArrayRef contractingDims) { - RankedTensorType ty = arg.getType().cast(); + RankedTensorType ty = mlir::cast(arg.getType()); int index = 0; for (auto contractingDim : contractingDims) { for (; index < contractingDim; index++) { diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h b/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h index 03eec3cd8312a6..d7d94a7659048c 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LLVM.h" namespace mlir { namespace mhlo { @@ -231,13 +232,13 @@ struct MapMhloOpToScalarOpImpl { }; struct IsAnyIntegerType { - bool operator()(Type t) { return t.isa(); } + bool operator()(Type t) { return mlir::isa(t); } }; struct IsSignedIntegerType { bool operator()(Type t) { // Pretend that signless is signed. This will change eventually. - return t.isa() && !t.isUnsignedInteger() && + return mlir::isa(t) && !t.isUnsignedInteger() && !t.isSignlessInteger(1); } }; @@ -249,11 +250,11 @@ struct IsUnsignedIntegerType { }; struct IsFloatType { - bool operator()(Type t) { return t.isa(); } + bool operator()(Type t) { return mlir::isa(t); } }; struct IsComplexType { - bool operator()(Type t) { return t.isa(); } + bool operator()(Type t) { return mlir::isa(t); } }; template