From b150396ea59fadda5358115d53a939e00087b7a6 Mon Sep 17 00:00:00 2001 From: Marc Rasi Date: Tue, 10 Dec 2019 09:51:37 -0800 Subject: [PATCH] [AutoDiff] TF-1046: put gen sig in more places --- lib/SIL/SILDeclRef.cpp | 21 ++++-- lib/SILGen/SILGen.cpp | 3 + .../Mandatory/Differentiation.cpp | 16 ++++- lib/Sema/DerivedConformanceDifferentiable.cpp | 8 ++- lib/Sema/TypeCheckAttr.cpp | 20 +++--- lib/Sema/TypeCheckProtocol.cpp | 50 +++++++++----- test/AutoDiff/derivative_registration.swift | 5 +- .../loadable_by_address_cross_module.swift | 9 +-- test/AutoDiff/nonvaried_result.swift | 4 +- .../protocol_requirement_autodiff_diags.swift | 2 +- ...sil_differentiability_witness_silgen.swift | 65 ++++++++++++++++--- test/AutoDiff/silgen_thunking/main.swift | 6 +- test/AutoDiff/witness_table_irgen.sil | 4 +- test/AutoDiff/witness_table_sil.swift | 12 ++-- 14 files changed, 159 insertions(+), 66 deletions(-) diff --git a/lib/SIL/SILDeclRef.cpp b/lib/SIL/SILDeclRef.cpp index d0fd83c1ded36..78ccd8b9360fe 100644 --- a/lib/SIL/SILDeclRef.cpp +++ b/lib/SIL/SILDeclRef.cpp @@ -949,13 +949,22 @@ SILDeclRef SILDeclRef::getNextOverriddenVTableEntry() const { if (autoDiffDerivativeFunctionIdentifier) { auto overriddenAttrs = overridden.getDecl()->getAttrs().getAttributes(); - if (llvm::none_of(overriddenAttrs, [&](const DifferentiableAttr *attr) { - return attr->getParameterIndices() == - autoDiffDerivativeFunctionIdentifier->getParameterIndices(); - })) { - return SILDeclRef(); + for (const auto *attr : overriddenAttrs) { + if (attr->getParameterIndices() != + autoDiffDerivativeFunctionIdentifier->getParameterIndices()) + continue; + + // TODO(TF-1056): Do we need to check generic signature requirements? + + auto dfi = overridden.autoDiffDerivativeFunctionIdentifier; + overridden.autoDiffDerivativeFunctionIdentifier = + AutoDiffDerivativeFunctionIdentifier::get( + dfi->getKind(), dfi->getParameterIndices(), + attr->getDerivativeGenericSignature(), + getDecl()->getASTContext()); + return overridden; } - return overridden; + return SILDeclRef(); } // SWIFT_ENABLE_TENSORFLOW END return overridden; diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index ed2a8efa1448c..5a9c71638198d 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -774,6 +774,9 @@ void SILGenModule::postEmitFunction(SILDeclRef constant, if (auto *vjpDecl = diffAttr->getVJPFunction()) vjp = getFunction(SILDeclRef(vjpDecl), NotForDefinition); auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0}); + assert((!AFD->getGenericSignature() || diffAttr->getDerivativeGenericSignature()) && + "type-checking should resolve derivative generic signatures for " + "all functions with generic signatures"); AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices, diffAttr->getDerivativeGenericSignature()); emitDifferentiabilityWitness(AFD, F, config, jvp, vjp, diffAttr); diff --git a/lib/SILOptimizer/Mandatory/Differentiation.cpp b/lib/SILOptimizer/Mandatory/Differentiation.cpp index 5562e8e2b5fbf..585bd0404748e 100644 --- a/lib/SILOptimizer/Mandatory/Differentiation.cpp +++ b/lib/SILOptimizer/Mandatory/Differentiation.cpp @@ -237,10 +237,20 @@ static bool diagnoseUnsupportedControlFlow(ADContext &context, /// derivative generic signature (containing requirements), and substitution /// map. Returns true if error is emitted. static bool diagnoseUnsatisfiedRequirements(ADContext &context, + CanSILFunctionType origFnTy, GenericSignature derivativeGenSig, SubstitutionMap substMap, DifferentiationInvoker invoker, SourceLoc loc) { + // If the original function is polymorphic and its generic signature is the + // same as the derivative generic signature, then the requirements are + // satisfied. This check is necessary because the subsequent logic does not + // correctly handle polymorphic original functions. + // TODO(TF-1055): Can be removed after we have a robust solution for TF-1055. + if (origFnTy->getInvocationGenericSignature() && derivativeGenSig && + origFnTy->getInvocationGenericSignature()->isEqual(derivativeGenSig)) + return false; + // If there are no derivative requirements, return false. if (!derivativeGenSig) return false; @@ -528,6 +538,7 @@ emitDerivativeFunctionReference( peerThroughFunctionConversions(original)) { auto loc = originalFRI->getLoc(); auto *originalFn = originalFRI->getReferencedFunctionOrNull(); + assert(originalFn); auto originalFnTy = originalFn->getLoweredFunctionType(); auto *desiredResultIndices = IndexSubset::get(context.getASTContext(), originalFnTy->getNumResults(), @@ -636,8 +647,9 @@ emitDerivativeFunctionReference( substMap = ai->getSubstitutionMap(); } if (diagnoseUnsatisfiedRequirements( - context, minimalWitness->getDerivativeGenericSignature(), substMap, - invoker, original.getLoc().getSourceLoc())) + context, original->getType().castTo(), + minimalWitness->getDerivativeGenericSignature(), substMap, invoker, + original.getLoc().getSourceLoc())) return None; DifferentiabilityWitnessFunctionKind witnessKind; switch (kind) { diff --git a/lib/Sema/DerivedConformanceDifferentiable.cpp b/lib/Sema/DerivedConformanceDifferentiable.cpp index 593bdcc9ecc2d..48f610f39360f 100644 --- a/lib/Sema/DerivedConformanceDifferentiable.cpp +++ b/lib/Sema/DerivedConformanceDifferentiable.cpp @@ -614,16 +614,18 @@ getOrSynthesizeTangentVectorStruct(DerivedConformance &derived, Identifier id) { if (member->getAttrs().hasAttribute() || getter->getAttrs().hasAttribute()) continue; - GenericSignature derivativeGenSig = GenericSignature(); + GenericSignature derivativeGenericSignature = + getter->getGenericSignature(); // If the parent declaration context is an extension, the nominal type may // conditionally conform to `Differentiable`. Use the extension generic // requirements in getter `@differentiable` attributes. if (auto *extDecl = dyn_cast(parentDC->getAsDecl())) - derivativeGenSig = extDecl->getGenericSignature(); + if (auto extGenSig = extDecl->getGenericSignature()) + derivativeGenericSignature = extGenSig; auto *diffableAttr = DifferentiableAttr::create( getter, /*implicit*/ true, SourceLoc(), SourceLoc(), /*linear*/ false, /*parameterIndices*/ IndexSubset::get(C, 1, {0}), - /*jvp*/ None, /*vjp*/ None, derivativeGenSig); + /*jvp*/ None, /*vjp*/ None, derivativeGenericSignature); member->getAttrs().add(diffableAttr); } } diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 1a4c098c0edee..e9757abe0f538 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -3429,13 +3429,16 @@ DifferentiableAttributeParameterIndicesRequest::evaluate( // Start type-checking the arguments of the @differentiable attribute. This // covers 'wrt:', 'jvp:', 'vjp:', and 'where', all of which are optional. + // Note: If there is a 'where' clause, then the generic signature from that + // overwrites this. + GenericSignature derivativeGenSig = original->getGenericSignature(); + // Handle 'where' clause, if it exists. // - Resolve attribute where clause requirements and store in the attribute // for serialization. // - Compute generic signature for autodiff derivative functions based on // the original function's generate signature and the attribute's where // clause requirements. - GenericSignature whereClauseGenSig = GenericSignature(); GenericEnvironment *whereClauseGenEnv = nullptr; if (auto *whereClause = attr->getWhereClause()) { // `@differentiable` attributes on protocol requirements do not support @@ -3507,13 +3510,14 @@ DifferentiableAttributeParameterIndicesRequest::evaluate( // Compute generic signature and environment for autodiff associated // functions. - whereClauseGenSig = std::move(builder).computeGenericSignature( + derivativeGenSig = std::move(builder).computeGenericSignature( attr->getLocation(), /*allowConcreteGenericParams=*/true); - whereClauseGenEnv = whereClauseGenSig->getGenericEnvironment(); - // Store the resolved derivative generic signature in the attribute. - attr->setDerivativeGenericSignature(whereClauseGenSig); + whereClauseGenEnv = derivativeGenSig->getGenericEnvironment(); } + // Store the resolved derivative generic signature in the attribute. + attr->setDerivativeGenericSignature(derivativeGenSig); + // Validate the 'wrt:' parameters. // Get the parsed wrt param indices, which have not yet been checked. @@ -3572,7 +3576,7 @@ DifferentiableAttributeParameterIndicesRequest::evaluate( originalFnTy->getAutoDiffDerivativeFunctionType( checkedWrtParamIndices, /*resultIndex*/ 0, AutoDiffDerivativeFunctionKind::JVP, lookupConformance, - whereClauseGenSig, /*makeSelfParamFirst*/ true); + derivativeGenSig, /*makeSelfParamFirst*/ true); auto isValidJVP = [&](AbstractFunctionDecl *jvpCandidate) -> bool { return checkFunctionSignature( @@ -3596,7 +3600,7 @@ DifferentiableAttributeParameterIndicesRequest::evaluate( originalFnTy->getAutoDiffDerivativeFunctionType( checkedWrtParamIndices, /*resultIndex*/ 0, AutoDiffDerivativeFunctionKind::VJP, lookupConformance, - whereClauseGenSig, /*makeSelfParamFirst*/ true); + derivativeGenSig, /*makeSelfParamFirst*/ true); auto isValidVJP = [&](AbstractFunctionDecl *vjpCandidate) -> bool { return checkFunctionSignature( @@ -3652,7 +3656,7 @@ DifferentiableAttributeParameterIndicesRequest::evaluate( // Register derivative function configuration. auto *resultIndices = IndexSubset::get(ctx, 1, {0}); original->addDerivativeFunctionConfiguration( - {checkedWrtParamIndices, resultIndices, whereClauseGenSig}); + {checkedWrtParamIndices, resultIndices, derivativeGenSig}); return checkedWrtParamIndices; } diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index 40f0321f597e2..28782cd669a2b 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -556,29 +556,47 @@ swift::matchWitness( (void)reqDiffAttr->getParameterIndices(); } for (auto *reqDiffAttr : reqAttrs.getAttributes()) { - bool foundExactAttr = false; - bool foundSupersetAttr = false; + bool foundExactConfig = false; + Optional supersetConfig = None; for (auto witnessConfig : witnessAFD->getDerivativeFunctionConfigurations()) { - // We can't use witnesses that have generic signatures not satisfied by - // the requirement's generic signature. - if (witnessConfig.derivativeGenericSignature && - !witnessConfig.derivativeGenericSignature - ->requirementsNotSatisfiedBy( - reqDiffAttr->getDerivativeGenericSignature()) - .empty()) - continue; + // All the witness's derivative generic requirements must be satisfied + // by the requirement's derivative generic requirements OR by the + // conditional conformance requirements. + if (witnessConfig.derivativeGenericSignature) { + bool genericRequirementsSatisfied = true; + auto reqDiffGenSig = reqDiffAttr->getDerivativeGenericSignature(); + auto conformanceGenSig = dc->getGenericSignatureOfContext(); + for (const auto &req : + witnessConfig.derivativeGenericSignature->getRequirements()) { + auto substReq = req.subst(result.WitnessSubstitutions); + bool reqDiffGenSigSatisfies = + reqDiffGenSig && substReq && + reqDiffGenSig->isRequirementSatisfied(*substReq); + bool conformanceGenSigSatisfies = + conformanceGenSig && + conformanceGenSig->isRequirementSatisfied(req); + if (!reqDiffGenSigSatisfies && !conformanceGenSigSatisfies) { + genericRequirementsSatisfied = false; + break; + } + } + if (!genericRequirementsSatisfied) + continue; + } if (witnessConfig.parameterIndices == - reqDiffAttr->getParameterIndices()) - foundExactAttr = true; + reqDiffAttr->getParameterIndices()) { + foundExactConfig = true; + break; + } if (witnessConfig.parameterIndices->isSupersetOf( reqDiffAttr->getParameterIndices())) - foundSupersetAttr = true; + supersetConfig = witnessConfig; } - if (!foundExactAttr) { + if (!foundExactConfig) { bool success = false; - if (foundSupersetAttr) { + if (supersetConfig) { // If the witness has a "superset" derivative configuration, create an // implicit `@differentiable` attribute with the exact requirement // `@differentiable` attribute parameter indices. @@ -586,7 +604,7 @@ swift::matchWitness( witnessAFD, /*implicit*/ true, reqDiffAttr->AtLoc, reqDiffAttr->getRange(), reqDiffAttr->isLinear(), reqDiffAttr->getParameterIndices(), /*jvp*/ None, - /*vjp*/ None, reqDiffAttr->getDerivativeGenericSignature()); + /*vjp*/ None, supersetConfig->derivativeGenericSignature); auto insertion = ctx.DifferentiableAttrs.try_emplace( {witnessAFD, newAttr->getParameterIndices()}, newAttr); // Valid `@differentiable` attributes are uniqued by original function diff --git a/test/AutoDiff/derivative_registration.swift b/test/AutoDiff/derivative_registration.swift index 31409bfee5a50..19464df06b925 100644 --- a/test/AutoDiff/derivative_registration.swift +++ b/test/AutoDiff/derivative_registration.swift @@ -153,10 +153,7 @@ DerivativeRegistrationTests.testWithLeakChecking("DerivativeGenericSignature") { let generic = Generic() let x: Tracked = 3 let dx = gradient(at: x) { x in generic.instanceMethod(x) } - // NOTE(TF-1046): `gradient(at:in:)` calls the generated derivative for - // `Generic.instanceMethod` is used, not the registered derivative. This - // behavior is likely not expected by users; TF-1046 will fix this. - expectEqual(1, dx) + expectEqual(1000, dx) } runAllTests() diff --git a/test/AutoDiff/loadable_by_address_cross_module.swift b/test/AutoDiff/loadable_by_address_cross_module.swift index 7941309b598bb..75a7e6dfdc90f 100644 --- a/test/AutoDiff/loadable_by_address_cross_module.swift +++ b/test/AutoDiff/loadable_by_address_cross_module.swift @@ -14,14 +14,15 @@ // Next, check that differentiability_witness_functions in the client get // correctly modified by LBA. +// RUN: %target-swift-frontend -emit-sil -I%t %s // RUN: %target-swift-frontend -emit-sil -I%t %s | %FileCheck %s -check-prefix=CHECK-CLIENT-PRE-LBA // RUN: %target-swift-frontend -c -I%t %s -Xllvm -sil-print-after=loadable-address 2>&1 | %FileCheck %s -check-prefix=CHECK-CLIENT-POST-LBA -// CHECK-CLIENT-PRE-LBA: differentiability_witness_function [jvp] [parameters 0 1] [results 0] @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, LargeLoadableType<τ_0_0>) -> Float -// CHECK-CLIENT-PRE-LBA: differentiability_witness_function [vjp] [parameters 0 1] [results 0] @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, LargeLoadableType<τ_0_0>) -> Float +// CHECK-CLIENT-PRE-LBA: differentiability_witness_function [jvp] [parameters 0 1] [results 0] @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, LargeLoadableType<τ_0_0>) -> Float +// CHECK-CLIENT-PRE-LBA: differentiability_witness_function [vjp] [parameters 0 1] [results 0] @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, LargeLoadableType<τ_0_0>) -> Float -// CHECK-CLIENT-POST-LBA: differentiability_witness_function [jvp] [parameters 0 1] [results 0] @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> Float as $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> (Float, @owned @callee_guaranteed (Float, @in_constant LargeLoadableType<τ_0_0>) -> Float) -// CHECK-CLIENT-POST-LBA: differentiability_witness_function [vjp] [parameters 0 1] [results 0] @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> Float as $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> (Float, @owned @callee_guaranteed (Float) -> (Float, LargeLoadableType<τ_0_0>)) +// CHECK-CLIENT-POST-LBA: differentiability_witness_function [jvp] [parameters 0 1] [results 0] @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> Float as $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> (Float, @owned @callee_guaranteed (Float, @in_constant LargeLoadableType<τ_0_0>) -> Float) +// CHECK-CLIENT-POST-LBA: differentiability_witness_function [vjp] [parameters 0 1] [results 0] @${{.*}}LBAModifiedFunction{{.*}} : $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> Float as $@convention(method) <τ_0_0> (Float, @in_constant LargeLoadableType<τ_0_0>) -> (Float, @owned @callee_guaranteed (Float) -> (Float, LargeLoadableType<τ_0_0>)) // Finally, execute the test. diff --git a/test/AutoDiff/nonvaried_result.swift b/test/AutoDiff/nonvaried_result.swift index 18a5fbe0a6655..8d3805f0bbb5a 100644 --- a/test/AutoDiff/nonvaried_result.swift +++ b/test/AutoDiff/nonvaried_result.swift @@ -44,7 +44,7 @@ NonVariedResultTests.testWithLeakChecking("SingleBasicBlockGeneric") { expectEqual((0, 0, 0), gradient(at: 3, 4, 5) { simpleGeneric($0, $1, $2) }) } -// CHECK-LABEL: sil private [ossa] @AD__${{.*}}simpleGeneric{{.*}}pullback_src_0_wrt_0_1_2 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 == τ_0_0.TangentVector> (@in_guaranteed τ_0_0.TangentVector, @owned _AD__$s4nullyycfU0_13simpleGenericL_yxx_x23DifferentiationUnittest7TrackedVySfGts14DifferentiableRz13TangentVectorsAGPQzRszlF_bb0__PB__src_0_wrt_0_1_2<τ_0_0>) -> (@out τ_0_0.TangentVector, @out τ_0_0.TangentVector, @owned Tracked) { +// CHECK-LABEL: sil private [ossa] @AD__${{.*}}simpleGeneric{{.*}}pullback_src_0_wrt_0_1_2{{.*}} : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 == τ_0_0.TangentVector> (@in_guaranteed τ_0_0.TangentVector, @owned _AD__$s4nullyycfU0_13simpleGenericL_yxx_x23DifferentiationUnittest7TrackedVySfGts14DifferentiableRz13TangentVectorsAGPQzRszlF_bb0__PB__src_0_wrt_0_1_2<τ_0_0>) -> (@out τ_0_0.TangentVector, @out τ_0_0.TangentVector, @owned Tracked) { // CHECK: bb0([[DX:%.*]] : $*τ_0_0, [[DY:%.*]] : $*τ_0_0, [[SEED:%.*]] : $*τ_0_0, [[PB_STRUCT:%.*]] : [[PB_STRUCT_TYPE:.*]]): // CHECK: [[ZERO_FN_X:%.*]] = witness_method $τ_0_0, #AdditiveArithmetic.zero!getter.1 : (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 // CHECK: [[METATYPE_X:%.*]] = metatype $@thick τ_0_0.Type @@ -150,7 +150,7 @@ NonVariedResultTests.testWithLeakChecking("ComplexGeneric") { expectEqual(0, pullback(at: Tracked(3)) { complexGeneric(10, $0) }(1)) } -// CHECK-LABEL: sil private [ossa] @AD__${{.*}}complexGeneric{{.*}}pullback_src_0_wrt_1 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, @owned _AD__$s4nullyycfU4_14complexGenericL_yxx_xts14DifferentiableRzlF_bb9__PB__src_0_wrt_1<τ_0_0>) -> @out τ_0_0.TangentVector { +// CHECK-LABEL: sil private [ossa] @AD__${{.*}}complexGeneric{{.*}}pullback_src_0_wrt_1{{.*}} : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0.TangentVector, @owned _AD__$s4nullyycfU4_14complexGenericL_yxx_xts14DifferentiableRzlF_bb9__PB__src_0_wrt_1<τ_0_0>) -> @out τ_0_0.TangentVector { // CHECK: bb0([[DY:%.*]] : $*τ_0_0.TangentVector, [[SEED:%.*]] : $*τ_0_0.TangentVector, [[PB_STRUCT:%.*]] : @owned [[PB_STRUCT_TYPE:.*]]): // CHECK: destroy_value [[PB_STRUCT]] : [[PB_STRUCT_TYPE]] // CHECK: [[ZERO_FN:%.*]] = witness_method $τ_0_0.TangentVector, #AdditiveArithmetic.zero!getter.1 : (Self.Type) -> () -> Self : $@convention(witness_method: AdditiveArithmetic) <τ_0_0 where τ_0_0 : AdditiveArithmetic> (@thick τ_0_0.Type) -> @out τ_0_0 diff --git a/test/AutoDiff/protocol_requirement_autodiff_diags.swift b/test/AutoDiff/protocol_requirement_autodiff_diags.swift index ec1fc3d508add..4dd714a5bf18f 100644 --- a/test/AutoDiff/protocol_requirement_autodiff_diags.swift +++ b/test/AutoDiff/protocol_requirement_autodiff_diags.swift @@ -13,6 +13,6 @@ struct AttemptsToSatisfyRequirement: HasRequirement { // This does not satisfy the requirement because the differentiable attribute is more // constrained than the requirement's differentiable attribute. @differentiable(where T: P) - // expected-note @+1 {{candidate is missing attribute '@differentiable'}} + // expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: (x, y))'}} func requirement(_ x: T, _ y: T) -> T { x } } diff --git a/test/AutoDiff/sil_differentiability_witness_silgen.swift b/test/AutoDiff/sil_differentiability_witness_silgen.swift index b8a917d31a6b3..7c34e914fe0bb 100644 --- a/test/AutoDiff/sil_differentiability_witness_silgen.swift +++ b/test/AutoDiff/sil_differentiability_witness_silgen.swift @@ -162,11 +162,6 @@ public func wrt_subset_vjp_wrt_x_y(_ tup: (Int, Int), _ x: Float, _ y: Float) -> // CHECK-NEXT: } // Test original function with `@differentiable` and `@derivative` attributes. -// NOTE(TF-1046): The `@differentiable` and `@derivative` attribute currently -// have different derivative generic signatures, causing two differentiability -// witnesses to be created. This behavior is unexpected for users; TF-1046 will -// resolve this issue so that only one differentiability witness will be -// created. protocol P1: Differentiable {} extension P1 { @@ -180,11 +175,63 @@ extension P1 { } } -// CHECK-LABEL: // differentiability witness for P1.foo() -// CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen2P1PAAE3fooSfyF : $@convention(method) (@in_guaranteed Self) -> Float { -// CHECK-NEXT: } - // CHECK-LABEL: // differentiability witness for P1.foo() // CHECK-NEXT: sil_differentiability_witness hidden [parameters 0] [results 0] <τ_0_0 where τ_0_0 : P1> @$s36sil_differentiability_witness_silgen2P1PAAE3fooSfyF : $@convention(method) (@in_guaranteed Self) -> Float { // CHECK-NEXT: vjp: @AD__$s36sil_differentiability_witness_silgen2P1PAAE3fooSfyF__vjp_src_0_wrt_0_36sil_differentiability_witness_silgen2P1Rzl : $@convention(method) <τ_0_0 where τ_0_0 : P1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed (Float) -> @out τ_0_0.TangentVector) // CHECK-NEXT: } + +// Test custom derivatives of functions with generic signatures and `@differentiable` attributes. + +@differentiable +@_silgen_name("genericWithDiffAttr") +public func genericWithDiffAttr(_ x: T) -> T { fatalError() } + +@derivative(of: genericWithDiffAttr) +public func vjpGenericWithDiffAttr(_ x: T) + -> (value: T, pullback: (T.TangentVector) -> T.TangentVector) +{ + fatalError() +} + +// CHECK-LABEL: // differentiability witness for genericWithDiffAttr +// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @genericWithDiffAttr : $@convention(thin) (@in_guaranteed T) -> @out T { +// CHECK-NEXT: vjp +// CHECK-NEXT: } + +// CHECK-NOT: // differentiability witness for genericWithDiffAttr + +@differentiable(where T: Differentiable) +@_silgen_name("genericWithConstrainedDifferentiable") +public func genericWithConstrainedDifferentiable(_ x: T) -> T { fatalError() } + +@derivative(of: genericWithConstrainedDifferentiable) +public func vjpGenericWithConstrainedDifferentiable(_ x: T) + -> (value: T, pullback: (T.TangentVector) -> T.TangentVector) +{ + fatalError() +} + +// CHECK-LABEL: // differentiability witness for genericWithConstrainedDifferentiable +// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @genericWithConstrainedDifferentiable : $@convention(thin) (@in_guaranteed T) -> @out T { +// CHECK-NEXT: vjp +// CHECK-NEXT: } + +// CHECK-NOT: // differentiability witness for genericWithConstrainedDifferentiable + +public extension Differentiable { + @differentiable + @_silgen_name("protocolExtensionWithDiffAttr") + func protocolExtensionWithDiffAttr() -> Self { self } + + @derivative(of: protocolExtensionWithDiffAttr) + func protocolExtensionWithDiffAttr() -> (value: Self, pullback: (TangentVector) -> TangentVector) { + fatalError("unimplemented") + } +} + +// CHECK-LABEL: // differentiability witness for protocolExtensionWithDiffAttr +// CHECK-NEXT: sil_differentiability_witness [serialized] [parameters 0] [results 0] <τ_0_0 where τ_0_0 : Differentiable> @protocolExtensionWithDiffAttr : $@convention(method) (@in_guaranteed Self) -> @out Self { +// CHECK-NEXT: vjp +// CHECK-NEXT: } + +// CHECK-NOT: // differentiability witness for protocolExtensionWithDiffAttr diff --git a/test/AutoDiff/silgen_thunking/main.swift b/test/AutoDiff/silgen_thunking/main.swift index 3236b8698a907..8aaa033456544 100644 --- a/test/AutoDiff/silgen_thunking/main.swift +++ b/test/AutoDiff/silgen_thunking/main.swift @@ -19,7 +19,7 @@ func vjpNoReabstraction(_ x: T) -> (T, (T.TangentVector) -> T return (x, { $0 }) } // Find the non-`[transparent]` SILGen thunk. -// CHECK-LABEL: sil hidden [thunk] [always_inline] [ossa] @AD__$s4main15noReabstractionyxxs14DifferentiableRzlF__vjp_src_0_wrt_0 : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> @out τ_0_0.TangentVector) +// CHECK-LABEL: sil hidden [thunk] [always_inline] [ossa] @AD__$s4main15noReabstractionyxxs14DifferentiableRzlF__vjp_src_0_wrt_0{{.*}} : $@convention(thin) <τ_0_0 where τ_0_0 : Differentiable> (@in_guaranteed τ_0_0) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> @out τ_0_0.TangentVector) var DerivativeSILGenThunkTests = TestSuite("DerivativeSILGenThunks") @@ -115,7 +115,7 @@ where Dummy: Differentiable & ExpressibleByIntegerLiteral { return (value, { v in (v, 2.0, 3.0) }) } -// CHECK-LABEL: sil hidden [always_inline] [ossa] @AD__$s4main21SelfReorderingGenericV20threeParameterMethodyACyxGqd___qd_0_ts14DifferentiableRd__sAFRd_0_s25ExpressibleByFloatLiteral13TangentVectorRpd__sAgHRpd_0_r0_lF__jvp_src_0_wrt_0_1_2 : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : ExpressibleByIntegerLiteral><τ_1_0, τ_1_1 where τ_1_0 : Differentiable, τ_1_1 : Differentiable, τ_1_0.TangentVector : ExpressibleByFloatLiteral, τ_1_1.TangentVector : ExpressibleByFloatLiteral> (@in_guaranteed τ_1_0, @in_guaranteed τ_1_1, @in_guaranteed SelfReorderingGeneric<τ_0_0>) -> (@out SelfReorderingGeneric<τ_0_0>, @owned @callee_guaranteed (@in_guaranteed τ_1_0.TangentVector, @in_guaranteed τ_1_1.TangentVector, @in_guaranteed SelfReorderingGeneric<τ_0_0>.TangentVector) -> @out SelfReorderingGeneric<τ_0_0>.TangentVector) { +// CHECK-LABEL: sil hidden [always_inline] [ossa] @AD__$s4main21SelfReorderingGenericV20threeParameterMethodyACyxGqd___qd_0_ts14DifferentiableRd__sAFRd_0_s25ExpressibleByFloatLiteral13TangentVectorRpd__sAgHRpd_0_r0_lF__jvp_src_0_wrt_0_1_2{{.*}} : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : ExpressibleByIntegerLiteral><τ_1_0, τ_1_1 where τ_1_0 : Differentiable, τ_1_1 : Differentiable, τ_1_0.TangentVector : ExpressibleByFloatLiteral, τ_1_1.TangentVector : ExpressibleByFloatLiteral> (@in_guaranteed τ_1_0, @in_guaranteed τ_1_1, @in_guaranteed SelfReorderingGeneric<τ_0_0>) -> (@out SelfReorderingGeneric<τ_0_0>, @owned @callee_guaranteed (@in_guaranteed τ_1_0.TangentVector, @in_guaranteed τ_1_1.TangentVector, @in_guaranteed SelfReorderingGeneric<τ_0_0>.TangentVector) -> @out SelfReorderingGeneric<τ_0_0>.TangentVector) { // CHECK: bb0([[JVP_RESULT:%.*]] : $*SelfReorderingGeneric<τ_0_0>, [[X:%.*]] : $*τ_1_0, [[Y:%.*]] : $*τ_1_1, [[SELF:%.*]] : $*SelfReorderingGeneric<τ_0_0>): // CHECK: [[JVP:%.*]] = function_ref @$s4main21SelfReorderingGenericV23jvpThreeParameterMethodyACyxG_AC13TangentVectorVyx_GAH_AFQyd__AFQyd_0_tctqd___qd_0_ts14DifferentiableRd__sAKRd_0_s25ExpressibleByFloatLiteralAIRQsAlJRQr0_lF // CHECK: [[DF:%.*]] = apply [[JVP]]<τ_0_0, τ_1_0, τ_1_1>([[JVP_RESULT]], [[X]], [[Y]], [[SELF]]) @@ -129,7 +129,7 @@ where Dummy: Differentiable & ExpressibleByIntegerLiteral { // CHECK: [[VOID:%.*]] = tuple () // CHECK: return [[VOID]] -// CHECK-LABEL: sil hidden [always_inline] [ossa] @AD__$s4main21SelfReorderingGenericV20threeParameterMethodyACyxGqd___qd_0_ts14DifferentiableRd__sAFRd_0_s25ExpressibleByFloatLiteral13TangentVectorRpd__sAgHRpd_0_r0_lF__vjp_src_0_wrt_0_1_2 : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : ExpressibleByIntegerLiteral><τ_1_0, τ_1_1 where τ_1_0 : Differentiable, τ_1_1 : Differentiable, τ_1_0.TangentVector : ExpressibleByFloatLiteral, τ_1_1.TangentVector : ExpressibleByFloatLiteral> (@in_guaranteed τ_1_0, @in_guaranteed τ_1_1, @in_guaranteed SelfReorderingGeneric<τ_0_0>) -> (@out SelfReorderingGeneric<τ_0_0>, @owned @callee_guaranteed (@in_guaranteed SelfReorderingGeneric<τ_0_0>.TangentVector) -> (@out τ_1_0.TangentVector, @out τ_1_1.TangentVector, @out SelfReorderingGeneric<τ_0_0>.TangentVector)) { +// CHECK-LABEL: sil hidden [always_inline] [ossa] @AD__$s4main21SelfReorderingGenericV20threeParameterMethodyACyxGqd___qd_0_ts14DifferentiableRd__sAFRd_0_s25ExpressibleByFloatLiteral13TangentVectorRpd__sAgHRpd_0_r0_lF__vjp_src_0_wrt_0_1_2{{.*}} : $@convention(method) <τ_0_0 where τ_0_0 : Differentiable, τ_0_0 : ExpressibleByIntegerLiteral><τ_1_0, τ_1_1 where τ_1_0 : Differentiable, τ_1_1 : Differentiable, τ_1_0.TangentVector : ExpressibleByFloatLiteral, τ_1_1.TangentVector : ExpressibleByFloatLiteral> (@in_guaranteed τ_1_0, @in_guaranteed τ_1_1, @in_guaranteed SelfReorderingGeneric<τ_0_0>) -> (@out SelfReorderingGeneric<τ_0_0>, @owned @callee_guaranteed (@in_guaranteed SelfReorderingGeneric<τ_0_0>.TangentVector) -> (@out τ_1_0.TangentVector, @out τ_1_1.TangentVector, @out SelfReorderingGeneric<τ_0_0>.TangentVector)) { // CHECK: bb0([[VJP_RESULT:%.*]] : $*SelfReorderingGeneric<τ_0_0>, [[X:%.*]] : $*τ_1_0, [[Y:%.*]] : $*τ_1_1, [[SELF:%.*]] : $*SelfReorderingGeneric<τ_0_0>): // CHECK: [[VJP:%.*]] = function_ref @$s4main21SelfReorderingGenericV23vjpThreeParameterMethodyACyxG_AC13TangentVectorVyx_G_AFQyd__AFQyd_0_tAHctqd___qd_0_ts14DifferentiableRd__sAKRd_0_s25ExpressibleByFloatLiteralAIRQsAlJRQr0_lF // CHECK: [[PB:%.*]] = apply [[VJP]]<τ_0_0, τ_1_0, τ_1_1>([[VJP_RESULT]], [[X]], [[Y]], [[SELF]]) diff --git a/test/AutoDiff/witness_table_irgen.sil b/test/AutoDiff/witness_table_irgen.sil index ad78bde96def2..5d3eb786ca518 100644 --- a/test/AutoDiff/witness_table_irgen.sil +++ b/test/AutoDiff/witness_table_irgen.sil @@ -65,8 +65,8 @@ bb0(%0 : $Float, %1 : $AD__$s23witness_tables_autodiff25DifferentiableConformanc sil_witness_table hidden DifferentiableConformance: DifferentiableRequirement module witness_tables_autodiff { method #DifferentiableRequirement.f!1: (Self) -> (Float) -> Float : @$s23witness_tables_autodiff25DifferentiableConformanceVAA0D11RequirementA2aDP1fyS2fFTW // protocol witness for DifferentiableRequirement.f(_:) in conformance DifferentiableConformance - method #DifferentiableRequirement.f!1.jvp.SU: (Self) -> (Float) -> Float : @AD__$s23witness_tables_autodiff25DifferentiableConformanceVAA0D11RequirementA2aDP1fyS2fFTW_jvp_SU // AD__$s23witness_tables_autodiff25DifferentiableConformanceVAA0D11RequirementA2aDP1fyS2fFTW_jvp_SU - method #DifferentiableRequirement.f!1.vjp.SU: (Self) -> (Float) -> Float : @AD__$s23witness_tables_autodiff25DifferentiableConformanceVAA0D11RequirementA2aDP1fyS2fFTW_vjp_SU // AD__$s23witness_tables_autodiff25DifferentiableConformanceVAA0D11RequirementA2aDP1fyS2fFTW_vjp_SU + method #DifferentiableRequirement.f!1.jvp.SU.: (Self) -> (Float) -> Float : @AD__$s23witness_tables_autodiff25DifferentiableConformanceVAA0D11RequirementA2aDP1fyS2fFTW_jvp_SU // AD__$s23witness_tables_autodiff25DifferentiableConformanceVAA0D11RequirementA2aDP1fyS2fFTW_jvp_SU + method #DifferentiableRequirement.f!1.vjp.SU.: (Self) -> (Float) -> Float : @AD__$s23witness_tables_autodiff25DifferentiableConformanceVAA0D11RequirementA2aDP1fyS2fFTW_vjp_SU // AD__$s23witness_tables_autodiff25DifferentiableConformanceVAA0D11RequirementA2aDP1fyS2fFTW_vjp_SU } // CHECK: @"$s19witness_table_irgen25DifferentiableConformanceVAA0D11RequirementAAWP" = hidden constant [4 x i8*] [ diff --git a/test/AutoDiff/witness_table_sil.swift b/test/AutoDiff/witness_table_sil.swift index cefe097cc0e87..a2bc1ad040068 100644 --- a/test/AutoDiff/witness_table_sil.swift +++ b/test/AutoDiff/witness_table_sil.swift @@ -85,12 +85,12 @@ struct S : Proto, AdditiveArithmetic { // CHECK-LABEL: sil_witness_table hidden S: Proto module witness_table_sil { // CHECK-NEXT: base_protocol Differentiable: S: Differentiable module witness_table_sil // CHECK-NEXT: method #Proto.function1!1: (Self) -> (Float, Double) -> Float : @{{.*}}function1 -// CHECK-NEXT: method #Proto.function1!1.jvp.SSU: (Self) -> (Float, Double) -> Float : @AD__{{.*}}function1{{.*}}_jvp_SSU -// CHECK-NEXT: method #Proto.function1!1.vjp.SSU: (Self) -> (Float, Double) -> Float : @AD__{{.*}}function1{{.*}}_vjp_SSU +// CHECK-NEXT: method #Proto.function1!1.jvp.SSU.: (Self) -> (Float, Double) -> Float : @AD__{{.*}}function1{{.*}}_jvp_SSU +// CHECK-NEXT: method #Proto.function1!1.vjp.SSU.: (Self) -> (Float, Double) -> Float : @AD__{{.*}}function1{{.*}}_vjp_SSU // CHECK-NEXT: method #Proto.function2!1: (Self) -> (Float, Double) -> Float : @{{.*}}function2 -// CHECK-NEXT: method #Proto.function2!1.jvp.SSS: (Self) -> (Float, Double) -> Float : @AD__{{.*}}function2{{.*}}_jvp_SSS -// CHECK-NEXT: method #Proto.function2!1.vjp.SSS: (Self) -> (Float, Double) -> Float : @AD__{{.*}}function2{{.*}}_vjp_SSS +// CHECK-NEXT: method #Proto.function2!1.jvp.SSS.: (Self) -> (Float, Double) -> Float : @AD__{{.*}}function2{{.*}}_jvp_SSS +// CHECK-NEXT: method #Proto.function2!1.vjp.SSS.: (Self) -> (Float, Double) -> Float : @AD__{{.*}}function2{{.*}}_vjp_SSS // CHECK-NEXT: method #Proto.function3!1: (Self) -> (Float, Double) -> Double : @{{.*}}function3 -// CHECK-NEXT: method #Proto.function3!1.jvp.USU: (Self) -> (Float, Double) -> Double : @AD__{{.*}}function3{{.*}}_jvp_USU -// CHECK-NEXT: method #Proto.function3!1.vjp.USU: (Self) -> (Float, Double) -> Double : @AD__{{.*}}function3{{.*}}_vjp_USU +// CHECK-NEXT: method #Proto.function3!1.jvp.USU.: (Self) -> (Float, Double) -> Double : @AD__{{.*}}function3{{.*}}_jvp_USU +// CHECK-NEXT: method #Proto.function3!1.vjp.USU.: (Self) -> (Float, Double) -> Double : @AD__{{.*}}function3{{.*}}_vjp_USU // CHECK-NEXT:}