Skip to content

Commit

Permalink
CLN: rename to triplet.py
Browse files Browse the repository at this point in the history
  • Loading branch information
facaiy committed Feb 7, 2019
1 parent f2f7f13 commit 30cede4
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 18 deletions.
12 changes: 6 additions & 6 deletions tensorflow_addons/losses/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,24 @@ licenses(["notice"]) # Apache 2.0
package(default_visibility = ["//visibility:public"])

py_library(
name = "metric_loss_ops_py",
name = "losses_py",
srcs = [
"__init__.py",
"python/__init__.py",
"python/metric_loss_ops.py",
"python/triplet.py",
],
srcs_version = "PY2AND3",
)

py_test(
name = "metric_loss_ops_py_test",
name = "triplet_py_test",
size = "small",
srcs = [
"python/metric_loss_ops_test.py",
"python/triplet_test.py",
],
main = "python/metric_loss_ops_test.py",
main = "python/triplet_test.py",
deps = [
":metric_loss_ops_py",
":losses_py",
],
srcs_version = "PY2AND3",
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements various metric learning losses."""
"""Implements triplet loss."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Expand All @@ -25,6 +25,12 @@
from tensorflow.python.ops.losses import losses_impl


# TODO: move to utils module?
def register_keras_custom_object(cls):
generic_utils._GLOBAL_CUSTOM_OBJECTS[cls.__name__] = cls
return cls


def pairwise_distance(feature, squared=False):
"""Computes the pairwise distance matrix with numerical stability.
Expand Down Expand Up @@ -116,6 +122,7 @@ def masked_minimum(data, mask, dim=1):
return masked_minimums


@register_keras_custom_object
def triplet_semihard_loss(y_true, y_pred, margin=1.0):
"""Computes the triplet loss with semi-hard negative mining.
Expand Down Expand Up @@ -189,6 +196,7 @@ def triplet_semihard_loss(y_true, y_pred, margin=1.0):
return triplet_loss


@register_keras_custom_object
class TripletSemiHardLoss(losses.LossFunctionWrapper):
"""Computes the triplet loss with semi-hard negative mining.
Expand All @@ -200,7 +208,7 @@ class TripletSemiHardLoss(losses.LossFunctionWrapper):
See: https://arxiv.org/abs/1503.03832.
We expect labels `y_true` to be provided as 1-D tf.int32 `Tensor` with shape
[batch_size] of multiclass integer labels. And embeddings `y_pred` must be 2-D
[batch_size] of multi-class integer labels. And embeddings `y_pred` must be 2-D
float `Tensor` of l2 normalized embedding vectors.
Args:
Expand All @@ -214,9 +222,3 @@ def __init__(self, margin=1.0, name=None):
name=name,
reduction=losses_impl.ReductionV2.NONE,
margin=margin)


generic_utils._GLOBAL_CUSTOM_OBJECTS[
'TripletSemiHardLoss'] = TripletSemiHardLoss
generic_utils._GLOBAL_CUSTOM_OBJECTS[
'triplet_semihard_loss'] = triplet_semihard_loss
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for metric_loss_ops."""
"""Tests for triplet loss."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
Expand All @@ -22,7 +22,7 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tensorflow_addons.losses.python import metric_loss_ops
from tensorflow_addons.losses.python import triplet


def pairwise_distance_np(feature, squared=False):
Expand Down Expand Up @@ -54,7 +54,7 @@ def pairwise_distance_np(feature, squared=False):

@test_util.run_all_in_graph_and_eager_modes
class TripletSemiHardLossTest(test.TestCase):
def test_all_correct_unweighted(self):
def test_unweighted(self):
num_data = 10
feat_dim = 6
margin = 1.0
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_all_correct_unweighted(self):
# Compute the loss in TF.
y_true = constant_op.constant(labels)
y_pred = constant_op.constant(embedding)
cce_obj = metric_loss_ops.TripletSemiHardLoss()
cce_obj = triplet.TripletSemiHardLoss()
loss = cce_obj(y_true, y_pred)
self.assertAlmostEqual(self.evaluate(loss), loss_np, 3)

Expand Down

0 comments on commit 30cede4

Please sign in to comment.