Skip to content

Commit

Permalink
Separating ashpy Executors and custom Keras losses
Browse files Browse the repository at this point in the history
  • Loading branch information
galeone committed Aug 22, 2019
1 parent 489848a commit 64b912f
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 143 deletions.
Empty file added ashpy/keras/__init__.py
Empty file.
182 changes: 182 additions & 0 deletions ashpy/keras/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright 2019 Zuru Tech HK Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Custom Keras losses, used by the AshPy executors."""

import tensorflow as tf


class L1(tf.keras.losses.Loss):
"""L1 Loss implementation as :py:class:`tf.keras.losses.Loss`."""

def __init__(self) -> None:
"""Initialize the Loss."""
super().__init__()
self._reduction = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE

@property
def reduction(self) -> tf.keras.losses.Reduction:
"""Return the current `reduction` for this type of loss."""
return self._reduction

@reduction.setter
def reduction(self, value: tf.keras.losses.Reduction) -> None:
"""
Set the `reduction`.
Args:
value (:py:class:`tf.keras.losses.Reduction`): Reduction to use for the loss.
"""
self._reduction = value

def call(self, x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
"""Compute the mean of the l1 between x and y."""
if self._reduction == tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE:
axis = None
elif self._reduction == tf.keras.losses.Reduction.NONE:
axis = (1, 2, 3)
else:
raise ValueError("L1Loss: unhandled reduction type")

return tf.reduce_mean(tf.abs(x - y), axis=axis)


class DMinMax(tf.keras.losses.Loss):
r"""Implementation of MinMax Discriminator loss as :py:class:`tf.keras.losses.Loss`.
.. math::
L_{D} = - \frac{1}{2} E [\log(D(x)) + \log (1 - D(G(z))]
"""

def __init__(self, from_logits: bool = True, label_smoothing: float = 0.0) -> None:
"""Initialize the loss."""
self._positive_bce = tf.keras.losses.BinaryCrossentropy(
from_logits=from_logits,
label_smoothing=label_smoothing,
reduction=tf.keras.losses.Reduction.NONE,
)

self._negative_bce = tf.keras.losses.BinaryCrossentropy(
from_logits=from_logits,
label_smoothing=0.0,
reduction=tf.keras.losses.Reduction.NONE,
)
super().__init__()

@property
def reduction(self) -> tf.keras.losses.Reduction:
"""
Return the reduction type of this loss.
Returns:
:py:classes:`tf.keras.losses.Reduction`: Reduction.
"""
return self._positive_bce.reduction

@reduction.setter
def reduction(self, value: tf.keras.losses.Reduction) -> None:
self._positive_bce.reduction = value
self._negative_bce.reduction = value

def call(self, d_real: tf.Tensor, d_fake: tf.Tensor) -> tf.Tensor:
"""
Compute the MinMax Loss.
Play the DiscriminatorMinMax game between the discriminator
computed in real and the discriminator compute with fake inputs.
Args:
d_real (:py:class:`tf.Tensor`): Real data.
d_fake (:py:class:`tf.Tensor`): Fake (generated) data.
Returns:
:py:class:`tf.Tensor`: Output Tensor.
"""
return 0.5 * (
self._positive_bce(tf.ones_like(d_real), d_real)
+ self._negative_bce(tf.zeros_like(d_fake), d_fake)
)


class DLeastSquare(tf.keras.losses.Loss):
"""Discriminator Least Square Loss as :py:class:`tf.keras.losses.Loss`."""

def __init__(self) -> None:
"""Least square Loss for Discriminator.
Reference: Least Squares Generative Adversarial Networks [1]_ .
Basically the Mean Squared Error between
the discriminator output when evaluated in fake samples and 0
and the discriminator output when evaluated in real samples and 1:
For the unconditioned case this is:
.. math::
L_{D} = \frac{1}{2} E[(D(x) - 1)^2 + (0 - D(G(z))^2]
where x are real samples and z is the latent vector.
For the conditioned case this is:
.. math::
L_{D} = \frac{1}{2} E[(D(x, c) - 1)^2 + (0 - D(G(c), c)^2]
where c is the condition and x are real samples.
.. [1] https://arxiv.org/abs/1611.04076
"""
self._positive_mse = tf.keras.losses.MeanSquaredError(
reduction=tf.keras.losses.Reduction.NONE
)
self._negative_mse = tf.keras.losses.MeanSquaredError(
reduction=tf.keras.losses.Reduction.NONE
)
super().__init__()

@property
def reduction(self) -> tf.keras.losses.Reduction:
"""
Return the reduction type for this loss.
Returns:
:py:class:`tf.keras.losses.Reduction`: Reduction.
"""
return self._positive_mse.reduction

@reduction.setter
def reduction(self, value) -> None:
self._positive_mse.reduction = value
self._negative_mse.reduction = value

def call(self, d_real: tf.Tensor, d_fake: tf.Tensor) -> tf.Tensor:
"""
Compute the Least Square Loss.
Args:
d_real (:py:class:`tf.Tensor`): Discriminator evaluated in real samples.
d_fake (:py:class:`tf.Tensor`): Discriminator evaluated in fake samples.
Returns:
:py:class:`tf.Tensor`: Loss.
"""
return 0.5 * (
self._positive_mse(tf.ones_like(d_real), d_real)
+ self._negative_mse(tf.zeros_like(d_fake), d_fake)
)
152 changes: 9 additions & 143 deletions ashpy/losses/gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from ashpy.contexts import GANContext, GANEncoderContext
from ashpy.losses.executor import Executor, SumExecutor
from ashpy.keras.losses import DLeastSquare, DMinMax, L1

if TYPE_CHECKING:
from ashpy.ashtypes import TWeight
Expand Down Expand Up @@ -223,44 +224,9 @@ class GeneratorL1(GANExecutor):
"""

class L1Loss(tf.keras.losses.Loss):
"""L1 Loss implementation as :py:class:`tf.keras.losses.Loss`."""

def __init__(self) -> None:
"""Initialize the Loss."""
super().__init__()
self._reduction = tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE

@property
def reduction(self) -> tf.keras.losses.Reduction:
"""Return the current `reduction` for this type of loss."""
return self._reduction

@reduction.setter
def reduction(self, value: tf.keras.losses.Reduction) -> None:
"""
Set the `reduction`.
Args:
value (:py:class:`tf.keras.losses.Reduction`): Reduction to use for the loss.
"""
self._reduction = value

def call(self, x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
"""Compute the mean of the l1 between x and y."""
if self._reduction == tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE:
axis = None
elif self._reduction == tf.keras.losses.Reduction.NONE:
axis = (1, 2, 3)
else:
raise ValueError("L1Loss: unhandled reduction type")

return tf.reduce_mean(tf.abs(x - y), axis=axis)

def __init__(self) -> None:
"""Initialize the Executor."""
super().__init__(GeneratorL1.L1Loss())
super().__init__(L1())

@Executor.reduce_loss
def call(self, context: GANContext, *, fake: tf.Tensor, real: tf.Tensor, **kwargs):
Expand All @@ -280,7 +246,7 @@ def call(self, context: GANContext, *, fake: tf.Tensor, real: tf.Tensor, **kwarg
return mae


class FeatureMatchingLoss(GeneratorL1):
class FeatureMatchingLoss(GANExecutor):
r"""
Conditional GAN Feature matching loss.
Expand Down Expand Up @@ -308,6 +274,10 @@ class FeatureMatchingLoss(GeneratorL1):
over the axis 1,2,3.
"""

def __init__(self) -> None:
"""Initialize the Executor."""
super().__init__(L1())

@Executor.reduce_loss
def call(
self,
Expand Down Expand Up @@ -604,68 +574,10 @@ class DiscriminatorMinMax(AdversarialLossD):
"""

class GANLoss(tf.keras.losses.Loss):
"""Implementation of MinMax loss as :py:class:`tf.keras.losses.Loss`."""

def __init__(
self, from_logits: bool = True, label_smoothing: float = 0.0
) -> None:
"""Initialize the loss."""
self._positive_bce = tf.keras.losses.BinaryCrossentropy(
from_logits=from_logits,
label_smoothing=label_smoothing,
reduction=tf.keras.losses.Reduction.NONE,
)

self._negative_bce = tf.keras.losses.BinaryCrossentropy(
from_logits=from_logits,
label_smoothing=0.0,
reduction=tf.keras.losses.Reduction.NONE,
)
super().__init__()

@property
def reduction(self) -> tf.keras.losses.Reduction:
"""
Return the reduction type of this loss.
Returns:
:py:classes:`tf.keras.losses.Reduction`: Reduction.
"""
return self._positive_bce.reduction

@reduction.setter
def reduction(self, value: tf.keras.losses.Reduction) -> None:
self._positive_bce.reduction = value
self._negative_bce.reduction = value

def call(self, d_real: tf.Tensor, d_fake: tf.Tensor) -> tf.Tensor:
"""
Compute the MinMax Loss.
Play the DiscriminatorMinMax game between the discriminator
computed in real and the discriminator compute with fake inputs.
Args:
d_real (:py:class:`tf.Tensor`): Real data.
d_fake (:py:class:`tf.Tensor`): Fake (generated) data.
Returns:
:py:class:`tf.Tensor`: Output Tensor.
"""
return 0.5 * (
self._positive_bce(tf.ones_like(d_real), d_real)
+ self._negative_bce(tf.zeros_like(d_fake), d_fake)
)

def __init__(self, from_logits=True, label_smoothing=0.0):
"""Initialize Loss."""
super().__init__(
DiscriminatorMinMax.GANLoss(
from_logits=from_logits, label_smoothing=label_smoothing
)
DMinMax(from_logits=from_logits, label_smoothing=label_smoothing)
)


Expand Down Expand Up @@ -696,55 +608,9 @@ class DiscriminatorLSGAN(AdversarialLossD):
"""

class LeastSquareLoss(tf.keras.losses.Loss):
"""Least Square Loss as :py:class:`tf.keras.losses.Loss`."""

def __init__(self) -> None:
"""Initialize the Loss."""
self._positive_mse = tf.keras.losses.MeanSquaredError(
reduction=tf.keras.losses.Reduction.NONE
)
self._negative_mse = tf.keras.losses.MeanSquaredError(
reduction=tf.keras.losses.Reduction.NONE
)
super().__init__()

@property
def reduction(self) -> tf.keras.losses.Reduction:
"""
Return the reduction type for this loss.
Returns:
:py:class:`tf.keras.losses.Reduction`: Reduction.
"""
return self._positive_mse.reduction

@reduction.setter
def reduction(self, value) -> None:
self._positive_mse.reduction = value
self._negative_mse.reduction = value

def call(self, d_real: tf.Tensor, d_fake: tf.Tensor) -> tf.Tensor:
"""
Compute the Least Square Loss.
Args:
d_real (:py:class:`tf.Tensor`): Discriminator evaluated in real samples.
d_fake (:py:class:`tf.Tensor`): Discriminator evaluated in fake samples.
Returns:
:py:class:`tf.Tensor`: Loss.
"""
return 0.5 * (
self._positive_mse(tf.ones_like(d_real), d_real)
+ self._negative_mse(tf.zeros_like(d_fake), d_fake)
)

def __init__(self) -> None:
"""Initialize loss."""
super().__init__(DiscriminatorLSGAN.LeastSquareLoss())
super().__init__(DLeastSquare())
self.name = "DiscriminatorLSGAN"


Expand Down

0 comments on commit 64b912f

Please sign in to comment.