Skip to content

Commit

Permalink
Merge pull request #46099 from dwyatte:metrics_class_id
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 363762456
Change-Id: Ibd6a375cb14509e9c90af422995f32505781f658
  • Loading branch information
tensorflower-gardener committed Mar 18, 2021
2 parents fb15002 + d4c5597 commit b7908a0
Show file tree
Hide file tree
Showing 14 changed files with 176 additions and 34 deletions.
88 changes: 81 additions & 7 deletions tensorflow/python/keras/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1468,11 +1468,17 @@ class SensitivitySpecificityBase(Metric):
[the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
"""

def __init__(self, value, num_thresholds=200, name=None, dtype=None):
def __init__(self,
value,
num_thresholds=200,
class_id=None,
name=None,
dtype=None):
super(SensitivitySpecificityBase, self).__init__(name=name, dtype=dtype)
if num_thresholds <= 0:
raise ValueError('`num_thresholds` must be > 0.')
self.value = value
self.class_id = class_id
self.true_positives = self.add_weight(
'true_positives',
shape=(num_thresholds,),
Expand Down Expand Up @@ -1521,13 +1527,19 @@ def update_state(self, y_true, y_pred, sample_weight=None):
y_true,
y_pred,
thresholds=self.thresholds,
class_id=self.class_id,
sample_weight=sample_weight)

def reset_states(self):
num_thresholds = len(self.thresholds)
K.batch_set_value(
[(v, np.zeros((num_thresholds,))) for v in self.variables])

def get_config(self):
config = {'class_id': self.class_id}
base_config = super(SensitivitySpecificityBase, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

def _find_max_under_constraint(self, constrained, dependent, predicate):
"""Returns the maximum of dependent_statistic that satisfies the constraint.
Expand Down Expand Up @@ -1571,13 +1583,21 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase):
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
If `class_id` is specified, we calculate precision by considering only the
entries in the batch for which `class_id` is above the threshold predictions,
and computing the fraction of them for which `class_id` is indeed a correct
label.
For additional information about specificity and sensitivity, see
[the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
Args:
specificity: A scalar value in range `[0, 1]`.
num_thresholds: (Optional) Defaults to 200. The number of thresholds to
use for matching the given specificity.
class_id: (Optional) Integer class ID for which we want binary metrics.
This must be in the half-open interval `[0, num_classes)`, where
`num_classes` is the last dimension of predictions.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
Expand All @@ -1604,13 +1624,22 @@ class SensitivityAtSpecificity(SensitivitySpecificityBase):
```
"""

def __init__(self, specificity, num_thresholds=200, name=None, dtype=None):
def __init__(self,
specificity,
num_thresholds=200,
class_id=None,
name=None,
dtype=None):
if specificity < 0 or specificity > 1:
raise ValueError('`specificity` must be in the range [0, 1].')
self.specificity = specificity
self.num_thresholds = num_thresholds
super(SensitivityAtSpecificity, self).__init__(
specificity, num_thresholds=num_thresholds, name=name, dtype=dtype)
specificity,
num_thresholds=num_thresholds,
class_id=class_id,
name=name,
dtype=dtype)

def result(self):
specificities = math_ops.div_no_nan(
Expand Down Expand Up @@ -1646,13 +1675,21 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase):
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
If `class_id` is specified, we calculate precision by considering only the
entries in the batch for which `class_id` is above the threshold predictions,
and computing the fraction of them for which `class_id` is indeed a correct
label.
For additional information about specificity and sensitivity, see
[the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).
Args:
sensitivity: A scalar value in range `[0, 1]`.
num_thresholds: (Optional) Defaults to 200. The number of thresholds to
use for matching the given sensitivity.
class_id: (Optional) Integer class ID for which we want binary metrics.
This must be in the half-open interval `[0, num_classes)`, where
`num_classes` is the last dimension of predictions.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
Expand All @@ -1679,13 +1716,22 @@ class SpecificityAtSensitivity(SensitivitySpecificityBase):
```
"""

def __init__(self, sensitivity, num_thresholds=200, name=None, dtype=None):
def __init__(self,
sensitivity,
num_thresholds=200,
class_id=None,
name=None,
dtype=None):
if sensitivity < 0 or sensitivity > 1:
raise ValueError('`sensitivity` must be in the range [0, 1].')
self.sensitivity = sensitivity
self.num_thresholds = num_thresholds
super(SpecificityAtSensitivity, self).__init__(
sensitivity, num_thresholds=num_thresholds, name=name, dtype=dtype)
sensitivity,
num_thresholds=num_thresholds,
class_id=class_id,
name=name,
dtype=dtype)

def result(self):
sensitivities = math_ops.div_no_nan(
Expand Down Expand Up @@ -1716,10 +1762,18 @@ class PrecisionAtRecall(SensitivitySpecificityBase):
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
If `class_id` is specified, we calculate precision by considering only the
entries in the batch for which `class_id` is above the threshold predictions,
and computing the fraction of them for which `class_id` is indeed a correct
label.
Args:
recall: A scalar value in range `[0, 1]`.
num_thresholds: (Optional) Defaults to 200. The number of thresholds to
use for matching the given recall.
class_id: (Optional) Integer class ID for which we want binary metrics.
This must be in the half-open interval `[0, num_classes)`, where
`num_classes` is the last dimension of predictions.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
Expand All @@ -1746,14 +1800,20 @@ class PrecisionAtRecall(SensitivitySpecificityBase):
```
"""

def __init__(self, recall, num_thresholds=200, name=None, dtype=None):
def __init__(self,
recall,
num_thresholds=200,
class_id=None,
name=None,
dtype=None):
if recall < 0 or recall > 1:
raise ValueError('`recall` must be in the range [0, 1].')
self.recall = recall
self.num_thresholds = num_thresholds
super(PrecisionAtRecall, self).__init__(
value=recall,
num_thresholds=num_thresholds,
class_id=class_id,
name=name,
dtype=dtype)

Expand Down Expand Up @@ -1786,10 +1846,18 @@ class RecallAtPrecision(SensitivitySpecificityBase):
If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.
If `class_id` is specified, we calculate precision by considering only the
entries in the batch for which `class_id` is above the threshold predictions,
and computing the fraction of them for which `class_id` is indeed a correct
label.
Args:
precision: A scalar value in range `[0, 1]`.
num_thresholds: (Optional) Defaults to 200. The number of thresholds to
use for matching the given precision.
class_id: (Optional) Integer class ID for which we want binary metrics.
This must be in the half-open interval `[0, num_classes)`, where
`num_classes` is the last dimension of predictions.
name: (Optional) string name of the metric instance.
dtype: (Optional) data type of the metric result.
Expand All @@ -1816,14 +1884,20 @@ class RecallAtPrecision(SensitivitySpecificityBase):
```
"""

def __init__(self, precision, num_thresholds=200, name=None, dtype=None):
def __init__(self,
precision,
num_thresholds=200,
class_id=None,
name=None,
dtype=None):
if precision < 0 or precision > 1:
raise ValueError('`precision` must be in the range [0, 1].')
self.precision = precision
self.num_thresholds = num_thresholds
super(RecallAtPrecision, self).__init__(
value=precision,
num_thresholds=num_thresholds,
class_id=class_id,
name=name,
dtype=dtype)

Expand Down

0 comments on commit b7908a0

Please sign in to comment.