Skip to content

Commit

Permalink
Merge pull request #48337 from lgeiger:keras-cond-where
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 367658364
Change-Id: I395007608ec051935dabbc4a6bdbb77fefbb8a0b
  • Loading branch information
tensorflower-gardener committed Apr 9, 2021
2 parents 88be8f3 + 2b482a4 commit 108b794
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 18 deletions.
1 change: 0 additions & 1 deletion tensorflow/python/keras/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ py_library(
"//tensorflow/python:check_ops",
"//tensorflow/python:confusion_matrix",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
Expand Down
14 changes: 3 additions & 11 deletions tensorflow/python/keras/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=unused-import
# pylint: disable=g-classes-have-attributes
# pylint: disable=g-doc-return-or-yield
"""Built-in metrics."""

import abc
import math
import types
import warnings

Expand All @@ -33,7 +31,6 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import activations
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine import base_layer
Expand All @@ -56,21 +53,18 @@
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.keras.utils import metrics_utils
from tensorflow.python.keras.utils import tf_inspect
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.keras.utils.generic_utils import to_list
from tensorflow.python.keras.utils.tf_utils import is_tensor_or_variable
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import confusion_matrix
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import dispatch
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export
Expand Down Expand Up @@ -1575,13 +1569,11 @@ def _find_max_under_constraint(self, constrained, dependent, predicate):
Returns maximal dependent value, if no value satiesfies the constraint 0.0.
"""
feasible = array_ops.where(predicate(constrained, self.value))
feasible = array_ops.where_v2(predicate(constrained, self.value))
feasible_exists = math_ops.greater(array_ops.size(feasible), 0)
max_dependent = math_ops.reduce_max(array_ops.gather(dependent, feasible))

def get_max():
return math_ops.reduce_max(array_ops.gather(dependent, feasible))

return control_flow_ops.cond(feasible_exists, get_max, lambda: 0.0)
return array_ops.where_v2(feasible_exists, max_dependent, 0.0)


@keras_export('keras.metrics.SensitivityAtSpecificity')
Expand Down
1 change: 1 addition & 0 deletions tensorflow/python/keras/optimizer_v2/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ py_library(
],
srcs_version = "PY3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
Expand Down Expand Up @@ -416,9 +417,9 @@ def __call__(self, step):
if self.cycle:
# Find the first multiple of decay_steps that is bigger than
# global_step. If global_step is zero set the multiplier to 1
multiplier = control_flow_ops.cond(
math_ops.equal(global_step_recomp, 0), lambda: 1.0,
lambda: math_ops.ceil(global_step_recomp / self.decay_steps))
multiplier = array_ops.where_v2(
math_ops.equal(global_step_recomp, 0), 1.0,
math_ops.ceil(global_step_recomp / self.decay_steps))
decay_steps_recomp = math_ops.multiply(decay_steps_recomp, multiplier)
else:
# Make sure that the global_step used is not bigger than decay_steps.
Expand Down
5 changes: 2 additions & 3 deletions tensorflow/python/keras/utils/metrics_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,8 @@ def update_confusion_matrix_variables(variables_to_update,
num_labels = 1
else:
num_labels = gen_math_ops.Prod(input=pred_shape[1:], axis=0)
thresh_label_tile = control_flow_ops.cond(
one_thresh, lambda: num_labels,
lambda: math_ops.cast(1, dtype=dtypes.int32))
thresh_label_tile = array_ops.where_v2(one_thresh, num_labels,
array_ops.ones([], dtype=dtypes.int32))

# Reshape predictions and labels, adding a dim for thresholding.
if multi_label:
Expand Down

0 comments on commit 108b794

Please sign in to comment.