Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

binary precision and recall metrics #86

Merged
merged 22 commits into from
Sep 19, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions elegy/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .accuracy import Accuracy, accuracy
from .precision import Precision, precision
from .recall import Recall, recall
from .binary_crossentropy import BinaryCrossentropy, binary_crossentropy
from .categorical_accuracy import CategoricalAccuracy, categorical_accuracy
from .mean import Mean
Expand All @@ -15,6 +17,10 @@
__all__ = [
"Accuracy",
"accuracy",
"Precision",
"precision",
"Recall",
"recall",
"BinaryCrossentropy",
"binary_crossentropy",
"CategoricalAccuracy",
Expand Down
95 changes: 95 additions & 0 deletions elegy/metrics/precision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from elegy import types
from elegy import utils
import typing as tp

import jax.numpy as jnp

from elegy.metrics.mean import Mean


def precision(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.ndarray:

if y_true.dtype != y_pred.dtype:
y_pred = y_pred.astype(y_true.dtype)

return (y_true[y_pred == 1] == y_pred[y_pred == 1]).astype(jnp.float32)


class Precision(Mean):
"""
Calculates how often predictions equals labels when predictions classes are equal to one. This metric creates two local variables,
`total` and `count` that are used to compute the frequency with which `y_pred` matches `y_true`. This frequency is
ultimately returned as `binary precision`: an idempotent operation that simply
divides `total` by `count`. If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.

```python
precision = elegy.metrics.Precision()

result = precision(
anvelezec marked this conversation as resolved.
Show resolved Hide resolved
y_true=jnp.array([0, 1, 1, 1]), y_pred=jnp.array([1, 0, 1, 1])
)
assert result == 0.6666667 # 2 / 3

result = precision(
y_true=jnp.array([1, 1, 1, 1]), y_pred=jnp.array([1, 1, 0, 0])
)
assert result == 0.8 # 4 / 5
```

Usage with elegy API:

```python
model = elegy.Model(
module_fn,
loss=elegy.losses.CategoricalCrossentropy(),
metrics=elegy.metrics.Precision(),
optimizer=optix.adam(1e-3),
)
```
"""

def __init__(self, on: tp.Optional[types.IndexLike] = None, **kwargs):
"""
Creates a `Precision` instance.

Arguments:
on: A string or integer, or iterable of string or integers, that
indicate how to index/filter the `y_true` and `y_pred`
arguments before passing them to `call`. For example if `on = "a"` then
`y_true = y_true["a"]`. If `on` is an iterable
the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
then `y_true = y_true["a"][0]["b"]`, same for `y_pred`. For more information
check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
kwargs: Additional keyword arguments passed to Module.
"""
super().__init__(on=on, **kwargs)

def call(
self,
y_true: jnp.ndarray,
y_pred: jnp.ndarray,
sample_weight: tp.Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
"""
Accumulates metric statistics. `y_true` and `y_pred` should have the same shape.

Arguments:
y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
sample_weight: Optional `sample_weight` acts as a
coefficient for the metric. If a scalar is provided, then the metric is
simply scaled by the given value. If `sample_weight` is a tensor of size
`[batch_size]`, then the metric for each sample of the batch is rescaled
by the corresponding element in the `sample_weight` vector. If the shape
of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted
to this shape), then each metric element of `y_pred` is scaled by the
corresponding value of `sample_weight`. (Note on `dN-1`: all metric
functions reduce by 1 dimension, usually the last axis (-1)).
Returns:
Array with the cumulative precision.
"""

return super().call(
values=precision(y_true=y_true, y_pred=y_pred), sample_weight=sample_weight,
)
22 changes: 22 additions & 0 deletions elegy/metrics/precision_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import unittest
import elegy

from elegy.testing_utils import transform_and_run
import jax.numpy as jnp


class PrecisionTest(unittest.TestCase):
@transform_and_run
def test_basic(self):

precision = elegy.metrics.Precision()

result = precision(
y_true=jnp.array([0, 1, 1, 1]), y_pred=jnp.array([1, 0, 1, 1])
)
assert result == 0.6666667
anvelezec marked this conversation as resolved.
Show resolved Hide resolved

result = precision(
y_true=jnp.array([1, 1, 1, 1]), y_pred=jnp.array([1, 1, 0, 0])
)
assert result == 0.8
anvelezec marked this conversation as resolved.
Show resolved Hide resolved
95 changes: 95 additions & 0 deletions elegy/metrics/recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from elegy import types
from elegy import utils
import typing as tp

import jax.numpy as jnp

from elegy.metrics.mean import Mean


def recall(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.ndarray:

if y_true.dtype != y_pred.dtype:
y_pred = y_pred.astype(y_true.dtype)

return (y_true[y_true == 1] == y_pred[y_true == 1]).astype(jnp.float32)


class Recall(Mean):
"""
Calculates how often predictions equals labels when real classes are equal to one. This metric creates two local variables,
`total` and `count` that are used to compute the frequency with which `y_pred` matches `y_true`. This frequency is
ultimately returned as `binary recall`: an idempotent operation that simply
divides `total` by `count`. If `sample_weight` is `None`, weights default to 1.
Use `sample_weight` of 0 to mask values.

```python
recall = elegy.metrics.Recall()

result = recall(
y_true=jnp.array([0, 1, 1, 1]), y_pred=jnp.array([1, 0, 1, 1])
)
assert result == 0.6666667 # 2 / 3

result = recall(
y_true=jnp.array([1, 1, 1, 1]), y_pred=jnp.array([1, 0, 0, 0])
)
assert result == 0.42857143 # 3 / 7
```

Usage with elegy API:

```python
model = elegy.Model(
module_fn,
loss=elegy.losses.CategoricalCrossentropy(),
metrics=elegy.metrics.Recall(),
optimizer=optix.adam(1e-3),
)
```
"""

def __init__(self, on: tp.Optional[types.IndexLike] = None, **kwargs):
"""
Creates a `Recall` instance.

Arguments:
on: A string or integer, or iterable of string or integers, that
indicate how to index/filter the `y_true` and `y_pred`
arguments before passing them to `call`. For example if `on = "a"` then
`y_true = y_true["a"]`. If `on` is an iterable
the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
then `y_true = y_true["a"][0]["b"]`, same for `y_pred`. For more information
check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
kwargs: Additional keyword arguments passed to Module.
"""
super().__init__(on=on, **kwargs)

def call(
self,
y_true: jnp.ndarray,
y_pred: jnp.ndarray,
sample_weight: tp.Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
"""
Accumulates metric statistics. `y_true` and `y_pred` should have the same shape.

Arguments:
y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.
y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.
sample_weight: Optional `sample_weight` acts as a
coefficient for the metric. If a scalar is provided, then the metric is
simply scaled by the given value. If `sample_weight` is a tensor of size
`[batch_size]`, then the metric for each sample of the batch is rescaled
by the corresponding element in the `sample_weight` vector. If the shape
of `sample_weight` is `[batch_size, d0, .. dN-1]` (or can be broadcasted
to this shape), then each metric element of `y_pred` is scaled by the
corresponding value of `sample_weight`. (Note on `dN-1`: all metric
functions reduce by 1 dimension, usually the last axis (-1)).
Returns:
Array with the cumulative recall.
"""

return super().call(
values=recall(y_true=y_true, y_pred=y_pred), sample_weight=sample_weight,
)
18 changes: 18 additions & 0 deletions elegy/metrics/recall_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import unittest
import elegy

from elegy.testing_utils import transform_and_run
import jax.numpy as jnp


class RecallTest(unittest.TestCase):
@transform_and_run
def test_basic(self):

recall = elegy.metrics.Recall()

result = recall(y_true=jnp.array([0, 1, 1, 1]), y_pred=jnp.array([1, 0, 1, 1]))
assert result == 0.6666667

result = recall(y_true=jnp.array([1, 1, 1, 1]), y_pred=jnp.array([1, 0, 0, 0]))
assert result == 0.42857143