Skip to content

Commit

Permalink
Merge pull request #21 from EmanueleGhelfi/losses_and_metrics
Browse files Browse the repository at this point in the history
Losses and metrics
  • Loading branch information
galeone committed Sep 5, 2019
2 parents 2929277 + 456ef0f commit 3675bcc
Show file tree
Hide file tree
Showing 22 changed files with 1,230 additions and 161 deletions.
10 changes: 3 additions & 7 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ dist: xenial
language: python
addons:
artifacts: true
apt:
update: true
python:
- '3.7'
stages:
Expand All @@ -11,9 +13,6 @@ stages:
- Deploy
before_install:
- sudo apt-get install -y --no-install-recommends bc
addons:
apt:
update: true
install: &requirements
- pip install -r dev-requirements.txt
- pip install -e .
Expand All @@ -25,10 +24,7 @@ jobs:
- pip install codecov
- pip --no-cache-dir install --upgrade git+https://github.com/thisch/pytest-sphinx.git pytest
- pip --no-cache-dir install pytest-cov
- module=""
- for d in $(ls -d */); do if [ -f "$d"__init__.py ]; then module=${d::-1}; fi
done
- pytest -x -s -vvv --doctest-modules $module --cov=$module
- pytest -x -s -vvv --doctest-modules ashpy tests --cov=ashpy
after_success:
- codecov
- stage: Black
Expand Down
2 changes: 1 addition & 1 deletion ashpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""ASHPY Package."""

from .modes import LogEvalMode
from . import contexts
from . import datasets
from . import layers
Expand All @@ -23,7 +24,6 @@
from . import trainers
from . import ashtypes
from . import keras
from .modes import LogEvalMode

__version__ = "1.0.2"
__url__ = "https://github.com/zurutech/ashpy"
Expand Down
2 changes: 1 addition & 1 deletion ashpy/contexts/base_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def measure_metrics(self) -> None:
def model_selection(self) -> None:
"""Use the metrics to perform model selection."""
for metric in self._metrics:
metric.model_selection(self._ckpt)
metric.model_selection(self._ckpt, self._global_step)

@property
def log_eval_mode(self) -> LogEvalMode:
Expand Down
2 changes: 1 addition & 1 deletion ashpy/contexts/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
import tensorflow as tf # pylint: disable=import-error

from ashpy.contexts.base_context import BaseContext
from ashpy.metrics import Metric
from ashpy.modes import LogEvalMode

if TYPE_CHECKING:
from ashpy.losses.executor import Executor
from ashpy.metrics import Metric


class ClassifierContext(BaseContext):
Expand Down
2 changes: 1 addition & 1 deletion ashpy/contexts/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
import tensorflow as tf

from ashpy.contexts.base_context import BaseContext
from ashpy.metrics import Metric
from ashpy.modes import LogEvalMode

if TYPE_CHECKING:
from ashpy.losses.executor import Executor
from ashpy.metrics import Metric


class GANContext(BaseContext):
Expand Down
84 changes: 84 additions & 0 deletions ashpy/keras/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,87 @@ def call(self, d_real: tf.Tensor, d_fake: tf.Tensor) -> tf.Tensor:
self._positive_mse(tf.ones_like(d_real), d_real)
+ self._negative_mse(tf.zeros_like(d_fake), d_fake)
)


class DHingeLoss(tf.keras.losses.Loss):
r"""
Discriminator Hinge Loss as Keras Metric.
See Geometric GAN [1]_ for more details.
.. [1] https://arxiv.org/abs/1705.02894
"""

def __init__(self) -> None:
"""Initialize the Loss."""
self._hinge_loss_real = tf.keras.losses.Hinge(
reduction=tf.keras.losses.Reduction.NONE
)
self._hinge_loss_fake = tf.keras.losses.Hinge(
reduction=tf.keras.losses.Reduction.NONE
)
super().__init__()

@property
def reduction(self) -> tf.keras.losses.Reduction:
"""Return the current `reduction` for this type of loss."""
return self._hinge_loss_fake.reduction

@reduction.setter
def reduction(self, value: tf.keras.losses.Reduction) -> None:
"""
Set the `reduction`.
Args:
value (:py:class:`tf.keras.losses.Reduction`): Reduction to use for the loss.
"""
self._hinge_loss_fake.reduction = value
self._hinge_loss_real.reduction = value

def call(self, d_real: tf.Tensor, d_fake: tf.Tensor) -> tf.Tensor:
"""Compute the hinge loss"""
real_loss = self._hinge_loss_real(tf.ones_like(d_real), d_real)
fake_loss = self._hinge_loss_fake(
tf.math.negative(tf.ones_like(d_fake)), d_fake
)

loss = real_loss + fake_loss # shape: (batch_size, 1)

return loss


class GHingeLoss(tf.keras.losses.Loss):
r"""
Generator Hinge Loss as Keras Metric.
See Geometric GAN [1]_ for more details.
.. [1] https://arxiv.org/abs/1705.02894
"""

def __init__(self) -> None:
"""Initialize the Loss."""
super().__init__()
self._reduction = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE

@property
def reduction(self) -> tf.keras.losses.Reduction:
"""Return the current `reduction` for this type of loss."""
return self._reduction

@reduction.setter
def reduction(self, value: tf.keras.losses.Reduction) -> None:
"""
Set the `reduction`.
Args:
value (:py:class:`tf.keras.losses.Reduction`): Reduction to use for the loss.
"""
self._reduction = value

def call(self, d_real: tf.Tensor, d_fake: tf.Tensor) -> tf.Tensor:
"""Computes the hinge loss"""
fake_loss = -tf.nn.relu(d_fake)

return fake_loss
17 changes: 10 additions & 7 deletions ashpy/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,20 @@
gan.GANExecutor
gan.AdversarialLossType
gan.AdversarialLossG
gan.AdversarialLossD
gan.GeneratorAdversarialLoss
gan.DiscriminatorAdversarialLoss
gan.GeneratorBCE
gan.GeneratorLSGAN
gan.GeneratorL1
gan.GeneratorHingeLoss
gan.FeatureMatchingLoss
gan.CategoricalCrossEntropy
gan.Pix2PixLoss
gan.Pix2PixLossSemantic
gan.EncoderBCE
gan.DiscriminatorMinMax
gan.DiscriminatorLSGAN
gan.DiscriminatorHingeLoss
gan.get_adversarial_loss_discriminator
gan.get_adversarial_loss_generator
Expand All @@ -79,8 +81,8 @@
from ashpy.losses.classifier import ClassifierLoss
from ashpy.losses.executor import Executor, SumExecutor
from ashpy.losses.gan import (
AdversarialLossD,
AdversarialLossG,
DiscriminatorAdversarialLoss,
GeneratorAdversarialLoss,
AdversarialLossType,
CategoricalCrossEntropy,
DiscriminatorLSGAN,
Expand All @@ -98,22 +100,23 @@
)

__ALL__ = [
"AdversarialLossD",
"AdversarialLossD",
"AdversarialLossG",
"DiscriminatorAdversarialLoss",
"GeneratorAdversarialLoss",
"AdversarialLossType",
"CategoricalCrossEntropy",
"ClassifierLoss",
"ClassifierLoss",
"DiscriminatorLSGAN",
"DiscriminatorMinMax",
"DiscriminatorHingeLoss",
"EncoderBCE",
"Executor",
"FeatureMatchingLoss",
"GANExecutor",
"GeneratorBCE",
"GeneratorL1",
"GeneratorLSGAN",
"GeneratorHingeLoss",
"get_adversarial_loss_discriminator",
"get_adversarial_loss_generator",
"Pix2PixLoss",
Expand Down

0 comments on commit 3675bcc

Please sign in to comment.