diff --git a/stdlib/public/TensorFlow/Tensor.swift b/stdlib/public/TensorFlow/Tensor.swift index 54a23e20234d9..75b0b4d5f670e 100644 --- a/stdlib/public/TensorFlow/Tensor.swift +++ b/stdlib/public/TensorFlow/Tensor.swift @@ -632,6 +632,7 @@ public extension Tensor { /// Reshape to the shape of the specified `Tensor`. /// - Precondition: The number of scalars matches the new shape. @inlinable @inline(__always) + @differentiable(wrt: self where Scalar : TensorFlowFloatingPoint) func reshaped(like other: Tensor) -> Tensor { return reshaped(toShape: other.shapeTensor) } @@ -639,6 +640,7 @@ public extension Tensor { /// Reshape to the specified shape. /// - Precondition: The number of scalars matches the new shape. @inlinable @inline(__always) + @differentiable(wrt: self where Scalar : TensorFlowFloatingPoint) func reshaped(to newShape: TensorShape) -> Tensor { return reshaped(toShape: Tensor(newShape.dimensions)) }