diff --git a/stdlib/public/core/Array.swift b/stdlib/public/core/Array.swift index 534c1e7bf004e..02b8ce8e58c53 100644 --- a/stdlib/public/core/Array.swift +++ b/stdlib/public/core/Array.swift @@ -1875,7 +1875,7 @@ extension Array where Element : Differentiable { /// The view of an array as the differentiable product manifold of `Element` /// multiplied with itself `count` times. @_fixed_layout - public struct DifferentiableView : Differentiable & KeyPathIterable { + public struct DifferentiableView : Differentiable { private var _base: [Element] /// The viewed array. diff --git a/stdlib/public/core/KeyPathIterable.swift b/stdlib/public/core/KeyPathIterable.swift index 2394d3a29f0b9..bd674fcd32f25 100644 --- a/stdlib/public/core/KeyPathIterable.swift +++ b/stdlib/public/core/KeyPathIterable.swift @@ -104,6 +104,13 @@ extension Array : KeyPathIterable { } } +extension Array.DifferentiableView : KeyPathIterable { + public typealias AllKeyPaths = [PartialKeyPath] + public var allKeyPaths: [PartialKeyPath] { + return [\Array.DifferentiableView.base] + } +} + extension Dictionary : KeyPathIterable { public typealias AllKeyPaths = [PartialKeyPath] public var allKeyPaths: [PartialKeyPath] { diff --git a/test/TensorFlowRuntime/key_path_iterable.swift b/test/TensorFlowRuntime/key_path_iterable.swift index 91192d39f5f64..852e2c8c4353c 100644 --- a/test/TensorFlowRuntime/key_path_iterable.swift +++ b/test/TensorFlowRuntime/key_path_iterable.swift @@ -6,7 +6,6 @@ // // `KeyPathIterable` tests. -import TensorFlow import StdlibUnittest var KeyPathIterableTests = TestSuite("KeyPathIterable") @@ -38,7 +37,10 @@ struct ComplexNested : KeyPathIterable, Equatable { } // TF-123: Test type with `@differentiable` function stored property. -struct TF_123 : KeyPathIterable { +struct Tensor: Differentiable { + var value: Scalar +} +struct TF_123 : KeyPathIterable { let activation1: @differentiable (Float) -> Float let activation2: @differentiable (Tensor) -> Tensor } @@ -179,4 +181,13 @@ KeyPathIterableTests.test("ComplexNested") { expectEqual(expected, x) } +// Verify that `Array.DifferentiableView.allKeyPaths` uses public property +// `base` instead of private property `_base`. +KeyPathIterableTests.test("Array.DifferentiableView") { + let view = [Float].DifferentiableView([1, 2, 3]) + let keyPaths1 = (0..<3).map { i in \Array.DifferentiableView.base[i] } + let keyPaths2 = view.allDifferentiableVariables.recursivelyAllWritableKeyPaths(to: Float.self) + expectEqual(keyPaths1, keyPaths2) +} + runAllTests()