Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 0 additions & 105 deletions aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -741,111 +741,6 @@ std::tuple<Tensor,Tensor> _thnn_nll_loss_forward(const Tensor & self, const Tens
}
return std::tuple<Tensor, Tensor>(output, total_weight);
}
Tensor & _thnn_nll_loss_backward_out(const Tensor & grad_output, const Tensor & self, const Tensor & target, const c10::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index, const Tensor & total_weight, Tensor & grad_input) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;

const OptionalDeviceGuard device_guard(device_of(self));
auto dispatch_scalar_type = infer_scalar_type(self);

switch (dispatch_scalar_type) {
case ScalarType::Double: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, ScalarType::Long);
auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss_backward_out", true, DeviceType::CUDA, dispatch_scalar_type);
auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 7, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
THNN_CudaDoubleClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
break;
}
case ScalarType::Float: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, ScalarType::Long);
auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss_backward_out", true, DeviceType::CUDA, dispatch_scalar_type);
auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 7, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
THNN_CudaClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
break;
}
case ScalarType::Half: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, ScalarType::Long);
auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss_backward_out", true, DeviceType::CUDA, dispatch_scalar_type);
auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 7, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
THNN_CudaHalfClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
break;
}
case ScalarType::BFloat16: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, ScalarType::Long);
auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss_backward_out", true, DeviceType::CUDA, dispatch_scalar_type);
auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
auto grad_input_ = checked_dense_tensor_unwrap(grad_input, "grad_input", 7, "_thnn_nll_loss_backward_out", false, DeviceType::CUDA, dispatch_scalar_type);
THNN_CudaBFloat16ClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
break;
}
default:
AT_ERROR("_thnn_nll_loss_backward_out not supported on CUDAType for ", dispatch_scalar_type);
}
return grad_input;
}
Tensor _thnn_nll_loss_backward(const Tensor & grad_output, const Tensor & self, const Tensor & target, const c10::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index, const Tensor & total_weight) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;

const OptionalDeviceGuard device_guard(device_of(self));
auto dispatch_scalar_type = infer_scalar_type(self);
auto grad_input_ = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(c10::Storage(c10::Storage::use_byte_size_t(), 0, allocator(), true),DispatchKey::CUDA, scalarTypeToTypeMeta(dispatch_scalar_type)).release();
auto grad_input = Tensor(c10::intrusive_ptr<TensorImpl, UndefinedTensorImpl>::reclaim(grad_input_));
switch (dispatch_scalar_type) {
case ScalarType::Double: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss_backward", false, DeviceType::CUDA, ScalarType::Long);
auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss_backward", true, DeviceType::CUDA, dispatch_scalar_type);
auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss_backward", false, DeviceType::CUDA, dispatch_scalar_type);
THNN_CudaDoubleClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
break;
}
case ScalarType::Float: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss_backward", false, DeviceType::CUDA, ScalarType::Long);
auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss_backward", true, DeviceType::CUDA, dispatch_scalar_type);
auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss_backward", false, DeviceType::CUDA, dispatch_scalar_type);
THNN_CudaClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
break;
}
case ScalarType::Half: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss_backward", false, DeviceType::CUDA, ScalarType::Long);
auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss_backward", true, DeviceType::CUDA, dispatch_scalar_type);
auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss_backward", false, DeviceType::CUDA, dispatch_scalar_type);
THNN_CudaHalfClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
break;
}
case ScalarType::BFloat16: {
auto grad_output_ = checked_dense_tensor_unwrap(grad_output, "grad_output", 1, "_thnn_nll_loss_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto self_ = checked_dense_tensor_unwrap(self, "self", 2, "_thnn_nll_loss_backward", false, DeviceType::CUDA, dispatch_scalar_type);
auto target_ = checked_dense_tensor_unwrap(target, "target", 3, "_thnn_nll_loss_backward", false, DeviceType::CUDA, ScalarType::Long);
auto weight_ = checked_dense_tensor_unwrap(weight, "weight", 4, "_thnn_nll_loss_backward", true, DeviceType::CUDA, dispatch_scalar_type);
auto total_weight_ = checked_dense_tensor_unwrap(total_weight, "total_weight", 7, "_thnn_nll_loss_backward", false, DeviceType::CUDA, dispatch_scalar_type);
THNN_CudaBFloat16ClassNLLCriterion_updateGradInput(globalContext().getTHCState(), self_, target_, grad_output_, grad_input_, reduction, weight_ ? weight_ : NULL, total_weight_, ignore_index);
break;
}
default:
AT_ERROR("_thnn_nll_loss_backward not supported on CUDAType for ", dispatch_scalar_type);
}
return grad_input;
}
std::tuple<Tensor &,Tensor &> _thnn_nll_loss2d_forward_out(const Tensor & self, const Tensor & target, const c10::optional<Tensor>& weight_opt, int64_t reduction, int64_t ignore_index, Tensor & output, Tensor & total_weight) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
Expand Down
242 changes: 242 additions & 0 deletions aten/src/ATen/native/cuda/Loss.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/native/TensorIterator.h>
#include <aten/src/ATen/TensorUtils.h>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/native/cuda/Loops.cuh>

constexpr float EPSILON = 1e-12;
Expand Down Expand Up @@ -153,4 +155,244 @@ Tensor& binary_cross_entropy_backward_out_cuda(const Tensor& grad, const Tensor&
return grad_input;
}

// -----------------------------------
// nll_loss backward
// -----------------------------------
namespace {

const int NLL_LOSS_THREADS = 32;

template <typename scalar_t>
__global__ void nll_loss_backward_no_reduce_cuda_kernel(
int batch_size,
int64_t *target,
PackedTensorAccessor64<scalar_t, 1> grad_output,
PackedTensorAccessor64<scalar_t, 2> grad_input,
scalar_t *weights,
int n_classes,
int ignore_index) {

CUDA_KERNEL_LOOP(index, batch_size) {
int cur_target = target[index];
if (cur_target == ignore_index) {
continue;
}
CUDA_KERNEL_ASSERT(cur_target >= 0 && cur_target < n_classes);
scalar_t weight = weights != nullptr ? weights[cur_target] : static_cast<scalar_t>(1);
grad_input[index][cur_target] = -weight * grad_output[index];
}
};

template <typename scalar_t>
__global__ void nll_loss_backward_reduce_cuda_kernel_1d(
scalar_t *grad_input,
scalar_t *grad_output,
scalar_t *weights,
int64_t *target,
scalar_t *total_weight,
bool size_average,
int n_classes,
int64_t ignore_index
) {
if (*total_weight <= 0) {
return;
}
scalar_t norm = size_average ? (static_cast<scalar_t>(1) / *total_weight) : static_cast<scalar_t>(1);
int t = static_cast<int>(*target);
if (t != static_cast<int>(ignore_index)) {
CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes);
grad_input[t] = -(weights != nullptr ? weights[t] : static_cast<scalar_t>(1)) * norm * grad_output[0];
}
};

template <typename scalar_t>
__global__ void nll_loss_backward_reduce_cuda_kernel_2d(
scalar_t* grad_input,
scalar_t* grad_output,
int64_t* target,
scalar_t* weights,
scalar_t* total_weight,
bool size_average,
int nframe,
int ndim,
int n_classes,
int64_t ignore_index) {
if (*total_weight <= 0) {
return;
}
scalar_t norm = size_average ? (static_cast<scalar_t>(1) / *total_weight) : static_cast<scalar_t>(1);

for (int i = threadIdx.x; i < nframe; i += NLL_LOSS_THREADS) {
int t = target[i];
if (t != static_cast<int>(ignore_index)) {
CUDA_KERNEL_ASSERT(t >= 0 && t < n_classes);
grad_input[i * ndim + t] = -(weights != nullptr ? weights[t] : static_cast<scalar_t>(1)) * norm * grad_output[0];
}
}
};

void nll_loss_backward_out_cuda_template(
Tensor& grad_input,
const Tensor& grad_output,
const Tensor& input,
const Tensor& target,
const Tensor& total_weight,
const c10::optional<Tensor>& weight_opt,
int64_t reduction,
int64_t ignore_index) {
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;

TORCH_CHECK(
target.dim() == 1,
"1D target tensor expected, multi-target not supported");
int64_t n_dims = input.dim();
TORCH_CHECK(
n_dims > 0 && n_dims <= 2, "input tensor should be 1D or 2D");

int64_t n_classes = input.size(-1);
int64_t batch_size = n_dims == 1 ? 1 : input.size(0);
int64_t num_targets = target.size(0);

TORCH_CHECK(
batch_size == num_targets,
"size mismatch (got input: ",
input.sizes(),
", target: ",
target.sizes(),
")")
TORCH_CHECK(
!weight.defined() || (weight.dim() == 1 && weight.numel() == n_classes),
"weight tensor should be defined either for all or no classes");

TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous");
auto weight_ = weight.defined() ? weight.contiguous() : weight;

if (reduction == at::Reduction::None && n_dims == 2) {
check_dim_size(grad_output, 1, 0, batch_size);
if (batch_size == 0) {
// This guards from unnecessary operations and launching CUDA kernel with 0 blocks.
return;
}
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"nll_loss_backward_no_reduce_cuda_kernel",
[&] {
nll_loss_backward_no_reduce_cuda_kernel<scalar_t>
<<<at::cuda::detail::GET_BLOCKS(batch_size),
at::cuda::detail::CUDA_NUM_THREADS,
0,
at::cuda::getCurrentCUDAStream()>>>(
batch_size,
target.data_ptr<int64_t>(),
grad_output.packed_accessor64<scalar_t, 1>(),
grad_input.packed_accessor64<scalar_t, 2>(),
weight.defined() ? weight_.data_ptr<scalar_t>() : nullptr,
n_classes,
ignore_index);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});

return;
}

auto target_ = target.contiguous();
TORCH_CHECK(grad_output.numel() == 1);

if (n_dims == 1) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"nll_loss_backward_reduce_cuda_kernel_1d",
[&] {
nll_loss_backward_reduce_cuda_kernel_1d<scalar_t>
<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input.data_ptr<scalar_t>(),
grad_output.data_ptr<scalar_t>(),
weight.defined() ? weight_.data_ptr<scalar_t>() : nullptr,
target.data_ptr<int64_t>(),
total_weight.data_ptr<scalar_t>(),
reduction == at::Reduction::Mean,
n_classes,
ignore_index
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"nll_loss_backward_reduce_cuda_kernel_2d",
[&] {
scalar_t* weight_data = nullptr;
if (weight.defined()) {
auto weight_ = weight.contiguous();
weight_data = weight_.data_ptr<scalar_t>();
}
nll_loss_backward_reduce_cuda_kernel_2d<scalar_t>
<<<1, NLL_LOSS_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input.data_ptr<scalar_t>(),
grad_output.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(),
weight.defined() ? weight_.data_ptr<scalar_t>() : nullptr,
total_weight.data_ptr<scalar_t>(),
reduction == at::Reduction::Mean,
input.size(0),
input.size(1),
n_classes,
ignore_index);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
}
} // namespace

Tensor& nll_loss_backward_out_cuda(const Tensor& grad_output,
const Tensor& self,
const Tensor& target,
const c10::optional<Tensor>& weight_opt,
int64_t reduction,
int64_t ignore_index,
const Tensor& total_weight,
Tensor& grad_input) {

grad_input.resize_as_(self);
grad_input.zero_();
nll_loss_backward_out_cuda_template(
grad_input,
grad_output,
self,
target,
total_weight,
weight_opt,
reduction,
ignore_index);
return grad_input;
}

Tensor nll_loss_backward_cuda(
const Tensor& grad_output,
const Tensor& self,
const Tensor& target, const c10::optional<Tensor>& weight_opt,
int64_t reduction,
int64_t ignore_index,
const Tensor& total_weight) {

auto grad_input = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
nll_loss_backward_out_cuda_template(
grad_input,
grad_output,
self,
target,
total_weight,
weight_opt,
reduction,
ignore_index);
return grad_input;
}

}} // namespace at::native
4 changes: 2 additions & 2 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8073,13 +8073,13 @@
python_module: nn
dispatch:
CPU: nll_loss_backward_out_cpu
CUDA: legacy::cuda::_thnn_nll_loss_backward_out
CUDA: nll_loss_backward_out_cuda

- func: nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, int ignore_index, Tensor total_weight) -> Tensor
python_module: nn
dispatch:
CPU: nll_loss_backward_cpu
CUDA: legacy::cuda::_thnn_nll_loss_backward
CUDA: nll_loss_backward_cuda

- func: nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, int ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)
python_module: nn
Expand Down
Loading