diff --git a/include/swift/AST/Attr.def b/include/swift/AST/Attr.def index b645a373ba216..83a91530cb7ff 100644 --- a/include/swift/AST/Attr.def +++ b/include/swift/AST/Attr.def @@ -395,8 +395,8 @@ SIMPLE_DECL_ATTR(TensorFlowGraph, TensorFlowGraph, OnFunc, 82) SIMPLE_DECL_ATTR(TFParameter, TFParameter, OnVar, 83) -SIMPLE_DECL_ATTR(_fieldwiseProductSpace, FieldwiseProductSpace, - OnTypeAlias | OnNominalType | UserInaccessible, 84) +SIMPLE_DECL_ATTR(_fieldwiseDifferentiable, FieldwiseDifferentiable, + OnNominalType | UserInaccessible, 84) SIMPLE_DECL_ATTR(noDerivative, NoDerivative, OnVar, 85) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index ae13fb9047896..9998d9a8cac89 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2742,6 +2742,11 @@ ERROR(noderivative_only_on_stored_properties_in_differentiable_structs,none, "@noDerivative is only allowed on stored properties in structure types " "that declare a conformance to 'Differentiable'", ()) +// @_fieldwiseDifferentiable attribute +ERROR(fieldwise_differentiable_only_on_differentiable_structs,none, + "@_fieldwiseDifferentiable is only allowed on structure types that " + "conform to 'Differentiable'", ()) + //------------------------------------------------------------------------------ // MARK: Type Check Expressions //------------------------------------------------------------------------------ @@ -3655,11 +3660,15 @@ ERROR(unreferenced_generic_parameter,none, // SWIFT_ENABLE_TENSORFLOW // Function differentiability ERROR(autodiff_attr_argument_not_differentiable,none, - "argument is not differentiable, but the enclosing function type is marked '@autodiff'; did you want to add '@nondiff' to this argument?", ()) + "argument is not differentiable, but the enclosing function type is " + "marked '@autodiff'; did you want to add '@nondiff' to this argument?", + ()) ERROR(autodiff_attr_result_not_differentiable,none, - "result is not differentiable, but the function type is marked '@autodiff'", ()) + "result is not differentiable, but the function type is marked " + "'@autodiff'", ()) ERROR(nondiff_attr_invalid_on_nondifferentiable_function,none, - "'nondiff' cannot be applied to arguments of a non-differentiable function", ()) + "'nondiff' cannot be applied to arguments of a non-differentiable " + "function", ()) // SIL ERROR(opened_non_protocol,none, diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 1bc71f9ec0d49..6bb1a0116a64a 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -99,16 +99,6 @@ static void createEntryArguments(SILFunction *f) { } } -/// Looks up a function in the current module. If it exists, returns it. -/// Otherwise, attempt to link it from imported modules. Returns null if such -/// function name does not exist. -static SILFunction *lookUpOrLinkFunction(StringRef name, SILModule &module) { - assert(!name.empty()); - if (auto *localFn = module.lookUpFunction(name)) - return localFn; - return module.findFunction(name, SILLinkage::PublicExternal); -} - /// Computes the correct linkage for functions generated by the AD pass /// associated with a function with linkage `originalLinkage`. static SILLinkage getAutoDiffFunctionLinkage(SILLinkage originalLinkage) { @@ -528,7 +518,7 @@ enum class StructExtractDifferentiationStrategy { // that is zero except along the direction of the corresponding field. // // Fields correspond by matching name. - FieldwiseProductSpace, + Fieldwise, // Differentiate the `struct_extract` by looking up the corresponding getter // and using its VJP. @@ -1291,6 +1281,7 @@ void DifferentiableActivityInfo::analyze(DominanceInfo *di, if (isVaried(cai->getSrc(), i)) recursivelySetVariedIfDifferentiable(cai->getDest(), i); } + // Handle `struct_extract`. else if (auto *sei = dyn_cast(&inst)) { if (isVaried(sei->getOperand(), i)) { auto hasNoDeriv = sei->getField()->getAttrs() @@ -2091,46 +2082,30 @@ class PrimalGenCloner final : public SILClonerWithScopes { } void visitStructExtractInst(StructExtractInst *sei) { - auto &astCtx = getContext().getASTContext(); - auto &structExtractDifferentiationStrategies = + auto &strategies = getDifferentiationTask()->getStructExtractDifferentiationStrategies(); - // Special handling logic only applies when the `struct_extract` is active. // If not, just do standard cloning. if (!activityInfo.isActive(sei, synthesis.indices)) { LLVM_DEBUG(getADDebugStream() << "Not active:\n" << *sei << '\n'); - structExtractDifferentiationStrategies.insert( + strategies.insert( {sei, StructExtractDifferentiationStrategy::Inactive}); SILClonerWithScopes::visitStructExtractInst(sei); return; } - // This instruction is active. Determine the appropriate differentiation // strategy, and use it. - - // Use the FieldwiseProductSpace strategy, if appropriate. auto *structDecl = sei->getStructDecl(); - auto cotangentDeclLookup = - structDecl->lookupDirect(astCtx.Id_CotangentVector); - if (cotangentDeclLookup.size() >= 1) { - assert(cotangentDeclLookup.size() == 1); - auto cotangentTypeDecl = cotangentDeclLookup.front(); - assert(isa(cotangentTypeDecl) || - isa(cotangentTypeDecl)); - if (cotangentTypeDecl->getAttrs() - .hasAttribute()) { - structExtractDifferentiationStrategies.insert( - {sei, StructExtractDifferentiationStrategy::FieldwiseProductSpace}); - SILClonerWithScopes::visitStructExtractInst(sei); - return; - } + if (structDecl->getAttrs().hasAttribute()) { + strategies.insert( + {sei, StructExtractDifferentiationStrategy::Fieldwise}); + SILClonerWithScopes::visitStructExtractInst(sei); + return; } - // The FieldwiseProductSpace strategy is not appropriate, so use the Getter // strategy. - structExtractDifferentiationStrategies.insert( + strategies.insert( {sei, StructExtractDifferentiationStrategy::Getter}); - // Find the corresponding getter and its VJP. auto *getterDecl = sei->getField()->getGetter(); assert(getterDecl); @@ -2142,42 +2117,29 @@ class PrimalGenCloner final : public SILClonerWithScopes { errorOccurred = true; return; } - auto getterDiffAttrs = getterFn->getDifferentiableAttrs(); - if (getterDiffAttrs.size() < 1) { - getContext().emitNondifferentiabilityError( - sei, synthesis.task, diag::autodiff_property_not_differentiable); - errorOccurred = true; - return; - } - auto *getterDiffAttr = getterDiffAttrs[0]; - if (!getterDiffAttr->hasVJP()) { + SILAutoDiffIndices indices(/*source*/ 0, /*parameters*/ {0}); + auto *task = getContext().lookUpDifferentiationTask(getterFn, indices); + if (!task) { getContext().emitNondifferentiabilityError( sei, synthesis.task, diag::autodiff_property_not_differentiable); errorOccurred = true; return; } - assert(getterDiffAttr->getIndices() == - SILAutoDiffIndices(/*source*/ 0, /*parameters*/{0})); - auto *getterVJP = lookUpOrLinkFunction(getterDiffAttr->getVJPName(), - getContext().getModule()); - // Reference and apply the VJP. auto loc = sei->getLoc(); - auto *getterVJPRef = getBuilder().createFunctionRef(loc, getterVJP); + auto *getterVJPRef = getBuilder().createFunctionRef(loc, task->getVJP()); auto *getterVJPApply = getBuilder().createApply( loc, getterVJPRef, /*substitutionMap*/ {}, /*args*/ {getMappedValue(sei->getOperand())}, /*isNonThrowing*/ false); SmallVector vjpDirectResults; extractAllElements(getterVJPApply, getBuilder(), vjpDirectResults); - ArrayRef originalDirectResults = - ArrayRef(vjpDirectResults).drop_back(1); - // Map original results. + auto originalDirectResults = + ArrayRef(vjpDirectResults).drop_back(1); SILValue originalDirectResult = joinElements(originalDirectResults, getBuilder(), getterVJPApply->getLoc()); mapValue(sei, originalDirectResult); - // Checkpoint the pullback. SILValue pullback = vjpDirectResults.back(); getPrimalInfo().addPullbackDecl(sei, pullback->getType().getASTType()); @@ -3079,60 +3041,41 @@ class AdjointEmitter final : public SILInstructionVisitor { auto loc = remapLocation(sei->getLoc()); auto &differentiationStrategies = getDifferentiationTask()->getStructExtractDifferentiationStrategies(); - auto differentiationStrategyLookUp = differentiationStrategies.find(sei); - assert(differentiationStrategyLookUp != differentiationStrategies.end()); - auto differentiationStrategy = differentiationStrategyLookUp->second; - - if (differentiationStrategy == - StructExtractDifferentiationStrategy::Inactive) { + auto strategy = differentiationStrategies.lookup(sei); + switch (strategy) { + case StructExtractDifferentiationStrategy::Inactive: assert(!activityInfo.isActive(sei, synthesis.indices)); return; - } - - if (differentiationStrategy == - StructExtractDifferentiationStrategy::FieldwiseProductSpace) { + case StructExtractDifferentiationStrategy::Fieldwise: { // Compute adjoint as follows: // y = struct_extract , x // adj[x] = struct (0, ..., key': adj[y], ..., 0) // where `key'` is the field in the cotangent space corresponding to // `key`. - - // Find the decl of the cotangent space type. auto structTy = sei->getOperand()->getType().getASTType(); auto cotangentVectorTy = structTy->getAutoDiffAssociatedVectorSpace( AutoDiffAssociatedVectorSpaceKind::Cotangent, LookUpConformanceInModule(getModule().getSwiftModule())) - ->getType()->getCanonicalType(); - assert(!getModule() - .Types.getTypeLowering(cotangentVectorTy) - .isAddressOnly()); + ->getType()->getCanonicalType(); + assert(!getModule().Types.getTypeLowering(cotangentVectorTy) + .isAddressOnly()); auto cotangentVectorSILTy = SILType::getPrimitiveObjectType(cotangentVectorTy); auto *cotangentVectorDecl = cotangentVectorTy->getStructOrBoundGenericStruct(); assert(cotangentVectorDecl); - // Find the corresponding field in the cotangent space. VarDecl *correspondingField = nullptr; + // If the cotangent space is the original sapce, then it's the same field. if (cotangentVectorDecl == sei->getStructDecl()) correspondingField = sei->getField(); + // Otherwise we just look it up by name. else { auto correspondingFieldLookup = cotangentVectorDecl->lookupDirect(sei->getField()->getName()); assert(correspondingFieldLookup.size() == 1); - assert(isa(correspondingFieldLookup[0])); - correspondingField = cast(correspondingFieldLookup[0]); + correspondingField = cast(correspondingFieldLookup.front()); } - assert(correspondingField); - -#ifndef NDEBUG - unsigned numMatchingStoredProperties = 0; - for (auto *storedProperty : cotangentVectorDecl->getStoredProperties()) - if (storedProperty == correspondingField) - numMatchingStoredProperties += 1; - assert(numMatchingStoredProperties == 1); -#endif - // Compute adjoint. auto av = getAdjointValue(sei); switch (av.getKind()) { @@ -3148,44 +3091,41 @@ class AdjointEmitter final : public SILInstructionVisitor { eltVals.push_back(av); else eltVals.push_back(AdjointValue::getZero( - SILType::getPrimitiveObjectType(field->getType() - ->getCanonicalType()))); + SILType::getPrimitiveObjectType( + field->getType()->getCanonicalType()))); } addAdjointValue(sei->getOperand(), AdjointValue::getAggregate(cotangentVectorSILTy, eltVals, allocator)); } } - return; } - - // The only remaining strategy is the getter strategy. - // Replace the `struct_extract` with a call to its pullback. - assert(differentiationStrategy == - StructExtractDifferentiationStrategy::Getter); - - // Get the pullback. - auto *pullbackField = getPrimalInfo().lookUpPullbackDecl(sei); - assert(pullbackField); - SILValue pullback = builder.createStructExtract(loc, - primalValueAggregateInAdj, - pullbackField); - - // Construct the pullback arguments. - SmallVector args; - auto seed = getAdjointValue(sei); - assert(seed.getType().isObject()); - args.push_back(materializeAdjointDirect(seed, loc)); - - // Call the pullback. - auto *pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(), - args, /*isNonThrowing*/ false); - assert(!pullbackCall->hasIndirectResults()); - - // Set adjoint for the `struct_extract` operand. - addAdjointValue(sei->getOperand(), - AdjointValue::getMaterialized(pullbackCall)); + case StructExtractDifferentiationStrategy::Getter: { + // Get the pullback. + auto *pullbackField = getPrimalInfo().lookUpPullbackDecl(sei); + assert(pullbackField); + SILValue pullback = builder.createStructExtract(loc, + primalValueAggregateInAdj, + pullbackField); + + // Construct the pullback arguments. + SmallVector args; + auto seed = getAdjointValue(sei); + assert(seed.getType().isObject()); + args.push_back(materializeAdjointDirect(seed, loc)); + + // Call the pullback. + auto *pullbackCall = builder.createApply(loc, pullback, SubstitutionMap(), + args, /*isNonThrowing*/ false); + assert(!pullbackCall->hasIndirectResults()); + + // Set adjoint for the `struct_extract` operand. + addAdjointValue(sei->getOperand(), + AdjointValue::getMaterialized(pullbackCall)); + break; + } + } } /// Handle `tuple` instruction. @@ -4236,25 +4176,22 @@ void DifferentiationTask::createVJP() { loc, adjointRef, vjpSubstMap, partialAdjointArgs, ParameterConvention::Direct_Guaranteed); - // === Clean up the stack allocations. === + // Clean up the stack allocations. for (auto alloc : reversed(stackAllocsToCleanUp)) builder.createDeallocStack(loc, alloc); - // === Return the direct results. === - // (Note that indirect results have already been filled in by the application - // of the primal). + // Return the direct results. Note that indirect results have already been + // filled in by the application of the primal. SmallVector directResults; auto originalDirectResults = ArrayRef(primalDirectResults) .take_back(originalConv.getNumDirectSILResults()); for (auto originalDirectResult : originalDirectResults) directResults.push_back(originalDirectResult); directResults.push_back(adjointPartialApply); - if (directResults.size() > 1) { - auto tupleRet = builder.createTuple(loc, directResults); - builder.createReturn(loc, tupleRet); - } else { - builder.createReturn(loc, directResults[0]); - } + if (directResults.size() > 1) + builder.createReturn(loc, builder.createTuple(loc, directResults)); + else + builder.createReturn(loc, directResults.front()); } //===----------------------------------------------------------------------===// diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index b1e23ddc5d02a..74d1102bb80f7 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -598,7 +598,8 @@ deriveDifferentiable_allDifferentiableVariables(DerivedConformance &derived) { // `AllDifferentiableVariables` struct for a nominal type, if it exists. // If not, synthesize the struct. static StructDecl * -getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived, Identifier id) { +getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived, + Identifier id) { auto &TC = derived.TC; auto parentDC = derived.getConformanceContext(); auto nominal = derived.Nominal; @@ -666,10 +667,10 @@ getOrSynthesizeSingleAssociatedStruct(DerivedConformance &derived, Identifier id auto *structDecl = new (C) StructDecl(SourceLoc(), id, SourceLoc(), /*Inherited*/ C.AllocateCopy(inherited), /*GenericParams*/ {}, parentDC); + structDecl->getAttrs().add( + new (C) FieldwiseDifferentiableAttr(/*implicit*/ true)); structDecl->setImplicit(); structDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); - structDecl->getAttrs().add(new (C) - FieldwiseProductSpaceAttr(/*Implicit*/ true)); // Add members to associated struct. for (auto *member : diffProperties) { @@ -754,8 +755,6 @@ static void addAssociatedTypeAliasDecl(Identifier name, TypeAliasDecl(SourceLoc(), SourceLoc(), name, SourceLoc(), {}, source); aliasDecl->setUnderlyingType(target->getDeclaredInterfaceType()); aliasDecl->setImplicit(); - aliasDecl->getAttrs().add(new (C) - FieldwiseProductSpaceAttr(/*Implicit*/ true)); if (auto env = source->getGenericEnvironmentOfContext()) aliasDecl->setGenericEnvironment(env); source->addMember(aliasDecl); @@ -855,6 +854,11 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived, auto nominal = derived.Nominal; auto &C = nominal->getASTContext(); + // Since associated types will be derived, we make this struct a fieldwise + // differentiable type. + nominal->getAttrs().add( + new (C) FieldwiseDifferentiableAttr(/*implicit*/ true)); + // Get all stored properties for differentation. SmallVector diffProperties; getStoredPropertiesForDifferentiation(nominal, diffProperties); @@ -892,8 +896,6 @@ deriveDifferentiable_AssociatedStruct(DerivedConformance &derived, SourceLoc(), {}, nominal); aliasDecl->setUnderlyingType(selfType); aliasDecl->setImplicit(); - aliasDecl->getAttrs().add( - new (C) FieldwiseProductSpaceAttr(/*implicit*/ true)); nominal->addMember(aliasDecl); aliasDecl->copyFormalAccessFrom(nominal, /*sourceIsParentContext*/ true); aliasDecl->setValidationToChecked(); diff --git a/lib/Sema/DerivedConformanceParameterized.cpp b/lib/Sema/DerivedConformanceParameterized.cpp index 1ce0d37587dd8..c64af5609eabe 100644 --- a/lib/Sema/DerivedConformanceParameterized.cpp +++ b/lib/Sema/DerivedConformanceParameterized.cpp @@ -400,11 +400,6 @@ static Type deriveParameterized_Parameters(DerivedConformance &derived) { parent->addMember(aliasDecl); aliasDecl->copyFormalAccessFrom(parent, /*sourceIsParentContext*/ true); aliasDecl->setValidationToChecked(); - // Add `@_fieldwiseProductSpace` attribute to typealias declaration. - // This enables differentiation wrt member accesses of the `Parameterized` - // struct. - aliasDecl->getAttrs().add(new (C) - FieldwiseProductSpaceAttr(/*Implicit*/ true)); TC.validateDecl(aliasDecl); C.addSynthesizedDecl(aliasDecl); }; diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 04c19277e26f5..fb70362ca3e04 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -126,7 +126,7 @@ class AttributeEarlyChecker : public AttributeVisitor { IGNORED_ATTR(CompilerEvaluable) IGNORED_ATTR(TensorFlowGraph) IGNORED_ATTR(TFParameter) - IGNORED_ATTR(FieldwiseProductSpace) + IGNORED_ATTR(FieldwiseDifferentiable) IGNORED_ATTR(NoDerivative) #undef IGNORED_ATTR @@ -886,7 +886,7 @@ class AttributeChecker : public AttributeVisitor { void visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr); void visitTensorFlowGraphAttr(TensorFlowGraphAttr *attr); void visitTFParameterAttr(TFParameterAttr *attr); - void visitFieldwiseProductSpaceAttr(FieldwiseProductSpaceAttr *attr); + void visitFieldwiseDifferentiableAttr(FieldwiseDifferentiableAttr *attr); void visitNoDerivativeAttr(NoDerivativeAttr *attr); }; } // end anonymous namespace @@ -2744,20 +2744,22 @@ void AttributeChecker::visitTFParameterAttr(TFParameterAttr *attr) { } } -void AttributeChecker::visitFieldwiseProductSpaceAttr( - FieldwiseProductSpaceAttr *attr) { - // If we make this attribute user-facing, we'll need to do various checks. - // - check that this attribute is on a - // Tangent/Cotangent/AllDifferentiableVariables type alias - // - check that we can access the raw fields of the - // Tangent/Cotangent/AllDifferentiableVariables structs from - // this module (e.g. the Tangent can't be a public resilient struct - // defined in a different module). - // - check that the stored properties of the - // Tangent/Cotangent/AllDifferentiableVariables match - // - // If we don't make this attribute user-facing, we can avoid doing checks - // here: the assertions in Differentiation.cpp suffice. +void AttributeChecker::visitFieldwiseDifferentiableAttr( + FieldwiseDifferentiableAttr *attr) { + auto *structDecl = dyn_cast(D); + if (!structDecl) { + diagnoseAndRemoveAttr(attr, + diag::fieldwise_differentiable_only_on_differentiable_structs); + return; + } + if (!TC.conformsToProtocol( + structDecl->swift::TypeDecl::getDeclaredInterfaceType(), + TC.Context.getProtocol(KnownProtocolKind::Differentiable), + structDecl, ConformanceCheckFlags::Used)) { + diagnoseAndRemoveAttr(attr, + diag::fieldwise_differentiable_only_on_differentiable_structs); + return; + } } void AttributeChecker::visitNoDerivativeAttr(NoDerivativeAttr *attr) { diff --git a/lib/Sema/TypeCheckDeclOverride.cpp b/lib/Sema/TypeCheckDeclOverride.cpp index 8e0851793d0af..bcd547b79ab36 100644 --- a/lib/Sema/TypeCheckDeclOverride.cpp +++ b/lib/Sema/TypeCheckDeclOverride.cpp @@ -1218,7 +1218,7 @@ namespace { UNINTERESTING_ATTR(CompilerEvaluable) UNINTERESTING_ATTR(TensorFlowGraph) UNINTERESTING_ATTR(TFParameter) - UNINTERESTING_ATTR(FieldwiseProductSpace) + UNINTERESTING_ATTR(FieldwiseDifferentiable) UNINTERESTING_ATTR(NoDerivative) // These can't appear on overridable declarations. diff --git a/test/AutoDiff/derived_differentiable_properties.swift b/test/AutoDiff/derived_differentiable_properties.swift index e23ef1107ddf6..ef2f72988caaf 100644 --- a/test/AutoDiff/derived_differentiable_properties.swift +++ b/test/AutoDiff/derived_differentiable_properties.swift @@ -6,15 +6,15 @@ public struct Foo : Differentiable { public var a: Float } -// CHECK-AST-LABEL: public struct Foo : Differentiable { +// CHECK-AST-LABEL: @_fieldwiseDifferentiable public struct Foo : Differentiable { // CHECK-AST: @sil_stored @differentiable(wrt: (self)) // CHECK-AST: public var a: Float { get set } -// CHECK-AST: @_fieldwiseProductSpace struct AllDifferentiableVariables -// CHECK-AST: @_fieldwiseProductSpace typealias AllDifferentiableVariables = Foo.AllDifferentiableVariables -// CHECK-AST: @_fieldwiseProductSpace typealias TangentVector = Foo.AllDifferentiableVariables -// CHECK-AST: @_fieldwiseProductSpace typealias CotangentVector = Foo.AllDifferentiableVariables -// CHECK-AST: @_fieldwiseProductSpace typealias TangentVector = Foo.AllDifferentiableVariables -// CHECK-AST: @_fieldwiseProductSpace typealias CotangentVector = Foo.AllDifferentiableVariables +// CHECK-AST: @_fieldwiseDifferentiable struct AllDifferentiableVariables +// CHECK-AST: typealias AllDifferentiableVariables = Foo.AllDifferentiableVariables +// CHECK-AST: typealias TangentVector = Foo.AllDifferentiableVariables +// CHECK-AST: typealias CotangentVector = Foo.AllDifferentiableVariables +// CHECK-AST: typealias TangentVector = Foo.AllDifferentiableVariables +// CHECK-AST: typealias CotangentVector = Foo.AllDifferentiableVariables // CHECK-SILGEN-LABEL: // Foo.a.getter // CHECK-SILGEN: sil [transparent] [serialized] [differentiable source 0 wrt 0] @$s33derived_differentiable_properties3FooV1aSfvg : $@convention(method) (Foo) -> Float @@ -26,11 +26,10 @@ let _: @autodiff (AdditiveTangentIsSelf) -> Float = { x in x.a + x.a } -// CHECK-AST-LABEL: struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable { +// CHECK-AST-LABEL: @_fieldwiseDifferentiable struct AdditiveTangentIsSelf : AdditiveArithmetic, Differentiable { // CHECK-AST-NOT: @differentiable -// CHECK-AST: @_fieldwiseProductSpace typealias TangentVector = AdditiveTangentIsSelf -// CHECK-AST: @_fieldwiseProductSpace typealias CotangentVector = AdditiveTangentIsSelf -// FIXME: `typealias AllDifferentiableVariables` should have `@_fieldwiseProductSpace`. +// CHECK-AST: typealias TangentVector = AdditiveTangentIsSelf +// CHECK-AST: typealias CotangentVector = AdditiveTangentIsSelf // CHECK-AST: typealias AllDifferentiableVariables = AdditiveTangentIsSelf struct TestNoDerivative : Differentiable { @@ -38,27 +37,27 @@ struct TestNoDerivative : Differentiable { @noDerivative var technicallyDifferentiable: Float } -// CHECK-AST-LABEL: struct TestNoDerivative : Differentiable { +// CHECK-AST-LABEL: @_fieldwiseDifferentiable struct TestNoDerivative : Differentiable { // CHECK-AST: @sil_stored var w: Float { get set } // CHECK-AST: @sil_stored @noDerivative var technicallyDifferentiable: Float { get set } -// CHECK-AST: @_fieldwiseProductSpace struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, VectorNumeric -// CHECK-AST: @_fieldwiseProductSpace typealias AllDifferentiableVariables = TestNoDerivative.AllDifferentiableVariables -// CHECK-AST: @_fieldwiseProductSpace typealias TangentVector = TestNoDerivative.AllDifferentiableVariables -// CHECK-AST: @_fieldwiseProductSpace typealias CotangentVector = TestNoDerivative.AllDifferentiableVariables -// CHECK-AST: @_fieldwiseProductSpace typealias TangentVector = TestNoDerivative.AllDifferentiableVariables -// CHECK-AST: @_fieldwiseProductSpace typealias CotangentVector = TestNoDerivative.AllDifferentiableVariables +// CHECK-AST: @_fieldwiseDifferentiable struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, VectorNumeric +// CHECK-AST: typealias AllDifferentiableVariables = TestNoDerivative.AllDifferentiableVariables +// CHECK-AST: typealias TangentVector = TestNoDerivative.AllDifferentiableVariables +// CHECK-AST: typealias CotangentVector = TestNoDerivative.AllDifferentiableVariables +// CHECK-AST: typealias TangentVector = TestNoDerivative.AllDifferentiableVariables +// CHECK-AST: typealias CotangentVector = TestNoDerivative.AllDifferentiableVariables struct TestKeyPathIterable : Differentiable, KeyPathIterable { var w: Float @noDerivative var technicallyDifferentiable: Float } -// CHECK-AST-LABEL: struct TestKeyPathIterable : Differentiable, KeyPathIterable { +// CHECK-AST-LABEL: @_fieldwiseDifferentiable struct TestKeyPathIterable : Differentiable, KeyPathIterable { // CHECK-AST: @sil_stored var w: Float { get set } // CHECK-AST: @sil_stored @noDerivative var technicallyDifferentiable: Float { get set } -// CHECK-AST: @_fieldwiseProductSpace struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, KeyPathIterable, VectorNumeric -// CHECK-AST: @_fieldwiseProductSpace typealias AllDifferentiableVariables = TestKeyPathIterable.AllDifferentiableVariables -// CHECK-AST: @_fieldwiseProductSpace typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables -// CHECK-AST: @_fieldwiseProductSpace typealias CotangentVector = TestKeyPathIterable.AllDifferentiableVariables -// CHECK-AST: @_fieldwiseProductSpace typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables -// CHECK-AST: @_fieldwiseProductSpace typealias CotangentVector = TestKeyPathIterable.AllDifferentiableVariables +// CHECK-AST: @_fieldwiseDifferentiable struct AllDifferentiableVariables : Differentiable, AdditiveArithmetic, KeyPathIterable, VectorNumeric +// CHECK-AST: typealias AllDifferentiableVariables = TestKeyPathIterable.AllDifferentiableVariables +// CHECK-AST: typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables +// CHECK-AST: typealias CotangentVector = TestKeyPathIterable.AllDifferentiableVariables +// CHECK-AST: typealias TangentVector = TestKeyPathIterable.AllDifferentiableVariables +// CHECK-AST: typealias CotangentVector = TestKeyPathIterable.AllDifferentiableVariables diff --git a/test/AutoDiff/e2e_differentiable_property.swift b/test/AutoDiff/e2e_differentiable_property.swift index a9e2119150be5..bd3a97aba4797 100644 --- a/test/AutoDiff/e2e_differentiable_property.swift +++ b/test/AutoDiff/e2e_differentiable_property.swift @@ -73,14 +73,13 @@ E2EDifferentiablePropertyTests.test("stored property") { expectEqual(expectedGrad, actualGrad) } +@_fieldwiseDifferentiable struct ProductSpaceSelfTangent : VectorNumeric { let x, y: Float } extension ProductSpaceSelfTangent : Differentiable { - @_fieldwiseProductSpace typealias TangentVector = ProductSpaceSelfTangent - @_fieldwiseProductSpace typealias CotangentVector = ProductSpaceSelfTangent } @@ -101,14 +100,13 @@ extension ProductSpaceOtherTangentTangentSpace : Differentiable { typealias CotangentVector = ProductSpaceOtherTangentTangentSpace } +@_fieldwiseDifferentiable struct ProductSpaceOtherTangent { let x, y: Float } extension ProductSpaceOtherTangent : Differentiable { - @_fieldwiseProductSpace typealias TangentVector = ProductSpaceOtherTangentTangentSpace - @_fieldwiseProductSpace typealias CotangentVector = ProductSpaceOtherTangentTangentSpace func moved(along: ProductSpaceOtherTangentTangentSpace) -> ProductSpaceOtherTangent { return ProductSpaceOtherTangent(x: x + along.x, y: y + along.y) diff --git a/test/AutoDiff/separate_cotangent_type.swift b/test/AutoDiff/separate_cotangent_type.swift index 49eb077638b7c..97b9fa7a867f4 100644 --- a/test/AutoDiff/separate_cotangent_type.swift +++ b/test/AutoDiff/separate_cotangent_type.swift @@ -10,6 +10,7 @@ import Glibc var SeparateCotangentTypeTests = TestSuite("SeparateCotangentType") +@_fieldwiseDifferentiable struct DifferentiableSubset : Differentiable { @differentiable(wrt: (self)) var w: Float @@ -17,11 +18,9 @@ struct DifferentiableSubset : Differentiable { var b: Float @noDerivative var flag: Bool - // @_fieldwiseProductSpace + @_fieldwiseDifferentiable struct TangentVector : Differentiable, VectorNumeric { - @_fieldwiseProductSpace typealias TangentVector = DifferentiableSubset.TangentVector - @_fieldwiseProductSpace typealias CotangentVector = DifferentiableSubset.CotangentVector var w: Float var b: Float @@ -29,11 +28,9 @@ struct DifferentiableSubset : Differentiable { return TangentVector(w: cotan.w, b: cotan.b) } } - // @_fieldwiseProductSpace + @_fieldwiseDifferentiable struct CotangentVector : Differentiable, VectorNumeric { - @_fieldwiseProductSpace typealias TangentVector = DifferentiableSubset.CotangentVector - @_fieldwiseProductSpace typealias CotangentVector = DifferentiableSubset.TangentVector var w: Float var b: Float diff --git a/test/AutoDiff/witness_table_silgen.swift b/test/AutoDiff/witness_table_silgen.swift index c963fda361ab5..d1ed64591633d 100644 --- a/test/AutoDiff/witness_table_silgen.swift +++ b/test/AutoDiff/witness_table_silgen.swift @@ -11,6 +11,7 @@ protocol Proto : Differentiable { func function3(_ x: Float, _ y: Float) -> Float } +@_fieldwiseDifferentiable struct S : Proto, VectorNumeric { static var zero: S { return S(p: 0) } typealias Scalar = Float @@ -18,9 +19,7 @@ struct S : Proto, VectorNumeric { static func - (lhs: S, rhs: S) -> S { return S(p: lhs.p - rhs.p) } static func * (lhs: Float, rhs: S) -> S { return S(p: lhs * rhs.p) } - @_fieldwiseProductSpace typealias TangentVector = S - @_fieldwiseProductSpace typealias CotangentVector = S @differentiable(wrt: (self), vjp: vjpP)