forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
onnxifi_graph_info.h
114 lines (98 loc) · 3.46 KB
/
onnxifi_graph_info.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#pragma once
#include <functional>
#include <memory>
#include <mutex>
#include <unordered_map>
#include "caffe2/core/logging.h"
#include "caffe2/opt/shape_info.h"
#include "foxi/onnxifi_loader.h"
namespace caffe2 {
namespace onnx {
struct BackendGraphInfo {
onnxBackendID backend_id;
onnxBackend backend;
onnxGraph graph;
onnxifi_library* lib{nullptr};
std::unordered_map<std::string, ShapeInfo> weight_shape_info;
BackendGraphInfo(
onnxBackendID backend_id,
onnxBackend backend,
onnxGraph graph,
onnxifi_library* lib,
std::unordered_map<std::string, ShapeInfo>&& s)
: backend_id(backend_id),
backend(backend),
graph(graph),
lib(lib),
weight_shape_info(std::move(s)) {}
BackendGraphInfo(const BackendGraphInfo& other) = delete;
BackendGraphInfo& operator=(const BackendGraphInfo& other) = delete;
BackendGraphInfo(BackendGraphInfo&& other) noexcept {
backend_id = other.backend_id;
backend = other.backend;
graph = other.graph;
lib = other.lib;
weight_shape_info = std::move(other.weight_shape_info);
other.backend_id = other.backend = other.graph = other.lib = nullptr;
}
BackendGraphInfo& operator=(BackendGraphInfo&& other) {
backend_id = other.backend_id;
backend = other.backend;
graph = other.graph;
lib = other.lib;
weight_shape_info = std::move(other.weight_shape_info);
other.backend_id = other.backend = other.graph = other.lib = nullptr;
return *this;
}
~BackendGraphInfo() {
if (lib) {
onnxStatus err;
if (graph) {
err = lib->onnxReleaseGraph(graph);
if (err != ONNXIFI_STATUS_SUCCESS) {
LOG(ERROR) << "Error when calling onnxReleaseGraph";
}
}
if (backend) {
err = lib->onnxReleaseBackend(backend);
if (err != ONNXIFI_STATUS_SUCCESS) {
LOG(ERROR) << "Error when calling onnxReleaseBackend";
}
}
if (backend_id) {
err = lib->onnxReleaseBackendID(backend_id);
if (err != ONNXIFI_STATUS_SUCCESS) {
LOG(ERROR) << "Error when calling onnxReleaseBackendID";
}
}
}
}
};
using SharedPtrBackendGraphInfo = std::shared_ptr<BackendGraphInfo>;
// This class maintains a map of already created graph for nets+ops
class OnnxBackendGraphMap {
public:
OnnxBackendGraphMap() {}
// Make class noncopyable and nomovable.
OnnxBackendGraphMap(const OnnxBackendGraphMap&) = delete;
OnnxBackendGraphMap(OnnxBackendGraphMap&&) = delete;
OnnxBackendGraphMap operator=(const OnnxBackendGraphMap&) = delete;
OnnxBackendGraphMap operator=(OnnxBackendGraphMap&&) = delete;
SharedPtrBackendGraphInfo lookup(const std::string& key);
// If corresponding BackendGraphInfo already exists, return it directly.
// Otherwise we use creator to create the BackendGraphInfo shared_ptr and
// insert it into the map and return it. The whole process should be guarded
// by a lock. Note that since it will create the backend while holding the
// lock, expect latency during initialization phase when there are lots of
// models to compile.
SharedPtrBackendGraphInfo insert(
const std::string& key,
std::function<SharedPtrBackendGraphInfo()> creator);
void remove(const std::string& key);
private:
std::mutex backend_graph_map_lock_;
std::unordered_map<std::string, SharedPtrBackendGraphInfo> backend_graph_map_;
};
OnnxBackendGraphMap* getOnnxBackendGraphMap();
} // namespace onnx
} // namespace caffe2