From e3a2deb276ec5f420cd20442d25e998cf84f1d86 Mon Sep 17 00:00:00 2001 From: Daniel Zheng Date: Fri, 15 Jun 2018 10:26:11 -0700 Subject: [PATCH 1/4] [AutoDiff] Update @differentiable attribute, refactor `lookupFuncDecl` method. - Add adjoint access level check to @differentiable attribute. - If the original function is "exported" (public or @_versioned), then the adjoint must also be exported. - Generalize `lookupFuncDecl` function and add it as a method to `TypeChecker` so that it can be used for type-checking #adjoint expression in a later commit. --- include/swift/AST/DiagnosticsParse.def | 8 +- include/swift/AST/DiagnosticsSema.def | 3 + lib/Parse/ParseDecl.cpp | 9 +- lib/Sema/TypeCheckAttr.cpp | 173 +++++++++---------------- lib/Sema/TypeCheckExpr.cpp | 104 +++++++++++++++ lib/Sema/TypeChecker.h | 17 ++- 6 files changed, 190 insertions(+), 124 deletions(-) diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index eeab7038a6d3a..d518f3be3863a 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -1439,13 +1439,7 @@ ERROR(attr_implements_expected_member_name,PointsToFirstBadToken, // SWIFT_ENABLE_TENSORFLOW // differentiable ERROR(attr_differentiable_expected_function_name,PointsToFirstBadToken, - "expected a qualified function to differentiate", ()) - -ERROR(attr_differentiable_missing_lsquare,PointsToFirstBadToken, - "missing '[' for parameter index list", ()) - -ERROR(attr_differentiable_missing_rsquare,PointsToFirstBadToken, - "missing ']' for parameter index list", ()) + "expected a qualified %0 function", (StringRef)) ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken, "expected a list of parameters to differentiate with respect to, e.g. (.0, w, b)", ()) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 1c228a7a71125..e3e1f62261ef9 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2435,6 +2435,9 @@ ERROR(differentiable_attr_ambiguous_function_identifier,none, "attribute", (DeclName)) ERROR(differentiable_attr_forward_mode_unsupported,none, "forward-mode automatic differentiation is not supported yet", ()) +ERROR(differentiable_attr_adjoint_invalid_access,none, + "adjoint %0 of public or @_versioned function %1 is required to either " + "be public or @_versioned", (DeclName, DeclName)) ERROR(compiler_evaluable_bad_context,none, "@compilerEvaluable functions not allowed here", ()) diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index 1ecd840ce97bc..554bbb85c3901 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -612,12 +612,13 @@ Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) { parseToken(tok::colon, diag::attr_differentiable_expected_colon_after_label, label)) return true; - // Parse the name of the adjoint function. + // Parse the name of the function. + Diagnostic funcDiag(diag::attr_differentiable_expected_function_name.ID, + { label }); result.Name = parseUnqualifiedDeclName(/*afterDot=*/false, result.Loc, - diag::attr_implements_expected_member_name, - /*allowOperators=*/true, - /*allowZeroArgCompoundNames=*/true); + funcDiag, /*allowOperators=*/true, + /*allowZeroArgCompoundNames=*/true); return !result.Name; }; diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 97cc3e8a3490f..bf32e822d3162 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -2215,97 +2215,6 @@ void AttributeChecker::visitImplementsAttr(ImplementsAttr *attr) { } } -// SWIFT_ENABLE_TENSORFLOW -// Returns the function declaration corresponding to the given function name and -// lookup context. If the function declaration cannot be resolved, emits a -// diagnostic and returns nullptr. -static FuncDecl *getResolvedFuncDecl( - DeclName funcName, SourceLoc funcNameLoc, TypeChecker &TC, - DeclContext *lookupContext, - const std::function &isValidFuncDecl, - const std::function &hasValidTypeContext, - const std::function &overloadDiagnostic, - const std::function &ambiguousDiagnostic, - const std::function ¬FunctionDiagnostic) { - - FuncDecl *resolvedFuncDecl = nullptr; - - // Initialize error flags. - bool notAFuncDecl = false; - bool wrongTypeContext = false; - bool overloadNotFound = false; - - // Perform lookup, ignoring access control. - auto options = defaultUnqualifiedLookupOptions | - NameLookupFlags::IgnoreAccessControl; - auto results = - TC.lookupUnqualified(lookupContext, funcName, funcNameLoc, options); - - // Note: static methods are omitted from `TypeChecker.lookupUnqualified` in - // Swift 3. The code below is a workaround for resolving them. - // - // This is necessary because the stdlib is compiled with `-swift-version 3` - // for Swift 3 compatibility, and floating point types use the - // `@differentiable` attribute with static adjoint methods (such as - // `_adjointAdd`). - if (lookupContext->getASTContext().isSwiftVersion3() && results.empty() && - lookupContext->isTypeContext()) { - auto tmp = TC.lookupMember(lookupContext, - lookupContext->getSelfTypeInContext(), funcName); - for (auto choice : tmp) { - auto decl = choice.getValueDecl(); - if (!decl) continue; - auto funcDecl = dyn_cast(decl); - if (!funcDecl) continue; - results.add(LookupResultEntry(funcDecl)); - } - } - - for (auto choice : results) { - auto decl = choice.getValueDecl(); - if (!decl) continue; - - auto funcDecl = dyn_cast(decl); - if (!funcDecl) { - notAFuncDecl = true; - continue; - } - if (!hasValidTypeContext(funcDecl)) { - wrongTypeContext = true; - continue; - } - if (!isValidFuncDecl(funcDecl)) { - overloadNotFound = true; - continue; - } - if (resolvedFuncDecl) { - ambiguousDiagnostic(); - resolvedFuncDecl = nullptr; - break; - } - resolvedFuncDecl = funcDecl; - } - // If function declaration could not be resolved, emit the appropriate - // diagnostic. - if (!resolvedFuncDecl) { - if (results.empty()) { - TC.diagnose(funcNameLoc, diag::use_unresolved_identifier, funcName, - funcName.isOperator()); - } else if (wrongTypeContext) { - TC.diagnose(funcNameLoc, - diag::differentiable_attr_function_not_same_type_context, - funcName); - } else if (overloadNotFound) { - overloadDiagnostic(); - } else { - assert(notAFuncDecl && "Expected 'not a function' error"); - notFunctionDiagnostic(); - } - } - - return resolvedFuncDecl; -} - void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { // Forward mode is unsupported. if (attr->getMode() == AutoDiffMode::Forward) { @@ -2335,7 +2244,8 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { // If the original function and the primal/adjoint have different parents, or // if they both have no type context and are in different modules, then // it's an error. - auto hasValidTypeContext = [&](FuncDecl *decl) { + std::function hasValidTypeContext + = [&](FuncDecl *decl) { if (!original->getInnermostTypeContext() && !decl->getInnermostTypeContext() && original->getParentModule() == decl->getParentModule()) @@ -2350,8 +2260,12 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { return original->getParent() == decl->getParent(); }; + // Set lookup options. + auto lookupOptions = defaultMemberLookupOptions + | NameLookupFlags::IgnoreAccessControl; + // Resolve the primal declaration, if it exists. - FuncDecl *resolvedPrimal = nullptr; + FuncDecl *primal = nullptr; if (attr->getPrimal()) { auto primalSpecifier = attr->getPrimal().getValue(); auto primalNameLoc = primalSpecifier.Loc.getBaseNameLoc(); @@ -2374,6 +2288,11 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { diag::differentiable_attr_specified_not_function, primalSpecifier.Name, /*isPrimal*/ true); }; + std::function primalInvalidTypeContextDiagnostic = [&]() { + TC.diagnose(primalNameLoc, + diag::differentiable_attr_function_not_same_type_context, + primalSpecifier.Name); + }; auto isValidPrimal = [&](FuncDecl *primalCandidate) { // Returns true if the primal candidate @@ -2406,15 +2325,18 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { return true; }; - resolvedPrimal = - getResolvedFuncDecl(primalSpecifier.Name, primalNameLoc, - TC, primalTypeCtx, isValidPrimal, hasValidTypeContext, - primalOverloadDiagnostic, primalAmbiguousDiagnostic, - primalNotFunctionDiagnostic); + primal = TC.lookupFuncDecl( + primalSpecifier.Name, primalNameLoc, primalTypeCtx, isValidPrimal, + primalOverloadDiagnostic, primalAmbiguousDiagnostic, + primalNotFunctionDiagnostic, lookupOptions, hasValidTypeContext, + primalInvalidTypeContextDiagnostic); - if (!resolvedPrimal) return; + if (!primal) { + attr->setInvalid(); + return; + } // Memorize the primal reference in the attribute. - attr->setPrimalFunction(resolvedPrimal); + attr->setPrimalFunction(primal); } // Compute the return type of the adjoint function. @@ -2517,9 +2439,8 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { // original result, and the seed. // // If the primal exists, the checkpoints type is the primal result type. - if (attr->getPrimal()) { - auto *primResultTy = - resolvedPrimal->getResultInterfaceType()->getAs(); + if (primal) { + auto *primResultTy = primal->getResultInterfaceType()->getAs(); auto checkpointsTy = primResultTy->getElement(0).getType(); paramTypes.push_back(FunctionType::Param(checkpointsTy)); } @@ -2552,7 +2473,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { } // Resolve the adjoint declaration. - FuncDecl *resolvedAdjoint = nullptr; + FuncDecl *adjoint = nullptr; auto adjointSpecifier = attr->getAdjoint(); auto adjointNameLoc = adjointSpecifier.Loc.getBaseNameLoc(); @@ -2563,16 +2484,25 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { TC.diagnose(adjointNameLoc, diag::differentiable_attr_adjoint_overload_not_found, adjointSpecifier.Name, expectedAdjointFnTy); + attr->setInvalid(); }; auto adjointAmbiguousDiagnostic = [&]() { TC.diagnose(adjointNameLoc, diag::differentiable_attr_ambiguous_function_identifier, adjointSpecifier.Name); + attr->setInvalid(); }; auto adjointNotFunctionDiagnostic = [&]() { TC.diagnose(adjointNameLoc, diag::differentiable_attr_specified_not_function, adjointSpecifier.Name, /*isPrimal*/ false); + attr->setInvalid(); + }; + std::function adjointInvalidTypeContextDiagnostic = [&]() { + TC.diagnose(adjointNameLoc, + diag::differentiable_attr_function_not_same_type_context, + adjointSpecifier.Name); + attr->setInvalid(); }; auto isValidAdjoint = [&](FuncDecl *adjointCandidate) { @@ -2582,16 +2512,37 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { return adjointType->isEqual(expectedAdjointFnTy); }; - resolvedAdjoint = - getResolvedFuncDecl(adjointSpecifier.Name, adjointNameLoc, - TC, adjointTypeCtx, isValidAdjoint, hasValidTypeContext, - adjointOverloadDiagnostic, adjointAmbiguousDiagnostic, - adjointNotFunctionDiagnostic); + adjoint = + TC.lookupFuncDecl(adjointSpecifier.Name, adjointNameLoc, adjointTypeCtx, + isValidAdjoint, adjointOverloadDiagnostic, + adjointAmbiguousDiagnostic, adjointNotFunctionDiagnostic, + lookupOptions, hasValidTypeContext, + adjointInvalidTypeContextDiagnostic); - if (!resolvedAdjoint) return; + if (!adjoint) { + attr->setInvalid(); + return; + } // Done checking @differentiable attribute. // Memorize the adjoint reference in the attribute. - attr->setAdjointFunction(resolvedAdjoint); + attr->setAdjointFunction(adjoint); + + // Check access control. + // If the original function is exported (i.e. it is public or @_versioned), + // then the adjoint must also be exported. + auto originalAccess = original->getFormalAccess(/*useDC*/ nullptr, + /*respectVersioned*/ true); + if (originalAccess >= AccessLevel::Public) { + auto adjointAccess = adjoint->getFormalAccess(/*useDC*/ nullptr, + /*respectVersioned*/ true); + if (adjointAccess < AccessLevel::Public) { + TC.diagnose(adjointNameLoc, + diag::differentiable_attr_adjoint_invalid_access, + adjointSpecifier.Name, original->getFullName()); + attr->setInvalid(); + return; + } + } } void AttributeChecker::visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr) { diff --git a/lib/Sema/TypeCheckExpr.cpp b/lib/Sema/TypeCheckExpr.cpp index 3a77c902905fa..e01b348f0694c 100644 --- a/lib/Sema/TypeCheckExpr.cpp +++ b/lib/Sema/TypeCheckExpr.cpp @@ -709,3 +709,107 @@ bool TypeChecker::isCompatibleWithVectorAutoDiff(Type type, DeclContext *DC) { } return false; } + +// SWIFT_ENABLE_TENSORFLOW +// Returns the function declaration corresponding to the given function name and +// lookup context. If the function declaration cannot be resolved, emits a +// diagnostic and returns nullptr. +FuncDecl * +TypeChecker::lookupFuncDecl( + DeclName funcName, SourceLoc funcNameLoc, DeclContext *lookupContext, + const std::function &isValidFuncDecl, + const std::function &overloadDiagnostic, + const std::function &ambiguousDiagnostic, + const std::function ¬FunctionDiagnostic, + NameLookupOptions lookupOptions, + const Optional> &hasValidTypeCtx, + const Optional> &invalidTypeCtxDiagnostic) { + + FuncDecl *resolvedFuncDecl = nullptr; + + // Perform lookup. + auto results = + lookupUnqualified(lookupContext, funcName, funcNameLoc, lookupOptions); + + // If looking up an operator within a type context, look specifically within + // the type context. + // This works around the fact that operators cannot be specified with a + // qualified name (i.e. only `(+)` works, `Float.(+)` doesn't). + if (funcName.isOperator() && lookupContext->isTypeContext()) { + auto tmp = lookupMember(lookupContext, + lookupContext->getSelfTypeInContext(), funcName); + if (!tmp.empty()) + results = tmp; + } + // Note: static methods are omitted from `TypeChecker.lookupUnqualified` in + // Swift 3. The code below is a workaround for resolving them. + // + // This is necessary because the stdlib is compiled with `-swift-version 3` + // for Swift 3 compatibility, and floating point types use the + // `@differentiable` attribute with static adjoint methods (such as + // `_adjointAdd`). + else if (lookupContext->getASTContext().isSwiftVersion3() + && results.empty() && lookupContext->isTypeContext()) { + results = lookupMember(lookupContext, lookupContext->getSelfTypeInContext(), + funcName); + } + + // Initialize error flags. + bool notAFuncDecl = false; + bool wrongTypeContext = false; + bool ambiguousFuncDecl = false; + bool overloadNotFound = false; + + // Filter lookup results. + for (auto choice : results) { + auto decl = choice.getValueDecl(); + if (!decl) continue; + + auto funcDecl = dyn_cast(decl); + if (!funcDecl) { + notAFuncDecl = true; + continue; + } + if (hasValidTypeCtx.hasValue()) { + auto hasValidTypeContext = hasValidTypeCtx.getValue(); + if (!hasValidTypeContext(funcDecl)) { + wrongTypeContext = true; + continue; + } + } + if (!isValidFuncDecl(funcDecl)) { + overloadNotFound = true; + continue; + } + if (resolvedFuncDecl) { + ambiguousFuncDecl = true; + resolvedFuncDecl = nullptr; + break; + } + resolvedFuncDecl = funcDecl; + } + // If function declaration could not be resolved, emit the appropriate + // diagnostic. + if (!resolvedFuncDecl) { + if (results.empty()) { + diagnose(funcNameLoc, diag::use_unresolved_identifier, funcName, + funcName.isOperator()); + } else if (ambiguousFuncDecl) { + ambiguousDiagnostic(); + } else if (wrongTypeContext) { + assert(invalidTypeCtxDiagnostic && + "Type context diagnostic should've been specified"); + diagnose(funcNameLoc, + diag::differentiable_attr_function_not_same_type_context, + funcName); + invalidTypeCtxDiagnostic.getValue()(); + } else if (overloadNotFound) { + overloadDiagnostic(); + } else { + assert(notAFuncDecl && "Expected 'not a function' error"); + notFunctionDiagnostic(); + } + } + + return resolvedFuncDecl; +} diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h index 8b301961b9649..922f5815e1158 100644 --- a/lib/Sema/TypeChecker.h +++ b/lib/Sema/TypeChecker.h @@ -2534,12 +2534,25 @@ class TypeChecker final : public LazyResolver { /// We say that a type supports scalar AD when it conforms to /// `FloatingPoint`. bool isCompatibleWithScalarAutoDiff(Type type, DeclContext *DC); - - + /// Determines whether the specified type supports vector differentiation. /// We say that a type supports vector AD when it conforms to /// `VectorNumeric` while its `ScalarElement` supports scalar AD. bool isCompatibleWithVectorAutoDiff(Type type, DeclContext *DC); + + /// SWIFT_ENABLE_TENSORFLOW + // Returns the function declaration corresponding to the given function name and + // lookup context. If the function declaration cannot be resolved, emits a + // diagnostic and returns nullptr. + FuncDecl *lookupFuncDecl( + DeclName funcName, SourceLoc funcNameLoc, DeclContext *lookupContext, + const std::function &isValidFuncDecl, + const std::function &overloadDiagnostic, + const std::function &ambiguousDiagnostic, + const std::function ¬FunctionDiagnostic, + NameLookupOptions lookupOptions = defaultMemberLookupOptions, + const Optional> &hasValidTypeCtx = None, + const Optional> &invalidTypeCtxDiagnostic = None); }; /// \brief RAII object that cleans up the given expression if not explicitly From 12d3a985d3411caee3a4c48bc8ac013c449aa5a9 Mon Sep 17 00:00:00 2001 From: Daniel Zheng Date: Fri, 15 Jun 2018 11:26:57 -0700 Subject: [PATCH 2/4] Refactoring/clean up. Addresses comments from @rxwei. --- lib/Sema/TypeCheckExpr.cpp | 63 ++++++++++--------- .../differentiable_attr_type_checking.swift | 4 +- 2 files changed, 34 insertions(+), 33 deletions(-) diff --git a/lib/Sema/TypeCheckExpr.cpp b/lib/Sema/TypeCheckExpr.cpp index e01b348f0694c..a1794308d26d5 100644 --- a/lib/Sema/TypeCheckExpr.cpp +++ b/lib/Sema/TypeCheckExpr.cpp @@ -735,12 +735,11 @@ TypeChecker::lookupFuncDecl( // the type context. // This works around the fact that operators cannot be specified with a // qualified name (i.e. only `(+)` works, `Float.(+)` doesn't). - if (funcName.isOperator() && lookupContext->isTypeContext()) { - auto tmp = lookupMember(lookupContext, - lookupContext->getSelfTypeInContext(), funcName); - if (!tmp.empty()) + if (funcName.isOperator() && lookupContext->isTypeContext()) + if (auto tmp = lookupMember(lookupContext, + lookupContext->getSelfTypeInContext(), + funcName)) results = tmp; - } // Note: static methods are omitted from `TypeChecker.lookupUnqualified` in // Swift 3. The code below is a workaround for resolving them. // @@ -771,8 +770,7 @@ TypeChecker::lookupFuncDecl( continue; } if (hasValidTypeCtx.hasValue()) { - auto hasValidTypeContext = hasValidTypeCtx.getValue(); - if (!hasValidTypeContext(funcDecl)) { + if (hasValidTypeCtx && !(*hasValidTypeCtx)(funcDecl)) { wrongTypeContext = true; continue; } @@ -788,28 +786,33 @@ TypeChecker::lookupFuncDecl( } resolvedFuncDecl = funcDecl; } - // If function declaration could not be resolved, emit the appropriate - // diagnostic. - if (!resolvedFuncDecl) { - if (results.empty()) { - diagnose(funcNameLoc, diag::use_unresolved_identifier, funcName, - funcName.isOperator()); - } else if (ambiguousFuncDecl) { - ambiguousDiagnostic(); - } else if (wrongTypeContext) { - assert(invalidTypeCtxDiagnostic && - "Type context diagnostic should've been specified"); - diagnose(funcNameLoc, - diag::differentiable_attr_function_not_same_type_context, - funcName); - invalidTypeCtxDiagnostic.getValue()(); - } else if (overloadNotFound) { - overloadDiagnostic(); - } else { - assert(notAFuncDecl && "Expected 'not a function' error"); - notFunctionDiagnostic(); - } - } + // If function declaration was resolved, return it. + if (resolvedFuncDecl) return resolvedFuncDecl; - return resolvedFuncDecl; + // Otherwise, emit the appropriate diagnostic and return nullptr. + if (results.empty()) { + diagnose(funcNameLoc, diag::use_unresolved_identifier, funcName, + funcName.isOperator()); + return nullptr; + } + if (ambiguousFuncDecl) { + ambiguousDiagnostic(); + return nullptr; + } + if (wrongTypeContext) { + assert(invalidTypeCtxDiagnostic && + "Type context diagnostic should've been specified"); + diagnose(funcNameLoc, + diag::differentiable_attr_function_not_same_type_context, + funcName); + invalidTypeCtxDiagnostic.getValue()(); + return nullptr; + } + if (overloadNotFound) { + overloadDiagnostic(); + return nullptr; + } + assert(notAFuncDecl && "Expected 'not a function' error"); + notFunctionDiagnostic(); + return nullptr; } diff --git a/test/AutoDiff/differentiable_attr_type_checking.swift b/test/AutoDiff/differentiable_attr_type_checking.swift index ddf7ac065bf11..85b6a6963c666 100644 --- a/test/AutoDiff/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/differentiable_attr_type_checking.swift @@ -16,9 +16,7 @@ func foo(_ x: Float) -> Float { } // Primal returns custom checkpoints type. -struct CheckpointsFoo { -} - +struct CheckpointsFoo {} func pfoo(_ x: Float) -> (checkpoints: CheckpointsFoo, originalValue: Float) { return (CheckpointsFoo(), x * x) } From 47f49f6da1880a587e25481d8d9aa48403c4b1e4 Mon Sep 17 00:00:00 2001 From: Daniel Zheng Date: Fri, 15 Jun 2018 12:32:45 -0700 Subject: [PATCH 3/4] Add/fix @differentiable attribute diagnostics, clean up. - Add @differentiable access control check to primal as well as adjoint. - Add @differentiable access control test. - Add diagnostic for differentiating void functions. --- include/swift/AST/DiagnosticsSema.def | 11 ++- lib/Sema/TypeCheckAttr.cpp | 76 +++++++++++-------- lib/Sema/TypeCheckExpr.cpp | 12 ++- .../differentiable_attr_access_control.swift | 32 ++++++++ .../differentiable_attr_type_checking.swift | 5 ++ 5 files changed, 94 insertions(+), 42 deletions(-) create mode 100644 test/AutoDiff/differentiable_attr_access_control.swift diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index e3e1f62261ef9..3444ff468baec 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2410,6 +2410,8 @@ ERROR(implements_attr_protocol_not_conformed_to,none, // SWIFT_ENABLE_TENSORFLOW ERROR(differentiable_attr_no_parameters,none, "%0 has no parameters to differentiate with respect to", (DeclName)) +ERROR(differentiable_attr_void_result,none, + "cannot differentiate void function %0", (DeclName)) ERROR(differentiable_attr_primal_overload_not_found,none, "%0 does not have expected parameters' type %1", (DeclName, Type)) ERROR(differentiable_attr_adjoint_overload_not_found,none, @@ -2428,16 +2430,17 @@ ERROR(differentiable_attr_cannot_diff_wrt_objects_or_existentials,none, ERROR(differentiable_attr_function_not_same_type_context,none, "%0 is not defined in the current type context", (DeclName)) ERROR(differentiable_attr_specified_not_function,none, - "%0 is not a function to be used as %select{primal|adjoint}1", + "%0 is not a function to be used as %select{adjoint|primal}1", (DeclName, bool)) ERROR(differentiable_attr_ambiguous_function_identifier,none, "ambiguous or overloaded identifier %0 cannot be used in @differentiable " "attribute", (DeclName)) ERROR(differentiable_attr_forward_mode_unsupported,none, "forward-mode automatic differentiation is not supported yet", ()) -ERROR(differentiable_attr_adjoint_invalid_access,none, - "adjoint %0 of public or @_versioned function %1 is required to either " - "be public or @_versioned", (DeclName, DeclName)) +ERROR(differentiable_attr_invalid_access,none, + "%select{adjoint|primal}2 %0 is required to either be public or " + "@_versioned because the original function %1 is public or @_versioned", + (DeclName, DeclName, bool)) ERROR(compiler_evaluable_bad_context,none, "@compilerEvaluable functions not allowed here", ()) diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index bf32e822d3162..37f114e4ed382 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -2227,9 +2227,10 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { auto *original = cast(D); auto isInstanceMethod = original->isInstanceMember(); auto selfDecl = original->getImplicitSelfDecl(); + auto &ctx = original->getASTContext(); - // If the original function has no parameters, there's nothing to - // differentiate with respect to. + // If the original function has no parameters or returns the empty tuple + // type, there's nothing to differentiate from or with-respect-to. auto &originalParams = *original->getParameterList(selfDecl ? 1 : 0); if (!isInstanceMethod && originalParams.size() == 0) { TC.diagnose(attr->getLocation(), diag::differentiable_attr_no_parameters, @@ -2237,27 +2238,53 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { .highlight(original->getSourceRange()); return; } + auto originalResultTy = original->getResultInterfaceType(); + if (originalResultTy->isEqual(ctx.TheEmptyTupleType)) { + TC.diagnose(attr->getLocation(), diag::differentiable_attr_void_result, + original->getName()) + .highlight(original->getSourceRange()); + return; + } - auto originalParamsTy = - originalParams.getInterfaceType(original->getASTContext()); + auto originalParamsTy = originalParams.getInterfaceType(ctx); // If the original function and the primal/adjoint have different parents, or // if they both have no type context and are in different modules, then // it's an error. + // Returns true on error. std::function hasValidTypeContext - = [&](FuncDecl *decl) { + = [&](FuncDecl *func) { if (!original->getInnermostTypeContext() && - !decl->getInnermostTypeContext() && - original->getParentModule() == decl->getParentModule()) + !func->getInnermostTypeContext() && + original->getParentModule() == func->getParentModule()) return true; if (auto typeCtx1 = original->getInnermostTypeContext()) { - if (auto typeCtx2 = decl->getInnermostTypeContext()) { + if (auto typeCtx2 = func->getInnermostTypeContext()) { auto type1 = typeCtx1->getDeclaredInterfaceType(); auto type2 = typeCtx2->getDeclaredInterfaceType(); return type1->isEqual(type2); } } - return original->getParent() == decl->getParent(); + return original->getParent() == func->getParent(); + }; + + // If the original function is exported (i.e. it is public or @_versioned), + // then the primal/adjoint must also be exported. + // Returns true on error. + using FuncSpecifier = DifferentiableAttr::FunctionSpecifier; + auto checkAccessControl = [&](FuncDecl *func, FuncSpecifier funcSpec, + bool isPrimal) { + auto originalAccess = original->getFormalAccess(/*useDC*/ nullptr, + /*respectVersioned*/ true); + if (originalAccess < AccessLevel::Public) return false; + auto funcAccess = func->getFormalAccess(/*useDC*/ nullptr, + /*respectVersioned*/ true); + if (funcAccess >= AccessLevel::Public) return false; + TC.diagnose(funcSpec.Loc.getBaseNameLoc(), + diag::differentiable_attr_invalid_access, + funcSpec.Name, original->getFullName(), isPrimal); + attr->setInvalid(); + return true; }; // Set lookup options. @@ -2303,8 +2330,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { auto primalSelfDecl = primalCandidate->getImplicitSelfDecl(); auto primalParams = primalCandidate->getParameterList(primalSelfDecl ? 1 : 0); - auto primalParamsTy = - primalParams->getInterfaceType(original->getASTContext()); + auto primalParamsTy = primalParams->getInterfaceType(ctx); if (!primalParamsTy->isEqual(originalParamsTy)) return false; auto originalCanGenSig = original->getGenericSignature() @@ -2335,6 +2361,8 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { attr->setInvalid(); return; } + // Check primal access control. + if (checkAccessControl(primal, primalSpecifier, /*isPrimal*/ true)) return; // Memorize the primal reference in the attribute. attr->setPrimalFunction(primal); } @@ -2350,7 +2378,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { SmallVector retElts; for (auto *param : originalParams) retElts.push_back(param->getInterfaceType()); - retTy = TupleType::get(retElts, original->getASTContext()); + retTy = TupleType::get(retElts, ctx); } else { retTy = originalParams[0]->getInterfaceType(); } @@ -2425,7 +2453,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { // type. assert(retElts.size() > 0 && "There should be at least one return type"); retTy = retElts.size() > 1 - ? TupleType::get(retElts, original->getASTContext()) + ? TupleType::get(retElts, ctx) : retElts[0].getType(); } @@ -2508,7 +2536,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { auto isValidAdjoint = [&](FuncDecl *adjointCandidate) { // Returns true if adjoint candidate has the expected type. auto adjointType = adjointCandidate->getInterfaceType() - ->getUnlabeledType(original->getASTContext()); + ->getUnlabeledType(ctx); return adjointType->isEqual(expectedAdjointFnTy); }; @@ -2523,26 +2551,12 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { attr->setInvalid(); return; } + // Check adjoint access control. + if (checkAccessControl(adjoint, adjointSpecifier, /*isPrimal*/ false)) + return; // Done checking @differentiable attribute. // Memorize the adjoint reference in the attribute. attr->setAdjointFunction(adjoint); - - // Check access control. - // If the original function is exported (i.e. it is public or @_versioned), - // then the adjoint must also be exported. - auto originalAccess = original->getFormalAccess(/*useDC*/ nullptr, - /*respectVersioned*/ true); - if (originalAccess >= AccessLevel::Public) { - auto adjointAccess = adjoint->getFormalAccess(/*useDC*/ nullptr, - /*respectVersioned*/ true); - if (adjointAccess < AccessLevel::Public) { - TC.diagnose(adjointNameLoc, - diag::differentiable_attr_adjoint_invalid_access, - adjointSpecifier.Name, original->getFullName()); - attr->setInvalid(); - return; - } - } } void AttributeChecker::visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr) { diff --git a/lib/Sema/TypeCheckExpr.cpp b/lib/Sema/TypeCheckExpr.cpp index a1794308d26d5..e59fa5f693589 100644 --- a/lib/Sema/TypeCheckExpr.cpp +++ b/lib/Sema/TypeCheckExpr.cpp @@ -735,11 +735,12 @@ TypeChecker::lookupFuncDecl( // the type context. // This works around the fact that operators cannot be specified with a // qualified name (i.e. only `(+)` works, `Float.(+)` doesn't). - if (funcName.isOperator() && lookupContext->isTypeContext()) + if (funcName.isOperator() && lookupContext->isTypeContext()) { if (auto tmp = lookupMember(lookupContext, lookupContext->getSelfTypeInContext(), funcName)) results = tmp; + } // Note: static methods are omitted from `TypeChecker.lookupUnqualified` in // Swift 3. The code below is a workaround for resolving them. // @@ -747,8 +748,8 @@ TypeChecker::lookupFuncDecl( // for Swift 3 compatibility, and floating point types use the // `@differentiable` attribute with static adjoint methods (such as // `_adjointAdd`). - else if (lookupContext->getASTContext().isSwiftVersion3() - && results.empty() && lookupContext->isTypeContext()) { + else if (lookupContext->getASTContext().isSwiftVersion3() && + results.empty() && lookupContext->isTypeContext()) { results = lookupMember(lookupContext, lookupContext->getSelfTypeInContext(), funcName); } @@ -802,10 +803,7 @@ TypeChecker::lookupFuncDecl( if (wrongTypeContext) { assert(invalidTypeCtxDiagnostic && "Type context diagnostic should've been specified"); - diagnose(funcNameLoc, - diag::differentiable_attr_function_not_same_type_context, - funcName); - invalidTypeCtxDiagnostic.getValue()(); + (*invalidTypeCtxDiagnostic)(); return nullptr; } if (overloadNotFound) { diff --git a/test/AutoDiff/differentiable_attr_access_control.swift b/test/AutoDiff/differentiable_attr_access_control.swift new file mode 100644 index 0000000000000..aa241378f8080 --- /dev/null +++ b/test/AutoDiff/differentiable_attr_access_control.swift @@ -0,0 +1,32 @@ +// RUN: %target-swift-frontend -typecheck -verify %s + +// If the original function is "exported" (public or @_versioned), then its +// primal/adjoint must also be exported. + +// Ok: all public. +@differentiable(reverse, adjoint: dfoo1(_:primal:seed:)) +public func foo1(_ x: Float) -> Float { return 1 } +public func dfoo1(_ x: Float, primal: Float, seed: Float) -> Float { return 1 } + +// Ok: all internal. +struct CheckpointsFoo {} +@differentiable(reverse, primal: pfoo2(_:), adjoint: dfoo2(_:checkpoints:originalValue:seed:)) +func foo2(_ x: Float) -> Float { return 1 } +func pfoo2(_ x: Float) -> (checkpoints: CheckpointsFoo, originalValue: Float) { return (CheckpointsFoo(), 1) } +func dfoo2(_ x: Float, checkpoints: CheckpointsFoo, originalValue: Float, seed: Float) -> Float { return 1 } + +// Ok: all private. +@differentiable(reverse, adjoint: dfoo3(_:primal:seed:)) +private func foo3(_ x: Float) -> Float { return 1 } +private func dfoo3(_ x: Float, primal: Float, seed: Float) -> Float { return 1 } + +// Error: adjoint not exported. +@differentiable(reverse, adjoint: dbar1(_:primal:seed:)) // expected-error {{adjoint 'dbar1(_:primal:seed:)' is required to either be public or @_versioned because the original function 'bar1' is public or @_versioned}} +public func bar1(_ x: Float) -> Float { return 1 } +private func dbar1(_ x: Float, primal: Float, seed: Float) -> Float { return 1 } + +// Error: primal not exported. +@differentiable(reverse, primal: pbar2(_:), adjoint: dbar2(_:checkpoints:originalValue:seed:)) // expected-error {{primal 'pbar2' is required to either be public or @_versioned because the original function 'bar2' is public or @_versioned}} +@_versioned func bar2(_ x: Float) -> Float { return 1 } +func pbar2(_ x: Float) -> (checkpoints: CheckpointsFoo, originalValue: Float) { return (CheckpointsFoo(), 1) } +func dbar2(_ x: Float, checkpoints: CheckpointsFoo, originalValue: Float, seed: Float) -> Float { return 1 } diff --git a/test/AutoDiff/differentiable_attr_type_checking.swift b/test/AutoDiff/differentiable_attr_type_checking.swift index 85b6a6963c666..29aae941e4caa 100644 --- a/test/AutoDiff/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/differentiable_attr_type_checking.swift @@ -15,6 +15,11 @@ func foo(_ x: Float) -> Float { return x * x } +// Original function must return non-Void type. +@differentiable(reverse, adjoint: dvoid) // expected-error {{cannot differentiate void function 'void'}} +func void(_ a: Float) {} +func dvoid(_ a: Float, _ x: (), _ y: ()) -> Float { return 1 } + // Primal returns custom checkpoints type. struct CheckpointsFoo {} func pfoo(_ x: Float) -> (checkpoints: CheckpointsFoo, originalValue: Float) { From aac118b6db70cc9bbfa704f3d2565a80151bf11c Mon Sep 17 00:00:00 2001 From: Daniel Zheng Date: Fri, 15 Jun 2018 14:09:36 -0700 Subject: [PATCH 4/4] Preemptively rename @_versioned to @usableFromInline. --- include/swift/AST/DiagnosticsSema.def | 4 ++-- lib/Sema/TypeCheckAttr.cpp | 18 ++++++++++-------- .../differentiable_attr_access_control.swift | 8 ++++---- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 3444ff468baec..3927a093054e3 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2439,8 +2439,8 @@ ERROR(differentiable_attr_forward_mode_unsupported,none, "forward-mode automatic differentiation is not supported yet", ()) ERROR(differentiable_attr_invalid_access,none, "%select{adjoint|primal}2 %0 is required to either be public or " - "@_versioned because the original function %1 is public or @_versioned", - (DeclName, DeclName, bool)) + "@usableFromInline because the original function %1 is public or " + "@usableFromInline", (DeclName, DeclName, bool)) ERROR(compiler_evaluable_bad_context,none, "@compilerEvaluable functions not allowed here", ()) diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 37f114e4ed382..616149bb5ea8b 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -2249,8 +2249,8 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { auto originalParamsTy = originalParams.getInterfaceType(ctx); // If the original function and the primal/adjoint have different parents, or - // if they both have no type context and are in different modules, then - // it's an error. + // if they both have no type context and are in different modules, then it's + // an error. // Returns true on error. std::function hasValidTypeContext = [&](FuncDecl *func) { @@ -2268,17 +2268,19 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { return original->getParent() == func->getParent(); }; - // If the original function is exported (i.e. it is public or @_versioned), - // then the primal/adjoint must also be exported. + // If the original function is exported (i.e. it is public or + // @usableFromInline), then the primal/adjoint must also be exported. // Returns true on error. using FuncSpecifier = DifferentiableAttr::FunctionSpecifier; auto checkAccessControl = [&](FuncDecl *func, FuncSpecifier funcSpec, bool isPrimal) { - auto originalAccess = original->getFormalAccess(/*useDC*/ nullptr, - /*respectVersioned*/ true); + auto originalAccess = + original->getFormalAccess(/*useDC*/ nullptr, + /*treatUsableFromInlineAsPublic*/ true); if (originalAccess < AccessLevel::Public) return false; - auto funcAccess = func->getFormalAccess(/*useDC*/ nullptr, - /*respectVersioned*/ true); + auto funcAccess = + func->getFormalAccess(/*useDC*/ nullptr, + /*treatUsableFromInlineAsPublic*/ true); if (funcAccess >= AccessLevel::Public) return false; TC.diagnose(funcSpec.Loc.getBaseNameLoc(), diag::differentiable_attr_invalid_access, diff --git a/test/AutoDiff/differentiable_attr_access_control.swift b/test/AutoDiff/differentiable_attr_access_control.swift index aa241378f8080..52a7c8f56bc01 100644 --- a/test/AutoDiff/differentiable_attr_access_control.swift +++ b/test/AutoDiff/differentiable_attr_access_control.swift @@ -1,7 +1,7 @@ // RUN: %target-swift-frontend -typecheck -verify %s -// If the original function is "exported" (public or @_versioned), then its -// primal/adjoint must also be exported. +// If the original function is "exported" (public or @usableFromInline), then +// its primal/adjoint must also be exported. // Ok: all public. @differentiable(reverse, adjoint: dfoo1(_:primal:seed:)) @@ -21,12 +21,12 @@ private func foo3(_ x: Float) -> Float { return 1 } private func dfoo3(_ x: Float, primal: Float, seed: Float) -> Float { return 1 } // Error: adjoint not exported. -@differentiable(reverse, adjoint: dbar1(_:primal:seed:)) // expected-error {{adjoint 'dbar1(_:primal:seed:)' is required to either be public or @_versioned because the original function 'bar1' is public or @_versioned}} +@differentiable(reverse, adjoint: dbar1(_:primal:seed:)) // expected-error {{adjoint 'dbar1(_:primal:seed:)' is required to either be public or @usableFromInline because the original function 'bar1' is public or @usableFromInline}} public func bar1(_ x: Float) -> Float { return 1 } private func dbar1(_ x: Float, primal: Float, seed: Float) -> Float { return 1 } // Error: primal not exported. -@differentiable(reverse, primal: pbar2(_:), adjoint: dbar2(_:checkpoints:originalValue:seed:)) // expected-error {{primal 'pbar2' is required to either be public or @_versioned because the original function 'bar2' is public or @_versioned}} +@differentiable(reverse, primal: pbar2(_:), adjoint: dbar2(_:checkpoints:originalValue:seed:)) // expected-error {{primal 'pbar2' is required to either be public or @usableFromInline because the original function 'bar2' is public or @usableFromInline}} @_versioned func bar2(_ x: Float) -> Float { return 1 } func pbar2(_ x: Float) -> (checkpoints: CheckpointsFoo, originalValue: Float) { return (CheckpointsFoo(), 1) } func dbar2(_ x: Float, checkpoints: CheckpointsFoo, originalValue: Float, seed: Float) -> Float { return 1 }