From 1c6f6103de2b0f5e500dc05a1c7d0a73b6c56d90 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Fri, 14 Jun 2019 17:18:58 -0700 Subject: [PATCH 1/2] Gardening. - Remove unused SWIFT_ENABLE_TENSORFLOW known protocols. - Readd AutoDiff test, previously accidentally removed. --- include/swift/AST/KnownProtocols.def | 2 - lib/IRGen/GenMeta.cpp | 3 -- .../Mandatory/Differentiation.cpp | 24 ++------- test/AutoDiff/separate_tangent_type.swift | 49 +++++++++++++++++++ 4 files changed, 54 insertions(+), 24 deletions(-) create mode 100644 test/AutoDiff/separate_tangent_type.swift diff --git a/include/swift/AST/KnownProtocols.def b/include/swift/AST/KnownProtocols.def index 5256cd16f0ee7..7d6fd183feb07 100644 --- a/include/swift/AST/KnownProtocols.def +++ b/include/swift/AST/KnownProtocols.def @@ -78,8 +78,6 @@ PROTOCOL(Encodable) PROTOCOL(Decodable) // SWIFT_ENABLE_TENSORFLOW PROTOCOL(AdditiveArithmetic) -PROTOCOL(Numeric) -PROTOCOL(FloatingPoint) PROTOCOL(KeyPathIterable) PROTOCOL(TensorArrayProtocol) PROTOCOL(TensorGroup) diff --git a/lib/IRGen/GenMeta.cpp b/lib/IRGen/GenMeta.cpp index 1296c6d544536..4194cc7da6e7e 100644 --- a/lib/IRGen/GenMeta.cpp +++ b/lib/IRGen/GenMeta.cpp @@ -4189,7 +4189,6 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) { case KnownProtocolKind::ExpressibleByColorLiteral: case KnownProtocolKind::ExpressibleByImageLiteral: case KnownProtocolKind::ExpressibleByFileReferenceLiteral: - // SWIFT_ENABLE_TENSORFLOW case KnownProtocolKind::ExpressibleByBuiltinBooleanLiteral: case KnownProtocolKind::ExpressibleByBuiltinExtendedGraphemeClusterLiteral: case KnownProtocolKind::ExpressibleByBuiltinFloatLiteral: @@ -4206,9 +4205,7 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) { case KnownProtocolKind::Decodable: case KnownProtocolKind::StringInterpolationProtocol: // SWIFT_ENABLE_TENSORFLOW - case KnownProtocolKind::FloatingPoint: case KnownProtocolKind::AdditiveArithmetic: - case KnownProtocolKind::Numeric: case KnownProtocolKind::KeyPathIterable: case KnownProtocolKind::TensorArrayProtocol: case KnownProtocolKind::TensorGroup: diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 4c0ceb764091d..5000136f40f92 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -865,18 +865,12 @@ class ADContext { /// Saved for deletion during cleanup. SmallVector generatedAssociatedFunctionReferences; + /// The AdditiveArithmetic protocol in the standard library. + ProtocolDecl *additiveArithmeticProtocol = + astCtx.getProtocol(KnownProtocolKind::AdditiveArithmetic); /// The VectorProtocol protocol in the standard library. ProtocolDecl *vectorProtocolProtocol = astCtx.getProtocol(KnownProtocolKind::VectorProtocol); - /// The Numeric protocol in the standard library. - ProtocolDecl *numericProtocol = - astCtx.getProtocol(KnownProtocolKind::Numeric); - /// The AdditiveArithmetic protocol in the standard library. - ProtocolDecl *additiveArithmeticProtocol = - astCtx.getProtocol(KnownProtocolKind::AdditiveArithmetic); - /// The FloatingPoint protocol in the stanard library. - ProtocolDecl *floatingPointProtocol = - astCtx.getProtocol(KnownProtocolKind::FloatingPoint); /// `AdditiveArithmetic.+` declaration. mutable FuncDecl *cachedPlusFn = nullptr; @@ -926,20 +920,12 @@ class ADContext { return generatedAssociatedFunctionReferences; } - ProtocolDecl *getVectorProtocolProtocol() const { - return vectorProtocolProtocol; - } - - ProtocolDecl *getNumericProtocol() const { - return numericProtocol; - } - ProtocolDecl *getAdditiveArithmeticProtocol() const { return additiveArithmeticProtocol; } - ProtocolDecl *getFloatingPointProtocol() const { - return floatingPointProtocol; + ProtocolDecl *getVectorProtocolProtocol() const { + return vectorProtocolProtocol; } FuncDecl *getPlusDecl() const { diff --git a/test/AutoDiff/separate_tangent_type.swift b/test/AutoDiff/separate_tangent_type.swift new file mode 100644 index 0000000000000..5676ec640ffaa --- /dev/null +++ b/test/AutoDiff/separate_tangent_type.swift @@ -0,0 +1,49 @@ +// RUN: %target-run-simple-swift +// REQUIRES: executable_test + +import StdlibUnittest +#if os(macOS) +import Darwin.C +#else +import Glibc +#endif + +var SeparateTangentTypeTests = TestSuite("SeparateTangentType") + +struct DifferentiableSubset : Differentiable { + @differentiable(wrt: self) + var w: Float + @differentiable(wrt: self) + var b: Float + @noDerivative var flag: Bool + + struct TangentVector : Differentiable, VectorProtocol { + typealias TangentVector = DifferentiableSubset.TangentVector + var w: Float + var b: Float + } + mutating func move(along v: TangentVector) { + w.move(along: v.w) + b.move(along: v.b) + } +} + +SeparateTangentTypeTests.test("Trivial") { + let x = DifferentiableSubset(w: 0, b: 1, flag: false) + let pb = pullback(at: x) { x in x } + expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero) +} + +SeparateTangentTypeTests.test("Initialization") { + let x = DifferentiableSubset(w: 0, b: 1, flag: false) + let pb = pullback(at: x) { x in DifferentiableSubset(w: 1, b: 2, flag: true) } + expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero) +} + +SeparateTangentTypeTests.test("SomeArithmetics") { + let x = DifferentiableSubset(w: 0, b: 1, flag: false) + let pb = pullback(at: x) { x in DifferentiableSubset(w: x.w * x.w, b: x.b * x.b, flag: true) } + expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero) +} + +runAllTests() From 59ed07f3e4d265317e62328d4d045c3978fa94cc Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Sat, 15 Jun 2019 05:12:33 +0000 Subject: [PATCH 2/2] Derived conformances gardening. - Use concise `llvm::erase_if` API. - Use consistent TODO comments. --- .../DerivedConformanceAdditiveArithmetic.cpp | 15 +++++----- lib/Sema/DerivedConformanceVectorProtocol.cpp | 29 ++++++++++--------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/lib/Sema/DerivedConformanceAdditiveArithmetic.cpp b/lib/Sema/DerivedConformanceAdditiveArithmetic.cpp index e5d38c28d892c..ff4ad890217c4 100644 --- a/lib/Sema/DerivedConformanceAdditiveArithmetic.cpp +++ b/lib/Sema/DerivedConformanceAdditiveArithmetic.cpp @@ -48,20 +48,21 @@ static StringRef getMathOperatorName(MathOperator op) { } // Return the protocol requirement with the specified name. +// TODO: Move function to shared place for use with other derived conformances. static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) { auto lookup = proto->lookupDirect(name); - lookup.erase(std::remove_if(lookup.begin(), lookup.end(), - [](ValueDecl *v) { - return !isa( - v->getDeclContext()) || - !v->isProtocolRequirement(); - }), - lookup.end()); + // Erase declarations that are not protocol requirements. + // This is important for removing default implementations of the same name. + llvm::erase_if(lookup, [](ValueDecl *v) { + return !isa(v->getDeclContext()) || + !v->isProtocolRequirement(); + }); assert(lookup.size() == 1 && "Ambiguous protocol requirement"); return lookup.front(); } // Return true if given nominal type has a `let` stored with an initial value. +// TODO: Move function to shared place for use with other derived conformances. static bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal) { return llvm::any_of(nominal->getStoredProperties(), [&](VarDecl *v) { return v->isLet() && v->hasInitialValue(); diff --git a/lib/Sema/DerivedConformanceVectorProtocol.cpp b/lib/Sema/DerivedConformanceVectorProtocol.cpp index de953e6c52a9f..677809e5921bd 100644 --- a/lib/Sema/DerivedConformanceVectorProtocol.cpp +++ b/lib/Sema/DerivedConformanceVectorProtocol.cpp @@ -31,19 +31,27 @@ using namespace swift; // Return the protocol requirement with the specified name. +// TODO: Move function to shared place for use with other derived conformances. static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) { auto lookup = proto->lookupDirect(name); - lookup.erase(std::remove_if(lookup.begin(), lookup.end(), - [](ValueDecl *v) { - return !isa( - v->getDeclContext()) || - !v->isProtocolRequirement(); - }), - lookup.end()); + // Erase declarations that are not protocol requirements. + // This is important for removing default implementations of the same name. + llvm::erase_if(lookup, [](ValueDecl *v) { + return !isa(v->getDeclContext()) || + !v->isProtocolRequirement(); + }); assert(lookup.size() == 1 && "Ambiguous protocol requirement"); return lookup.front(); } +// Return true if given nominal type has a `let` stored with an initial value. +// TODO: Move function to shared place for use with other derived conformances. +static bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal) { + return llvm::any_of(nominal->getStoredProperties(), [&](VarDecl *v) { + return v->isLet() && v->hasInitialValue(); + }); +} + // Return the `VectorSpaceScalar` associated type for the given `ValueDecl` if // it conforms to `VectorProtocol` in the given context. Otherwise, return // `nullptr`. @@ -97,13 +105,6 @@ static Type deriveVectorProtocol_VectorSpaceScalar(NominalTypeDecl *nominal, return sameScalarType; } -// Return true if given nominal type has a `let` stored with an initial value. -static bool hasLetStoredPropertyWithInitialValue(NominalTypeDecl *nominal) { - return llvm::any_of(nominal->getStoredProperties(), [&](VarDecl *v) { - return v->isLet() && v->hasInitialValue(); - }); -} - bool DerivedConformance::canDeriveVectorProtocol(NominalTypeDecl *nominal, DeclContext *DC) { // Must not have any `let` stored properties with an initial value.