From b93ada414624fcaf161bedd7df6dd32e9d4abb39 Mon Sep 17 00:00:00 2001 From: Roy Li Date: Fri, 26 Jan 2018 13:49:14 -0800 Subject: [PATCH 1/4] add reduce=True argument to MultiLabelMarginLoss --- aten/src/ATen/nn.yaml | 4 +- aten/src/THCUNN/MultiLabelMarginCriterion.cu | 16 ++- .../generic/MultiLabelMarginCriterion.cu | 50 +++++++--- aten/src/THCUNN/generic/THCUNN.h | 7 +- .../THNN/generic/MultiLabelMarginCriterion.c | 99 ++++++++++++++++--- aten/src/THNN/generic/THNN.h | 7 +- test/common_nn.py | 56 +++++++++++ test/test_nn.py | 30 ++++++ tools/autograd/derivatives.yaml | 4 +- torch/legacy/nn/MultiLabelMarginCriterion.py | 8 +- torch/nn/functional.py | 2 +- torch/nn/modules/loss.py | 24 ++++- 12 files changed, 269 insertions(+), 38 deletions(-) diff --git a/aten/src/ATen/nn.yaml b/aten/src/ATen/nn.yaml index 339ec210f7a25..13865f7f144f0 100644 --- a/aten/src/ATen/nn.yaml +++ b/aten/src/ATen/nn.yaml @@ -25,11 +25,11 @@ scalar_check: output: 'true' -- name: multilabel_margin_loss(Tensor self, LongTensor target, bool size_average=true) +- name: multilabel_margin_loss(Tensor self, LongTensor target, bool size_average=true, bool reduce=true) cname: MultiLabelMarginCriterion buffers: [is_target] scalar_check: - output: 'true' + output: reduce || self_->isScalar() is_target: target_->isScalar() - name: nll_loss(Tensor self, LongTensor target, Tensor weight={}, bool size_average=true, int64_t ignore_index=-100, bool reduce=True) diff --git a/aten/src/THCUNN/MultiLabelMarginCriterion.cu b/aten/src/THCUNN/MultiLabelMarginCriterion.cu index 72fc486bcaca7..460c42355a0f6 100644 --- a/aten/src/THCUNN/MultiLabelMarginCriterion.cu +++ b/aten/src/THCUNN/MultiLabelMarginCriterion.cu @@ -77,12 +77,14 @@ __global__ void cunn_MultiLabelMarginCriterion_updateOutput_kernel(Dtype *output template __global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(Dtype *gradInput, + Dtype *gradOutput, Dtype *input, THCIndex_t *target, Dtype *istarget, int nframe, int dim, - int sizeaverage) + int sizeaverage, + int reduce) { // Temporary sums (for mapreduce) __shared__ Acctype sums[MULTILABELMARGIN_THREADS]; @@ -93,9 +95,14 @@ __global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(Dtype *gra Dtype *gradInput_k = gradInput + k*dim; THCIndex_t *target_k = target + k*dim; Dtype *istarget_k = istarget + k*dim; + + Dtype *gradOutput_k = gradOutput; + if (!reduce) { + gradOutput_k += k; + } // gain: - Dtype g = ScalarConvert::to( sizeaverage ? 1./((Acctype)(nframe*dim)) : 1./((Acctype)dim) ); + Dtype g = ScalarConvert::to( sizeaverage && reduce ? 1./((Acctype)(nframe*dim)) : 1./((Acctype)dim) ); // zero gradients: for (int d = threadIdx.x; d < dim; d += blockDim.x) { @@ -133,6 +140,11 @@ __global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(Dtype *gra } __syncthreads(); } + + for (int d = threadIdx.x; d < dim; d += blockDim.x) { + gradInput_k[d] *= *gradOutput_k; + } + __syncthreads(); } #include "generic/MultiLabelMarginCriterion.cu" diff --git a/aten/src/THCUNN/generic/MultiLabelMarginCriterion.cu b/aten/src/THCUNN/generic/MultiLabelMarginCriterion.cu index bc6b35b51df56..6b530e1ec67bd 100644 --- a/aten/src/THCUNN/generic/MultiLabelMarginCriterion.cu +++ b/aten/src/THCUNN/generic/MultiLabelMarginCriterion.cu @@ -9,19 +9,20 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)( THCIndexTensor *target, THCTensor *output, THCTensor *istarget, - bool sizeaverage) + bool sizeaverage, + bool reduce) { input = THCTensor_(newContiguous)(state, input); target = THCIndexTensor_(newContiguous)(state, target); istarget = THCTensor_(newContiguous)(state, istarget); THCTensor_(resizeAs)(state, istarget, input); - THCTensor_(resize1d)(state, output, 1); if(input->nDimension == 1) { int dim = input->size[0]; THArgCheck((target->nDimension == 1) && (target->size[0] == dim), 3, "inconsistent target size"); + THCTensor_(resize1d)(state, output, 1); dim3 blocks(1); dim3 threads(MULTILABELMARGIN_THREADS); @@ -43,43 +44,65 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)( int dim = input->size[1]; THArgCheck((target->nDimension == 2) && (target->size[0] == nframe) && (target->size[1] == dim), 3, "inconsistent target size"); - THCTensor *output_tmp = THCTensor_(newWithSize1d)(state, input->size[0]); dim3 blocks(input->size[0]); dim3 threads(MULTILABELMARGIN_THREADS); + if (reduce) + { + THCTensor *output_tmp = THCTensor_(newWithSize1d)(state, input->size[0]); + THCTensor_(resize1d)(state, output, 1); + + cunn_MultiLabelMarginCriterion_updateOutput_kernel + <<>>( + THCTensor_(data)(state, output_tmp), + THCTensor_(data)(state, input), + THCIndexTensor_(data)(state, target), + THCTensor_(data)(state, istarget), + nframe, dim, + sizeaverage + ); + THCudaCheck(cudaGetLastError()); + THCTensor_(set1d)(state, output, 0, ScalarConvert::to(THCTensor_(sumall)(state, output_tmp))); + THCTensor_(free)(state, output_tmp); + } + else + { + THCTensor_(resize1d)(state, output, input->size[0]); + cunn_MultiLabelMarginCriterion_updateOutput_kernel <<>>( - THCTensor_(data)(state, output_tmp), + THCTensor_(data)(state, output), THCTensor_(data)(state, input), THCIndexTensor_(data)(state, target), THCTensor_(data)(state, istarget), nframe, dim, - sizeaverage + false ); THCudaCheck(cudaGetLastError()); - THCTensor_(set1d)(state, output, 0, ScalarConvert::to(THCTensor_(sumall)(state, output_tmp))); - THCTensor_(free)(state, output_tmp); + } } else THError("vector or matrix expected"); THCTensor_(free)(state, input); THCIndexTensor_(free)(state, target); - THCTensor_(free)(state, istarget); } void THNN_(MultiLabelMarginCriterion_updateGradInput)( THCState *state, THCTensor *input, THCIndexTensor *target, + THCTensor *gradOutput, THCTensor *gradInput, THCTensor *istarget, - bool sizeaverage) + bool sizeaverage, + bool reduce) { input = THCTensor_(newContiguous)(state, input); target = THCIndexTensor_(newContiguous)(state, target); istarget = THCTensor_(newContiguous)(state, istarget); + gradOutput = THCTensor_(newContiguous)(state, gradOutput); THCTensor_(resizeAs)(state, gradInput, input); if(gradInput->nDimension == 1) @@ -95,11 +118,13 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)( cunn_MultiLabelMarginCriterion_updateGradInput_kernel <<>>( THCTensor_(data)(state, gradInput), + THCTensor_(data)(state, gradOutput), THCTensor_(data)(state, input), THCIndexTensor_(data)(state, target), THCTensor_(data)(state, istarget), 1, gradInput->size[0], - sizeaverage); + sizeaverage, + reduce); } else if(gradInput->nDimension == 2) @@ -116,11 +141,13 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)( cunn_MultiLabelMarginCriterion_updateGradInput_kernel <<>>( THCTensor_(data)(state, gradInput), + THCTensor_(data)(state, gradOutput), THCTensor_(data)(state, input), THCIndexTensor_(data)(state, target), THCTensor_(data)(state, istarget), gradInput->size[0], gradInput->size[1], - sizeaverage); + sizeaverage, + reduce); } else THError("vector or matrix expected"); @@ -130,6 +157,7 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)( THCTensor_(free)(state, input); THCIndexTensor_(free)(state, target); THCTensor_(free)(state, istarget); + THCTensor_(free)(state, gradOutput); } #endif diff --git a/aten/src/THCUNN/generic/THCUNN.h b/aten/src/THCUNN/generic/THCUNN.h index f696d087ec17c..4107181652cd3 100644 --- a/aten/src/THCUNN/generic/THCUNN.h +++ b/aten/src/THCUNN/generic/THCUNN.h @@ -396,15 +396,18 @@ TH_API void THNN_(MultiLabelMarginCriterion_updateOutput)( THCIndexTensor *target, THCTensor *output, THCTensor *istarget, - bool sizeaverage); + bool sizeaverage, + bool reduce); TH_API void THNN_(MultiLabelMarginCriterion_updateGradInput)( THCState *state, THCTensor *input, THCIndexTensor *target, + THCTensor *gradOutput, THCTensor *gradInput, THCTensor *istarget, - bool sizeaverage); + bool sizeaverage, + bool reduce); TH_API void THNN_(MultiMarginCriterion_updateOutput)( THCState *state, diff --git a/aten/src/THNN/generic/MultiLabelMarginCriterion.c b/aten/src/THNN/generic/MultiLabelMarginCriterion.c index dfef8c944ebf0..221dc5ad6d6c6 100644 --- a/aten/src/THNN/generic/MultiLabelMarginCriterion.c +++ b/aten/src/THNN/generic/MultiLabelMarginCriterion.c @@ -9,9 +9,10 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)( THIndexTensor *target, THTensor *output, THTensor *isTarget, - bool sizeAverage) + bool sizeAverage, + bool reduce) { - real *input_data, *isTarget_data; + real *input_data, *output_data, *isTarget_data; THIndex_t *target_data; int64_t nframe, dim; int64_t t, d, dt, ddt; @@ -19,7 +20,6 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)( THArgCheck((input->nDimension == 1) || (input->nDimension == 2), 2, "vector or matrix expected"); - THTensor_(resize1d)(output, 1); if (input->nDimension == 1) { @@ -48,7 +48,55 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)( THTensor_(zero)(isTarget); isTarget_data = THTensor_(data)(isTarget); - sum = 0; + if (reduce) + { + THTensor_(resize1d)(output, 1); + + sum = 0; + for (t = 0; t < nframe; t++) + { + for (ddt = 0; ddt < dim; ddt++) + { + THIndex_t target_idx = target_data[ddt] - TH_INDEX_BASE; + if (target_idx < 0) + break; + isTarget_data[target_idx] = 1; + } + for (dt = 0; dt < dim; dt++) + { + THIndex_t target_idx = target_data[dt] - TH_INDEX_BASE; + real input_target; + if (target_idx < 0) + break; + + input_target = input_data[target_idx]; + for (d = 0; d < dim; d++) + { + if (!isTarget_data[d]) + { + real z = 1 - input_target + input_data[d]; + if (z > 0) + sum += z; + } + } + } + input_data += dim; + target_data += dim; + isTarget_data += dim; + } + + sum /= dim; + if (sizeAverage) + sum /= nframe; + THTensor_(set1d)(output, 0, sum); + + THTensor_(free)(input); + THIndexTensor_(free)(target); + return; + } + + THTensor_(resize1d)(output, nframe); + for (t = 0; t < nframe; t++) { for (ddt = 0; ddt < dim; ddt++) @@ -58,6 +106,8 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)( break; isTarget_data[target_idx] = 1; } + + sum = 0; for (dt = 0; dt < dim; dt++) { THIndex_t target_idx = target_data[dt] - TH_INDEX_BASE; @@ -76,17 +126,15 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)( } } } + + sum /= dim; + THTensor_(set1d)(output, t, sum); + input_data += dim; target_data += dim; isTarget_data += dim; } - sum /= dim; - if (sizeAverage) - sum /= nframe; - - THTensor_(set1d)(output, 0, sum); - THTensor_(free)(input); THIndexTensor_(free)(target); } @@ -95,9 +143,11 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)( THNNState *state, THTensor *input, THIndexTensor *target, + THTensor *gradOutput, THTensor *gradInput, THTensor *isTarget, - bool sizeAverage) + bool sizeAverage, + bool reduce) { real *input_data; real *gradInput_data; @@ -142,12 +192,12 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)( target_data = THIndexTensor_(data)(target); isTarget_data = THTensor_(data)(isTarget); - g = sizeAverage ? ( 1./((real)(nframe*dim)) ) : ( 1./((real)dim) ); - THTensor_(resizeAs)(gradInput, input); THTensor_(zero)(gradInput); gradInput_data = THTensor_(data)(gradInput); + g = sizeAverage && reduce ? (1./((real)(nframe*dim))) : (1./((real)dim)); + for (t = 0; t < nframe; t++) { for (dt = 0; dt < dim; dt++) @@ -176,10 +226,33 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)( isTarget_data += dim; gradInput_data += dim; } + gradInput_data -= nframe*dim; + + if (reduce) + { + THNN_CHECK_DIM_SIZE(gradOutput, 1, 0, 1); + for (t = 0; t < nframe*dim; t++) + { + gradInput_data[t] *= THTensor_fastGet1d(gradOutput, 0); + } + } + else + { + THNN_CHECK_DIM_SIZE(gradOutput, 1, 0, nframe); + gradOutput = THTensor_(newContiguous)(gradOutput); + for (t = 0; t < nframe; t++) + { + for (d = 0; d < dim; d++) + { + gradInput_data[t * dim + d] *= THTensor_fastGet1d(gradOutput, t); + } + } + } THTensor_(free)(input); THIndexTensor_(free)(target); THTensor_(free)(isTarget); + THTensor_(free)(gradOutput); } #endif diff --git a/aten/src/THNN/generic/THNN.h b/aten/src/THNN/generic/THNN.h index c46211e98230a..6314fa3b8ca8d 100644 --- a/aten/src/THNN/generic/THNN.h +++ b/aten/src/THNN/generic/THNN.h @@ -358,14 +358,17 @@ TH_API void THNN_(MultiLabelMarginCriterion_updateOutput)( THIndexTensor *target, THTensor *output, THTensor *isTarget, - bool sizeAverage); + bool sizeAverage, + bool reduce); TH_API void THNN_(MultiLabelMarginCriterion_updateGradInput)( THNNState *state, THTensor *input, THIndexTensor *target, + THTensor *gradOutput, THTensor *gradInput, THTensor *isTarget, - bool sizeAverage); + bool sizeAverage, + bool reduce); TH_API void THNN_(MultiMarginCriterion_updateOutput)( THNNState *state, diff --git a/test/common_nn.py b/test/common_nn.py index 6b8c64be24963..1be2a2fb74117 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -322,11 +322,55 @@ def smoothl1loss_reference(input, target, size_average=True, reduce=True): return output +def _multilabelmarginloss_reference(input, target, is_target): + sum = 0 + for i in range(0, target.size()[0]): + target_index = target[i] + if (target_index < 0): + break + is_target[target_index] = 1 + for i in range(0, target.size()[0]): + target_index = target[i] + if (target_index < 0): + break + + for j in range (0, target.size()[0]): + if not is_target[j]: + z = 1 - input[target_index] + input[j] + if z > 0: + sum += z + + return sum + + +def multilabelmarginloss_reference(input, target, size_average=True, reduce=True): + is_target = torch.LongTensor(input.size()).zero_() + + if input.dim() == 1: + n = 1 + dim = input.size()[0] + output = torch.Tensor(n).zero_() + output[0] = _multilabelmarginloss_reference(input, target, is_target) + else: + n = input.size()[0] + dim = input.size()[1] + output = torch.Tensor(n).zero_() + for i in range(0, n): + output[i] = _multilabelmarginloss_reference(input[i], target[i], is_target[i]) + + if reduce and size_average: + return output.mean()/dim + elif reduce: + return output.sum()/dim + return output/dim + + loss_reference_fns = { 'KLDivLoss': kldivloss_reference, 'NLLLoss': nllloss_reference, 'NLLLossNd': nlllossNd_reference, 'SmoothL1Loss': smoothl1loss_reference, + 'MultiLabelMarginLoss': multilabelmarginloss_reference, } @@ -439,10 +483,22 @@ def smoothl1loss_reference(input, target, size_average=True, reduce=True): desc='margin', check_no_size_average=True, ), + dict( + module_name='MultiLabelMarginLoss', + input_size=(10,), + target_fn=lambda: torch.rand(10).mul(10).floor().long(), + reference_fn=lambda i, t, m: + multilabelmarginloss_reference(i, t, size_average=get_size_average(m)), + desc="1d", + check_no_size_average=True, + check_gradgrad=False, + ), dict( module_name='MultiLabelMarginLoss', input_size=(5, 10), target_fn=lambda: torch.rand(5, 10).mul(10).floor().long(), + reference_fn=lambda i, t, m: + multilabelmarginloss_reference(i, t, size_average=get_size_average(m)), check_no_size_average=True, check_gradgrad=False, ), diff --git a/test/test_nn.py b/test/test_nn.py index 6ac7d113c4431..2ea8f4f8de275 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4557,6 +4557,34 @@ def smoothl1loss_no_reduce_test(): pickle=False) +def multilabelmarginloss_1d_no_reduce_test(): + t = Variable(torch.rand(10).mul(10).floor().long()) + return dict( + fullname='MultiLabelMarginLoss_no_reduce', + constructor=wrap_functional( + lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduce=False)), + input_fn=lambda: torch.randn(10), + reference_fn=lambda i, _: + loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduce=False), + check_no_size_average=True, + check_gradgrad=False, + pickle=False) + + +def multilabelmarginloss_no_reduce_test(): + t = Variable(torch.rand(5, 10).mul(10).floor().long()) + return dict( + fullname='MultiLabelMarginLoss_1d_no_reduce', + constructor=wrap_functional( + lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduce=False)), + input_fn=lambda: torch.randn(5, 10), + reference_fn=lambda i, _: + loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduce=False), + check_no_size_average=True, + check_gradgrad=False, + pickle=False) + + new_module_tests = [ poissonnllloss_no_reduce_test(), bceloss_no_reduce_test(), @@ -4577,6 +4605,8 @@ def smoothl1loss_no_reduce_test(): nlllossNd_no_reduce_weights_test(), nlllossNd_no_reduce_ignore_index_test(), smoothl1loss_no_reduce_test(), + multilabelmarginloss_1d_no_reduce_test(), + multilabelmarginloss_no_reduce_test(), dict( module_name='BatchNorm1d', constructor_args=(10,), diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 328ccf204a535..d2d3c5a5ebaa3 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -673,8 +673,8 @@ - name: multi_margin_loss_forward(Tensor self, Tensor target, Scalar p, Scalar margin, Tensor weight, bool size_average) self: multi_margin_loss_backward(self, target, p, margin, weight, size_average).mul_(grad) -- name: multilabel_margin_loss_forward(Tensor self, Tensor target, bool size_average) - self: multilabel_margin_loss_backward(self, target, size_average, is_target).mul_(grad) +- name: multilabel_margin_loss_forward(Tensor self, Tensor target, bool size_average, bool reduce) + self: multilabel_margin_loss_backward(grad, self, target, size_average, reduce, is_target) - name: nll_loss_forward(Tensor self, Tensor target, Tensor weight, bool size_average, int64_t ignore_index, bool reduce) self: nll_loss_backward(grad, self, target, weight, size_average, ignore_index, reduce, total_weight) diff --git a/torch/legacy/nn/MultiLabelMarginCriterion.py b/torch/legacy/nn/MultiLabelMarginCriterion.py index 42d6f7ac91119..7f49e6da10841 100644 --- a/torch/legacy/nn/MultiLabelMarginCriterion.py +++ b/torch/legacy/nn/MultiLabelMarginCriterion.py @@ -20,19 +20,23 @@ def updateOutput(self, input, target): target, self.output_tensor, self.isTarget, - self.sizeAverage + self.sizeAverage, + True, # reduce ) self.output = self.output_tensor[0] return self.output def updateGradInput(self, input, target): target = target.long() + implicit_gradOutput = torch.ones(1).type_as(input) self._backend.MultiLabelMarginCriterion_updateGradInput( self._backend.library_state, input, target, + implicit_gradOutput, self.gradInput, self.isTarget, - self.sizeAverage + self.sizeAverage, + True, # reduce ) return self.gradInput diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 8c20eb2782083..9e1451c06bb41 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -1476,7 +1476,7 @@ def hinge_embedding_loss(input, target, margin=1.0, size_average=True): multilabel_margin_loss = _add_docstr(torch._C._nn.multilabel_margin_loss, r""" -multilabel_margin_loss(input, target, size_average=True) -> Variable +multilabel_margin_loss(input, target, size_average=True, reduce=True) -> Variable See :class:`~torch.nn.MultiLabelMarginLoss` for details. """) diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index e91f3823b839d..68dc23b7c2008 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -565,10 +565,32 @@ class MultiLabelMarginLoss(_Loss): The criterion only considers the first non-negative `y[j]` targets. This allows for different samples to have variable amounts of target classes + + Args: + size_average (bool, optional): By default, the losses are averaged + over observations for each minibatch. However, if the field + size_average is set to ``False``, the losses are instead summed for + each minibatch. Default: ``True`` + reduce (bool, optional): By default, the losses are averaged or summed over + observations for each minibatch depending on size_average. When reduce + is False, returns a loss per batch element instead and ignores + size_average. Default: True + + Shape: + - Input: :math:`(N)` or :math:`(N, *)` where `*` means, any number of additional + dimensions + - Target: :math:`(N)` or :math:`(N, *)`, same shape as the input + - Output: scalar. If `reduce` is False, then `(N)` or `(N, *)`, same shape as + input. """ + def __init__(self, size_average=True, reduce=True): + super(MultiLabelMarginLoss, self).__init__(size_average) + self.reduce = reduce + def forward(self, input, target): _assert_no_grad(target) - return F.multilabel_margin_loss(input, target, size_average=self.size_average) + return F.multilabel_margin_loss(input, target, size_average=self.size_average, + reduce=self.reduce) class SmoothL1Loss(_Loss): From 212c331437c4772e27ebffebc35fc56a0f7e9144 Mon Sep 17 00:00:00 2001 From: Roy Li Date: Mon, 29 Jan 2018 17:23:33 -0800 Subject: [PATCH 2/4] Fix lint --- test/common_nn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/common_nn.py b/test/common_nn.py index 1be2a2fb74117..a83018a31d1da 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -334,7 +334,7 @@ def _multilabelmarginloss_reference(input, target, is_target): if (target_index < 0): break - for j in range (0, target.size()[0]): + for j in range(0, target.size()[0]): if not is_target[j]: z = 1 - input[target_index] + input[j] if z > 0: @@ -359,10 +359,10 @@ def multilabelmarginloss_reference(input, target, size_average=True, reduce=True output[i] = _multilabelmarginloss_reference(input[i], target[i], is_target[i]) if reduce and size_average: - return output.mean()/dim + return output.mean() / dim elif reduce: - return output.sum()/dim - return output/dim + return output.sum() / dim + return output / dim loss_reference_fns = { From 25ff9538916f18366d1b7f8a69bea9074555a228 Mon Sep 17 00:00:00 2001 From: Roy Li Date: Wed, 31 Jan 2018 15:51:05 -0800 Subject: [PATCH 3/4] Addressed comments --- .../THNN/generic/MultiLabelMarginCriterion.c | 10 ++--- test/common_nn.py | 41 ++++++++----------- test/test_nn.py | 15 +++++++ torch/nn/modules/loss.py | 12 +++--- 4 files changed, 43 insertions(+), 35 deletions(-) diff --git a/aten/src/THNN/generic/MultiLabelMarginCriterion.c b/aten/src/THNN/generic/MultiLabelMarginCriterion.c index 221dc5ad6d6c6..8e9a3b5182cac 100644 --- a/aten/src/THNN/generic/MultiLabelMarginCriterion.c +++ b/aten/src/THNN/generic/MultiLabelMarginCriterion.c @@ -88,7 +88,7 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)( sum /= dim; if (sizeAverage) sum /= nframe; - THTensor_(set1d)(output, 0, sum); + THTensor_fastSet1d(output, 0, sum); THTensor_(free)(input); THIndexTensor_(free)(target); @@ -128,7 +128,7 @@ void THNN_(MultiLabelMarginCriterion_updateOutput)( } sum /= dim; - THTensor_(set1d)(output, t, sum); + THTensor_fastSet1d(output, t, sum); input_data += dim; target_data += dim; @@ -193,6 +193,7 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)( isTarget_data = THTensor_(data)(isTarget); THTensor_(resizeAs)(gradInput, input); + gradInput = THTensor_(newContiguous)(gradInput); THTensor_(zero)(gradInput); gradInput_data = THTensor_(data)(gradInput); @@ -226,7 +227,7 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)( isTarget_data += dim; gradInput_data += dim; } - gradInput_data -= nframe*dim; + gradInput_data = THTensor_(data)(gradInput); if (reduce) { @@ -239,7 +240,6 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)( else { THNN_CHECK_DIM_SIZE(gradOutput, 1, 0, nframe); - gradOutput = THTensor_(newContiguous)(gradOutput); for (t = 0; t < nframe; t++) { for (d = 0; d < dim; d++) @@ -252,7 +252,7 @@ void THNN_(MultiLabelMarginCriterion_updateGradInput)( THTensor_(free)(input); THIndexTensor_(free)(target); THTensor_(free)(isTarget); - THTensor_(free)(gradOutput); + THTensor_(free)(gradInput); } #endif diff --git a/test/common_nn.py b/test/common_nn.py index a83018a31d1da..345723f36a2b6 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -322,41 +322,34 @@ def smoothl1loss_reference(input, target, size_average=True, reduce=True): return output -def _multilabelmarginloss_reference(input, target, is_target): - sum = 0 - for i in range(0, target.size()[0]): - target_index = target[i] - if (target_index < 0): - break - is_target[target_index] = 1 - for i in range(0, target.size()[0]): - target_index = target[i] - if (target_index < 0): +def _multilabelmarginloss_reference(input, target): + targets = [] + for target_index in target: + if target_index < 0: break + targets.append(target_index) - for j in range(0, target.size()[0]): - if not is_target[j]: - z = 1 - input[target_index] + input[j] - if z > 0: - sum += z + sum = 0 + for target_index in targets: + for i in range(0, len(input)): + if i not in targets: + sum += max(0, 1 - input[target_index] + input[i]) return sum def multilabelmarginloss_reference(input, target, size_average=True, reduce=True): - is_target = torch.LongTensor(input.size()).zero_() - if input.dim() == 1: n = 1 - dim = input.size()[0] - output = torch.Tensor(n).zero_() - output[0] = _multilabelmarginloss_reference(input, target, is_target) + dim = input.size(0) + output = torch.zeros(n) + output[0] = _multilabelmarginloss_reference(input, target) else: - n = input.size()[0] - dim = input.size()[1] - output = torch.Tensor(n).zero_() + n = input.size(0) + dim = input.size(1) + output = torch.zeros(n) for i in range(0, n): - output[i] = _multilabelmarginloss_reference(input[i], target[i], is_target[i]) + output[i] = _multilabelmarginloss_reference(input[i], target[i]) if reduce and size_average: return output.mean() / dim diff --git a/test/test_nn.py b/test/test_nn.py index 2ea8f4f8de275..99ba5eb71a9ae 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4571,6 +4571,20 @@ def multilabelmarginloss_1d_no_reduce_test(): pickle=False) +def multilabelmarginloss_index_neg_test(): + t = Variable(torch.clamp(torch.rand(5, 10).add(-.5).mul(20).floor().long(), min=-1)) + return dict( + fullname='MultiLabelMarginLoss_index_neg', + constructor=wrap_functional( + lambda i: F.multilabel_margin_loss(i, t.type_as(i).long(), reduce=False)), + input_fn=lambda: torch.randn(5, 10), + reference_fn=lambda i, _: + loss_reference_fns['MultiLabelMarginLoss'](i, t.data.type_as(i).long(), reduce=False), + check_no_size_average=True, + check_gradgrad=False, + pickle=False) + + def multilabelmarginloss_no_reduce_test(): t = Variable(torch.rand(5, 10).mul(10).floor().long()) return dict( @@ -4606,6 +4620,7 @@ def multilabelmarginloss_no_reduce_test(): nlllossNd_no_reduce_ignore_index_test(), smoothl1loss_no_reduce_test(), multilabelmarginloss_1d_no_reduce_test(), + multilabelmarginloss_index_neg_test(), multilabelmarginloss_no_reduce_test(), dict( module_name='BatchNorm1d', diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 68dc23b7c2008..cd18b6ae4a312 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -562,7 +562,8 @@ class MultiLabelMarginLoss(_Loss): `y` and `x` must have the same size. - The criterion only considers the first non-negative `y[j]` targets. + The criterion only considers a contiguous block of non-negative targets that + starts at the front. This allows for different samples to have variable amounts of target classes @@ -577,11 +578,10 @@ class MultiLabelMarginLoss(_Loss): size_average. Default: True Shape: - - Input: :math:`(N)` or :math:`(N, *)` where `*` means, any number of additional - dimensions - - Target: :math:`(N)` or :math:`(N, *)`, same shape as the input - - Output: scalar. If `reduce` is False, then `(N)` or `(N, *)`, same shape as - input. + - Input: :math:`(C)` or :math:`(N, C)` where `N` is the batch size and `C` + is the number of classes. + - Target: :math:`(C)` or :math:`(N, C)`, same shape as the input. + - Output: scalar. If `reduce` is False, then `(N)`. """ def __init__(self, size_average=True, reduce=True): super(MultiLabelMarginLoss, self).__init__(size_average) From 11e85a0a4513d0b497815cb24917ff2558371be0 Mon Sep 17 00:00:00 2001 From: Roy Li Date: Fri, 2 Feb 2018 13:28:45 -0800 Subject: [PATCH 4/4] Remove unneeded syncthreads calls --- aten/src/THCUNN/MultiLabelMarginCriterion.cu | 2 -- torch/nn/modules/loss.py | 14 +++++++------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/aten/src/THCUNN/MultiLabelMarginCriterion.cu b/aten/src/THCUNN/MultiLabelMarginCriterion.cu index 460c42355a0f6..a8dc15e21137c 100644 --- a/aten/src/THCUNN/MultiLabelMarginCriterion.cu +++ b/aten/src/THCUNN/MultiLabelMarginCriterion.cu @@ -138,13 +138,11 @@ __global__ void cunn_MultiLabelMarginCriterion_updateGradInput_kernel(Dtype *gra if (threadIdx.x == 0) { gradInput_k[target_idx] += ScalarConvert::to(totalSum); } - __syncthreads(); } for (int d = threadIdx.x; d < dim; d += blockDim.x) { gradInput_k[d] *= *gradOutput_k; } - __syncthreads(); } #include "generic/MultiLabelMarginCriterion.cu" diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index cd18b6ae4a312..bb9faf3eee44f 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -568,14 +568,14 @@ class MultiLabelMarginLoss(_Loss): This allows for different samples to have variable amounts of target classes Args: - size_average (bool, optional): By default, the losses are averaged - over observations for each minibatch. However, if the field - size_average is set to ``False``, the losses are instead summed for - each minibatch. Default: ``True`` + size_average (bool, optional): By default, the losses are averaged over + observations for each minibatch. However, if the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. + Default: ``True`` reduce (bool, optional): By default, the losses are averaged or summed over - observations for each minibatch depending on size_average. When reduce - is False, returns a loss per batch element instead and ignores - size_average. Default: True + observations for each minibatch depending on :attr:`size_average`. When + :attr:`reduce` is ``False``, returns a loss per batch element instead and + ignores :attr:`size_average`. Default: ``True`` Shape: - Input: :math:`(C)` or :math:`(N, C)` where `N` is the batch size and `C`