diff --git a/torch_xla/csrc/tensor_ops.cpp b/torch_xla/csrc/tensor_ops.cpp index 518803530f0..0fd48a39996 100644 --- a/torch_xla/csrc/tensor_ops.cpp +++ b/torch_xla/csrc/tensor_ops.cpp @@ -182,10 +182,10 @@ XLATensor Softplus(const XLATensor& input, const at::Scalar& beta, XLATensor SoftplusBackward(const XLATensor& grad_output, const XLATensor& input, const at::Scalar& beta, const at::Scalar& threshold, const XLATensor& output) { - XLATensor scaled_output = XLATensor::mul(output, beta); - XLATensor z = XLATensor::exp(scaled_output); + XLATensor scaled_input = XLATensor::mul(input, beta); + XLATensor z = XLATensor::exp(XLATensor::mul(output, beta)); return XLATensor::where( - XLATensor::gt(scaled_output, threshold), grad_output, + XLATensor::gt(scaled_input, threshold), grad_output, XLATensor::mul(grad_output, XLATensor::div(XLATensor::sub(z, 1, 1), z))); }