-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmodel_loader.cc
182 lines (164 loc) · 5.74 KB
/
model_loader.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/tools/model_loader.h"
#include <cstdlib>
#include <iostream>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/strings/numbers.h"
#include "absl/strings/str_split.h"
#include "tensorflow/lite/core/model_builder.h"
#include "tensorflow/lite/minimal_logging.h"
namespace tflite {
namespace tools {
bool ModelLoader::Init() {
if (model_ && model_->initialized()) {
// Already done.
return true;
}
if (!InitInternal()) {
return false;
}
if (!model_ || !model_->initialized()) {
return false;
}
return true;
}
bool PathModelLoader::InitInternal() {
if (model_path_.empty()) {
TFLITE_LOG_PROD(TFLITE_LOG_ERROR, "model_path is empty.");
return false;
}
model_ = FlatBufferModel::VerifyAndBuildFromFile(model_path_.c_str());
return true;
}
bool BufferModelLoader::InitInternal() {
if (!caller_owned_buffer_ || model_size_ <= 0) {
TFLITE_LOG_PROD(TFLITE_LOG_ERROR,
"Failed to create BufferModelLoader: caller_owned_buffer "
"is %s; model_size: %zu",
caller_owned_buffer_ ? "not null" : "null", model_size_);
return false;
}
model_ = FlatBufferModel::VerifyAndBuildFromBuffer(caller_owned_buffer_,
model_size_);
return true;
}
#ifndef _WIN32
bool MmapModelLoader::InitInternal() {
if (model_fd_ < 0 || model_offset_ < 0 || model_size_ < 0) {
TFLITE_LOG_PROD(
TFLITE_LOG_ERROR,
"Invalid model file descriptor. file descriptor: %d model_offset: "
"%zu model_size: %zu",
model_fd_, model_offset_, model_size_);
return false;
}
if (!MMAPAllocation::IsSupported()) {
TFLITE_LOG_PROD(TFLITE_LOG_ERROR, "MMAPAllocation is not supported.");
return false;
}
auto allocation = std::make_unique<MMAPAllocation>(
model_fd_, model_offset_, model_size_, tflite::DefaultErrorReporter());
if (!allocation->valid()) {
TFLITE_LOG_PROD(TFLITE_LOG_ERROR, "MMAPAllocation is not valid.");
return false;
}
model_ = FlatBufferModel::VerifyAndBuildFromAllocation(std::move(allocation));
#if FLATBUFFERS_LITTLEENDIAN == 0
model_ = FlatBufferModel::ByteConvertModel(std::move(model_));
#endif
return true;
}
bool PipeModelLoader::InitInternal() {
if (pipe_fd_ < 0) {
TFLITE_LOG_PROD(TFLITE_LOG_ERROR, "Invalid pipe file descriptor %d",
pipe_fd_);
return false;
}
std::free(model_buffer_);
model_buffer_ = reinterpret_cast<uint8_t*>(std::malloc(model_size_));
int read_bytes = 0;
int remaining_bytes = model_size_;
uint8_t* buffer = model_buffer_;
while (remaining_bytes > 0 &&
(read_bytes = read(pipe_fd_, buffer, remaining_bytes)) > 0) {
remaining_bytes -= read_bytes;
buffer += read_bytes;
}
// Close the read pipe.
close(pipe_fd_);
if (read_bytes < 0 || remaining_bytes != 0) {
TFLITE_LOG_PROD(
TFLITE_LOG_ERROR,
"Read Model from pipe failed: %s. Expect to read %zu bytes, "
"%d bytes missing.",
std::strerror(errno), model_size_, remaining_bytes);
// If read() failed with -1, or read partial or too much data.
return false;
}
model_ = FlatBufferModel::VerifyAndBuildFromBuffer(
reinterpret_cast<const char*>(model_buffer_), model_size_);
return true;
}
#endif // !_WIN32
std::unique_ptr<ModelLoader> CreateModelLoaderFromPath(absl::string_view path) {
std::vector<absl::string_view> parts = absl::StrSplit(path, ':');
if (parts.empty()) {
return nullptr;
}
#ifndef _WIN32
if (parts[0] == "fd") {
int model_fd;
size_t model_offset, model_size;
if (parts.size() != 4 || !absl::SimpleAtoi(parts[1], &model_fd) ||
!absl::SimpleAtoi(parts[2], &model_offset) ||
!absl::SimpleAtoi(parts[3], &model_size)) {
TFLITE_LOG_PROD(TFLITE_LOG_ERROR, "Failed to parse model path: %s", path);
return nullptr;
}
return std::make_unique<MmapModelLoader>(model_fd, model_offset,
model_size);
}
if (parts[0] == "pipe") {
int read_fd, write_fd;
size_t model_size;
if (parts.size() != 4 || !absl::SimpleAtoi(parts[1], &read_fd) ||
!absl::SimpleAtoi(parts[2], &write_fd) ||
!absl::SimpleAtoi(parts[3], &model_size)) {
TFLITE_LOG_PROD(TFLITE_LOG_ERROR, "Failed to parse model path: %s", path);
return nullptr;
}
// If set, close the write pipe for the read process / thread.
if (write_fd >= 0) {
close(write_fd);
}
return std::make_unique<PipeModelLoader>(read_fd, model_size);
}
#endif // !_WIN32
if (parts[0] == "buffer") {
int64_t buffer_handle;
size_t model_size;
if (parts.size() != 3 || !absl::SimpleAtoi(parts[1], &buffer_handle) ||
!absl::SimpleAtoi(parts[2], &model_size)) {
TFLITE_LOG_PROD(TFLITE_LOG_ERROR, "Failed to parse model path: %s", path);
return nullptr;
}
return std::make_unique<BufferModelLoader>(
reinterpret_cast<const char*>(buffer_handle), model_size);
}
return std::make_unique<PathModelLoader>(path);
}
} // namespace tools
} // namespace tflite