Skip to content

Commit ff64467

Browse files
committed
[AutoDiff] Update @differentiable attribute, refactor lookupFuncDecl method. (#17246)
* [AutoDiff] Update @differentiable attribute, refactor `lookupFuncDecl` method. - Add adjoint access level check to @differentiable attribute. - If the original function is "exported" (public or @_versioned), then the adjoint must also be exported. - Generalize `lookupFuncDecl` function and add it as a method to `TypeChecker` so that it can be used for type-checking #adjoint expression in a later commit. * Refactoring/clean up. Addresses comments from @rxwei. * Add/fix @differentiable attribute diagnostics, clean up. - Add @differentiable access control check to primal as well as adjoint. - Add @differentiable access control test. - Add diagnostic for differentiating void functions. * Preemptively rename @_versioned to @usableFromInline.
1 parent ff5d851 commit ff64467

File tree

8 files changed

+264
-144
lines changed

8 files changed

+264
-144
lines changed

include/swift/AST/DiagnosticsParse.def

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1454,13 +1454,7 @@ ERROR(attr_implements_expected_member_name,PointsToFirstBadToken,
14541454
// SWIFT_ENABLE_TENSORFLOW
14551455
// differentiable
14561456
ERROR(attr_differentiable_expected_function_name,PointsToFirstBadToken,
1457-
"expected a qualified function to differentiate", ())
1458-
1459-
ERROR(attr_differentiable_missing_lsquare,PointsToFirstBadToken,
1460-
"missing '[' for parameter index list", ())
1461-
1462-
ERROR(attr_differentiable_missing_rsquare,PointsToFirstBadToken,
1463-
"missing ']' for parameter index list", ())
1457+
"expected a qualified %0 function", (StringRef))
14641458

14651459
ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken,
14661460
"expected a list of parameters to differentiate with respect to, e.g. (.0, w, b)", ())

include/swift/AST/DiagnosticsSema.def

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2486,6 +2486,8 @@ ERROR(implements_attr_protocol_not_conformed_to,none,
24862486
// SWIFT_ENABLE_TENSORFLOW
24872487
ERROR(differentiable_attr_no_parameters,none,
24882488
"%0 has no parameters to differentiate with respect to", (DeclName))
2489+
ERROR(differentiable_attr_void_result,none,
2490+
"cannot differentiate void function %0", (DeclName))
24892491
ERROR(differentiable_attr_primal_overload_not_found,none,
24902492
"%0 does not have expected parameters' type %1", (DeclName, Type))
24912493
ERROR(differentiable_attr_adjoint_overload_not_found,none,
@@ -2504,13 +2506,17 @@ ERROR(differentiable_attr_cannot_diff_wrt_objects_or_existentials,none,
25042506
ERROR(differentiable_attr_function_not_same_type_context,none,
25052507
"%0 is not defined in the current type context", (DeclName))
25062508
ERROR(differentiable_attr_specified_not_function,none,
2507-
"%0 is not a function to be used as %select{primal|adjoint}1",
2509+
"%0 is not a function to be used as %select{adjoint|primal}1",
25082510
(DeclName, bool))
25092511
ERROR(differentiable_attr_ambiguous_function_identifier,none,
25102512
"ambiguous or overloaded identifier %0 cannot be used in @differentiable "
25112513
"attribute", (DeclName))
25122514
ERROR(differentiable_attr_forward_mode_unsupported,none,
25132515
"forward-mode automatic differentiation is not supported yet", ())
2516+
ERROR(differentiable_attr_invalid_access,none,
2517+
"%select{adjoint|primal}2 %0 is required to either be public or "
2518+
"@usableFromInline because the original function %1 is public or "
2519+
"@usableFromInline", (DeclName, DeclName, bool))
25142520

25152521
ERROR(compiler_evaluable_bad_context,none,
25162522
"@compilerEvaluable functions not allowed here", ())

lib/Parse/ParseDecl.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -875,12 +875,13 @@ Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) {
875875
parseToken(tok::colon,
876876
diag::attr_differentiable_expected_colon_after_label, label))
877877
return true;
878-
// Parse the name of the adjoint function.
878+
// Parse the name of the function.
879+
Diagnostic funcDiag(diag::attr_differentiable_expected_function_name.ID,
880+
{ label });
879881
result.Name =
880882
parseUnqualifiedDeclName(/*afterDot=*/false, result.Loc,
881-
diag::attr_implements_expected_member_name,
882-
/*allowOperators=*/true,
883-
/*allowZeroArgCompoundNames=*/true);
883+
funcDiag, /*allowOperators=*/true,
884+
/*allowZeroArgCompoundNames=*/true);
884885
return !result.Name;
885886
};
886887

lib/Sema/TypeCheckAttr.cpp

Lines changed: 93 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -2231,98 +2231,6 @@ void AttributeChecker::visitFrozenAttr(FrozenAttr *attr) {
22312231
}
22322232
}
22332233

2234-
// SWIFT_ENABLE_TENSORFLOW
2235-
// Returns the function declaration corresponding to the given function name and
2236-
// lookup context. If the function declaration cannot be resolved, emits a
2237-
// diagnostic and returns nullptr.
2238-
static FuncDecl *getResolvedFuncDecl(
2239-
DeclName funcName, SourceLoc funcNameLoc, TypeChecker &TC,
2240-
DeclContext *lookupContext,
2241-
const std::function<bool(FuncDecl *)> &isValidFuncDecl,
2242-
const std::function<bool(FuncDecl *)> &hasValidTypeContext,
2243-
const std::function<void()> &overloadDiagnostic,
2244-
const std::function<void()> &ambiguousDiagnostic,
2245-
const std::function<void()> &notFunctionDiagnostic) {
2246-
2247-
FuncDecl *resolvedFuncDecl = nullptr;
2248-
2249-
// Initialize error flags.
2250-
bool notAFuncDecl = false;
2251-
bool wrongTypeContext = false;
2252-
bool overloadNotFound = false;
2253-
2254-
// Perform lookup, ignoring access control.
2255-
auto options = defaultUnqualifiedLookupOptions |
2256-
NameLookupFlags::IgnoreAccessControl;
2257-
auto results =
2258-
TC.lookupUnqualified(lookupContext, funcName, funcNameLoc, options);
2259-
2260-
// Note: static methods are omitted from `TypeChecker.lookupUnqualified` in
2261-
// Swift 3. The code below is a workaround for resolving them.
2262-
//
2263-
// This is necessary because the stdlib is compiled with `-swift-version 3`
2264-
// for Swift 3 compatibility, and floating point types use the
2265-
// `@differentiable` attribute with static adjoint methods (such as
2266-
// `_adjointAdd`).
2267-
if (lookupContext->getASTContext().isSwiftVersion3() && results.empty() &&
2268-
lookupContext->isTypeContext()) {
2269-
auto tmp = TC.lookupMember(lookupContext,
2270-
lookupContext->getSelfTypeInContext(), funcName);
2271-
for (auto choice : tmp) {
2272-
auto decl = choice.getValueDecl();
2273-
if (!decl) continue;
2274-
auto funcDecl = dyn_cast<FuncDecl>(decl);
2275-
if (!funcDecl) continue;
2276-
results.add(LookupResultEntry(funcDecl));
2277-
}
2278-
}
2279-
2280-
for (auto choice : results) {
2281-
auto decl = choice.getValueDecl();
2282-
if (!decl) continue;
2283-
2284-
auto funcDecl = dyn_cast<FuncDecl>(decl);
2285-
if (!funcDecl) {
2286-
notAFuncDecl = true;
2287-
continue;
2288-
}
2289-
if (!hasValidTypeContext(funcDecl)) {
2290-
wrongTypeContext = true;
2291-
continue;
2292-
}
2293-
if (!isValidFuncDecl(funcDecl)) {
2294-
overloadNotFound = true;
2295-
continue;
2296-
}
2297-
if (resolvedFuncDecl) {
2298-
ambiguousDiagnostic();
2299-
resolvedFuncDecl = nullptr;
2300-
break;
2301-
}
2302-
resolvedFuncDecl = funcDecl;
2303-
}
2304-
// If function declaration could not be resolved, emit the appropriate
2305-
// diagnostic.
2306-
if (!resolvedFuncDecl) {
2307-
if (results.empty()) {
2308-
TC.diagnose(funcNameLoc, diag::use_unresolved_identifier, funcName,
2309-
funcName.isOperator());
2310-
} else if (wrongTypeContext) {
2311-
TC.diagnose(funcNameLoc,
2312-
diag::differentiable_attr_function_not_same_type_context,
2313-
funcName);
2314-
} else if (overloadNotFound) {
2315-
overloadDiagnostic();
2316-
} else {
2317-
assert(notAFuncDecl && "Expected 'not a function' error");
2318-
notFunctionDiagnostic();
2319-
}
2320-
}
2321-
2322-
return resolvedFuncDecl;
2323-
}
2324-
2325-
// SWIFT_ENABLE_TENSORFLOW
23262234
void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
23272235
// Forward mode is unsupported.
23282236
if (attr->getMode() == AutoDiffMode::Forward) {
@@ -2335,40 +2243,74 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
23352243
auto *original = cast<FuncDecl>(D);
23362244
auto isInstanceMethod = original->isInstanceMember();
23372245
auto selfDecl = original->getImplicitSelfDecl();
2246+
auto &ctx = original->getASTContext();
23382247

2339-
// If the original function has no parameters, there's nothing to
2340-
// differentiate with respect to.
2248+
// If the original function has no parameters or returns the empty tuple
2249+
// type, there's nothing to differentiate from or with-respect-to.
23412250
auto &originalParams = *original->getParameterList(selfDecl ? 1 : 0);
23422251
if (!isInstanceMethod && originalParams.size() == 0) {
23432252
TC.diagnose(attr->getLocation(), diag::differentiable_attr_no_parameters,
23442253
original->getName())
23452254
.highlight(original->getSourceRange());
23462255
return;
23472256
}
2257+
auto originalResultTy = original->getResultInterfaceType();
2258+
if (originalResultTy->isEqual(ctx.TheEmptyTupleType)) {
2259+
TC.diagnose(attr->getLocation(), diag::differentiable_attr_void_result,
2260+
original->getName())
2261+
.highlight(original->getSourceRange());
2262+
return;
2263+
}
23482264

2349-
auto originalParamsTy =
2350-
originalParams.getInterfaceType(original->getASTContext());
2265+
auto originalParamsTy = originalParams.getInterfaceType(ctx);
23512266

23522267
// If the original function and the primal/adjoint have different parents, or
2353-
// if they both have no type context and are in different modules, then
2354-
// it's an error.
2355-
auto hasValidTypeContext = [&](FuncDecl *decl) {
2268+
// if they both have no type context and are in different modules, then it's
2269+
// an error.
2270+
// Returns true on error.
2271+
std::function<bool(FuncDecl *)> hasValidTypeContext
2272+
= [&](FuncDecl *func) {
23562273
if (!original->getInnermostTypeContext() &&
2357-
!decl->getInnermostTypeContext() &&
2358-
original->getParentModule() == decl->getParentModule())
2274+
!func->getInnermostTypeContext() &&
2275+
original->getParentModule() == func->getParentModule())
23592276
return true;
23602277
if (auto typeCtx1 = original->getInnermostTypeContext()) {
2361-
if (auto typeCtx2 = decl->getInnermostTypeContext()) {
2278+
if (auto typeCtx2 = func->getInnermostTypeContext()) {
23622279
auto type1 = typeCtx1->getDeclaredInterfaceType();
23632280
auto type2 = typeCtx2->getDeclaredInterfaceType();
23642281
return type1->isEqual(type2);
23652282
}
23662283
}
2367-
return original->getParent() == decl->getParent();
2284+
return original->getParent() == func->getParent();
23682285
};
23692286

2287+
// If the original function is exported (i.e. it is public or
2288+
// @usableFromInline), then the primal/adjoint must also be exported.
2289+
// Returns true on error.
2290+
using FuncSpecifier = DifferentiableAttr::FunctionSpecifier;
2291+
auto checkAccessControl = [&](FuncDecl *func, FuncSpecifier funcSpec,
2292+
bool isPrimal) {
2293+
auto originalAccess =
2294+
original->getFormalAccess(/*useDC*/ nullptr,
2295+
/*treatUsableFromInlineAsPublic*/ true);
2296+
if (originalAccess < AccessLevel::Public) return false;
2297+
auto funcAccess =
2298+
func->getFormalAccess(/*useDC*/ nullptr,
2299+
/*treatUsableFromInlineAsPublic*/ true);
2300+
if (funcAccess >= AccessLevel::Public) return false;
2301+
TC.diagnose(funcSpec.Loc.getBaseNameLoc(),
2302+
diag::differentiable_attr_invalid_access,
2303+
funcSpec.Name, original->getFullName(), isPrimal);
2304+
attr->setInvalid();
2305+
return true;
2306+
};
2307+
2308+
// Set lookup options.
2309+
auto lookupOptions = defaultMemberLookupOptions
2310+
| NameLookupFlags::IgnoreAccessControl;
2311+
23702312
// Resolve the primal declaration, if it exists.
2371-
FuncDecl *resolvedPrimal = nullptr;
2313+
FuncDecl *primal = nullptr;
23722314
if (attr->getPrimal()) {
23732315
auto primalSpecifier = attr->getPrimal().getValue();
23742316
auto primalNameLoc = primalSpecifier.Loc.getBaseNameLoc();
@@ -2391,6 +2333,11 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
23912333
diag::differentiable_attr_specified_not_function,
23922334
primalSpecifier.Name, /*isPrimal*/ true);
23932335
};
2336+
std::function<void()> primalInvalidTypeContextDiagnostic = [&]() {
2337+
TC.diagnose(primalNameLoc,
2338+
diag::differentiable_attr_function_not_same_type_context,
2339+
primalSpecifier.Name);
2340+
};
23942341

23952342
auto isValidPrimal = [&](FuncDecl *primalCandidate) {
23962343
// Returns true if the primal candidate
@@ -2401,8 +2348,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
24012348
auto primalSelfDecl = primalCandidate->getImplicitSelfDecl();
24022349
auto primalParams =
24032350
primalCandidate->getParameterList(primalSelfDecl ? 1 : 0);
2404-
auto primalParamsTy =
2405-
primalParams->getInterfaceType(original->getASTContext());
2351+
auto primalParamsTy = primalParams->getInterfaceType(ctx);
24062352
if (!primalParamsTy->isEqual(originalParamsTy))
24072353
return false;
24082354
auto originalCanGenSig = original->getGenericSignature()
@@ -2423,15 +2369,20 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
24232369
return true;
24242370
};
24252371

2426-
resolvedPrimal =
2427-
getResolvedFuncDecl(primalSpecifier.Name, primalNameLoc,
2428-
TC, primalTypeCtx, isValidPrimal, hasValidTypeContext,
2429-
primalOverloadDiagnostic, primalAmbiguousDiagnostic,
2430-
primalNotFunctionDiagnostic);
2372+
primal = TC.lookupFuncDecl(
2373+
primalSpecifier.Name, primalNameLoc, primalTypeCtx, isValidPrimal,
2374+
primalOverloadDiagnostic, primalAmbiguousDiagnostic,
2375+
primalNotFunctionDiagnostic, lookupOptions, hasValidTypeContext,
2376+
primalInvalidTypeContextDiagnostic);
24312377

2432-
if (!resolvedPrimal) return;
2378+
if (!primal) {
2379+
attr->setInvalid();
2380+
return;
2381+
}
2382+
// Check primal access control.
2383+
if (checkAccessControl(primal, primalSpecifier, /*isPrimal*/ true)) return;
24332384
// Memorize the primal reference in the attribute.
2434-
attr->setPrimalFunction(resolvedPrimal);
2385+
attr->setPrimalFunction(primal);
24352386
}
24362387

24372388
// Compute the return type of the adjoint function.
@@ -2445,7 +2396,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
24452396
SmallVector<TupleTypeElt, 8> retElts;
24462397
for (auto *param : originalParams)
24472398
retElts.push_back(param->getInterfaceType());
2448-
retTy = TupleType::get(retElts, original->getASTContext());
2399+
retTy = TupleType::get(retElts, ctx);
24492400
} else {
24502401
retTy = originalParams[0]->getInterfaceType();
24512402
}
@@ -2520,7 +2471,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
25202471
// type.
25212472
assert(retElts.size() > 0 && "There should be at least one return type");
25222473
retTy = retElts.size() > 1
2523-
? TupleType::get(retElts, original->getASTContext())
2474+
? TupleType::get(retElts, ctx)
25242475
: retElts[0].getType();
25252476
}
25262477

@@ -2534,9 +2485,8 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
25342485
// original result, and the seed.
25352486
//
25362487
// If the primal exists, the checkpoints type is the primal result type.
2537-
if (attr->getPrimal()) {
2538-
auto *primResultTy =
2539-
resolvedPrimal->getResultInterfaceType()->getAs<TupleType>();
2488+
if (primal) {
2489+
auto *primResultTy = primal->getResultInterfaceType()->getAs<TupleType>();
25402490
auto checkpointsTy = primResultTy->getElement(0).getType();
25412491
paramTypes.push_back(FunctionType::Param(checkpointsTy));
25422492
}
@@ -2569,7 +2519,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
25692519
}
25702520

25712521
// Resolve the adjoint declaration.
2572-
FuncDecl *resolvedAdjoint = nullptr;
2522+
FuncDecl *adjoint = nullptr;
25732523
auto adjointSpecifier = attr->getAdjoint();
25742524
auto adjointNameLoc = adjointSpecifier.Loc.getBaseNameLoc();
25752525

@@ -2580,35 +2530,51 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
25802530
TC.diagnose(adjointNameLoc,
25812531
diag::differentiable_attr_adjoint_overload_not_found,
25822532
adjointSpecifier.Name, expectedAdjointFnTy);
2533+
attr->setInvalid();
25832534
};
25842535
auto adjointAmbiguousDiagnostic = [&]() {
25852536
TC.diagnose(adjointNameLoc,
25862537
diag::differentiable_attr_ambiguous_function_identifier,
25872538
adjointSpecifier.Name);
2539+
attr->setInvalid();
25882540
};
25892541
auto adjointNotFunctionDiagnostic = [&]() {
25902542
TC.diagnose(adjointNameLoc,
25912543
diag::differentiable_attr_specified_not_function,
25922544
adjointSpecifier.Name, /*isPrimal*/ false);
2545+
attr->setInvalid();
2546+
};
2547+
std::function<void()> adjointInvalidTypeContextDiagnostic = [&]() {
2548+
TC.diagnose(adjointNameLoc,
2549+
diag::differentiable_attr_function_not_same_type_context,
2550+
adjointSpecifier.Name);
2551+
attr->setInvalid();
25932552
};
25942553

25952554
auto isValidAdjoint = [&](FuncDecl *adjointCandidate) {
25962555
// Returns true if adjoint candidate has the expected type.
25972556
auto adjointType = adjointCandidate->getInterfaceType()
2598-
->getUnlabeledType(original->getASTContext());
2557+
->getUnlabeledType(ctx);
25992558
return adjointType->isEqual(expectedAdjointFnTy);
26002559
};
26012560

2602-
resolvedAdjoint =
2603-
getResolvedFuncDecl(adjointSpecifier.Name, adjointNameLoc,
2604-
TC, adjointTypeCtx, isValidAdjoint, hasValidTypeContext,
2605-
adjointOverloadDiagnostic, adjointAmbiguousDiagnostic,
2606-
adjointNotFunctionDiagnostic);
2561+
adjoint =
2562+
TC.lookupFuncDecl(adjointSpecifier.Name, adjointNameLoc, adjointTypeCtx,
2563+
isValidAdjoint, adjointOverloadDiagnostic,
2564+
adjointAmbiguousDiagnostic, adjointNotFunctionDiagnostic,
2565+
lookupOptions, hasValidTypeContext,
2566+
adjointInvalidTypeContextDiagnostic);
26072567

2608-
if (!resolvedAdjoint) return;
2568+
if (!adjoint) {
2569+
attr->setInvalid();
2570+
return;
2571+
}
2572+
// Check adjoint access control.
2573+
if (checkAccessControl(adjoint, adjointSpecifier, /*isPrimal*/ false))
2574+
return;
26092575
// Done checking @differentiable attribute.
26102576
// Memorize the adjoint reference in the attribute.
2611-
attr->setAdjointFunction(resolvedAdjoint);
2577+
attr->setAdjointFunction(adjoint);
26122578
}
26132579

26142580
static bool

0 commit comments

Comments
 (0)