Skip to content

Commit

Permalink
PR #55888: [mhlo] Clean up C-API references to CPP
Browse files Browse the repository at this point in the history
Imported from GitHub PR tensorflow/tensorflow#55888

remove std::string referencres in ```tensorflow/compiler/mlir/hlo/include/mlir-hlo-c/Attributes.h``` file.
Copybara import of the project:

--
79e4773aee3a5c09814b3d349f92324bc7898f49 by lipracer <lipracer@gmail.com>:

[mhlo] Clean up C-API references to CPP

--
7fad8c409cda3f50284d1f487cfa2da496798597 by lipracer <lipracer@gmail.com>:

[mhlo] replace const char* with MlirStringRef

Merging this change closes #55888

PiperOrigin-RevId: 447594137
  • Loading branch information
lipracer authored and TensorFlow MLIR Team committed May 9, 2022
1 parent 53711b0 commit 9d95d60
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 91 deletions.
54 changes: 26 additions & 28 deletions include/mlir-hlo-c/Attributes.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ limitations under the License.

#include <sys/types.h>

#include <string>

#include "mlir-c/IR.h"
#include "mlir-c/Support.h"

Expand Down Expand Up @@ -175,105 +173,105 @@ mlirMhloConvDimensionNumbersGetOutputSpatialDimensionsElem(MlirAttribute attr,
//
// Creates a new ComparisonDirection attribute with the given
// 'direction' string parameter.
MLIR_CAPI_EXPORTED MlirAttribute mlirMhloComparisonDirectionAttrGet(
MlirContext ctx, const std::string &direction);
MLIR_CAPI_EXPORTED MlirAttribute
mlirMhloComparisonDirectionAttrGet(MlirContext ctx, MlirStringRef direction);

// Returns true if the given attribute is a ComparisonDirection attribute.
MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsAComparisonDirectionAttr(
MlirAttribute attr);

// Returns the direction string associated with ComparisonDirection attribute.
MLIR_CAPI_EXPORTED std::string mlirMhloComparisonDirectionAttrGetDirection(
MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirStringRef
mlirMhloComparisonDirectionAttrGetDirection(MlirAttribute attr);

//
// ComparisonTypeAttr.
//
// Creates a new ComparisonType attribute with the given 'type' string
// parameter.
MLIR_CAPI_EXPORTED MlirAttribute
mlirMhloComparisonTypeAttrGet(MlirContext ctx, const std::string &type);
mlirMhloComparisonTypeAttrGet(MlirContext ctx, MlirStringRef type);

// Returns true if the given attribute is a ComparisonType attribute.
MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsAComparisonTypeAttr(
MlirAttribute attr);

// Returns the type string associated with ComparisonType attribute.
MLIR_CAPI_EXPORTED std::string mlirMhloComparisonTypeAttrGetType(
MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirStringRef
mlirMhloComparisonTypeAttrGetType(MlirAttribute attr);

//
// PrecisionAttr.
//
// Creates a new Precision attribute with the given 'type' string
// parameter.
MLIR_CAPI_EXPORTED MlirAttribute
mlirMhloPrecisionAttrGet(MlirContext ctx, const std::string &type);
MLIR_CAPI_EXPORTED MlirAttribute mlirMhloPrecisionAttrGet(MlirContext ctx,
MlirStringRef type);

// Returns true if the given attribute is a Precision attribute.
MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsAPrecisionAttr(MlirAttribute attr);

// Returns the type string associated with Precision attribute.
MLIR_CAPI_EXPORTED std::string mlirMhloPrecisionAttrGetPrecision(
MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirStringRef
mlirMhloPrecisionAttrGetPrecision(MlirAttribute attr);

//
// FftTypeAttr.
//
// Creates a new FftType attribute with the given 'type' string parameter.
MLIR_CAPI_EXPORTED MlirAttribute
mlirMhloFftTypeAttrGet(MlirContext ctx, const std::string &type);
MLIR_CAPI_EXPORTED MlirAttribute mlirMhloFftTypeAttrGet(MlirContext ctx,
MlirStringRef type);

// Returns true if the given attribute is a FftType attribute.
MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsAFftTypeAttr(MlirAttribute attr);

// Returns the type string associated with FftType attribute.
MLIR_CAPI_EXPORTED std::string mlirMhloFftTypeAttrGetFftType(
MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirStringRef
mlirMhloFftTypeAttrGetFftType(MlirAttribute attr);

//
// DequantizeModeAttr.
//
// Creates a new DequantizeMode attribute with the given 'mode' string
// parameter.
MLIR_CAPI_EXPORTED MlirAttribute
mlirMhloDequantizeModeAttrGet(MlirContext ctx, const std::string &mode);
mlirMhloDequantizeModeAttrGet(MlirContext ctx, MlirStringRef mode);

// Returns true if the given attribute is a DequantizeMode attribute.
MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsADequantizeModeAttr(
MlirAttribute attr);

// Returns the mode string associated with DequantizeMode attribute.
MLIR_CAPI_EXPORTED std::string mlirMhloDequantizeModeAttrGetDequantizeMode(
MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirStringRef
mlirMhloDequantizeModeAttrGetDequantizeMode(MlirAttribute attr);

//
// TransposeAttr.
//
// Creates a new Transpose attribute with the given 'type' string parameter.
MLIR_CAPI_EXPORTED MlirAttribute
mlirMhloTransposeAttrGet(MlirContext ctx, const std::string &type);
MLIR_CAPI_EXPORTED MlirAttribute mlirMhloTransposeAttrGet(MlirContext ctx,
MlirStringRef type);

// Returns true if the given attribute is a Transpose attribute.
MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsATransposeAttr(MlirAttribute attr);

// Returns the type string associated with Transpose attribute.
MLIR_CAPI_EXPORTED std::string mlirMhloTransposeAttrGetTranspose(
MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirStringRef
mlirMhloTransposeAttrGetTranspose(MlirAttribute attr);

//
// FusionKindAttr.
//
// Creates a new FusionKind attribute with the given 'kind' string parameter.
MLIR_CAPI_EXPORTED MlirAttribute
mlirMhloFusionKindAttrGet(MlirContext ctx, const std::string &kind);
MLIR_CAPI_EXPORTED MlirAttribute mlirMhloFusionKindAttrGet(MlirContext ctx,
MlirStringRef kind);

// Returns true if the given attribute is a FusionKind attribute.
MLIR_CAPI_EXPORTED bool mlirMhloAttributeIsAFusionKindAttr(MlirAttribute attr);

// Returns the fusion-kind string associated with FusionKind attribute.
MLIR_CAPI_EXPORTED std::string mlirMhloFusionKindAttrGetFusionKind(
MlirAttribute attr);
MLIR_CAPI_EXPORTED MlirStringRef
mlirMhloFusionKindAttrGetFusionKind(MlirAttribute attr);

//
// RngAlgorithmAttr.
Expand Down
84 changes: 35 additions & 49 deletions lib/CAPI/Attributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ limitations under the License.

#include "mlir-hlo-c/Attributes.h"

#include <string>

#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base_attrs.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
Expand Down Expand Up @@ -355,9 +353,9 @@ int64_t mlirMhloConvDimensionNumbersGetOutputSpatialDimensionsElem(
// ComparisonDirectionAttr.
//
MlirAttribute mlirMhloComparisonDirectionAttrGet(MlirContext ctx,
const std::string &direction) {
MlirStringRef direction) {
llvm::Optional<mlir::mhlo::ComparisonDirection> compare_direction =
mlir::mhlo::symbolizeComparisonDirection(direction);
mlir::mhlo::symbolizeComparisonDirection(unwrap(direction));
if (!compare_direction)
llvm_unreachable("Invalid comparison-direction specified.");
return wrap(mlir::mhlo::ComparisonDirectionAttr::get(
Expand All @@ -368,22 +366,19 @@ bool mlirMhloAttributeIsAComparisonDirectionAttr(MlirAttribute attr) {
return unwrap(attr).isa<mlir::mhlo::ComparisonDirectionAttr>();
}

std::string mlirMhloComparisonDirectionAttrGetDirection(MlirAttribute attr) {
return mlir::mhlo::stringifyComparisonDirection(
unwrap(attr)
.cast<mlir::mhlo::ComparisonDirectionAttr>()
.getValue())
.str();
MlirStringRef mlirMhloComparisonDirectionAttrGetDirection(MlirAttribute attr) {
return wrap(mlir::mhlo::stringifyComparisonDirection(
unwrap(attr).cast<mlir::mhlo::ComparisonDirectionAttr>().getValue()));
}

//
// ComparisonTypeAttr.
//

MlirAttribute mlirMhloComparisonTypeAttrGet(MlirContext ctx,
const std::string &type) {
MlirStringRef type) {
llvm::Optional<mlir::mhlo::ComparisonType> compare_type =
mlir::mhlo::symbolizeComparisonType(type);
mlir::mhlo::symbolizeComparisonType(unwrap(type));
if (!compare_type) llvm_unreachable("Invalid comparison-type specified.");
return wrap(mlir::mhlo::ComparisonTypeAttr::get(unwrap(ctx),
compare_type.getValue()));
Expand All @@ -393,20 +388,18 @@ bool mlirMhloAttributeIsAComparisonTypeAttr(MlirAttribute attr) {
return unwrap(attr).isa<mlir::mhlo::ComparisonTypeAttr>();
}

std::string mlirMhloComparisonTypeAttrGetType(MlirAttribute attr) {
return mlir::mhlo::stringifyComparisonType(
unwrap(attr).cast<mlir::mhlo::ComparisonTypeAttr>().getValue())
.str();
MlirStringRef mlirMhloComparisonTypeAttrGetType(MlirAttribute attr) {
return wrap(mlir::mhlo::stringifyComparisonType(
unwrap(attr).cast<mlir::mhlo::ComparisonTypeAttr>().getValue()));
}

//
// PrecisionAttr.
//

MlirAttribute mlirMhloPrecisionAttrGet(MlirContext ctx,
const std::string &type) {
MlirAttribute mlirMhloPrecisionAttrGet(MlirContext ctx, MlirStringRef type) {
llvm::Optional<mlir::mhlo::Precision> precision_type =
mlir::mhlo::symbolizePrecision(type);
mlir::mhlo::symbolizePrecision(unwrap(type));
if (!precision_type) llvm_unreachable("Invalid precision-type specified.");
return wrap(
mlir::mhlo::PrecisionAttr::get(unwrap(ctx), precision_type.getValue()));
Expand All @@ -416,19 +409,18 @@ bool mlirMhloAttributeIsAPrecisionAttr(MlirAttribute attr) {
return unwrap(attr).isa<mlir::mhlo::PrecisionAttr>();
}

std::string mlirMhloPrecisionAttrGetPrecision(MlirAttribute attr) {
return mlir::mhlo::stringifyPrecision(
unwrap(attr).cast<mlir::mhlo::PrecisionAttr>().getValue())
.str();
MlirStringRef mlirMhloPrecisionAttrGetPrecision(MlirAttribute attr) {
return wrap(mlir::mhlo::stringifyPrecision(
unwrap(attr).cast<mlir::mhlo::PrecisionAttr>().getValue()));
}

//
// FftTypeAttr.
//

MlirAttribute mlirMhloFftTypeAttrGet(MlirContext ctx, const std::string &type) {
MlirAttribute mlirMhloFftTypeAttrGet(MlirContext ctx, MlirStringRef type) {
llvm::Optional<mlir::mhlo::FftType> fft_type =
mlir::mhlo::symbolizeFftType(type);
mlir::mhlo::symbolizeFftType(unwrap(type));
if (!fft_type) llvm_unreachable("Invalid fft-type specified.");
return wrap(mlir::mhlo::FftTypeAttr::get(unwrap(ctx), fft_type.getValue()));
}
Expand All @@ -437,20 +429,19 @@ bool mlirMhloAttributeIsAFftTypeAttr(MlirAttribute attr) {
return unwrap(attr).isa<mlir::mhlo::FftTypeAttr>();
}

std::string mlirMhloFftTypeAttrGetFftType(MlirAttribute attr) {
return mlir::mhlo::stringifyFftType(
unwrap(attr).cast<mlir::mhlo::FftTypeAttr>().getValue())
.str();
MlirStringRef mlirMhloFftTypeAttrGetFftType(MlirAttribute attr) {
return wrap(mlir::mhlo::stringifyFftType(
unwrap(attr).cast<mlir::mhlo::FftTypeAttr>().getValue()));
}

//
// DequantizeModeAttr.
//

MlirAttribute mlirMhloDequantizeModeAttrGet(MlirContext ctx,
const std::string &mode) {
MlirStringRef mode) {
llvm::Optional<mlir::mhlo::DequantizeMode> dequantize_mode =
mlir::mhlo::symbolizeDequantizeMode(mode);
mlir::mhlo::symbolizeDequantizeMode(unwrap(mode));
if (!dequantize_mode) llvm_unreachable("Invalid dequantize-mode specified.");
return wrap(mlir::mhlo::DequantizeModeAttr::get(unwrap(ctx),
dequantize_mode.getValue()));
Expand All @@ -460,20 +451,18 @@ bool mlirMhloAttributeIsADequantizeModeAttr(MlirAttribute attr) {
return unwrap(attr).isa<mlir::mhlo::DequantizeModeAttr>();
}

std::string mlirMhloDequantizeModeAttrGetDequantizeMode(MlirAttribute attr) {
return mlir::mhlo::stringifyDequantizeMode(
unwrap(attr).cast<mlir::mhlo::DequantizeModeAttr>().getValue())
.str();
MlirStringRef mlirMhloDequantizeModeAttrGetDequantizeMode(MlirAttribute attr) {
return wrap(mlir::mhlo::stringifyDequantizeMode(
unwrap(attr).cast<mlir::mhlo::DequantizeModeAttr>().getValue()));
}

//
// TransposeAttr.
//

MlirAttribute mlirMhloTransposeAttrGet(MlirContext ctx,
const std::string &type) {
MlirAttribute mlirMhloTransposeAttrGet(MlirContext ctx, MlirStringRef type) {
llvm::Optional<mlir::mhlo::Transpose> transpose_type =
mlir::mhlo::symbolizeTranspose(type);
mlir::mhlo::symbolizeTranspose(unwrap(type));
if (!transpose_type) llvm_unreachable("Invalid transpose-type specified.");
return wrap(
mlir::mhlo::TransposeAttr::get(unwrap(ctx), transpose_type.getValue()));
Expand All @@ -483,20 +472,18 @@ bool mlirMhloAttributeIsATransposeAttr(MlirAttribute attr) {
return unwrap(attr).isa<mlir::mhlo::TransposeAttr>();
}

std::string mlirMhloTransposeAttrGetTranspose(MlirAttribute attr) {
return mlir::mhlo::stringifyTranspose(
unwrap(attr).cast<mlir::mhlo::TransposeAttr>().getValue())
.str();
MlirStringRef mlirMhloTransposeAttrGetTranspose(MlirAttribute attr) {
return wrap(mlir::mhlo::stringifyTranspose(
unwrap(attr).cast<mlir::mhlo::TransposeAttr>().getValue()));
}

//
// FusionKindAttr.
//

MlirAttribute mlirMhloFusionKindAttrGet(MlirContext ctx,
const std::string &kind) {
MlirAttribute mlirMhloFusionKindAttrGet(MlirContext ctx, MlirStringRef kind) {
llvm::Optional<mlir::mhlo::FusionKind> fusion_kind =
mlir::mhlo::symbolizeFusionKind(kind);
mlir::mhlo::symbolizeFusionKind(unwrap(kind));
if (!fusion_kind) llvm_unreachable("Invalid fusion-kind specified.");
return wrap(
mlir::mhlo::FusionKindAttr::get(unwrap(ctx), fusion_kind.getValue()));
Expand All @@ -506,10 +493,9 @@ bool mlirMhloAttributeIsAFusionKindAttr(MlirAttribute attr) {
return unwrap(attr).isa<mlir::mhlo::FusionKindAttr>();
}

std::string mlirMhloFusionKindAttrGetFusionKind(MlirAttribute attr) {
return mlir::mhlo::stringifyFusionKind(
unwrap(attr).cast<mlir::mhlo::FusionKindAttr>().getValue())
.str();
MlirStringRef mlirMhloFusionKindAttrGetFusionKind(MlirAttribute attr) {
return wrap(mlir::mhlo::stringifyFusionKind(
unwrap(attr).cast<mlir::mhlo::FusionKindAttr>().getValue()));
}

//
Expand Down
Loading

0 comments on commit 9d95d60

Please sign in to comment.