Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PluggableDevice] Kernel C API enhancement for retrieving attributes #44017

171 changes: 168 additions & 3 deletions tensorflow/c/kernels.cc
Expand Up @@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/stream_executor/stream.h"
#endif // !defined(IS_MOBILE_PLATFORM) && !defined(IS_SLIM_BUILD)

using tensorflow::errors::InvalidArgument;
// This file forms the basis of a stable ABI for third-party kernel
// implementations. It is crucial that changes to this file are made cautiously
// and with a focus on maintaining both source and binary compatibility.
Expand Down Expand Up @@ -87,9 +88,25 @@ void AddTypeConstraint(TF_KernelBuilder* kernel_builder, const char* attr_name,
TF_SetStatus(status, TF_OK, "");
}
#undef CASE

} // namespace
} // namespace tensorflow

namespace {
const tensorflow::AttrValue* GetAttrValue(TF_OpKernelConstruction* ctx,
const char* attr_name,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
const tensorflow::AttrValue* attr =
::tensorflow::AttrSlice(cc_ctx->def()).Find(attr_name);
if (attr == nullptr) {
status->status = InvalidArgument("Operation '", cc_ctx->def().name(),
"' has no attr named '", attr_name, "'.");
}
return attr;
}
} // namespace

void TF_KernelBuilder_TypeConstraint(TF_KernelBuilder* kernel_builder,
const char* attr_name,
const TF_DataType type,
Expand Down Expand Up @@ -257,7 +274,81 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
cc_ctx->CtxFailure(s);
}

#define DEFINE_TF_GETATTR(func, c_type, cc_type) \
void TF_OpKernelConstruction_GetAttrSize(TF_OpKernelConstruction* ctx,
const char* attr_name,
int32_t* list_size,
int32_t* total_size,
TF_Status* status) {
const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status);
if (!status->status.ok()) {
*list_size = -1;
*total_size = -1;
return;
}
switch (attr->value_case()) {
#define SINGLE_CASE(kK, attr_type, size_expr) \
case tensorflow::AttrValue::kK: \
*list_size = -1; \
*total_size = size_expr; \
break;

SINGLE_CASE(kS, TF_ATTR_STRING, attr->s().length());
SINGLE_CASE(kI, TF_ATTR_INT, -1);
SINGLE_CASE(kF, TF_ATTR_FLOAT, -1);
SINGLE_CASE(kB, TF_ATTR_BOOL, -1);
SINGLE_CASE(kType, TF_ATTR_TYPE, -1);
SINGLE_CASE(kShape, TF_ATTR_SHAPE,
attr->shape().unknown_rank() ? -1 : attr->shape().dim_size());
SINGLE_CASE(kTensor, TF_ATTR_TENSOR, -1);
#undef SINGLE_CASE

case tensorflow::AttrValue::kList:
*list_size = 0;
*total_size = -1;
#define LIST_CASE(field, attr_type, ...) \
if (attr->list().field##_size() > 0) { \
*list_size = attr->list().field##_size(); \
__VA_ARGS__; \
break; \
}

LIST_CASE(
s, TF_ATTR_STRING, *total_size = 0;
for (int i = 0; i < attr->list().s_size();
++i) { *total_size += attr->list().s(i).size(); });
LIST_CASE(i, TF_ATTR_INT);
LIST_CASE(f, TF_ATTR_FLOAT);
LIST_CASE(b, TF_ATTR_BOOL);
LIST_CASE(type, TF_ATTR_TYPE);
LIST_CASE(
shape, TF_ATTR_SHAPE, *total_size = 0;
for (int i = 0; i < attr->list().shape_size(); ++i) {
const auto& s = attr->list().shape(i);
*total_size += s.unknown_rank() ? 0 : s.dim_size();
});
LIST_CASE(tensor, TF_ATTR_TENSOR);
LIST_CASE(tensor, TF_ATTR_FUNC);
#undef LIST_CASE
break;

case tensorflow::AttrValue::kPlaceholder:
*list_size = -1;
*total_size = -1;
break;

case tensorflow::AttrValue::kFunc:
*list_size = -1;
*total_size = -1;
break;

case tensorflow::AttrValue::VALUE_NOT_SET:
status->status =
InvalidArgument("Attribute '", attr_name, "' has no value set");
break;
}
}

#define DEFINE_TF_GETATTR(func, c_type, cc_type, attr_type, list_field) \
void TF_OpKernelConstruction_GetAttr##func(TF_OpKernelConstruction* ctx, \
const char* attr_name, \
c_type* val, TF_Status* status) { \
Expand All @@ -269,10 +360,84 @@ void TF_OpKernelContext_Failure(TF_OpKernelContext* ctx, TF_Status* status) {
if (s.ok()) { \
*val = static_cast<c_type>(v); \
} \
} \
void TF_OpKernelConstruction_GetAttr##func##List( \
TF_OpKernelConstruction* ctx, const char* attr_name, c_type* vals, \
int max_vals, TF_Status* status) { \
TF_SetStatus(status, TF_OK, ""); \
const tensorflow::AttrValue* attr = GetAttrValue(ctx, attr_name, status); \
if (!status->status.ok()) return; \
if (attr->value_case() != tensorflow::AttrValue::kList) { \
status->status = \
InvalidArgument("Value for '", attr_name, "' is not a list."); \
return; \
} \
annarev marked this conversation as resolved.
Show resolved Hide resolved
status->status = \
tensorflow::AttrValueHasType(*attr, "list(" attr_type ")"); \
if (!status->status.ok()) return; \
const auto len = std::min(max_vals, attr->list().list_field##_size()); \
for (int i = 0; i < len; ++i) { \
vals[i] = static_cast<c_type>(attr->list().list_field(i)); \
} \
}

DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType)
DEFINE_TF_GETATTR(Int32, tensorflow::int32, int32_t)
DEFINE_TF_GETATTR(Type, TF_DataType, tensorflow::DataType, "type", type)
DEFINE_TF_GETATTR(Int32, int32_t, tensorflow::int32, "int", i)
DEFINE_TF_GETATTR(Int64, int64_t, tensorflow::int64, "int", i)
DEFINE_TF_GETATTR(Float, float, float, "float", f)
DEFINE_TF_GETATTR(Bool, TF_Bool, bool, "bool", b)

void TF_OpKernelConstruction_GetAttrString(TF_OpKernelConstruction* ctx,
const char* attr_name, char* value,
size_t max_length,
TF_Status* status) {
std::string v;
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
::tensorflow::Set_TF_Status_from_Status(status, s);

if (!status->status.ok()) return;

if (max_length <= 0) {
return;
}
std::memcpy(value, v.data(), std::min<size_t>(v.length(), max_length));
annarev marked this conversation as resolved.
Show resolved Hide resolved
}

void TF_OpKernelConstruction_GetAttrStringList(TF_OpKernelConstruction* ctx,
const char* attr_name,
char** values, size_t* lengths,
int max_values, void* storage,
size_t storage_size,
TF_Status* status) {
std::vector<std::string> v;
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
::tensorflow::Status s = cc_ctx->GetAttr(attr_name, &v);
::tensorflow::Set_TF_Status_from_Status(status, s);

if (!status->status.ok()) return;

const auto len = std::min(max_values, static_cast<int>(v.size()));
char* p = static_cast<char*>(storage);
for (int i = 0; i < len; ++i) {
const std::string& s = v[i];
values[i] = p;
lengths[i] = s.size();
if ((p + s.size()) > (static_cast<char*>(storage) + storage_size)) {
status->status = InvalidArgument(
"Not enough storage to hold the requested list of strings");
return;
}
memcpy(values[i], s.data(), s.size());
p += s.size();
}
}

bool TF_OpKernelConstruction_HasAttr(TF_OpKernelConstruction* ctx,
const char* attr_name, TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelConstruction*>(ctx);
return cc_ctx->HasAttr(attr_name);
}

TF_StringView TF_OpKernelConstruction_GetName(TF_OpKernelConstruction* ctx) {
auto* cc_ctx = reinterpret_cast<tensorflow::OpKernelConstruction*>(ctx);
Expand Down
124 changes: 124 additions & 0 deletions tensorflow/c/kernels.h
Expand Up @@ -184,6 +184,24 @@ TF_CAPI_EXPORT extern TF_DataType TF_ExpectedOutputDataType(
// Returns the step ID of the given context.
TF_CAPI_EXPORT extern int64_t TF_StepId(TF_OpKernelContext* ctx);

// Get the list_size and total_size of the attribute `attr_name` of `oper`.
// list_size - the length of the list.
// total_size - total size of the list.
// (1) If attr_type == TF_ATTR_STRING
// then total_size is the cumulative byte size
// of all the strings in the list.
// (3) If attr_type == TF_ATTR_SHAPE
// then total_size is the number of dimensions
// of the shape valued attribute, or -1
// if its rank is unknown.
// (4) If attr_type == TF_ATTR_SHAPE
// then total_size is the cumulative number
// of dimensions of all shapes in the list.
// (5) Otherwise, total_size is undefined.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrSize(
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* list_size,
int32_t* total_size, TF_Status* status);

// Interprets the named kernel construction attribute as a TF_DataType and
// places it into *val. *status is set to TF_OK.
//
Expand All @@ -202,6 +220,112 @@ TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32(
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* val,
TF_Status* status);

// Interprets the named kernel construction attribute as int64_t and
// places it into *val. *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// int64, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64(
TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* val,
TF_Status* status);

// Interprets the named kernel construction attribute as float and
// places it into *val. *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// float, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloat(
TF_OpKernelConstruction* ctx, const char* attr_name, float* val,
TF_Status* status);

// Interprets the named kernel construction attribute as bool and
// places it into *val. *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// bool, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBool(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* val,
TF_Status* status);

// Interprets the named kernel construction attribute as string and
// places it into *val. `val` must
// point to an array of length at least `max_length` (ideally set to
// total_size from TF_OpKernelConstruction_GetAttrSize(ctx,
// attr_name, list_size, total_size)). *status is set to TF_OK.
//
// If the attribute could not be found or could not be interpreted as
// string, *status is populated with an error.
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrString(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the current return type? Should it be TString?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet thanks for the reminder, I fixed the typo "bool"->"string" in the description. return value is a string type, which will be stored into value and max_length.

Copy link
Contributor

@annarev annarev Nov 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet As I understand TString is for representing string Tensors, so char* is fine here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So should the type of val be char* val instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TF_OpKernelConstruction* ctx, const char* attr_name, char* val,
size_t max_length, TF_Status* status);

// Interprets the named kernel construction attribute as a TF_DataType array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrTypeList(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_DataType* vals,
int max_vals, TF_Status* status);

// Interprets the named kernel construction attribute as int32_t array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt32List(
TF_OpKernelConstruction* ctx, const char* attr_name, int32_t* vals,
int max_vals, TF_Status* status);

// Interprets the named kernel construction attribute as int64_t array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrInt64List(
TF_OpKernelConstruction* ctx, const char* attr_name, int64_t* vals,
int max_vals, TF_Status* status);

// Interprets the named kernel construction attribute as float array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrFloatList(
TF_OpKernelConstruction* ctx, const char* attr_name, float* vals,
int max_vals, TF_Status* status);

// Interprets the named kernel construction attribute as bool array and
// places it into *vals. *status is set to TF_OK.
// `vals` must point to an array of length at least `max_values` (ideally set
// to list_size from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size)).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrBoolList(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Bool* vals,
int max_vals, TF_Status* status);

// Interprets the named kernel construction attribute as string array and fills
// in `vals` and `lengths`, each of which must point to an array of length at
// least `max_values`. *status is set to TF_OK. The elements of values will
// point to addresses in `storage` which must be at least `storage_size` bytes
// in length. Ideally, max_values would be set to list_size and `storage` would
// be at least total_size, obtained from
// TF_OpKernelConstruction_GetAttrSize(ctx, attr_name, list_size,
// total_size).
TF_CAPI_EXPORT extern void TF_OpKernelConstruction_GetAttrStringList(
TF_OpKernelConstruction* ctx, const char* attr_name, char** vals,
size_t* lengths, int max_values, void* storage, size_t storage_size,
TF_Status* status);

// Return true if the kernel construction has the attr_name
TF_CAPI_EXPORT extern bool TF_OpKernelConstruction_HasAttr(
TF_OpKernelConstruction* ctx, const char* attr_name, TF_Status* status);

// Returns the unique operation name for this OpKernel.
TF_CAPI_EXPORT extern TF_StringView TF_OpKernelConstruction_GetName(
TF_OpKernelConstruction* ctx);
Expand Down