diff --git a/lib/SIL/SILPrinter.cpp b/lib/SIL/SILPrinter.cpp index 0f8ce5b8caacc..331fc3afd25be 100644 --- a/lib/SIL/SILPrinter.cpp +++ b/lib/SIL/SILPrinter.cpp @@ -3085,6 +3085,51 @@ void SILDefaultWitnessTable::dump() const { print(llvm::errs()); } +// TODO(TF-893): Use this helper to dedupe the same logic in +// `SILFunction::print`. +static void printSILFunctionNameAndType( + llvm::raw_ostream &OS, SILFunction *function) { + function->printName(OS); + OS << " : $"; + llvm::DenseMap Aliases; + llvm::DenseSet UsedNames; + auto sig = function->getLoweredFunctionType()->getGenericSignature(); + auto *env = function->getGenericEnvironment(); + if (sig && env) { + llvm::SmallString<16> disambiguatedNameBuf; + unsigned disambiguatedNameCounter = 1; + for (auto *paramTy : sig->getGenericParams()) { + auto sugaredTy = env->getSugaredType(paramTy); + Identifier name = sugaredTy->getName(); + while (!UsedNames.insert(name).second) { + disambiguatedNameBuf.clear(); + { + llvm::raw_svector_ostream names(disambiguatedNameBuf); + names << sugaredTy->getName() << disambiguatedNameCounter++; + } + name = function->getASTContext().getIdentifier(disambiguatedNameBuf); + } + if (name != sugaredTy->getName()) { + Aliases[paramTy->getCanonicalType()] = name; + + // Also for the archetype + auto archetypeTy = env->mapTypeIntoContext(paramTy) + ->getAs(); + if (archetypeTy) + Aliases[archetypeTy->getCanonicalType()] = name; + } + } + } + + { + PrintOptions withGenericEnvironment = PrintOptions::printSIL(); + withGenericEnvironment.GenericEnv = env; + withGenericEnvironment.AlternativeTypeNames = + Aliases.empty() ? nullptr : &Aliases; + function->getLoweredFunctionType()->print(OS, withGenericEnvironment); + } +} + // SWIFT_ENABLE_TENSORFLOW void SILDifferentiabilityWitness::print( llvm::raw_ostream &OS, bool verbose) const { @@ -3107,7 +3152,7 @@ void SILDifferentiabilityWitness::print( interleave(getResultIndices()->getIndices(), [&](unsigned index) { OS << index; }, [&] { OS << ' '; }); - OS << ']'; + OS << "] "; // ([where ...])? if (auto *derivativeGenSig = getDerivativeGenericSignature()) { ArrayRef requirements; @@ -3123,28 +3168,34 @@ void SILDifferentiabilityWitness::print( } } if (!requirements.empty()) { - OS << " [where "; + OS << "[where "; auto subPrinter = PrintOptions::printSIL(); + subPrinter.GenericEnv = origGenEnv; interleave(requirements, [&](Requirement req) { req.print(OS, subPrinter); }, [&] { OS << ", "; }); - OS << ']'; + OS << "] "; } } // @original-function-name : $original-sil-type - OS << " @" << originalFunction->getName() << " : " - << originalFunction->getLoweredType(); + printSILFunctionNameAndType(OS, originalFunction); // { // jvp: @jvp-function-name : $jvp-sil-type // vjp: @vjp-function-name : $vjp-sil-type // } OS << " {\n"; - if (jvp) - OS << " jvp: @" << jvp->getName() << " : " << jvp->getLoweredType() << '\n'; - if (vjp) - OS << " vjp: @" << vjp->getName() << " : " << vjp->getLoweredType() << '\n'; + if (jvp) { + OS << " jvp: "; + printSILFunctionNameAndType(OS, jvp); + OS << '\n'; + } + if (vjp) { + OS << " vjp: "; + printSILFunctionNameAndType(OS, vjp); + OS << '\n'; + } OS << "}\n\n"; } diff --git a/lib/SILGen/SILGen.cpp b/lib/SILGen/SILGen.cpp index f5b9fdf4411d5..c634da2abc9ef 100644 --- a/lib/SILGen/SILGen.cpp +++ b/lib/SILGen/SILGen.cpp @@ -752,87 +752,135 @@ void SILGenModule::postEmitFunction(SILDeclRef constant, F->print(llvm::dbgs())); // SWIFT_ENABLE_TENSORFLOW - // Create self-reordering thunks for JVPs/VJPs of `@differentiable` methods. - if (constant.hasDecl() && constant.getAbstractFunctionDecl()) { + // Visit `@differentiable` attributes and generate SIL differentiability + // witnesses. + // TODO(TF-835): Visit `@differentiating` attributes when type-checking no + // longer generates implicit `@differentiable` attributes. See TF-835 for + // replacement code. + // Skip if the SILDeclRef is a: + // - Default argument generator function. + // - Thunk. + if (constant.hasDecl() && constant.getAbstractFunctionDecl() && + constant.kind != SILDeclRef::Kind::DefaultArgGenerator && + !constant.isThunk()) { auto *AFD = constant.getAbstractFunctionDecl(); - auto origFnType = AFD->getInterfaceType()->castTo(); - auto origSilFnType = F->getLoweredFunctionType(); - // Jointly iterate over AST `@differentiable` attributes and SIL - // `[differentiable]` attributes. - auto diffAttrs = AFD->getAttrs().getAttributes(); - auto silDiffAttrs = F->getDifferentiableAttrs(); - for (auto pair : llvm::zip(diffAttrs, silDiffAttrs)) { - auto *diffAttr = const_cast(std::get<0>(pair)); - auto *silDiffAttr = std::get<1>(pair); - // Compute lowered parameter indices. - auto *paramIndices = diffAttr->getParameterIndices(); - auto *loweredParamIndices = autodiff::getLoweredParameterIndices( - paramIndices, origFnType); - SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices); - assert(silDiffAttr->getIndices() == indices && - "Expected matching @differentiable and [differentiable] indices"); - - auto lookUpConformance = LookUpConformanceInModule(M.getSwiftModule()); - auto expectedJVPType = origSilFnType->getAutoDiffDerivativeFunctionType( - indices.parameters, indices.source, - AutoDiffDerivativeFunctionKind::JVP, Types, lookUpConformance); - auto expectedVJPType = origSilFnType->getAutoDiffDerivativeFunctionType( - indices.parameters, indices.source, - AutoDiffDerivativeFunctionKind::VJP, Types, lookUpConformance); - - // Self reordering is necessary if wrt at least two parameters, including - // self. - auto shouldReorderSelf = [&]() { - if (!F->hasSelfParam()) - return false; - auto selfParamIndex = origSilFnType->getNumParameters() - 1; - if (!indices.isWrtParameter(selfParamIndex)) - return false; - return indices.parameters->getNumIndices() > 1; - }; - bool reorderSelf = shouldReorderSelf(); - - // Thunk JVP method, if it is defined. - if (auto *jvpDecl = diffAttr->getJVPFunction()) { - SILFunction *jvpThunk; - auto *jvpFn = getFunction(SILDeclRef(jvpDecl), NotForDefinition); - if (reorderSelf || jvpFn->getLoweredFunctionType() != expectedJVPType) { - jvpThunk = getOrCreateAutoDiffDerivativeFunctionThunk( - F, indices, jvpFn, AutoDiffDerivativeFunctionKind::JVP, - reorderSelf); - } else { - auto *id = AutoDiffDerivativeFunctionIdentifier::get( - AutoDiffDerivativeFunctionKind::JVP, - diffAttr->getParameterIndices(), AFD->getASTContext()); - jvpThunk = getOrCreateAutoDiffThunk( - constant.asAutoDiffDerivativeFunction(id), jvpFn, - expectedJVPType); - } - silDiffAttr->setJVPName(jvpThunk->getName()); - } - // Thunk VJP method, if it is defined. - if (auto *vjpDecl = diffAttr->getVJPFunction()) { - SILFunction *vjpThunk; - auto *vjpFn = getFunction(SILDeclRef(vjpDecl), NotForDefinition); - if (reorderSelf || vjpFn->getLoweredFunctionType() != expectedVJPType) { - vjpThunk = getOrCreateAutoDiffDerivativeFunctionThunk( - F, indices, vjpFn, AutoDiffDerivativeFunctionKind::VJP, - reorderSelf); - } else { - auto *id = AutoDiffDerivativeFunctionIdentifier::get( - AutoDiffDerivativeFunctionKind::VJP, - diffAttr->getParameterIndices(), AFD->getASTContext()); - vjpThunk = getOrCreateAutoDiffThunk( - constant.asAutoDiffDerivativeFunction(id), vjpFn, - expectedVJPType); - } - silDiffAttr->setVJPName(vjpThunk->getName()); - } + // Visit all `@differentiable` attributes. + for (auto *diffAttr : AFD->getAttrs().getAttributes()) { + SILFunction *jvp = nullptr; + SILFunction *vjp = nullptr; + if (auto *jvpDecl = diffAttr->getJVPFunction()) + jvp = getFunction(SILDeclRef(jvpDecl), NotForDefinition); + if (auto *vjpDecl = diffAttr->getVJPFunction()) + vjp = getFunction(SILDeclRef(vjpDecl), NotForDefinition); + auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0}); + AutoDiffConfig config{diffAttr->getParameterIndices(), resultIndices, + diffAttr->getDerivativeGenericSignature()}; + emitDifferentiabilityWitness(AFD, F, config, jvp, vjp); } } F->verify(); } +void SILGenModule::emitDifferentiabilityWitness( + AbstractFunctionDecl *originalAFD, SILFunction *originalFunction, + const AutoDiffConfig &config, SILFunction *jvp, SILFunction *vjp) { + auto *origFnType = originalAFD->getInterfaceType()->castTo(); + auto origSilFnType = originalFunction->getLoweredFunctionType(); + auto *loweredParamIndices = autodiff::getLoweredParameterIndices( + config.parameterIndices, origFnType); + // NOTE(TF-893): Extending capacity is necessary when `origSilFnType` has + // parameters corresponding to captured variables. These parameters do not + // appear in the type of `origFnType`. + // TODO: If posssible, change `autodiff::getLoweredParameterIndices` to + // take `CaptureInfo` into account. + if (origSilFnType->getNumParameters() > loweredParamIndices->getCapacity()) + loweredParamIndices = loweredParamIndices->extendingCapacity( + getASTContext(), origSilFnType->getNumParameters()); + // TODO(TF-913): Replace usages of `SILAutoDiffIndices` with `AutoDiffConfig`. + SILAutoDiffIndices indices(/*source*/ 0, loweredParamIndices); + + // Self reordering thunk is necessary if wrt at least two parameters, + // including self. + auto shouldReorderSelf = [&]() { + if (!originalFunction->hasSelfParam()) + return false; + auto selfParamIndex = origSilFnType->getNumParameters() - 1; + if (!indices.isWrtParameter(selfParamIndex)) + return false; + return indices.parameters->getNumIndices() > 1; + }; + bool reorderSelf = shouldReorderSelf(); + + CanGenericSignature derivativeCanGenSig; + if (auto *derivativeGenSig = config.derivativeGenericSignature) + derivativeCanGenSig = derivativeGenSig->getCanonicalSignature(); + // TODO(TF-835): Use simpler derivative generic signature logic below when + // type-checking no longer generates implicit `@differentiable` attributes. + // See TF-835 for replacement code. + if (jvp) { + auto jvpCanGenSig = jvp->getLoweredFunctionType()->getGenericSignature(); + if (!derivativeCanGenSig && jvpCanGenSig) + derivativeCanGenSig = jvpCanGenSig; + assert(derivativeCanGenSig == jvpCanGenSig); + } + if (vjp) { + auto vjpCanGenSig = vjp->getLoweredFunctionType()->getGenericSignature(); + if (!derivativeCanGenSig && vjpCanGenSig) + derivativeCanGenSig = vjpCanGenSig; + assert(derivativeCanGenSig == vjpCanGenSig); + } + // Create new SIL differentiability witness. + // Witness JVP and VJP are set below. + // TODO(TF-919): Explore creating serialized differentiability witnesses. + // Currently, differentiability witnesses are never serialized to avoid + // deserialization issues where JVP/VJP functions cannot be found. + auto *diffWitness = SILDifferentiabilityWitness::create( + M, originalFunction->getLinkage(), originalFunction, + loweredParamIndices, config.resultIndices, derivativeCanGenSig, + /*jvp*/ nullptr, /*vjp*/ nullptr, /*isSerialized*/ false); + + // Set derivative function in differentiability witness. + auto setDerivativeInDifferentiabilityWitness = + [&](AutoDiffDerivativeFunctionKind kind, SILFunction *derivative) { + auto expectedDerivativeType = + origSilFnType->getAutoDiffDerivativeFunctionType( + indices.parameters, indices.source, kind, Types, + LookUpConformanceInModule(M.getSwiftModule())); + // Thunk derivative function. + SILFunction *derivativeThunk; + if (reorderSelf || + derivative->getLoweredFunctionType() != expectedDerivativeType) { + derivativeThunk = getOrCreateAutoDiffDerivativeFunctionThunk( + originalFunction, indices, derivative, kind, reorderSelf); + } else { + // Note: `AutoDiffDerivativeFunctionIdentifier` must be constructed with + // the AST-level parameter indices, not the SIL-level ones. + auto *id = AutoDiffDerivativeFunctionIdentifier::get( + kind, config.parameterIndices, getASTContext()); + derivativeThunk = getOrCreateAutoDiffThunk( + SILDeclRef(originalAFD).asAutoDiffDerivativeFunction(id), derivative, + expectedDerivativeType); + } + // Check for existing same derivative. + // TODO(TF-835): Remove condition below and simplify assertion to + // `!diffWitness->getDerivative(kind)` after `@differentiating` attribute + // type-checking no longer generates implicit `@differentiable` attributes. + auto *existingDerivative = diffWitness->getDerivative(kind); + if (existingDerivative && existingDerivative == derivativeThunk) + return; + assert(!existingDerivative && + "SIL differentiability witness already has a different existing " + "derivative"); + diffWitness->setDerivative(kind, derivativeThunk); + }; + if (jvp) + setDerivativeInDifferentiabilityWitness(AutoDiffDerivativeFunctionKind::JVP, + jvp); + if (vjp) + setDerivativeInDifferentiabilityWitness(AutoDiffDerivativeFunctionKind::VJP, + vjp); +} + void SILGenModule:: emitMarkFunctionEscapeForTopLevelCodeGlobals(SILLocation loc, const CaptureInfo &captureInfo) { diff --git a/lib/SILGen/SILGen.h b/lib/SILGen/SILGen.h index c69aabcdba131..879fa64b974ba 100644 --- a/lib/SILGen/SILGen.h +++ b/lib/SILGen/SILGen.h @@ -318,6 +318,16 @@ class LLVM_LIBRARY_VISIBILITY SILGenModule : public ASTVisitor { /// Emit the self-conformance witness table for a protocol. void emitSelfConformanceWitnessTable(ProtocolDecl *protocol); + // SWIFT_ENABLE_TENSORFLOW + /// Emit the differentiability witness for the given original function + /// declaration and SIL function, autodiff configuration, and JVP and VJP + /// functions (null if undefined). + void emitDifferentiabilityWitness(AbstractFunctionDecl *originalAFD, + SILFunction *originalFunction, + const AutoDiffConfig &config, + SILFunction *jvp, SILFunction *vjp); + // SWIFT_ENABLE_TENSORFLOW END + /// Emit the lazy initializer function for a global pattern binding /// declaration. SILFunction *emitLazyGlobalInitializer(StringRef funcName, diff --git a/lib/Serialization/DeserializeSIL.cpp b/lib/Serialization/DeserializeSIL.cpp index 1228cb02f7916..625e91ee11093 100644 --- a/lib/Serialization/DeserializeSIL.cpp +++ b/lib/Serialization/DeserializeSIL.cpp @@ -3016,6 +3016,9 @@ void SILDeserializer::readWitnessTableEntries( // Another record means the end of this WitnessTable. while (kind != SIL_WITNESS_TABLE && kind != SIL_DEFAULT_WITNESS_TABLE && + // SWIFT_ENABLE_TENSORFLOW + kind != SIL_DIFFERENTIABILITY_WITNESS && + // SWIFT_ENABLE_TENSORFLOW END kind != SIL_FUNCTION) { if (kind == SIL_DEFAULT_WITNESS_TABLE_NO_ENTRY) { witnessEntries.push_back(SILDefaultWitnessTable::Entry()); diff --git a/lib/Serialization/Serialization.cpp b/lib/Serialization/Serialization.cpp index 900657239e6f2..4697c25c39c4b 100644 --- a/lib/Serialization/Serialization.cpp +++ b/lib/Serialization/Serialization.cpp @@ -789,6 +789,7 @@ void Serializer::writeBlockInfoBlock() { BLOCK_RECORD(sil_block, SIL_INST_LINEAR_FUNCTION); BLOCK_RECORD(sil_block, SIL_INST_DIFFERENTIABLE_FUNCTION_EXTRACT); BLOCK_RECORD(sil_block, SIL_INST_LINEAR_FUNCTION_EXTRACT); + BLOCK_RECORD(sil_block, SIL_DIFFERENTIABILITY_WITNESS); // SWIFT_ENABLE_TENSORFLOW END // These layouts can exist in both decl blocks and sil blocks. @@ -829,6 +830,7 @@ void Serializer::writeBlockInfoBlock() { BLOCK_RECORD(sil_index_block, SIL_DEFAULT_WITNESS_TABLE_OFFSETS); BLOCK_RECORD(sil_index_block, SIL_PROPERTY_OFFSETS); // SWIFT_ENABLE_TENSORFLOW + BLOCK_RECORD(sil_index_block, SIL_DIFFERENTIABILITY_WITNESS_NAMES); BLOCK_RECORD(sil_index_block, SIL_DIFFERENTIABILITY_WITNESS_OFFSETS); // SWIFT_ENABLE_TENSORFLOW END diff --git a/lib/Serialization/SerializeSIL.cpp b/lib/Serialization/SerializeSIL.cpp index 82c2558f9c865..e029ec3e0cc17 100644 --- a/lib/Serialization/SerializeSIL.cpp +++ b/lib/Serialization/SerializeSIL.cpp @@ -2322,7 +2322,7 @@ void SILSerializer::writeIndexTables() { } // SWIFT_ENABLE_TENSORFLOW - if (!DifferentiabilityWitnessOffset.empty()) { + if (!DifferentiabilityWitnessList.empty()) { writeIndexTable(S, List, sil_index_block::SIL_DIFFERENTIABILITY_WITNESS_NAMES, DifferentiabilityWitnessList); @@ -2542,17 +2542,12 @@ writeSILDifferentiabilityWitness(const SILDifferentiabilityWitness &dw) { DifferentiabilityWitnessOffset.push_back(Out.GetCurrentBitNo()); auto *original = dw.getOriginalFunction(); - addReferencedSILFunction(original, /*DeclOnly*/ true); IdentifierID jvpID = 0; IdentifierID vjpID = 0; - if (auto *jvp = dw.getJVP()) { - addReferencedSILFunction(jvp, /*DeclOnly*/ true); - jvpID = S.addUniquedStringRef(jvp->getName()); - } - if (auto *vjp = dw.getVJP()) { - addReferencedSILFunction(vjp, /*DeclOnly*/ true); - vjpID = S.addUniquedStringRef(vjp->getName()); - } + if (auto *jvp = dw.getJVP()) + jvpID = addSILFunctionRef(jvp); + if (auto *vjp = dw.getVJP()) + vjpID = addSILFunctionRef(vjp); SmallVector parameterAndResultIndices( dw.getParameterIndices()->begin(), dw.getParameterIndices()->end()); parameterAndResultIndices.append(dw.getResultIndices()->begin(), @@ -2569,7 +2564,7 @@ writeSILDifferentiabilityWitness(const SILDifferentiabilityWitness &dw) { DifferentiabilityWitnessLayout::emitRecord( Out, ScratchRecord, SILAbbrCodes[DifferentiabilityWitnessLayout::Code], - S.addUniquedStringRef(original->getName()), + addSILFunctionRef(original), toStableSILLinkage(dw.getLinkage()), dw.isSerialized(), S.addGenericSignatureRef(dw.getDerivativeGenericSignature()), diff --git a/test/AutoDiff/sil_differentiability_witness.sil b/test/AutoDiff/sil_differentiability_witness.sil index 8f56d2480dcef..94700b3e3fb2b 100644 --- a/test/AutoDiff/sil_differentiability_witness.sil +++ b/test/AutoDiff/sil_differentiability_witness.sil @@ -74,7 +74,7 @@ sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 } // CHECK-LABEL: // differentiability witness for generic -// CHECK: sil_differentiability_witness hidden [parameters 0 1] [results 0] [where τ_0_0 : _Differentiable] @generic : $@convention(thin) <τ_0_0> (@in_guaranteed τ_0_0, Float) -> @out τ_0_0 { +// CHECK: sil_differentiability_witness hidden [parameters 0 1] [results 0] [where T : _Differentiable] @generic : $@convention(thin) (@in_guaranteed T, Float) -> @out T { // CHECK: jvp: @AD__generic__jvp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector, Float) -> @out τ_0_0.TangentVector) // CHECK: vjp: @AD__generic__vjp_src_0_wrt_0_1 : $@convention(thin) <τ_0_0 where τ_0_0 : _Differentiable> (@in_guaranteed τ_0_0, Float) -> (@out τ_0_0, @owned @callee_guaranteed (@in_guaranteed τ_0_0.TangentVector) -> (@out τ_0_0.TangentVector, Float)) // CHECK: }