Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal change #82

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 59 additions & 14 deletions tensorflow_gnn/runner/tasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,47 @@
Tensor = Union[tf.Tensor, tf.RaggedTensor]


class _FromLogitsMixIn(tf.keras.metrics.Metric):
"""Mixin for `tf.keras.metrics.Metric` with a from_logits option."""
class _GnnMetricMixIn(tf.keras.metrics.Metric):
"""Mixin for `tf.keras.metrics.Metric` with gnn-specific options."""

def __init__(self, from_logits: bool, *args, **kwargs) -> None:
def __init__(self,
from_logits: bool,
*args,
class_id: int = -1,
**kwargs) -> None:
"""Initializes metric mixin class.

Args:
from_logits: Whether the predictions are logits.
*args: Variable arguments.
class_id: Categorical class id for per-class precision/recall.
Assume class id starts with 0.
**kwargs: Variable key-word arguments.
"""
super().__init__(*args, **kwargs)
self._from_logits = from_logits
self._class_id = class_id

def update_state(self,
y_true: tf.Tensor,
y_pred: tf.Tensor,
sample_weight: Optional[tf.Tensor] = None) -> None:
return super().update_state(
y_true,
tf.nn.sigmoid(y_pred) if self._from_logits else y_pred,
sample_weight)


class _Precision(_FromLogitsMixIn, tf.keras.metrics.Precision):
if self._from_logits:
if self._class_id >= 0:
# Multi-class classification.
y_true = (y_true == self._class_id)
y_pred = (tf.argmax(y_pred, -1) == self._class_id)
else:
# Binary classification.
y_pred = tf.nn.sigmoid(y_pred)
return super().update_state(y_true, y_pred, sample_weight)


class _Precision(_GnnMetricMixIn, tf.keras.metrics.Precision):
pass


class _Recall(_FromLogitsMixIn, tf.keras.metrics.Recall):
class _Recall(_GnnMetricMixIn, tf.keras.metrics.Recall):
pass


Expand Down Expand Up @@ -98,7 +117,20 @@ def metrics(self) -> Sequence[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]]:
class _MulticlassClassification(_Classification):
"""Multiclass classification."""

def __init__(self, num_classes: int, *args, **kwargs): # pylint: disable=useless-super-delegation
def __init__(self,
num_classes: int,
*args,
class_names: Optional[Sequence[str]] = None,
per_class_statistics: bool = False,
**kwargs): # pylint: disable=useless-super-delegation
if (class_names is not None and len(class_names) != num_classes):
raise ValueError(f"Expected {num_classes} classes, got "
f"{len(class_names)} class names.")
if class_names is None:
self._class_names = [f"class_{i}" for i in range(num_classes)]
else:
self._class_names = class_names
self._per_class_statistics = per_class_statistics
super().__init__(num_classes, *args, **kwargs)

def losses(self) -> Sequence[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]]:
Expand All @@ -107,8 +139,21 @@ def losses(self) -> Sequence[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]]:

def metrics(self) -> Sequence[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]]:
"""Sparse categorical metrics."""
return (tf.keras.metrics.SparseCategoricalAccuracy(),
tf.keras.metrics.SparseCategoricalCrossentropy(from_logits=True))
metric_objs = [
tf.keras.metrics.SparseCategoricalAccuracy(),
tf.keras.metrics.SparseCategoricalCrossentropy(from_logits=True)]

if self._per_class_statistics:
for (i, class_name) in enumerate(self._class_names):
metric_objs.append(_Precision(
from_logits=True,
class_id=i,
name=f"precision_for_{class_name}"))
metric_objs.append(_Recall(
from_logits=True,
class_id=i,
name=f"recall_for_{class_name}"))
return metric_objs


class _GraphClassification(_Classification):
Expand Down
28 changes: 28 additions & 0 deletions tensorflow_gnn/runner/tasks/classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,34 @@ def test_fit(self,
def test_protocol(self, klass: object):
self.assertIsInstance(klass, orchestration.Task)

def test_default_per_class_metrics(self):
task = classification.GraphMulticlassClassification(
num_classes=5, node_set_name="nodes", per_class_statistics=True)
metric_names = [metric.name for metric in task.metrics()]
self.assertContainsSubset([
"precision_for_class_0", "precision_for_class_1",
"precision_for_class_2", "precision_for_class_3",
"precision_for_class_4", "recall_for_class_0", "recall_for_class_1",
"recall_for_class_2", "recall_for_class_3", "recall_for_class_4"
], metric_names)

def test_custom_per_class_metrics(self):
task = classification.RootNodeMulticlassClassification(
num_classes=3,
node_set_name="nodes",
per_class_statistics=True,
class_names=["foo", "bar", "baz"])
metric_names = [metric.name for metric in task.metrics()]
self.assertContainsSubset([
"precision_for_foo", "precision_for_bar", "precision_for_baz",
"recall_for_foo", "recall_for_bar", "recall_for_baz"
], metric_names)

def test_invalid_number_of_class_names(self):
with self.assertRaises(ValueError):
classification.GraphMulticlassClassification(
num_classes=5, node_set_name="nodes", class_names=["foo", "bar"])


if __name__ == "__main__":
tf.test.main()