Skip to content

Commit

Permalink
multi_margin_loss: check weight shape, make contiguous on CPU, ad…
Browse files Browse the repository at this point in the history
…d tests

ghstack-source-id: deae4cd19cdde371632b0f2c00addca770458579
Pull Request resolved: #104852
  • Loading branch information
nkaretnikov committed Jul 9, 2023
1 parent 4d09c7a commit ca2c137
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 23 deletions.
9 changes: 8 additions & 1 deletion aten/src/ATen/native/LossMulti.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ namespace {
int64_t& dim,
const int64_t& ndims,
const Tensor& input,
const Tensor& target) {
const Tensor& target,
const c10::optional<Tensor>& weight) {
TORCH_CHECK(
(ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
"Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
Expand All @@ -64,6 +65,12 @@ namespace {
target.dim() <= 1 && target.numel() == nframe,
"inconsistent target size, expected ", nframe, " but got ",
target.sizes());
if (weight && weight->defined()) {
TORCH_CHECK(
weight->dim() <= 1 && weight->numel() == dim,
"inconsistent weight size, expected ", dim, " but got ",
weight->sizes());
}
}


Expand Down
26 changes: 12 additions & 14 deletions aten/src/ATen/native/LossMultiMargin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@ void multi_margin_loss_out_cpu_template(
const Tensor& target,
int p,
const Scalar& margin,
const Tensor& weight,
const c10::optional<Tensor>& weight,
int64_t reduction) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t nframe, dim;
const auto ndims = input.dim();

TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");

multi_margin_loss_shape_check(nframe, dim, ndims, input, target);
multi_margin_loss_shape_check(nframe, dim, ndims, input, target, weight);

// produce a scalar output for 1d input
if (reduction == Reduction::None && target.dim() > 0) {
Expand All @@ -125,13 +125,17 @@ void multi_margin_loss_out_cpu_template(

auto input_contiguous = input.contiguous();
auto target_contiguous = target.contiguous();
Tensor weight_contiguous;
if (weight && weight->defined()) {
weight_contiguous = weight->contiguous();
}

AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "multi_margin_loss_cpu_kernel", [&] {
auto input_data = input_contiguous.data_ptr<scalar_t>();
auto target_data = target_contiguous.data_ptr<int64_t>();
auto weight_data =
weight.defined() ? weight.data_ptr<scalar_t>() : nullptr;
weight_contiguous.defined() ? weight_contiguous.data_ptr<scalar_t>() : nullptr;
multi_margin_loss_cpu_kernel<scalar_t>(
output,
input_data,
Expand Down Expand Up @@ -219,7 +223,7 @@ void multi_margin_loss_backward_out_cpu_template(

TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");

multi_margin_loss_shape_check(nframe, dim, ndims, input, target);
multi_margin_loss_shape_check(nframe, dim, ndims, input, target, weight);
grad_input.resize_as_(input);
TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");

Expand Down Expand Up @@ -262,12 +266,9 @@ Tensor multi_margin_loss_cpu(
const Tensor& input,
const Tensor& target,
const Scalar& p,
const Scalar& margin, const c10::optional<Tensor>& weight_opt,
const Scalar& margin,
const c10::optional<Tensor>& weight,
int64_t reduction) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;

auto output = at::empty({0}, input.options());
multi_margin_loss_out_cpu_template(
output, input, target, p.toInt(), margin, weight, reduction);
Expand All @@ -277,13 +278,10 @@ Tensor multi_margin_loss_cpu(
Tensor& multi_margin_loss_cpu_out(const Tensor& input,
const Tensor& target,
const Scalar& p,
const Scalar& margin, const c10::optional<Tensor>& weight_opt,
const Scalar& margin,
const c10::optional<Tensor>& weight,
int64_t reduction,
Tensor& output) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;

multi_margin_loss_out_cpu_template(
output, input, target, p.toInt(), margin, weight, reduction);
return output;
Expand Down
13 changes: 10 additions & 3 deletions aten/src/ATen/native/cuda/MultiMarginLoss.cu
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ void multi_margin_loss_shape_check(
int64_t& dim,
const int64_t& ndims,
const Tensor& input,
const Tensor& target) {
const Tensor& target,
const c10::optional<Tensor>& weight) {
TORCH_CHECK(
(ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
"Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
Expand All @@ -150,6 +151,12 @@ void multi_margin_loss_shape_check(
target.dim() <= 1 && target.numel() == nframe,
"inconsistent target size, expected ", nframe, " but got ",
target.sizes());
if (weight && weight->defined()) {
TORCH_CHECK(
weight->dim() <= 1 && weight->numel() == dim,
"inconsistent weight size, expected ", dim, " but got ",
weight->sizes());
}
}

} // namespace (anonymous)
Expand All @@ -163,7 +170,7 @@ Tensor& multi_margin_loss_cuda_out(

TORCH_CHECK(p == 1 || p == 2, "multi_margin_loss: Invalid p, expected 1 or 2 but got ", p);

multi_margin_loss_shape_check(nframe, dim, ndims, input_, target_);
multi_margin_loss_shape_check(nframe, dim, ndims, input_, target_, weights_);

// produce a scalar output for 1d input
if (reduction == Reduction::None && target_.dim() > 0) {
Expand Down Expand Up @@ -318,7 +325,7 @@ Tensor& multi_margin_loss_cuda_backward_out(
TORCH_CHECK(p == 1 || p == 2,
"multi_margin_loss_backward: Invalid p, expected 1 or 2 but got ", p);

multi_margin_loss_shape_check(nframe, dim, ndims, input_, target_);
multi_margin_loss_shape_check(nframe, dim, ndims, input_, target_, weights_);
resize_output(grad_input_, input_.sizes());

if (input_.numel() == 0) {
Expand Down
11 changes: 8 additions & 3 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def meta_index_select_out(self, dim, index, out):
return out.copy_(torch.index_select(self, dim, index))


def _multi_margin_loss_shape_check(ndims, input, target):
def _multi_margin_loss_shape_check(ndims, input, target, weight):
torch._check(
(ndims == 2 and input.size(1) != 0)
or (ndims == 1 and input.size(0) != 0)
Expand All @@ -337,6 +337,11 @@ def _multi_margin_loss_shape_check(ndims, input, target):
target.dim() <= 1 and target.numel() == nframe,
lambda: f"inconsistent target size, expected {nframe} but got {target.shape}",
)
if weight is not None:
torch._check(
weight.ndim <= 1 and weight.numel() == dim,
lambda: f"inconsistent weight size, expected {dim} but got {weight.shape}",
)

return nframe, dim

Expand All @@ -353,7 +358,7 @@ def meta_multi_margin_loss(
) -> Tensor:
ndims = input.ndim
torch._check(p == 1 or p == 2, lambda: "only p == 1 and p == 2 supported")
nframe, _ = _multi_margin_loss_shape_check(ndims, input, target)
nframe, _ = _multi_margin_loss_shape_check(ndims, input, target, weight)
if reduction == Reduction.NONE.value and target.ndim > 0:
return input.new_empty(nframe)
else:
Expand All @@ -373,7 +378,7 @@ def meta_multi_margin_loss_backward(
) -> Tensor:
ndims = input.ndim
torch._check(p == 1 or p == 2, lambda: "only p == 1 and p == 2 supported")
_multi_margin_loss_shape_check(ndims, input, target)
_multi_margin_loss_shape_check(ndims, input, target, weight)
return input.new_empty(input.shape)


Expand Down
15 changes: 13 additions & 2 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,7 @@ def sample_inputs_zero_(op_info, device, dtype, requires_grad, **kwargs):
def sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
_make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)
make_weight = partial(_make_tensor, requires_grad=False)

inputs = (
((), make_target([], low=0, high=1), {}),
Expand All @@ -1405,6 +1406,7 @@ def sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwa
((S, M), make_target([S], low=0, high=M), {"margin": 1.0}),
((S, M), make_target([S], low=0, high=M), {"margin": -3.14}),
((M, S), make_target([M], low=0, high=S), {"weight": None}),
((M, S), make_target([M], low=0, high=S), {"weight": make_weight([S], low=-10., high=10.)}),
((M, S), make_target([M], low=0, high=S), {"reduction": "none"}),
((M, S), make_target([M], low=0, high=S), {"reduction": "mean"}),
((M, S), make_target([M], low=0, high=S), {"reduction": "sum"}),
Expand All @@ -1418,6 +1420,7 @@ def reference_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **
yield from sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs)
_make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)
make_weight = partial(_make_tensor, requires_grad=False)

inputs = (
((), make_target([], low=0, high=1)),
Expand All @@ -1427,10 +1430,11 @@ def reference_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **
)
ps = (1, 2)
margins = (0, 7, -3.14)
weights = (None, make_weight([S], low=-10., high=10.))
reductions = (None, "none", "mean", "sum")

for (input_shape, target), p, margin, reduction in product(inputs, ps, margins, reductions):
kwargs = {"p": p, "margin": margin}
for (input_shape, target), p, margin, weight, reduction in product(inputs, ps, margins, weights, reductions):
kwargs = {"p": p, "margin": margin, "weight": weight}
if reduction is not None:
kwargs["reduction"] = reduction
yield SampleInput(_make_tensor(input_shape), args=(target,), kwargs=kwargs)
Expand All @@ -1454,6 +1458,13 @@ def error_inputs_multi_margin_loss(op, device, **kwargs):
# invalid target dtype
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={}),
error_type=RuntimeError, error_regex='expected scalar type Long but found Float')
# invalid weight
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(())}),
error_type=ValueError, error_regex='weight must be one-dimensional')
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(5, 4)}),
error_type=ValueError, error_regex='weight must be one-dimensional')
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(5,)}),
error_type=RuntimeError, error_regex=r'inconsistent weight size, expected 4 but got \[5\]')
# invalid p
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'p': 3}),
error_type=ValueError, error_regex='only p == 1 and p == 2 supported')
Expand Down

0 comments on commit ca2c137

Please sign in to comment.