diff --git a/include/swift/AST/DiagnosticsParse.def b/include/swift/AST/DiagnosticsParse.def index eeab7038a6d3a..d518f3be3863a 100644 --- a/include/swift/AST/DiagnosticsParse.def +++ b/include/swift/AST/DiagnosticsParse.def @@ -1439,13 +1439,7 @@ ERROR(attr_implements_expected_member_name,PointsToFirstBadToken, // SWIFT_ENABLE_TENSORFLOW // differentiable ERROR(attr_differentiable_expected_function_name,PointsToFirstBadToken, - "expected a qualified function to differentiate", ()) - -ERROR(attr_differentiable_missing_lsquare,PointsToFirstBadToken, - "missing '[' for parameter index list", ()) - -ERROR(attr_differentiable_missing_rsquare,PointsToFirstBadToken, - "missing ']' for parameter index list", ()) + "expected a qualified %0 function", (StringRef)) ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken, "expected a list of parameters to differentiate with respect to, e.g. (.0, w, b)", ()) diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 1c228a7a71125..3927a093054e3 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2410,6 +2410,8 @@ ERROR(implements_attr_protocol_not_conformed_to,none, // SWIFT_ENABLE_TENSORFLOW ERROR(differentiable_attr_no_parameters,none, "%0 has no parameters to differentiate with respect to", (DeclName)) +ERROR(differentiable_attr_void_result,none, + "cannot differentiate void function %0", (DeclName)) ERROR(differentiable_attr_primal_overload_not_found,none, "%0 does not have expected parameters' type %1", (DeclName, Type)) ERROR(differentiable_attr_adjoint_overload_not_found,none, @@ -2428,13 +2430,17 @@ ERROR(differentiable_attr_cannot_diff_wrt_objects_or_existentials,none, ERROR(differentiable_attr_function_not_same_type_context,none, "%0 is not defined in the current type context", (DeclName)) ERROR(differentiable_attr_specified_not_function,none, - "%0 is not a function to be used as %select{primal|adjoint}1", + "%0 is not a function to be used as %select{adjoint|primal}1", (DeclName, bool)) ERROR(differentiable_attr_ambiguous_function_identifier,none, "ambiguous or overloaded identifier %0 cannot be used in @differentiable " "attribute", (DeclName)) ERROR(differentiable_attr_forward_mode_unsupported,none, "forward-mode automatic differentiation is not supported yet", ()) +ERROR(differentiable_attr_invalid_access,none, + "%select{adjoint|primal}2 %0 is required to either be public or " + "@usableFromInline because the original function %1 is public or " + "@usableFromInline", (DeclName, DeclName, bool)) ERROR(compiler_evaluable_bad_context,none, "@compilerEvaluable functions not allowed here", ()) diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index 1ecd840ce97bc..554bbb85c3901 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -612,12 +612,13 @@ Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) { parseToken(tok::colon, diag::attr_differentiable_expected_colon_after_label, label)) return true; - // Parse the name of the adjoint function. + // Parse the name of the function. + Diagnostic funcDiag(diag::attr_differentiable_expected_function_name.ID, + { label }); result.Name = parseUnqualifiedDeclName(/*afterDot=*/false, result.Loc, - diag::attr_implements_expected_member_name, - /*allowOperators=*/true, - /*allowZeroArgCompoundNames=*/true); + funcDiag, /*allowOperators=*/true, + /*allowZeroArgCompoundNames=*/true); return !result.Name; }; diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 97cc3e8a3490f..616149bb5ea8b 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -2215,97 +2215,6 @@ void AttributeChecker::visitImplementsAttr(ImplementsAttr *attr) { } } -// SWIFT_ENABLE_TENSORFLOW -// Returns the function declaration corresponding to the given function name and -// lookup context. If the function declaration cannot be resolved, emits a -// diagnostic and returns nullptr. -static FuncDecl *getResolvedFuncDecl( - DeclName funcName, SourceLoc funcNameLoc, TypeChecker &TC, - DeclContext *lookupContext, - const std::function &isValidFuncDecl, - const std::function &hasValidTypeContext, - const std::function &overloadDiagnostic, - const std::function &ambiguousDiagnostic, - const std::function ¬FunctionDiagnostic) { - - FuncDecl *resolvedFuncDecl = nullptr; - - // Initialize error flags. - bool notAFuncDecl = false; - bool wrongTypeContext = false; - bool overloadNotFound = false; - - // Perform lookup, ignoring access control. - auto options = defaultUnqualifiedLookupOptions | - NameLookupFlags::IgnoreAccessControl; - auto results = - TC.lookupUnqualified(lookupContext, funcName, funcNameLoc, options); - - // Note: static methods are omitted from `TypeChecker.lookupUnqualified` in - // Swift 3. The code below is a workaround for resolving them. - // - // This is necessary because the stdlib is compiled with `-swift-version 3` - // for Swift 3 compatibility, and floating point types use the - // `@differentiable` attribute with static adjoint methods (such as - // `_adjointAdd`). - if (lookupContext->getASTContext().isSwiftVersion3() && results.empty() && - lookupContext->isTypeContext()) { - auto tmp = TC.lookupMember(lookupContext, - lookupContext->getSelfTypeInContext(), funcName); - for (auto choice : tmp) { - auto decl = choice.getValueDecl(); - if (!decl) continue; - auto funcDecl = dyn_cast(decl); - if (!funcDecl) continue; - results.add(LookupResultEntry(funcDecl)); - } - } - - for (auto choice : results) { - auto decl = choice.getValueDecl(); - if (!decl) continue; - - auto funcDecl = dyn_cast(decl); - if (!funcDecl) { - notAFuncDecl = true; - continue; - } - if (!hasValidTypeContext(funcDecl)) { - wrongTypeContext = true; - continue; - } - if (!isValidFuncDecl(funcDecl)) { - overloadNotFound = true; - continue; - } - if (resolvedFuncDecl) { - ambiguousDiagnostic(); - resolvedFuncDecl = nullptr; - break; - } - resolvedFuncDecl = funcDecl; - } - // If function declaration could not be resolved, emit the appropriate - // diagnostic. - if (!resolvedFuncDecl) { - if (results.empty()) { - TC.diagnose(funcNameLoc, diag::use_unresolved_identifier, funcName, - funcName.isOperator()); - } else if (wrongTypeContext) { - TC.diagnose(funcNameLoc, - diag::differentiable_attr_function_not_same_type_context, - funcName); - } else if (overloadNotFound) { - overloadDiagnostic(); - } else { - assert(notAFuncDecl && "Expected 'not a function' error"); - notFunctionDiagnostic(); - } - } - - return resolvedFuncDecl; -} - void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { // Forward mode is unsupported. if (attr->getMode() == AutoDiffMode::Forward) { @@ -2318,9 +2227,10 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { auto *original = cast(D); auto isInstanceMethod = original->isInstanceMember(); auto selfDecl = original->getImplicitSelfDecl(); + auto &ctx = original->getASTContext(); - // If the original function has no parameters, there's nothing to - // differentiate with respect to. + // If the original function has no parameters or returns the empty tuple + // type, there's nothing to differentiate from or with-respect-to. auto &originalParams = *original->getParameterList(selfDecl ? 1 : 0); if (!isInstanceMethod && originalParams.size() == 0) { TC.diagnose(attr->getLocation(), diag::differentiable_attr_no_parameters, @@ -2328,30 +2238,63 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { .highlight(original->getSourceRange()); return; } + auto originalResultTy = original->getResultInterfaceType(); + if (originalResultTy->isEqual(ctx.TheEmptyTupleType)) { + TC.diagnose(attr->getLocation(), diag::differentiable_attr_void_result, + original->getName()) + .highlight(original->getSourceRange()); + return; + } - auto originalParamsTy = - originalParams.getInterfaceType(original->getASTContext()); + auto originalParamsTy = originalParams.getInterfaceType(ctx); // If the original function and the primal/adjoint have different parents, or - // if they both have no type context and are in different modules, then - // it's an error. - auto hasValidTypeContext = [&](FuncDecl *decl) { + // if they both have no type context and are in different modules, then it's + // an error. + // Returns true on error. + std::function hasValidTypeContext + = [&](FuncDecl *func) { if (!original->getInnermostTypeContext() && - !decl->getInnermostTypeContext() && - original->getParentModule() == decl->getParentModule()) + !func->getInnermostTypeContext() && + original->getParentModule() == func->getParentModule()) return true; if (auto typeCtx1 = original->getInnermostTypeContext()) { - if (auto typeCtx2 = decl->getInnermostTypeContext()) { + if (auto typeCtx2 = func->getInnermostTypeContext()) { auto type1 = typeCtx1->getDeclaredInterfaceType(); auto type2 = typeCtx2->getDeclaredInterfaceType(); return type1->isEqual(type2); } } - return original->getParent() == decl->getParent(); + return original->getParent() == func->getParent(); }; + // If the original function is exported (i.e. it is public or + // @usableFromInline), then the primal/adjoint must also be exported. + // Returns true on error. + using FuncSpecifier = DifferentiableAttr::FunctionSpecifier; + auto checkAccessControl = [&](FuncDecl *func, FuncSpecifier funcSpec, + bool isPrimal) { + auto originalAccess = + original->getFormalAccess(/*useDC*/ nullptr, + /*treatUsableFromInlineAsPublic*/ true); + if (originalAccess < AccessLevel::Public) return false; + auto funcAccess = + func->getFormalAccess(/*useDC*/ nullptr, + /*treatUsableFromInlineAsPublic*/ true); + if (funcAccess >= AccessLevel::Public) return false; + TC.diagnose(funcSpec.Loc.getBaseNameLoc(), + diag::differentiable_attr_invalid_access, + funcSpec.Name, original->getFullName(), isPrimal); + attr->setInvalid(); + return true; + }; + + // Set lookup options. + auto lookupOptions = defaultMemberLookupOptions + | NameLookupFlags::IgnoreAccessControl; + // Resolve the primal declaration, if it exists. - FuncDecl *resolvedPrimal = nullptr; + FuncDecl *primal = nullptr; if (attr->getPrimal()) { auto primalSpecifier = attr->getPrimal().getValue(); auto primalNameLoc = primalSpecifier.Loc.getBaseNameLoc(); @@ -2374,6 +2317,11 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { diag::differentiable_attr_specified_not_function, primalSpecifier.Name, /*isPrimal*/ true); }; + std::function primalInvalidTypeContextDiagnostic = [&]() { + TC.diagnose(primalNameLoc, + diag::differentiable_attr_function_not_same_type_context, + primalSpecifier.Name); + }; auto isValidPrimal = [&](FuncDecl *primalCandidate) { // Returns true if the primal candidate @@ -2384,8 +2332,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { auto primalSelfDecl = primalCandidate->getImplicitSelfDecl(); auto primalParams = primalCandidate->getParameterList(primalSelfDecl ? 1 : 0); - auto primalParamsTy = - primalParams->getInterfaceType(original->getASTContext()); + auto primalParamsTy = primalParams->getInterfaceType(ctx); if (!primalParamsTy->isEqual(originalParamsTy)) return false; auto originalCanGenSig = original->getGenericSignature() @@ -2406,15 +2353,20 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { return true; }; - resolvedPrimal = - getResolvedFuncDecl(primalSpecifier.Name, primalNameLoc, - TC, primalTypeCtx, isValidPrimal, hasValidTypeContext, - primalOverloadDiagnostic, primalAmbiguousDiagnostic, - primalNotFunctionDiagnostic); + primal = TC.lookupFuncDecl( + primalSpecifier.Name, primalNameLoc, primalTypeCtx, isValidPrimal, + primalOverloadDiagnostic, primalAmbiguousDiagnostic, + primalNotFunctionDiagnostic, lookupOptions, hasValidTypeContext, + primalInvalidTypeContextDiagnostic); - if (!resolvedPrimal) return; + if (!primal) { + attr->setInvalid(); + return; + } + // Check primal access control. + if (checkAccessControl(primal, primalSpecifier, /*isPrimal*/ true)) return; // Memorize the primal reference in the attribute. - attr->setPrimalFunction(resolvedPrimal); + attr->setPrimalFunction(primal); } // Compute the return type of the adjoint function. @@ -2428,7 +2380,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { SmallVector retElts; for (auto *param : originalParams) retElts.push_back(param->getInterfaceType()); - retTy = TupleType::get(retElts, original->getASTContext()); + retTy = TupleType::get(retElts, ctx); } else { retTy = originalParams[0]->getInterfaceType(); } @@ -2503,7 +2455,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { // type. assert(retElts.size() > 0 && "There should be at least one return type"); retTy = retElts.size() > 1 - ? TupleType::get(retElts, original->getASTContext()) + ? TupleType::get(retElts, ctx) : retElts[0].getType(); } @@ -2517,9 +2469,8 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { // original result, and the seed. // // If the primal exists, the checkpoints type is the primal result type. - if (attr->getPrimal()) { - auto *primResultTy = - resolvedPrimal->getResultInterfaceType()->getAs(); + if (primal) { + auto *primResultTy = primal->getResultInterfaceType()->getAs(); auto checkpointsTy = primResultTy->getElement(0).getType(); paramTypes.push_back(FunctionType::Param(checkpointsTy)); } @@ -2552,7 +2503,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { } // Resolve the adjoint declaration. - FuncDecl *resolvedAdjoint = nullptr; + FuncDecl *adjoint = nullptr; auto adjointSpecifier = attr->getAdjoint(); auto adjointNameLoc = adjointSpecifier.Loc.getBaseNameLoc(); @@ -2563,35 +2514,51 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { TC.diagnose(adjointNameLoc, diag::differentiable_attr_adjoint_overload_not_found, adjointSpecifier.Name, expectedAdjointFnTy); + attr->setInvalid(); }; auto adjointAmbiguousDiagnostic = [&]() { TC.diagnose(adjointNameLoc, diag::differentiable_attr_ambiguous_function_identifier, adjointSpecifier.Name); + attr->setInvalid(); }; auto adjointNotFunctionDiagnostic = [&]() { TC.diagnose(adjointNameLoc, diag::differentiable_attr_specified_not_function, adjointSpecifier.Name, /*isPrimal*/ false); + attr->setInvalid(); + }; + std::function adjointInvalidTypeContextDiagnostic = [&]() { + TC.diagnose(adjointNameLoc, + diag::differentiable_attr_function_not_same_type_context, + adjointSpecifier.Name); + attr->setInvalid(); }; auto isValidAdjoint = [&](FuncDecl *adjointCandidate) { // Returns true if adjoint candidate has the expected type. auto adjointType = adjointCandidate->getInterfaceType() - ->getUnlabeledType(original->getASTContext()); + ->getUnlabeledType(ctx); return adjointType->isEqual(expectedAdjointFnTy); }; - resolvedAdjoint = - getResolvedFuncDecl(adjointSpecifier.Name, adjointNameLoc, - TC, adjointTypeCtx, isValidAdjoint, hasValidTypeContext, - adjointOverloadDiagnostic, adjointAmbiguousDiagnostic, - adjointNotFunctionDiagnostic); + adjoint = + TC.lookupFuncDecl(adjointSpecifier.Name, adjointNameLoc, adjointTypeCtx, + isValidAdjoint, adjointOverloadDiagnostic, + adjointAmbiguousDiagnostic, adjointNotFunctionDiagnostic, + lookupOptions, hasValidTypeContext, + adjointInvalidTypeContextDiagnostic); - if (!resolvedAdjoint) return; + if (!adjoint) { + attr->setInvalid(); + return; + } + // Check adjoint access control. + if (checkAccessControl(adjoint, adjointSpecifier, /*isPrimal*/ false)) + return; // Done checking @differentiable attribute. // Memorize the adjoint reference in the attribute. - attr->setAdjointFunction(resolvedAdjoint); + attr->setAdjointFunction(adjoint); } void AttributeChecker::visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr) { diff --git a/lib/Sema/TypeCheckExpr.cpp b/lib/Sema/TypeCheckExpr.cpp index 3a77c902905fa..e59fa5f693589 100644 --- a/lib/Sema/TypeCheckExpr.cpp +++ b/lib/Sema/TypeCheckExpr.cpp @@ -709,3 +709,108 @@ bool TypeChecker::isCompatibleWithVectorAutoDiff(Type type, DeclContext *DC) { } return false; } + +// SWIFT_ENABLE_TENSORFLOW +// Returns the function declaration corresponding to the given function name and +// lookup context. If the function declaration cannot be resolved, emits a +// diagnostic and returns nullptr. +FuncDecl * +TypeChecker::lookupFuncDecl( + DeclName funcName, SourceLoc funcNameLoc, DeclContext *lookupContext, + const std::function &isValidFuncDecl, + const std::function &overloadDiagnostic, + const std::function &ambiguousDiagnostic, + const std::function ¬FunctionDiagnostic, + NameLookupOptions lookupOptions, + const Optional> &hasValidTypeCtx, + const Optional> &invalidTypeCtxDiagnostic) { + + FuncDecl *resolvedFuncDecl = nullptr; + + // Perform lookup. + auto results = + lookupUnqualified(lookupContext, funcName, funcNameLoc, lookupOptions); + + // If looking up an operator within a type context, look specifically within + // the type context. + // This works around the fact that operators cannot be specified with a + // qualified name (i.e. only `(+)` works, `Float.(+)` doesn't). + if (funcName.isOperator() && lookupContext->isTypeContext()) { + if (auto tmp = lookupMember(lookupContext, + lookupContext->getSelfTypeInContext(), + funcName)) + results = tmp; + } + // Note: static methods are omitted from `TypeChecker.lookupUnqualified` in + // Swift 3. The code below is a workaround for resolving them. + // + // This is necessary because the stdlib is compiled with `-swift-version 3` + // for Swift 3 compatibility, and floating point types use the + // `@differentiable` attribute with static adjoint methods (such as + // `_adjointAdd`). + else if (lookupContext->getASTContext().isSwiftVersion3() && + results.empty() && lookupContext->isTypeContext()) { + results = lookupMember(lookupContext, lookupContext->getSelfTypeInContext(), + funcName); + } + + // Initialize error flags. + bool notAFuncDecl = false; + bool wrongTypeContext = false; + bool ambiguousFuncDecl = false; + bool overloadNotFound = false; + + // Filter lookup results. + for (auto choice : results) { + auto decl = choice.getValueDecl(); + if (!decl) continue; + + auto funcDecl = dyn_cast(decl); + if (!funcDecl) { + notAFuncDecl = true; + continue; + } + if (hasValidTypeCtx.hasValue()) { + if (hasValidTypeCtx && !(*hasValidTypeCtx)(funcDecl)) { + wrongTypeContext = true; + continue; + } + } + if (!isValidFuncDecl(funcDecl)) { + overloadNotFound = true; + continue; + } + if (resolvedFuncDecl) { + ambiguousFuncDecl = true; + resolvedFuncDecl = nullptr; + break; + } + resolvedFuncDecl = funcDecl; + } + // If function declaration was resolved, return it. + if (resolvedFuncDecl) return resolvedFuncDecl; + + // Otherwise, emit the appropriate diagnostic and return nullptr. + if (results.empty()) { + diagnose(funcNameLoc, diag::use_unresolved_identifier, funcName, + funcName.isOperator()); + return nullptr; + } + if (ambiguousFuncDecl) { + ambiguousDiagnostic(); + return nullptr; + } + if (wrongTypeContext) { + assert(invalidTypeCtxDiagnostic && + "Type context diagnostic should've been specified"); + (*invalidTypeCtxDiagnostic)(); + return nullptr; + } + if (overloadNotFound) { + overloadDiagnostic(); + return nullptr; + } + assert(notAFuncDecl && "Expected 'not a function' error"); + notFunctionDiagnostic(); + return nullptr; +} diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h index 8b301961b9649..922f5815e1158 100644 --- a/lib/Sema/TypeChecker.h +++ b/lib/Sema/TypeChecker.h @@ -2534,12 +2534,25 @@ class TypeChecker final : public LazyResolver { /// We say that a type supports scalar AD when it conforms to /// `FloatingPoint`. bool isCompatibleWithScalarAutoDiff(Type type, DeclContext *DC); - - + /// Determines whether the specified type supports vector differentiation. /// We say that a type supports vector AD when it conforms to /// `VectorNumeric` while its `ScalarElement` supports scalar AD. bool isCompatibleWithVectorAutoDiff(Type type, DeclContext *DC); + + /// SWIFT_ENABLE_TENSORFLOW + // Returns the function declaration corresponding to the given function name and + // lookup context. If the function declaration cannot be resolved, emits a + // diagnostic and returns nullptr. + FuncDecl *lookupFuncDecl( + DeclName funcName, SourceLoc funcNameLoc, DeclContext *lookupContext, + const std::function &isValidFuncDecl, + const std::function &overloadDiagnostic, + const std::function &ambiguousDiagnostic, + const std::function ¬FunctionDiagnostic, + NameLookupOptions lookupOptions = defaultMemberLookupOptions, + const Optional> &hasValidTypeCtx = None, + const Optional> &invalidTypeCtxDiagnostic = None); }; /// \brief RAII object that cleans up the given expression if not explicitly diff --git a/test/AutoDiff/differentiable_attr_access_control.swift b/test/AutoDiff/differentiable_attr_access_control.swift new file mode 100644 index 0000000000000..52a7c8f56bc01 --- /dev/null +++ b/test/AutoDiff/differentiable_attr_access_control.swift @@ -0,0 +1,32 @@ +// RUN: %target-swift-frontend -typecheck -verify %s + +// If the original function is "exported" (public or @usableFromInline), then +// its primal/adjoint must also be exported. + +// Ok: all public. +@differentiable(reverse, adjoint: dfoo1(_:primal:seed:)) +public func foo1(_ x: Float) -> Float { return 1 } +public func dfoo1(_ x: Float, primal: Float, seed: Float) -> Float { return 1 } + +// Ok: all internal. +struct CheckpointsFoo {} +@differentiable(reverse, primal: pfoo2(_:), adjoint: dfoo2(_:checkpoints:originalValue:seed:)) +func foo2(_ x: Float) -> Float { return 1 } +func pfoo2(_ x: Float) -> (checkpoints: CheckpointsFoo, originalValue: Float) { return (CheckpointsFoo(), 1) } +func dfoo2(_ x: Float, checkpoints: CheckpointsFoo, originalValue: Float, seed: Float) -> Float { return 1 } + +// Ok: all private. +@differentiable(reverse, adjoint: dfoo3(_:primal:seed:)) +private func foo3(_ x: Float) -> Float { return 1 } +private func dfoo3(_ x: Float, primal: Float, seed: Float) -> Float { return 1 } + +// Error: adjoint not exported. +@differentiable(reverse, adjoint: dbar1(_:primal:seed:)) // expected-error {{adjoint 'dbar1(_:primal:seed:)' is required to either be public or @usableFromInline because the original function 'bar1' is public or @usableFromInline}} +public func bar1(_ x: Float) -> Float { return 1 } +private func dbar1(_ x: Float, primal: Float, seed: Float) -> Float { return 1 } + +// Error: primal not exported. +@differentiable(reverse, primal: pbar2(_:), adjoint: dbar2(_:checkpoints:originalValue:seed:)) // expected-error {{primal 'pbar2' is required to either be public or @usableFromInline because the original function 'bar2' is public or @usableFromInline}} +@_versioned func bar2(_ x: Float) -> Float { return 1 } +func pbar2(_ x: Float) -> (checkpoints: CheckpointsFoo, originalValue: Float) { return (CheckpointsFoo(), 1) } +func dbar2(_ x: Float, checkpoints: CheckpointsFoo, originalValue: Float, seed: Float) -> Float { return 1 } diff --git a/test/AutoDiff/differentiable_attr_type_checking.swift b/test/AutoDiff/differentiable_attr_type_checking.swift index ddf7ac065bf11..29aae941e4caa 100644 --- a/test/AutoDiff/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/differentiable_attr_type_checking.swift @@ -15,10 +15,13 @@ func foo(_ x: Float) -> Float { return x * x } -// Primal returns custom checkpoints type. -struct CheckpointsFoo { -} +// Original function must return non-Void type. +@differentiable(reverse, adjoint: dvoid) // expected-error {{cannot differentiate void function 'void'}} +func void(_ a: Float) {} +func dvoid(_ a: Float, _ x: (), _ y: ()) -> Float { return 1 } +// Primal returns custom checkpoints type. +struct CheckpointsFoo {} func pfoo(_ x: Float) -> (checkpoints: CheckpointsFoo, originalValue: Float) { return (CheckpointsFoo(), x * x) }