Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the docstring of @tf.exports methods in tf.nest. #36186

Merged
merged 15 commits into from Feb 19, 2020
43 changes: 31 additions & 12 deletions tensorflow/python/util/nest.py
Expand Up @@ -16,8 +16,16 @@
"""## Functions for working with arbitrarily nested sequences of elements.

This module can perform operations on nested structures. A nested structure is a
Python sequence, tuple (including `namedtuple`), or dict that can contain
further sequences, tuples, and dicts.
Python collection that can contain further collections as well as other objects
called atoms. Note that numpy arrays are considered atoms.

nest recognizes the following types of collections:
1.tuple
2.namedtuple
3.dict
4.orderedDict
5.MutableMapping
6.attr.s

attr.s decorated classes (http://www.attrs.org) are also supported, in the
same way as `namedtuple`.
Expand All @@ -42,6 +50,7 @@
from tensorflow.python import _pywrap_utils
from tensorflow.python.util.compat import collections_abc as _collections_abc
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.platform import tf_logging


_SHALLOW_TREE_HAS_INVALID_KEYS = (
Expand Down Expand Up @@ -109,11 +118,12 @@ def _is_namedtuple(instance, strict=False):


# See the swig file (util.i) for documentation.
_is_mapping = _pywrap_utils.IsMapping
_is_mapping_view = _pywrap_utils.IsMappingView
_is_attrs = _pywrap_utils.IsAttrs
_is_composite_tensor = _pywrap_utils.IsCompositeTensor
_is_type_spec = _pywrap_utils.IsTypeSpec
_is_mutable_mapping = _pywrap_utils.IsMutableMapping
_is_mapping = _pywrap_utils.IsMapping


def _sequence_like(instance, args):
Expand All @@ -128,7 +138,7 @@ def _sequence_like(instance, args):
Returns:
`args` with the type of `instance`.
"""
if _is_mapping(instance):
if _is_mutable_mapping(instance):
punndcoder28 marked this conversation as resolved.
Show resolved Hide resolved
# Pack dictionaries in a deterministic order by sorting the keys.
# Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
Expand All @@ -138,11 +148,19 @@ def _sequence_like(instance, args):
instance_type = type(instance)
if instance_type == _collections.defaultdict:
punndcoder28 marked this conversation as resolved.
Show resolved Hide resolved
d = _collections.defaultdict(instance.default_factory)
for key in instance:
d[key] = result[key]
return d
else:
return instance_type((key, result[key]) for key in instance)
d = instance_type()
for key in instance:
d[key] = result[key]
return d
elif _is_mapping(instance):
result = dict(zip(_sorted(instance), args))
instance_type = type(instance)
tf_logging.log_first_n(
tf_logging.WARN, "Mapping types may not work well with tf.nest. Prefer using"
"MutableMapping for {}".format(instance_type), 1
)
return instance_type((key, result[key]) for key in instance)
elif _is_mapping_view(instance):
# We can't directly construct mapping views, so we create a list instead
return list(args)
Expand Down Expand Up @@ -243,7 +261,7 @@ def is_nested(seq):
def flatten(structure, expand_composites=False):
"""Returns a flat list from a given nested structure.

If nest is not a sequence, tuple (or a namedtuple), dict, or an attrs class,
If nest is not a structure , tuple (or a namedtuple), dict, or an attrs class,
then returns a single-element list:
[nest].

Expand All @@ -260,8 +278,8 @@ def flatten(structure, expand_composites=False):
running.

Args:
structure: an arbitrarily nested structure or a scalar object. Note, numpy
arrays are considered scalars.
structure: an arbitrarily nested structure. Note, numpy arrays are considered
atoms and are not flattened.
expand_composites: If true, then composite tensors such as tf.SparseTensor
and tf.RaggedTensor are expanded into their component tensors.

Expand Down Expand Up @@ -514,7 +532,7 @@ def map_structure(func, *structure, **kwargs):

Args:
func: A callable that accepts as many arguments as there are structures.
*structure: scalar, or tuple or list of constructed scalars and/or other
*structure: scalar, or tuple or dict or list of constructed scalars and/or other
tuples/lists, or scalars. Note: numpy arrays are considered as scalars.
**kwargs: Valid keyword args are:

Expand Down Expand Up @@ -1354,6 +1372,7 @@ def sequence_fn(instance, args):


_pywrap_utils.RegisterType("Mapping", _collections_abc.Mapping)
_pywrap_utils.RegisterType("MutableMapping", _collections_abc.MutableMapping)
_pywrap_utils.RegisterType("Sequence", _collections_abc.Sequence)
_pywrap_utils.RegisterType("MappingView", _collections_abc.MappingView)
_pywrap_utils.RegisterType("ObjectProxy", _wrapt.ObjectProxy)
12 changes: 12 additions & 0 deletions tensorflow/python/util/util.cc
Expand Up @@ -221,6 +221,17 @@ int IsMappingHelper(PyObject* o) {
return check_cache->CachedLookup(o);
}

// Returns 1 if `o` is considered a mutable mapping for the purposes of Flatten().
// Returns 0 otherwise.
// Returns -1 if an error occurred.
int IsMutableMappingHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
return IsInstanceOfRegisteredType(to_check, "MutableMapping");
});
if (PyDict_Check(o)) return true;
return check_cache->CachedLookup(o);
}

// Returns 1 if `o` is considered a mapping view for the purposes of Flatten().
// Returns 0 otherwise.
// Returns -1 if an error occurred.
Expand Down Expand Up @@ -877,6 +888,7 @@ bool AssertSameStructureHelper(

bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
bool IsMutableMapping(PyObject* o){ return IsMutableMappingHelper(o) == 1; }
bool IsMappingView(PyObject* o) { return IsMappingViewHelper(o) == 1; }
bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/python/util/util.h
Expand Up @@ -86,6 +86,15 @@ PyObject* IsNamedtuple(PyObject* o, bool strict);
// True if the sequence subclasses mapping.
bool IsMapping(PyObject* o);

// Returns a true if its input is a collections.MutableMapping.
//
// Args:
// seq: the input to be checked.
//
// Returns:
// True if the sequence subclasses mapping.
bool IsMutableMapping(PyObject* o);

// Returns a true if its input is a (possibly wrapped) tuple.
//
// Args:
Expand Down
18 changes: 18 additions & 0 deletions tensorflow/python/util/util_wrapper.cc
Expand Up @@ -140,6 +140,24 @@ PYBIND11_MODULE(_pywrap_utils, m) {
Returns:
True if `instance` is a `collections.Mapping`.
)pbdoc");
m.def(
"IsMutableMapping",
[](const py::handle& o) {
bool result = tensorflow::swig::IsMutableMapping(o.ptr());
if (PyErr_Occurred()) {
throw py::error_already_set();
}
return result;
},
R"pbdoc(
Returns True if `instance` is a `collections.MutableMapping`.

Args:
instance: An instance of a Python object.

Returns:
True if `instance` is a `collections.MutableMapping`.
)pbdoc");
m.def(
"IsMappingView",
[](const py::handle& o) {
Expand Down
1 change: 1 addition & 0 deletions tensorflow/tools/def_file_filter/symbols_pybind.txt
Expand Up @@ -5,6 +5,7 @@ tensorflow::swig::IsCompositeTensor
tensorflow::swig::IsTypeSpec
tensorflow::swig::IsNamedtuple
tensorflow::swig::IsMapping
tensorflow::swig::IsMutableMapping
tensorflow::swig::IsMappingView
tensorflow::swig::IsAttrs
tensorflow::swig::IsTensor
Expand Down