Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multi_margin_loss: check weight shape, make contiguous on CPU, add tests #104852

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 8 additions & 1 deletion aten/src/ATen/native/LossMulti.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ namespace {
int64_t& dim,
const int64_t& ndims,
const Tensor& input,
const Tensor& target) {
const Tensor& target,
const c10::optional<Tensor>& weight) {
TORCH_CHECK(
(ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
"Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
Expand All @@ -64,6 +65,12 @@ namespace {
target.dim() <= 1 && target.numel() == nframe,
"inconsistent target size, expected ", nframe, " but got ",
target.sizes());
if (weight && weight->defined()) {
TORCH_CHECK(
weight->dim() <= 1 && weight->numel() == dim,
"inconsistent weight size, expected ", dim, " but got ",
weight->sizes());
}
}


Expand Down
26 changes: 12 additions & 14 deletions aten/src/ATen/native/LossMultiMargin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@ void multi_margin_loss_out_cpu_template(
const Tensor& target,
int p,
const Scalar& margin,
const Tensor& weight,
const c10::optional<Tensor>& weight,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multi_margin_loss_shape_check accepts an optional Tensor because weight is an optional on CUDA. So I change this (and elsewhere) to be consistent.

int64_t reduction) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t nframe, dim;
const auto ndims = input.dim();

TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");

multi_margin_loss_shape_check(nframe, dim, ndims, input, target);
multi_margin_loss_shape_check(nframe, dim, ndims, input, target, weight);

// produce a scalar output for 1d input
if (reduction == Reduction::None && target.dim() > 0) {
Expand All @@ -125,13 +125,17 @@ void multi_margin_loss_out_cpu_template(

auto input_contiguous = input.contiguous();
auto target_contiguous = target.contiguous();
Tensor weight_contiguous;
if (weight && weight->defined()) {
weight_contiguous = weight->contiguous();
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For compat with multi_margin_loss_cuda_out:

  Tensor weights;
  if (weights_ && weights_->defined()) {
    weights = weights_->contiguous();
  }


AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "multi_margin_loss_cpu_kernel", [&] {
auto input_data = input_contiguous.data_ptr<scalar_t>();
auto target_data = target_contiguous.data_ptr<int64_t>();
auto weight_data =
weight.defined() ? weight.data_ptr<scalar_t>() : nullptr;
weight_contiguous.defined() ? weight_contiguous.data_ptr<scalar_t>() : nullptr;
multi_margin_loss_cpu_kernel<scalar_t>(
output,
input_data,
Expand Down Expand Up @@ -219,7 +223,7 @@ void multi_margin_loss_backward_out_cpu_template(

TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");

multi_margin_loss_shape_check(nframe, dim, ndims, input, target);
multi_margin_loss_shape_check(nframe, dim, ndims, input, target, weight);
grad_input.resize_as_(input);
TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");

Expand Down Expand Up @@ -262,12 +266,9 @@ Tensor multi_margin_loss_cpu(
const Tensor& input,
const Tensor& target,
const Scalar& p,
const Scalar& margin, const c10::optional<Tensor>& weight_opt,
const Scalar& margin,
const c10::optional<Tensor>& weight,
int64_t reduction) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and elsewhere: to have a consistent API for multi_margin_loss_shape_check (it accepts optional weight on CPU and CUDA).

auto output = at::empty({0}, input.options());
multi_margin_loss_out_cpu_template(
output, input, target, p.toInt(), margin, weight, reduction);
Expand All @@ -277,13 +278,10 @@ Tensor multi_margin_loss_cpu(
Tensor& multi_margin_loss_cpu_out(const Tensor& input,
const Tensor& target,
const Scalar& p,
const Scalar& margin, const c10::optional<Tensor>& weight_opt,
const Scalar& margin,
const c10::optional<Tensor>& weight,
int64_t reduction,
Tensor& output) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;

multi_margin_loss_out_cpu_template(
output, input, target, p.toInt(), margin, weight, reduction);
return output;
Expand Down
13 changes: 10 additions & 3 deletions aten/src/ATen/native/cuda/MultiMarginLoss.cu
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ void multi_margin_loss_shape_check(
int64_t& dim,
const int64_t& ndims,
const Tensor& input,
const Tensor& target) {
const Tensor& target,
const c10::optional<Tensor>& weight) {
TORCH_CHECK(
(ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0,
"Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
Expand All @@ -150,6 +151,12 @@ void multi_margin_loss_shape_check(
target.dim() <= 1 && target.numel() == nframe,
"inconsistent target size, expected ", nframe, " but got ",
target.sizes());
if (weight && weight->defined()) {
TORCH_CHECK(
weight->dim() <= 1 && weight->numel() == dim,
"inconsistent weight size, expected ", dim, " but got ",
weight->sizes());
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the docs: "a manual rescaling weight given to each class. If given, it has to be a Tensor of size C. Otherwise, it is treated as if having all ones." But target here can be a scalar, so could it be useful to allow scalars for weight?

https://pytorch.org/docs/stable/generated/torch.nn.MultiMarginLoss.html

}

} // namespace (anonymous)
Expand All @@ -163,7 +170,7 @@ Tensor& multi_margin_loss_cuda_out(

TORCH_CHECK(p == 1 || p == 2, "multi_margin_loss: Invalid p, expected 1 or 2 but got ", p);

multi_margin_loss_shape_check(nframe, dim, ndims, input_, target_);
multi_margin_loss_shape_check(nframe, dim, ndims, input_, target_, weights_);

// produce a scalar output for 1d input
if (reduction == Reduction::None && target_.dim() > 0) {
Expand Down Expand Up @@ -318,7 +325,7 @@ Tensor& multi_margin_loss_cuda_backward_out(
TORCH_CHECK(p == 1 || p == 2,
"multi_margin_loss_backward: Invalid p, expected 1 or 2 but got ", p);

multi_margin_loss_shape_check(nframe, dim, ndims, input_, target_);
multi_margin_loss_shape_check(nframe, dim, ndims, input_, target_, weights_);
resize_output(grad_input_, input_.sizes());

if (input_.numel() == 0) {
Expand Down
11 changes: 8 additions & 3 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def meta_index_select_out(self, dim, index, out):
return out.copy_(torch.index_select(self, dim, index))


def _multi_margin_loss_shape_check(ndims, input, target):
def _multi_margin_loss_shape_check(ndims, input, target, weight):
torch._check(
(ndims == 2 and input.size(1) != 0)
or (ndims == 1 and input.size(0) != 0)
Expand All @@ -337,6 +337,11 @@ def _multi_margin_loss_shape_check(ndims, input, target):
target.dim() <= 1 and target.numel() == nframe,
lambda: f"inconsistent target size, expected {nframe} but got {target.shape}",
)
if weight is not None:
torch._check(
weight.ndim <= 1 and weight.numel() == dim,
lambda: f"inconsistent weight size, expected {dim} but got {weight.shape}",
)

return nframe, dim

Expand All @@ -353,7 +358,7 @@ def meta_multi_margin_loss(
) -> Tensor:
ndims = input.ndim
torch._check(p == 1 or p == 2, lambda: "only p == 1 and p == 2 supported")
nframe, _ = _multi_margin_loss_shape_check(ndims, input, target)
nframe, _ = _multi_margin_loss_shape_check(ndims, input, target, weight)
if reduction == Reduction.NONE.value and target.ndim > 0:
return input.new_empty(nframe)
else:
Expand All @@ -373,7 +378,7 @@ def meta_multi_margin_loss_backward(
) -> Tensor:
ndims = input.ndim
torch._check(p == 1 or p == 2, lambda: "only p == 1 and p == 2 supported")
_multi_margin_loss_shape_check(ndims, input, target)
_multi_margin_loss_shape_check(ndims, input, target, weight)
return input.new_empty(input.shape)


Expand Down
15 changes: 13 additions & 2 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,6 +1397,7 @@ def sample_inputs_zero_(op_info, device, dtype, requires_grad, **kwargs):
def sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs):
_make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)
make_weight = partial(_make_tensor, requires_grad=False)

inputs = (
((), make_target([], low=0, high=1), {}),
Expand All @@ -1405,6 +1406,7 @@ def sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwa
((S, M), make_target([S], low=0, high=M), {"margin": 1.0}),
((S, M), make_target([S], low=0, high=M), {"margin": -3.14}),
((M, S), make_target([M], low=0, high=S), {"weight": None}),
((M, S), make_target([M], low=0, high=S), {"weight": make_weight([S], low=-10., high=10.)}),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight wasn't tested before at all.

((M, S), make_target([M], low=0, high=S), {"reduction": "none"}),
((M, S), make_target([M], low=0, high=S), {"reduction": "mean"}),
((M, S), make_target([M], low=0, high=S), {"reduction": "sum"}),
Expand All @@ -1418,6 +1420,7 @@ def reference_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **
yield from sample_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **kwargs)
_make_tensor = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_target = partial(_make_tensor, dtype=torch.long, requires_grad=False)
make_weight = partial(_make_tensor, requires_grad=False)

inputs = (
((), make_target([], low=0, high=1)),
Expand All @@ -1427,10 +1430,11 @@ def reference_inputs_multi_margin_loss(op_info, device, dtype, requires_grad, **
)
ps = (1, 2)
margins = (0, 7, -3.14)
weights = (None, make_weight([S], low=-10., high=10.))
reductions = (None, "none", "mean", "sum")

for (input_shape, target), p, margin, reduction in product(inputs, ps, margins, reductions):
kwargs = {"p": p, "margin": margin}
for (input_shape, target), p, margin, weight, reduction in product(inputs, ps, margins, weights, reductions):
kwargs = {"p": p, "margin": margin, "weight": weight}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More tests with different combinations of arguments.

if reduction is not None:
kwargs["reduction"] = reduction
yield SampleInput(_make_tensor(input_shape), args=(target,), kwargs=kwargs)
Expand All @@ -1454,6 +1458,13 @@ def error_inputs_multi_margin_loss(op, device, **kwargs):
# invalid target dtype
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={}),
error_type=RuntimeError, error_regex='expected scalar type Long but found Float')
# invalid weight
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(())}),
error_type=ValueError, error_regex='weight must be one-dimensional')
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(5, 4)}),
error_type=ValueError, error_regex='weight must be one-dimensional')
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This "one-dimensional" error is only on the Python side. See my comment above, maybe worth removing this restriction from the Python side and allow scalars?

yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'weight': make_input(5,)}),
error_type=RuntimeError, error_regex=r'inconsistent weight size, expected 4 but got \[5\]')
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raises an error now.

# invalid p
yield ErrorInput(SampleInput(make_input(5, 4), args=(make_input(5,),), kwargs={'p': 3}),
error_type=ValueError, error_regex='only p == 1 and p == 2 supported')
Expand Down