forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
/
accelerator.cpp
290 lines (257 loc) · 9.96 KB
/
accelerator.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "accelerator.h"
#include <iostream>
#include <string>
#include <torch/csrc/jit/passes/onnx.h>
#include "torch/csrc/jit/passes/shape_analysis.h"
#include <torch/torch.h>
#include "bridge.h"
#include "core/common/logging/sinks/clog_sink.h"
#include "core/framework/session_options.h"
#include "core/session/environment.h"
#include "python/onnxruntime_pybind_state_common.h"
namespace onnxruntime {
namespace lazytensor {
namespace py = pybind11;
namespace aten = torch::jit::aten;
namespace prim = torch::jit::prim;
// static variable used to create inference session and training session.
const static std::string env_name = std::string("LTC");
static std::unique_ptr<onnxruntime::Environment> ltc_env;
onnxruntime::Environment& GetLtcEnv() {
if (!ltc_env) {
ORT_THROW_IF_ERROR(
onnxruntime::Environment::Create(
std::make_unique<onnxruntime::logging::LoggingManager>(
std::make_unique<onnxruntime::logging::CLogSink>(),
onnxruntime::logging::Severity::kWARNING,
false,
onnxruntime::logging::LoggingManager::InstanceType::Temporal,
&env_name),
ltc_env));
}
return *ltc_env;
}
bool Accelerator::Supported(const torch::jit::Node* node) {
if (!node) {
return false;
}
switch (node->kind()) {
case aten::relu:
case aten::mul:
case aten::add:
case aten::sub:
case aten::div:
case aten::gt:
case aten::lt:
case aten::eq:
case prim::Constant:
case aten::sqrt:
//case aten::size:
//case aten::addcmul:
case aten::permute:
case aten::mm:
case aten::ne:
//case aten::max_pool2d_with_indices:
//case aten::threshold_backward:
std::cout << "[compiler.cc] Support " << *node; //<< std::endl;
return true;
default:
std::cout << "[compiler.cc] Not support " << *node; //<< std::endl;
return false;
}
}
void Accelerator::Run(torch::jit::Stack& stack) {
// Get the number of expected inputs to the graph we are compiling
const at::ArrayRef<torch::jit::Value*>& graph_inputs = subgraph_->inputs();
const auto num_inputs = graph_inputs.size();
// Pop these inputs from the stack.
at::ArrayRef<c10::IValue> inputs = torch::jit::last(stack, num_inputs);
// If we haven't compiled for the shape/device of these inputs before,
// do so now.
torch::jit::CompleteArgumentSpec spec{false, at::ArrayRef<c10::IValue>(inputs)};
if (cache_.find(spec) == cache_.end()) {
cache_.emplace(spec, Compile(spec, inputs));
}
// Run the compiled function!
auto outputs = cache_[spec].code(inputs);
torch::jit::drop(stack, num_inputs);
for (auto& output : outputs) {
stack.push_back(output);
}
}
void Accelerator::CheckArgs(
const at::ArrayRef<c10::IValue>& inputs) {
// TODO: remove this check.
TORCH_CHECK(inputs.size(), "Need at least one input.");
for (const auto& input : inputs) {
TORCH_CHECK(input.isTensor() || input.isScalar(), "Compiler can only handle Tensor or Scalar inputs.");
}
}
// Store input types in sub-graph so that
// ONNX exporter can use them. Input types
// are required when executing ONNX model
// in ORT.
// TODO: Allow ORT to accept models without
// input types. Then, we can remove this function.
void Accelerator::PropagateArgTypes(
const at::ArrayRef<c10::IValue>& inputs) {
TORCH_CHECK(subgraph_->inputs().size() == inputs.size(),
"Number of provided inputs must match captured sub-graph's schema.");
const auto num_inputs = subgraph_->inputs().size();
for (size_t i = 0; i < num_inputs; ++i) {
auto input_symbol = subgraph_->inputs()[i];
auto input_value = inputs[i];
input_symbol->setType(input_value.type());
}
std::cout << "JIT sub-graph: " << std::endl;
std::cout << *subgraph_ << std::endl;
torch::jit::PropagateInputShapes(subgraph_);
std::cout << "JIT sub-graph with shpaes: " << std::endl;
std::cout << *subgraph_ << std::endl;
}
// ONNX exporter is written in Python, so
// this function may calls some Python functions.
// Be aware of GIL issue.
// The returned value is the path to exported
// ONNX file.
static std::string ExportToOnnx(std::shared_ptr<torch::jit::Graph> graph) {
pybind11::gil_scoped_acquire guard{};
// Retrieve Python exporter function.
pybind11::function export_to_onnx =
pybind11::reinterpret_borrow<pybind11::function>(
pybind11::module::import("torch.onnx.utils").attr("_optimize_graph_1"));
// Execute Python function.
auto result = export_to_onnx(graph, ::torch::onnx::OperatorExportTypes::ONNX);
return result.cast<std::string>();
}
// Create an empty session object.
// Models will be loaded later.
static std::unique_ptr<onnxruntime::InferenceSession> CreateSession() {
// Enviroment shared by all sessions.
static onnxruntime::Environment& pybind_default_env = GetLtcEnv();
// All sessions use the same config.
static onnxruntime::SessionOptions sess_opts;
return std::make_unique<onnxruntime::InferenceSession>(sess_opts, pybind_default_env);
}
static OrtDevice CheckAndGetTensorDevice(at::ArrayRef<c10::IValue>& values) {
// This memory info must be shared by all tensors;
// for example, all tensors on CPU or all on a specific GPU.
// When all values are not tensors, we assume CPU device.
// c10::Device's index is default to -1.
c10::Device unique_tensor_device(c10::DeviceType::CPU);
bool assigned = false;
for (auto value : values) {
if (!value.isTensor()) {
continue;
}
auto tensor = value.toTensor();
if (assigned) {
// A device has been recorded, so we compare
// it with the current tensor's device.
TORCH_CHECK(unique_tensor_device == tensor.device(),
"All tensors must be on the same device.");
} else {
// Record the 1st tensor device.
unique_tensor_device = tensor.device();
assigned = true;
}
}
return CreateOrtDevice(unique_tensor_device);
}
std::string GetC10TypeString(c10::IValue& value) {
}
CompiledObject Accelerator::Compile(
torch::jit::CompleteArgumentSpec spec, at::ArrayRef<c10::IValue>& args) {
CompiledObject compiled;
// Assign an empty session.
compiled.sess = CreateSession();
// Let's get the empty session and initialize it.
onnxruntime::InferenceSession& sess = *compiled.sess;
OrtCUDAProviderOptions provider_options{};
provider_options.do_copy_in_default_stream = true;
auto factory = onnxruntime::CreateExecutionProviderFactory_Cuda(&provider_options);
ORT_THROW_IF_ERROR(sess.RegisterExecutionProvider(factory->CreateProvider()));
// Export from Pytorch and load ONNX model into session.
CheckArgs(args);
PropagateArgTypes(args);
std::string model_path = ExportToOnnx(subgraph_);
ORT_THROW_IF_ERROR(sess.Load(model_path));
ORT_THROW_IF_ERROR(sess.Initialize());
onnxruntime::RunOptions run_options;
std::vector<std::string> feed_names;
std::vector<std::string> fetch_names;
for (auto node_arg : *sess.GetModelInputs().second) {
feed_names.push_back(node_arg->Name());
}
for (auto node_arg : *sess.GetModelOutputs().second) {
fetch_names.push_back(node_arg->Name());
}
// Memory info for all tensors.
// Assume all inputs and outputs are on the same device.
OrtDevice shared_device = CheckAndGetTensorDevice(args);
// Duplicate device info for each output tensor.
// TODO: Force scalar to be on CPU since at::Scalar is CPU value.
std::vector<OrtDevice> fetches_device_info(fetch_names.size(), shared_device);
auto code = [this, spec, run_options,
feed_names, fetch_names,
fetches_device_info, &sess, model_path](at::ArrayRef<c10::IValue>& args) {
// Inputs of ORT session.
std::vector<OrtValue> feeds;
// Outputs of ORT session.
std::vector<OrtValue> fetches;
std::cout << "Execute ONNX model " << model_path << std::endl;
// Prepare inputs.
const auto num_inputs = subgraph_->inputs().size();
for (size_t i = 0; i < num_inputs; ++i) {
// The value can be either tensor or scalar.
// Scalar is a tensor with empty shape vector.
// Create ORT tensor from Pytorch tensor without copy.
if (args.at(i).isScalar()) {
// Scalar.
// ORT_ENFORCE(subgraph_->inputs().at(i)->type()->kind() == c10::TypeKind::TensorType);
feeds.push_back(CreateOrtScalarValue(args.at(i).toScalar()));
} else if (args.at(i).isTensor()) {
// Tensor.
ORT_ENFORCE(subgraph_->inputs().at(i)->type()->kind() == c10::TypeKind::TensorType);
feeds.push_back(CreateOrtTensorValue(args.at(i).toTensor()));
} else {
// Looks like LTC only passes scalars and tensors into backend, so we don't care
// other types for now.
ORT_THROW("Only tensor inputs are supported.");
}
}
std::cout << "Run" << std::endl;
// Inputs are ready. Let's run ORT.
ORT_THROW_IF_ERROR(sess.Run(
run_options,
feed_names, feeds,
fetch_names, &fetches, &fetches_device_info));
std::cout << "Run done" << std::endl;
// Convert ORT output to Pytorch format.
std::vector<c10::IValue> outputs;
for (auto value : fetches) {
if (value.IsTensor()) {
onnxruntime::Tensor* tensor = value.GetMutable<onnxruntime::Tensor>();
const onnxruntime::TensorShape& tensor_shape = tensor->Shape();
if (tensor_shape.NumDimensions() > 0) {
// Create Pytorch tensor from ORT tensor without copy.
outputs.push_back(std::move(CreateC10IvalueTensor(value)));
} else if (tensor_shape.NumDimensions() == 0) {
outputs.push_back(std::move(CreateC10IvalueScalar(value)));
} else {
ORT_ENFORCE("Unsupported tensor shape.");
}
} else {
ORT_ENFORCE("Output must be tensor or scalar.");
}
}
std::cout << "Execute ONNX model done" << model_path << std::endl;
return outputs;
};
compiled.code = code;
return compiled;
}
} // namespace lazytensor
} // namespace onnxruntime