Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions include/swift/AST/KnownProtocols.def
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,16 @@ PROTOCOL(KeyPathIterable)
PROTOCOL(TensorArrayProtocol)
PROTOCOL(TensorGroup)
PROTOCOL(VectorProtocol)
PROTOCOL(Differentiable)
PROTOCOL(EuclideanDifferentiable)
PROTOCOL(Expression)

PROTOCOL_(ObjectiveCBridgeable)
PROTOCOL_(DestructorSafeContainer)

PROTOCOL(StringInterpolationProtocol)

PROTOCOL_(Differentiable)

EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByArrayLiteral, "Array", false)
EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByBooleanLiteral, "BooleanLiteralType", true)
EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByDictionaryLiteral, "Dictionary", false)
Expand All @@ -107,9 +109,9 @@ EXPRESSIBLE_BY_LITERAL_PROTOCOL(ExpressibleByUnicodeScalarLiteral, "UnicodeScala
EXPRESSIBLE_BY_LITERAL_PROTOCOL_(ExpressibleByColorLiteral, "_ColorLiteralType", true)
EXPRESSIBLE_BY_LITERAL_PROTOCOL_(ExpressibleByImageLiteral, "_ImageLiteralType", true)
EXPRESSIBLE_BY_LITERAL_PROTOCOL_(ExpressibleByFileReferenceLiteral, "_FileReferenceLiteralType", true)
// SWIFT_ENABLE_TENSORFLOW
// TODO(TF-735): Implement ExpressibleByQuoteLiteral.
// EXPRESSIBLE_BY_LITERAL_PROTOCOL_(ExpressibleByQuoteLiteral, "_QuoteLiteralType", true)
PROTOCOL(Expression)

BUILTIN_EXPRESSIBLE_BY_LITERAL_PROTOCOL_(ExpressibleByBuiltinBooleanLiteral)
BUILTIN_EXPRESSIBLE_BY_LITERAL_PROTOCOL_(ExpressibleByBuiltinExtendedGraphemeClusterLiteral)
Expand Down
7 changes: 2 additions & 5 deletions lib/AST/Builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1007,11 +1007,8 @@ static ValueDecl *getAutoDiffApplyAssociatedFunction(
// rethrows -> (R, (R.TangentVector) -> ...T.TangentVector)
unsigned numGenericParams = 1 + arity;
BuiltinGenericSignatureBuilder builder(Context, numGenericParams);
// Look up the Differentiable protocol.
SmallVector<ValueDecl *, 1> diffableProtoLookup;
Context.lookupInSwiftModule("Differentiable", diffableProtoLookup);
assert(diffableProtoLookup.size() == 1);
auto *diffableProto = cast<ProtocolDecl>(diffableProtoLookup.front());
// Get the `Differentiable` protocol.
auto *diffableProto = Context.getProtocol(KnownProtocolKind::Differentiable);
// Create type parameters and add conformance constraints.
auto fnResultGen = makeGenericParam(arity);
builder.addConformanceRequirement(fnResultGen, diffableProto);
Expand Down
4 changes: 2 additions & 2 deletions lib/IRGen/GenMeta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4216,7 +4216,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::Encodable:
case KnownProtocolKind::Decodable:
case KnownProtocolKind::StringInterpolationProtocol:
case KnownProtocolKind::Expression:
case KnownProtocolKind::Differentiable:
// SWIFT_ENABLE_TENSORFLOW
case KnownProtocolKind::AdditiveArithmetic:
case KnownProtocolKind::PointwiseMultiplicative:
Expand All @@ -4225,8 +4225,8 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::TensorArrayProtocol:
case KnownProtocolKind::TensorGroup:
case KnownProtocolKind::VectorProtocol:
case KnownProtocolKind::Differentiable:
case KnownProtocolKind::EuclideanDifferentiable:
case KnownProtocolKind::Expression:
return SpecialProtocol::None;
}

Expand Down
5 changes: 3 additions & 2 deletions stdlib/public/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ set(SWIFTLIB_ESSENTIAL
ASCII.swift
Assert.swift
AssertCommon.swift
# SWIFT_ENABLE_TENSORFLOW
AutoDiff.swift
BidirectionalCollection.swift
Bitset.swift
Bool.swift
Expand Down Expand Up @@ -208,6 +206,9 @@ set(SWIFTLIB_SOURCES
CollectionDifference.swift
CollectionOfOne.swift
Diffing.swift
Differentiable.swift
# SWIFT_ENABLE_TENSORFLOW
DifferentiationSupport.swift
Mirror.swift
PlaygroundDisplay.swift
CommandLine.swift
Expand Down
66 changes: 66 additions & 0 deletions stdlib/public/core/Differentiable.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//===--- Differentiable.swift ---------------------------------*- swift -*-===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file defines the Differentiable protocol, used by the experimental
// differentiable programming project. Please see forum discussion for more
// information:
// https://forums.swift.org/t/differentiable-programming-mega-proposal/28547
//
//===----------------------------------------------------------------------===//

/// A type that mathematically represents a differentiable manifold whose
/// tangent spaces are finite-dimensional.
public protocol _Differentiable {
/// A type representing a differentiable value's derivatives.
///
/// Mathematically, this is equivalent to the tangent bundle of the
/// differentiable manifold represented by the differentiable type.
associatedtype TangentVector: _Differentiable & AdditiveArithmetic
where TangentVector.TangentVector == TangentVector

/// Moves `self` along the given direction. In Riemannian geometry, this is
/// equivalent to exponential map, which moves `self` on the geodesic surface
/// along the given tangent vector.
mutating func move(along direction: TangentVector)

// SWIFT_ENABLE_TENSORFLOW
/// A tangent vector such that `move(along: zeroTangentVector)` will not
/// modify `self`.
/// - Note: `zeroTangentVector` can be `TangentVector.zero` in most cases,
/// but types whose tangent vectors depend on instance properties of `self`
/// need to provide a different implementation. For example, the tangent
/// vector of an `Array` depends on the array's `count`.
@available(*, deprecated, message: """
`zeroTangentVector` derivation has not been implemented; do not use \
this property
""")
var zeroTangentVector: TangentVector { get }
}

public extension _Differentiable where TangentVector == Self {
mutating func move(along direction: TangentVector) {
self += direction
}
}

public extension _Differentiable {
// This is a temporary solution that allows us to add `zeroTangentVector`
// without implementing derived conformances. This property is marked
// unavailable because it will produce incorrect results when tangent vectors
// depend on instance properties of `self`.
// FIXME: Implement derived conformance and remove this default
// implementation.
var zeroTangentVector: TangentVector { .zero }
}

// SWIFT_ENABLE_TENSORFLOW
public typealias Differentiable = _Differentiable
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//===--- AutoDiff.swift ---------------------------------------*- swift -*-===//
//===--- DifferentiationSupport.swift -------------------------*- swift -*-===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
// Copyright (c) 2014 - 2019 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
Expand All @@ -12,7 +12,8 @@
//
// SWIFT_ENABLE_TENSORFLOW
//
// This file defines support for automatic differentiation.
// This file defines support for differentiable programming and deep learning
// APIs.
//
//===----------------------------------------------------------------------===//

Expand Down Expand Up @@ -101,8 +102,9 @@ public extension VectorProtocol {
}
}

/* Note: These default-implemented operators will slow down type-checking
performance and break existing code.
/*
// Note: These default-implemented operators will slow down type-checking
// performance and break existing code.

public extension VectorProtocol {
static func + (lhs: Self, rhs: VectorSpaceScalar) -> Self {
Expand Down Expand Up @@ -149,50 +151,6 @@ public extension VectorProtocol where VectorSpaceScalar : SignedNumeric {
}
*/

/// A type that mathematically represents a differentiable manifold whose
/// tangent spaces are finite-dimensional.
public protocol Differentiable {
/// A type representing a differentiable value’s derivatives.
///
/// Mathematically, this is equivalent to the tangent bundle of the
/// differentiable manifold represented by the differentiable type.
associatedtype TangentVector: Differentiable & AdditiveArithmetic
where TangentVector.TangentVector == TangentVector

/// Moves `self` along the given direction. In Riemannian geometry, this is
/// equivalent to exponential map, which moves `self` on the geodesic surface
/// along the given tangent vector.
mutating func move(along direction: TangentVector)

/// A tangent vector such that `move(along: zeroTangentVector)` will not
/// modify `self`.
/// - Note: `zeroTangentVector` can be `TangentVector.zero` in most cases,
/// but types whose tangent vectors depend on instance properties of `self`
/// need to provide a different implementation. For example, the tangent
/// vector of an `Array` depends on the array’s `count`.
@available(*, deprecated, message: """
`zeroTangentVector` derivation has not been implemented; do not use \
this property
""")
var zeroTangentVector: TangentVector { get }
}

public extension Differentiable {
// This is a temporary solution that allows us to add `zeroTangentVector`
// without implementing derived conformances. This property is marked
// unavailable because it will produce incorrect results when tangent vectors
// depend on instance properties of `self`.
// FIXME: Implement derived conformance and remove this default
// implementation.
var zeroTangentVector: TangentVector { .zero }
}

public extension Differentiable where TangentVector == Self {
mutating func move(along direction: TangentVector) {
self += direction
}
}

/// A type that is differentiable in the Euclidean space.
/// The type may represent a vector space, or consist of a vector space and some
/// other non-differentiable component.
Expand Down Expand Up @@ -1077,8 +1035,9 @@ public extension Array where Element: Differentiable {
}

//===----------------------------------------------------------------------===//
// JVP Diagnostics
// JVP diagnostics
//===----------------------------------------------------------------------===//

@_silgen_name("_printJVPErrorAndExit")
public func _printJVPErrorAndExit() -> Never {
fatalError("""
Expand Down
7 changes: 4 additions & 3 deletions stdlib/public/core/GroupInfo.json
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,6 @@
"SIMDVector.swift",
"SIMDVectorTypes.swift"]}
],
"AutoDiff": [
"AutoDiff.swift",
],
"Optional": [
"Optional.swift"
],
Expand Down Expand Up @@ -231,5 +228,9 @@
],
"Result": [
"Result.swift"
],
"DifferentiableProgramming": [
"Differentiable.swift",
"DifferentiationSupport.swift"
]
}
14 changes: 7 additions & 7 deletions test/AutoDiff/derived_differentiable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct TestNoDerivative : EuclideanDifferentiable {
// CHECK-AST: var w: Float
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol
// CHECK-AST: internal struct TangentVector : _Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol
// CHECK-AST: internal typealias TangentVector = TestNoDerivative.TangentVector
// CHECK-AST: internal var differentiableVectorView: TestNoDerivative.TangentVector { get }

Expand All @@ -57,7 +57,7 @@ struct TestPointwiseMultiplicative : Differentiable {
// CHECK-AST: var w: PointwiseMultiplicativeDummy
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: PointwiseMultiplicativeDummy
// CHECK-AST: internal init(w: PointwiseMultiplicativeDummy, technicallyDifferentiable: PointwiseMultiplicativeDummy)
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, PointwiseMultiplicative
// CHECK-AST: internal struct TangentVector : _Differentiable, AdditiveArithmetic, PointwiseMultiplicative
// CHECK-AST: internal typealias TangentVector = TestPointwiseMultiplicative.TangentVector


Expand All @@ -70,14 +70,14 @@ struct TestKeyPathIterable : Differentiable, KeyPathIterable {
// CHECK-AST: var w: Float
// CHECK-AST: @noDerivative internal var technicallyDifferentiable: Float
// CHECK-AST: internal init(w: Float, technicallyDifferentiable: Float)
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol, KeyPathIterable
// CHECK-AST: internal struct TangentVector : _Differentiable, AdditiveArithmetic, ElementaryFunctions, VectorProtocol, KeyPathIterable
// CHECK-AST: internal typealias TangentVector = TestKeyPathIterable.TangentVector

struct GenericTanMember<T : Differentiable> : Differentiable, AdditiveArithmetic {
var x: T.TangentVector
}

// CHECK-AST-LABEL: internal struct GenericTanMember<T> : Differentiable, AdditiveArithmetic where T : Differentiable
// CHECK-AST-LABEL: internal struct GenericTanMember<T> : Differentiable, AdditiveArithmetic where T : _Differentiable
// CHECK-AST: internal var x: T.TangentVector
// CHECK-AST: internal init(x: T.TangentVector)
// CHECK-AST: internal typealias TangentVector = GenericTanMember<T>
Expand All @@ -92,7 +92,7 @@ public struct ConditionallyDifferentiable<T> {
extension ConditionallyDifferentiable : Differentiable where T : Differentiable {}

// CHECK-AST-LABEL: public struct ConditionallyDifferentiable<T> {
// CHECK-AST: @differentiable(wrt: self where T : Differentiable)
// CHECK-AST: @differentiable(wrt: self where T : _Differentiable)
// CHECK-AST: public var x: T
// CHECK-AST: internal init(x: T)
// CHECK-AST: }
Expand Down Expand Up @@ -121,7 +121,7 @@ final class AdditiveArithmeticClass<T : AdditiveArithmetic & Differentiable> : A
}
}

// CHECK-AST-LABEL: final internal class AdditiveArithmeticClass<T> : AdditiveArithmetic, Differentiable where T : AdditiveArithmetic, T : Differentiable {
// CHECK-AST-LABEL: final internal class AdditiveArithmeticClass<T> : AdditiveArithmetic, Differentiable where T : AdditiveArithmetic, T : _Differentiable {
// CHECK-AST: final internal var x: T, y: T
// CHECK-AST: internal struct TangentVector : Differentiable, AdditiveArithmetic
// CHECK-AST: internal struct TangentVector : _Differentiable, AdditiveArithmetic
// CHECK-AST: }
4 changes: 2 additions & 2 deletions test/AutoDiff/differentiable_attr_silgen.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public func foo_indir_ret<T: Differentiable>(_ x: Float, _ y: T) -> T {
return y
}

// CHECK-SIL-LABEL: sil [differentiable source 0 wrt 0, 1 vjp @AD__foo_indir_ret__vjp_src_0_wrt_0_1] [ossa] @foo_indir_ret : $@convention(thin) <T where T : Differentiable> (Float, @in_guaranteed T) -> @out T {
// CHECK-SIL-LABEL: sil [differentiable source 0 wrt 0, 1 vjp @AD__foo_indir_ret__vjp_src_0_wrt_0_1] [ossa] @foo_indir_ret : $@convention(thin) <T where T : _Differentiable> (Float, @in_guaranteed T) -> @out T {
// CHECK-SIL: bb0(%0 : $*T, %1 : $Float, %2 : $*T):

@_silgen_name("dfoo_indir_ret")
Expand Down Expand Up @@ -101,7 +101,7 @@ struct DiffComputedProp : Differentiable & AdditiveArithmetic {
// Check that `@differentiable` attribute is transferred from computed property
// storage declaration to getter accessor.

// CHECK-AST: struct DiffComputedProp : AdditiveArithmetic & Differentiable {
// CHECK-AST: struct DiffComputedProp : _Differentiable & AdditiveArithmetic {
// CHECK-AST-NEXT: var computedProp: Float { get }
// CHECK-AST: }

Expand Down
2 changes: 1 addition & 1 deletion test/AutoDiff/differentiable_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,7 @@ struct TF_521<T: FloatingPoint> {
self.imaginary = imaginary
}
}
// expected-error @+2 {{type 'TF_521<T>' does not conform to protocol 'Differentiable'}}
// expected-error @+2 {{type 'TF_521<T>' does not conform to protocol '_Differentiable'}}
// expected-note @+1 {{do you want to add protocol stubs}}
extension TF_521: Differentiable where T: Differentiable {
// expected-note @+1 {{possibly intended match 'TF_521<T>.TangentVector' does not conform to 'AdditiveArithmetic'}}
Expand Down
2 changes: 1 addition & 1 deletion test/AutoDiff/differentiable_func_debuginfo.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
// Conclusion: mangling coverage is important.

// Minimal dummy compiler-known `Differentiable` protocol.
public protocol Differentiable {
public protocol _Differentiable {
associatedtype TangentVector
}

Expand Down
2 changes: 1 addition & 1 deletion test/AutoDiff/differentiable_func_type_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ extension Vector: Differentiable where T: Differentiable {}
func inferredConformancesGeneric<T, U>(_: @differentiable (Vector<T>) -> Vector<U>) {}

func nondiffVectorFunc(x: Vector<Int>) -> Vector<Int> {}
// expected-error @+1 {{global function 'inferredConformancesGeneric' requires that 'Int' conform to 'Differentiable}}
// expected-error @+1 {{global function 'inferredConformancesGeneric' requires that 'Int' conform to '_Differentiable}}
inferredConformancesGeneric(nondiffVectorFunc)

func diffVectorFunc(x: Vector<Float>) -> Vector<Float> {}
Expand Down
4 changes: 2 additions & 2 deletions test/AutoDiff/differentiable_function_silgen.swift
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,5 @@ func appliesReabstraction(_ f: @escaping @differentiable (Float) -> Float) {
// CHECK-SILGEN: [[REABS_VJP:%.*]] = function_ref @$sS4fIegyd_Iegydo_S4fIegnr_Iegnro_TR : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
// CHECK-SILGEN: [[NEW_VJP:%.*]] = partial_apply [callee_guaranteed] [[REABS_VJP]]([[VJP_COPY]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)
// CHECK-SILGEN: [[NEW_DIFF_FUNC:%.*]] = differentiable_function [wrt 0] [order 1] [[NEW_ORIG]] : $@callee_guaranteed (@in_guaranteed Float) -> @out Float with {[[NEW_JVP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float), [[NEW_VJP]] : $@callee_guaranteed (@in_guaranteed Float) -> (@out Float, @owned @callee_guaranteed (@in_guaranteed Float) -> @out Float)}
// CHECK-SILGEN: [[DIFF_API:%.*]] = function_ref @${{.*}}pullback{{.*}}at{{.*}} : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector
// CHECK-SILGEN: apply [[DIFF_API]]<Float, Float>({{.*}}, [[NEW_DIFF_FUNC]]) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : Differentiable, τ_0_1 : Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector
// CHECK-SILGEN: [[DIFF_API:%.*]] = function_ref @${{.*}}pullback{{.*}}at{{.*}} : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : _Differentiable, τ_0_1 : _Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector
// CHECK-SILGEN: apply [[DIFF_API]]<Float, Float>({{.*}}, [[NEW_DIFF_FUNC]]) : $@convention(thin) <τ_0_0, τ_0_1 where τ_0_0 : _Differentiable, τ_0_1 : _Differentiable> (@in_guaranteed τ_0_0, @guaranteed @differentiable @callee_guaranteed (@in_guaranteed τ_0_0) -> @out τ_0_1) -> @owned @callee_guaranteed (@in_guaranteed τ_0_1.TangentVector) -> @out τ_0_0.TangentVector
Loading