Skip to content

Commit

Permalink
Update generated tensor_shape file in the Numpy backend.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 427174249
  • Loading branch information
jburnim authored and tensorflower-gardener committed Feb 8, 2022
1 parent 0672630 commit 08a3bb1
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 1 deletion.
Expand Up @@ -30,6 +30,7 @@
'from tensorflow.python.eager import monitoring',
('from tensorflow.python.platform '
'import tf_logging as logging'),
'from tensorflow.python.types import trace',
'from tensorflow.python.util.tf_export import tf_export',
]

Expand All @@ -41,6 +42,11 @@ class Monitoring(object):
def __getattr__(self, name):
return lambda *args, **kwargs: ()
monitoring = Monitoring()
class Trace(object):
TraceType = object
trace = Trace()
def tf_export(*args, **kwargs):
return lambda f: f
"""
Expand Down
Expand Up @@ -20,6 +20,11 @@ class Monitoring(object):
def __getattr__(self, name):
return lambda *args, **kwargs: ()
monitoring = Monitoring()

class Trace(object):
TraceType = object
trace = Trace()

def tf_export(*args, **kwargs):
return lambda f: f

Expand All @@ -40,12 +45,15 @@ def tf_export(*args, **kwargs):
"""Helper classes for tensor shape inference."""
import functools
import operator
from typing import Optional, Sequence

import six

# from tensorflow.core.framework import tensor_shape_pb2
# from tensorflow.python import tf2
# from tensorflow.python.eager import monitoring
# from tensorflow.python.platform import tf_logging as logging
# from tensorflow.python.types import trace
# from tensorflow.python.util.tf_export import tf_export

_TENSORSHAPE_V2_OVERRIDE = None
Expand Down Expand Up @@ -761,7 +769,7 @@ def as_dimension(value):


@tf_export("TensorShape")
class TensorShape(object):
class TensorShape(trace.TraceType):
"""Represents the shape of a `Tensor`.
A `TensorShape` represents a possibly-partial shape specification for a
Expand Down Expand Up @@ -1132,6 +1140,123 @@ def with_rank_at_most(self, rank):
else:
return self

def is_subtype_of(self, other: trace.TraceType) -> bool:
"""Returns True iff `self` is subtype of `other`.
Shape A is a subtype of shape B if shape B can successfully represent it:
* A `TensorShape` of any rank is a subtype of `TensorShape(None)`.
* TensorShapes of equal ranks are covariant, i.e.
`TensorShape([A1, A2, ..])` is a subtype of
`TensorShape([B1, B2, ..])` iff An is a subtype of Bn.
An is subtype of Bn iff An == Bn or Bn is None.
* TensorShapes of different defined ranks have no subtyping relation.
The subtyping relation is reflexive and transitive, but not symmetric.
Some examples:
* `TensorShape([32, 784])` is a subtype of `TensorShape(None)`, and
`TensorShape([4, 4])` is also a subtype of `TensorShape(None)` but
`TensorShape([32, 784])` and `TensorShape([4, 4])` are not subtypes of
each other.
* All two-dimensional shapes are subtypes of `TensorShape([None, None])`,
such as `TensorShape([32, 784])`. There is no subtype relationship with,
for example, `TensorShape([None])` or `TensorShape([None, None, None])`.
* `TensorShape([32, None])` is also a subtype of `TensorShape([None, None])`
and `TensorShape(None)`. It is not a subtype of, for example,
`TensorShape([32])`, `TensorShape([32, None, 1])`,
`TensorShape([64, None])` or `TensorShape([None, 32])`.
* `TensorShape([32, 784])` is a subtype of itself, and also
`TensorShape([32, None])`, `TensorShape([None, 784])`,
`TensorShape([None, None])` and `TensorShape(None)`.
It has no subtype relation with, for example, `TensorShape([32, 1, 784])`
or `TensorShape([None])`.
Args:
other: Another `TensorShape`.
Returns:
True iff `self` is subtype of `other`.
"""
if not isinstance(other, TensorShape):
return False

# All Tensors are subtypes of a Tensor with no shape.
if other.rank is None:
return True

# Tensor with a defined shape can only be subtype of another with a defined
# shape if they have the same number of dimensions.
if self.rank != other.rank:
return False

# A Tensor is a subtype if each corresponding dimension is a subtype.
return all(o is None or s == o for s, o in zip(self._dims, other._dims)) # pylint: disable=protected-access

def most_specific_common_supertype(
self, others: Sequence[trace.TraceType]) -> Optional["TensorShape"]:
"""Returns the most specific supertype `TensorShape` of self and others.
* `TensorShape([None, 1])` is the most specific `TensorShape` supertyping
both `TensorShape([2, 1])` and `TensorShape([5, 1])`. Note that
`TensorShape(None)` is also a supertype but it is not "most specific".
* `TensorShape([1, 2, 3])` is the most specific `TensorShape` supertyping
both `TensorShape([1, 2, 3])` and `TensorShape([1, 2, 3]`). There are
other less specific TensorShapes that supertype above mentioned
TensorShapes, e.g. `TensorShape([1, 2, None])`, `TensorShape(None)`.
* `TensorShape([None, None])` is the most specific `TensorShape`
supertyping both `TensorShape([2, None])` and `TensorShape([None, 3])`.
As always, `TensorShape(None)` is also a supertype but not the most
specific one.
* `TensorShape(None`) is the only `TensorShape` supertyping both
`TensorShape([1, 2, 3])` and `TensorShape([1, 2])`. In general, any two
shapes that have different ranks will only have `TensorShape(None)`
as a common supertype.
* `TensorShape(None)` is the only `TensorShape` supertyping both
`TensorShape([1, 2, 3])` and `TensorShape(None)`. In general, the common
supertype of any shape with `TensorShape(None)` is `TensorShape(None)`.
Args:
others: Sequence of `TensorShape`.
Returns:
A `TensorShape` which is the most specific supertype shape of `self`
and `others`. None if it does not exist.
"""
if any(not isinstance(other, TensorShape) for other in others):
return None

# A Rankless TensorShape is already a global supertype so we return another
# instance of it.
if self.rank is None:
return unknown_shape()

# A Rankless TensorShape is the most specific supertype for shapes whose
# ranks do not match.
if any(other.dims is None or self.rank != other.rank for other in others):
return unknown_shape()

# Retain the integer dimension if it is the same across all others, else
# use an undefined dimension.
dims = [
dim if all(dim == other._dims[i]
for other in others) else None
for i, dim in enumerate(self._dims)
]
return TensorShape(dims)

# TODO(b/216206374): Consider deprecation at TraceType release.
def is_compatible_with(self, other):
"""Returns True iff `self` is compatible with `other`.
Expand Down

0 comments on commit 08a3bb1

Please sign in to comment.