Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion stdlib/public/core/Array.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1875,7 +1875,7 @@ extension Array where Element : Differentiable {
/// The view of an array as the differentiable product manifold of `Element`
/// multiplied with itself `count` times.
@_fixed_layout
public struct DifferentiableView : Differentiable & KeyPathIterable {
public struct DifferentiableView : Differentiable {
private var _base: [Element]

/// The viewed array.
Expand Down
7 changes: 7 additions & 0 deletions stdlib/public/core/KeyPathIterable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ extension Array : KeyPathIterable {
}
}

extension Array.DifferentiableView : KeyPathIterable {
public typealias AllKeyPaths = [PartialKeyPath<Array.DifferentiableView>]
public var allKeyPaths: [PartialKeyPath<Array.DifferentiableView>] {
return [\Array.DifferentiableView.base]
}
}

extension Dictionary : KeyPathIterable {
public typealias AllKeyPaths = [PartialKeyPath<Dictionary>]
public var allKeyPaths: [PartialKeyPath<Dictionary>] {
Expand Down
15 changes: 13 additions & 2 deletions test/TensorFlowRuntime/key_path_iterable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
//
// `KeyPathIterable` tests.

import TensorFlow
import StdlibUnittest

var KeyPathIterableTests = TestSuite("KeyPathIterable")
Expand Down Expand Up @@ -38,7 +37,10 @@ struct ComplexNested : KeyPathIterable, Equatable {
}

// TF-123: Test type with `@differentiable` function stored property.
struct TF_123<Scalar : TensorFlowScalar & Differentiable & FloatingPoint> : KeyPathIterable {
struct Tensor<Scalar : Differentiable>: Differentiable {
var value: Scalar
}
struct TF_123<Scalar : Differentiable> : KeyPathIterable {
let activation1: @differentiable (Float) -> Float
let activation2: @differentiable (Tensor<Scalar>) -> Tensor<Scalar>
}
Expand Down Expand Up @@ -179,4 +181,13 @@ KeyPathIterableTests.test("ComplexNested") {
expectEqual(expected, x)
}

// Verify that `Array.DifferentiableView.allKeyPaths` uses public property
// `base` instead of private property `_base`.
KeyPathIterableTests.test("Array.DifferentiableView") {
let view = [Float].DifferentiableView([1, 2, 3])
let keyPaths1 = (0..<3).map { i in \Array<Float>.DifferentiableView.base[i] }
let keyPaths2 = view.allDifferentiableVariables.recursivelyAllWritableKeyPaths(to: Float.self)
expectEqual(keyPaths1, keyPaths2)
}

runAllTests()