Skip to content

Commit

Permalink
Merge pull request #24 from EmanueleGhelfi/fix_hinge_loss
Browse files Browse the repository at this point in the history
Fix hinge loss + add gan documentation
  • Loading branch information
galeone committed Sep 9, 2019
2 parents f053532 + c103d2f commit b1925ac
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 12 deletions.
46 changes: 41 additions & 5 deletions ashpy/keras/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __init__(self) -> None:
where c is the condition and x are real samples.
.. [1] https://arxiv.org/abs/1611.04076
.. [1] Least Squares Generative Adversarial Networks https://arxiv.org/abs/1611.04076
"""
self._positive_mse = tf.keras.losses.MeanSquaredError(
Expand Down Expand Up @@ -189,7 +189,23 @@ class DHingeLoss(tf.keras.losses.Loss):
Discriminator Hinge Loss as Keras Metric.
See Geometric GAN [1]_ for more details.
.. [1] https://arxiv.org/abs/1705.02894
The Discriminator Hinge loss is the hinge version
of the adversarial loss.
The Hinge loss is defined as:
.. math::
L_{\text{hinge}} = \max(0, 1 -t y)
where y is the Discriminator output
and t is the target class (+1 or -1 in the case of binary classification).
For the case of GANs:
.. math::
L_{D_{\text{hinge}}} = - \mathbb{E}_{(x,y) \sim p_data} [ \min(0, -1+D(x,y)) ] -
\mathbb{E}_{x \sim p_x, y \sim p_data} [ \min(0, -1 - D(G(z),y)) ]
.. [1] Geometric GAN https://arxiv.org/abs/1705.02894
"""

def __init__(self) -> None:
Expand Down Expand Up @@ -236,7 +252,28 @@ class GHingeLoss(tf.keras.losses.Loss):
Generator Hinge Loss as Keras Metric.
See Geometric GAN [1]_ for more details.
.. [1] https://arxiv.org/abs/1705.02894
The Generator Hinge loss is the hinge version
of the adversarial loss.
The Hinge loss is defined as:
.. math::
L_{\text{hinge}} = \max(0, 1 - t y)
where y is the Discriminator output
and t is the target class (+1 or -1 in the case of binary classification).
The target class of the generated images is +1.
For the case of GANs
.. math::
L_{G_{\text{hinge}}} = - \mathbb{E}_{(x \sim p_x, y \sim p_data} [ \min(0, -1+D(G(x),y)) ]
This can be simply approximated as:
.. math::
L_{G_{\text{hinge}}} = - \mathbb{E}_{(x \sim p_x, y \sim p_data} [ D(G(x),y) ]
.. [1] Geometric GAN https://arxiv.org/abs/1705.02894
"""

Expand All @@ -263,6 +300,5 @@ def reduction(self, value: tf.keras.losses.Reduction) -> None:

def call(self, d_real: tf.Tensor, d_fake: tf.Tensor) -> tf.Tensor:
"""Computes the hinge loss"""
fake_loss = -tf.nn.relu(d_fake)

return fake_loss
return -d_fake
8 changes: 4 additions & 4 deletions ashpy/losses/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from enum import Enum
from enum import Enum, auto
from typing import TYPE_CHECKING, List, Type, Union

import tensorflow as tf
Expand All @@ -32,9 +32,9 @@
class AdversarialLossType(Enum):
"""Enumeration for Adversarial Losses. Implemented: GAN and LSGAN."""

GAN = 0 # classical gan loss (minmax)
LSGAN = 1 # Least Square GAN
HINGE_LOSS = 2 # Hinge loss
GAN = auto() # classical gan loss (minmax)
LSGAN = auto() # Least Square GAN
HINGE_LOSS = auto() # Hinge loss


class GANExecutor(Executor, ABC):
Expand Down
9 changes: 6 additions & 3 deletions ashpy/models/gans.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
:nosignatures:
:toctree: models
Generator
ConvGenerator
DenseGenerator
----
Expand All @@ -33,7 +34,8 @@
:nosignatures:
:toctree: models
Discriminator
ConvDiscriminator
DenseDiscriminator
----
Expand All @@ -43,7 +45,8 @@
:nosignatures:
:toctree: models
Encoder
ConvEncoder
DenseEncoder
"""
from ashpy.models.convolutional.decoders import BaseDecoder as BaseConvDecoder
Expand Down

0 comments on commit b1925ac

Please sign in to comment.