From 881d954502c2149aed40e73c6eba0d136b2dceda Mon Sep 17 00:00:00 2001 From: James Bradbury Date: Wed, 16 Jan 2019 18:53:25 -0800 Subject: [PATCH] [AutoDiff] [API] Fix bug in Tensor.reshaped VJP Also adds test. --- stdlib/public/TensorFlow/Gradients.swift | 2 +- test/TensorFlowRuntime/tensor_autodiff_runtime.swift | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/stdlib/public/TensorFlow/Gradients.swift b/stdlib/public/TensorFlow/Gradients.swift index e919966bdbc43..18ae46b7d5203 100644 --- a/stdlib/public/TensorFlow/Gradients.swift +++ b/stdlib/public/TensorFlow/Gradients.swift @@ -518,7 +518,7 @@ extension Tensor where Scalar : Differentiable & FloatingPoint { ) -> (Tensor, (Tensor) -> Tensor) { let value = reshaped(toShape: newShape) return (value, { v in - return v.reshaped(toShape: newShape) + return v.reshaped(toShape: self.shapeTensor) }) } diff --git a/test/TensorFlowRuntime/tensor_autodiff_runtime.swift b/test/TensorFlowRuntime/tensor_autodiff_runtime.swift index 50c1b652c6f3c..9b5045a9ef67f 100644 --- a/test/TensorFlowRuntime/tensor_autodiff_runtime.swift +++ b/test/TensorFlowRuntime/tensor_autodiff_runtime.swift @@ -92,16 +92,24 @@ TensorADTests.testAllBackends("mean") { expectTrue(meanGradAlongAxes(input) == expected) } +TensorADTests.testAllBackends("reshaped") { + let shapeTensor = Tensor([2, 2, 2]) + let input = Tensor(ones: [2, 4]) + let reshapedPullback = pullback(at: input) { (a: Tensor) in a.reshaped(toShape: shapeTensor) } + let reshaped = Tensor(ones: [2, 2, 2]) + expectTrue(reshapedPullback(reshaped) == input) +} + TensorADTests.testAllBackends("transposed") { let input = Tensor(ones: [2, 3]) let transposed = Tensor(ones: [3, 2]) let transposedPullback = pullback(at: input) { (a: Tensor) in a.transposed() } let transposedPermutationsPullback = pullback(at: input) { (a: Tensor) in a.transposed(withPermutations: [1, 0]) } - let transposedVariadiicsPullback = pullback(at: input) { (a: Tensor) in a.transposed(withPermutations: 1, 0) } + let transposedVariadicsPullback = pullback(at: input) { (a: Tensor) in a.transposed(withPermutations: 1, 0) } expectTrue(transposedPullback(transposed) == input) expectTrue(transposedPermutationsPullback(transposed) == input) - expectTrue(transposedVariadiicsPullback(transposed) == input) + expectTrue(transposedVariadicsPullback(transposed) == input) } TensorADTests.testAllBackends("relu") {