-
Notifications
You must be signed in to change notification settings - Fork 74.1k
/
options.py
154 lines (122 loc) · 5.1 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for tf.data options."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from absl import logging
def _internal_attr_name(name):
return "_" + name
class OptionsBase(object):
"""Base class for representing a set of tf.data options.
Attributes:
_options: Stores the option values.
"""
def __init__(self):
# NOTE: Cannot use `self._options` here as we override `__setattr__`
object.__setattr__(self, "_options", {})
def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
for name in set(self._options) | set(other._options): # pylint: disable=protected-access
if getattr(self, name) != getattr(other, name):
return False
return True
def __ne__(self, other):
if isinstance(other, self.__class__):
return not self.__eq__(other)
else:
return NotImplemented
def __setattr__(self, name, value):
if hasattr(self, name):
object.__setattr__(self, name, value)
else:
raise AttributeError(
"Cannot set the property %s on %s." % (name, type(self).__name__))
# Creates a namedtuple with three keys for optimization graph rewrites settings.
def graph_rewrites():
return collections.namedtuple("GraphRewrites",
["enabled", "disabled", "default"])
def create_option(name, ty, docstring, default_factory=lambda: None):
"""Creates a type-checked property.
Args:
name: The name to use.
ty: The type to use. The type of the property will be validated when it
is set.
docstring: The docstring to use.
default_factory: A callable that takes no arguments and returns a default
value to use if not set.
Returns:
A type-checked property.
"""
def get_fn(option):
# pylint: disable=protected-access
if name not in option._options:
option._options[name] = default_factory()
return option._options.get(name)
def set_fn(option, value):
if not isinstance(value, ty):
raise TypeError("Property \"%s\" must be of type %s, got: %r (type: %r)" %
(name, ty, value, type(value)))
option._options[name] = value # pylint: disable=protected-access
return property(get_fn, set_fn, None, docstring)
def merge_options(*options_list):
"""Merges the given options, returning the result as a new options object.
The input arguments are expected to have a matching type that derives from
`tf.data.OptionsBase` (and thus each represent a set of options). The method
outputs an object of the same type created by merging the sets of options
represented by the input arguments.
If an option is set to different values by different options objects, the
result will match the setting of the options object that appears in the input
list last.
If an option is an instance of `tf.data.OptionsBase` itself, then this method
is applied recursively to the set of options represented by this option.
Args:
*options_list: options to merge
Raises:
TypeError: if the input arguments are incompatible or not derived from
`tf.data.OptionsBase`
Returns:
A new options object which is the result of merging the given options.
"""
if len(options_list) < 1:
raise ValueError("At least one options should be provided")
result_type = type(options_list[0])
for options in options_list:
if not isinstance(options, result_type):
raise TypeError("Incompatible options type: %r vs %r" % (type(options),
result_type))
if not isinstance(options_list[0], OptionsBase):
raise TypeError("The inputs should inherit from `OptionsBase`")
default_options = result_type()
result = result_type()
for options in options_list:
# Iterate over all set options and merge them into the result.
for name in options._options: # pylint: disable=protected-access
this = getattr(result, name)
that = getattr(options, name)
default = getattr(default_options, name)
if that == default:
continue
elif this == default:
setattr(result, name, that)
elif isinstance(this, OptionsBase):
setattr(result, name, merge_options(this, that))
elif this != that:
logging.warning("Changing the value of option %s from %r to %r.", name,
this, that)
setattr(result, name, that)
return result