Skip to content

Commit

Permalink
Move V2 loss reduction to the keras module as the distribute lib depe…
Browse files Browse the repository at this point in the history
…ndency on it has been removed.

PiperOrigin-RevId: 233004569
  • Loading branch information
pavithrasv authored and tensorflower-gardener committed Feb 8, 2019
1 parent 20bb92c commit 4604d1e
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 143 deletions.
10 changes: 5 additions & 5 deletions tensorflow/python/keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,9 @@
from tensorflow.python.keras.optimizer_v2 import optimizer_v2
from tensorflow.python.keras.saving import saving_utils
from tensorflow.python.keras.utils import data_utils
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training.checkpointable import base as checkpointable
Expand Down Expand Up @@ -476,8 +475,9 @@ def compile(self,
sample_weight = mask
else:
# Update dimensions of weights to match with mask if possible.
mask, _, sample_weight = squeeze_or_expand_dimensions(
mask, None, sample_weight)
mask, _, sample_weight = (
losses_utils.squeeze_or_expand_dimensions(
mask, None, sample_weight))
sample_weight *= mask

output_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)
Expand All @@ -489,7 +489,7 @@ def compile(self,

# Keep track of stateful result tensor and function for the loss.
# Reset reduction here as metric wrapper will take care of that.
loss_fn.reduction = losses_impl.ReductionV2.NONE
loss_fn.reduction = losses_utils.ReductionV2.NONE
output_loss_metric = metrics_module.SumOverBatchSizeMetricWrapper(
loss_fn, name=loss_fn.name)
result_tensor = self._call_metric_fn(output_loss_metric, y_true,
Expand Down
40 changes: 20 additions & 20 deletions tensorflow/python/keras/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.utils import losses_utils
from tensorflow.python.keras.utils.generic_utils import deserialize_keras_object
from tensorflow.python.keras.utils.generic_utils import serialize_keras_object
from tensorflow.python.keras.utils.losses_utils import compute_weighted_loss
from tensorflow.python.keras.utils.tf_utils import is_tensor_or_variable
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
Expand Down Expand Up @@ -63,7 +63,7 @@ def call(self, y_true, y_pred):
"""

def __init__(self,
reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name=None):
self.reduction = reduction
self.name = name
Expand Down Expand Up @@ -97,7 +97,7 @@ def __call__(self, y_true, y_pred, sample_weight=None):
with ops.name_scope(scope_name, format(self.__class__.__name__),
(y_pred, y_true, sample_weight)):
losses = self.call(y_true, y_pred)
return compute_weighted_loss(
return losses_utils.compute_weighted_loss(
losses, sample_weight, reduction=self.reduction)

@classmethod
Expand Down Expand Up @@ -141,7 +141,7 @@ class LossFunctionWrapper(Loss):

def __init__(self,
fn,
reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name=None,
**kwargs):
super(LossFunctionWrapper, self).__init__(reduction=reduction, name=name)
Expand Down Expand Up @@ -192,7 +192,7 @@ class MeanSquaredError(LossFunctionWrapper):
"""

def __init__(self,
reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name='mean_squared_error'):
super(MeanSquaredError, self).__init__(
mean_squared_error, name=name, reduction=reduction)
Expand Down Expand Up @@ -222,7 +222,7 @@ class MeanAbsoluteError(LossFunctionWrapper):
"""

def __init__(self,
reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name='mean_absolute_error'):
super(MeanAbsoluteError, self).__init__(
mean_absolute_error, name=name, reduction=reduction)
Expand Down Expand Up @@ -252,7 +252,7 @@ class MeanAbsolutePercentageError(LossFunctionWrapper):
"""

def __init__(self,
reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name='mean_absolute_percentage_error'):
super(MeanAbsolutePercentageError, self).__init__(
mean_absolute_percentage_error, name=name, reduction=reduction)
Expand Down Expand Up @@ -282,7 +282,7 @@ class MeanSquaredLogarithmicError(LossFunctionWrapper):
"""

def __init__(self,
reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name='mean_squared_logarithmic_error'):
super(MeanSquaredLogarithmicError, self).__init__(
mean_squared_logarithmic_error, name=name, reduction=reduction)
Expand Down Expand Up @@ -326,7 +326,7 @@ class BinaryCrossentropy(LossFunctionWrapper):
def __init__(self,
from_logits=False,
label_smoothing=0,
reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name='binary_crossentropy'):
super(BinaryCrossentropy, self).__init__(
binary_crossentropy,
Expand Down Expand Up @@ -382,7 +382,7 @@ class CategoricalCrossentropy(LossFunctionWrapper):
def __init__(self,
from_logits=False,
label_smoothing=0,
reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name='categorical_crossentropy'):
super(CategoricalCrossentropy, self).__init__(
categorical_crossentropy,
Expand Down Expand Up @@ -434,7 +434,7 @@ class SparseCategoricalCrossentropy(LossFunctionWrapper):

def __init__(self,
from_logits=False,
reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name=None):
super(SparseCategoricalCrossentropy, self).__init__(
sparse_categorical_crossentropy,
Expand Down Expand Up @@ -470,7 +470,7 @@ class Hinge(LossFunctionWrapper):
"""

def __init__(self,
reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name=None):
super(Hinge, self).__init__(hinge, name=name, reduction=reduction)

Expand Down Expand Up @@ -502,7 +502,7 @@ class SquaredHinge(LossFunctionWrapper):
"""

def __init__(self,
reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name='squared_hinge'):
super(SquaredHinge, self).__init__(
squared_hinge, name=name, reduction=reduction)
Expand All @@ -529,7 +529,7 @@ class CategoricalHinge(LossFunctionWrapper):
"""

def __init__(self,
reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name='categorical_hinge'):
super(CategoricalHinge, self).__init__(
categorical_hinge, name=name, reduction=reduction)
Expand Down Expand Up @@ -558,7 +558,7 @@ class LogLoss(LossFunctionWrapper):
"""

def __init__(self,
reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name='logloss'):
super(LogLoss, self).__init__(logloss, name=name, reduction=reduction)

Expand Down Expand Up @@ -586,7 +586,7 @@ class Poisson(LossFunctionWrapper):
"""

def __init__(self,
reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name='poisson'):
super(Poisson, self).__init__(poisson, name=name, reduction=reduction)

Expand Down Expand Up @@ -614,7 +614,7 @@ class LogCosh(LossFunctionWrapper):
"""

def __init__(self,
reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name='logcosh'):
super(LogCosh, self).__init__(logcosh, name=name, reduction=reduction)

Expand Down Expand Up @@ -642,7 +642,7 @@ class KLDivergence(LossFunctionWrapper):
"""

def __init__(self,
reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name='kullback_leibler_divergence'):
super(KLDivergence, self).__init__(
kullback_leibler_divergence, name=name, reduction=reduction)
Expand Down Expand Up @@ -685,7 +685,7 @@ class Huber(LossFunctionWrapper):

def __init__(self,
delta=1.0,
reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name='huber_loss'):
super(Huber, self).__init__(
huber_loss, name=name, reduction=reduction, delta=delta)
Expand Down Expand Up @@ -985,7 +985,7 @@ class CosineProximity(Loss):

def __init__(self,
axis=-1,
reduction=losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE,
reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE,
name=None):
super(CosineProximity, self).__init__(reduction=reduction, name=name)
self.axis = axis
Expand Down
Loading

0 comments on commit 4604d1e

Please sign in to comment.