In [5]:
import torch
import torch.nn as nn

In [None]:
def pairwise_distance(embeddings, squared=False):
    """Compute the 2D matrix of distances between all the embeddings.

    Args:
        embeddings: tensor of shape (batch_size, embed_dim)
        squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
                 If false, output is the pairwise euclidean distance matrix.

    Returns:
        pairwise_distances: tensor of shape (batch_size, batch_size)
    """
    # Get the dot product between all embeddings
    # shape (batch_size, batch_size)
    #dot_product = tf.matmul(embeddings, tf.transpose(embeddings))
    dot_product = torch.matmul(embeddings, torch.transpose(embeddings, 0, 1))

    # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
    # This also provides more numerical stability (the diagonal of the result will be exactly 0).
    # shape (batch_size,)
    #square_norm = tf.diag_part(dot_product)
    square_norm = torch.diagonal(dot_product, 0)

    # Compute the pairwise distance matrix as we have:
    # ||a - b||^2 = ||a||^2  - 2 <a, b> + ||b||^2
    # shape (batch_size, batch_size)
    #distances = tf.expand_dims(square_norm, 0) - 2.0 * dot_product + tf.expand_dims(square_norm, 1)
    distances = torch.unsqueeze(square_norm, 0) - 2.0 * dot_product + torch.unsqueeze(square_norm, 1)
    
    # Because of computation errors, some distances might be negative so we put everything >= 0.0
    #distances = tf.maximum(distances, 0.0)
    distances = torch.max(distances, torch.zeros_like(distances))
    distances = distances.float()

    if not squared:
        # Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
        # we need to add a small epsilon where distances == 0.0
        #mask = tf.to_float(tf.equal(distances, 0.0))
        #distances = distances + mask * 1e-16
        mask = (distances == 0.0).float()
        distances = distances + mask * 1e-16

        distances = torch.sqrt(distances)

        # Correct the epsilon added: set the distances on the mask to be exactly 0.0
        distances = distances * (1.0 - mask)

    return distances

In [3]:
def batch_semihard_triplet_loss(labels, embeddings, margin=0.001, squared=False):
    """Build the triplet loss over a batch of embeddings.

    For each anchor, we get the hardest positive and hardest negative to form a triplet.

    Args:
        labels: labels of the batch, of size (batch_size,)
        embeddings: tensor of shape (batch_size, embed_dim)
        margin: margin for triplet loss
        squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
                 If false, output is the pairwise euclidean distance matrix.

    Returns:
        triplet_loss: scalar tensor containing the triplet loss
    """
    # Get the pairwise distance matrix
    pairwise_dist = pairwise_distance(embeddings, squared=squared)

    # For each anchor, get the hardest positive
    # First, we need to get a mask for every valid positive (they should have same label)
    mask_anchor_positive = get_anchor_positive_triplet_mask(labels)
    mask_anchor_positive = mask_anchor_positive

    # We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
    anchor_positive_dist = torch.matmul(mask_anchor_positive, pairwise_dist)

    # shape (batch_size, 1)
    hardest_positive_dist = torch.reduce_max(anchor_positive_dist, axis=1, keepdims=True)

    # For each anchor, get the hardest negative
    # First, we need to get a mask for every valid negative (they should have different labels)
    mask_anchor_negative = get_anchor_negative_triplet_mask(labels)
    mask_anchor_negative = mask_anchor_negative.float()

    # We add the maximum value in each row to the invalid negatives (label(a) == label(n))
    max_anchor_negative_dist = torch.reduce_max(pairwise_dist, axis=1, keepdims=True)
    anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)

    # shape (batch_size,)
    hardest_negative_dist = torch.reduce_min(anchor_negative_dist, axis=1, keepdims=True)

    # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
    triplet_loss = torch.max(hardest_positive_dist - hardest_negative_dist + margin, 0.0)

    # Get final mean triplet loss
    triplet_loss = torch.reduce_mean(triplet_loss)

    return triplet_loss

In [None]:
class HardTripletLoss(nn.Module):
    """Hard/Hardest Triplet Loss
    (pytorch implementation of https://omoindrot.github.io/triplet-loss)
    For each anchor, we get the hardest positive and hardest negative to form a triplet.
    """
    def __init__(self, margin=0.1, hardest=False, squared=False):
        """
        Args:
            margin: margin for triplet loss
            hardest: If true, loss is considered only hardest triplets.
            squared: If true, output is the pairwise squared euclidean distance matrix.
                If false, output is the pairwise euclidean distance matrix.
        """
        super(HardTripletLoss, self).__init__()
        self.margin = margin
        self.hardest = hardest
        self.squared = squared

    def forward(self, embeddings, labels):
        """
        Args:
            labels: labels of the batch, of size (batch_size,)
            embeddings: tensor of shape (batch_size, embed_dim)
        Returns:
            triplet_loss: scalar tensor containing the triplet loss
        """
        # Make sure that labels only have identity and do not contain mask info.
        assert( len(labels.shape) == 1 or (len(labels.shape) == 2 and labels.shape[-1] == 1) )
        
        pairwise_dist = self._pairwise_distance(embeddings, squared=self.squared)

        if self.hardest:
            # Get the hardest positive pairs
            mask_anchor_positive = self._get_anchor_positive_triplet_mask(labels).float()
            valid_positive_dist = pairwise_dist * mask_anchor_positive
            hardest_positive_dist, _ = torch.max(valid_positive_dist, dim=1, keepdim=True)

            # Get the hardest negative pairs
            mask_anchor_negative = self._get_anchor_negative_triplet_mask(labels).float()
            max_anchor_negative_dist, _ = torch.max(pairwise_dist, dim=1, keepdim=True)
            anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (
                    1.0 - mask_anchor_negative)
            hardest_negative_dist, _ = torch.min(anchor_negative_dist, dim=1, keepdim=True)

            # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
            triplet_loss = F.relu(hardest_positive_dist - hardest_negative_dist + self.margin)
            triplet_loss = torch.mean(triplet_loss)
        else:
            anc_pos_dist = pairwise_dist.unsqueeze(dim=2)
            anc_neg_dist = pairwise_dist.unsqueeze(dim=1)

            # Compute a 3D tensor of size (batch_size, batch_size, batch_size)
            # triplet_loss[i, j, k] will contain the triplet loss of anc=i, pos=j, neg=k
            # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
            # and the 2nd (batch_size, 1, batch_size)
            loss = anc_pos_dist - anc_neg_dist + self.margin

            mask = self._get_triplet_mask(labels).float()
            triplet_loss = loss * mask

            # Remove negative losses (i.e. the easy triplets)
            triplet_loss = F.relu(triplet_loss)

            # Count number of hard triplets (where triplet_loss > 0)
            hard_triplets = torch.gt(triplet_loss, 1e-16).float()
            num_hard_triplets = torch.sum(hard_triplets)

            triplet_loss = torch.sum(triplet_loss) / (num_hard_triplets + 1e-16)
            
        return triplet_loss


    def _pairwise_distance(self, x, squared=False, eps=1e-16):
        """ Compute the 2D matrix of distances between all the embeddings.
        """
        cor_mat = torch.matmul(x, x.t())
        norm_mat = cor_mat.diag()
        distances = norm_mat.unsqueeze(1) - 2 * cor_mat + norm_mat.unsqueeze(0)
        distances = F.relu(distances)

        if not squared:
            mask = torch.eq(distances, 0.0).float()
            distances = distances + mask * eps
            distances = torch.sqrt(distances)
            distances = distances * (1.0 - mask)

        return distances


    def _get_anchor_positive_triplet_mask(self, labels):
        """ Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.
        """
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        indices_not_equal = torch.eye(labels.shape[0]).to(device).byte() ^ 1

        # Check if labels[i] == labels[j]
        labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1)

        mask = indices_not_equal * labels_equal

        return mask


    def _get_anchor_negative_triplet_mask(self, labels):
        """ Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.
        """
        # Check if labels[i] != labels[k]
        labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1)
        mask = labels_equal ^ True
        #mask = labels_equal ^ 1

        return mask


    def _get_triplet_mask(self, labels):
        """Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
           A triplet (i, j, k) is valid if:
             - i, j, k are distinct
             - labels[i] == labels[j] and labels[i] != labels[k]
        """
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        # Check that i, j and k are distinct
        indices_not_same = torch.eye(labels.shape[0]).to(device).byte() ^ 1
        i_not_equal_j = torch.unsqueeze(indices_not_same, 2)
        i_not_equal_k = torch.unsqueeze(indices_not_same, 1)
        j_not_equal_k = torch.unsqueeze(indices_not_same, 0)
        distinct_indices = i_not_equal_j * i_not_equal_k * j_not_equal_k

        # Check if labels[i] == labels[j] and labels[i] != labels[k]
        label_equal = torch.eq(torch.unsqueeze(labels, 0), torch.unsqueeze(labels, 1))
        i_equal_j = torch.unsqueeze(label_equal, 2)
        i_equal_k = torch.unsqueeze(label_equal, 1)
        
        valid_labels = i_equal_j * (i_equal_k ^ True)
        #valid_labels = i_equal_j * (i_equal_k ^ 1)

        mask = distinct_indices * valid_labels   # Combine the two masks

        return mask

In [None]:
class HardTripletLossWithMask(nn.Module):
    """Hard/Hardest Triplet Loss
    (pytorch implementation of https://omoindrot.github.io/triplet-loss)
    For each anchor, we get the hardest positive and hardest negative to form a triplet.
    """
    def __init__(self, margin=0.1, hardest=False, squared=False):
        """
        Args:
            margin: margin for triplet loss
            hardest: If true, loss is considered only hardest triplets.
            squared: If true, output is the pairwise squared euclidean distance matrix.
                If false, output is the pairwise euclidean distance matrix.
        """
        super(HardTripletLossWithMask, self).__init__()
        self.margin = margin
        self.hardest = hardest
        self.squared = squared

    def forward(self, embeddings, labels):
        """
        Args:
            labels: labels of the batch, of size (batch_size,)
            embeddings: tensor of shape (batch_size, embed_dim)
        Returns:
            triplet_loss: scalar tensor containing the triplet loss
        """
        # Make sure that labels contain both identity and mask info.
        assert( len(labels.shape) == 2 and labels.shape[-1] == 2 )
        labels, masked_faces = labels[:,0], labels[:,1]
        
        pairwise_dist = self._pairwise_distance(embeddings, squared=self.squared)

        if self.hardest:
            # Get the hardest positive pairs
            mask_anchor_positive = self._get_anchor_positive_triplet_mask(labels).float()
            valid_positive_dist = pairwise_dist * mask_anchor_positive
            hardest_positive_dist, _ = torch.max(valid_positive_dist, dim=1, keepdim=True)

            # Get the hardest negative pairs
            mask_anchor_negative = self._get_anchor_negative_triplet_mask(labels).float()
            max_anchor_negative_dist, _ = torch.max(pairwise_dist, dim=1, keepdim=True)
            anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (
                    1.0 - mask_anchor_negative)
            hardest_negative_dist, _ = torch.min(anchor_negative_dist, dim=1, keepdim=True)

            # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
            triplet_loss = F.relu(hardest_positive_dist - hardest_negative_dist + 0.1)
            triplet_loss = torch.mean(triplet_loss)
        else:
            anc_pos_dist = pairwise_dist.unsqueeze(dim=2)
            anc_neg_dist = pairwise_dist.unsqueeze(dim=1)

            # Compute a 3D tensor of size (batch_size, batch_size, batch_size)
            # triplet_loss[i, j, k] will contain the triplet loss of anc=i, pos=j, neg=k
            # Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
            # and the 2nd (batch_size, 1, batch_size)
            triplet_loss = anc_pos_dist - anc_neg_dist
            
            # Get the masks for clear faces and masked faces
            # Each mask is D^(m*m*m)
            # mask[i,j,k] is 1 if {i,j,k} is a valid triplet where i:Anchor, j:Positive, k:negative
            clear_face_mask, masked_face_mask = self._get_triplet_mask(labels, masked_faces)
            clear_face_mask = clear_face_mask.float()
            masked_face_mask = masked_face_mask.float()
            
            # Apply the masks
            triplet_loss_clear  = triplet_loss * clear_face_mask
            triplet_loss_masked = triplet_loss * masked_face_mask
            
            # Use broadcast to calculate the sums of the two triple losses along the 1st dimension
            # The result is 4D: [m*m*m*m]
            # Example: tl_clear(1,2,3) + tl_masked(1,4,5)...
            # Use two 2D tensors as an example:
            # |0 2 0|   |1 0 3|   |1 0 3|  |4 9  4|  |7  0 9 |
            # |4 0 6| + |0 5 0| = |3 2 5|  |0 5  6|  |15 8 17|
            # |0 8 0|   |7 0 9|   |1 0 3|, |6 11 6|, |7  0 9 |
            loss = triplet_loss_clear.unsqueeze(2) + triplet_loss_masked.unsqueeze(1) + self.margin

            # Remove negative losses (i.e. the easy triplets)
            loss = F.relu(loss)

            # Count number of hard triplets (where triplet_loss > 0)
            hard_triplets = torch.gt(loss, 1e-16).float()
            num_hard_triplets = torch.sum(hard_triplets)

            loss = torch.sum(loss) / (num_hard_triplets + 1e-16)

        return loss


    def _pairwise_distance(self, x, squared=False, eps=1e-16):
        """ Compute the 2D matrix of distances between all the embeddings.
        """
        cor_mat = torch.matmul(x, x.t())
        norm_mat = cor_mat.diag()
        distances = norm_mat.unsqueeze(1) - 2 * cor_mat + norm_mat.unsqueeze(0)
        distances = F.relu(distances)

        if not squared:
            mask = torch.eq(distances, 0.0).float()
            distances = distances + mask * eps
            distances = torch.sqrt(distances)
            distances = distances * (1.0 - mask)

        return distances


    def _get_anchor_positive_triplet_mask(self, labels):
        """ Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.
        """
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        indices_not_equal = torch.eye(labels.shape[0]).to(device).byte() ^ 1

        # Check if labels[i] == labels[j]
        labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1)

        mask = indices_not_equal * labels_equal

        return mask

    def _get_anchor_positive_masked_triplet_mask(self, the_labels):
        """ Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.
        """
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        labels = the_labels.t()[0]
        indices_not_equal = torch.eye(labels.shape[0]).to(device).byte() ^ 1

        # Check if labels[i] == labels[j]
        labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1)

        mask = indices_not_equal * labels_equal

        masked_faces = the_labels[:, 1:] == True
        
        return mask

    def _get_anchor_negative_triplet_mask(self, labels):
        """ Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.
        """
        # Check if labels[i] != labels[k]
        labels_equal = torch.unsqueeze(labels, 0) == torch.unsqueeze(labels, 1)
        mask = labels_equal ^ 1

        return mask


    def _get_triplet_mask(self, labels, masked_faces):
        """Return Two 3D masks (clear_face_mask and masked_face_mask)
           where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
           A triplet (i, j, k) is valid if:
             - i, j, k are distinct
             - idens[i] == idens[j] and idens[i] != idens[k]
             - masked_faces[j] == False and masked_faces[k] == False FOR clear_face_mask
             - masked_faces[j] == True and masked_faces[k] == True FOR masked_face_mask
             
        """
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        # Check that i, j and k are distinct
        indices_not_same = torch.eye(labels.shape[0]).to(device).byte() ^ 1
        i_not_equal_j = torch.unsqueeze(indices_not_same, 2)
        i_not_equal_k = torch.unsqueeze(indices_not_same, 1)
        j_not_equal_k = torch.unsqueeze(indices_not_same, 0)
        distinct_indices = i_not_equal_j * i_not_equal_k * j_not_equal_k

        # Check if labels[i] == labels[j] and labels[i] != labels[k]
        label_equal = torch.eq(torch.unsqueeze(labels, 0), torch.unsqueeze(labels, 1))
        i_equal_j = torch.unsqueeze(label_equal, 2)
        i_equal_k = torch.unsqueeze(label_equal, 1)
        valid_labels = i_equal_j * (i_equal_k ^ True)
        #valid_labels = i_equal_j * (i_equal_k ^ 1)
        
        # Check if masked_faces[j] == False and masked_faces[k] = False
        mf = masked_faces.unsqueeze(1) * masked_faces
        masked = mf.unsqueeze(2) * masked_faces
        clear = masked ^ 1

        # Combine the 3 masks
        clear_face_mask = distinct_indices * valid_labels * clear
        masked_face_mask = distinct_indices * valid_labels * masked

        return clear_face_mask, masked_face_mask
    
    