diff --git a/references/similarity/loss.py b/references/similarity/loss.py index 3e467b74c52..1fa4a89c762 100644 --- a/references/similarity/loss.py +++ b/references/similarity/loss.py @@ -77,7 +77,7 @@ def batch_all_triplet_loss(labels, embeddings, margin, p): def _get_triplet_mask(labels): # Check that i, j and k are distinct - indices_equal = torch.eye(labels.size(0), dtype=torch.uint8, device=labels.device) + indices_equal = torch.eye(labels.size(0), dtype=torch.bool, device=labels.device) indices_not_equal = ~indices_equal i_not_equal_j = indices_not_equal.unsqueeze(2) i_not_equal_k = indices_not_equal.unsqueeze(1) @@ -96,7 +96,7 @@ def _get_triplet_mask(labels): def _get_anchor_positive_triplet_mask(labels): # Check that i and j are distinct - indices_equal = torch.eye(labels.size(0), dtype=torch.uint8, device=labels.device) + indices_equal = torch.eye(labels.size(0), dtype=torch.bool, device=labels.device) indices_not_equal = ~indices_equal # Check if labels[i] == labels[j]