diff --git a/stdlib/public/TensorFlow/Gradients.swift b/stdlib/public/TensorFlow/Gradients.swift index 7d120926d130a..06b9a777a2c55 100644 --- a/stdlib/public/TensorFlow/Gradients.swift +++ b/stdlib/public/TensorFlow/Gradients.swift @@ -571,7 +571,7 @@ extension Tensor where Scalar : TensorFlowFloatingPoint { @inlinable func _vjpMean() -> (Tensor, (Tensor) -> Tensor) { return (mean(), { [shape = shapeTensor, count = scalarCountTensor] in - $0.broadcast(toShape: shape) / Tensor(count) + ($0 / Tensor(count)).broadcast(toShape: shape) }) }