-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
225 lines (189 loc) · 7.1 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Contains layer utilities for input validation and format conversion."""
from tensorflow.python.framework import smart_cond as smart_module
from tensorflow.python.ops import cond
from tensorflow.python.ops import variables
def convert_data_format(data_format, ndim):
if data_format == 'channels_last':
if ndim == 3:
return 'NWC'
elif ndim == 4:
return 'NHWC'
elif ndim == 5:
return 'NDHWC'
else:
raise ValueError(f'Input rank: {ndim} not supported. We only support '
'input rank 3, 4 or 5.')
elif data_format == 'channels_first':
if ndim == 3:
return 'NCW'
elif ndim == 4:
return 'NCHW'
elif ndim == 5:
return 'NCDHW'
else:
raise ValueError(f'Input rank: {ndim} not supported. We only support '
'input rank 3, 4 or 5.')
else:
raise ValueError(f'Invalid data_format: {data_format}. We only support '
'"channels_first" or "channels_last"')
def normalize_tuple(value, n, name):
"""Transforms a single integer or iterable of integers into an integer tuple.
Args:
value: The value to validate and convert. Could an int, or any iterable
of ints.
n: The size of the tuple to be returned.
name: The name of the argument being validated, e.g. "strides" or
"kernel_size". This is only used to format error messages.
Returns:
A tuple of n integers.
Raises:
ValueError: If something else than an int/long or iterable thereof was
passed.
"""
if isinstance(value, int):
return (value,) * n
else:
try:
value_tuple = tuple(value)
except TypeError:
raise ValueError(f'Argument `{name}` must be a tuple of {str(n)} '
f'integers. Received: {str(value)}')
if len(value_tuple) != n:
raise ValueError(f'Argument `{name}` must be a tuple of {str(n)} '
f'integers. Received: {str(value)}')
for single_value in value_tuple:
try:
int(single_value)
except (ValueError, TypeError):
raise ValueError(f'Argument `{name}` must be a tuple of {str(n)} '
f'integers. Received: {str(value)} including element '
f'{str(single_value)} of type '
f'{str(type(single_value))}')
return value_tuple
def normalize_data_format(value):
data_format = value.lower()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('The `data_format` argument must be one of '
'"channels_first", "channels_last". Received: '
f'{str(value)}.')
return data_format
def normalize_padding(value):
padding = value.lower()
if padding not in {'valid', 'same'}:
raise ValueError('The `padding` argument must be one of "valid", "same". '
f'Received: {str(padding)}.')
return padding
def conv_output_length(input_length, filter_size, padding, stride, dilation=1):
"""Determines output length of a convolution given input length.
Args:
input_length: integer.
filter_size: integer.
padding: one of "same", "valid", "full".
stride: integer.
dilation: dilation rate, integer.
Returns:
The output length (integer).
"""
if input_length is None:
return None
assert padding in {'same', 'valid', 'full'}
dilated_filter_size = filter_size + (filter_size - 1) * (dilation - 1)
if padding == 'same':
output_length = input_length
elif padding == 'valid':
output_length = input_length - dilated_filter_size + 1
elif padding == 'full':
output_length = input_length + dilated_filter_size - 1
return (output_length + stride - 1) // stride
def conv_input_length(output_length, filter_size, padding, stride):
"""Determines input length of a convolution given output length.
Args:
output_length: integer.
filter_size: integer.
padding: one of "same", "valid", "full".
stride: integer.
Returns:
The input length (integer).
"""
if output_length is None:
return None
assert padding in {'same', 'valid', 'full'}
if padding == 'same':
pad = filter_size // 2
elif padding == 'valid':
pad = 0
elif padding == 'full':
pad = filter_size - 1
return (output_length - 1) * stride - 2 * pad + filter_size
def deconv_output_length(input_length, filter_size, padding, stride):
"""Determines output length of a transposed convolution given input length.
Args:
input_length: integer.
filter_size: integer.
padding: one of "same", "valid", "full".
stride: integer.
Returns:
The output length (integer).
"""
if input_length is None:
return None
input_length *= stride
if padding == 'valid':
input_length += max(filter_size - stride, 0)
elif padding == 'full':
input_length -= (stride + filter_size - 2)
return input_length
def smart_cond(pred, true_fn=None, false_fn=None, name=None):
"""Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
If `pred` is a bool or has a constant value, we return either `true_fn()`
or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
Args:
pred: A scalar determining whether to return the result of `true_fn` or
`false_fn`.
true_fn: The callable to be performed if pred is true.
false_fn: The callable to be performed if pred is false.
name: Optional name prefix when using `tf.cond`.
Returns:
Tensors returned by the call to either `true_fn` or `false_fn`.
Raises:
TypeError: If `true_fn` or `false_fn` is not callable.
"""
if isinstance(pred, variables.Variable):
return cond.cond(
pred, true_fn=true_fn, false_fn=false_fn, name=name)
return smart_module.smart_cond(
pred, true_fn=true_fn, false_fn=false_fn, name=name)
def constant_value(pred):
"""Return the bool value for `pred`, or None if `pred` had a dynamic value.
Args:
pred: A scalar, either a Python bool or a TensorFlow boolean variable
or tensor, or the Python integer 1 or 0.
Returns:
True or False if `pred` has a constant boolean value, None otherwise.
Raises:
TypeError: If `pred` is not a Variable, Tensor or bool, or Python
integer 1 or 0.
"""
# Allow integer booleans.
if isinstance(pred, int):
if pred == 1:
pred = True
elif pred == 0:
pred = False
if isinstance(pred, variables.Variable):
return None
return smart_module.smart_constant_value(pred)