Skip to content

Commit

Permalink
Merge pull request #47873 from zenogantner:keras-losses-python3
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 365854152
Change-Id: Ib6a4f0beb6a49903496186416ffc316f12723549
  • Loading branch information
tensorflower-gardener committed Mar 30, 2021
2 parents 6472268 + 0062afb commit 9fde795
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 28 deletions.
47 changes: 20 additions & 27 deletions tensorflow/python/keras/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def __init__(self,
name: (Optional) name for the loss.
**kwargs: The keyword arguments that are passed on to `fn`.
"""
super(LossFunctionWrapper, self).__init__(reduction=reduction, name=name)
super().__init__(reduction=reduction, name=name)
self.fn = fn
self._fn_kwargs = kwargs

Expand All @@ -262,7 +262,7 @@ def get_config(self):
config = {}
for k, v in self._fn_kwargs.items():
config[k] = backend.eval(v) if tf_utils.is_tensor_or_variable(v) else v
base_config = super(LossFunctionWrapper, self).get_config()
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))


Expand Down Expand Up @@ -321,8 +321,7 @@ def __init__(self,
more details.
name: Optional name for the op. Defaults to 'mean_squared_error'.
"""
super(MeanSquaredError, self).__init__(
mean_squared_error, name=name, reduction=reduction)
super().__init__(mean_squared_error, name=name, reduction=reduction)


@keras_export('keras.losses.MeanAbsoluteError')
Expand Down Expand Up @@ -380,8 +379,7 @@ def __init__(self,
more details.
name: Optional name for the op. Defaults to 'mean_absolute_error'.
"""
super(MeanAbsoluteError, self).__init__(
mean_absolute_error, name=name, reduction=reduction)
super().__init__(mean_absolute_error, name=name, reduction=reduction)


@keras_export('keras.losses.MeanAbsolutePercentageError')
Expand Down Expand Up @@ -441,7 +439,7 @@ def __init__(self,
name: Optional name for the op. Defaults to
'mean_absolute_percentage_error'.
"""
super(MeanAbsolutePercentageError, self).__init__(
super().__init__(
mean_absolute_percentage_error, name=name, reduction=reduction)


Expand Down Expand Up @@ -502,7 +500,7 @@ def __init__(self,
name: Optional name for the op. Defaults to
'mean_squared_logarithmic_error'.
"""
super(MeanSquaredLogarithmicError, self).__init__(
super().__init__(
mean_squared_logarithmic_error, name=name, reduction=reduction)


Expand Down Expand Up @@ -596,7 +594,7 @@ def __init__(self,
more details.
name: (Optional) Name for the op. Defaults to 'binary_crossentropy'.
"""
super(BinaryCrossentropy, self).__init__(
super().__init__(
binary_crossentropy,
name=name,
reduction=reduction,
Expand Down Expand Up @@ -675,7 +673,7 @@ def __init__(self,
more details.
name: Optional name for the op. Defaults to 'categorical_crossentropy'.
"""
super(CategoricalCrossentropy, self).__init__(
super().__init__(
categorical_crossentropy,
name=name,
reduction=reduction,
Expand Down Expand Up @@ -752,7 +750,7 @@ def __init__(self,
name: Optional name for the op. Defaults to
'sparse_categorical_crossentropy'.
"""
super(SparseCategoricalCrossentropy, self).__init__(
super().__init__(
sparse_categorical_crossentropy,
name=name,
reduction=reduction,
Expand Down Expand Up @@ -815,7 +813,7 @@ def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='hinge'):
more details.
name: Optional name for the op. Defaults to 'hinge'.
"""
super(Hinge, self).__init__(hinge, name=name, reduction=reduction)
super().__init__(hinge, name=name, reduction=reduction)


@keras_export('keras.losses.SquaredHinge')
Expand Down Expand Up @@ -876,8 +874,7 @@ def __init__(self,
more details.
name: Optional name for the op. Defaults to 'squared_hinge'.
"""
super(SquaredHinge, self).__init__(
squared_hinge, name=name, reduction=reduction)
super().__init__(squared_hinge, name=name, reduction=reduction)


@keras_export('keras.losses.CategoricalHinge')
Expand Down Expand Up @@ -936,8 +933,7 @@ def __init__(self,
more details.
name: Optional name for the op. Defaults to 'categorical_hinge'.
"""
super(CategoricalHinge, self).__init__(
categorical_hinge, name=name, reduction=reduction)
super().__init__(categorical_hinge, name=name, reduction=reduction)


@keras_export('keras.losses.Poisson')
Expand Down Expand Up @@ -993,7 +989,7 @@ def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='poisson'):
more details.
name: Optional name for the op. Defaults to 'poisson'.
"""
super(Poisson, self).__init__(poisson, name=name, reduction=reduction)
super().__init__(poisson, name=name, reduction=reduction)


@keras_export('keras.losses.LogCosh')
Expand Down Expand Up @@ -1050,7 +1046,7 @@ def __init__(self, reduction=losses_utils.ReductionV2.AUTO, name='log_cosh'):
more details.
name: Optional name for the op. Defaults to 'log_cosh'.
"""
super(LogCosh, self).__init__(log_cosh, name=name, reduction=reduction)
super().__init__(log_cosh, name=name, reduction=reduction)


@keras_export('keras.losses.KLDivergence')
Expand Down Expand Up @@ -1110,8 +1106,7 @@ def __init__(self,
more details.
name: Optional name for the op. Defaults to 'kl_divergence'.
"""
super(KLDivergence, self).__init__(
kl_divergence, name=name, reduction=reduction)
super().__init__(kl_divergence, name=name, reduction=reduction)


@keras_export('keras.losses.Huber')
Expand Down Expand Up @@ -1178,8 +1173,7 @@ def __init__(self,
more details.
name: Optional name for the op. Defaults to 'huber_loss'.
"""
super(Huber, self).__init__(
huber, name=name, reduction=reduction, delta=delta)
super().__init__(huber, name=name, reduction=reduction, delta=delta)


@keras_export('keras.metrics.mean_squared_error', 'keras.metrics.mse',
Expand Down Expand Up @@ -1979,7 +1973,7 @@ def __init__(self,
axis=-1,
reduction=losses_utils.ReductionV2.AUTO,
name='cosine_similarity'):
super(CosineSimilarity, self).__init__(
super().__init__(
cosine_similarity, reduction=reduction, name=name, axis=axis)


Expand Down Expand Up @@ -2078,11 +2072,10 @@ def get(identifier):
return deserialize(identifier)
if isinstance(identifier, dict):
return deserialize(identifier)
elif callable(identifier):
if callable(identifier):
return identifier
else:
raise ValueError(
'Could not interpret loss function identifier: {}'.format(identifier))
raise ValueError(
f'Could not interpret loss function identifier: {identifier}')


LABEL_DTYPES_FOR_LOSSES = {
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/python/keras/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,7 +1783,7 @@ def test_loss_with_non_default_dtype(self):
class BinaryTruePositivesViaControlFlow(losses.Loss):

def __init__(self, reduction=losses_utils.ReductionV2.AUTO):
super(BinaryTruePositivesViaControlFlow, self).__init__(reduction=reduction)
super().__init__(reduction=reduction)

def call(self, y_true, y_pred):
y_true = math_ops.cast(y_true, dtypes.bool)
Expand Down

0 comments on commit 9fde795

Please sign in to comment.