diff --git a/include/swift/AST/Attr.h b/include/swift/AST/Attr.h index 850ef46bcf7c3..30f610483f21c 100644 --- a/include/swift/AST/Attr.h +++ b/include/swift/AST/Attr.h @@ -1679,9 +1679,9 @@ class DifferentiatingAttr final friend TrailingObjects; /// The original function name. - DeclNameWithLoc Original; - /// The original function, resolved by the type checker. - FuncDecl *OriginalFunction = nullptr; + DeclNameWithLoc OriginalFunctionName; + /// The original function declaration, resolved by the type checker. + AbstractFunctionDecl *OriginalFunction = nullptr; /// The number of parsed parameters specified in 'wrt:'. unsigned NumParsedParameters = 0; /// The differentiation parameters' indices, resolved by the type checker. @@ -1706,9 +1706,15 @@ class DifferentiatingAttr final DeclNameWithLoc original, IndexSubset *indices); - DeclNameWithLoc getOriginal() const { return Original; } - FuncDecl *getOriginalFunction() const { return OriginalFunction; } - void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; } + DeclNameWithLoc getOriginalFunctionName() const { + return OriginalFunctionName; + } + AbstractFunctionDecl *getOriginalFunction() const { + return OriginalFunction; + } + void setOriginalFunction(AbstractFunctionDecl *decl) { + OriginalFunction = decl; + } /// The parsed differentiation parameters, i.e. the list of parameters /// specified in 'wrt:'. @@ -1750,9 +1756,9 @@ class TransposingAttr final /// is an instance/static method). TypeRepr *BaseType; /// The original function name. - DeclNameWithLoc Original; - /// The original function, resolved by the type checker. - FuncDecl *OriginalFunction = nullptr; + DeclNameWithLoc OriginalFunctionName; + /// The original function declaration, resolved by the type checker. + AbstractFunctionDecl *OriginalFunction = nullptr; /// The number of parsed parameters specified in 'wrt:'. unsigned NumParsedParameters = 0; /// The differentiation parameters' indices, resolved by the type checker. @@ -1779,10 +1785,15 @@ class TransposingAttr final IndexSubset *indices); TypeRepr *getBaseType() const { return BaseType; } - DeclNameWithLoc getOriginal() const { return Original; } - - FuncDecl *getOriginalFunction() const { return OriginalFunction; } - void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; } + DeclNameWithLoc getOriginalFunctionName() const { + return OriginalFunctionName; + } + AbstractFunctionDecl *getOriginalFunction() const { + return OriginalFunction; + } + void setOriginalFunction(AbstractFunctionDecl *decl) { + OriginalFunction = decl; + } /// The parsed transposing parameters, i.e. the list of parameters /// specified in 'wrt:'. diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index c29abac83f44d..62bb2f68395b7 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2954,9 +2954,8 @@ NOTE(differentiable_attr_duplicate_note,none, "other attribute declared here", ()) ERROR(differentiable_attr_function_not_same_type_context,none, "%0 is not defined in the current type context", (DeclName)) -ERROR(differentiable_attr_specified_not_function,none, - "%0 is not a function to be used as derivative function", - (DeclName)) +ERROR(differentiable_attr_derivative_not_function,none, + "registered derivative %0 must be a 'func' declaration", (DeclName)) ERROR(differentiable_attr_class_derivative_not_final,none, "class member derivative must be final", ()) ERROR(differentiable_attr_ambiguous_function_identifier,none, @@ -3020,6 +3019,8 @@ ERROR(differentiating_attr_overload_not_found,none, "could not find function %0 with expected type %1", (DeclName, Type)) ERROR(differentiating_attr_not_in_same_file_as_original,none, "derivative not in the same file as the original function", ()) +ERROR(differentiating_attr_original_stored_property_unsupported,none, + "cannot register derivative for stored property %0", (DeclName)) ERROR(differentiating_attr_original_already_has_derivative,none, "a derivative already exists for %0", (DeclName)) @@ -3033,7 +3034,7 @@ ERROR(transposing_attr_cannot_use_named_wrt_params,none, "cannot use named 'wrt' parameters in '@transposing' attribute, found %0", (Identifier)) ERROR(transposing_attr_result_value_not_differentiable,none, - "'@transposing' attribute requires original function result to " + "'@transposing' attribute requires original function result %0 to " "conform to 'Differentiable'", (Type)) // differentiation `wrt` parameters clause diff --git a/lib/AST/Attr.cpp b/lib/AST/Attr.cpp index be805321137ca..048522a719cea 100644 --- a/lib/AST/Attr.cpp +++ b/lib/AST/Attr.cpp @@ -926,7 +926,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options, Printer << '('; auto *attr = cast(this); auto *derivative = cast(D); - Printer << attr->getOriginal().Name; + Printer << attr->getOriginalFunctionName().Name; auto diffParamsString = getDifferentiationParametersClauseString( derivative, attr->getParameterIndices(), attr->getParsedParameters()); if (!diffParamsString.empty()) @@ -941,7 +941,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options, Printer << '('; auto *attr = cast(this); auto *transpose = cast(D); - Printer << attr->getOriginal().Name; + Printer << attr->getOriginalFunctionName().Name; auto transParamsString = getTransposedParametersClauseString( transpose, attr->getParameterIndices(), attr->getParsedParameters()); if (!transParamsString.empty()) @@ -1570,19 +1570,21 @@ void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D, // SWIFT_ENABLE_TENSORFLOW DifferentiatingAttr::DifferentiatingAttr( bool implicit, SourceLoc atLoc, SourceRange baseRange, - DeclNameWithLoc original, ArrayRef params) + DeclNameWithLoc originalName, ArrayRef params) : DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit), - Original(std::move(original)), NumParsedParameters(params.size()) { + OriginalFunctionName(std::move(originalName)), + NumParsedParameters(params.size()) { std::copy(params.begin(), params.end(), getTrailingObjects()); } DifferentiatingAttr::DifferentiatingAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange, - DeclNameWithLoc original, + DeclNameWithLoc originalName, IndexSubset *indices) : DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit), - Original(std::move(original)), ParameterIndices(indices) {} + OriginalFunctionName(std::move(originalName)), ParameterIndices(indices) { +} DifferentiatingAttr * DifferentiatingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc, @@ -1607,10 +1609,10 @@ DifferentiatingAttr *DifferentiatingAttr::create(ASTContext &context, TransposingAttr::TransposingAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange, TypeRepr *baseType, - DeclNameWithLoc original, + DeclNameWithLoc originalName, ArrayRef params) : DeclAttribute(DAK_Transposing, atLoc, baseRange, implicit), - BaseType(baseType), Original(std::move(original)), + BaseType(baseType), OriginalFunctionName(std::move(originalName)), NumParsedParameters(params.size()) { std::uninitialized_copy(params.begin(), params.end(), getTrailingObjects()); @@ -1618,9 +1620,10 @@ TransposingAttr::TransposingAttr(bool implicit, SourceLoc atLoc, TransposingAttr::TransposingAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange, TypeRepr *baseType, - DeclNameWithLoc original, IndexSubset *indices) + DeclNameWithLoc originalName, + IndexSubset *indices) : DeclAttribute(DAK_Transposing, atLoc, baseRange, implicit), - BaseType(baseType), Original(std::move(original)), + BaseType(baseType), OriginalFunctionName(std::move(originalName)), ParameterIndices(indices) {} TransposingAttr * diff --git a/lib/Parse/ParseDecl.cpp b/lib/Parse/ParseDecl.cpp index 0e47cb88b03ee..a19501d292815 100644 --- a/lib/Parse/ParseDecl.cpp +++ b/lib/Parse/ParseDecl.cpp @@ -1069,7 +1069,7 @@ bool Parser::parseDifferentiableAttributeArguments( SyntaxParsingContext FuncDeclNameContext( SyntaxContext, SyntaxKind::FunctionDeclName); Diagnostic funcDiag(diag::attr_differentiable_expected_function_name.ID, - { label }); + {label}); result.Name = parseUnqualifiedDeclName(/*afterDot=*/false, result.Loc, funcDiag, /*allowOperators=*/true, @@ -1165,11 +1165,14 @@ Parser::parseDifferentiatingAttribute(SourceLoc atLoc, SourceLoc loc) { // Parse the name of the function. SyntaxParsingContext FuncDeclNameContext( SyntaxContext, SyntaxKind::FunctionDeclName); + // NOTE: Use `afterDot = true` and `allowDeinitAndSubscript = true` to + // enable, e.g. `@differentiating(init)` and + // `@differentiating(subscript)`. original.Name = parseUnqualifiedDeclName( - /*afterDot*/ false, original.Loc, + /*afterDot*/ true, original.Loc, diag::attr_differentiating_expected_original_name, - /*allowOperators*/ true, /*allowZeroArgCompoundNames*/ true); - + /*allowOperators*/ true, /*allowZeroArgCompoundNames*/ true, + /*allowDeinitAndSubscript*/ true); if (consumeIfTrailingComma()) return makeParserError(); } @@ -1228,19 +1231,13 @@ bool parseQualifiedDeclName(Parser &P, Diag<> nameParseError, if (parseBaseTypeForQualifiedDeclName(P, baseType)) return true; - // If base type was parsed and has at least one component, then there was a - // dot before the current token. - bool afterDot = false; - if (baseType) { - if (auto ident = dyn_cast(baseType)) { - auto components = ident->getComponentRange(); - afterDot = std::distance(components.begin(), components.end()) > 0; - } - } + // NOTE: Use `afterDot = true` and `allowDeinitAndSubscript = true` to enable + // initializer and subscript lookup. original.Name = - P.parseUnqualifiedDeclName(afterDot, original.Loc, nameParseError, - /*allowOperators*/ true, - /*allowZeroArgCompoundNames*/ true); + P.parseUnqualifiedDeclName(/*afterDot*/ true, original.Loc, + nameParseError, /*allowOperators*/ true, + /*allowZeroArgCompoundNames*/ true, + /*allowDeinitAndSubscript*/ true); // The base type is optional, but the final unqualified decl name is not. // If name could not be parsed, return true for error. @@ -1285,7 +1282,6 @@ ParserResult Parser::parseTransposingAttribute(SourceLoc atLoc, diag::attr_transposing_expected_original_name, baseType, original)) return makeParserError(); - if (consumeIfTrailingComma()) return makeParserError(); } diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 1e2c5572855f0..a328286bc0058 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -2694,9 +2694,115 @@ TypeChecker::inferDifferentiableParameters( } // SWIFT_ENABLE_TENSORFLOW -static FuncDecl *resolveAutoDiffDerivativeFunction( - DeclNameWithLoc specifier, AbstractFunctionDecl *original, - Type expectedTy, std::function isValid) { +// Returns the function declaration corresponding to the given function name and +// lookup context. If the base type of the function is specified, member lookup +// is performed. Otherwise, unqualified lookup is performed. +// If the function declaration cannot be resolved, emits a diagnostic and +// returns nullptr. +static AbstractFunctionDecl *findAbstractFunctionDecl( + DeclName funcName, SourceLoc funcNameLoc, Type baseType, + DeclContext *lookupContext, + const std::function &isValidCandidate, + const std::function &overloadDiagnostic, + const std::function &ambiguousDiagnostic, + const std::function ¬FunctionDiagnostic, + NameLookupOptions lookupOptions, + const Optional> + &hasValidTypeCtx, + const Optional> &invalidTypeCtxDiagnostic) { + auto &ctx = lookupContext->getASTContext(); + AbstractFunctionDecl *resolvedCandidate = nullptr; + + // Perform lookup. + LookupResult results; + if (baseType) { + results = TypeChecker::lookupMember(lookupContext, baseType, funcName); + } else { + results = TypeChecker::lookupUnqualified(lookupContext, funcName, + funcNameLoc, lookupOptions); + + // If looking up an operator within a type context, look specifically within + // the type context. + // This tries to resolve unqualified operators, like `+`. + if (funcName.isOperator() && lookupContext->isTypeContext()) { + if (auto tmp = TypeChecker::lookupMember( + lookupContext, lookupContext->getSelfTypeInContext(), funcName)) + results = tmp; + } + } + + // Initialize error flags. + bool notFunction = false; + bool wrongTypeContext = false; + bool ambiguousFuncDecl = false; + bool overloadNotFound = false; + + // Filter lookup results. + for (auto choice : results) { + auto decl = choice.getValueDecl(); + if (!decl) + continue; + // Cast the candidate to an `AbstractFunctionDecl`. + auto *candidate = dyn_cast(decl); + // If the candidate is an `AbstractStorageDecl`, use its getter as the + // candidate. + if (auto *asd = dyn_cast(decl)) + candidate = asd->getAccessor(AccessorKind::Get); + if (!candidate) { + notFunction = true; + continue; + } + if (hasValidTypeCtx && !(*hasValidTypeCtx)(candidate)) { + wrongTypeContext = true; + continue; + } + if (!isValidCandidate(candidate)) { + overloadNotFound = true; + continue; + } + if (resolvedCandidate) { + ambiguousFuncDecl = true; + resolvedCandidate = nullptr; + break; + } + resolvedCandidate = candidate; + } + // If function declaration was resolved, return it. + if (resolvedCandidate) + return resolvedCandidate; + + // Otherwise, emit the appropriate diagnostic and return nullptr. + if (results.empty()) { + ctx.Diags.diagnose(funcNameLoc, diag::use_unresolved_identifier, funcName, + funcName.isOperator()); + return nullptr; + } + if (ambiguousFuncDecl) { + ambiguousDiagnostic(); + return nullptr; + } + if (wrongTypeContext) { + assert(invalidTypeCtxDiagnostic && + "Type context diagnostic should've been specified"); + (*invalidTypeCtxDiagnostic)(); + return nullptr; + } + if (overloadNotFound) { + overloadDiagnostic(); + return nullptr; + } + assert(notFunction && "Expected 'not a function' error"); + notFunctionDiagnostic(); + return nullptr; +} + +// SWIFT_ENABLE_TENSORFLOW +// Finds a derivative function declaration using the given function specifier, +// original function declaration, expected type, and "is valid" predicate. If no +// valid derivative function is found, emits diagnostics and returns false. +static FuncDecl *findAutoDiffDerivativeFunction( + DeclNameWithLoc specifier, AbstractFunctionDecl *original, Type expectedTy, + std::function isValid) { auto &ctx = original->getASTContext(); auto &diags = ctx.Diags; auto nameLoc = specifier.Loc.getBaseNameLoc(); @@ -2710,7 +2816,7 @@ static FuncDecl *resolveAutoDiffDerivativeFunction( specifier.Name); }; auto notFunctionDiagnostic = [&]() { - diags.diagnose(nameLoc, diag::differentiable_attr_specified_not_function, + diags.diagnose(nameLoc, diag::differentiable_attr_derivative_not_function, specifier.Name); }; std::function invalidTypeContextDiagnostic = [&]() { @@ -2723,19 +2829,20 @@ static FuncDecl *resolveAutoDiffDerivativeFunction( // defined in compatible type contexts. If the original function and the // derivative function have different parents, or if they both have no type // context and are in different modules, return false. - std::function hasValidTypeContext = [&](FuncDecl *func) { - // Check if both functions are top-level. - if (!original->getInnermostTypeContext() && - !func->getInnermostTypeContext() && - original->getParentModule() == func->getParentModule()) - return true; - // Check if both functions are defined in the same type context. - if (auto typeCtx1 = original->getInnermostTypeContext()) - if (auto typeCtx2 = func->getInnermostTypeContext()) - return typeCtx1->getSelfNominalTypeDecl() == - typeCtx2->getSelfNominalTypeDecl(); - return original->getParent() == func->getParent(); - }; + std::function hasValidTypeContext = + [&](AbstractFunctionDecl *func) { + // Check if both functions are top-level. + if (!original->getInnermostTypeContext() && + !func->getInnermostTypeContext() && + original->getParentModule() == func->getParentModule()) + return true; + // Check if both functions are defined in the same type context. + if (auto typeCtx1 = original->getInnermostTypeContext()) + if (auto typeCtx2 = func->getInnermostTypeContext()) + return typeCtx1->getSelfNominalTypeDecl() == + typeCtx2->getSelfNominalTypeDecl(); + return original->getParent() == func->getParent(); + }; auto isABIPublic = [&](AbstractFunctionDecl *func) { return func->getFormalAccess() >= AccessLevel::Public || @@ -2744,9 +2851,9 @@ static FuncDecl *resolveAutoDiffDerivativeFunction( }; // If the original function is exported (i.e. it is public or - // @usableFromInline), then the derivative functions must also be exported. + // `@usableFromInline`), then the derivative functions must also be exported. // Returns true on error. - auto checkAccessControl = [&](FuncDecl *func) { + auto checkAccessControl = [&](AbstractFunctionDecl *func) { if (!isABIPublic(original)) return false; if (isABIPublic(func)) @@ -2764,17 +2871,21 @@ static FuncDecl *resolveAutoDiffDerivativeFunction( auto lookupOptions = defaultMemberLookupOptions | NameLookupFlags::IgnoreAccessControl; - auto candidate = TypeChecker::lookupFuncDecl( + auto *candidate = findAbstractFunctionDecl( specifier.Name, nameLoc, /*baseType*/ Type(), originalTypeCtx, isValid, overloadDiagnostic, ambiguousDiagnostic, notFunctionDiagnostic, lookupOptions, hasValidTypeContext, invalidTypeContextDiagnostic); - if (!candidate) return nullptr; - + // Reject non-`func` registered derivatives. JVPs and VJPs must be `func` + // declarations. + if (isa(candidate)) { + diags.diagnose(nameLoc, diag::differentiable_attr_derivative_not_function, + specifier.Name); + return nullptr; + } if (checkAccessControl(candidate)) return nullptr; - // Derivatives of class members must be final. if (original->getDeclContext()->getSelfClassDecl() && !candidate->isFinal()) { @@ -2782,8 +2893,9 @@ static FuncDecl *resolveAutoDiffDerivativeFunction( diag::differentiable_attr_class_derivative_not_final); return nullptr; } - - return candidate; + assert(isa(candidate)); + auto *funcDecl = cast(candidate); + return funcDecl; } // SWIFT_ENABLE_TENSORFLOW @@ -3447,20 +3559,19 @@ DifferentiableAttributeParameterIndicesRequest::evaluate( AutoDiffDerivativeFunctionKind::JVP, lookupConformance, whereClauseGenSig, /*makeSelfParamFirst*/ true); - auto isValidJVP = [&](FuncDecl *jvpCandidate) { + auto isValidJVP = [&](AbstractFunctionDecl *jvpCandidate) -> bool { return checkFunctionSignature( cast(expectedJVPFnTy->getCanonicalType()), jvpCandidate->getInterfaceType()->getCanonicalType()); }; - FuncDecl *jvp = resolveAutoDiffDerivativeFunction( + FuncDecl *jvp = findAutoDiffDerivativeFunction( attr->getJVP().getValue(), original, expectedJVPFnTy, isValidJVP); - if (!jvp) { attr->setInvalid(); return nullptr; } - // Memorize the jvp reference in the attribute. + // Set the JVP declaration in the attribute. attr->setJVPFunction(jvp); } @@ -3472,20 +3583,19 @@ DifferentiableAttributeParameterIndicesRequest::evaluate( AutoDiffDerivativeFunctionKind::VJP, lookupConformance, whereClauseGenSig, /*makeSelfParamFirst*/ true); - auto isValidVJP = [&](FuncDecl *vjpCandidate) { + auto isValidVJP = [&](AbstractFunctionDecl *vjpCandidate) -> bool { return checkFunctionSignature( cast(expectedVJPFnTy->getCanonicalType()), vjpCandidate->getInterfaceType()->getCanonicalType()); }; - FuncDecl *vjp = resolveAutoDiffDerivativeFunction( + FuncDecl *vjp = findAutoDiffDerivativeFunction( attr->getVJP().getValue(), original, expectedVJPFnTy, isValidVJP); - if (!vjp) { attr->setInvalid(); return nullptr; } - // Memorize the vjp reference in the attribute. + // Set the VJP declaration in the attribute. attr->setVJPFunction(vjp); } @@ -3532,10 +3642,10 @@ DifferentiableAttributeParameterIndicesRequest::evaluate( // SWIFT_ENABLE_TENSORFLOW void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { - FuncDecl *derivative = dyn_cast(D); + FuncDecl *derivative = cast(D); auto lookupConformance = LookUpConformanceInModule(D->getDeclContext()->getParentModule()); - auto original = attr->getOriginal(); + auto originalName = attr->getOriginalFunctionName(); auto *derivativeInterfaceType = derivative->getInterfaceType() ->castTo(); @@ -3607,8 +3717,8 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { return false; // Check if target's requirements are satisfied by source. return TypeChecker::checkGenericArguments( - derivative, original.Loc.getBaseNameLoc(), - original.Loc.getBaseNameLoc(), Type(), + derivative, originalName.Loc.getBaseNameLoc(), + originalName.Loc.getBaseNameLoc(), Type(), source->getGenericParams(), target->getRequirements(), [](SubstitutableType *dependentType) { return Type(dependentType); @@ -3616,50 +3726,53 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { lookupConformance, None) == RequirementCheckResult::Success; }; - auto isValidOriginal = [&](FuncDecl *originalCandidate) { + auto isValidOriginal = [&](AbstractFunctionDecl *originalCandidate) { return checkFunctionSignature( cast(originalFnType->getCanonicalType()), originalCandidate->getInterfaceType()->getCanonicalType(), checkGenericSignatureSatisfied); }; - // TODO: Do not reuse incompatible `@differentiable` attribute diagnostics. - // Rename compatible diagnostics so that they're not attribute-specific. + // TODO(TF-998): Do not reuse incompatible `@differentiable` attribute + // diagnostics. Rename compatible diagnostics so that they're not + // attribute-specific. auto overloadDiagnostic = [&]() { - diagnose(original.Loc, diag::differentiating_attr_overload_not_found, - original.Name, originalFnType); + diagnose(originalName.Loc, diag::differentiating_attr_overload_not_found, + originalName.Name, originalFnType); }; auto ambiguousDiagnostic = [&]() { - diagnose(original.Loc, + diagnose(originalName.Loc, diag::differentiable_attr_ambiguous_function_identifier, - original.Name); + originalName.Name); }; auto notFunctionDiagnostic = [&]() { - diagnose(original.Loc, diag::differentiable_attr_specified_not_function, - original.Name); + diagnose(originalName.Loc, + diag::differentiable_attr_derivative_not_function, + originalName.Name); }; std::function invalidTypeContextDiagnostic = [&]() { - diagnose(original.Loc, + diagnose(originalName.Loc, diag::differentiable_attr_function_not_same_type_context, - original.Name); + originalName.Name); }; // Returns true if the derivative function and original function candidate are // defined in compatible type contexts. If the derivative function and the // original function candidate have different parents, return false. - std::function hasValidTypeContext = [&](FuncDecl *func) { - // Check if both functions are top-level. - if (!derivative->getInnermostTypeContext() && - !func->getInnermostTypeContext()) - return true; - // Check if both functions are defined in the same type context. - if (auto typeCtx1 = derivative->getInnermostTypeContext()) - if (auto typeCtx2 = func->getInnermostTypeContext()) { - return typeCtx1->getSelfNominalTypeDecl() == - typeCtx2->getSelfNominalTypeDecl(); - } - return derivative->getParent() == func->getParent(); - }; + std::function hasValidTypeContext = + [&](AbstractFunctionDecl *func) { + // Check if both functions are top-level. + if (!derivative->getInnermostTypeContext() && + !func->getInnermostTypeContext()) + return true; + // Check if both functions are defined in the same type context. + if (auto typeCtx1 = derivative->getInnermostTypeContext()) + if (auto typeCtx2 = func->getInnermostTypeContext()) { + return typeCtx1->getSelfNominalTypeDecl() == + typeCtx2->getSelfNominalTypeDecl(); + } + return derivative->getParent() == func->getParent(); + }; auto lookupOptions = defaultMemberLookupOptions | NameLookupFlags::IgnoreAccessControl; @@ -3668,16 +3781,30 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { assert(derivativeTypeCtx); // Look up original function. - auto *originalFn = TypeChecker::lookupFuncDecl( - original.Name, original.Loc.getBaseNameLoc(), /*baseType*/ Type(), + auto *originalAFD = findAbstractFunctionDecl( + originalName.Name, originalName.Loc.getBaseNameLoc(), /*baseType*/ Type(), derivativeTypeCtx, isValidOriginal, overloadDiagnostic, ambiguousDiagnostic, notFunctionDiagnostic, lookupOptions, hasValidTypeContext, invalidTypeContextDiagnostic); - if (!originalFn) { + if (!originalAFD) { attr->setInvalid(); return; } - attr->setOriginalFunction(originalFn); + // Diagnose original stored properties. Stored properties cannot have custom + // registered derivatives. + if (auto *accessorDecl = dyn_cast(originalAFD)) { + auto *asd = accessorDecl->getStorage(); + if (asd->hasStorage()) { + diagnose(originalName.Loc, + diag::differentiating_attr_original_stored_property_unsupported, + originalName.Name); + diagnose(originalAFD->getLoc(), diag::decl_declared_here, + asd->getFullName()); + attr->setInvalid(); + return; + } + } + attr->setOriginalFunction(originalAFD); // Get checked wrt param indices. auto *checkedWrtParamIndices = attr->getParameterIndices(); @@ -3700,7 +3827,7 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { // Check if differentiation parameter indices are valid. if (checkDifferentiationParameters( - originalFn, checkedWrtParamIndices, originalFnType, + originalAFD, checkedWrtParamIndices, originalFnType, derivative->getGenericEnvironment(), derivative->getModuleContext(), parsedWrtParams, attr->getLocation())) { attr->setInvalid(); @@ -3753,7 +3880,7 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { // Emit differential/pullback type mismatch error on attribute. diagnose(attr->getLocation(), diag::differentiating_attr_result_func_type_mismatch, - funcResultElt.getName(), originalFn->getFullName()); + funcResultElt.getName(), originalAFD->getFullName()); // Emit note with expected differential/pullback type on actual type // location. auto *tupleReturnTypeRepr = @@ -3764,10 +3891,10 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { funcResultElt.getName(), expectedFuncEltType) .highlight(funcEltTypeRepr->getSourceRange()); // Emit note showing original function location, if possible. - if (originalFn->getLoc().isValid()) - diagnose(originalFn->getLoc(), + if (originalAFD->getLoc().isValid()) + diagnose(originalAFD->getLoc(), diag::differentiating_attr_result_func_original_note, - originalFn->getFullName()); + originalAFD->getFullName()); attr->setInvalid(); return; } @@ -3776,7 +3903,7 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { // TODO(TF-136): Full support for cross-file/cross-module retroactive // differentiability will require SIL differentiability witnesses and lots of // plumbing. - if (originalFn->getParentSourceFile() != derivative->getParentSourceFile()) { + if (originalAFD->getParentSourceFile() != derivative->getParentSourceFile()) { diagnoseAndRemoveAttr( attr, diag::differentiating_attr_not_in_same_file_as_original); return; @@ -3785,14 +3912,14 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { // Try to find a `@differentiable` attribute on the original function with the // same differentiation parameters. DifferentiableAttr *da = nullptr; - for (auto *cda : originalFn->getAttrs().getAttributes()) + for (auto *cda : originalAFD->getAttrs().getAttributes()) if (checkedWrtParamIndices == cda->getParameterIndices()) da = const_cast(cda); // If the original function does not have a `@differentiable` attribute with // the same differentiation parameters, create one. if (!da) { da = DifferentiableAttr::create( - originalFn, /*implicit*/ true, attr->AtLoc, attr->getRange(), + originalAFD, /*implicit*/ true, attr->AtLoc, attr->getRange(), /*linear*/ false, checkedWrtParamIndices, /*jvp*/ None, /*vjp*/ None, derivative->getGenericSignature()); switch (kind) { @@ -3804,7 +3931,7 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { break; } auto insertion = Ctx.DifferentiableAttrs.try_emplace( - {originalFn, checkedWrtParamIndices}, da); + {originalAFD, checkedWrtParamIndices}, da); // Valid `@differentiable` attributes are uniqued by their parameter // indices. Reject duplicate attributes for the same decl and parameter // indices pair. @@ -3814,7 +3941,7 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { diag::differentiable_attr_duplicate_note); return; } - originalFn->getAttrs().add(da); + originalAFD->getAttrs().add(da); return; } // If the original function has a `@differentiable` attribute with the same @@ -3830,7 +3957,7 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { (da->getJVPFunction() && da->getJVPFunction() != derivative)) { diagnoseAndRemoveAttr( attr, diag::differentiating_attr_original_already_has_derivative, - originalFn->getFullName()); + originalAFD->getFullName()); return; } da->setJVPFunction(derivative); @@ -3842,7 +3969,7 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { (da->getVJPFunction() && da->getVJPFunction() != derivative)) { diagnoseAndRemoveAttr( attr, diag::differentiating_attr_original_already_has_derivative, - originalFn->getFullName()); + originalAFD->getFullName()); return; } da->setVJPFunction(derivative); @@ -3851,10 +3978,10 @@ void AttributeChecker::visitDifferentiatingAttr(DifferentiatingAttr *attr) { } void AttributeChecker::visitTransposingAttr(TransposingAttr *attr) { - auto *transpose = dyn_cast(D); + auto *transpose = cast(D); auto lookupConformance = LookUpConformanceInModule(D->getDeclContext()->getParentModule()); - auto original = attr->getOriginal(); + auto originalName = attr->getOriginalFunctionName(); auto *transposeInterfaceType = transpose->getInterfaceType()->castTo(); @@ -3913,7 +4040,7 @@ void AttributeChecker::visitTransposingAttr(TransposingAttr *attr) { if (!valueResultConf) { diagnose(attr->getLocation(), diag::transposing_attr_result_value_not_differentiable, - expectedOriginalFnType); + expectedOriginalResultType); D->getAttrs().removeAttribute(attr); attr->setInvalid(); return; @@ -3932,49 +4059,50 @@ void AttributeChecker::visitTransposingAttr(TransposingAttr *attr) { return false; // Check if target's requirements are satisfied by source. return TypeChecker::checkGenericArguments( - transpose, original.Loc.getBaseNameLoc(), - original.Loc.getBaseNameLoc(), Type(), + transpose, originalName.Loc.getBaseNameLoc(), + originalName.Loc.getBaseNameLoc(), Type(), source->getGenericParams(), target->getRequirements(), [](SubstitutableType *dependentType) { return Type(dependentType); }, lookupConformance, None) == RequirementCheckResult::Success; }; - - auto isValidOriginal = [&](FuncDecl *originalCandidate) { + + auto isValidOriginal = [&](AbstractFunctionDecl *originalCandidate) { return checkFunctionSignature( cast(expectedOriginalFnType->getCanonicalType()), originalCandidate->getInterfaceType()->getCanonicalType(), checkGenericSignatureSatisfied); }; - - // TODO: Do not reuse incompatible `@differentiable` attribute diagnostics. - // Rename compatible diagnostics so that they're not attribute-specific. + + // TODO(TF-998): Do not reuse incompatible `@differentiable` attribute + // diagnostics. Rename compatible diagnostics so that they're not + // attribute-specific. auto overloadDiagnostic = [&]() { - diagnose(original.Loc, diag::differentiating_attr_overload_not_found, - original.Name, expectedOriginalFnType); + diagnose(originalName.Loc, diag::differentiating_attr_overload_not_found, + originalName.Name, expectedOriginalFnType); }; auto ambiguousDiagnostic = [&]() { - diagnose(original.Loc, + diagnose(originalName.Loc, diag::differentiable_attr_ambiguous_function_identifier, - original.Name); + originalName.Name); }; auto notFunctionDiagnostic = [&]() { - diagnose(original.Loc, diag::differentiable_attr_specified_not_function, - original.Name); + diagnose(originalName.Loc, + diag::differentiable_attr_derivative_not_function, + originalName.Name); }; std::function invalidTypeContextDiagnostic = [&]() { - diagnose(original.Loc, + diagnose(originalName.Loc, diag::differentiable_attr_function_not_same_type_context, - original.Name); + originalName.Name); }; // Returns true if the derivative function and original function candidate are // defined in compatible type contexts. If the derivative function and the // original function candidate have different parents, return false. - std::function hasValidTypeContext = [&](FuncDecl *func) { - return true; - }; + std::function hasValidTypeContext = + [&](AbstractFunctionDecl *decl) { return true; }; auto typeRes = TypeResolution::forContextual(transpose->getDeclContext()); auto baseType = Type(); @@ -3988,21 +4116,21 @@ void AttributeChecker::visitTransposingAttr(TransposingAttr *attr) { assert(transposeTypeCtx); // Look up original function. - auto funcLoc = original.Loc.getBaseNameLoc(); + auto funcLoc = originalName.Loc.getBaseNameLoc(); if (attr->getBaseType()) funcLoc = attr->getBaseType()->getLoc(); - auto *originalFn = TypeChecker::lookupFuncDecl( - original.Name, funcLoc, baseType, transposeTypeCtx, isValidOriginal, + auto *originalAFD = findAbstractFunctionDecl( + originalName.Name, funcLoc, baseType, transposeTypeCtx, isValidOriginal, overloadDiagnostic, ambiguousDiagnostic, notFunctionDiagnostic, lookupOptions, hasValidTypeContext, invalidTypeContextDiagnostic); - if (!originalFn) { + if (!originalAFD) { D->getAttrs().removeAttribute(attr); attr->setInvalid(); return; } - attr->setOriginalFunction(originalFn); + attr->setOriginalFunction(originalAFD); // Gather differentiation parameters. // Differentiation parameters are with respect to the original function. @@ -4011,7 +4139,7 @@ void AttributeChecker::visitTransposingAttr(TransposingAttr *attr) { wrtParamTypes); // Check if differentiation parameter indices are valid. - if (checkTransposingParameters(originalFn, wrtParamTypes, + if (checkTransposingParameters(originalAFD, wrtParamTypes, transpose->getGenericEnvironment(), transpose->getModuleContext(), parsedWrtParams, attr->getLocation())) { diff --git a/lib/Sema/TypeCheckExpr.cpp b/lib/Sema/TypeCheckExpr.cpp index a2549676454ed..09aa57e0749a1 100644 --- a/lib/Sema/TypeCheckExpr.cpp +++ b/lib/Sema/TypeCheckExpr.cpp @@ -750,102 +750,3 @@ Expr *TypeChecker::foldSequence(SequenceExpr *expr, DeclContext *dc) { return Result; } - -// SWIFT_ENABLE_TENSORFLOW -// Returns the function declaration corresponding to the given function name and -// lookup context. If the base type of the function is specified, member lookup -// is performed. Otherwise, unqualified lookup is performed. -// If the function declaration cannot be resolved, emits a diagnostic and -// returns nullptr. -FuncDecl * -TypeChecker::lookupFuncDecl( - DeclName funcName, SourceLoc funcNameLoc, Type baseType, - DeclContext *lookupContext, - const std::function &isValidFuncDecl, - const std::function &overloadDiagnostic, - const std::function &ambiguousDiagnostic, - const std::function ¬FunctionDiagnostic, - NameLookupOptions lookupOptions, - const Optional> &hasValidTypeCtx, - const Optional> &invalidTypeCtxDiagnostic) { - auto &ctx = lookupContext->getASTContext(); - FuncDecl *resolvedFuncDecl = nullptr; - - // Perform lookup. - LookupResult results; - if (baseType) { - results = TypeChecker::lookupMember(lookupContext, baseType, funcName); - } else { - results = TypeChecker::lookupUnqualified( - lookupContext, funcName, funcNameLoc, lookupOptions); - - // If looking up an operator within a type context, look specifically within - // the type context. - // This tries to resolve unqualified operators, like `+`. - if (funcName.isOperator() && lookupContext->isTypeContext()) { - if (auto tmp = - TypeChecker::lookupMember(lookupContext, - lookupContext->getSelfTypeInContext(), - funcName)) - results = tmp; - } - } - - // Initialize error flags. - bool notAFuncDecl = false; - bool wrongTypeContext = false; - bool ambiguousFuncDecl = false; - bool overloadNotFound = false; - - // Filter lookup results. - for (auto choice : results) { - auto decl = choice.getValueDecl(); - if (!decl) continue; - - auto funcDecl = dyn_cast(decl); - if (!funcDecl) { - notAFuncDecl = true; - continue; - } - if (hasValidTypeCtx && !(*hasValidTypeCtx)(funcDecl)) { - wrongTypeContext = true; - continue; - } - if (!isValidFuncDecl(funcDecl)) { - overloadNotFound = true; - continue; - } - if (resolvedFuncDecl) { - ambiguousFuncDecl = true; - resolvedFuncDecl = nullptr; - break; - } - resolvedFuncDecl = funcDecl; - } - // If function declaration was resolved, return it. - if (resolvedFuncDecl) return resolvedFuncDecl; - - // Otherwise, emit the appropriate diagnostic and return nullptr. - if (results.empty()) { - ctx.Diags.diagnose(funcNameLoc, diag::use_unresolved_identifier, funcName, - funcName.isOperator()); - return nullptr; - } - if (ambiguousFuncDecl) { - ambiguousDiagnostic(); - return nullptr; - } - if (wrongTypeContext) { - assert(invalidTypeCtxDiagnostic && - "Type context diagnostic should've been specified"); - (*invalidTypeCtxDiagnostic)(); - return nullptr; - } - if (overloadNotFound) { - overloadDiagnostic(); - return nullptr; - } - assert(notAFuncDecl && "Expected 'not a function' error"); - notFunctionDiagnostic(); - return nullptr; -} diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h index 079edf7c7d127..b94501da3f216 100644 --- a/lib/Sema/TypeChecker.h +++ b/lib/Sema/TypeChecker.h @@ -1908,21 +1908,6 @@ class TypeChecker final { static DeclTypeCheckingSemantics getDeclTypeCheckingSemantics(ValueDecl *decl); - /// SWIFT_ENABLE_TENSORFLOW - // Returns the function declaration corresponding to the given function name - // and lookup context. If the function declaration cannot be resolved, emits a - // diagnostic and returns nullptr. - static FuncDecl *lookupFuncDecl( - DeclName funcName, SourceLoc funcNameLoc, Type baseType, - DeclContext *lookupContext, - const std::function &isValidFuncDecl, - const std::function &overloadDiagnostic, - const std::function &ambiguousDiagnostic, - const std::function ¬FunctionDiagnostic, - NameLookupOptions lookupOptions = defaultMemberLookupOptions, - const Optional> &hasValidTypeCtx = None, - const Optional> &invalidTypeCtxDiagnostic = None); - /// SWIFT_ENABLE_TENSORFLOW /// Creates an `IndexSubset` for the given function type, representing /// all inferred differentiation parameters. diff --git a/test/AutoDiff/derivative_registration.swift b/test/AutoDiff/derivative_registration.swift index 39f160adb2431..1016577ec5b36 100644 --- a/test/AutoDiff/derivative_registration.swift +++ b/test/AutoDiff/derivative_registration.swift @@ -35,6 +35,25 @@ struct Wrapper : Differentiable { var float: Tracked } +extension Wrapper { + @_semantics("autodiff.opaque") + init(_ x: Tracked, _ y: Tracked) { + self.float = x * y + } + + @differentiating(init(_:_:)) + static func _vjpInit(_ x: Tracked, _ y: Tracked) + -> (value: Self, pullback: (TangentVector) -> (Tracked, Tracked)) { + return (.init(x, y), { v in (v.float * y, v.float * x) }) + } +} +DerivativeRegistrationTests.testWithLeakChecking("Initializer") { + let v = Wrapper.TangentVector(float: 1) + let (𝛁x, 𝛁y) = pullback(at: 3, 4, in: { x, y in Wrapper(x, y) })(v) + expectEqual(4, 𝛁x) + expectEqual(3, 𝛁y) +} + extension Wrapper { @_semantics("autodiff.opaque") static func multiply(_ x: Tracked, _ y: Tracked) -> Tracked { @@ -61,16 +80,60 @@ extension Wrapper { func _vjpMultiply(_ x: Tracked) -> (value: Tracked, pullback: (Tracked) -> (Wrapper.TangentVector, Tracked)) { return (float * x, { v in - (Wrapper.TangentVector(float: v * x), v * self.float) + (TangentVector(float: v * x), v * self.float) }) } } DerivativeRegistrationTests.testWithLeakChecking("InstanceMethod") { let x: Tracked = 2 let wrapper = Wrapper(float: 3) - let (𝛁wrapper, 𝛁x) = wrapper.gradient(at: x) { wrapper, x in wrapper.multiply(x) } + let (𝛁wrapper, 𝛁x) = gradient(at: wrapper, x) { wrapper, x in wrapper.multiply(x) } + expectEqual(Wrapper.TangentVector(float: 2), 𝛁wrapper) + expectEqual(3, 𝛁x) +} + +extension Wrapper { + subscript(_ x: Tracked) -> Tracked { + @_semantics("autodiff.opaque") + get { float * x } + set {} + } + + @differentiating(subscript(_:)) + func _vjpSubscript(_ x: Tracked) + -> (value: Tracked, pullback: (Tracked) -> (Wrapper.TangentVector, Tracked)) { + return (self[x], { v in + (TangentVector(float: v * x), v * self.float) + }) + } +} +DerivativeRegistrationTests.testWithLeakChecking("Subscript") { + let x: Tracked = 2 + let wrapper = Wrapper(float: 3) + let (𝛁wrapper, 𝛁x) = gradient(at: wrapper, x) { wrapper, x in wrapper[x] } expectEqual(Wrapper.TangentVector(float: 2), 𝛁wrapper) expectEqual(3, 𝛁x) } +extension Wrapper { + var computedProperty: Tracked { + @_semantics("autodiff.opaque") + get { float * float } + set {} + } + + @differentiating(computedProperty) + func _vjpComputedProperty() + -> (value: Tracked, pullback: (Tracked) -> Wrapper.TangentVector) { + return (computedProperty, { [f = self.float] v in + TangentVector(float: v * (f + f)) + }) + } +} +DerivativeRegistrationTests.testWithLeakChecking("ComputedProperty") { + let wrapper = Wrapper(float: 3) + let 𝛁wrapper = gradient(at: wrapper) { wrapper in wrapper.computedProperty } + expectEqual(Wrapper.TangentVector(float: 6), 𝛁wrapper) +} + runAllTests() diff --git a/test/AutoDiff/differentiable_attr_type_checking.swift b/test/AutoDiff/differentiable_attr_type_checking.swift index a2befa0bd0f1a..b6f8b0056138b 100644 --- a/test/AutoDiff/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/differentiable_attr_type_checking.swift @@ -826,6 +826,39 @@ func slope3(_ x: Float) -> Float { return 3 * x } +// Check that `@differentiable` attribute rejects stored properties. +struct StoredProperty : Differentiable { + // expected-error @+1 {{'@differentiable' attribute on stored property cannot specify 'jvp:' or 'vjp:'}} + @differentiable(vjp: vjpStored) + var stored: Float + + func vjpStored() -> (Float, (Float) -> TangentVector) { + (stored, { _ in .zero }) + } +} + +// Check that `@differentiable` attribute rejects non-`func` derivatives. +struct Struct: Differentiable { + // expected-error @+1 {{registered derivative 'computedPropertyVJP' must be a 'func' declaration}} + @differentiable(vjp: computedPropertyVJP) + func testComputedProperty() -> Float { 1 } + var computedPropertyVJP: (Float, (Float) -> TangentVector) { + (1, { _ in .zero }) + } + + // expected-error @+1 {{expected a vjp function name}} + @differentiable(vjp: init) + func testInitializer() -> Struct { self } + init(_ x: Struct) {} + + // expected-error @+1 {{expected a vjp function name}} + @differentiable(vjp: subscript) + func testSubscript() -> Float { 1 } + subscript() -> (Float, (Float) -> TangentVector) { + (1, { _ in .zero }) + } +} + // Index based 'wrt:' struct NumberWrtStruct: Differentiable { diff --git a/test/AutoDiff/differentiating_attr_type_checking.swift b/test/AutoDiff/differentiating_attr_type_checking.swift index 0b34f931fce29..98efc37199894 100644 --- a/test/AutoDiff/differentiating_attr_type_checking.swift +++ b/test/AutoDiff/differentiating_attr_type_checking.swift @@ -296,26 +296,6 @@ func vjpConsistent(_ x: Float) -> (value: Float, pullback: (Float) -> Float) { return (x, { $0 }) } -// Test usage of `@differentiable` on a stored property -struct PropertyDiff : Differentiable & AdditiveArithmetic { - // expected-error @+1 {{'@differentiable' attribute on stored property cannot specify 'jvp:' or 'vjp:'}} - @differentiable(vjp: vjpPropertyA) - var a: Float = 1 - typealias TangentVector = PropertyDiff - typealias AllDifferentiableVariables = PropertyDiff - func vjpPropertyA() -> (Float, (Float) -> PropertyDiff) { - (.zero, { _ in .zero }) - } -} - -@differentiable -func f(_ x: PropertyDiff) -> Float { - return x.a -} - -let a = gradient(at: PropertyDiff(), in: f) -print(a) - // Index based 'wrt:' func add2(x: Float, y: Float) -> Float { @@ -370,3 +350,79 @@ class Sub : Super { return (foo(x), { v in v }) } } + +// Test non-`func` original declarations. + +struct Struct { + var x: T +} +extension Struct: Equatable where T: Equatable {} +extension Struct: Differentiable & AdditiveArithmetic +where T: Differentiable & AdditiveArithmetic {} + +// Test computed properties. +extension Struct { + var computedProperty: T { x } +} +extension Struct where T: Differentiable & AdditiveArithmetic { + @differentiating(computedProperty) + func vjpProperty() -> (value: T, pullback: (T.TangentVector) -> TangentVector) { + (x, { v in .init(x: v) }) + } +} + +// Test initializers. +extension Struct { + init(_ x: Float) {} + init(_ x: T, y: Float) {} +} +extension Struct where T: Differentiable & AdditiveArithmetic { + @differentiating(init) + static func vjpInit(_ x: Float) -> (value: Struct, pullback: (TangentVector) -> Float) { + (.init(x), { _ in .zero }) + } + + @differentiating(init(_:y:)) + static func vjpInit2(_ x: T, _ y: Float) -> (value: Struct, pullback: (TangentVector) -> (T.TangentVector, Float)) { + (.init(x, y: y), { _ in (.zero, .zero) }) + } +} + +// Test subscripts. +extension Struct { + subscript() -> Float { + get { 1 } + set {} + } + subscript(float float: Float) -> Float { 1 } + subscript(x: T) -> T { x } +} + +extension Struct where T: Differentiable & AdditiveArithmetic { + @differentiating(subscript) + func vjpSubscript() -> (value: Float, pullback: (Float) -> TangentVector) { + (1, { _ in .zero }) + } + + @differentiating(subscript(float:), wrt: self) + func vjpSubscriptLabelled(float: Float) -> (value: Float, pullback: (Float) -> TangentVector) { + (1, { _ in .zero }) + } + + @differentiating(subscript(_:), wrt: self) + func vjpSubscriptGeneric(x: T) -> (value: T, pullback: (T.TangentVector) -> TangentVector) { + (x, { _ in .zero }) + } +} + +// Check that `@differentiating` attribute rejects stored property original declarations. + +struct StoredProperty: Differentiable { + // expected-note @+1 {{'stored' declared here}} + var stored: Float + // expected-error @+1 {{cannot register derivative for stored property 'stored'}} + @differentiating(stored) + func vjpStored() -> (value: Float, pullback: (Float) -> TangentVector) { + (stored, { _ in .zero }) + } +} diff --git a/test/AutoDiff/transposing_attr_type_checking.swift b/test/AutoDiff/transposing_attr_type_checking.swift index b1838b6da1422..cacc0f3ee4de8 100644 --- a/test/AutoDiff/transposing_attr_type_checking.swift +++ b/test/AutoDiff/transposing_attr_type_checking.swift @@ -155,7 +155,7 @@ func missingDiffSelfRequirement(x: T) -> T { return x } -// expected-error @+1 {{'@transposing' attribute requires original function result to conform to 'Differentiable'}} +// expected-error @+1 {{'@transposing' attribute requires original function result 'T' to conform to 'Differentiable'}} @transposing(missingDiffSelfRequirement, wrt: 0) func missingDiffSelfRequirementT(x: T) -> T { return x @@ -202,7 +202,7 @@ func transposingIntT1(x: Float, t: Float) -> Int { return Int(x) } -// expected-error @+1 {{'@transposing' attribute requires original function result to conform to 'Differentiable'}} +// expected-error @+1 {{'@transposing' attribute requires original function result 'Int' to conform to 'Differentiable'}} @transposing(transposingInt, wrt: 0) func tangentNotLast(t: Float, y: Int) -> Float { return t @@ -462,3 +462,85 @@ extension Float { return (1, T(1), 1) } } + +// Test non-`func` original declarations. + +struct Struct {} +extension Struct: Equatable where T: Equatable {} +extension Struct: Differentiable & AdditiveArithmetic +where T: Differentiable & AdditiveArithmetic {} + +// Test computed properties. +extension Struct { + var computedProperty: Struct { self } +} +extension Struct where T: Differentiable & AdditiveArithmetic { + @transposing(computedProperty, wrt: self) + func transposeProperty() -> Self { + self + } +} + +// Test initializers. +extension Struct { + init(_ x: Float) {} + init(_ x: T, y: Float) {} +} + +extension Struct where T: Differentiable & AdditiveArithmetic { + // TODO(TF-997): Support `@transposing` attribute with initializer original declaration. + // expected-error @+1 {{'@transposing' attribute requires original function result 'Struct.Type' to conform to 'Differentiable'}} + @transposing(init, wrt: 0) + static func vjpInit(_ x: Self) -> Float { + fatalError() + } + + // TODO(TF-997): Support `@transposing` attribute with initializer original declaration. + // expected-error @+1 {{'@transposing' attribute requires original function result 'Struct.Type' to conform to 'Differentiable'}} + @transposing(init(_:y:), wrt: (0, 1)) + static func vjpInit2(_ x: Self) -> (T, Float) { + fatalError() + } +} + +// Test subscripts. +extension Struct { + subscript() -> Self { + get { self } + set {} + } + subscript(float float: Float) -> Self { self } + subscript(x: U) -> Self { self } +} + +extension Struct where T: Differentiable & AdditiveArithmetic { + @transposing(subscript, wrt: self) + func vjpSubscript() -> Self { + self + } + + @transposing(subscript(float:), wrt: self) + func vjpSubscriptLabelled(_ float: Float) -> Self { + self + } + + @transposing(subscript(_:), wrt: self) + func vjpSubscriptGeneric(x: U) -> Self { + self + } +} + +// Check that `@transposing` attribute rejects stored property original declarations. + +struct StoredProperty: Differentiable { + var stored: Float + + // Note: `@transposing` support for instance members is currently too limited + // to properly register a transpose for a non-`Self`-typed member. + + // expected-error @+1 {{could not find function 'stored' with expected type '(StoredProperty) -> () -> StoredProperty'}} + @transposing(stored, wrt: self) + func vjpStored() -> Self { + fatalError() + } +}