diff --git a/tensorflow_gnn/runner/tasks/classification.py b/tensorflow_gnn/runner/tasks/classification.py index fde60de3..01ed6f92 100644 --- a/tensorflow_gnn/runner/tasks/classification.py +++ b/tensorflow_gnn/runner/tasks/classification.py @@ -8,28 +8,47 @@ Tensor = Union[tf.Tensor, tf.RaggedTensor] -class _FromLogitsMixIn(tf.keras.metrics.Metric): - """Mixin for `tf.keras.metrics.Metric` with a from_logits option.""" +class _GnnMetricMixIn(tf.keras.metrics.Metric): + """Mixin for `tf.keras.metrics.Metric` with gnn-specific options.""" - def __init__(self, from_logits: bool, *args, **kwargs) -> None: + def __init__(self, + from_logits: bool, + *args, + class_id: int = -1, + **kwargs) -> None: + """Initializes metric mixin class. + + Args: + from_logits: Whether the predictions are logits. + *args: Variable arguments. + class_id: Categorical class id for per-class precision/recall. + Assume class id starts with 0. + **kwargs: Variable key-word arguments. + """ super().__init__(*args, **kwargs) self._from_logits = from_logits + self._class_id = class_id def update_state(self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight: Optional[tf.Tensor] = None) -> None: - return super().update_state( - y_true, - tf.nn.sigmoid(y_pred) if self._from_logits else y_pred, - sample_weight) - - -class _Precision(_FromLogitsMixIn, tf.keras.metrics.Precision): + if self._from_logits: + if self._class_id >= 0: + # Multi-class classification. + y_true = (y_true == self._class_id) + y_pred = (tf.argmax(y_pred, -1) == self._class_id) + else: + # Binary classification. + y_pred = tf.nn.sigmoid(y_pred) + return super().update_state(y_true, y_pred, sample_weight) + + +class _Precision(_GnnMetricMixIn, tf.keras.metrics.Precision): pass -class _Recall(_FromLogitsMixIn, tf.keras.metrics.Recall): +class _Recall(_GnnMetricMixIn, tf.keras.metrics.Recall): pass @@ -98,7 +117,20 @@ def metrics(self) -> Sequence[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]]: class _MulticlassClassification(_Classification): """Multiclass classification.""" - def __init__(self, num_classes: int, *args, **kwargs): # pylint: disable=useless-super-delegation + def __init__(self, + num_classes: int, + *args, + class_names: Optional[Sequence[str]] = None, + per_class_statistics: bool = False, + **kwargs): # pylint: disable=useless-super-delegation + if (class_names is not None and len(class_names) != num_classes): + raise ValueError(f"Expected {num_classes} classes, got " + f"{len(class_names)} class names.") + if class_names is None: + self._class_names = [f"class_{i}" for i in range(num_classes)] + else: + self._class_names = class_names + self._per_class_statistics = per_class_statistics super().__init__(num_classes, *args, **kwargs) def losses(self) -> Sequence[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]]: @@ -107,8 +139,21 @@ def losses(self) -> Sequence[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]]: def metrics(self) -> Sequence[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]]: """Sparse categorical metrics.""" - return (tf.keras.metrics.SparseCategoricalAccuracy(), - tf.keras.metrics.SparseCategoricalCrossentropy(from_logits=True)) + metric_objs = [ + tf.keras.metrics.SparseCategoricalAccuracy(), + tf.keras.metrics.SparseCategoricalCrossentropy(from_logits=True)] + + if self._per_class_statistics: + for (i, class_name) in enumerate(self._class_names): + metric_objs.append(_Precision( + from_logits=True, + class_id=i, + name=f"precision_for_{class_name}")) + metric_objs.append(_Recall( + from_logits=True, + class_id=i, + name=f"recall_for_{class_name}")) + return metric_objs class _GraphClassification(_Classification): diff --git a/tensorflow_gnn/runner/tasks/classification_test.py b/tensorflow_gnn/runner/tasks/classification_test.py index 76a7aba1..73c7d174 100644 --- a/tensorflow_gnn/runner/tasks/classification_test.py +++ b/tensorflow_gnn/runner/tasks/classification_test.py @@ -161,6 +161,34 @@ def test_fit(self, def test_protocol(self, klass: object): self.assertIsInstance(klass, orchestration.Task) + def test_default_per_class_metrics(self): + task = classification.GraphMulticlassClassification( + num_classes=5, node_set_name="nodes", per_class_statistics=True) + metric_names = [metric.name for metric in task.metrics()] + self.assertContainsSubset([ + "precision_for_class_0", "precision_for_class_1", + "precision_for_class_2", "precision_for_class_3", + "precision_for_class_4", "recall_for_class_0", "recall_for_class_1", + "recall_for_class_2", "recall_for_class_3", "recall_for_class_4" + ], metric_names) + + def test_custom_per_class_metrics(self): + task = classification.RootNodeMulticlassClassification( + num_classes=3, + node_set_name="nodes", + per_class_statistics=True, + class_names=["foo", "bar", "baz"]) + metric_names = [metric.name for metric in task.metrics()] + self.assertContainsSubset([ + "precision_for_foo", "precision_for_bar", "precision_for_baz", + "recall_for_foo", "recall_for_bar", "recall_for_baz" + ], metric_names) + + def test_invalid_number_of_class_names(self): + with self.assertRaises(ValueError): + classification.GraphMulticlassClassification( + num_classes=5, node_set_name="nodes", class_names=["foo", "bar"]) + if __name__ == "__main__": tf.test.main()