Skip to content

Commit

Permalink
sparse categorical crossentropy should check bounds (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-g committed Dec 13, 2020
1 parent 2907c81 commit 5549de5
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

# elegy.losses.mean_absolute_percentage_error

::: elegy.losses.mean_absolute_percentage_error.mean_absolute_percentage_error
selection:
inherited_members: true
10 changes: 10 additions & 0 deletions docs/api/metrics/F1.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

# elegy.metrics.F1

::: elegy.metrics.f1.F1
selection:
inherited_members: true
members:
- __init__
- call

6 changes: 6 additions & 0 deletions docs/api/metrics/f1.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

# elegy.metrics.f1

::: elegy.metrics.f1.f1
selection:
inherited_members: true
2 changes: 1 addition & 1 deletion elegy/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"MeanAbsoluteError",
"mean_absolute_error",
"MeanAbsolutePercentageError",
"mean_percentage_absolute_error",
"mean_absolute_percentage_error",
"MeanSquaredError",
"mean_squared_error",
"MeanSquaredLogarithmicError",
Expand Down
34 changes: 27 additions & 7 deletions elegy/losses/sparse_categorical_crossentropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,32 @@


def sparse_categorical_crossentropy(
y_true: jnp.ndarray, y_pred: jnp.ndarray, from_logits: bool = False
y_true: jnp.ndarray,
y_pred: jnp.ndarray,
from_logits: bool = False,
check_bounds: bool = True,
) -> jnp.ndarray:

n_classes = y_pred.shape[-1]

if from_logits:
y_pred = jax.nn.log_softmax(y_pred)
return -jnp.take_along_axis(y_pred, y_true[..., None], axis=-1)[..., 0]

loss = -jnp.take_along_axis(y_pred, y_true[..., None], axis=-1)[..., 0]
else:
# select output value
y_pred = jnp.take_along_axis(y_pred, y_true[..., None], axis=-1)[..., 0]

# calculate log
y_pred = jnp.maximum(y_pred, utils.EPSILON)
y_pred = jnp.log(y_pred)
return -y_pred
loss = -y_pred

if check_bounds:
# set NaN where y_true is negative or larger/equal to the number of y_pred channels
loss = jnp.where(y_true < 0, jnp.nan, loss)
loss = jnp.where(y_true >= n_classes, jnp.nan, loss)

return loss


class SparseCategoricalCrossentropy(Loss):
Expand Down Expand Up @@ -75,8 +86,8 @@ class SparseCategoricalCrossentropy(Loss):
```python
model = elegy.Model(
module_fn,
loss=lelegy.losses.SparseCategoricalCrossentropy(),
metrics=lelegy.metrics.Accuracy(),
loss=elegy.losses.SparseCategoricalCrossentropy(),
metrics=elegy.metrics.Accuracy(),
optimizer=optax.adam(1e-3),
)
Expand All @@ -89,6 +100,7 @@ def __init__(
reduction: tp.Optional[Reduction] = None,
weight: tp.Optional[float] = None,
on: tp.Optional[types.IndexLike] = None,
check_bounds: tp.Optional[bool] = True,
**kwargs
):
"""
Expand All @@ -109,10 +121,15 @@ def __init__(
the structures will be indexed iteratively, for example if `on = ["a", 0, "b"]`
then `y_true = y_true["a"][0]["b"]`, same for `y_pred`. For more information
check out [Keras-like behavior](https://poets-ai.github.io/elegy/guides/modules-losses-metrics/#keras-like-behavior).
check_bounds: If `True` (default), checks `y_true` for negative values and values
larger or equal than the number of channels in `y_pred`. Sets loss to NaN
if this is the case. If `False`, the check is disabled and the loss may contain
incorrect values.
"""
super().__init__(reduction=reduction, weight=weight, on=on, **kwargs)

self._from_logits = from_logits
self._check_bounds = check_bounds

def call(
self, y_true, y_pred, sample_weight: tp.Optional[jnp.ndarray] = None
Expand All @@ -138,5 +155,8 @@ def call(
"""

return sparse_categorical_crossentropy(
y_true, y_pred, from_logits=self._from_logits
y_true,
y_pred,
from_logits=self._from_logits,
check_bounds=self._check_bounds,
)
15 changes: 15 additions & 0 deletions elegy/losses/sparse_categorical_crossentropy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,18 @@ def test_basic():
)
result = scce(y_true, y_pred) # [0.0513, 2.303]
assert jnp.all(jnp.isclose(result, [0.0513, 2.303], rtol=0.01))


def test_scce_out_of_bounds():
ypred = jnp.zeros([4, 10])
ytrue0 = jnp.array([0, 0, -1, 0])
ytrue1 = jnp.array([0, 0, 10, 0])

scce = elegy.losses.SparseCategoricalCrossentropy()

assert jnp.isnan(scce(ytrue0, ypred)).any()
assert jnp.isnan(scce(ytrue1, ypred)).any()

scce = elegy.losses.SparseCategoricalCrossentropy(check_bounds=False)
assert not jnp.isnan(scce(ytrue0, ypred)).any()
assert not jnp.isnan(scce(ytrue1, ypred)).any()
4 changes: 3 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ nav:
binary_crossentropy: api/losses/binary_crossentropy.md
cosine_similarity: api/losses/cosine_similarity.md
mean_absolute_error: api/losses/mean_absolute_error.md
mean_percentage_absolute_error: api/losses/mean_percentage_absolute_error.md
mean_absolute_percentage_error: api/losses/mean_absolute_percentage_error.md
mean_squared_error: api/losses/mean_squared_error.md
mean_squared_logarithmic_error: api/losses/mean_squared_logarithmic_error.md
sparse_categorical_crossentropy: api/losses/sparse_categorical_crossentropy.md
Expand All @@ -75,6 +75,7 @@ nav:
BinaryAccuracy: api/metrics/BinaryAccuracy.md
BinaryCrossentropy: api/metrics/BinaryCrossentropy.md
CategoricalAccuracy: api/metrics/CategoricalAccuracy.md
F1: api/metrics/F1.md
Mean: api/metrics/Mean.md
MeanAbsoluteError: api/metrics/MeanAbsoluteError.md
MeanSquaredError: api/metrics/MeanSquaredError.md
Expand All @@ -88,6 +89,7 @@ nav:
binary_accuracy: api/metrics/binary_accuracy.md
binary_crossentropy: api/metrics/binary_crossentropy.md
categorical_accuracy: api/metrics/categorical_accuracy.md
f1: api/metrics/f1.md
mean_absolute_error: api/metrics/mean_absolute_error.md
mean_squared_error: api/metrics/mean_squared_error.md
precision: api/metrics/precision.md
Expand Down

0 comments on commit 5549de5

Please sign in to comment.