Skip to content

Commit

Permalink
internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 351169539
  • Loading branch information
xuanhuiwang authored and ramakumar1729 committed Jan 29, 2021
1 parent 1f1c2da commit 8693c77
Showing 1 changed file with 13 additions and 37 deletions.
50 changes: 13 additions & 37 deletions tensorflow_ranking/python/losses_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def ndcg(labels, ranks=None, perm_mat=None):
return normalized_dcg


class _LambdaWeight(object):
class _LambdaWeight(object, metaclass=abc.ABCMeta):
"""Interface for ranking metric optimization.
This class wraps weights used in the LambdaLoss framework for ranking metric
Expand All @@ -158,8 +158,6 @@ class _LambdaWeight(object):
together with standard loss such as logistic loss and softmax loss.
"""

__metaclass__ = abc.ABCMeta

@abc.abstractmethod
def pair_weights(self, labels, ranks):
"""Returns the weight adjustment `Tensor` for example pairs.
Expand Down Expand Up @@ -514,11 +512,9 @@ def _sample_gumbel(shape, eps=1e-20, seed=None):
return -tf.math.log(-tf.math.log(u + eps) + eps)


class _RankingLoss(object):
class _RankingLoss(object, metaclass=abc.ABCMeta):
"""Interface for ranking loss."""

__metaclass__ = abc.ABCMeta

def __init__(self, name, lambda_weight=None, temperature=1.0):
"""Constructor.
Expand Down Expand Up @@ -627,26 +623,6 @@ def compute_per_list(self, labels, logits, weights):
"""
raise NotImplementedError('Calling an abstract method.')

def eval_metric_unreduced(self, labels, logits, weights):
"""Computes the unreduced eval metric for the loss.
Args:
labels: A `Tensor` of the same shape as `logits` representing graded
relevance.
logits: A `Tensor` with shape [batch_size, list_size]. Each value is the
ranking score of the corresponding item.
weights: A scalar, a `Tensor` with shape [batch_size, 1] for list-wise
weights, or a `Tensor` with shape [batch_size, list_size] for item-wise
weights.
Returns:
A pair of `Tensor` objects containing losses and weights for use in
a weighted average.
"""
losses, loss_weights = self.compute_unreduced_loss(labels, logits)
weights = tf.multiply(self.normalize_weights(labels, weights), loss_weights)
return losses, weights

def eval_metric(self, labels, logits, weights):
"""Computes the eval metric for the loss in tf.estimator (not tf.keras).
Expand All @@ -664,15 +640,14 @@ def eval_metric(self, labels, logits, weights):
Returns:
A metric op.
"""
losses, weights = self.eval_metric_unreduced(labels, logits, weights)
losses, loss_weights = self.compute_unreduced_loss(labels, logits)
weights = tf.multiply(self.normalize_weights(labels, weights), loss_weights)
return tf.compat.v1.metrics.mean(losses, weights)


class _PairwiseLoss(_RankingLoss):
class _PairwiseLoss(_RankingLoss, metaclass=abc.ABCMeta):
"""Interface for pairwise ranking loss."""

__metaclass__ = abc.ABCMeta

@abc.abstractmethod
def _pairwise_loss(self, pairwise_logits):
"""The loss of pairwise logits with l_i > l_j."""
Expand Down Expand Up @@ -838,11 +813,12 @@ def compute(self, labels, logits, weights, reduction):
return tf.compat.v1.losses.compute_weighted_loss(
losses, weights, reduction=reduction)

def eval_metric_unreduced(self, labels, logits, weights):
def eval_metric(self, labels, logits, weights):
"""See `_RankingLoss`."""
logits = self.get_logits(logits)
labels, logits = self.precompute(labels, logits, weights)
return self.compute_unreduced_loss(labels, logits)
losses, weights = self.compute_unreduced_loss(labels, logits)
return tf.compat.v1.metrics.mean(losses, weights)

def compute_per_list(self, labels, logits, weights):
"""See `_RankingLoss`."""
Expand Down Expand Up @@ -933,7 +909,7 @@ class ClickEMLoss(_PointwiseLoss):
"""

def __init__(self, name, temperature=1.0):
super(ClickEMLoss, self).__init__(name, None, temperature)
super().__init__(name, None, temperature)

def _compute_latent_prob(self, clicks, exam_logits, rel_logits):
"""Computes the probability of latent variables in EM.
Expand Down Expand Up @@ -1003,7 +979,7 @@ def __init__(self, name, temperature=1.0):
name: A string used as the name for this loss.
temperature: A float number to modify the logits=logits/temperature.
"""
super(SigmoidCrossEntropyLoss, self).__init__(name, None, temperature)
super().__init__(name, None, temperature)

def compute_unreduced_loss(self, labels, logits):
"""See `_RankingLoss`."""
Expand All @@ -1026,7 +1002,7 @@ def __init__(self, name):
name: A string used as the name for this loss.
"""
# temperature is not used in this loss.
super(MeanSquaredLoss, self).__init__(name, None, temperature=1.0)
super().__init__(name, None, temperature=1.0)

def compute_unreduced_loss(self, labels, logits):
"""See `_RankingLoss`."""
Expand Down Expand Up @@ -1080,7 +1056,7 @@ class ApproxNDCGLoss(_ListwiseLoss):
# Use a different default temperature.
def __init__(self, name, lambda_weight=None, temperature=0.1):
"""See `_ListwiseLoss`."""
super(ApproxNDCGLoss, self).__init__(name, lambda_weight, temperature)
super().__init__(name, lambda_weight, temperature)

def compute_unreduced_loss(self, labels, logits):
"""See `_RankingLoss`."""
Expand All @@ -1106,7 +1082,7 @@ class ApproxMRRLoss(_ListwiseLoss):
# Use a different default temperature.
def __init__(self, name, lambda_weight=None, temperature=0.1):
"""See `_ListwiseLoss`."""
super(ApproxMRRLoss, self).__init__(name, lambda_weight, temperature)
super().__init__(name, lambda_weight, temperature)

def compute_unreduced_loss(self, labels, logits):
"""See `_RankingLoss`."""
Expand Down

0 comments on commit 8693c77

Please sign in to comment.