Skip to content

Commit

Permalink
NFC: Use the free function variants for dyn_cast/cast/isa/....
Browse files Browse the repository at this point in the history
The member functions in `Type/Attribute/Value/Location/AffineExpr` are deprecated and will go away.

PiperOrigin-RevId: 628090109
  • Loading branch information
chsigg authored and tensorflower-gardener committed Apr 25, 2024
1 parent fb8e059 commit 70a5f10
Show file tree
Hide file tree
Showing 522 changed files with 4,833 additions and 4,274 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/BUILD
Expand Up @@ -37,6 +37,7 @@ cc_library(
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)

Expand Down
4 changes: 4 additions & 0 deletions tensorflow/compiler/mlir/lite/BUILD
Expand Up @@ -310,6 +310,7 @@ cc_library(
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)

Expand Down Expand Up @@ -911,6 +912,7 @@ cc_library(
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
)

Expand Down Expand Up @@ -1037,6 +1039,7 @@ cc_library(
"@llvm-project//llvm:Analysis",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@local_tsl//tsl/platform:status",
"@local_xla//xla:statusor",
Expand Down Expand Up @@ -1210,6 +1213,7 @@ cc_library(
"//tensorflow/core/platform:errors",
"//tensorflow/lite/schema:schema_fbs",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@local_xla//xla:statusor",
],
)
Expand Down
Expand Up @@ -44,7 +44,7 @@ namespace common {

bool IsConstantOrNone(Operation* op) {
return (op->getNumResults() == 1 &&
op->getResult(0).getType().isa<NoneType>()) ||
mlir::isa<NoneType>(op->getResult(0).getType())) ||
matchPattern(op, m_Constant()) || isa<QConstOp>(op);
}

Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/mlir/lite/experimental/tac/BUILD
Expand Up @@ -42,6 +42,7 @@ cc_library(
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)

Expand Down Expand Up @@ -88,6 +89,7 @@ cc_library(
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:SideEffectInterfaces",
"@llvm-project//mlir:Support",
],
)

Expand Down
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/utils/utils.h"
Expand Down Expand Up @@ -52,7 +53,7 @@ bool NotTFLQuantDequantizeOp(Operation* op);

// Returns true if it is a shaped type of f32 elements.
inline bool IsF32ShapedType(Type t) {
if (auto shaped_type = t.dyn_cast_or_null<ShapedType>()) {
if (auto shaped_type = mlir::dyn_cast_or_null<ShapedType>(t)) {
return shaped_type.getElementType().isF32();
}
return false;
Expand Down
Expand Up @@ -29,6 +29,7 @@
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Region.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h"
#include "tensorflow/compiler/mlir/lite/experimental/tac/hardwares/target_hardware.h"
#include "tensorflow/compiler/mlir/lite/experimental/tac/runtime_metadata_generated.h"
Expand Down Expand Up @@ -82,8 +83,7 @@ std::optional<std::vector<float>> GetPerDeviceCosts(
for (const auto& kv : hardware_map) {
auto cost_attr = device_costs_attr.getNamed(kv.first);
if (!cost_attr.has_value()) return std::nullopt;
float cost = cost_attr->getValue()
.dyn_cast_or_null<mlir::FloatAttr>()
float cost = mlir::dyn_cast_or_null<mlir::FloatAttr>(cost_attr->getValue())
.getValueAsDouble();
device_costs[kv.second] = cost;
}
Expand Down
Expand Up @@ -59,13 +59,14 @@ int64_t GetTransferredTensorBytes(func::CallOp from_graph,
for (auto input : to_graph.getOperands()) {
Operation* input_op = input.getDefiningOp();
if (input_op && input_op == from_graph.getOperation()) {
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
auto input_type =
mlir::dyn_cast_or_null<RankedTensorType>(input.getType());
if (input_type == nullptr || !input_type.hasStaticShape()) continue;
// Quantized type does not support getSizeInBits.
if (IsQUI8Type(input_type) || IsQI8Type(input_type)) {
total_size_transferred += input_type.getNumElements() * 8;
} else {
auto s_type = input_type.cast<ShapedType>();
auto s_type = mlir::cast<ShapedType>(input_type);
total_size_transferred +=
s_type.getNumElements() * s_type.getElementTypeBitWidth();
}
Expand All @@ -81,7 +82,8 @@ int64_t GetTransferredElementCount(func::CallOp from_graph,
for (auto input : to_graph.getOperands()) {
Operation* input_op = input.getDefiningOp();
if (input_op && input_op == from_graph.getOperation()) {
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
auto input_type =
mlir::dyn_cast_or_null<RankedTensorType>(input.getType());
if (input_type == nullptr || !input_type.hasStaticShape()) continue;
total_element_count += input_type.getNumElements();
}
Expand Down
Expand Up @@ -156,13 +156,13 @@ struct FoldQuantizedI32ToFloat : public OpRewritePattern<TFL::DequantizeOp> {
if (!IsQI32Type(input_dequant.getType())) return failure();

auto output_type =
dequant_op.getOutput().getType().dyn_cast_or_null<ShapedType>();
mlir::dyn_cast_or_null<ShapedType>(dequant_op.getOutput().getType());
if (!output_type || !output_type.getElementType().isF32()) return failure();

auto input_type = input_dequant.getType().dyn_cast<ShapedType>();
auto input_type = mlir::dyn_cast<ShapedType>(input_dequant.getType());
// TODO(renjieliu): support UniformQuantizedPerAxisType.
auto q_type = input_type.getElementType()
.dyn_cast_or_null<quant::UniformQuantizedType>();
auto q_type = mlir::dyn_cast_or_null<quant::UniformQuantizedType>(
input_type.getElementType());
if (!q_type) return failure();

const float scale = q_type.getScale();
Expand All @@ -183,9 +183,9 @@ struct FoldQuantizedI32ToFloat : public OpRewritePattern<TFL::DequantizeOp> {
};

auto dequant_values =
input_values.cast<DenseIntOrFPElementsAttr>().mapValues(
FloatType::getF32(rewriter.getContext()),
llvm::function_ref<DequantizeFuncType>(dequantize_func));
mlir::cast<DenseIntOrFPElementsAttr>(input_values)
.mapValues(FloatType::getF32(rewriter.getContext()),
llvm::function_ref<DequantizeFuncType>(dequantize_func));
rewriter.replaceOpWithNewOp<TFL::ConstOp>(dequant_op, dequant_op.getType(),
dequant_values);

Expand Down
Expand Up @@ -96,11 +96,11 @@ LogicalResult EnsureBias(Operation* op, int bias_idx,
PatternRewriter& rewriter) {
auto bias = op->getOperand(bias_idx);

if (!bias.getType().isa<NoneType>()) return failure();
if (!mlir::isa<NoneType>(bias.getType())) return failure();

// Proceed to create a zero bias.
auto output = op->getResult(0);
auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
auto output_type = mlir::dyn_cast_or_null<RankedTensorType>(output.getType());
if (!output_type) return failure();

// bias should be a vector sized of the last output dim.
Expand Down Expand Up @@ -163,7 +163,7 @@ SmallVector<Value, 4> SliceOutputs(Operation* split_op, Value input,
SmallVector<int32_t, 4> slice_size;
auto current_output = split_op->getResult(i);
auto current_output_type =
current_output.getType().cast<RankedTensorType>();
mlir::cast<RankedTensorType>(current_output.getType());
for (int d = 0; d < input_type.getRank(); ++d) {
if (d == split_dim) {
// Split dimension.
Expand Down Expand Up @@ -208,7 +208,7 @@ LogicalResult LowerPackIntoConcatReshape::matchAndRewrite(
TFL::PackOp pack_op, PatternRewriter& rewriter) const {
// Pack op should have same shape type.
SmallVector<Value, 5> pack_inputs(pack_op.getValues());
auto input_type = pack_inputs[0].getType().dyn_cast<RankedTensorType>();
auto input_type = mlir::dyn_cast<RankedTensorType>(pack_inputs[0].getType());
if (!input_type) return failure();

// Figure out output shapes.
Expand Down Expand Up @@ -266,8 +266,8 @@ LogicalResult SquaredDifference::matchAndRewrite(
TFL::SquaredDifferenceOp squared_diff_op, PatternRewriter& rewriter) const {
auto x = squared_diff_op.getLhs();
auto y = squared_diff_op.getRhs();
auto x_type = x.getType().dyn_cast<RankedTensorType>();
auto y_type = y.getType().dyn_cast<RankedTensorType>();
auto x_type = mlir::dyn_cast<RankedTensorType>(x.getType());
auto y_type = mlir::dyn_cast<RankedTensorType>(y.getType());
if (!x_type || !y_type) return failure();
if (x_type.getShape() != y_type.getShape()) return failure();

Expand All @@ -290,16 +290,16 @@ LogicalResult UnrollSplit::matchAndRewrite(TFL::SplitOp split_op,
PatternRewriter& rewriter) const {
auto num_splits = split_op.getNumSplits();
auto input = split_op.getValue();
auto input_type = input.getType().dyn_cast<RankedTensorType>();
auto input_type = mlir::dyn_cast<RankedTensorType>(input.getType());
if (input_type == nullptr || !input_type.hasStaticShape()) return failure();

for (auto result : split_op.getResults()) {
auto result_type = result.getType().dyn_cast<RankedTensorType>();
auto result_type = mlir::dyn_cast<RankedTensorType>(result.getType());
if (result_type == nullptr) return failure();
}

auto output = split_op.getResult(0);
auto output_type = output.getType().cast<RankedTensorType>();
auto output_type = mlir::cast<RankedTensorType>(output.getType());

// TODO(renjieliu): change to use split_dim when we raise the constants
// as well.
Expand Down Expand Up @@ -330,11 +330,11 @@ LogicalResult UnrollSplitV::matchAndRewrite(TFL::SplitVOp splitv_op,
return failure();

auto input = splitv_op.getValue();
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
auto input_type = mlir::dyn_cast_or_null<RankedTensorType>(input.getType());
if (!input_type || !input_type.hasRank()) return failure();

for (auto result : splitv_op.getResults()) {
auto result_type = result.getType().dyn_cast<RankedTensorType>();
auto result_type = mlir::dyn_cast<RankedTensorType>(result.getType());
if (result_type == nullptr) return failure();
}

Expand Down Expand Up @@ -371,20 +371,21 @@ LogicalResult PadSlice::matchAndRewrite(TFL::SliceOp slice_op,
// We have to know the shape of the input, as well as the begin/size.
// also, begin and size have to be constants.
auto input = slice_op.getInput();
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
auto input_type = mlir::dyn_cast_or_null<RankedTensorType>(input.getType());
if (!input_type || !input_type.hasStaticShape()) return failure();

if (input_type.getRank() >= 4) return failure();

auto begin = slice_op.getBegin();
auto begin_type = begin.getType().dyn_cast_or_null<RankedTensorType>();
auto begin_type = mlir::dyn_cast_or_null<RankedTensorType>(begin.getType());
if (!begin_type || !begin_type.hasStaticShape()) return failure();

auto size = slice_op.getSize();
auto size_type = size.getType().dyn_cast_or_null<RankedTensorType>();
auto size_type = mlir::dyn_cast_or_null<RankedTensorType>(size.getType());
if (!size_type || !size_type.hasStaticShape()) return failure();

auto output_type = slice_op.getType().dyn_cast_or_null<RankedTensorType>();
auto output_type =
mlir::dyn_cast_or_null<RankedTensorType>(slice_op.getType());
if (!output_type || !output_type.hasStaticShape()) return failure();

// Pad 0s in front of the begin.
Expand Down Expand Up @@ -472,17 +473,17 @@ LogicalResult FullyConnectedToConv::matchAndRewrite(
TFL::FullyConnectedOp fc_op, PatternRewriter& rewriter) const {
// We have to know the shape of the input.
auto input = fc_op.getInput();
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
auto input_type = mlir::dyn_cast_or_null<RankedTensorType>(input.getType());
if (!input_type || !input_type.hasStaticShape()) return failure();

// We have to know the shape of the weight.
auto weight = fc_op.getFilter();
auto weight_type = weight.getType().dyn_cast_or_null<RankedTensorType>();
auto weight_type = mlir::dyn_cast_or_null<RankedTensorType>(weight.getType());
if (!weight_type || !weight_type.hasStaticShape()) return failure();

// We have to know the shape of the output as well.
auto output = fc_op.getResult(0);
auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
auto output_type = mlir::dyn_cast_or_null<RankedTensorType>(output.getType());
if (!output_type || !output_type.hasStaticShape()) return failure();

// Insert a reshape after the input.
Expand Down Expand Up @@ -532,13 +533,14 @@ LogicalResult PadConcat::matchAndRewrite(TFL::ConcatenationOp concat_op,
PatternRewriter& rewriter) const {
int rank = -1;
for (auto input : concat_op.getValues()) {
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
auto input_type = mlir::dyn_cast_or_null<RankedTensorType>(input.getType());
if (!input_type || !input_type.hasStaticShape()) return failure();

rank = input_type.getRank();
}

auto output_type = concat_op.getType().dyn_cast_or_null<RankedTensorType>();
auto output_type =
mlir::dyn_cast_or_null<RankedTensorType>(concat_op.getType());
if (!output_type || !output_type.hasStaticShape()) return failure();

if (rank >= 4) return failure();
Expand All @@ -547,7 +549,7 @@ LogicalResult PadConcat::matchAndRewrite(TFL::ConcatenationOp concat_op,
// We will insert a reshape op after every input.
SmallVector<Value, 4> reshape_ops;
for (auto input : concat_op.getValues()) {
auto input_type = input.getType().cast<RankedTensorType>();
auto input_type = mlir::cast<RankedTensorType>(input.getType());
// Get the new shape.
SmallVector<int64_t, 4> new_shape;
for (int i = 0; i < 4 - rank; ++i) {
Expand Down Expand Up @@ -603,7 +605,7 @@ LogicalResult PadConcat::matchAndRewrite(TFL::ConcatenationOp concat_op,
LogicalResult ReduceMeanToAvgPool::matchAndRewrite(
TFL::MeanOp mean_op, PatternRewriter& rewriter) const {
auto input = mean_op.getInput();
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
auto input_type = mlir::dyn_cast_or_null<RankedTensorType>(input.getType());
// Only 4d is supported here.
if (!input_type || input_type.getRank() != 4) return failure();

Expand All @@ -619,7 +621,7 @@ LogicalResult ReduceMeanToAvgPool::matchAndRewrite(
}

auto output = mean_op.getOutput();
auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
auto output_type = mlir::dyn_cast_or_null<RankedTensorType>(output.getType());
if (!output_type) return failure();

auto input_quantized_type =
Expand Down Expand Up @@ -669,7 +671,7 @@ LogicalResult ReduceMeanToAvgPool::matchAndRewrite(
LogicalResult InsertRequantForReduceMean::matchAndRewrite(
TFL::MeanOp mean_op, PatternRewriter& rewriter) const {
auto input = mean_op.getInput();
auto input_type = input.getType().dyn_cast_or_null<ShapedType>();
auto input_type = mlir::dyn_cast_or_null<ShapedType>(input.getType());
if (!input_type) return failure();

// Only need to do this for quantized input.
Expand All @@ -678,7 +680,7 @@ LogicalResult InsertRequantForReduceMean::matchAndRewrite(
if (!input_quantized_type) return failure();

auto output = mean_op.getOutput();
auto output_type = output.getType().dyn_cast_or_null<ShapedType>();
auto output_type = mlir::dyn_cast_or_null<ShapedType>(output.getType());
if (!output_type) return failure();
auto output_quantized_type =
quant::QuantizedType::getQuantizedElementType(output_type);
Expand Down
Expand Up @@ -107,11 +107,12 @@ bool IsConstOrQConstInt(Operation* op) {

if (auto arith_const_op = dyn_cast_or_null<arith::ConstantOp>(op)) {
// arith ConstOp path.
auto type = arith_const_op.getType().cast<ShapedType>().getElementType();
auto type =
mlir::cast<ShapedType>(arith_const_op.getType()).getElementType();
if (!type.isInteger(32) && !type.isInteger(64)) return false;
} else if (auto const_op = dyn_cast_or_null<TFL::ConstOp>(op)) {
// ConstOp path.
auto type = const_op.getType().cast<ShapedType>().getElementType();
auto type = mlir::cast<ShapedType>(const_op.getType()).getElementType();
if (!type.isInteger(32) && !type.isInteger(64)) return false;
} else {
// QConstOp path.
Expand Down
Expand Up @@ -113,18 +113,11 @@ void AddAttrs(OpsAdded& ops_added, OpBuilder& builder, int func_count) {
added_func_op->setAttr(kInterfaceNameAttr, interface_name);
added_call_op->setAttr(kInterfaceNameAttr, interface_name);

StringAttr device = added_func_op->getRegion(0)
.getBlocks()
.front()
.front()
.getAttr(kDevice)
.cast<StringAttr>();
StringAttr inference_type = added_func_op->getRegion(0)
.getBlocks()
.front()
.front()
.getAttr(kInferenceType)
.cast<StringAttr>();
StringAttr device = mlir::cast<StringAttr>(
added_func_op->getRegion(0).getBlocks().front().front().getAttr(kDevice));
StringAttr inference_type = mlir::cast<StringAttr>(
added_func_op->getRegion(0).getBlocks().front().front().getAttr(
kInferenceType));
added_call_op->setAttr(kDevice, device);
added_call_op->setAttr(kInferenceType, inference_type);
added_func_op->setAttr(kDevice, device);
Expand Down
Expand Up @@ -110,7 +110,7 @@ void ApplyTacFilter(ModuleOp module, const TacFilter& tac_filter,

llvm::Regex op_regex(tac_filter.op_filter().op_name_pattern());
module.walk([&](Operation* op) {
auto named_loc = op->getLoc().dyn_cast<NameLoc>();
auto named_loc = mlir::dyn_cast<NameLoc>(op->getLoc());
if (!named_loc) {
return;
}
Expand Down

0 comments on commit 70a5f10

Please sign in to comment.