-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdevice_util.py
159 lines (125 loc) · 5.43 KB
/
device_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Device-related support functions."""
from tensorflow.python.eager import context
from tensorflow.python.framework import config
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
def canonicalize(d, default=None):
"""Canonicalize device string.
If d has missing components, the rest would be deduced from the `default`
argument or from '/replica:0/task:0/device:CPU:0'. For example:
If d = '/cpu:0', default='/job:worker/task:1', it returns
'/job:worker/replica:0/task:1/device:CPU:0'.
If d = '/cpu:0', default='/job:worker', it returns
'/job:worker/replica:0/task:0/device:CPU:0'.
If d = '/gpu:0', default=None, it returns
'/replica:0/task:0/device:GPU:0'.
Note: This uses "job:localhost" as the default if executing eagerly.
Args:
d: a device string or tf.config.LogicalDevice
default: a string for default device if d doesn't have all components.
Returns:
a canonicalized device string.
"""
if isinstance(d, context.LogicalDevice):
d = tf_device.DeviceSpec.from_string(d.name)
else:
d = tf_device.DeviceSpec.from_string(d)
assert d.device_type is None or d.device_type == d.device_type.upper(), (
"Device type '%s' must be all-caps." % (d.device_type,))
# Fill in missing device fields using defaults.
result = tf_device.DeviceSpec(
replica=0, task=0, device_type="CPU", device_index=0)
if ops.executing_eagerly_outside_functions():
# Try to deduce job, replica and task in case it's in a multi worker setup.
# TODO(b/151452748): Using list_logical_devices is not always safe since it
# may return remote devices as well, but we're already doing this elsewhere.
host_cpu = tf_device.DeviceSpec.from_string(
config.list_logical_devices("CPU")[0].name)
if host_cpu.job:
result = result.make_merged_spec(host_cpu)
else:
# The default job is localhost if eager execution is enabled
result = result.replace(job="localhost")
if default:
# Overrides any defaults with values from the default device if given.
result = result.make_merged_spec(
tf_device.DeviceSpec.from_string(default))
# Apply `d` last, so that it's values take precedence over the defaults.
result = result.make_merged_spec(d)
return result.to_string()
def canonicalize_without_job_and_task(d):
"""Partially canonicalize device string.
This returns device string from `d` without including job and task.
This is most useful for parameter server strategy where the device strings are
generated on the chief, but executed on workers.
For example:
If d = '/cpu:0', default='/job:worker/task:1', it returns
'/replica:0/device:CPU:0'.
If d = '/cpu:0', default='/job:worker', it returns
'/replica:0/device:CPU:0'.
If d = '/gpu:0', default=None, it returns
'/replica:0/device:GPU:0'.
Note: This uses "job:localhost" as the default if executing eagerly.
Args:
d: a device string or tf.config.LogicalDevice
Returns:
a partially canonicalized device string.
"""
canonicalized_device = canonicalize(d)
spec = tf_device.DeviceSpec.from_string(canonicalized_device)
spec = spec.replace(job=None, task=None, replica=0)
return spec.to_string()
def resolve(d):
"""Canonicalize `d` with current device as default."""
return canonicalize(d, default=current())
class _FakeNodeDef(object):
"""A fake NodeDef for _FakeOperation."""
__slots__ = ["op", "name"]
def __init__(self):
self.op = ""
self.name = ""
class _FakeOperation(object):
"""A fake Operation object to pass to device functions."""
def __init__(self):
self.device = ""
self.type = ""
self.name = ""
self.node_def = _FakeNodeDef()
def _set_device(self, device):
self.device = ops._device_string(device) # pylint: disable=protected-access
def _set_device_from_string(self, device_str):
self.device = device_str
def current():
"""Return a string (not canonicalized) for the current device."""
# TODO(josh11b): Work out how this function interacts with ops.colocate_with.
if ops.executing_eagerly_outside_functions():
d = context.context().device_name
else:
op = _FakeOperation()
ops.get_default_graph()._apply_device_functions(op) # pylint: disable=protected-access
d = op.device
return d
def get_host_for_device(device):
"""Returns the corresponding host device for the given device."""
spec = tf_device.DeviceSpec.from_string(device)
return tf_device.DeviceSpec(
job=spec.job, replica=spec.replica, task=spec.task,
device_type="CPU", device_index=0).to_string()
def local_devices_from_num_gpus(num_gpus):
"""Returns device strings for local GPUs or CPU."""
return (tuple("/device:GPU:%d" % i for i in range(num_gpus)) or
("/device:CPU:0",))