diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index a79a25062cded..a786b3f2e5f6a 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -767,19 +767,25 @@ void SILGenModule::postEmitFunction(SILDeclRef constant, constant.kind != SILDeclRef::Kind::DefaultArgGenerator && !constant.isThunk()) { auto *AFD = constant.getAbstractFunctionDecl(); - // Visit all `@differentiable` attributes. - for (auto *diffAttr : AFD->getAttrs().getAttributes()) { - SILFunction *jvp = nullptr; - SILFunction *vjp = nullptr; - if (auto *jvpDecl = diffAttr->getJVPFunction()) - jvp = getFunction(SILDeclRef(jvpDecl), NotForDefinition); - if (auto *vjpDecl = diffAttr->getVJPFunction()) - vjp = getFunction(SILDeclRef(vjpDecl), NotForDefinition); - auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0}); - AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices, - diffAttr->getDerivativeGenericSignature().getPointer()); - emitDifferentiabilityWitness(AFD, F, config, jvp, vjp); - } + auto emitWitnesses = [&](DeclAttributes &Attrs) { + for (auto *diffAttr : Attrs.getAttributes()) { + SILFunction *jvp = nullptr; + SILFunction *vjp = nullptr; + if (auto *jvpDecl = diffAttr->getJVPFunction()) + jvp = getFunction(SILDeclRef(jvpDecl), NotForDefinition); + if (auto *vjpDecl = diffAttr->getVJPFunction()) + vjp = getFunction(SILDeclRef(vjpDecl), NotForDefinition); + auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0}); + AutoDiffConfig config( + diffAttr->getParameterIndices(), resultIndices, + diffAttr->getDerivativeGenericSignature().getPointer()); + emitDifferentiabilityWitness(AFD, F, config, jvp, vjp); + } + }; + if (auto *accessor = dyn_cast(AFD)) + if (accessor->isGetter()) + emitWitnesses(accessor->getStorage()->getAttrs()); + emitWitnesses(AFD->getAttrs()); } F->verify(); } @@ -817,21 +823,6 @@ void SILGenModule::emitDifferentiabilityWitness( CanGenericSignature derivativeCanGenSig; if (auto derivativeGenSig = config.derivativeGenericSignature) derivativeCanGenSig = derivativeGenSig->getCanonicalSignature(); - // TODO(TF-835): Use simpler derivative generic signature logic below when - // type-checking no longer generates implicit `@differentiable` attributes. - // See TF-835 for replacement code. - if (jvp) { - auto jvpCanGenSig = jvp->getLoweredFunctionType()->getSubstGenericSignature(); - if (!derivativeCanGenSig && jvpCanGenSig) - derivativeCanGenSig = jvpCanGenSig; - assert(derivativeCanGenSig == jvpCanGenSig); - } - if (vjp) { - auto vjpCanGenSig = vjp->getLoweredFunctionType()->getSubstGenericSignature(); - if (!derivativeCanGenSig && vjpCanGenSig) - derivativeCanGenSig = vjpCanGenSig; - assert(derivativeCanGenSig == vjpCanGenSig); - } // Create new SIL differentiability witness. // Witness JVP and VJP are set below. // TODO(TF-919): Explore creating serialized differentiability witnesses. diff --git a/lib/TBDGen/TBDGen.cpp b/lib/TBDGen/TBDGen.cpp index bba081a5928fa..7cb9040527d2e 100644 --- a/lib/TBDGen/TBDGen.cpp +++ b/lib/TBDGen/TBDGen.cpp @@ -195,20 +195,9 @@ void TBDGenVisitor::addDifferentiabilityWitness( attr->getParameterIndices(), original->getInterfaceType()->castTo()); - GenericSignature genericSignature = attr->getDerivativeGenericSignature(); - if (auto *jvpDecl = attr->getJVPFunction()) { - assert(!genericSignature || - jvpDecl->getGenericSignature()->isEqual(genericSignature)); - genericSignature = jvpDecl->getGenericSignature(); - } - if (auto *vjpDecl = attr->getVJPFunction()) { - assert(!genericSignature || - vjpDecl->getGenericSignature()->isEqual(genericSignature)); - genericSignature = vjpDecl->getGenericSignature(); - } - std::string originalMangledName = SILDeclRef(original).mangle(); - AutoDiffConfig config{loweredParamIndices, resultIndices, genericSignature}; + AutoDiffConfig config{loweredParamIndices, resultIndices, + attr->getDerivativeGenericSignature()}; SILDifferentiabilityWitnessKey key(originalMangledName, config); Mangle::ASTMangler mangle; diff --git a/test/AutoDiff/sil_differentiability_witness_silgen.swift b/test/AutoDiff/sil_differentiability_witness_silgen.swift index 5d784fb5ac960..22d763f6bdee9 100644 --- a/test/AutoDiff/sil_differentiability_witness_silgen.swift +++ b/test/AutoDiff/sil_differentiability_witness_silgen.swift @@ -74,6 +74,10 @@ func generic_vjp(_ x: T, _ y: Float) -> ( public struct Foo: Differentiable { public var x: Float +// CHECK-LABEL: // differentiability witness for Foo.x.getter +// CHECK-NEXT: sil_differentiability_witness [parameters 0] [results 0] @$s36sil_differentiability_witness_silgen3FooV1xSfvg : $@convention(method) (Foo) -> Float { +// CHECK-NEXT: } + @differentiable public init(_ x: Float) { self.x = x