@@ -2231,98 +2231,6 @@ void AttributeChecker::visitFrozenAttr(FrozenAttr *attr) {
22312231 }
22322232}
22332233
2234- // SWIFT_ENABLE_TENSORFLOW
2235- // Returns the function declaration corresponding to the given function name and
2236- // lookup context. If the function declaration cannot be resolved, emits a
2237- // diagnostic and returns nullptr.
2238- static FuncDecl *getResolvedFuncDecl (
2239- DeclName funcName, SourceLoc funcNameLoc, TypeChecker &TC,
2240- DeclContext *lookupContext,
2241- const std::function<bool (FuncDecl *)> &isValidFuncDecl,
2242- const std::function<bool(FuncDecl *)> &hasValidTypeContext,
2243- const std::function<void()> &overloadDiagnostic,
2244- const std::function<void()> &ambiguousDiagnostic,
2245- const std::function<void()> ¬FunctionDiagnostic) {
2246-
2247- FuncDecl *resolvedFuncDecl = nullptr ;
2248-
2249- // Initialize error flags.
2250- bool notAFuncDecl = false ;
2251- bool wrongTypeContext = false ;
2252- bool overloadNotFound = false ;
2253-
2254- // Perform lookup, ignoring access control.
2255- auto options = defaultUnqualifiedLookupOptions |
2256- NameLookupFlags::IgnoreAccessControl;
2257- auto results =
2258- TC.lookupUnqualified (lookupContext, funcName, funcNameLoc, options);
2259-
2260- // Note: static methods are omitted from `TypeChecker.lookupUnqualified` in
2261- // Swift 3. The code below is a workaround for resolving them.
2262- //
2263- // This is necessary because the stdlib is compiled with `-swift-version 3`
2264- // for Swift 3 compatibility, and floating point types use the
2265- // `@differentiable` attribute with static adjoint methods (such as
2266- // `_adjointAdd`).
2267- if (lookupContext->getASTContext ().isSwiftVersion3 () && results.empty () &&
2268- lookupContext->isTypeContext ()) {
2269- auto tmp = TC.lookupMember (lookupContext,
2270- lookupContext->getSelfTypeInContext (), funcName);
2271- for (auto choice : tmp) {
2272- auto decl = choice.getValueDecl ();
2273- if (!decl) continue ;
2274- auto funcDecl = dyn_cast<FuncDecl>(decl);
2275- if (!funcDecl) continue ;
2276- results.add (LookupResultEntry (funcDecl));
2277- }
2278- }
2279-
2280- for (auto choice : results) {
2281- auto decl = choice.getValueDecl ();
2282- if (!decl) continue ;
2283-
2284- auto funcDecl = dyn_cast<FuncDecl>(decl);
2285- if (!funcDecl) {
2286- notAFuncDecl = true ;
2287- continue ;
2288- }
2289- if (!hasValidTypeContext (funcDecl)) {
2290- wrongTypeContext = true ;
2291- continue ;
2292- }
2293- if (!isValidFuncDecl (funcDecl)) {
2294- overloadNotFound = true ;
2295- continue ;
2296- }
2297- if (resolvedFuncDecl) {
2298- ambiguousDiagnostic ();
2299- resolvedFuncDecl = nullptr ;
2300- break ;
2301- }
2302- resolvedFuncDecl = funcDecl;
2303- }
2304- // If function declaration could not be resolved, emit the appropriate
2305- // diagnostic.
2306- if (!resolvedFuncDecl) {
2307- if (results.empty ()) {
2308- TC.diagnose (funcNameLoc, diag::use_unresolved_identifier, funcName,
2309- funcName.isOperator ());
2310- } else if (wrongTypeContext) {
2311- TC.diagnose (funcNameLoc,
2312- diag::differentiable_attr_function_not_same_type_context,
2313- funcName);
2314- } else if (overloadNotFound) {
2315- overloadDiagnostic ();
2316- } else {
2317- assert (notAFuncDecl && " Expected 'not a function' error" );
2318- notFunctionDiagnostic ();
2319- }
2320- }
2321-
2322- return resolvedFuncDecl;
2323- }
2324-
2325- // SWIFT_ENABLE_TENSORFLOW
23262234void AttributeChecker::visitDifferentiableAttr (DifferentiableAttr *attr) {
23272235 // Forward mode is unsupported.
23282236 if (attr->getMode () == AutoDiffMode::Forward) {
@@ -2335,40 +2243,74 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
23352243 auto *original = cast<FuncDecl>(D);
23362244 auto isInstanceMethod = original->isInstanceMember ();
23372245 auto selfDecl = original->getImplicitSelfDecl ();
2246+ auto &ctx = original->getASTContext ();
23382247
2339- // If the original function has no parameters, there's nothing to
2340- // differentiate with respect to.
2248+ // If the original function has no parameters or returns the empty tuple
2249+ // type, there's nothing to differentiate from or with-respect- to.
23412250 auto &originalParams = *original->getParameterList (selfDecl ? 1 : 0 );
23422251 if (!isInstanceMethod && originalParams.size () == 0 ) {
23432252 TC.diagnose (attr->getLocation (), diag::differentiable_attr_no_parameters,
23442253 original->getName ())
23452254 .highlight (original->getSourceRange ());
23462255 return ;
23472256 }
2257+ auto originalResultTy = original->getResultInterfaceType ();
2258+ if (originalResultTy->isEqual (ctx.TheEmptyTupleType )) {
2259+ TC.diagnose (attr->getLocation (), diag::differentiable_attr_void_result,
2260+ original->getName ())
2261+ .highlight (original->getSourceRange ());
2262+ return ;
2263+ }
23482264
2349- auto originalParamsTy =
2350- originalParams.getInterfaceType (original->getASTContext ());
2265+ auto originalParamsTy = originalParams.getInterfaceType (ctx);
23512266
23522267 // If the original function and the primal/adjoint have different parents, or
2353- // if they both have no type context and are in different modules, then
2354- // it's an error.
2355- auto hasValidTypeContext = [&](FuncDecl *decl) {
2268+ // if they both have no type context and are in different modules, then it's
2269+ // an error.
2270+ // Returns true on error.
2271+ std::function<bool (FuncDecl *)> hasValidTypeContext
2272+ = [&](FuncDecl *func) {
23562273 if (!original->getInnermostTypeContext () &&
2357- !decl ->getInnermostTypeContext () &&
2358- original->getParentModule () == decl ->getParentModule ())
2274+ !func ->getInnermostTypeContext () &&
2275+ original->getParentModule () == func ->getParentModule ())
23592276 return true ;
23602277 if (auto typeCtx1 = original->getInnermostTypeContext ()) {
2361- if (auto typeCtx2 = decl ->getInnermostTypeContext ()) {
2278+ if (auto typeCtx2 = func ->getInnermostTypeContext ()) {
23622279 auto type1 = typeCtx1->getDeclaredInterfaceType ();
23632280 auto type2 = typeCtx2->getDeclaredInterfaceType ();
23642281 return type1->isEqual (type2);
23652282 }
23662283 }
2367- return original->getParent () == decl ->getParent ();
2284+ return original->getParent () == func ->getParent ();
23682285 };
23692286
2287+ // If the original function is exported (i.e. it is public or
2288+ // @usableFromInline), then the primal/adjoint must also be exported.
2289+ // Returns true on error.
2290+ using FuncSpecifier = DifferentiableAttr::FunctionSpecifier;
2291+ auto checkAccessControl = [&](FuncDecl *func, FuncSpecifier funcSpec,
2292+ bool isPrimal) {
2293+ auto originalAccess =
2294+ original->getFormalAccess (/* useDC*/ nullptr ,
2295+ /* treatUsableFromInlineAsPublic*/ true );
2296+ if (originalAccess < AccessLevel::Public) return false ;
2297+ auto funcAccess =
2298+ func->getFormalAccess (/* useDC*/ nullptr ,
2299+ /* treatUsableFromInlineAsPublic*/ true );
2300+ if (funcAccess >= AccessLevel::Public) return false ;
2301+ TC.diagnose (funcSpec.Loc .getBaseNameLoc (),
2302+ diag::differentiable_attr_invalid_access,
2303+ funcSpec.Name , original->getFullName (), isPrimal);
2304+ attr->setInvalid ();
2305+ return true ;
2306+ };
2307+
2308+ // Set lookup options.
2309+ auto lookupOptions = defaultMemberLookupOptions
2310+ | NameLookupFlags::IgnoreAccessControl;
2311+
23702312 // Resolve the primal declaration, if it exists.
2371- FuncDecl *resolvedPrimal = nullptr ;
2313+ FuncDecl *primal = nullptr ;
23722314 if (attr->getPrimal ()) {
23732315 auto primalSpecifier = attr->getPrimal ().getValue ();
23742316 auto primalNameLoc = primalSpecifier.Loc .getBaseNameLoc ();
@@ -2391,6 +2333,11 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
23912333 diag::differentiable_attr_specified_not_function,
23922334 primalSpecifier.Name , /* isPrimal*/ true );
23932335 };
2336+ std::function<void ()> primalInvalidTypeContextDiagnostic = [&]() {
2337+ TC.diagnose (primalNameLoc,
2338+ diag::differentiable_attr_function_not_same_type_context,
2339+ primalSpecifier.Name );
2340+ };
23942341
23952342 auto isValidPrimal = [&](FuncDecl *primalCandidate) {
23962343 // Returns true if the primal candidate
@@ -2401,8 +2348,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
24012348 auto primalSelfDecl = primalCandidate->getImplicitSelfDecl ();
24022349 auto primalParams =
24032350 primalCandidate->getParameterList (primalSelfDecl ? 1 : 0 );
2404- auto primalParamsTy =
2405- primalParams->getInterfaceType (original->getASTContext ());
2351+ auto primalParamsTy = primalParams->getInterfaceType (ctx);
24062352 if (!primalParamsTy->isEqual (originalParamsTy))
24072353 return false ;
24082354 auto originalCanGenSig = original->getGenericSignature ()
@@ -2423,15 +2369,20 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
24232369 return true ;
24242370 };
24252371
2426- resolvedPrimal =
2427- getResolvedFuncDecl ( primalSpecifier.Name , primalNameLoc,
2428- TC, primalTypeCtx, isValidPrimal, hasValidTypeContext ,
2429- primalOverloadDiagnostic, primalAmbiguousDiagnostic ,
2430- primalNotFunctionDiagnostic );
2372+ primal = TC. lookupFuncDecl (
2373+ primalSpecifier.Name , primalNameLoc, primalTypeCtx, isValidPrimal ,
2374+ primalOverloadDiagnostic, primalAmbiguousDiagnostic ,
2375+ primalNotFunctionDiagnostic, lookupOptions, hasValidTypeContext ,
2376+ primalInvalidTypeContextDiagnostic );
24312377
2432- if (!resolvedPrimal) return ;
2378+ if (!primal) {
2379+ attr->setInvalid ();
2380+ return ;
2381+ }
2382+ // Check primal access control.
2383+ if (checkAccessControl (primal, primalSpecifier, /* isPrimal*/ true )) return ;
24332384 // Memorize the primal reference in the attribute.
2434- attr->setPrimalFunction (resolvedPrimal );
2385+ attr->setPrimalFunction (primal );
24352386 }
24362387
24372388 // Compute the return type of the adjoint function.
@@ -2445,7 +2396,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
24452396 SmallVector<TupleTypeElt, 8 > retElts;
24462397 for (auto *param : originalParams)
24472398 retElts.push_back (param->getInterfaceType ());
2448- retTy = TupleType::get (retElts, original-> getASTContext () );
2399+ retTy = TupleType::get (retElts, ctx );
24492400 } else {
24502401 retTy = originalParams[0 ]->getInterfaceType ();
24512402 }
@@ -2520,7 +2471,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
25202471 // type.
25212472 assert (retElts.size () > 0 && " There should be at least one return type" );
25222473 retTy = retElts.size () > 1
2523- ? TupleType::get (retElts, original-> getASTContext () )
2474+ ? TupleType::get (retElts, ctx )
25242475 : retElts[0 ].getType ();
25252476 }
25262477
@@ -2534,9 +2485,8 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
25342485 // original result, and the seed.
25352486 //
25362487 // If the primal exists, the checkpoints type is the primal result type.
2537- if (attr->getPrimal ()) {
2538- auto *primResultTy =
2539- resolvedPrimal->getResultInterfaceType ()->getAs <TupleType>();
2488+ if (primal) {
2489+ auto *primResultTy = primal->getResultInterfaceType ()->getAs <TupleType>();
25402490 auto checkpointsTy = primResultTy->getElement (0 ).getType ();
25412491 paramTypes.push_back (FunctionType::Param (checkpointsTy));
25422492 }
@@ -2569,7 +2519,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
25692519 }
25702520
25712521 // Resolve the adjoint declaration.
2572- FuncDecl *resolvedAdjoint = nullptr ;
2522+ FuncDecl *adjoint = nullptr ;
25732523 auto adjointSpecifier = attr->getAdjoint ();
25742524 auto adjointNameLoc = adjointSpecifier.Loc .getBaseNameLoc ();
25752525
@@ -2580,35 +2530,51 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
25802530 TC.diagnose (adjointNameLoc,
25812531 diag::differentiable_attr_adjoint_overload_not_found,
25822532 adjointSpecifier.Name , expectedAdjointFnTy);
2533+ attr->setInvalid ();
25832534 };
25842535 auto adjointAmbiguousDiagnostic = [&]() {
25852536 TC.diagnose (adjointNameLoc,
25862537 diag::differentiable_attr_ambiguous_function_identifier,
25872538 adjointSpecifier.Name );
2539+ attr->setInvalid ();
25882540 };
25892541 auto adjointNotFunctionDiagnostic = [&]() {
25902542 TC.diagnose (adjointNameLoc,
25912543 diag::differentiable_attr_specified_not_function,
25922544 adjointSpecifier.Name , /* isPrimal*/ false );
2545+ attr->setInvalid ();
2546+ };
2547+ std::function<void ()> adjointInvalidTypeContextDiagnostic = [&]() {
2548+ TC.diagnose (adjointNameLoc,
2549+ diag::differentiable_attr_function_not_same_type_context,
2550+ adjointSpecifier.Name );
2551+ attr->setInvalid ();
25932552 };
25942553
25952554 auto isValidAdjoint = [&](FuncDecl *adjointCandidate) {
25962555 // Returns true if adjoint candidate has the expected type.
25972556 auto adjointType = adjointCandidate->getInterfaceType ()
2598- ->getUnlabeledType (original-> getASTContext () );
2557+ ->getUnlabeledType (ctx );
25992558 return adjointType->isEqual (expectedAdjointFnTy);
26002559 };
26012560
2602- resolvedAdjoint =
2603- getResolvedFuncDecl (adjointSpecifier.Name , adjointNameLoc,
2604- TC, adjointTypeCtx, isValidAdjoint, hasValidTypeContext,
2605- adjointOverloadDiagnostic, adjointAmbiguousDiagnostic,
2606- adjointNotFunctionDiagnostic);
2561+ adjoint =
2562+ TC.lookupFuncDecl (adjointSpecifier.Name , adjointNameLoc, adjointTypeCtx,
2563+ isValidAdjoint, adjointOverloadDiagnostic,
2564+ adjointAmbiguousDiagnostic, adjointNotFunctionDiagnostic,
2565+ lookupOptions, hasValidTypeContext,
2566+ adjointInvalidTypeContextDiagnostic);
26072567
2608- if (!resolvedAdjoint) return ;
2568+ if (!adjoint) {
2569+ attr->setInvalid ();
2570+ return ;
2571+ }
2572+ // Check adjoint access control.
2573+ if (checkAccessControl (adjoint, adjointSpecifier, /* isPrimal*/ false ))
2574+ return ;
26092575 // Done checking @differentiable attribute.
26102576 // Memorize the adjoint reference in the attribute.
2611- attr->setAdjointFunction (resolvedAdjoint );
2577+ attr->setAdjointFunction (adjoint );
26122578}
26132579
26142580static bool
0 commit comments