Skip to content

Commit 00adac8

Browse files
committed
[WIP]
Try to make this pass: ``` import Darwin print(gradient(at: Double(3), in: tan)) print(gradient(at: Double(3), in: Darwin.tan)) print(gradient(at: Double(3), in: Darwin.cos)) ``` ``` Undefined symbols for architecture x86_64: "_AD__$s6Darwin3cosyS2dF_PSRS", referenced from: _main in tgmath-e5045b.o ld: symbol(s) not found for architecture x86_64 <unknown>:0: error: link command failed with exit code 1 (use -v to see invocation) ```
1 parent fe33547 commit 00adac8

File tree

14 files changed

+275
-68
lines changed

14 files changed

+275
-68
lines changed

include/swift/SIL/SILDifferentiabilityWitness.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class SILDifferentiabilityWitness
9898
SILDifferentiabilityWitnessKey getKey() const;
9999
SILModule &getModule() const { return Module; }
100100
SILLinkage getLinkage() const { return Linkage; }
101+
void setLinkage(SILLinkage linkage) { Linkage = linkage; }
101102
SILFunction *getOriginalFunction() const { return OriginalFunction; }
102103
const AutoDiffConfig &getConfig() const { return Config; }
103104
IndexSubset *getParameterIndices() const {

lib/IRGen/GenDecl.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,9 +1078,6 @@ void IRGenerator::emitGlobalTopLevel(llvm::StringSet<> *linkerDirectives) {
10781078
// Emit differentiability witnesses.
10791079
for (auto &dw :
10801080
PrimaryIGM->getSILModule().getDifferentiabilityWitnessList()) {
1081-
if (dw.isDeclaration())
1082-
continue;
1083-
10841081
// Emit into same IRGenModule as the original function.
10851082
// NOTE(TF-894): Investigate whether `getGenModule(dw.getVJP())` is
10861083
// significant/desirable; `getGenModule` seems relevant for multi-threaded
@@ -4487,7 +4484,7 @@ IRGenModule::getAddrOfWitnessTablePattern(const NormalProtocolConformance *conf,
44874484
}
44884485

44894486
// SWIFT_ENABLE_TENSORFLOW
4490-
/// Look up the address of a witness table.
4487+
/// Look up the address of a differentiability witness.
44914488
llvm::Constant *IRGenModule::getAddrOfDifferentiabilityWitness(
44924489
const SILDifferentiabilityWitness *witness, ConstantInit definition) {
44934490
auto entity = LinkEntity::forDifferentiabilityWitness(witness);

lib/IRGen/GenDiffWitness.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,25 @@ void IRGenModule::emitSILDifferentiabilityWitness(
2929
SILDifferentiabilityWitness *dw) {
3030
PrettyStackTraceDifferentiabilityWitness _st(
3131
"emitting differentiability witness for", dw->getKey());
32-
32+
llvm::errs() << "IRGEN SIL DIFF WITNESS, CLANG DECL? "
33+
<< dw->getOriginalFunction()->getClangDecl() << ", C REFERENCES: "
34+
<< dw->getOriginalFunction()->hasCReferences() << ", WEAK IMPORTED: "
35+
<< dw->getOriginalFunction()->isWeakImported() << ", IS EXTERNAL DECL: "
36+
<< dw->getOriginalFunction()->isExternalDeclaration() << ", IS DECL: "
37+
<< dw->isDeclaration()
38+
<< "\n";
39+
dw->dump();
3340
// Don't emit declarations.
3441
if (dw->isDeclaration())
3542
return;
3643

37-
// Don't emit public_external witnesses.
38-
if (hasPublicVisibility(dw->getLinkage()) &&
39-
isAvailableExternally(dw->getLinkage()))
40-
return;
44+
// Don't emit `public_external` witnesses.
45+
// Make an exception for Clang-imported functions.
46+
if (!dw->getOriginalFunction()->getClangDecl()) {
47+
if (hasPublicVisibility(dw->getLinkage()) &&
48+
isAvailableExternally(dw->getLinkage()))
49+
return;
50+
}
4151

4252
ConstantInitBuilder builder(*this);
4353
auto diffWitnessContents = builder.beginStruct();
@@ -52,4 +62,6 @@ void IRGenModule::emitSILDifferentiabilityWitness(
5262

5363
getAddrOfDifferentiabilityWitness(
5464
dw, diffWitnessContents.finishAndCreateFuture());
65+
66+
llvm::errs() << "IRGEN'd DIFF WITNESS!\n";
5567
}

lib/SIL/SILFunctionType.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,13 @@ CanSILFunctionType SILFunctionType::getAutoDiffDerivativeFunctionType(
423423
if (getSubstGenericSignature() && derivativeFnGenSig &&
424424
!derivativeFnGenSig->areAllParamsConcrete())
425425
canGenSig = derivativeFnGenSig;
426-
return SILFunctionType::get(canGenSig, getExtInfo(), getCoroutineKind(),
426+
// If original function is `@convention(c)`, the derivative function should
427+
// have `@convention(thin)`. IRGen does not support `@convention(c)` functions
428+
// with multiple results.
429+
auto extInfo = getExtInfo();
430+
if (getLanguage() == SILFunctionLanguage::C)
431+
extInfo = extInfo.withRepresentation(SILFunctionTypeRepresentation::Thin);
432+
return SILFunctionType::get(canGenSig, extInfo, getCoroutineKind(),
427433
getCalleeConvention(), newParameters, getYields(),
428434
newResults, getOptionalErrorResult(),
429435
getSubstitutions(), isGenericSignatureImplied(),

lib/SILGen/SILGen.cpp

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -755,11 +755,8 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
755755
F->print(llvm::dbgs()));
756756

757757
// SWIFT_ENABLE_TENSORFLOW
758-
// Visit `@differentiable` attributes and generate SIL differentiability
759-
// witnesses.
760-
// TODO(TF-835): Visit `@derivative` attributes when type-checking no longer
761-
// generates implicit `@differentiable` attributes. See TF-835 for replacement
762-
// code.
758+
// Visit `@differentiable` amd `@derivative` attributes and generate SIL
759+
// differentiability witnesses.
763760
// Skip if the SILDeclRef is a:
764761
// - Default argument generator function.
765762
// - Thunk.
@@ -782,6 +779,10 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
782779
"all original SIL functions with generic signatures");
783780
AutoDiffConfig config(diffAttr->getParameterIndices(), resultIndices,
784781
diffAttr->getDerivativeGenericSignature());
782+
llvm::errs() << "SILGEN DIFF ATTR, ORIGINAL: " << AFD->getEffectiveFullName()
783+
<< ", F2N: " << constant.isForeignToNativeThunk()
784+
<< ", N2F: " << constant.isNativeToForeignThunk() << "\n";
785+
785786
emitDifferentiabilityWitness(AFD, F, config, jvp, vjp, diffAttr);
786787
}
787788
for (auto *derivAttr : Attrs.getAttributes<DerivativeAttr>()) {
@@ -798,6 +799,15 @@ void SILGenModule::postEmitFunction(SILDeclRef constant,
798799
auto *origAFD = derivAttr->getOriginalFunction();
799800
auto origConstant =
800801
SILDeclRef(origAFD).asForeign(requiresForeignEntryPoint(origAFD));
802+
#if 0
803+
auto origConstant = SILDeclRef(origAFD);
804+
#endif
805+
llvm::errs() << "SILGEN DERIV ATTR, ORIGINAL: " << origAFD->getEffectiveFullName()
806+
<< ", F2N: " << origConstant.isForeignToNativeThunk()
807+
<< ", N2F: " << origConstant.isNativeToForeignThunk()
808+
<< ", HAS CLANG: " << origAFD->hasClangNode()
809+
<< ", FOREIGN: " << requiresForeignEntryPoint(origAFD)
810+
<< "\n";
801811
auto *origFn = getFunction(origConstant, NotForDefinition);
802812
auto derivativeGenSig = AFD->getGenericSignature();
803813
auto *resultIndices = IndexSubset::get(getASTContext(), 1, {0});
@@ -863,6 +873,19 @@ void SILGenModule::emitDifferentiabilityWitness(
863873
/*isSerialized*/ hasPublicVisibility(originalFunction->getLinkage()),
864874
attr);
865875
}
876+
llvm::errs() << "SILGEN DIFF WITNESS, ORIG LINKAGE: "
877+
<< (unsigned)originalFunction->getLinkage() << "\n";
878+
// Strip external from linkage of Clang-imported functions.
879+
if (originalFunction->hasClangNode()) {
880+
llvm::errs() << "SILGEN DIFF WITNESS, STRIP EXTERNAL FROM LINKAGE\n";
881+
diffWitness->setLinkage(stripExternalFromLinkage(diffWitness->getLinkage()));
882+
}
883+
attr->print(llvm::errs(), originalAFD);
884+
llvm::errs() << "\n";
885+
diffWitness->dump();
886+
// #if 0
887+
originalFunction->dump();
888+
// #endif
866889

867890
// Set derivative function in differentiability witness.
868891
auto setDerivativeInDifferentiabilityWitness =

lib/SILOptimizer/Utils/Differentiation/Common.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -343,21 +343,32 @@ SILDifferentiabilityWitness *getOrCreateMinimalASTDifferentiabilityWitness(
343343
if (!minimalConfig)
344344
return nullptr;
345345

346+
std::string originalName = original->getName();
347+
// If original function requires a foreign entry point, look up the
348+
// differentiability witness for the foreign function.
349+
if (requiresForeignEntryPoint(originalAFD)) {
350+
originalName = SILDeclRef(originalAFD).asForeign().mangle();
351+
original = module.lookUpFunction(SILDeclRef(originalAFD).asForeign());
352+
}
353+
346354
auto *existingWitness = module.lookUpDifferentiabilityWitness(
347-
{original->getName(), *minimalConfig});
348-
if (existingWitness)
355+
{originalName, *minimalConfig});
356+
if (existingWitness) {
357+
llvm::errs() << "FOUND EXISTING WITNESS\n";
358+
existingWitness->dump();
349359
return existingWitness;
360+
}
350361

351362
assert(original->isExternalDeclaration() &&
352363
"SILGen should create differentiability witnesses for all function "
353364
"definitions with explicit differentiable attributes");
365+
llvm::errs() << "EXTERNAL WITNESS FOR: " << original->getName() << "\n";
354366

355367
return SILDifferentiabilityWitness::createDeclaration(
356368
module, SILLinkage::PublicExternal, original,
357369
minimalConfig->parameterIndices, minimalConfig->resultIndices,
358370
minimalConfig->derivativeGenericSignature);
359371
}
360372

361-
362373
} // end namespace autodiff
363374
} // end namespace swift

lib/SILOptimizer/Utils/Differentiation/LinearMapInfo.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,6 @@ bool LinearMapInfo::shouldDifferentiateApplySite(FullApplySite applySite) {
472472

473473
// TODO: Pattern match to make sure there is at least one `store` to the
474474
// array's active buffer.
475-
// if (isArrayLiteralIntrinsic(ai) && hasActiveResults)
476475
if (isArrayLiteralIntrinsic(applySite) && hasActiveResults)
477476
return true;
478477

lib/TBDGen/TBDGen.cpp

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ void TBDGenVisitor::addSymbolInternal(StringRef name,
6565
if (StringSymbols && kind == SymbolKind::GlobalSymbol) {
6666
auto isNewValue = StringSymbols->insert(name).second;
6767
(void)isNewValue;
68+
if (!isNewValue)
69+
llvm::errs() << "DUPLICATE: " << name << "\n";
6870
assert(isNewValue && "symbol appears twice");
6971
}
7072
}
@@ -276,25 +278,55 @@ void TBDGenVisitor::addAutoDiffDerivativeFunction(
276278
void TBDGenVisitor::addDifferentiabilityWitness(
277279
AbstractFunctionDecl *original, IndexSubset *astParameterIndices,
278280
IndexSubset *resultIndices, GenericSignature derivativeGenericSignature) {
279-
if (SILDeclRef(original).getLinkage(ForDefinition) != SILLinkage::Public)
281+
bool foreign = requiresForeignEntryPoint(original);
282+
llvm::errs() << "TBDGenVisitor::addDifferentiabilityWitness: "
283+
<< original->getEffectiveFullName() << ", HAS CLANG? " << original->hasClangNode() << ", FOREIGN: " << foreign << "\n";
284+
if (!foreign) {
285+
if (SILDeclRef(original).asForeign(foreign).getLinkage(ForDefinition) != SILLinkage::Public)
280286
return;
287+
}
288+
// original->dump();
289+
290+
#if 0
291+
if (SILDeclRef(original).asForeign(requiresForeignEntryPoint(original)).getLinkage(ForDefinition) != SILLinkage::Public)
292+
return;
293+
#endif
281294

282295
auto *silParamIndices = autodiff::getLoweredParameterIndices(
283296
astParameterIndices,
284297
original->getInterfaceType()->castTo<AnyFunctionType>());
285298

299+
std::string originalMangledName = SILDeclRef(original).asForeign(requiresForeignEntryPoint(original)).mangle();
300+
#if 0
286301
std::string originalMangledName = SILDeclRef(original).mangle();
302+
#endif
287303
AutoDiffConfig config{silParamIndices, resultIndices,
288304
derivativeGenericSignature};
289305
SILDifferentiabilityWitnessKey key(originalMangledName, config);
290306

291307
Mangle::ASTMangler mangler;
292308
std::string mangledName = mangler.mangleSILDifferentiabilityWitnessKey(key);
293309
addSymbol(mangledName);
310+
#if 0
311+
if (foreign) {
312+
std::string originalMangledName = SILDeclRef(original).asForeign(true).mangle();
313+
AutoDiffConfig config{silParamIndices, resultIndices,
314+
derivativeGenericSignature};
315+
SILDifferentiabilityWitnessKey key(originalMangledName, config);
316+
317+
Mangle::ASTMangler mangler;
318+
std::string mangledName = mangler.mangleSILDifferentiabilityWitnessKey(key);
319+
addSymbol(mangledName);
320+
}
321+
#endif
294322
}
295323

296324
void TBDGenVisitor::addDerivativeConfiguration(AbstractFunctionDecl *original,
297325
AutoDiffConfig config) {
326+
auto inserted = AddedDerivatives.insert({original, config});
327+
if (!inserted.second)
328+
return;
329+
298330
addAutoDiffLinearMapFunction(original, config,
299331
AutoDiffLinearMapKind::Differential);
300332
addAutoDiffLinearMapFunction(original, config,
@@ -374,9 +406,25 @@ void TBDGenVisitor::visitAbstractFunctionDecl(AbstractFunctionDecl *AFD) {
374406
}
375407

376408
// SWIFT_ENABLE_TENSORFLOW
409+
#if 0
377410
for (auto derivativeConfig : AFD->getDerivativeFunctionConfigurations()) {
378411
addDerivativeConfiguration(AFD, derivativeConfig);
379412
}
413+
#endif
414+
for (const auto *differentiableAttr :
415+
AFD->getAttrs().getAttributes<DifferentiableAttr>())
416+
addDerivativeConfiguration(
417+
AFD,
418+
AutoDiffConfig(differentiableAttr->getParameterIndices(),
419+
IndexSubset::get(AFD->getASTContext(), 1, {0}),
420+
differentiableAttr->getDerivativeGenericSignature()));
421+
for (const auto *derivativeAttr :
422+
AFD->getAttrs().getAttributes<DerivativeAttr>())
423+
addDerivativeConfiguration(
424+
derivativeAttr->getOriginalFunction(),
425+
AutoDiffConfig(derivativeAttr->getParameterIndices(),
426+
IndexSubset::get(AFD->getASTContext(), 1, {0}),
427+
AFD->getGenericSignature()));
380428

381429
visitDefaultArguments(AFD, AFD->getParameters());
382430
}
@@ -430,6 +478,15 @@ void TBDGenVisitor::visitAbstractStorageDecl(AbstractStorageDecl *ASD) {
430478
ASD->visitEmittedAccessors([&](AccessorDecl *accessor) {
431479
visitFuncDecl(accessor);
432480
});
481+
482+
// SWIFT_ENABLE_TENSORFLOW
483+
for (const auto *differentiableAttr :
484+
ASD->getAttrs().getAttributes<DifferentiableAttr>())
485+
addDerivativeConfiguration(
486+
ASD->getAccessor(AccessorKind::Get),
487+
AutoDiffConfig(differentiableAttr->getParameterIndices(),
488+
IndexSubset::get(ASD->getASTContext(), 1, {0}),
489+
differentiableAttr->getDerivativeGenericSignature()));
433490
}
434491

435492
void TBDGenVisitor::visitVarDecl(VarDecl *VD) {

lib/TBDGen/TBDGenVisitor.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,15 @@ class TBDGenVisitor : public ASTVisitor<TBDGenVisitor> {
5656
const TBDGenOptions &Opts;
5757
Decl* TopLevelDecl = nullptr;
5858

59+
// SWIFT_ENABLE_TENSORFLOW
60+
/// Tracks derivatives that have been added to the TBD.
61+
///
62+
/// Different attributes trigger emission of the same derivatives (e.g.
63+
/// `@differentiable` and `@derivative(of:)`), so we use this to deduplicate
64+
/// the symbols associated with the derivatives in the TBD.
65+
llvm::DenseSet<std::pair<AbstractFunctionDecl *, AutoDiffConfig>>
66+
AddedDerivatives;
67+
5968
private:
6069
void addSymbolInternal(StringRef name, llvm::MachO::SymbolKind kind,
6170
bool isLinkerDirective = false);

0 commit comments

Comments
 (0)