In [7]:
import torch
import torch.nn.functional as F

## CL-DRD

In [29]:
def lambda_mrr_loss(y_pred, y_true, eps=1e-10, padded_value_indicator=-1, reduction="mean", sigma=1.):
    """
    y_pred: FloatTensor [bz, topk]
    y_true: FloatTensor [bz, topk]
    """
    device = y_pred.device
    y_pred = y_pred.clone()
    y_true = y_true.clone()

    padded_mask = y_true == padded_value_indicator
    y_pred[padded_mask] = float("-inf")
    y_true[padded_mask] = float("-inf")

    # Here we sort the true and predicted relevancy scores.
    y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1)

    # After sorting, we can mask out the pairs of indices (i, j) containing index of a padded element.
    true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred)
    true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :]
    padded_pairs_mask = torch.isfinite(true_diffs)
    padded_pairs_mask = padded_pairs_mask & (true_diffs > 0)


    # Here we find the gains, discounts and ideal DCGs per slate.
    inv_pos_idxs = 1. / torch.arange(1, y_pred.shape[1] + 1).to(device)
    weights = torch.abs(inv_pos_idxs.view(1,-1,1) - inv_pos_idxs.view(1,1,-1)) # [1, topk, topk]

    # We are clamping the array entries to maintain correct backprop (log(0) and division by 0)
    scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :])
    scores_diffs.masked_fill(torch.isnan(scores_diffs), 0.)
    losses = torch.log(1. + torch.exp(-scores_diffs)) * weights #[bz, topk, topk]

    if reduction == "sum":
        loss = torch.sum(losses[padded_pairs_mask])
    elif reduction == "mean":
        loss = torch.mean(losses[padded_pairs_mask])
    else:
        raise ValueError("Reduction method can be either sum or mean")

    return loss


if __name__ == "__main__":

    y_true = torch.FloatTensor([[1.,1./2, 1./3, 0.,0., -1/2., -1.]])
    y_pred_1 = torch.FloatTensor([[2.3, 1.2, 1.1, 0.5, 0.23, 0.21, 40]])
    y_pred_2 = torch.FloatTensor([[0.5, 0.23, 2.3, 1.2, 1.1, 5, 20]])
    
    print(lambda_mrr_loss(y_pred_1, y_true))
    print(lambda_mrr_loss(y_pred_2, y_true))
    
    y_true_batch = torch.FloatTensor([[1.,1./2, 1./3, 0.,0., -1/2., -1.],[1.,1./2, 1./3, 0.,0., -1/2., -1.]])
    y_pred_batch = torch.FloatTensor([[2.3, 1.2, 1.1, 0.5, 0.23, 0.21, 40],[0.5, 0.23, 2.3, 1.2, 1.1, 5, 20]])
   
    print(lambda_mrr_loss(y_pred_batch, y_true_batch))
    
    y_true = torch.FloatTensor([[1.,1./2, 1./3, 0.,0., -1/2., -1.]])
    y_pred = torch.FloatTensor([[20.3, 10.2, 10.1, 50, 100, 21, 40]])
    
    print(lambda_mrr_loss(y_pred, y_true))
    

tensor(0.0890)
tensor(1.1732)
tensor(0.6311)
tensor(inf)


## Ours

In [30]:
def lambda_mrr_loss(y_pred, y_true, eps=1e-10, padded_value_indicator=-1, reduction="mean", sigma=1.):
    """
    y_pred: FloatTensor [bz, topk]
    y_true: FloatTensor [bz, topk]
    """
    device = y_pred.device
    y_pred = y_pred.clone()
    y_true = y_true.clone()

    padded_mask = y_true == padded_value_indicator
    y_pred[padded_mask] = float("-inf")
    y_true[padded_mask] = float("-inf")

    # Here we sort the true and predicted relevancy scores.
    y_pred_sorted, indices_pred = y_pred.sort(descending=True, dim=-1)

    # After sorting, we can mask out the pairs of indices (i, j) containing index of a padded element.
    true_sorted_by_preds = torch.gather(y_true, dim=1, index=indices_pred)
    true_diffs = true_sorted_by_preds[:, :, None] - true_sorted_by_preds[:, None, :]
    padded_pairs_mask = torch.isfinite(true_diffs)
    padded_pairs_mask = padded_pairs_mask & (true_diffs > 0)


    # Here we find the gains, discounts and ideal DCGs per slate.
    inv_pos_idxs = 1. / torch.arange(1, y_pred.shape[1] + 1).to(device)
    weights = torch.abs(inv_pos_idxs.view(1,-1,1) - inv_pos_idxs.view(1,1,-1)) # [1, topk, topk]

    # We are clamping the array entries to maintain correct backprop (log(0) and division by 0)
    scores_diffs = (y_pred_sorted[:, :, None] - y_pred_sorted[:, None, :])
    
    topk = scores_diffs.size(1)
    scores_diffs=scores_diffs.view(1,-1,1)

    scores_diffs=F.pad(input=-scores_diffs, pad=(0,1), mode='constant', value=0)
    scores = torch.logsumexp(scores_diffs,2,True)
    scores = scores.view(-1,topk,topk)

    losses =  scores * weights #[bz, topk, topk]

    if reduction == "sum":
        loss = torch.sum(losses[padded_pairs_mask])
    elif reduction == "mean":
        loss = torch.mean(losses[padded_pairs_mask])
    else:
        raise ValueError("Reduction method can be either sum or mean")

    return loss


if __name__ == "__main__":

    y_true = torch.FloatTensor([[1.,1./2, 1./3, 0.,0., -1/2., -1.]])
    y_pred_1 = torch.FloatTensor([[2.3, 1.2, 1.1, 0.5, 0.23, 0.21, 40]])
    y_pred_2 = torch.FloatTensor([[0.5, 0.23, 2.3, 1.2, 1.1, 5, 20]])
    
    print(lambda_mrr_loss(y_pred_1, y_true))
    print(lambda_mrr_loss(y_pred_2, y_true))
    
    y_true_batch = torch.FloatTensor([[1.,1./2, 1./3, 0.,0., -1/2., -1.],[1.,1./2, 1./3, 0.,0., -1/2., -1.]])
    y_pred_batch = torch.FloatTensor([[2.3, 1.2, 1.1, 0.5, 0.23, 0.21, 40],[0.5, 0.23, 2.3, 1.2, 1.1, 5, 20]])
   
    print(lambda_mrr_loss(y_pred_batch, y_true_batch))
    
    y_true = torch.FloatTensor([[1.,1./2, 1./3, 0.,0., -1/2., -1.]])
    y_pred = torch.FloatTensor([[20.3, 10.2, 10.1, 50, 100, 21, 40]])
    
    print(lambda_mrr_loss(y_pred, y_true))
    

tensor(0.0890)
tensor(1.1732)
tensor(0.6311)
tensor(17.3262)
