diff --git a/kernels/portable/cpu/op_convolution_backward.cpp b/kernels/portable/cpu/op_convolution_backward.cpp index 848b66ec559..7884ea0c44c 100644 --- a/kernels/portable/cpu/op_convolution_backward.cpp +++ b/kernels/portable/cpu/op_convolution_backward.cpp @@ -34,7 +34,7 @@ bool check_convolution_backward_args( bool transposed, IntArrayRef output_padding, int64_t groups, - ET_UNUSED executorch::aten::ArrayRef output_mask, + executorch::aten::ArrayRef output_mask, Tensor& grad_input, Tensor& grad_weight, Tensor& grad_bias) { @@ -45,9 +45,18 @@ bool check_convolution_backward_args( ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(weight, input)); ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_output, input)); - ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_input, input)); - ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_weight, input)); - ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_bias, input)); + + if (output_mask[0]) { + ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_input, input)); + } + + if (output_mask[1]) { + ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_weight, input)); + } + + if (output_mask[2]) { + ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_bias, input)); + } ET_LOG_MSG_AND_RETURN_IF_FALSE( check_convolution_args( @@ -267,19 +276,23 @@ std::tuple convolution_backward_out( InvalidArgument, ret_val); - ET_KERNEL_CHECK( - ctx, - resize_tensor(grad_input, input.sizes()) == Error::Ok, - InvalidArgument, - ret_val); + if (output_mask[0]) { + ET_KERNEL_CHECK( + ctx, + resize_tensor(grad_input, input.sizes()) == Error::Ok, + InvalidArgument, + ret_val); + } - ET_KERNEL_CHECK( - ctx, - resize_tensor(grad_weight, weight.sizes()) == Error::Ok, - InvalidArgument, - ret_val); + if (output_mask[1]) { + ET_KERNEL_CHECK( + ctx, + resize_tensor(grad_weight, weight.sizes()) == Error::Ok, + InvalidArgument, + ret_val); + } - if (bias_sizes_opt.has_value()) { + if (bias_sizes_opt.has_value() && output_mask[2]) { ET_KERNEL_CHECK( ctx, resize_tensor(grad_bias, bias_sizes_opt.value()) == Error::Ok,