/
load_library.py
220 lines (181 loc) · 7.28 KB
/
load_library.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Function for loading TensorFlow plugins."""
import errno
import hashlib
import importlib
import os
import platform
import sys
from tensorflow.python.client import pywrap_tf_session as py_tf
from tensorflow.python.eager import context
from tensorflow.python.framework import _pywrap_python_op_gen
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@tf_export('load_op_library')
def load_op_library(library_filename):
"""Loads a TensorFlow plugin, containing custom ops and kernels.
Pass "library_filename" to a platform-specific mechanism for dynamically
loading a library. The rules for determining the exact location of the
library are platform-specific and are not documented here. When the
library is loaded, ops and kernels registered in the library via the
`REGISTER_*` macros are made available in the TensorFlow process. Note
that ops with the same name as an existing op are rejected and not
registered with the process.
Args:
library_filename: Path to the plugin.
Relative or absolute filesystem path to a dynamic library file.
Returns:
A python module containing the Python wrappers for Ops defined in
the plugin.
Raises:
RuntimeError: when unable to load the library or get the python wrappers.
"""
lib_handle = py_tf.TF_LoadLibrary(library_filename)
try:
wrappers = _pywrap_python_op_gen.GetPythonWrappers(
py_tf.TF_GetOpList(lib_handle))
finally:
# Delete the library handle to release any memory held in C
# that are no longer needed.
py_tf.TF_DeleteLibraryHandle(lib_handle)
# Get a unique name for the module.
module_name = hashlib.sha1(wrappers).hexdigest()
if module_name in sys.modules:
return sys.modules[module_name]
module_spec = importlib.machinery.ModuleSpec(module_name, None)
module = importlib.util.module_from_spec(module_spec)
# pylint: disable=exec-used
exec(wrappers, module.__dict__)
# Allow this to be recognized by AutoGraph.
setattr(module, '_IS_TENSORFLOW_PLUGIN', True)
sys.modules[module_name] = module
return module
@deprecation.deprecated(date=None,
instructions='Use `tf.load_library` instead.')
@tf_export(v1=['load_file_system_library'])
def load_file_system_library(library_filename):
"""Loads a TensorFlow plugin, containing file system implementation.
Pass `library_filename` to a platform-specific mechanism for dynamically
loading a library. The rules for determining the exact location of the
library are platform-specific and are not documented here.
Args:
library_filename: Path to the plugin.
Relative or absolute filesystem path to a dynamic library file.
Returns:
None.
Raises:
RuntimeError: when unable to load the library.
"""
py_tf.TF_LoadLibrary(library_filename)
def _is_shared_object(filename):
"""Check the file to see if it is a shared object, only using extension."""
if platform.system() == 'Linux':
if filename.endswith('.so'):
return True
else:
index = filename.rfind('.so.')
if index == -1:
return False
else:
# A shared object with the API version in filename
return filename[index + 4].isdecimal()
elif platform.system() == 'Darwin':
return filename.endswith('.dylib')
elif platform.system() == 'Windows':
return filename.endswith('.dll')
else:
return False
@tf_export('load_library')
def load_library(library_location):
"""Loads a TensorFlow plugin.
"library_location" can be a path to a specific shared object, or a folder.
If it is a folder, all shared objects that are named "libtfkernel*" will be
loaded. When the library is loaded, kernels registered in the library via the
`REGISTER_*` macros are made available in the TensorFlow process.
Args:
library_location: Path to the plugin or the folder of plugins.
Relative or absolute filesystem path to a dynamic library file or folder.
Returns:
None
Raises:
OSError: When the file to be loaded is not found.
RuntimeError: when unable to load the library.
"""
if os.path.exists(library_location):
if os.path.isdir(library_location):
directory_contents = os.listdir(library_location)
kernel_libraries = [
os.path.join(library_location, f) for f in directory_contents
if _is_shared_object(f)]
else:
kernel_libraries = [library_location]
for lib in kernel_libraries:
py_tf.TF_LoadLibrary(lib)
else:
raise OSError(
errno.ENOENT,
'The file or folder to load kernel libraries from does not exist.',
library_location)
def load_pluggable_device_library(library_location):
"""Loads a TensorFlow PluggableDevice plugin.
"library_location" can be a path to a specific shared object, or a folder.
If it is a folder, all shared objects will be loaded. when the library is
loaded, devices/kernels registered in the library via StreamExecutor C API
and Kernel/Op Registration C API are made available in TensorFlow process.
Args:
library_location: Path to the plugin or folder of plugins. Relative or
absolute filesystem path to a dynamic library file or folder.
Raises:
OSError: When the file to be loaded is not found.
RuntimeError: when unable to load the library.
"""
if os.path.exists(library_location):
if os.path.isdir(library_location):
directory_contents = os.listdir(library_location)
pluggable_device_libraries = [
os.path.join(library_location, f)
for f in directory_contents
if _is_shared_object(f)
]
else:
pluggable_device_libraries = [library_location]
for lib in pluggable_device_libraries:
py_tf.TF_LoadPluggableDeviceLibrary(lib)
# Reinitialized physical devices list after plugin registration.
context.context().reinitialize_physical_devices()
else:
raise OSError(
errno.ENOENT,
'The file or folder to load pluggable device libraries from does not '
'exist.', library_location)
@tf_export('experimental.register_filesystem_plugin')
def register_filesystem_plugin(plugin_location):
"""Loads a TensorFlow FileSystem plugin.
Args:
plugin_location: Path to the plugin. Relative or absolute filesystem plugin
path to a dynamic library file.
Returns:
None
Raises:
OSError: When the file to be loaded is not found.
RuntimeError: when unable to load the library.
"""
if os.path.exists(plugin_location):
py_tf.TF_RegisterFilesystemPlugin(plugin_location)
else:
raise OSError(errno.ENOENT,
'The file to load file system plugin from does not exist.',
plugin_location)