Skip to content

Commit

Permalink
Fixed cases where tf.TensorShape was constructed with float dimensions
Browse files Browse the repository at this point in the history
This is a prerequisite for making TensorShape and Dimension more strict
about the types of their arguments.

PiperOrigin-RevId: 274700832
  • Loading branch information
superbobry authored and tensorflower-gardener committed Oct 15, 2019
1 parent e2eb7e3 commit 2f245bd
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tensorflow/python/distribute/all_reduce_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ def testBuildRingGatherPassStructure(self):
for otl in output_tensors:
self.assertEqual(len(otl), num_chunks)
for ot in otl:
self.assertEqual(ot.shape, [tlen/num_chunks])
self.assertEqual(ot.shape, [tlen//num_chunks])

def _buildInitialVars(self, shape, dev_list):
values = []
num_devices = len(dev_list)
dim = np.prod(shape) if shape else 1
dim = np.prod(shape, dtype=int) if shape else 1
for d in range(0, num_devices):
with ops.device(dev_list[d]):
npt = np.zeros(shape).astype(np.float32)
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/tpu/tpu_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def get_sharded_shape(self, shape, shard_index=None):
raise ValueError("shape %s cannot be sharded %d ways along dimension %d" %
(shape.as_list(), self._number_of_shards,
self._shard_dimension))
dims[self._shard_dimension] /= self._number_of_shards
dims[self._shard_dimension] //= self._number_of_shards
return tensor_shape.as_shape(dims)

def _unshard_shape(self, shape):
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/training/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def __init__(self, sparse, map_op, rank):
"""
self._sparse = sparse
self._map_op = map_op
self._rank = tensor_shape.Dimension(rank)
self._rank = tensor_shape.as_dimension(rank)

def __eq__(self, other):
if self.sparse != other.sparse:
Expand Down

0 comments on commit 2f245bd

Please sign in to comment.