Skip to content

Commit

Permalink
Merge pull request #47826 from lgeiger:ones-zeros
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 363674543
Change-Id: I33ffc4b75c4dbc2955bdab4e9745e32328be12dc
  • Loading branch information
tensorflower-gardener committed Mar 18, 2021
2 parents 11bc035 + 87bdd51 commit da467fc
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 18 deletions.
7 changes: 2 additions & 5 deletions tensorflow/python/keras/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1270,15 +1270,12 @@ def _broadcast(v):

inputs = array_ops.reshape(inputs, squeezed_shape)

def _set_const_tensor(val, dtype, shape):
return array_ops.fill(shape, constant_op.constant(val, dtype=dtype))

# self.gamma and self.beta have the wrong shape for fused_batch_norm, so
# we cannot pass them as the scale and offset parameters. Therefore, we
# create two constant tensors in correct shapes for fused_batch_norm and
# later construct a separate calculation on the scale and offset.
scale = _set_const_tensor(1.0, self.dtype, [pre_dim])
offset = _set_const_tensor(0.0, self.dtype, [pre_dim])
scale = array_ops.ones([pre_dim], dtype=self.dtype)
offset = array_ops.zeros([pre_dim], dtype=self.dtype)

# Compute layer normalization using the fused_batch_norm function.
outputs, _, _ = nn.fused_batch_norm(
Expand Down
7 changes: 4 additions & 3 deletions tensorflow/python/ops/array_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,11 @@ def _CreateDenseMaskAndBegin(sizes, concat_dim):
# with 0's everywhere and 1 in the concat dim position.
# Note: Can't use sparse_to_dense since it isn't GPU-capable (for now)
mask = array_ops.concat([
array_ops.fill(array_ops.expand_dims(concat_dim, 0), 0), [1],
array_ops.fill(shape_of_shape - concat_dim - 1, 0)
array_ops.zeros(
array_ops.expand_dims(concat_dim, 0), dtype=dtypes.int32), [1],
array_ops.zeros(shape_of_shape - concat_dim - 1, dtype=dtypes.int32)
], 0)
begin = array_ops.fill(shape_of_shape, 0)
begin = array_ops.zeros(shape_of_shape, dtype=dtypes.int32)
return mask, begin

def _ExtractInputShapes(inputs):
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/python/ops/embedding_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,8 @@ def embedding_lookup_sparse(params,
embeddings = array_ops.gather(embeddings, idx)

# Reshape weights to allow broadcast
ones = array_ops.fill(
array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
ones_shape = array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0)
ones = array_ops.ones(ones_shape, dtype=dtypes.int32)
bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones],
0)

Expand Down
6 changes: 2 additions & 4 deletions tensorflow/python/ops/gradients_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from tensorflow.python.eager import backprop
from tensorflow.python.eager import backprop_util
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function as framework_function
from tensorflow.python.framework import ops
Expand Down Expand Up @@ -172,9 +171,8 @@ def _DefaultGradYs(grad_ys,
"Gradients of complex tensors must set grad_ys (y.dtype = %r)" %
y.dtype)
new_grad_ys.append(
array_ops.fill(
array_ops.shape(y),
constant_op.constant(1, dtype=y.dtype, name="grad_ys_%d" % i)))
array_ops.ones(
array_ops.shape(y), dtype=y.dtype, name="grad_ys_%d" % i))
continue
if y.dtype.is_floating or y.dtype.is_integer:
if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer:
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/python/ops/image_ops_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5548,7 +5548,8 @@ def suppression_loop_body(boxes, iou_threshold, output_size, idx):
array_ops.gather(array_ops.reshape(sorted_indices, [-1]),
gather_idx),
[batch_size, -1])
invalid_index = array_ops.fill([batch_size, max_output_size], 0)
invalid_index = array_ops.zeros([batch_size, max_output_size],
dtype=dtypes.int32)
idx_index = array_ops.expand_dims(math_ops.range(max_output_size), 0)
num_valid_expanded = array_ops.expand_dims(num_valid, 1)
idx = array_ops.where(idx_index < num_valid_expanded,
Expand Down
5 changes: 3 additions & 2 deletions tensorflow/python/ops/math_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,10 @@ def _SegmentMeanGrad(op, grad):
input_rank = array_ops.rank(op.inputs[0])
ones_shape = array_ops.concat([
array_ops.shape(op.inputs[1]),
array_ops.fill(array_ops.expand_dims(input_rank - 1, 0), 1)
array_ops.ones(
array_ops.expand_dims(input_rank - 1, 0), dtype=dtypes.int32)
], 0)
ones = array_ops.fill(ones_shape, constant_op.constant(1, dtype=grad.dtype))
ones = array_ops.ones(ones_shape, dtype=grad.dtype)
scaled_grad = math_ops.divide(grad, math_ops.segment_sum(ones, op.inputs[1]))
return array_ops.gather(scaled_grad, op.inputs[1]), None

Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/ops/math_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4169,7 +4169,7 @@ def reduced_shape(input_shape, axes):
], # [1, 2]
[
input_shape, # [2, 3, 5, 7]
array_ops.fill(axes_shape, 1)
array_ops.ones(axes_shape, dtype=dtypes.int32)
]) # [1, 1]


Expand Down

0 comments on commit da467fc

Please sign in to comment.