Skip to content

Commit 48dcb8d

Browse files
committed
Use correct IR name for AutoDiff mangling
1 parent 95b1426 commit 48dcb8d

File tree

3 files changed

+121
-101
lines changed

3 files changed

+121
-101
lines changed

lib/AST/ASTMangler.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,19 @@ void ASTMangler::beginManglingWithAutoDiffOriginalFunction(
669669
beginManglingClangDecl(typedefType->getDecl());
670670
return;
671671
}
672+
673+
if (auto *EA = ExternAttr::find(afd->getAttrs(), ExternKind::C)) {
674+
beginManglingWithoutPrefix();
675+
appendOperator(EA->getCName(afd));
676+
return;
677+
}
678+
679+
if (afd->getAttrs().hasAttribute<CDeclAttr>()) {
680+
beginManglingWithoutPrefix();
681+
appendOperator(afd->getCDeclName());
682+
return;
683+
}
684+
672685
beginMangling();
673686
if (auto *cd = dyn_cast<ConstructorDecl>(afd))
674687
appendConstructorEntity(cd, /*isAllocating*/ !cd->isConvenienceInit());

lib/SILOptimizer/Mandatory/Differentiation.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -983,14 +983,21 @@ DifferentiationTransformer::emitDerivativeFunctionReference(
983983
// `SILDifferentiabilityWitness` processing
984984
//===----------------------------------------------------------------------===//
985985

986+
static StringRef getIRName(SILFunction *F) {
987+
if (!F->asmName().empty())
988+
return F->asmName();
989+
990+
return F->getName();
991+
}
992+
986993
static SILFunction *createEmptyVJP(ADContext &context,
987994
SILDifferentiabilityWitness *witness,
988995
SerializedKind_t isSerialized) {
989996
auto original = witness->getOriginalFunction();
990997
auto config = witness->getConfig();
991998
LLVM_DEBUG({
992999
auto &s = getADDebugStream();
993-
s << "Creating VJP for " << original->getName() << ":\n\t";
1000+
s << "Creating VJP for " << getIRName(original) << ":\n\t";
9941001
s << "Original type: " << original->getLoweredFunctionType() << "\n\t";
9951002
s << "Config: " << config << "\n\t";
9961003
});
@@ -1001,7 +1008,7 @@ static SILFunction *createEmptyVJP(ADContext &context,
10011008
// === Create an empty VJP. ===
10021009
Mangle::DifferentiationMangler mangler(context.getASTContext());
10031010
auto vjpName = mangler.mangleDerivativeFunction(
1004-
original->getName(), AutoDiffDerivativeFunctionKind::VJP, config);
1011+
getIRName(original), AutoDiffDerivativeFunctionKind::VJP, config);
10051012
auto vjpCanGenSig = witness->getDerivativeGenericSignature().getCanonicalSignature();
10061013
GenericEnvironment *vjpGenericEnv = nullptr;
10071014
if (vjpCanGenSig && !vjpCanGenSig->areAllParamsConcrete())
@@ -1035,7 +1042,7 @@ static SILFunction *createEmptyJVP(ADContext &context,
10351042
auto config = witness->getConfig();
10361043
LLVM_DEBUG({
10371044
auto &s = getADDebugStream();
1038-
s << "Creating JVP for " << original->getName() << ":\n\t";
1045+
s << "Creating JVP for " << getIRName(original) << ":\n\t";
10391046
s << "Original type: " << original->getLoweredFunctionType() << "\n\t";
10401047
s << "Config: " << config << "\n\t";
10411048
});
@@ -1045,7 +1052,7 @@ static SILFunction *createEmptyJVP(ADContext &context,
10451052

10461053
Mangle::DifferentiationMangler mangler(context.getASTContext());
10471054
auto jvpName = mangler.mangleDerivativeFunction(
1048-
original->getName(), AutoDiffDerivativeFunctionKind::JVP, config);
1055+
getIRName(original), AutoDiffDerivativeFunctionKind::JVP, config);
10491056
auto jvpCanGenSig = witness->getDerivativeGenericSignature().getCanonicalSignature();
10501057
GenericEnvironment *jvpGenericEnv = nullptr;
10511058
if (jvpCanGenSig && !jvpCanGenSig->areAllParamsConcrete())
@@ -1174,7 +1181,7 @@ bool DifferentiationTransformer::canonicalizeDifferentiabilityWitness(
11741181
"_fatalErrorForwardModeDifferentiationDisabled");
11751182
LLVM_DEBUG(getADDebugStream()
11761183
<< "Generated empty JVP for "
1177-
<< orig->getName() << ":\n"
1184+
<< getIRName(orig) << ":\n"
11781185
<< *jvp);
11791186
}
11801187
}
@@ -1245,7 +1252,7 @@ static SILValue promoteCurryThunkApplicationToDifferentiableFunction(
12451252
// Create a new curry thunk.
12461253
AutoDiffConfig desiredConfig(parameterIndices, resultIndices);
12471254
// TODO(TF-685): Use more principled mangling for thunks.
1248-
auto newThunkName = "AD__" + thunk->getName().str() +
1255+
auto newThunkName = "AD__" + getIRName(thunk).str() +
12491256
"__differentiable_curry_thunk_" + desiredConfig.mangle();
12501257

12511258
// Construct new curry thunk type with `@differentiable` function

0 commit comments

Comments
 (0)