Skip to content

Commit

Permalink
Update on "[pytorch] Add triplet margin loss with custom distance"
Browse files Browse the repository at this point in the history
Summary: As discussed [here](#43342),
adding in a Python-only implementation of the triplet-margin loss that takes a
custom distance function.  Still discussing whether this is necessary to add to
PyTorch Core.

Test Plan: python test/run_tests.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D23363898](https://our.internmc.facebook.com/intern/diff/D23363898)

[ghstack-poisoned]
  • Loading branch information
ethch18 committed Sep 2, 2020
2 parents 5080482 + 14ebb2c commit 9fb04f9
Show file tree
Hide file tree
Showing 38 changed files with 438 additions and 198 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ cmake_dependent_option(
CAFFE2_USE_MSVC_STATIC_RUNTIME "Using MSVC static runtime libraries" ON
"NOT BUILD_SHARED_LIBS" OFF)
option(BUILD_TEST "Build C++ test binaries (need gtest and gbenchmark)" OFF)
option(BUILD_STATIC_RUNTIME_BENCHMARK "Build C++ binaries for static runtime benchmarks (need gbenchmark)" OFF)
option(BUILD_MOBILE_BENCHMARKS "Build C++ test binaries for mobile (ARM) targets(need gtest and gbenchmark)" OFF)
option(BUILD_MOBILE_TEST "Build C++ test binaries for mobile (ARM) targets(need gtest and gbenchmark)" OFF)
option(BUILD_JNI "Build JNI bindings" OFF)
Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/native/Activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -799,7 +799,6 @@ Tensor log_sigmoid(const Tensor & self) {
Tensor log_sigmoid_backward_cpu(const Tensor& grad_output, const Tensor& input, const Tensor& buffer) {
Tensor grad_input;
auto iter = at::TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(grad_input)
.add_input(input)
.add_input(buffer)
Expand All @@ -815,7 +814,6 @@ Tensor& log_sigmoid_backward_out_cpu(
const Tensor& input,
const Tensor& buffer) {
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(grad_input)
.add_input(input)
.add_input(buffer)
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ Tensor& remainder_(Tensor& self, const Tensor& other) {

Tensor& true_divide_out(Tensor& result, const Tensor& self, const Tensor& divisor) {
TensorIterator iter = TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(result)
.add_input(self)
.add_input(divisor)
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/native/Copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ static Tensor & copy_impl(Tensor & self, const Tensor & src, bool non_blocking)
}

auto iter = TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(self)
.add_input(src)
.resize_outputs(false)
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/DistributionTemplates.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <ATen/Dispatch.h>
#include <ATen/Generator.h>
#include <ATen/Tensor.h>
#include <ATen/MemoryOverlap.h>
#include <ATen/native/TensorIterator.h>
#include <c10/util/Optional.h>
#include <limits>
Expand Down Expand Up @@ -340,13 +341,15 @@ Tensor& cauchy_impl_(Tensor& self, double median, double sigma, c10::optional<Ge
template<template<typename> class bernoulli_tensor_kernel, typename RNG>
Tensor& bernoulli_impl_(Tensor& self, const Tensor& p_, c10::optional<Generator> gen) {
NoNamesGuard guard;
at::assert_no_internal_overlap(self);
bernoulli_tensor_kernel<RNG>()(self, p_, gen);
return self;
}

template<template<typename> class bernoulli_scalar_kernel, typename RNG>
Tensor& bernoulli_impl_(Tensor& self, double p, c10::optional<Generator> gen) {
TORCH_CHECK(0 <= p && p <= 1, "bernoulli_ expects p to be in [0, 1], but got p=", p);
at::assert_no_internal_overlap(self);
bernoulli_scalar_kernel<RNG>()(self, p, gen);
return self;
}
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/FunctionOfAMatrixUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ Tensor& _compute_linear_combination_out(Tensor& output, const Tensor& input, con
);

auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false) // Output is intentionally 0 strided above
.check_all_same_dtype(false)
.resize_outputs(false)
.add_output(output_restrided)
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/native/GatedLinearUnit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ Tensor& glu_backward_out(Tensor& grad_input,
at::sigmoid_out(gradInputfirstHalf, secondHalf);
// for second gradinput half, can get a better performance by fusion
auto iter = at::TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(gradInputsecondHalf)
.add_input(gradInputfirstHalf)
.add_input(firstHalf)
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/LegacyDefinitions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
#include <ATen/LegacyTHFunctionsCPU.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/ExpandUtils.h>
#include <ATen/MemoryOverlap.h>

namespace at { namespace native {

// Methods

Tensor & masked_scatter__cpu(Tensor& self, const Tensor & mask, const Tensor & source) {
at::assert_no_internal_overlap(self);
Tensor b_mask;
std::tie(b_mask) = expand_inplace(self, mask, "masked_scatter_");
// As we dispatch on self and TH is type-checked, we need different definitions.
Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/native/Loss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ Tensor& smooth_l1_loss_out(Tensor& result, const Tensor& input, const Tensor& ta
Tensor& smooth_l1_loss_backward_out(Tensor& grad_input, const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction) {
auto norm = reduction == Reduction::Mean ? 1. / input.numel() : 1.;
auto iter = at::TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(grad_input)
.add_input(input)
.add_input(target)
Expand Down Expand Up @@ -363,7 +362,6 @@ Tensor& mse_loss_backward_out(Tensor& grad_input, const Tensor& grad_output,
const Tensor& input, const Tensor& target, int64_t reduction) {
auto norm = reduction == Reduction::Mean ? 2. / input.numel() : 2.;
auto iter = at::TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(grad_input)
.add_input(input)
.add_input(target)
Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/native/PointwiseOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ Tensor& addcmul_out(
Scalar value) {
checkBackend("addcmul_cpu", result, self.options().backend());
auto iter = at::TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(result)
.add_input(self)
.add_input(tensor1)
Expand Down Expand Up @@ -82,7 +81,6 @@ Tensor& addcdiv_out(
}
checkBackend("addcdiv_cpu", result, self.options().backend());
auto iter = at::TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(result)
.add_input(self)
.add_input(tensor1)
Expand Down
35 changes: 33 additions & 2 deletions aten/src/ATen/native/TensorAdvancedIndexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ static TensorIterator make_index_put_iterator(const AdvancedIndex& info, const T
"got ", info.src.scalar_type(), " for the destination "
"and ", value.scalar_type(), " for the source.");
TensorIteratorConfig config;
// info.src is restrided by restride_src with 0 strided dimensions
config.set_check_mem_overlap(false);
config.resize_outputs(false);
config.check_all_same_dtype(false);
config.add_output(info.src);
Expand All @@ -235,7 +237,8 @@ static TensorIterator make_index_put_iterator(const AdvancedIndex& info, const T

static TensorIterator make_index_iterator(const AdvancedIndex& info) {
TensorIteratorConfig config;
config.check_all_same_dtype(false)
config.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.declare_static_dtype_and_device(info.src.scalar_type(), info.src.device())
.add_output(Tensor())
.add_input(info.src);
Expand All @@ -247,7 +250,9 @@ static TensorIterator make_index_iterator(const AdvancedIndex& info) {

static TensorIterator make_index_out_iterator(const AdvancedIndex& info, Tensor& result) {
TensorIteratorConfig config;
config.check_all_same_dtype(false)
// info.src is a restrided view of result
config.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.add_output(result)
.add_input(info.src);
for (auto& index : info.indices) {
Expand All @@ -267,6 +272,7 @@ Tensor index(const Tensor & self, TensorList indices) {

Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices) {
TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
at::assert_no_internal_overlap(result);

auto info = make_info(self, indices);
auto iter = make_index_out_iterator(info, result);
Expand All @@ -286,6 +292,15 @@ Tensor & _index_put_impl_(Tensor & self, TensorList indices, const Tensor & valu
index_put_accum_stub(self.device().type(), self, indices, value, unsafe);
return self;
}

if (at::has_internal_overlap(self) == MemOverlap::YES) {
TORCH_WARN(
"Use of index_put_ on expanded tensors is deprecated. "
"Please clone() the tensor before performing this operation. "
"This also applies to advanced indexing e.g. tensor[indices] = tensor");
}
at::assert_no_partial_overlap(self, value);

auto info = make_info(self, indices);
auto iter = make_index_put_iterator(info, value);
index_put_stub(iter.device_type(), iter, info.indexed_sizes, info.indexed_strides, accumulate);
Expand Down Expand Up @@ -424,6 +439,7 @@ Tensor & index_select_out_cpu_(Tensor & result, const Tensor & self, int64_t dim
"index_select(): self and result must have the same scalar type");
TORCH_CHECK(dim == 0 || dim < self.dim(),
"index_select(): Indexing dim ", dim, " is out of bounds of tensor");
at::assert_no_internal_overlap(result);

auto result_size = self.sizes().vec();
if (self.dim() > 0) {
Expand Down Expand Up @@ -629,7 +645,16 @@ static Tensor & masked_fill_impl_cpu(Tensor & self, const Tensor & mask, Scalar
"please use a mask with dtype torch.bool instead.");
}

if (at::has_internal_overlap(self) == MemOverlap::YES) {
TORCH_WARN(
"Use of masked_fill_ on expanded tensors is deprecated. "
"Please clone() the tensor before performing this operation. "
"This also applies to advanced indexing e.g. tensor[mask] = scalar");
}
at::assert_no_partial_overlap(self, mask);

auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false) // deprecated, but not a hard error
.check_all_same_dtype(false)
.resize_outputs(false)
.add_output(self)
Expand Down Expand Up @@ -694,6 +719,10 @@ static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self,
TORCH_CHECK(self.scalar_type() == result.scalar_type(),
"masked_select(): self and result must have the same scalar type");

at::assert_no_internal_overlap(result);
at::assert_no_partial_overlap(result, self);
at::assert_no_partial_overlap(result, mask);

if (mask.dtype() == at::ScalarType::Byte) {
TORCH_WARN("masked_select received a mask with dtype torch.uint8, this behavior is now deprecated," \
"please use a mask with dtype torch.bool instead.");
Expand Down Expand Up @@ -723,6 +752,7 @@ static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self,
_self.is_contiguous() && _mask.is_contiguous();
if (use_serial_kernel) {
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false) // result is intenionally zero-strided above
.check_all_same_dtype(false)
.resize_outputs(false)
.add_output(result_strided)
Expand All @@ -746,6 +776,7 @@ static Tensor & masked_select_out_impl_cpu(Tensor & result, const Tensor & self,
std::partial_sum(mask_long_data, mask_long_data + mask_long.numel(), mask_prefix_sum_data);

auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false) // result is intenionally zero-strided above
.check_all_same_dtype(false)
.resize_outputs(false)
.add_output(result_strided)
Expand Down
3 changes: 0 additions & 3 deletions aten/src/ATen/native/TensorCompare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ Tensor& isposinf_out(Tensor& result, const Tensor& self) {
} else {
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.set_check_mem_overlap(true)
.add_output(result)
.add_input(self)
.build();
Expand All @@ -149,7 +148,6 @@ Tensor& isneginf_out(Tensor& result, const Tensor& self) {
} else {
auto iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.set_check_mem_overlap(true)
.add_output(result)
.add_input(self)
.build();
Expand Down Expand Up @@ -247,7 +245,6 @@ Tensor _s_where(const Tensor& condition, const Tensor& self, const Tensor& other
Tensor ret = at::empty(self.sizes(), self.options());
auto iter = at::TensorIteratorConfig()
.check_all_same_dtype(false)
.set_check_mem_overlap(true)
.add_output(ret)
.add_input(condition)
.add_input(self)
Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ void complex_check_dtype(
Tensor& complex_out(Tensor& result, const Tensor& real, const Tensor& imag) {
complex_check_dtype(result, real, imag);
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(result)
.add_input(real)
.add_input(imag)
Expand All @@ -148,7 +147,6 @@ Tensor complex(const Tensor& real, const Tensor& imag) {
Tensor& polar_out(Tensor& result, const Tensor& abs, const Tensor& angle) {
complex_check_dtype(result, abs, angle);
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(result)
.add_input(abs)
.add_input(angle)
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/TensorIterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,7 @@ TensorIterator TensorIterator::nullary_op(Tensor& out) {
TensorIterator TensorIterator::reduce_op(Tensor& out, const Tensor& a) {
TORCH_INTERNAL_ASSERT(out.defined());
return TensorIteratorConfig()
.set_check_mem_overlap(false)
.add_output(out)
.add_input(a)
.resize_outputs(false)
Expand All @@ -887,6 +888,7 @@ TensorIterator TensorIterator::reduce_op(Tensor& out1, Tensor& out2, const Tenso
TORCH_CHECK(out1.strides() == out2.strides(), "reduce_op(): expected both outputs to have same strides, but output1 has ", out1.strides(),
" and output2 has ", out2.strides());
return TensorIteratorConfig()
.set_check_mem_overlap(false)
.add_output(out1)
.add_output(out2)
.add_input(a)
Expand Down
12 changes: 11 additions & 1 deletion aten/src/ATen/native/TensorIterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,16 @@ class CAFFE2_API TensorIteratorConfig final {
/// Construction
TensorIteratorConfig& add_output(const Tensor& output);
TensorIteratorConfig& add_input(const Tensor& input);

// Sets the check_mem_overlap_ flag, which is true by default.
// If true, inputs are checked for partial overlap with the outputs and
// outputs are checked for internal overlap (e.g. broadcasted views). An error
// is raised if unacceptable overlap is detected.
// If you're migrating an existing operator to using TensorIterator, please
// consider if the previous implementation checked memory overlap. If it did
// not, and if the operator is idempotent (for example, Tensor.fill_(0)), then
// checking memory overlap is BC-breaking. Please don't check memory overlap
// in that case.
TensorIteratorConfig& set_check_mem_overlap(bool check_mem_overlap);

// Sets the check_all_same_dtype_ flag, which is true by default
Expand Down Expand Up @@ -476,7 +486,7 @@ class CAFFE2_API TensorIteratorConfig final {

c10::optional<DimVector> static_shape_ = c10::nullopt;
c10::optional<std::pair<ScalarType, Device>> static_dtype_and_device_ = c10::nullopt;
bool check_mem_overlap_ = false;
bool check_mem_overlap_ = true;
bool allow_cpu_scalars_ = false;
bool is_reduction_ = false;
bool resize_outputs_ = true;
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
"unsupported operation: the input tensors cannot refer to any of the "
"output memory locations. Found overlap in input tensor ", i);
}
at::assert_no_internal_overlap(result);

auto should_skip = [](const Tensor& t) { return t.numel() == 0 && t.dim() == 1; };
for (auto const &tensor : tensors) {
Expand Down Expand Up @@ -196,6 +197,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
auto result_stride_bytes = result.stride(dim) * elementSize(result.scalar_type());

auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false) // Already checked above
.resize_outputs(false)
.add_output(result_slice)
.add_input(source_slice)
Expand All @@ -222,6 +224,7 @@ Tensor & _cat_out_cpu(Tensor& result, TensorList tensors, int64_t dim) {
auto result_slice = result.narrow(dim, offset, slice_dim_size);

auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false) // Already checked above
.resize_outputs(false)
.add_output(result_slice)
.add_input(tensor)
Expand Down
2 changes: 0 additions & 2 deletions aten/src/ATen/native/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,6 @@ Tensor& logical_not_(Tensor& self) {
Tensor& logical_not_out(Tensor& result, const Tensor& self) {
TensorIterator iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.set_check_mem_overlap(true)
.add_output(result)
.add_input(self)
.build();
Expand All @@ -421,7 +420,6 @@ Tensor& signbit_out(Tensor& result, const Tensor& self) {
} else {
TensorIterator iter = TensorIteratorConfig()
.check_all_same_dtype(false)
.set_check_mem_overlap(true)
.add_output(result)
.add_input(self)
.build();
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/UnfoldBackward.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ static TensorIterator _make_unfold_backward_iter_over_grad_out(
/* } */

auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.add_output(grad_out_restrided)
Expand Down Expand Up @@ -163,6 +164,7 @@ static TensorIterator _make_unfold_backward_iter_over_grad_in(
/* } */

auto iter = TensorIteratorConfig()
.set_check_mem_overlap(false)
.check_all_same_dtype(false)
.resize_outputs(false)
.add_output(grad_out_restrided)
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/native/cpu/LerpKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ static void lerp_kernel_tensor(
TORCH_CHECK(self.dtype() == end.dtype(), "expected dtype ", self.dtype(), " for `end` but got dtype ", end.dtype());
TORCH_CHECK(self.dtype() == weights.dtype(), "expected dtype ", self.dtype(), " for `weights` but got dtype ", weights.dtype());
auto iter = TensorIteratorConfig()
.set_check_mem_overlap(true)
.add_output(ret)
.add_input(self)
.add_input(end)
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/cuda/Indexing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,7 @@ Tensor& index_select_out_cuda(Tensor& out, const Tensor& self, int64_t dim,

TORCH_CHECK(at::cuda::check_device({out, self, index}),
"Input, output and indices must be on the current device");
at::assert_no_internal_overlap(out);

dim = at::maybe_wrap_dim(dim, self);
TORCH_CHECK(self.dim() <= MAX_TENSORINFO_DIMS, DIM_WARNING);
Expand Down

0 comments on commit 9fb04f9

Please sign in to comment.