In [1]:
import tensorflow as tf

In [2]:
def _get_anchor_positive_triplet_mask(labels):
    """Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.

    Args:
        labels: tf.int32 `Tensor` with shape [batch_size]

    Returns:
        mask: tf.bool `Tensor` with shape [batch_size, batch_size]
    """
    # Check that i and j are distinct
    indices_equal = tf.cast(tf.eye(tf.shape(labels)[0]), tf.bool)
    indices_not_equal = tf.logical_not(indices_equal)

    # Check if labels[i] == labels[j]
    # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
    labels_equal = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))

    # Combine the two masks
    mask = tf.logical_and(indices_not_equal, labels_equal)

    return mask

In [3]:
labels = tf.constant([1, 2, 3, 1])
_get_anchor_positive_triplet_mask(labels)

<tf.Tensor: shape=(4, 4), dtype=bool, numpy=
array([[False, False, False,  True],
       [False, False, False, False],
       [False, False, False, False],
       [ True, False, False, False]])>

In [4]:
def _get_anchor_negative_triplet_mask(labels):
    """Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.

    Args:
        labels: tf.int32 `Tensor` with shape [batch_size]

    Returns:
        mask: tf.bool `Tensor` with shape [batch_size, batch_size]
    """
    # Check if labels[i] != labels[k]
    # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
    labels_equal = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))

    mask = tf.logical_not(labels_equal)

    return mask

In [5]:
print(tf.expand_dims(labels, 0))
print(tf.expand_dims(labels, 1))

tf.Tensor([[1 2 3 1]], shape=(1, 4), dtype=int32)
tf.Tensor(
[[1]
 [2]
 [3]
 [1]], shape=(4, 1), dtype=int32)


In [6]:
print(tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1)))

tf.Tensor(
[[ True False False  True]
 [False  True False False]
 [False False  True False]
 [ True False False  True]], shape=(4, 4), dtype=bool)


In [7]:
labels = tf.constant([1, 2, 3, 1])
_get_anchor_negative_triplet_mask(labels)

<tf.Tensor: shape=(4, 4), dtype=bool, numpy=
array([[False,  True,  True, False],
       [ True, False,  True,  True],
       [ True,  True, False,  True],
       [False,  True,  True, False]])>

In [8]:
def _get_triplet_mask(labels):
    """Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.

    A triplet (i, j, k) is valid if:
        - i, j, k are distinct
        - labels[i] == labels[j] and labels[i] != labels[k]

    Args:
        labels: tf.int32 `Tensor` with shape [batch_size]
    """
    # Check that i, j and k are distinct
    indices_equal = tf.cast(tf.eye(tf.shape(labels)[0]), tf.bool)
    indices_not_equal = tf.logical_not(indices_equal)
    i_not_equal_j = tf.expand_dims(indices_not_equal, 2)
    i_not_equal_k = tf.expand_dims(indices_not_equal, 1)
    j_not_equal_k = tf.expand_dims(indices_not_equal, 0)

    distinct_indices = tf.logical_and(tf.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)


    # Check if labels[i] == labels[j] and labels[i] != labels[k]
    label_equal = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))
    i_equal_j = tf.expand_dims(label_equal, 2)
    i_equal_k = tf.expand_dims(label_equal, 1)

    valid_labels = tf.logical_and(i_equal_j, tf.logical_not(i_equal_k))

    # Combine the two masks
    mask = tf.logical_and(distinct_indices, valid_labels)

    return mask

In [10]:
labels = tf.constant([1, 2, 3, 1, 2])

_get_triplet_mask(labels)

<tf.Tensor: shape=(5, 5, 5), dtype=bool, numpy=
array([[[False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False,  True,  True, False,  True],
        [False, False, False, False, False]],

       [[False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [ True, False,  True,  True, False]],

       [[False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]],

       [[False,  True,  True, False,  True],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]],

       [[False, False, False, False, False],

In [6]:
labels = tf.constant([1, 2, 3, 1])
_get_triplet_mask(labels)

<tf.Tensor: shape=(4, 4, 4), dtype=bool, numpy=
array([[[False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False,  True,  True, False]],

       [[False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False]],

       [[False, False, False, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False]],

       [[False,  True,  True, False],
        [False, False, False, False],
        [False, False, False, False],
        [False, False, False, False]]])>

In [13]:
print(tf.shape(labels))
print(tf.shape(labels)[0])
print(labels.shape)

tf.Tensor([4], shape=(1,), dtype=int32)
tf.Tensor(4, shape=(), dtype=int32)
(4,)


In [14]:
tf.cast(tf.eye(tf.shape(labels)[0]), tf.bool)

<tf.Tensor: shape=(4, 4), dtype=bool, numpy=
array([[ True, False, False, False],
       [False,  True, False, False],
       [False, False,  True, False],
       [False, False, False,  True]])>

In [15]:
indices_equal = tf.cast(tf.eye(tf.shape(labels)[0]), tf.bool)
print(indices_equal)

indices_not_equal = tf.logical_not(indices_equal)
print(indices_not_equal)

i_not_equal_j = tf.expand_dims(indices_not_equal, 2)
i_not_equal_k = tf.expand_dims(indices_not_equal, 1)
j_not_equal_k = tf.expand_dims(indices_not_equal, 0)
print(i_not_equal_j)
print(i_not_equal_k)
print(j_not_equal_k)

tf.Tensor(
[[ True False False False]
 [False  True False False]
 [False False  True False]
 [False False False  True]], shape=(4, 4), dtype=bool)
tf.Tensor(
[[False  True  True  True]
 [ True False  True  True]
 [ True  True False  True]
 [ True  True  True False]], shape=(4, 4), dtype=bool)
tf.Tensor(
[[[False]
  [ True]
  [ True]
  [ True]]

 [[ True]
  [False]
  [ True]
  [ True]]

 [[ True]
  [ True]
  [False]
  [ True]]

 [[ True]
  [ True]
  [ True]
  [False]]], shape=(4, 4, 1), dtype=bool)
tf.Tensor(
[[[False  True  True  True]]

 [[ True False  True  True]]

 [[ True  True False  True]]

 [[ True  True  True False]]], shape=(4, 1, 4), dtype=bool)
tf.Tensor(
[[[False  True  True  True]
  [ True False  True  True]
  [ True  True False  True]
  [ True  True  True False]]], shape=(1, 4, 4), dtype=bool)
