Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 43 additions & 43 deletions aten/src/ATen/native/cuda/Loss.cu
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ template <typename scalar_t, typename index_t>
__global__ void nll_loss_forward_no_reduce_cuda_kernel(
int64_t batch_size,
PackedTensorAccessor64<scalar_t, 2> input,
index_t* target,
const index_t* target,
scalar_t* output,
scalar_t* weights,
const scalar_t* weights,
int64_t n_classes,
int64_t ignore_index) {
CUDA_KERNEL_LOOP(index, batch_size) {
Expand All @@ -185,9 +185,9 @@ template <typename scalar_t, typename index_t>
__global__ void nll_loss_forward_reduce_cuda_kernel_1d(
scalar_t* output,
scalar_t* total_weight,
scalar_t* input,
index_t* target,
scalar_t* weights,
const scalar_t* input,
const index_t* target,
const scalar_t* weights,
bool size_average,
int64_t n_classes,
int64_t ignore_index) {
Expand Down Expand Up @@ -221,9 +221,9 @@ template <typename scalar_t, typename accscalar_t, typename index_t>
__global__ void nll_loss_forward_reduce_cuda_kernel_2d(
scalar_t* output,
scalar_t* total_weight,
scalar_t* input,
index_t* target,
scalar_t* weights,
const scalar_t* input,
const index_t* target,
const scalar_t* weights,
bool size_average,
int64_t nframe,
int64_t ndim,
Expand Down Expand Up @@ -307,9 +307,9 @@ void nll_loss_forward_out_cuda_template(
at::cuda::getCurrentCUDAStream()>>>(
batch_size,
input.packed_accessor64<scalar_t, 2>(),
target.data_ptr<index_t>(),
output.data_ptr<scalar_t>(),
weight_.defined() ? weight_.data_ptr<scalar_t>()
target.const_data_ptr<index_t>(),
output.mutable_data_ptr<scalar_t>(),
weight_.defined() ? weight_.const_data_ptr<scalar_t>()
: nullptr,
n_classes,
ignore_index);
Expand Down Expand Up @@ -349,11 +349,11 @@ void nll_loss_forward_out_cuda_template(
[&] {
nll_loss_forward_reduce_cuda_kernel_1d<scalar_t, index_t>
<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
output.data_ptr<scalar_t>(),
total_weight.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
target.data_ptr<index_t>(),
weight_.defined() ? weight_.data_ptr<scalar_t>()
output.mutable_data_ptr<scalar_t>(),
total_weight.mutable_data_ptr<scalar_t>(),
input.const_data_ptr<scalar_t>(),
target.const_data_ptr<index_t>(),
weight_.defined() ? weight_.const_data_ptr<scalar_t>()
: nullptr,
reduction == at::Reduction::Mean,
n_classes,
Expand All @@ -378,11 +378,11 @@ void nll_loss_forward_out_cuda_template(
NLL_LOSS_THREADS,
0,
at::cuda::getCurrentCUDAStream()>>>(
output.data_ptr<scalar_t>(),
total_weight.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
target.data_ptr<index_t>(),
weight_.defined() ? weight_.data_ptr<scalar_t>()
output.mutable_data_ptr<scalar_t>(),
total_weight.mutable_data_ptr<scalar_t>(),
input.const_data_ptr<scalar_t>(),
target.const_data_ptr<index_t>(),
weight_.defined() ? weight_.const_data_ptr<scalar_t>()
: nullptr,
reduction == at::Reduction::Mean,
input.size(0),
Expand All @@ -398,10 +398,10 @@ void nll_loss_forward_out_cuda_template(
template <typename scalar_t, typename index_t>
__global__ void nll_loss_backward_no_reduce_cuda_kernel(
int batch_size,
index_t *target,
const index_t *target,
PackedTensorAccessor64<scalar_t, 1> grad_output,
PackedTensorAccessor64<scalar_t, 2> grad_input,
scalar_t *weights,
const scalar_t *weights,
int64_t n_classes,
int64_t ignore_index) {

Expand All @@ -419,10 +419,10 @@ __global__ void nll_loss_backward_no_reduce_cuda_kernel(
template <typename scalar_t, typename index_t>
__global__ void nll_loss_backward_reduce_cuda_kernel_1d(
scalar_t *grad_input,
scalar_t *grad_output,
scalar_t *weights,
index_t *target,
scalar_t *total_weight,
const scalar_t *grad_output,
const scalar_t *weights,
const index_t *target,
const scalar_t *total_weight,
bool size_average,
int64_t n_classes,
int64_t ignore_index
Expand All @@ -442,10 +442,10 @@ template<> struct bwd_index_type<int64_t> { using type = uint64_t; };
template <typename scalar_t, typename index_t>
__global__ void nll_loss_backward_reduce_cuda_kernel_2d(
scalar_t* grad_input,
scalar_t* grad_output,
index_t* target,
scalar_t* weights,
scalar_t* total_weight,
const scalar_t* grad_output,
const index_t* target,
const scalar_t* weights,
const scalar_t* total_weight,
bool size_average,
int nframe,
int ndim,
Expand Down Expand Up @@ -508,10 +508,10 @@ void nll_loss_backward_out_cuda_template(
0,
at::cuda::getCurrentCUDAStream()>>>(
batch_size,
target.data_ptr<index_t>(),
target.const_data_ptr<index_t>(),
grad_output.packed_accessor64<scalar_t, 1>(),
grad_input.packed_accessor64<scalar_t, 2>(),
weight.defined() ? weight_.data_ptr<scalar_t>() : nullptr,
weight.defined() ? weight_.const_data_ptr<scalar_t>() : nullptr,
n_classes,
ignore_index);
C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand All @@ -533,12 +533,12 @@ void nll_loss_backward_out_cuda_template(
[&] {
nll_loss_backward_reduce_cuda_kernel_1d<scalar_t, index_t>
<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input.data_ptr<scalar_t>(),
grad_output.data_ptr<scalar_t>(),
weight.defined() ? weight_.data_ptr<scalar_t>()
grad_input.mutable_data_ptr<scalar_t>(),
grad_output.const_data_ptr<scalar_t>(),
weight.defined() ? weight_.const_data_ptr<scalar_t>()
: nullptr,
target.data_ptr<index_t>(),
total_weight.data_ptr<scalar_t>(),
target.const_data_ptr<index_t>(),
total_weight.const_data_ptr<scalar_t>(),
reduction == at::Reduction::Mean,
n_classes,
ignore_index);
Expand All @@ -558,11 +558,11 @@ void nll_loss_backward_out_cuda_template(
[&] {
nll_loss_backward_reduce_cuda_kernel_2d<scalar_t, index_t>
<<<1, NLL_LOSS_THREADS, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_input.data_ptr<scalar_t>(),
grad_output.data_ptr<scalar_t>(),
target.data_ptr<index_t>(),
weight.defined() ? weight_.data_ptr<scalar_t>() : nullptr,
total_weight.data_ptr<scalar_t>(),
grad_input.mutable_data_ptr<scalar_t>(),
grad_output.const_data_ptr<scalar_t>(),
target.const_data_ptr<index_t>(),
weight.defined() ? weight_.const_data_ptr<scalar_t>() : nullptr,
total_weight.const_data_ptr<scalar_t>(),
reduction == at::Reduction::Mean,
input.size(0),
input.size(1),
Expand Down