From 26b1459cce9ac01b963e476b86b603e0362f12e3 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Tue, 11 Jun 2019 19:05:25 -0700 Subject: [PATCH 1/3] [stdlib] Rename `VectorProtocol.Scalar` to `VectorSpaceScalar`. A step towards machine learning optimizer revamp. The next step is to make `Tensor.VectorSpaceScalar = Float`. This makes `Tensor` scalar multiplication always work with `Float` instead of `Scalar`. Renaming `VectorProtocol.Scalar` is necessary. Otherwise, `Scalar` is ambiguous within `Tensor`'s type context: it may refer to either: - The witness type of `VectorProtocol.Scalar`, or - The `Scalar` generic parameter. --- include/swift/AST/KnownIdentifiers.def | 2 +- ...rmanceAdditiveArithmeticVectorProtocol.cpp | 52 ++++++++++--------- lib/Sema/DerivedConformanceDifferentiable.cpp | 15 +++--- lib/Sema/DerivedConformances.cpp | 4 +- stdlib/public/core/AutoDiff.swift | 10 ++-- .../public/core/FloatingPointTypes.swift.gyb | 2 +- 6 files changed, 45 insertions(+), 40 deletions(-) diff --git a/include/swift/AST/KnownIdentifiers.def b/include/swift/AST/KnownIdentifiers.def index 9b1bd92d702af..c287162b5cc15 100644 --- a/include/swift/AST/KnownIdentifiers.def +++ b/include/swift/AST/KnownIdentifiers.def @@ -133,7 +133,7 @@ IDENTIFIER_(tensorHandleCount) IDENTIFIER_(typeList) // AdditiveArithmetic, VectorProtocol IDENTIFIER(zero) -IDENTIFIER(Scalar) +IDENTIFIER(VectorSpaceScalar) // Differentiable IDENTIFIER(AllDifferentiableVariables) IDENTIFIER(TangentVector) diff --git a/lib/Sema/DerivedConformanceAdditiveArithmeticVectorProtocol.cpp b/lib/Sema/DerivedConformanceAdditiveArithmeticVectorProtocol.cpp index 36eefdc0c7cd3..fde9d09a55109 100644 --- a/lib/Sema/DerivedConformanceAdditiveArithmeticVectorProtocol.cpp +++ b/lib/Sema/DerivedConformanceAdditiveArithmeticVectorProtocol.cpp @@ -36,7 +36,7 @@ enum MathOperator { Add, // `-(Self, Self)`, `AdditiveArithmetic` requirement Subtract, - // `*(Scalar, Self)`, `VectorProtocol` requirement + // `*(VectorSpaceScalar, Self)`, `VectorProtocol` requirement ScalarMultiply }; @@ -76,52 +76,54 @@ static ValueDecl *getProtocolRequirement(ProtocolDecl *proto, Identifier name) { return lookup.front(); } -// Return the `Scalar` associated type for the given `ValueDecl` if it conforms -// to `VectorProtocol` in the given context. Otherwise, return `nullptr`. -static Type getVectorProtocolScalarAssocType(VarDecl *varDecl, DeclContext *DC) { +// Return the `VectorSpaceScalar` associated type for the given `ValueDecl` if +// it conforms to `VectorProtocol` in the given context. Otherwise, return +// `nullptr`. +static Type getVectorProtocolVectorSpaceScalarAssocType( + VarDecl *varDecl, DeclContext *DC) { auto &C = varDecl->getASTContext(); - auto *vectorNumericProto = C.getProtocol(KnownProtocolKind::VectorProtocol); + auto *vectorProto = C.getProtocol(KnownProtocolKind::VectorProtocol); if (!varDecl->hasInterfaceType()) C.getLazyResolver()->resolveDeclSignature(varDecl); if (!varDecl->hasInterfaceType()) return nullptr; auto varType = DC->mapTypeIntoContext(varDecl->getValueInterfaceType()); - auto conf = TypeChecker::conformsToProtocol(varType, vectorNumericProto, DC, - None); + auto conf = TypeChecker::conformsToProtocol(varType, vectorProto, DC, None); if (!conf) return nullptr; - Type scalarType = conf->getTypeWitnessByName(varType, C.Id_Scalar); - assert(scalarType && "'Scalar' associated type not found"); - return scalarType; + conf->dump(); + return conf->getTypeWitnessByName(varType, C.Id_VectorSpaceScalar); } -// Return the `Scalar` associated type for the given nominal type in the given -// context, or `nullptr` if `Scalar` cannot be derived. -static Type deriveVectorProtocol_Scalar(NominalTypeDecl *nominal, - DeclContext *DC) { +// Return the `VectorSpaceScalar` associated type for the given nominal type in +// the given context, or `nullptr` if `VectorSpaceScalar` cannot be derived. +static Type deriveVectorProtocol_VectorSpaceScalar(NominalTypeDecl *nominal, + DeclContext *DC) { auto &C = DC->getASTContext(); // Nominal type must be a struct. (Zero stored properties is okay.) if (!isa(nominal)) return nullptr; // If all stored properties conform to `VectorProtocol` and have the same - // `Scalar` associated type, return that `Scalar` associated type. - // Otherwise, the `Scalar` type cannot be derived. + // `VectorSpaceScalar` associated type, return that `VectorSpaceScalar` + // associated type. Otherwise, the `VectorSpaceScalar` type cannot be derived. Type sameScalarType; for (auto member : nominal->getStoredProperties()) { if (!member->hasInterfaceType()) C.getLazyResolver()->resolveDeclSignature(member); if (!member->hasInterfaceType()) return nullptr; - auto scalarType = getVectorProtocolScalarAssocType(member, DC); + auto scalarType = getVectorProtocolVectorSpaceScalarAssocType(member, DC); // If stored property does not conform to `VectorProtocol`, return nullptr. if (!scalarType) return nullptr; - // If same `Scalar` type has not been set, set it for the first time. + // If same `VectorSpaceScalar` type has not been set, set it for the first + // time. if (!sameScalarType) { sameScalarType = scalarType; continue; } - // If stored property `Scalar` types do not match, return nullptr. + // If stored property `VectorSpaceScalar` types do not match, return + // nullptr. if (!scalarType->isEqual(sameScalarType)) return nullptr; } @@ -169,8 +171,8 @@ bool DerivedConformance::canDeriveVectorProtocol(NominalTypeDecl *nominal, // value information. if (hasLetStoredPropertyWithInitialValue(nominal)) return false; - // Must be able to derive `Scalar` associated type. - return bool(deriveVectorProtocol_Scalar(nominal, DC)); + // Must be able to derive `VectorSpaceScalar` associated type. + return bool(deriveVectorProtocol_VectorSpaceScalar(nominal, DC)); } // Synthesize body for the given math operator. @@ -303,7 +305,8 @@ static ValueDecl *deriveMathOperator(DerivedConformance &derived, return std::make_pair(selfInterfaceType, selfInterfaceType); case ScalarMultiply: return std::make_pair( - deriveVectorProtocol_Scalar(nominal, parentDC)->mapTypeOutOfContext(), + deriveVectorProtocol_VectorSpaceScalar(nominal, parentDC) + ->mapTypeOutOfContext(), selfInterfaceType); } }; @@ -480,8 +483,9 @@ Type DerivedConformance::deriveVectorProtocol(AssociatedTypeDecl *requirement) { // Diagnose conformances in disallowed contexts. if (checkAndDiagnoseDisallowedContext(requirement)) return nullptr; - if (requirement->getBaseName() == TC.Context.Id_Scalar) - return deriveVectorProtocol_Scalar(Nominal, getConformanceContext()); + if (requirement->getBaseName() == TC.Context.Id_VectorSpaceScalar) + return deriveVectorProtocol_VectorSpaceScalar( + Nominal, getConformanceContext()); TC.diagnose(requirement->getLoc(), diag::broken_vector_protocol_requirement); return nullptr; } diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index 467f5efe0130d..0781408d2e6a4 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -625,8 +625,8 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived, auto diffableType = TypeLoc::withoutLoc(diffableProto->getDeclaredType()); auto *addArithProto = C.getProtocol(KnownProtocolKind::AdditiveArithmetic); auto addArithType = TypeLoc::withoutLoc(addArithProto->getDeclaredType()); - auto *vecNumProto = C.getProtocol(KnownProtocolKind::VectorProtocol); - auto vecNumType = TypeLoc::withoutLoc(vecNumProto->getDeclaredType()); + auto *vectorProto = C.getProtocol(KnownProtocolKind::VectorProtocol); + auto vectorType = TypeLoc::withoutLoc(vectorProto->getDeclaredType()); auto *kpIterableProto = C.getProtocol(KnownProtocolKind::KeyPathIterable); auto kpIterableType = TypeLoc::withoutLoc(kpIterableProto->getDeclaredType()); @@ -651,17 +651,18 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived, }); // Associated struct can derive `VectorProtocol` if the associated types of - // all members conform to `VectorProtocol` and share the same scalar type. + // all members conform to `VectorProtocol` and share the same + // `VectorSpaceScalar` type. Type sameScalarType; bool canDeriveVectorProtocol = canDeriveAdditiveArithmetic && !diffProperties.empty() && llvm::all_of(diffProperties, [&](VarDecl *vd) { auto conf = TC.conformsToProtocol(getAssociatedType(vd, parentDC, id), - vecNumProto, nominal, - None); + vectorProto, nominal, None); if (!conf) return false; - Type scalarType = conf->getTypeWitnessByName(vd->getType(), C.Id_Scalar); + auto scalarType = + conf->getTypeWitnessByName(vd->getType(), C.Id_VectorSpaceScalar); if (!sameScalarType) { sameScalarType = scalarType; return true; @@ -685,7 +686,7 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived, // type, make the associated struct conform to `VectorProtocol` instead of // just `AdditiveArithmetic`. if (canDeriveVectorProtocol) - inherited.push_back(vecNumType); + inherited.push_back(vectorType); auto *structDecl = new (C) StructDecl(SourceLoc(), id, SourceLoc(), /*Inherited*/ C.AllocateCopy(inherited), diff --git a/lib/Sema/DerivedConformances.cpp b/lib/Sema/DerivedConformances.cpp index 778dd3b61509a..6ca2e619e0c14 100644 --- a/lib/Sema/DerivedConformances.cpp +++ b/lib/Sema/DerivedConformances.cpp @@ -369,8 +369,8 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc, return getRequirement(KnownProtocolKind::Differentiable); // SWIFT_ENABLE_TENSORFLOW - // VectorProtocol.Scalar - if (name.isSimpleName(ctx.Id_Scalar)) + // VectorProtocol.VectorSpaceScalar + if (name.isSimpleName(ctx.Id_VectorSpaceScalar)) return getRequirement(KnownProtocolKind::VectorProtocol); return nullptr; diff --git a/stdlib/public/core/AutoDiff.swift b/stdlib/public/core/AutoDiff.swift index d111b61e4211c..7e73d631e26d0 100644 --- a/stdlib/public/core/AutoDiff.swift +++ b/stdlib/public/core/AutoDiff.swift @@ -24,18 +24,18 @@ /// elements in this vector space and have either no shape or a static shape. public protocol VectorProtocol : AdditiveArithmetic { /// The type of scalars in the vector space. - associatedtype Scalar : AdditiveArithmetic + associatedtype VectorSpaceScalar : AdditiveArithmetic - static func * (lhs: Scalar, rhs: Self) -> Self - static func *= (lhs: inout Self, rhs: Scalar) + static func * (lhs: VectorSpaceScalar, rhs: Self) -> Self + static func *= (lhs: inout Self, rhs: VectorSpaceScalar) } public extension VectorProtocol { - static func * (lhs: Self, rhs: Scalar) -> Self { + static func * (lhs: Self, rhs: VectorSpaceScalar) -> Self { return rhs * lhs } - static func *= (lhs: inout Self, rhs: Scalar) { + static func *= (lhs: inout Self, rhs: VectorSpaceScalar) { lhs = rhs * lhs } } diff --git a/stdlib/public/core/FloatingPointTypes.swift.gyb b/stdlib/public/core/FloatingPointTypes.swift.gyb index 6b886154cff9a..5edbb7eb0e8a0 100644 --- a/stdlib/public/core/FloatingPointTypes.swift.gyb +++ b/stdlib/public/core/FloatingPointTypes.swift.gyb @@ -1877,7 +1877,7 @@ extension ${Self} : Strideable { //===----------------------------------------------------------------------===// extension ${Self} : VectorProtocol { - public typealias Scalar = ${Self} + public typealias VectorSpaceScalar = ${Self} } extension ${Self} : Differentiable { From 7f7d5abe60b35c6f59f146cb7322984652578188 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Tue, 11 Jun 2019 19:31:29 -0700 Subject: [PATCH 2/3] Update checkout for tensorflow-swift-apis. --- utils/update_checkout/update-checkout-config.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/update_checkout/update-checkout-config.json b/utils/update_checkout/update-checkout-config.json index b2ebb4d21b460..2fa4992280d83 100644 --- a/utils/update_checkout/update-checkout-config.json +++ b/utils/update_checkout/update-checkout-config.json @@ -349,7 +349,7 @@ "clang-tools-extra": "swift-DEVELOPMENT-SNAPSHOT-2019-06-02-a", "libcxx": "swift-DEVELOPMENT-SNAPSHOT-2019-06-02-a", "tensorflow": "ebc41609e27dcf0998d8970e77a2e1f53e13ac86", - "tensorflow-swift-apis": "b9bcebc7dfd1497e324c24f4c0d8c0b212580d1f", + "tensorflow-swift-apis": "55ddd7ccedf46a3a480ded0affa067eadd91bae0", "indexstore-db": "swift-DEVELOPMENT-SNAPSHOT-2019-06-02-a", "sourcekit-lsp": "swift-DEVELOPMENT-SNAPSHOT-2019-06-02-a" } From e09e918ec6f50d6ca710af5c60f34c39cde3a2f4 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Tue, 11 Jun 2019 20:22:26 -0700 Subject: [PATCH 3/3] Update swift-apis checkout --- utils/update_checkout/update-checkout-config.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/update_checkout/update-checkout-config.json b/utils/update_checkout/update-checkout-config.json index 2fa4992280d83..c37f12b8faf34 100644 --- a/utils/update_checkout/update-checkout-config.json +++ b/utils/update_checkout/update-checkout-config.json @@ -349,7 +349,7 @@ "clang-tools-extra": "swift-DEVELOPMENT-SNAPSHOT-2019-06-02-a", "libcxx": "swift-DEVELOPMENT-SNAPSHOT-2019-06-02-a", "tensorflow": "ebc41609e27dcf0998d8970e77a2e1f53e13ac86", - "tensorflow-swift-apis": "55ddd7ccedf46a3a480ded0affa067eadd91bae0", + "tensorflow-swift-apis": "eea90b71a7bf01ad031e11bb4a753c40e19e55d9", "indexstore-db": "swift-DEVELOPMENT-SNAPSHOT-2019-06-02-a", "sourcekit-lsp": "swift-DEVELOPMENT-SNAPSHOT-2019-06-02-a" }