Skip to content

Commit

Permalink
Merge pull request #36186 from punndcoder28:punndCoder28
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 296058257
Change-Id: I837a5c6149e97b68cc5d85fe02c97a455486fb93
  • Loading branch information
tensorflower-gardener committed Feb 19, 2020
2 parents b18833b + 54457e5 commit 053680b
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 13 deletions.
45 changes: 32 additions & 13 deletions tensorflow/python/util/nest.py
Expand Up @@ -16,8 +16,16 @@
"""## Functions for working with arbitrarily nested sequences of elements.
This module can perform operations on nested structures. A nested structure is a
Python sequence, tuple (including `namedtuple`), or dict that can contain
further sequences, tuples, and dicts.
Python collection that can contain further collections as well as other objects
called atoms. Note that numpy arrays are considered atoms.
nest recognizes the following types of collections:
1.tuple
2.namedtuple
3.dict
4.orderedDict
5.MutableMapping
6.attr.s
attr.s decorated classes (http://www.attrs.org) are also supported, in the
same way as `namedtuple`.
Expand All @@ -42,6 +50,7 @@
from tensorflow.python import _pywrap_utils
from tensorflow.python.util.compat import collections_abc as _collections_abc
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.platform import tf_logging


_SHALLOW_TREE_HAS_INVALID_KEYS = (
Expand Down Expand Up @@ -109,11 +118,12 @@ def _is_namedtuple(instance, strict=False):


# See the swig file (util.i) for documentation.
_is_mapping = _pywrap_utils.IsMapping
_is_mapping_view = _pywrap_utils.IsMappingView
_is_attrs = _pywrap_utils.IsAttrs
_is_composite_tensor = _pywrap_utils.IsCompositeTensor
_is_type_spec = _pywrap_utils.IsTypeSpec
_is_mutable_mapping = _pywrap_utils.IsMutableMapping
_is_mapping = _pywrap_utils.IsMapping


def _sequence_like(instance, args):
Expand All @@ -128,7 +138,7 @@ def _sequence_like(instance, args):
Returns:
`args` with the type of `instance`.
"""
if _is_mapping(instance):
if _is_mutable_mapping(instance):
# Pack dictionaries in a deterministic order by sorting the keys.
# Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
Expand All @@ -138,11 +148,18 @@ def _sequence_like(instance, args):
instance_type = type(instance)
if instance_type == _collections.defaultdict:
d = _collections.defaultdict(instance.default_factory)
for key in instance:
d[key] = result[key]
return d
else:
return instance_type((key, result[key]) for key in instance)
d = instance_type()
for key in instance:
d[key] = result[key]
return d
elif _is_mapping(instance):
result = dict(zip(_sorted(instance), args))
instance_type = type(instance)
tf_logging.log_first_n(
tf_logging.WARN, "Mapping types may not work well with tf.nest. Prefer"
"using MutableMapping for {}".format(instance_type), 1)
return instance_type((key, result[key]) for key in instance)
elif _is_mapping_view(instance):
# We can't directly construct mapping views, so we create a list instead
return list(args)
Expand Down Expand Up @@ -243,7 +260,7 @@ def is_nested(seq):
def flatten(structure, expand_composites=False):
"""Returns a flat list from a given nested structure.
If nest is not a sequence, tuple (or a namedtuple), dict, or an attrs class,
If nest is not a structure , tuple (or a namedtuple), dict, or an attrs class,
then returns a single-element list:
[nest].
Expand Down Expand Up @@ -291,8 +308,8 @@ def flatten(structure, expand_composites=False):
[7., 8., 9.]], dtype=float32)>]
Args:
structure: an arbitrarily nested structure or a scalar object. Note, numpy
arrays are considered scalars.
structure: an arbitrarily nested structure. Note, numpy arrays are
considered atoms and are not flattened.
expand_composites: If true, then composite tensors such as tf.SparseTensor
and tf.RaggedTensor are expanded into their component tensors.
Expand Down Expand Up @@ -545,8 +562,9 @@ def map_structure(func, *structure, **kwargs):
Args:
func: A callable that accepts as many arguments as there are structures.
*structure: scalar, or tuple or list of constructed scalars and/or other
tuples/lists, or scalars. Note: numpy arrays are considered as scalars.
*structure: scalar, or tuple or dict or list of constructed scalars and/or
other tuples/lists, or scalars. Note: numpy arrays are considered as
scalars.
**kwargs: Valid keyword args are:
* `check_types`: If set to `True` (default) the types of
Expand Down Expand Up @@ -1385,6 +1403,7 @@ def sequence_fn(instance, args):


_pywrap_utils.RegisterType("Mapping", _collections_abc.Mapping)
_pywrap_utils.RegisterType("MutableMapping", _collections_abc.MutableMapping)
_pywrap_utils.RegisterType("Sequence", _collections_abc.Sequence)
_pywrap_utils.RegisterType("MappingView", _collections_abc.MappingView)
_pywrap_utils.RegisterType("ObjectProxy", _wrapt.ObjectProxy)
11 changes: 11 additions & 0 deletions tensorflow/python/util/util.cc
Expand Up @@ -221,6 +221,16 @@ int IsMappingHelper(PyObject* o) {
return check_cache->CachedLookup(o);
}

// Returns 1 if `o` is considered a mutable mapping for the purposes of
// Flatten(). Returns 0 otherwise. Returns -1 if an error occurred.
int IsMutableMappingHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
return IsInstanceOfRegisteredType(to_check, "MutableMapping");
});
if (PyDict_Check(o)) return true;
return check_cache->CachedLookup(o);
}

// Returns 1 if `o` is considered a mapping view for the purposes of Flatten().
// Returns 0 otherwise.
// Returns -1 if an error occurred.
Expand Down Expand Up @@ -877,6 +887,7 @@ bool AssertSameStructureHelper(

bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
bool IsMutableMapping(PyObject* o) { return IsMutableMappingHelper(o) == 1; }
bool IsMappingView(PyObject* o) { return IsMappingViewHelper(o) == 1; }
bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
Expand Down
9 changes: 9 additions & 0 deletions tensorflow/python/util/util.h
Expand Up @@ -86,6 +86,15 @@ PyObject* IsNamedtuple(PyObject* o, bool strict);
// True if the sequence subclasses mapping.
bool IsMapping(PyObject* o);

// Returns a true if its input is a collections.MutableMapping.
//
// Args:
// seq: the input to be checked.
//
// Returns:
// True if the sequence subclasses mapping.
bool IsMutableMapping(PyObject* o);

// Returns a true if its input is a (possibly wrapped) tuple.
//
// Args:
Expand Down
18 changes: 18 additions & 0 deletions tensorflow/python/util/util_wrapper.cc
Expand Up @@ -140,6 +140,24 @@ PYBIND11_MODULE(_pywrap_utils, m) {
Returns:
True if `instance` is a `collections.Mapping`.
)pbdoc");
m.def(
"IsMutableMapping",
[](const py::handle& o) {
bool result = tensorflow::swig::IsMutableMapping(o.ptr());
if (PyErr_Occurred()) {
throw py::error_already_set();
}
return result;
},
R"pbdoc(
Returns True if `instance` is a `collections.MutableMapping`.
Args:
instance: An instance of a Python object.
Returns:
True if `instance` is a `collections.MutableMapping`.
)pbdoc");
m.def(
"IsMappingView",
[](const py::handle& o) {
Expand Down
1 change: 1 addition & 0 deletions tensorflow/tools/def_file_filter/symbols_pybind.txt
Expand Up @@ -5,6 +5,7 @@ tensorflow::swig::IsCompositeTensor
tensorflow::swig::IsTypeSpec
tensorflow::swig::IsNamedtuple
tensorflow::swig::IsMapping
tensorflow::swig::IsMutableMapping
tensorflow::swig::IsMappingView
tensorflow::swig::IsAttrs
tensorflow::swig::IsTensor
Expand Down

0 comments on commit 053680b

Please sign in to comment.