-
Notifications
You must be signed in to change notification settings - Fork 953
Description
template
TfLiteStatus copyToTensor(TfLiteContext* context, const FromT* in,
TfLiteEvalTensor* out, int num_elements) {
switch (out->type) {
case kTfLiteInt8:
copyCast(in, out->data.int8, num_elements);
break;
case kTfLiteInt16:
copyCast(in, out->data.i16, num_elements);
break;
case kTfLiteInt32:
copyCast(in, out->data.i32, num_elements);
break;
case kTfLiteUInt32:
copyCast(in, out->data.u32, num_elements);
break;
case kTfLiteFloat32:
copyCast(in, tflite::micro::GetTensorData(out), num_elements);
break;
default:
// Unsupported type.
MicroPrintf("Output type %s (%d) not supported.",
TfLiteTypeGetName(out->type), out->type);
}
return kTfLiteOk;
}
TfLiteStatus CastEval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteEvalTensor* input =
tflite::micro::GetEvalInput(context, node, kInputTensor);
TfLiteEvalTensor* output =
tflite::micro::GetEvalOutput(context, node, kOutputTensor);
int num_elements = MatchingFlatSize(tflite::micro::GetTensorShape(input),
tflite::micro::GetTensorShape(output));
switch (input->type) {
case kTfLiteInt8:
return copyToTensor(context, input->data.int8, output, num_elements);
case kTfLiteInt16:
return copyToTensor(context, tflite::micro::GetTensorData<int16_t>(input),
output, num_elements);
case kTfLiteInt32:
return copyToTensor(context, tflite::micro::GetTensorData<int32_t>(input),
output, num_elements);
case kTfLiteUInt32:
return copyToTensor(context,
tflite::micro::GetTensorData<uint32_t>(input), output,
num_elements);
case kTfLiteFloat32:
return copyToTensor(context, tflite::micro::GetTensorData(input),
output, num_elements);
case kTfLiteBool:
return copyToTensor(context, tflite::micro::GetTensorData(input),
output, num_elements);
default:
// Unsupported type.
MicroPrintf("Input type %s (%d) not supported.",
TfLiteTypeGetName(input->type), input->type);
}
return kTfLiteOk;
}
} // namespace