-
Notifications
You must be signed in to change notification settings - Fork 125
/
universal_sentence_encoder_utils.h
65 lines (55 loc) · 2.87 KB
/
universal_sentence_encoder_utils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_UTILS_UNIVERSAL_SENTENCE_ENCODER_UTILS_H_
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_UTILS_UNIVERSAL_SENTENCE_ENCODER_UTILS_H_
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
namespace tflite {
namespace task {
namespace text {
// Returns the input tensor indices for a Universal Sentence Encoder QA model in
// this order: query text, response context, response text.
//
// The model is expected to contain input tensors with names:
//
// Tensor | Metadata Name | Tensor Name
// ---------------- | ------------------ | -------------------------------
// Query text | "inp_text" | "ParseExample/ParseExampleV2:1"
// Response context | "res_context" | "ParseExample/ParseExampleV2:2"
// Response text | "res_text" | "ParseExample/ParseExampleV2:3"
//
// Tensors will be matched by first checking the metadata tensor name and then
// the Model tensor name. If no matching tensor name is found, the first three
// input tensors will be used for query text, response context, response text,
// respectively. Other input tensors will be ignored.
tflite::support::StatusOr<std::vector<int>>
GetUniversalSentenceEncoderInputTensorIndices(
tflite::task::core::TfLiteEngine* engine);
// Returns the output tensor indices for a Universal Sentence Encoder QA model
// in this order: query encoding, response encoding.
//
// The model is expected to contain output tensors with names:
//
// - Query encoding "query_encoding" | "Final/EncodeQuery/mul"
// - Response encoding "response_encoding"| "Final/EncodeResult/mul"
//
// Tensors will be matched by first checking the metadata tensor name and then
// the Model tensor name. If no matching tensor name is found, the first two
// output tensors will be used for query encoding and response encoding,
// respectively. Other output tensors will be ignored.
tflite::support::StatusOr<std::vector<int>>
GetUniversalSentenceEncoderOutputTensorIndices(
tflite::task::core::TfLiteEngine* engine);
} // namespace text
} // namespace task
} // namespace tflite
#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_UTILS_UNIVERSAL_SENTENCE_ENCODER_UTILS_H_