From bc8d3c73f7ba06315d021afc5906bdc3d8df468b Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Wed, 12 Jun 2019 21:20:39 -0700 Subject: [PATCH 1/3] [AutoDiff] Change `Differentiable.moved(along:)` to `move(along:)`. Change `Differentiable.moved(along:)` to `mutating func move(along:)`. This is important for upcoming `Differentiable` class support. Update `Differentiable` derived conformances (logic and diagnostics). Update tests. --- include/swift/AST/DiagnosticsSema.def | 15 +- include/swift/AST/KnownIdentifiers.def | 2 +- lib/Sema/DerivedConformanceDifferentiable.cpp | 192 +++++++----------- lib/Sema/DerivedConformances.cpp | 4 +- stdlib/public/core/Array.swift | 13 +- stdlib/public/core/AutoDiff.swift | 38 ++-- .../public/core/FloatingPointTypes.swift.gyb | 4 + test/AutoDiff/anyderivative.swift | 13 +- test/AutoDiff/array.swift | 4 +- test/AutoDiff/autodiff_diagnostics.swift | 8 +- .../derived_differentiable_properties.swift | 4 +- .../differentiable_attr_type_checking.swift | 6 +- .../e2e_differentiable_property.swift | 19 +- test/AutoDiff/separate_cotangent_type.swift | 48 ----- test/Sema/struct_differentiable.swift | 55 +++-- 15 files changed, 172 insertions(+), 253 deletions(-) delete mode 100644 test/AutoDiff/separate_cotangent_type.swift diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 4c6d59e76e3e8..d439f8239b9c6 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2520,12 +2520,15 @@ ERROR(broken_differentiable_requirement,none, "Differentiable protocol is broken: unexpected requirement", ()) WARNING(differentiable_nondiff_type_implicit_noderivative_fixit,none, "stored property %0 has no derivative because it does not conform to " - "'Differentiable'; add '@noDerivative' to make it explicit", - (Identifier)) -WARNING(differentiable_constant_property_implicit_noderivative_fixit,none, - "'let' properties with a default value do not have a derivative; add " - "'@noDerivative' to make it explicit, or change it to 'var' to allow " - "derivatives", ()) + "'Differentiable'; add an explicit '@noDerivative' attribute" + "%select{|, or conform %1 to 'AdditiveArithmetic'}2", + (Identifier, Identifier, bool)) +WARNING(differentiable_let_property_implicit_noderivative_fixit,none, + "synthesis of the 'Differentiable.move(along:)' requirement for %1 " + "requires all stored properties to be mutable; use 'var' instead, or add " + "an explicit '@noDerivative' attribute" + "%select{|, or conform %1 to 'AdditiveArithmetic'}2", + (Identifier, Identifier, bool)) NOTE(codable_extraneous_codingkey_case_here,none, "CodingKey case %0 does not match any stored properties", (Identifier)) diff --git a/include/swift/AST/KnownIdentifiers.def b/include/swift/AST/KnownIdentifiers.def index 9fea573c29da0..4d740bd8f5abc 100644 --- a/include/swift/AST/KnownIdentifiers.def +++ b/include/swift/AST/KnownIdentifiers.def @@ -139,7 +139,7 @@ IDENTIFIER(scaled) IDENTIFIER(AllDifferentiableVariables) IDENTIFIER(TangentVector) IDENTIFIER(allDifferentiableVariables) -IDENTIFIER(moved) +IDENTIFIER(move) // Kinds of layout constraints IDENTIFIER_WITH_NAME(UnknownLayout, "_UnknownLayout") diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index 6d07c8dc83e45..fd1502ebd7519 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -57,7 +57,7 @@ getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, for (auto *vd : nominal->getStoredProperties()) { if (vd->getAttrs().hasAttribute()) continue; - if (vd->isLet() && vd->hasInitialValue()) + if (vd->isLet()) continue; if (!vd->hasInterfaceType()) C.getLazyResolver()->resolveDeclSignature(vd); @@ -74,9 +74,9 @@ getStoredPropertiesForDifferentiation(NominalTypeDecl *nominal, // Convert the given `ValueDecl` to a `StructDecl` if it is a `StructDecl` or a // `TypeDecl` with an underlying struct type. Otherwise, return `nullptr`. static StructDecl *convertToStructDecl(ValueDecl *v) { - if (auto structDecl = dyn_cast(v)) + if (auto *structDecl = dyn_cast(v)) return structDecl; - auto typeDecl = dyn_cast(v); + auto *typeDecl = dyn_cast(v); if (!typeDecl) return nullptr; return dyn_cast_or_null( @@ -113,7 +113,7 @@ static StructDecl *getAssociatedStructDecl(DeclContext *DC, Identifier id) { assert(conf && "Nominal must conform to `Differentiable`"); Type assocType = conf->getTypeWitnessByName(DC->getSelfTypeInContext(), id); assert(assocType && "`Differentiable` protocol associated type not found"); - auto structDecl = dyn_cast(assocType->getAnyNominal()); + auto *structDecl = dyn_cast(assocType->getAnyNominal()); assert(structDecl && "Associated type must be a struct type"); return structDecl; } @@ -141,7 +141,7 @@ bool DerivedConformance::canDeriveDifferentiable(NominalTypeDecl *nominal, auto isValidAssocTypeCandidate = [&](ValueDecl *v, bool checkAdditiveArithmetic = false) -> StructDecl * { // Valid candidate must be a struct or a typealias to a struct. - auto structDecl = convertToStructDecl(v); + auto *structDecl = convertToStructDecl(v); if (!structDecl) return nullptr; // Valid candidate must either: @@ -219,20 +219,6 @@ static void deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl, auto *nominal = parentDC->getSelfNominalTypeDecl(); auto &C = nominal->getASTContext(); - // Create memberwise initializer for the returned nominal type: - // `Nominal.init(...)`. - auto retNominalInterfaceType = - funcDecl->getMethodInterfaceType()->getAs()->getResult(); - auto *retNominal = retNominalInterfaceType->getAnyNominal(); - auto retNominalType = funcDecl->mapTypeIntoContext(retNominalInterfaceType); - auto *retNominalTypeExpr = TypeExpr::createImplicit(retNominalType, C); - auto *memberwiseInitDecl = retNominal->getEffectiveMemberwiseInitializer(); - assert(memberwiseInitDecl && "Memberwise initializer must exist"); - auto *initDRE = - new (C) DeclRefExpr(memberwiseInitDecl, DeclNameLoc(), /*Implicit*/ true); - initDRE->setFunctionRefKind(FunctionRefKind::SingleApply); - auto *initExpr = new (C) ConstructorRefCallExpr(initDRE, retNominalTypeExpr); - // Get method protocol requirement. auto *diffProto = C.getProtocol(KnownProtocolKind::Differentiable); auto *methodReq = getProtocolRequirement(diffProto, methodName); @@ -245,36 +231,20 @@ static void deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl, auto *paramDRE = new (C) DeclRefExpr(paramDecl, DeclNameLoc(), /*Implicit*/ true); - // Hash properties for differentiation into a set for fast lookup. - SmallVector diffProps; - getStoredPropertiesForDifferentiation(nominal, parentDC, diffProps); - SmallPtrSet diffPropsSet(diffProps.begin(), diffProps.end()); + SmallVector diffProperties; + getStoredPropertiesForDifferentiation(nominal, parentDC, diffProperties); // Create call expression applying a member method to a parameter member. // Format: `.method(.)`. - // Example: `x.moved(along: direction.x)`. - auto createMemberMethodCallExpr = [&](VarDecl *retNominalMember) -> Expr * { - // Find `Self` member corresponding to member from returned nominal type. - VarDecl *selfMember = nullptr; - for (auto candidate : nominal->getStoredProperties()) { - if (candidate->getName() == retNominalMember->getName()) { - selfMember = candidate; - break; - } - } - assert(selfMember && "Could not find corresponding self member"); - // If member is not for differentiation, create direct reference to member. - if (!diffPropsSet.count(selfMember)) - return new (C) MemberRefExpr(selfDRE, SourceLoc(), selfMember, - DeclNameLoc(), /*Implicit*/ true); - // Otherwise, construct member method call. - auto module = nominal->getModuleContext(); - auto selfMemberType = - parentDC->mapTypeIntoContext(selfMember->getValueInterfaceType()); - auto confRef = module->lookupConformance(selfMemberType, diffProto); + // Example: `x.move(along: direction.x)`. + auto createMemberMethodCallExpr = [&](VarDecl *member) -> Expr * { + auto *module = nominal->getModuleContext(); + auto memberType = + parentDC->mapTypeIntoContext(member->getValueInterfaceType()); + auto confRef = module->lookupConformance(memberType, diffProto); assert(confRef && "Member does not conform to `Differentiable`"); - // Get member type's method, e.g. `Member.moved(along:)`. + // Get member type's method, e.g. `Member.move(along:)`. // Use protocol requirement declaration for the method by default: this // will be dynamically dispatched. ValueDecl *memberMethodDecl = methodReq; @@ -288,57 +258,49 @@ static void deriveBodyDifferentiable_method(AbstractFunctionDecl *funcDecl, new (C) DeclRefExpr(memberMethodDecl, DeclNameLoc(), /*Implicit*/ true); memberMethodDRE->setFunctionRefKind(FunctionRefKind::SingleApply); - // Create reference to member method: `x.moved(along:)`. + // Create reference to member method: `x.move(along:)`. auto memberExpr = - new (C) MemberRefExpr(selfDRE, SourceLoc(), selfMember, DeclNameLoc(), + new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(), /*Implicit*/ true); auto memberMethodExpr = new (C) DotSyntaxCallExpr(memberMethodDRE, SourceLoc(), memberExpr); // Create reference to parameter member: `direction.x`. VarDecl *paramMember = nullptr; - auto paramNominal = paramDecl->getType()->getAnyNominal(); + auto *paramNominal = paramDecl->getType()->getAnyNominal(); assert(paramNominal && "Parameter should have a nominal type"); // Find parameter member corresponding to returned nominal member. - for (auto candidate : paramNominal->getStoredProperties()) { - if (candidate->getName() == retNominalMember->getName()) { + for (auto *candidate : paramNominal->getStoredProperties()) { + if (candidate->getName() == member->getName()) { paramMember = candidate; break; } } assert(paramMember && "Could not find corresponding parameter member"); - auto paramMemberExpr = + auto *paramMemberExpr = new (C) MemberRefExpr(paramDRE, SourceLoc(), paramMember, DeclNameLoc(), /*Implicit*/ true); - // Create expression: `x.moved(along: direction.x)`. + // Create expression: `x.move(along: direction.x)`. return CallExpr::createImplicit(C, memberMethodExpr, {paramMemberExpr}, {methodParamLabel}); }; // Create array of member method call expressions. - llvm::SmallVector memberMethodCallExprs; + llvm::SmallVector memberMethodCallExprs; llvm::SmallVector memberNames; - for (auto *member : retNominal->getStoredProperties()) { - // Initialized `let` properties don't get an argument in memberwise - // initializers. - if (member->isLet() && member->getParentInitializer()) - continue; + for (auto *member : diffProperties) { memberMethodCallExprs.push_back(createMemberMethodCallExpr(member)); memberNames.push_back(member->getName()); } - // Call memberwise initializer with member method call expressions. - auto *callExpr = - CallExpr::createImplicit(C, initExpr, memberMethodCallExprs, memberNames); - ASTNode returnStmt = new (C) ReturnStmt(SourceLoc(), callExpr, true); - funcDecl->setBody( - BraceStmt::create(C, SourceLoc(), returnStmt, SourceLoc(), true)); + funcDecl->setBody(BraceStmt::create( + C, SourceLoc(), memberMethodCallExprs, SourceLoc(), true)); } -// Synthesize body for `moved(along:)`. -static void deriveBodyDifferentiable_moved(AbstractFunctionDecl *funcDecl, +// Synthesize body for `move(along:)`. +static void deriveBodyDifferentiable_move(AbstractFunctionDecl *funcDecl, void *) { auto &C = funcDecl->getASTContext(); - deriveBodyDifferentiable_method(funcDecl, C.Id_moved, + deriveBodyDifferentiable_method(funcDecl, C.Id_move, C.getIdentifier("along")); } @@ -347,10 +309,9 @@ static ValueDecl *deriveDifferentiable_method( DerivedConformance &derived, Identifier methodName, Identifier argumentName, Identifier parameterName, Type parameterType, Type returnType, AbstractFunctionDecl::BodySynthesizer bodySynthesizer) { - auto nominal = derived.Nominal; - auto &TC = derived.TC; + auto *nominal = derived.Nominal; auto &C = derived.TC.Context; - auto parentDC = derived.getConformanceContext(); + auto *parentDC = derived.getConformanceContext(); auto *param = new (C) ParamDecl(VarDecl::Specifier::Default, SourceLoc(), SourceLoc(), @@ -359,15 +320,16 @@ static ValueDecl *deriveDifferentiable_method( ParameterList *params = ParameterList::create(C, {param}); DeclName declName(C, methodName, params); - auto funcDecl = FuncDecl::create(C, SourceLoc(), StaticSpellingKind::None, - SourceLoc(), declName, SourceLoc(), - /*Throws*/ false, SourceLoc(), - /*GenericParams=*/nullptr, params, - TypeLoc::withoutLoc(returnType), parentDC); + auto *funcDecl = FuncDecl::create(C, SourceLoc(), StaticSpellingKind::None, + SourceLoc(), declName, SourceLoc(), + /*Throws*/ false, SourceLoc(), + /*GenericParams=*/nullptr, params, + TypeLoc::withoutLoc(returnType), parentDC); + funcDecl->setSelfAccessKind(SelfAccessKind::Mutating); funcDecl->setImplicit(); funcDecl->setBodySynthesizer(bodySynthesizer.Fn, bodySynthesizer.Context); - if (auto env = parentDC->getGenericEnvironmentOfContext()) + if (auto *env = parentDC->getGenericEnvironmentOfContext()) funcDecl->setGenericEnvironment(env); funcDecl->computeType(); funcDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); @@ -376,36 +338,21 @@ static ValueDecl *deriveDifferentiable_method( derived.addMembersToConformanceContext({funcDecl}); C.addSynthesizedDecl(funcDecl); - // Returned nominal type must define a memberwise initializer. - // Add memberwise initializer if necessary. - auto returnNominal = returnType->getAnyNominal(); - assert(returnNominal && "Return type must be a nominal type"); - if (!returnNominal->getEffectiveMemberwiseInitializer()) { - // The implicit memberwise constructor must be explicitly created so that - // it can called in `Differentiable` methods. Normally, the memberwise - // constructor is synthesized during SILGen, which is too late. - auto *initDecl = createImplicitConstructor( - TC, returnNominal, ImplicitConstructorKind::Memberwise); - returnNominal->addMember(initDecl); - C.addSynthesizedDecl(initDecl); - } - return funcDecl; } -// Synthesize the `moved(along:)` function declaration. -static ValueDecl *deriveDifferentiable_moved(DerivedConformance &derived) { +// Synthesize the `move(along:)` function declaration. +static ValueDecl *deriveDifferentiable_move(DerivedConformance &derived) { auto &C = derived.TC.Context; - auto parentDC = derived.getConformanceContext(); - auto selfInterfaceType = parentDC->getDeclaredInterfaceType(); + auto *parentDC = derived.getConformanceContext(); auto *tangentDecl = getAssociatedStructDecl(parentDC, C.Id_TangentVector); auto tangentType = tangentDecl->getDeclaredInterfaceType(); return deriveDifferentiable_method( - derived, C.Id_moved, C.getIdentifier("along"), - C.getIdentifier("direction"), tangentType, selfInterfaceType, - {deriveBodyDifferentiable_moved, nullptr}); + derived, C.Id_move, C.getIdentifier("along"), + C.getIdentifier("direction"), tangentType, C.TheEmptyTupleType, + {deriveBodyDifferentiable_move, nullptr}); } // Return the underlying `allDifferentiableVariables` of a VarDecl `x`. @@ -416,7 +363,7 @@ static ValueDecl *getUnderlyingAllDiffableVariables(DeclContext *DC, auto *module = DC->getParentModule(); auto &C = module->getASTContext(); auto *diffableProto = C.getProtocol(KnownProtocolKind::Differentiable); - auto allDiffableVarsReq = + auto *allDiffableVarsReq = getProtocolRequirement(diffableProto, C.Id_allDifferentiableVariables); if (!varDecl->hasInterfaceType()) C.getLazyResolver()->resolveDeclSignature(varDecl); @@ -476,7 +423,7 @@ static void derivedBody_allDifferentiableVariablesGetter( Expr *memberExpr = new (C) MemberRefExpr(selfDRE, SourceLoc(), member, DeclNameLoc(), /*Implicit*/ true); member->setInterfaceType(member->getValueInterfaceType()); - auto memberAllDiffableVarsDecl = + auto *memberAllDiffableVarsDecl = getUnderlyingAllDiffableVariables(parentDC, member); if (member != memberAllDiffableVarsDecl) { memberExpr = new (C) MemberRefExpr(memberExpr, SourceLoc(), @@ -513,7 +460,7 @@ static void derivedBody_allDifferentiableVariablesSetter( // Map `AllDifferentiableVariables` struct members to their names for // efficient lookup. llvm::DenseMap diffPropertyMap; - for (auto member : allDiffableVarsStruct->getStoredProperties()) + for (auto *member : allDiffableVarsStruct->getStoredProperties()) diffPropertyMap[member->getName()] = member; SmallVector assignExprs; @@ -524,7 +471,7 @@ static void derivedBody_allDifferentiableVariablesSetter( if (member->isLet()) continue; // Create lhs: either `self.x` or `self.x.allDifferentiableVariables`. - auto lhsAllDiffableVars = + auto *lhsAllDiffableVars = getUnderlyingAllDiffableVariables(parentDC, member); Expr *lhs; if (member == lhsAllDiffableVars) { @@ -558,7 +505,7 @@ deriveDifferentiable_allDifferentiableVariables(DerivedConformance &derived) { auto &C = TC.Context; // Get `AllDifferentiableVariables` struct. - auto allDiffableVarsStruct = + auto *allDiffableVarsStruct = getAssociatedStructDecl(parentDC, C.Id_AllDifferentiableVariables); auto returnInterfaceTy = allDiffableVarsStruct->getDeclaredInterfaceType(); @@ -599,8 +546,8 @@ static std::pair getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived, Identifier id) { auto &TC = derived.TC; - auto parentDC = derived.getConformanceContext(); - auto nominal = derived.Nominal; + auto *parentDC = derived.getConformanceContext(); + auto *nominal = derived.Nominal; auto &C = nominal->getASTContext(); assert(id == C.Id_TangentVector || id == C.Id_AllDifferentiableVariables); @@ -694,7 +641,7 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived, for (auto *member : diffProperties) { // Add this member's corresponding associated type to the parent's // associated struct. - auto newMember = new (C) VarDecl( + auto *newMember = new (C) VarDecl( member->isStatic(), member->getSpecifier(), member->isCaptureList(), /*NameLoc*/ SourceLoc(), member->getName(), structDecl); // NOTE: `newMember` is not marked as implicit here, because that affects @@ -776,7 +723,7 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived, C.addSynthesizedDecl(initDecl); // After memberwise initializer is synthesized, mark members as implicit. - for (auto member : structDecl->getStoredProperties()) + for (auto *member : structDecl->getStoredProperties()) member->setImplicit(); derived.addMembersToConformanceContext({structDecl}); @@ -792,7 +739,7 @@ static void addAssociatedTypeAliasDecl(Identifier name, StructDecl *target, TypeChecker &TC) { auto &C = TC.Context; - auto nominal = sourceDC->getSelfNominalTypeDecl(); + auto *nominal = sourceDC->getSelfNominalTypeDecl(); assert(nominal && "Expected `DeclContext` to be a nominal type"); auto lookup = nominal->lookupDirect(name); assert(lookup.size() < 2 && @@ -806,7 +753,7 @@ static void addAssociatedTypeAliasDecl(Identifier name, return; } // Otherwise, create a new typealias. - auto aliasDecl = new (C) + auto *aliasDecl = new (C) TypeAliasDecl(SourceLoc(), SourceLoc(), name, SourceLoc(), {}, sourceDC); aliasDecl->setUnderlyingType(target->getDeclaredInterfaceType()); aliasDecl->setImplicit(); @@ -822,7 +769,7 @@ static void addAssociatedTypeAliasDecl(Identifier name, // Diagnose stored properties in the nominal that do not have an explicit // `@noDerivative` attribute, but either: // - Do not conform to `Differentiable`. -// - Are a `let` stored property with an initial value. +// - Are a `let` stored property. // Emit a warning and a fixit so that users will make the attribute explicit. static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC, NominalTypeDecl *nominal, @@ -841,9 +788,8 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC, bool conformsToDifferentiable = TC.conformsToProtocol(varType, diffableProto, nominal, None).hasValue(); - bool isConstantProperty = vd->isLet() && vd->hasInitialValue(); // If stored property should not be diagnosed, continue. - if (conformsToDifferentiable && !isConstantProperty) + if (conformsToDifferentiable && !vd->isLet()) continue; // Otherwise, add an implicit `@noDerivative` attribute. vd->getAttrs().add( @@ -851,18 +797,26 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC, auto loc = vd->getLoc().isValid() ? vd->getLoc() : DC->getAsDecl()->getLoc(); assert(loc.isValid() && "Expected valid source location"); - // Diagnose stored property with fixit. + // If nominal type can conform to `AdditiveArithmetic`, suggest conforming + // adding a conformance to `AdditiveArithmetic`. + // `Differentiable` protocol requirements all have default implementations + // when `Self` conforms to `AdditiveArithmetic`, so `Differentiable` + // derived conformances will no longer be necessary. + bool nominalCanDeriveAdditiveArithmetic = + DerivedConformance::canDeriveAdditiveArithmetic(nominal, DC); if (!conformsToDifferentiable) { TC.diagnose(loc, diag::differentiable_nondiff_type_implicit_noderivative_fixit, - vd->getName()) + vd->getName(), nominal->getName(), + nominalCanDeriveAdditiveArithmetic) .fixItInsert(vd->getAttributeInsertionLoc(/*forModifier*/ false), "@noDerivative "); continue; } - TC.diagnose( - loc, - diag::differentiable_constant_property_implicit_noderivative_fixit) + TC.diagnose(loc, + diag::differentiable_let_property_implicit_noderivative_fixit, + vd->getName(), nominal->getName(), + nominalCanDeriveAdditiveArithmetic) .fixItInsert(vd->getAttributeInsertionLoc(/*forModifier*/ false), "@noDerivative "); @@ -947,8 +901,8 @@ static Type deriveDifferentiable_AssociatedStruct(DerivedConformance &derived, Identifier id) { auto &TC = derived.TC; - auto parentDC = derived.getConformanceContext(); - auto nominal = derived.Nominal; + auto *parentDC = derived.getConformanceContext(); + auto *nominal = derived.Nominal; auto &C = nominal->getASTContext(); // Get all stored properties for differentation. @@ -1054,8 +1008,8 @@ ValueDecl *DerivedConformance::deriveDifferentiable(ValueDecl *requirement) { // Diagnose conformances in disallowed contexts. if (checkAndDiagnoseDisallowedContext(requirement)) return nullptr; - if (requirement->getBaseName() == TC.Context.Id_moved) - return deriveDifferentiable_moved(*this); + if (requirement->getBaseName() == TC.Context.Id_move) + return deriveDifferentiable_move(*this); if (requirement->getBaseName() == TC.Context.Id_allDifferentiableVariables) return deriveDifferentiable_allDifferentiableVariables(*this); TC.diagnose(requirement->getLoc(), diag::broken_differentiable_requirement); diff --git a/lib/Sema/DerivedConformances.cpp b/lib/Sema/DerivedConformances.cpp index 4ac0662b280f2..e226b12f694d5 100644 --- a/lib/Sema/DerivedConformances.cpp +++ b/lib/Sema/DerivedConformances.cpp @@ -299,9 +299,9 @@ ValueDecl *DerivedConformance::getDerivableRequirement(TypeChecker &tc, } // SWIFT_ENABLE_TENSORFLOW - // Differentiable.moved(along:) + // Differentiable.move(along:) if (name.isCompoundName() && - name.getBaseName() == ctx.Id_moved) { + name.getBaseName() == ctx.Id_move) { auto argumentNames = name.getArgumentNames(); if (argumentNames.size() == 1 && argumentNames[0] == ctx.getIdentifier("along")) { diff --git a/stdlib/public/core/Array.swift b/stdlib/public/core/Array.swift index c16ec3332f37d..05debca429f94 100644 --- a/stdlib/public/core/Array.swift +++ b/stdlib/public/core/Array.swift @@ -1973,13 +1973,14 @@ extension Array where Element : Differentiable { } } - public func moved(along direction: TangentVector) -> DifferentiableView { + public mutating func move(along direction: TangentVector) { precondition( base.count == direction.base.count, "cannot move Array.DifferentiableView with count \(base.count) along " + "direction with different count \(direction.base.count)") - return DifferentiableView( - zip(base, direction.base).map { $0.moved(along: $1) }) + for i in base.indices { + base[i].move(along: direction.base[i]) + } } } } @@ -2072,8 +2073,10 @@ extension Array : Differentiable where Element : Differentiable { } } - public func moved(along direction: TangentVector) -> Array { - return DifferentiableView(self).moved(along: direction).base + public mutating func move(along direction: TangentVector) { + var view = DifferentiableView(self) + view.move(along: direction) + self = view.base } } diff --git a/stdlib/public/core/AutoDiff.swift b/stdlib/public/core/AutoDiff.swift index a8ce774a9ec7e..d12acf95d8a2d 100644 --- a/stdlib/public/core/AutoDiff.swift +++ b/stdlib/public/core/AutoDiff.swift @@ -65,10 +65,9 @@ public protocol Differentiable { /// All differentiable variables of this value. var allDifferentiableVariables: AllDifferentiableVariables { get set } - /// Returns `self` moved along the value space towards the given tangent - /// vector. In Riemannian geometry (mathematics), this represents an - /// exponential map. - func moved(along direction: TangentVector) -> Self + /// Moves `self` along the value space towards the given tangent vector. In + /// Riemannian geometry (mathematics), this represents an exponential map. + mutating func move(along direction: TangentVector) @available(*, deprecated, message: "'CotangentVector' is now equal to 'TangentVector' and will be removed") @@ -82,13 +81,9 @@ public extension Differentiable where AllDifferentiableVariables == Self { } } -// FIXME: The `Self : AdditiveArithmetic` constraint should be implied by -// `TangentVector == Self`, but the type checker errors out when it does not -// exist. -public extension Differentiable - where TangentVector == Self, Self : AdditiveArithmetic { - func moved(along direction: TangentVector) -> Self { - return self + direction +public extension Differentiable where TangentVector == Self { + mutating func move(along direction: TangentVector) { + self += direction } } @@ -451,7 +446,7 @@ internal protocol _AnyDerivativeBox { // `Differentiable` requirements. var _allDifferentiableVariables: _AnyDerivativeBox { get } - func _moved(along direction: _AnyDerivativeBox) -> _AnyDerivativeBox + mutating func _move(along direction: _AnyDerivativeBox) /// The underlying base value, type-erased to `Any`. var _typeErasedBase: Any { get } @@ -555,18 +550,17 @@ internal struct _ConcreteDerivativeBox : _AnyDerivativeBox return _ConcreteDerivativeBox(_base.allDifferentiableVariables) } - func _moved(along direction: _AnyDerivativeBox) -> _AnyDerivativeBox { - if _isOpaqueZero() { - return direction - } + mutating func _move(along direction: _AnyDerivativeBox) { if direction._isOpaqueZero() { - return self + return } + // The case where `self._isOpaqueZero()` returns true is handled in + // `AnyDerivative.move(along:)`. guard let directionBase = direction._unboxed(to: T.TangentVector.self) else { _derivativeTypeMismatch(T.self, type(of: direction._typeErasedBase)) } - return _ConcreteDerivativeBox(_base.moved(along: directionBase)) + _base.move(along: directionBase) } } @@ -661,7 +655,11 @@ public struct AnyDerivative : Differentiable & AdditiveArithmetic { get { return AnyDerivative(_box: _box._allDifferentiableVariables) } // set { _box._allDifferentiableVariables = newValue._box } } - public func moved(along direction: TangentVector) -> AnyDerivative { - return AnyDerivative(_box: _box._moved(along: direction._box)) + public mutating func move(along direction: TangentVector) { + if _box._isOpaqueZero() { + _box = direction._box + return + } + _box._move(along: direction._box) } } diff --git a/stdlib/public/core/FloatingPointTypes.swift.gyb b/stdlib/public/core/FloatingPointTypes.swift.gyb index 0c471e36d3572..6ba7a71bd7841 100644 --- a/stdlib/public/core/FloatingPointTypes.swift.gyb +++ b/stdlib/public/core/FloatingPointTypes.swift.gyb @@ -1887,6 +1887,10 @@ extension ${Self} : VectorProtocol { extension ${Self} : Differentiable { public typealias TangentVector = ${Self} public typealias AllDifferentiableVariables = ${Self} + + public mutating func move(along direction: TangentVector) { + self += direction + } } //===----------------------------------------------------------------------===// diff --git a/test/AutoDiff/anyderivative.swift b/test/AutoDiff/anyderivative.swift index d0fde4db62b69..d5e8b9e749037 100644 --- a/test/AutoDiff/anyderivative.swift +++ b/test/AutoDiff/anyderivative.swift @@ -6,10 +6,19 @@ import StdlibUnittest var AnyDerivativeTests = TestSuite("AnyDerivative") struct Vector : Differentiable { - let x, y: Float + var x, y: Float } struct Generic : Differentiable { - let x: T + var x: T +} + +extension AnyDerivative { + // This exists only to faciliate testing. + func moved(along direction: TangentVector) -> Self { + var result = self + result.move(along: direction) + return result + } } AnyDerivativeTests.test("Vector") { diff --git a/test/AutoDiff/array.swift b/test/AutoDiff/array.swift index 8e01eaf9466ce..d4dea963db865 100644 --- a/test/AutoDiff/array.swift +++ b/test/AutoDiff/array.swift @@ -29,8 +29,8 @@ ArrayAutodiffTests.test("ArraySubscript") { ArrayAutodiffTests.test("ArrayConcat") { struct TwoArrays : Differentiable { - let a: [Float] - let b: [Float] + var a: [Float] + var b: [Float] } func sumFirstThreeConcatted(_ arrs: TwoArrays) -> Float { diff --git a/test/AutoDiff/autodiff_diagnostics.swift b/test/AutoDiff/autodiff_diagnostics.swift index 92e5d85caa1c7..34fb22cb1f6f5 100644 --- a/test/AutoDiff/autodiff_diagnostics.swift +++ b/test/AutoDiff/autodiff_diagnostics.swift @@ -25,10 +25,10 @@ _ = gradient(at: 0, in: one_to_one_0) // okay! //===----------------------------------------------------------------------===// struct S { - let p: Float + var p: Float } - extension S : Differentiable, VectorProtocol { + // Test custom `TangentVector` type with non-matching stored property name. struct TangentVector: Differentiable, VectorProtocol { var dp: Float } @@ -39,8 +39,8 @@ extension S : Differentiable, VectorProtocol { static func - (lhs: S, rhs: S) -> S { return S(p: lhs.p - rhs.p) } static func * (lhs: Float, rhs: S) -> S { return S(p: lhs * rhs.p) } - func moved(along direction: TangentVector) -> S { - return S(p: p + direction.dp) + mutating func move(along direction: TangentVector) { + p.move(along: direction.dp) } } diff --git a/test/AutoDiff/derived_differentiable_properties.swift b/test/AutoDiff/derived_differentiable_properties.swift index cb3cfb220832e..e0450010c2363 100644 --- a/test/AutoDiff/derived_differentiable_properties.swift +++ b/test/AutoDiff/derived_differentiable_properties.swift @@ -77,12 +77,12 @@ struct GenericTanMember : Differentiable, AdditiveArithmetic // CHECK-AST: @_implements(Equatable, ==(_:_:)) internal static func __derived_struct_equals(_ a: GenericTanMember, _ b: GenericTanMember) -> Bool public struct ConditionallyDifferentiable { - public let x: T + public var x: T } extension ConditionallyDifferentiable : Differentiable where T : Differentiable {} // CHECK-AST-LABEL: public struct ConditionallyDifferentiable { // CHECK-AST: @differentiable(wrt: self where T : Differentiable) -// CHECK-AST: public let x: T +// CHECK-AST: public var x: T // CHECK-AST: internal init(x: T) // CHECK-AST: } diff --git a/test/AutoDiff/differentiable_attr_type_checking.swift b/test/AutoDiff/differentiable_attr_type_checking.swift index 486b76264f043..b8c5284d962ff 100644 --- a/test/AutoDiff/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/differentiable_attr_type_checking.swift @@ -548,11 +548,7 @@ struct ResultLabelTest { } struct Tensor : AdditiveArithmetic {} -extension Tensor : Differentiable where Scalar : Differentiable { - typealias TangentVector = Tensor - typealias AllDifferentiableVariables = Tensor - func moved(along direction: Tensor) -> Tensor { return self } -} +extension Tensor : Differentiable where Scalar : Differentiable {} @differentiable(where Scalar : Differentiable) func where2(x: Tensor) -> Tensor { return x diff --git a/test/AutoDiff/e2e_differentiable_property.swift b/test/AutoDiff/e2e_differentiable_property.swift index 6138ddcf48cf8..1f181beae3e6a 100644 --- a/test/AutoDiff/e2e_differentiable_property.swift +++ b/test/AutoDiff/e2e_differentiable_property.swift @@ -20,16 +20,15 @@ struct Space { /// `x` is a computed property with a custom vjp. var x: Float { @differentiable(vjp: vjpX) - get { - return storedX - } + get { storedX } + set { storedX = newValue } } func vjpX() -> (Float, (Float) -> TangentSpace) { return (x, { v in TangentSpace(x: v, y: 0) } ) } - private let storedX: Float + private var storedX: Float @differentiable var y: Float @@ -42,8 +41,9 @@ struct Space { extension Space : Differentiable { typealias TangentVector = TangentSpace - func moved(along: TangentSpace) -> Space { - return Space(x: x + along.x, y: y + along.y) + mutating func move(along direction: TangentSpace) { + x.move(along: direction.x) + y.move(along: direction.y) } } @@ -106,13 +106,14 @@ extension ProductSpaceOtherTangentTangentSpace : Differentiable { } struct ProductSpaceOtherTangent { - let x, y: Float + var x, y: Float } extension ProductSpaceOtherTangent : Differentiable { typealias TangentVector = ProductSpaceOtherTangentTangentSpace - func moved(along: ProductSpaceOtherTangentTangentSpace) -> ProductSpaceOtherTangent { - return ProductSpaceOtherTangent(x: x + along.x, y: y + along.y) + mutating func move(along direction: ProductSpaceOtherTangentTangentSpace) { + x.move(along: direction.x) + y.move(along: direction.y) } } diff --git a/test/AutoDiff/separate_cotangent_type.swift b/test/AutoDiff/separate_cotangent_type.swift deleted file mode 100644 index 68bb5f818c853..0000000000000 --- a/test/AutoDiff/separate_cotangent_type.swift +++ /dev/null @@ -1,48 +0,0 @@ -// RUN: %target-run-simple-swift -// REQUIRES: executable_test - -import StdlibUnittest -#if os(macOS) -import Darwin.C -#else -import Glibc -#endif - -var SeparateTangentTypeTests = TestSuite("SeparateTangentType") - -struct DifferentiableSubset : Differentiable { - @differentiable(wrt: self) - var w: Float - @differentiable(wrt: self) - var b: Float - @noDerivative var flag: Bool - - struct TangentVector : Differentiable, VectorProtocol { - typealias TangentVector = DifferentiableSubset.TangentVector - var w: Float - var b: Float - } - func moved(along v: TangentVector) -> DifferentiableSubset { - return DifferentiableSubset(w: w.moved(along: v.w), b: b.moved(along: v.b), flag: flag) - } -} - -SeparateTangentTypeTests.test("Trivial") { - let x = DifferentiableSubset(w: 0, b: 1, flag: false) - let pb = pullback(at: x) { x in x } - expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero) -} - -SeparateTangentTypeTests.test("Initialization") { - let x = DifferentiableSubset(w: 0, b: 1, flag: false) - let pb = pullback(at: x) { x in DifferentiableSubset(w: 1, b: 2, flag: true) } - expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero) -} - -SeparateTangentTypeTests.test("SomeArithmetics") { - let x = DifferentiableSubset(w: 0, b: 1, flag: false) - let pb = pullback(at: x) { x in DifferentiableSubset(w: x.w * x.w, b: x.b * x.b, flag: true) } - expectEqual(pb(DifferentiableSubset.TangentVector.zero), DifferentiableSubset.TangentVector.zero) -} - -runAllTests() diff --git a/test/Sema/struct_differentiable.swift b/test/Sema/struct_differentiable.swift index e10f02b2dccfe..5ab2573b0e11e 100644 --- a/test/Sema/struct_differentiable.swift +++ b/test/Sema/struct_differentiable.swift @@ -20,20 +20,28 @@ func testEmpty() { // Previously, this crashed due to duplicate memberwise initializer synthesis. struct EmptyAdditiveArithmetic : AdditiveArithmetic, Differentiable {} -// Test structs whose stored properties all have a default value. -struct AllLetStoredPropertiesHaveInitialValue : Differentiable { - // expected-warning @+1 {{'let' properties with a default value do not have a derivative; add '@noDerivative' to make it explicit, or change it to 'var' to allow derivatives}} {{3-3=@noDerivative }} - let x = Float(1) - // expected-warning @+1 {{'let' properties with a default value do not have a derivative; add '@noDerivative' to make it explicit, or change it to 'var' to allow derivatives}} {{3-3=@noDerivative }} - let y = Float(1) -} -struct AllVarStoredPropertiesHaveInitialValue : Differentiable { +// Test structs with `let` stored properties. +// Derived conformances fail because `mutating func move` requires all stored +// properties to be mutable. +struct ImmutableStoredProperties : Differentiable { + var okay: Float + + // expected-warning @+1 {{stored property 'nondiff' has no derivative because it does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute, or conform 'ImmutableStoredProperties' to 'AdditiveArithmetic'}} {{3-3=@noDerivative }} + let nondiff: Int + + // expected-warning @+1 {{synthesis of the 'Differentiable.move(along:)' requirement for 'ImmutableStoredProperties' requires all stored properties to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute, or conform 'ImmutableStoredProperties' to 'AdditiveArithmetic'}} {{3-3=@noDerivative }} + let diff: Float +} +func testImmutableStoredProperties() { + _ = ImmutableStoredProperties.TangentVector(okay: 1) +} +struct MutableStoredPropertiesWithInitialValue : Differentiable { var x = Float(1) - var y = Float(1) + var y = Double(1) } // Test struct with both an empty constructor and memberwise initializer. struct AllMixedStoredPropertiesHaveInitialValue : Differentiable { - let x = Float(1) // expected-warning {{'let' properties with a default value do not have a derivative}} {{3-3=@noDerivative }} + let x = Float(1) // expected-warning {{synthesis of the 'Differentiable.move(along:)' requirement for 'AllMixedStoredPropertiesHaveInitialValue' requires all stored properties to be mutable; use 'var' instead, or add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }} var y = Float(1) // Memberwise initializer should be `init(y:)` since `x` is immutable. static func testMemberwiseInitializer() { @@ -54,7 +62,7 @@ struct Simple : AdditiveArithmetic, Differentiable { func testSimple() { var simple = Simple(w: 1, b: 1) simple.allDifferentiableVariables = simple + simple - assert(simple.moved(along: simple) == simple + simple) + simple.move(along: simple) } // Test type with mixed members. @@ -65,7 +73,7 @@ struct Mixed : AdditiveArithmetic, Differentiable { func testMixed(_ simple: Simple) { var mixed = Mixed(simple: simple, float: 1) mixed.allDifferentiableVariables = Mixed(simple: simple, float: 2) - assert(mixed.moved(along: mixed) == mixed + mixed) + mixed.move(along: mixed) } // Test type with manual definition of vector space types to `Self`. @@ -87,7 +95,7 @@ struct GenericVectorSpacesEqualSelf : AdditiveArithmetic, Differentiable func testGenericVectorSpacesEqualSelf() { var genericSame = GenericVectorSpacesEqualSelf(w: 1, b: 1) genericSame.allDifferentiableVariables = genericSame + genericSame - assert(genericSame.moved(along: genericSame) == genericSame + genericSame) + genericSame.move(along: genericSame) } // Test nested type. @@ -100,8 +108,8 @@ func testNested( _ simple: Simple, _ mixed: Mixed, _ genericSame: GenericVectorSpacesEqualSelf ) { - let nested = Nested(simple: simple, mixed: mixed, generic: genericSame) - assert(nested.moved(along: nested) == nested + nested) + var nested = Nested(simple: simple, mixed: mixed, generic: genericSame) + nested.move(along: nested) _ = pullback(at: nested) { model in model.simple + model.simple @@ -137,15 +145,6 @@ func testAllMembersVectorProtocol() { assertConformsToVectorProtocol(AllMembersVectorProtocol.TangentVector.self) } -// Test type with immutable, differentiable stored property. -struct ImmutableStoredProperty : Differentiable { - var w: Float - let fixedBias: Float = .pi // expected-warning {{'let' properties with a default value do not have a derivative}} {{3-3=@noDerivative }} -} -func testImmutableStoredProperty() { - _ = ImmutableStoredProperty.TangentVector(w: 1) -} - // Test type whose properties are not all differentiable. struct DifferentiableSubset : Differentiable { var w: Float @@ -200,14 +199,14 @@ func testKeyPathIterable(x: TestKeyPathIterable) { // Test type with user-defined memberwise initializer. struct TF_25: Differentiable { - public let bar: Float + public var bar: Float public init(bar: Float) { self.bar = bar } } // Test user-defined memberwise initializer. struct TF_25_Generic: Differentiable { - public let bar: T + public var bar: T public init(bar: T) { self.bar = bar } @@ -318,12 +317,12 @@ struct StaticMembersShouldNotAffectAnything : AdditiveArithmetic, Differentiable struct ImplicitNoDerivative : Differentiable { var a: Float - var b: Bool // expected-warning {{stored property 'b' has no derivative because it does not conform to 'Differentiable'; add '@noDerivative' to make it explicit}} + var b: Bool // expected-warning {{stored property 'b' has no derivative because it does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}} } struct ImplicitNoDerivativeWithSeparateTangent : Differentiable { var x: DifferentiableSubset - var b: Bool // expected-warning {{stored property 'b' has no derivative because it does not conform to 'Differentiable'; add '@noDerivative' to make it explicit}} {{3-3=@noDerivative }} + var b: Bool // expected-warning {{stored property 'b' has no derivative because it does not conform to 'Differentiable'; add an explicit '@noDerivative' attribute}} {{3-3=@noDerivative }} } // TF-265: Test invalid initializer (that uses a non-existent type). From 8fc3cd148eb4f66f17bafff21bf28e0fff391567 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 13 Jun 2019 15:07:13 -0700 Subject: [PATCH 2/3] Fix `@noDerivative` warning location. The `@noDerivative` fixit location was correct, but the location of the warning marker `^` was incorrect. Before: ``` let bad2: Int ^ @noDerivative ``` After: ``` let bad2: Int ^ @noDerivative ``` --- lib/Sema/DerivedConformanceDifferentiable.cpp | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index fd1502ebd7519..ce3093d952d52 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -794,8 +794,7 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC, // Otherwise, add an implicit `@noDerivative` attribute. vd->getAttrs().add( new (TC.Context) NoDerivativeAttr(/*Implicit*/ true)); - auto loc = - vd->getLoc().isValid() ? vd->getLoc() : DC->getAsDecl()->getLoc(); + auto loc = vd->getAttributeInsertionLoc(/*forModifier*/ false); assert(loc.isValid() && "Expected valid source location"); // If nominal type can conform to `AdditiveArithmetic`, suggest conforming // adding a conformance to `AdditiveArithmetic`. @@ -809,16 +808,14 @@ static void checkAndDiagnoseImplicitNoDerivative(TypeChecker &TC, diag::differentiable_nondiff_type_implicit_noderivative_fixit, vd->getName(), nominal->getName(), nominalCanDeriveAdditiveArithmetic) - .fixItInsert(vd->getAttributeInsertionLoc(/*forModifier*/ false), - "@noDerivative "); + .fixItInsert(loc, "@noDerivative "); continue; } TC.diagnose(loc, diag::differentiable_let_property_implicit_noderivative_fixit, vd->getName(), nominal->getName(), nominalCanDeriveAdditiveArithmetic) - .fixItInsert(vd->getAttributeInsertionLoc(/*forModifier*/ false), - "@noDerivative "); + .fixItInsert(loc, "@noDerivative "); } } From bc3ab07f691abbd3b7f2b3d98dc3cc2f60f7eba8 Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Thu, 13 Jun 2019 15:13:04 -0700 Subject: [PATCH 3/3] Update checkout for tensorflow-swift-apis. --- utils/update_checkout/update-checkout-config.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/update_checkout/update-checkout-config.json b/utils/update_checkout/update-checkout-config.json index cb2ca84d4ba97..dc4e9277c8ac7 100644 --- a/utils/update_checkout/update-checkout-config.json +++ b/utils/update_checkout/update-checkout-config.json @@ -349,7 +349,7 @@ "clang-tools-extra": "swift-DEVELOPMENT-SNAPSHOT-2019-06-02-a", "libcxx": "swift-DEVELOPMENT-SNAPSHOT-2019-06-02-a", "tensorflow": "ebc41609e27dcf0998d8970e77a2e1f53e13ac86", - "tensorflow-swift-apis": "fb3135cc8113639d74070e019105a129dcb74e7b", + "tensorflow-swift-apis": "5d3ef57b501781f1bab3b4ca85e8b8fc91671c14", "indexstore-db": "swift-DEVELOPMENT-SNAPSHOT-2019-06-02-a", "sourcekit-lsp": "swift-DEVELOPMENT-SNAPSHOT-2019-06-02-a" }