Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi label margin loss #50007

Closed
wants to merge 35 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
a1885d1
working on supplying a good test
v0dro Oct 23, 2020
01f5073
Merge branch 'master' of github.com:pytorch/pytorch into multi-label-…
v0dro Oct 28, 2020
8e2a16f
Change MultiLabelMarginLoss to accept 0-dim batch sizes
v0dro Oct 28, 2020
0c36491
move shape checking to separate function and reduce code duplication
v0dro Oct 28, 2020
02bcff1
accept 0 dim tensor in MultiLabelMargin
v0dro Oct 29, 2020
f6626a5
update NN modules
v0dro Oct 29, 2020
89ede33
Merge branch 'master' of github.com:pytorch/pytorch into multi-label-…
v0dro Nov 13, 2020
8116fd0
Merge branch 'master' of github.com:pytorch/pytorch into multi-label-…
v0dro Nov 20, 2020
16036d5
Merge branch 'master' of github.com:pytorch/pytorch into multi-label-…
v0dro Nov 27, 2020
6cda8b0
remove third party apps from previous commit
v0dro Nov 27, 2020
c7dd5df
update multi label margin loss to work with JIT
v0dro Nov 27, 2020
6ebf017
remove include iostream
v0dro Nov 27, 2020
830bde7
some extra tests
v0dro Dec 3, 2020
909c27f
Merge branch 'master' of github.com:pytorch/pytorch into multi-label-…
v0dro Dec 3, 2020
02b587a
update test for MultiLabel loss cehcks
v0dro Dec 3, 2020
7ab03e1
change checks for GPU
v0dro Dec 3, 2020
1339074
update CUDA
v0dro Dec 4, 2020
51526c5
update test for flake8
v0dro Dec 4, 2020
1124e64
Merge branch 'master' of github.com:pytorch/pytorch into multi-label-…
v0dro Dec 18, 2020
d44d86e
update testing and conditionals for testing
v0dro Dec 18, 2020
0ab92ad
update tests
v0dro Dec 18, 2020
1966220
Merge branch 'master' of github.com:pytorch/pytorch into multi-label-…
v0dro Dec 18, 2020
5762783
chnage include headers to use angle brackets
v0dro Dec 18, 2020
b562139
revert use of angle brackets
v0dro Dec 18, 2020
3be59d9
update CUDA tests
v0dro Dec 19, 2020
947aaf6
multi margin loss shape check for CUDA
v0dro Dec 19, 2020
19aa539
Merge branch 'master' of github.com:pytorch/pytorch into multi-label-…
v0dro Jan 3, 2021
e1a7c2a
Merge branch 'multi-label-margin-loss' of github.com:v0dro/pytorch in…
v0dro Jan 3, 2021
4ef5707
Merge branch 'master' of github.com:pytorch/pytorch into multi-label-…
v0dro Jan 7, 2021
76aa2ab
return if input numel is 0
v0dro Jan 7, 2021
4696ada
Use standard include method for LossMulti.h
v0dro Jan 15, 2021
57edc91
Use standard include method for LossMulti.h
v0dro Jan 15, 2021
c8ea9fc
Update LossMultiLabelMargin.cpp
v0dro Jan 15, 2021
12e078a
Update LossMultiMargin.cpp
v0dro Jan 15, 2021
f4072f4
Merge branch 'master' of github.com:pytorch/pytorch into multi-label-…
v0dro Jan 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
72 changes: 72 additions & 0 deletions aten/src/ATen/native/LossMulti.h
@@ -0,0 +1,72 @@
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/AccumulateType.h>

#pragma once

namespace at { namespace native {
namespace {
static void multilabel_margin_loss_shape_check(
int64_t& nframe,
int64_t& dim,
const int64_t& ndims,
TensorArg& target_arg,
const Tensor& input,
const Tensor& target) {
bool valid_inputs = (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0;
TORCH_CHECK(
valid_inputs,
"Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
input.sizes());

if (ndims <= 1) {
nframe = 1;
dim = ndims == 0 ? 1 : input.size(0);
TORCH_CHECK(
valid_inputs && target.dim() <= 1 && target.numel() == dim,
"inconsistent size ",
target.sizes(),
" for ",
target_arg);
} else {
nframe = input.size(0);
dim = input.size(1);
TORCH_CHECK(
valid_inputs && target.dim() == 2 && target.size(0) == nframe &&
target.size(1) == dim,
"inconsistent size ",
target.sizes(),
" for ",
target_arg);
}
}

static void multi_margin_loss_shape_check(
int64_t& nframe,
int64_t& dim,
const int64_t& ndims,
TensorArg& target_arg,
const Tensor& input,
const Tensor& target) {
bool valid_inputs = (ndims == 2 && input.size(1) != 0) || (ndims == 1 && input.size(0) != 0) || ndims == 0;
if (ndims <= 1) {
nframe = 1;
dim = ndims == 0 ? 1 : input.size(0);
} else {
nframe = input.size(0);
dim = input.size(1);
}

TORCH_CHECK(
valid_inputs,
"Expected non-empty vector or matrix with optional 0-dim batch size, but got: ",
input.sizes());
TORCH_CHECK(
valid_inputs && target.dim() <= 1 && target.numel() == nframe,
"inconsistent target size, got: ",
target.sizes());
}


} // anonymous namespace
}} // namespace at::native
99 changes: 33 additions & 66 deletions aten/src/ATen/native/LossMultiLabelMargin.cpp
Expand Up @@ -2,6 +2,7 @@
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/TensorUtils.h>
#include <ATen/native/LossMulti.h>

namespace at {
namespace native {
Expand Down Expand Up @@ -39,6 +40,7 @@ inline scalar_t multilabel_margin_loss_forward_inner_sum_cpu(
}
}
}

return sum;
}

Expand Down Expand Up @@ -100,34 +102,32 @@ static void multilabel_margin_loss_forward_out_cpu_template(
Tensor& is_target,
int64_t reduction) {
auto target_arg = TensorArg(target, "target", 2);

const auto ndims = input.dim();

TORCH_CHECK(
input.numel() > 0 && ndims <= 2,
"non-empty vector or matrix expected, got size: ",
input.sizes());

int64_t nframe, dim;
const int64_t ndims = input.dim();
if (ndims <= 1) {
nframe = 1;
dim = ndims == 0 ? 1 : input.size(0);
TORCH_CHECK(
target.numel() > 0 && target.dim() <= 1 && target.numel() == dim,
"inconsistent size ",
target.sizes(),
" for ",
target_arg);
} else {
}
else {
nframe = input.size(0);
dim = input.size(1);
TORCH_CHECK(
target.numel() > 0 && target.dim() == 2 && target.size(0) == nframe &&
target.size(1) == dim,
"inconsistent size ",
target.sizes(),
" for ",
target_arg);
}
multilabel_margin_loss_shape_check(nframe, dim, ndims, target_arg, input, target);

// special case target.dim() <= 1: produce scalar output for scalar inputs
// even if reduction == Reduction::None
if (reduction != Reduction::None || target.dim() <= 1) {
output.resize_({});
} else {
output.resize_({nframe});
}

is_target.resize_as_(target);
TORCH_CHECK(is_target.is_contiguous(), "is_target must be contiguous");
is_target.zero_();

if (input.numel() == 0) {
return;
}

TORCH_CHECK(
Expand All @@ -138,18 +138,6 @@ static void multilabel_margin_loss_forward_out_cpu_template(
auto input_contiguous = input.contiguous();
auto target_contiguous = target.contiguous();

is_target.resize_as_(target);
TORCH_CHECK(is_target.is_contiguous(), "is_target must be contiguous");
is_target.zero_();

// special case target.dim() <= 1: produce scalar output for scalar inputs
// even if reduction == Reduction::None
if (reduction != Reduction::None || target.dim() <= 1) {
output.resize_({});
} else {
output.resize_({nframe});
}

AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "multilabel_margin_loss_forward_out_frame", [&] {
multilabel_margin_loss_forward_out_frame<scalar_t>(
Expand Down Expand Up @@ -232,39 +220,22 @@ static void multilabel_margin_loss_backward_out_cpu_template(
const Tensor& target,
int64_t reduction,
const Tensor& is_target) {
int64_t nframe, dim;
CheckedFrom c = "multilabel_margin_loss_backward_cpu_template";
auto target_arg = TensorArg(target, "target", 3);
auto is_target_arg = TensorArg(is_target, "is_target", 5);
const int64_t ndims = input.dim();

const auto ndims = input.dim();

TORCH_CHECK(
input.numel() > 0 && ndims <= 2,
"non-empty vector or matrix expected, got size: ",
input.sizes());

int64_t nframe, dim;
if (ndims <= 1) {
nframe = 1;
dim = ndims == 0 ? 1 : input.size(0);
TORCH_CHECK(
target.numel() > 0 && target.dim() <= 1 && target.numel() == dim,
"inconsistent size ",
target.sizes(),
" for ",
target_arg);
} else {
nframe = input.size(0);
dim = input.size(1);
TORCH_CHECK(
target.numel() > 0 && target.dim() == 2 && target.size(0) == nframe &&
target.size(1) == dim,
"inconsistent size ",
target.sizes(),
" for ",
target_arg);
}
multilabel_margin_loss_shape_check(nframe, dim, ndims, target_arg, input, target);
checkSameSize(c, target_arg, is_target_arg);

grad_input.resize_as_(input);
if (grad_input.numel() == 0) {
return;
}

TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
grad_input.zero_();

TORCH_CHECK(
target.min().item<int64_t>() >= -1, target_arg, " is out of range");
Expand All @@ -275,10 +246,6 @@ static void multilabel_margin_loss_backward_out_cpu_template(
auto target_contiguous = target.contiguous();
auto is_target_contiguous = is_target.contiguous();

grad_input.resize_as_(input);
TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
grad_input.zero_();

AT_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "multilabel_margin_loss_backward_out_frame", [&] {
multilabel_margin_loss_backward_out_frame<scalar_t>(
Expand Down
51 changes: 15 additions & 36 deletions aten/src/ATen/native/LossMultiMargin.cpp
@@ -1,6 +1,7 @@
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/AccumulateType.h>
#include <ATen/native/LossMulti.h>

namespace at {
namespace native {
Expand Down Expand Up @@ -93,34 +94,23 @@ void multi_margin_loss_out_cpu_template(
Scalar margin,
const Tensor& weight,
int64_t reduction) {
int64_t nframe, dim;
const auto ndims = input.dim();
TORCH_CHECK(
input.numel() > 0 && ndims <= 2,
"non-empty vector or matrix expected, got size: ",
input.sizes());
auto target_arg = TensorArg(target, "target", 2);

TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");

int64_t nframe, dim;
if (ndims <= 1) {
nframe = 1;
dim = ndims == 0 ? 1 : input.size(0);
} else {
nframe = input.size(0);
dim = input.size(1);
}

TORCH_CHECK(
target.numel() > 0 && target.dim() <= 1 && target.numel() == nframe,
"inconsistent target size, got: ",
target.sizes());
multi_margin_loss_shape_check(nframe, dim, ndims, target_arg, input, target);

// produce a scalar output for 1d input
if (reduction == Reduction::None && target.dim() > 0) {
output.resize_({nframe});
} else {
output.resize_({});
}
if (input.numel() == 0) {
return;
}

auto input_contiguous = input.contiguous();
auto target_contiguous = target.contiguous();
Expand Down Expand Up @@ -212,31 +202,20 @@ void multi_margin_loss_backward_out_cpu_template(
Scalar margin,
const Tensor& weight,
int64_t reduction) {
int64_t nframe, dim;
auto target_arg = TensorArg(target, "target", 2);
const auto ndims = input.dim();
TORCH_CHECK(
input.numel() > 0 && ndims <= 2,
"non-empty vector or matrix expected, got size: ",
input.sizes());


TORCH_CHECK(p == 1 || p == 2, "only p == 1 and p == 2 supported");

int64_t nframe, dim;
if (ndims <= 1) {
nframe = 1;
dim = ndims == 0 ? 1 : input.size(0);
} else {
nframe = input.size(0);
dim = input.size(1);
}

TORCH_CHECK(
target.numel() > 0 && target.dim() <= 1 && target.numel() == nframe,
"inconsistent target size, got: ",
target.sizes());

multi_margin_loss_shape_check(nframe, dim, ndims, target_arg, input, target);
grad_input.resize_as_(input);
TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");

if (input.numel() == 0) {
return;
}

auto input_contiguous = input.contiguous();
auto target_contiguous = target.contiguous();
auto weight_contiguous = weight.contiguous();
Expand Down