From 540644753f2c35e9531fa9ea802e1c402adebefe Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Wed, 27 Feb 2019 18:33:43 -0800 Subject: [PATCH 1/2] [TF API] Make `reshaped(to:)` and `reshaped(like:)` differentiable. --- stdlib/public/TensorFlow/Tensor.swift | 2 ++ 1 file changed, 2 insertions(+) diff --git a/stdlib/public/TensorFlow/Tensor.swift b/stdlib/public/TensorFlow/Tensor.swift index 54a23e20234d9..34589227e7174 100644 --- a/stdlib/public/TensorFlow/Tensor.swift +++ b/stdlib/public/TensorFlow/Tensor.swift @@ -632,6 +632,7 @@ public extension Tensor { /// Reshape to the shape of the specified `Tensor`. /// - Precondition: The number of scalars matches the new shape. @inlinable @inline(__always) + @differentiable(wrt: self) func reshaped(like other: Tensor) -> Tensor { return reshaped(toShape: other.shapeTensor) } @@ -639,6 +640,7 @@ public extension Tensor { /// Reshape to the specified shape. /// - Precondition: The number of scalars matches the new shape. @inlinable @inline(__always) + @differentiable func reshaped(to newShape: TensorShape) -> Tensor { return reshaped(toShape: Tensor(newShape.dimensions)) } From 0ead422e4381918c78914922c90ec4f41ed63fb7 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Wed, 27 Feb 2019 20:37:30 -0800 Subject: [PATCH 2/2] Forgot about generic requirements. --- stdlib/public/TensorFlow/Tensor.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stdlib/public/TensorFlow/Tensor.swift b/stdlib/public/TensorFlow/Tensor.swift index 34589227e7174..75b0b4d5f670e 100644 --- a/stdlib/public/TensorFlow/Tensor.swift +++ b/stdlib/public/TensorFlow/Tensor.swift @@ -632,7 +632,7 @@ public extension Tensor { /// Reshape to the shape of the specified `Tensor`. /// - Precondition: The number of scalars matches the new shape. @inlinable @inline(__always) - @differentiable(wrt: self) + @differentiable(wrt: self where Scalar : TensorFlowFloatingPoint) func reshaped(like other: Tensor) -> Tensor { return reshaped(toShape: other.shapeTensor) } @@ -640,7 +640,7 @@ public extension Tensor { /// Reshape to the specified shape. /// - Precondition: The number of scalars matches the new shape. @inlinable @inline(__always) - @differentiable + @differentiable(wrt: self where Scalar : TensorFlowFloatingPoint) func reshaped(to newShape: TensorShape) -> Tensor { return reshaped(toShape: Tensor(newShape.dimensions)) }