diff --git a/include/swift/AST/AutoDiff.h b/include/swift/AST/AutoDiff.h index b447fdf3f7773..be9f14696dc4f 100644 --- a/include/swift/AST/AutoDiff.h +++ b/include/swift/AST/AutoDiff.h @@ -199,6 +199,12 @@ class AutoDiffParameterIndicesBuilder { /// `AutoDiffParameterIndices::parameters` for documentation about the order. void setParameter(unsigned parameterIndex); + /// Sets the parameters at indices in the specified range. + void setParameters(unsigned lowerBound, unsigned upperBound); + + /// Sets all parameters. + void setAllParameters(); + /// Returns the number of parameters. unsigned size() { return parameters.size(); } }; diff --git a/include/swift/AST/DiagnosticsSema.def b/include/swift/AST/DiagnosticsSema.def index 8ffce7b33d122..c45e7983204c7 100644 --- a/include/swift/AST/DiagnosticsSema.def +++ b/include/swift/AST/DiagnosticsSema.def @@ -2446,8 +2446,6 @@ WARNING(differentiable_implicit_noderivative_fixit,none, "stored property %0 has no derivative because it does not conform to " "'Differentiable'; add '@noDerivative' to make it explicit", (Identifier)) -NOTE(protocol_witness_missing_differentiable_attr,none, - "candidate is missing attribute '%0'", (StringRef)) NOTE(codable_extraneous_codingkey_case_here,none, "CodingKey case %0 does not match any stored properties", (Identifier)) @@ -2728,6 +2726,8 @@ ERROR(differentiable_attr_unsupported_req_kind,none, "layout requirement are not supported by '@differentiable' attribute", ()) ERROR(differentiable_attr_class_unsupported,none, "class members cannot be marked with '@differentiable'", ()) +NOTE(protocol_witness_missing_specific_differentiable_attr,none, + "candidate is missing attribute '%0'", (StringRef)) // @differentiang ERROR(differentiating_attr_expected_result_tuple,none, diff --git a/lib/AST/Attr.cpp b/lib/AST/Attr.cpp index 363e956ec430f..097892407e731 100644 --- a/lib/AST/Attr.cpp +++ b/lib/AST/Attr.cpp @@ -557,15 +557,17 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options, // SWIFT_ENABLE_TENSORFLOW case DAK_Differentiable: { Printer.printAttrName("@differentiable"); - Printer << '('; auto *attr = cast(this); auto parsedParams = attr->getParsedParameters(); - + // If no attribute parameter is specified, do not print parentheses at all. + if (parsedParams.empty() && !attr->getJVP() && !attr->getVJP() && + !attr->getWhereClause()) + break; + Printer << '('; // Get original function. auto *original = dyn_cast_or_null(D); if (auto *varDecl = dyn_cast_or_null(D)) original = varDecl->getGetter(); - bool isMethod = original && original->hasImplicitSelfDecl(); // Print comma if not leading clause. bool isLeadingClause = true; @@ -579,35 +581,24 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options, // Print differentiation parameters, if any. if (auto indices = attr->getParameterIndices()) { - printCommaIfNecessary(); - Printer << "wrt: ("; - SmallBitVector parameters(indices->parameters); - // Check if differentiating wrt `self`. If so, manually print it first. - if (isMethod && parameters.test(parameters.size() - 1)) { - parameters.reset(parameters.size() - 1); - Printer << "self"; - if (parameters.any()) - Printer << ", "; + if (!parsedParams.empty()) { + printCommaIfNecessary(); + Printer << "wrt: "; + if (parsedParams.size() > 1) + Printer << '('; + interleave(parsedParams, [&](const ParsedAutoDiffParameter ¶m) { + switch (param.getKind()) { + case ParsedAutoDiffParameter::Kind::Named: + Printer << param.getName(); + break; + case ParsedAutoDiffParameter::Kind::Self: + Printer << "self"; + break; + } + }, [&]{ Printer << ", "; }); + if (parsedParams.size() > 1) + Printer << ')'; } - // Print remaining differentiation parameters. - interleave(parameters.set_bits(), [&](unsigned index) { - Printer << original->getParameters()->get(index)->getName().str(); - }, [&] { Printer << ", "; }); - Printer << ")"; - } else if (!parsedParams.empty()) { - printCommaIfNecessary(); - Printer << "wrt: ("; - interleave(parsedParams, [&](const ParsedAutoDiffParameter ¶m) { - switch (param.getKind()) { - case ParsedAutoDiffParameter::Kind::Named: - Printer << param.getName(); - break; - case ParsedAutoDiffParameter::Kind::Self: - Printer << "self"; - break; - } - }, [&] { Printer << ", "; }); - Printer << ")"; } // Print jvp function name. if (auto jvp = attr->getJVP()) { diff --git a/lib/AST/AutoDiff.cpp b/lib/AST/AutoDiff.cpp index 20aed486aa97d..18e1256a13bfd 100644 --- a/lib/AST/AutoDiff.cpp +++ b/lib/AST/AutoDiff.cpp @@ -262,8 +262,8 @@ static unsigned getNumAutoDiffParameterIndices(AnyFunctionType *fnTy) { } AutoDiffParameterIndicesBuilder::AutoDiffParameterIndicesBuilder( - AnyFunctionType *functionType, bool setAllParams) : - parameters(getNumAutoDiffParameterIndices(functionType), setAllParams) { + AnyFunctionType *functionType, bool setAllParams) + : parameters(getNumAutoDiffParameterIndices(functionType), setAllParams) { } AutoDiffParameterIndices * @@ -276,6 +276,15 @@ void AutoDiffParameterIndicesBuilder::setParameter(unsigned paramIndex) { parameters.set(paramIndex); } +void AutoDiffParameterIndicesBuilder::setParameters(unsigned lowerBound, + unsigned upperBound) { + parameters.set(lowerBound, upperBound); +} + +void AutoDiffParameterIndicesBuilder::setAllParameters() { + parameters.set(); +} + Type VectorSpace::getType() const { switch (kind) { case Kind::Vector: diff --git a/lib/Sema/TypeCheckAttr.cpp b/lib/Sema/TypeCheckAttr.cpp index 86812c1f8b45c..39c1eaae6afa2 100644 --- a/lib/Sema/TypeCheckAttr.cpp +++ b/lib/Sema/TypeCheckAttr.cpp @@ -2462,15 +2462,12 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { LookUpConformanceInModule(D->getDeclContext()->getParentModule()); AbstractFunctionDecl *original = nullptr; - bool isProperty = false; if (auto *vd = dyn_cast(D)) { // When used on a storage decl, @differentiable refers to its getter. original = vd->getGetter(); - isProperty = true; } else if (auto *afd = dyn_cast(D)) { original = afd; if (auto *accessor = dyn_cast(afd)) { - isProperty = true; // We do not support setters yet because inout is not supported yet. if (accessor->isSetter()) original = nullptr; @@ -2609,16 +2606,13 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) { AutoDiffParameterIndicesBuilder autoDiffParameterIndicesBuilder( originalFnTy); if (parsedWrtParams.empty()) { - if (isProperty) - autoDiffParameterIndicesBuilder.setParameter(0); - else { - // If 'wrt:' is not specified, the wrt parameters are all the parameters - // in the main parameter group. Self is intentionally excluded except - // when it's a property. - unsigned numNonSelfParameters = autoDiffParameterIndicesBuilder.size() - - (isMethod ? 1 : 0); - for (unsigned i : range(numNonSelfParameters)) - autoDiffParameterIndicesBuilder.setParameter(i); + if (original->isStatic() || isa(original)) { + auto *methodTy = + original->getMethodInterfaceType()->castTo(); + autoDiffParameterIndicesBuilder + .setParameters(0, methodTy->getNumParams()); + } else { + autoDiffParameterIndicesBuilder.setAllParameters(); } } else { // 'wrt:' is specified. Validate and collect the selected parameters. diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index e3b922abdb124..1c54830614286 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -661,24 +661,22 @@ swift::matchWitness( } // SWIFT_ENABLE_TENSORFLOW - // Differentiation attributes must match completely or the generated - // functions will have the wrong signature. - // TODO(TF-285): Handle multiple `@differentiable` attributes on protocol - // requirements. Only missing attributes should be diagnosed. - auto *reqDiffAttr = - reqAttrs.getAttribute(/*AllowInvalid*/ true); - auto *witnessDiffAttr = - witnessAttrs.getAttribute(/*AllowInvalid*/ true); - if (reqDiffAttr && (!reqDiffAttr->getParameterIndices() || - !witnessDiffAttr || - !witnessDiffAttr->getParameterIndices() || - !witnessDiffAttr->parametersMatch(*reqDiffAttr))) { - if (auto *vdWitness = dyn_cast(witness)) - return RequirementMatch( - getStandinForAccessor(vdWitness, AccessorKind::Get), - MatchKind::DifferentiableConflict); - else - return RequirementMatch(witness, MatchKind::DifferentiableConflict); + // '@differentiable' attributes must match completely. + for (auto *reqDiffAttr : reqAttrs.getAttributes()) { + auto witnessDiffAttrs = + witnessAttrs.getAttributes(); + bool reqDiffAttrMatch = llvm::any_of(witnessDiffAttrs, + [&](const DifferentiableAttr *witnessDiffAttr) { + return witnessDiffAttr->parametersMatch(*reqDiffAttr); + }); + if (!reqDiffAttrMatch) { + if (auto *vdWitness = dyn_cast(witness)) + return RequirementMatch( + getStandinForAccessor(vdWitness, AccessorKind::Get), + MatchKind::DifferentiableConflict); + else + return RequirementMatch(witness, MatchKind::DifferentiableConflict); + } } // Now finalize the match. @@ -2244,20 +2242,20 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance, diags.diagnose(match.Witness, diag::protocol_witness_not_objc); break; // SWIFT_ENABLE_TENSORFLOW - case MatchKind::DifferentiableConflict: - std::string diffAttrReq; - { + case MatchKind::DifferentiableConflict: { + for (auto *da : req->getAttrs() + .getAttributes()) { + assert(da); + std::string diffAttrReq; llvm::raw_string_ostream stream(diffAttrReq); - // TODO(TF-285): Handle multiple `@differentiable` attributes on protocol - // requirements. Only missing attributes should be diagnosed. - req->getAttrs().getAttribute()->print(stream, req); - diffAttrReq = StringRef(stream.str()).trim(); + da->print(stream, req); + diags.diagnose(match.Witness, + diag::protocol_witness_missing_specific_differentiable_attr, + StringRef(stream.str()).trim()); } - diags.diagnose(match.Witness, - diag::protocol_witness_missing_differentiable_attr, - diffAttrReq); break; } + } } ConformanceChecker::ConformanceChecker( diff --git a/test/AutoDiff/autodiff_diagnostics.swift b/test/AutoDiff/autodiff_diagnostics.swift index 097d88de71d67..77856c20e1e1e 100644 --- a/test/AutoDiff/autodiff_diagnostics.swift +++ b/test/AutoDiff/autodiff_diagnostics.swift @@ -111,14 +111,14 @@ _ = gradient(at: 0) { x in if_else(0, true) } var a: Float = 3.0 protocol P { - @differentiable + @differentiable(wrt: x) func foo(x: Float) -> Float } enum T : P { // expected-note @+2 {{when differentiating this function definition}} // expected-error @+1 {{function is not differentiable}} - @differentiable func foo(x: Float) -> Float { + @differentiable(wrt: x) func foo(x: Float) -> Float { // expected-note @+1 {{cannot differentiate writes to global variables}} a = a + x return a @@ -127,7 +127,7 @@ enum T : P { // expected-note @+2 {{when differentiating this function definition}} // expected-error @+1 {{function is not differentiable}} -@differentiable func foo(x: Float) -> Float { +@differentiable(wrt: x) func foo(x: Float) -> Float { // expected-note @+1 {{cannot differentiate writes to global variables}} a = a + x return a diff --git a/test/AutoDiff/derived_differentiable_properties.swift b/test/AutoDiff/derived_differentiable_properties.swift index ec00a8bbc207d..5ed547189db83 100644 --- a/test/AutoDiff/derived_differentiable_properties.swift +++ b/test/AutoDiff/derived_differentiable_properties.swift @@ -7,7 +7,7 @@ public struct Foo : Differentiable { } // CHECK-AST-LABEL: @_fieldwiseDifferentiable public struct Foo : Differentiable { -// CHECK-AST: @differentiable(wrt: (self)) +// CHECK-AST: @differentiable // CHECK-AST: public var a: Float // CHECK-AST: internal init(a: Float) // CHECK-AST: @_fieldwiseDifferentiable public struct AllDifferentiableVariables diff --git a/test/AutoDiff/differentiable_attr_type_checking.swift b/test/AutoDiff/differentiable_attr_type_checking.swift index d28c5163bd91c..5a9ab9f06d286 100644 --- a/test/AutoDiff/differentiable_attr_type_checking.swift +++ b/test/AutoDiff/differentiable_attr_type_checking.swift @@ -161,7 +161,7 @@ extension JVPStruct : Differentiable { } extension JVPStruct { - @differentiable(jvp: wrtAllNonSelfJVP) + @differentiable(wrt: x, jvp: wrtAllNonSelfJVP) func wrtAllNonSelf(x: Float) -> Float { return x + p } @@ -318,7 +318,7 @@ extension VJPStruct : Differentiable { } extension VJPStruct { - @differentiable(vjp: wrtAllNonSelfVJP) + @differentiable(wrt: x, vjp: wrtAllNonSelfVJP) func wrtAllNonSelf(x: Float) -> Float { return x + p } @@ -422,7 +422,7 @@ func vjpWhere2(x: Tensor) -> (Tensor< struct A { struct B { - @differentiable(where T : Differentiable, V : Differentiable, V.TangentVector == V) + @differentiable(wrt: x where T : Differentiable, V : Differentiable, V.TangentVector == V) func whereInGenericContext(x: T) -> T { return x } @@ -510,18 +510,50 @@ struct DifferentiableInitStruct : DifferentiableInit { var y: Float // FIXME(TF-284): Fix unexpected diagnostic. - // expected-note @+2 {{candidate is missing attribute '@differentiable(wrt: (x, y))'}} - // expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: (x))'}} + // expected-note @+2 {{candidate is missing attribute '@differentiable'}} + // expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: x)'}} init(x: Float, y: Float) { self.x = x self.y = y } // FIXME(TF-284): Fix unexpected diagnostic. - // expected-note @+2 {{candidate is missing attribute '@differentiable(wrt: (x))'}} - // expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: (x, y))'}} + // expected-note @+2 {{candidate is missing attribute '@differentiable(wrt: x)'}} + // expected-note @+1 {{candidate is missing attribute '@differentiable'}} init(x: Float, y: Int) { self.x = x self.y = Float(y) } } + + +protocol NotRefiningDiffable { + @differentiable(wrt: x) + // expected-note @+1 {{protocol requires function 'a' with type '(Float) -> Float'; do you want to add a stub?}} + func a(_ x: Float) -> Float +} + +// expected-error @+1 {{type 'CertainlyNotDiffableWrtSelf' does not conform to protocol 'NotRefiningDiffable'}} +struct CertainlyNotDiffableWrtSelf : NotRefiningDiffable { + // expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: x)'}} + func a(_ x: Float) -> Float { return x * 5.0 } +} + + +protocol TF285 : Differentiable { + @differentiable(wrt: (x, y)) + @differentiable(wrt: x) + // expected-note @+1 {{protocol requires function 'foo(x:y:)' with type '(Float, Float) -> Float'; do you want to add a stub?}} + func foo(x: Float, y: Float) -> Float +} + +// expected-error @+1 {{type 'TF285MissingOneDiffAttr' does not conform to protocol 'TF285'}} +struct TF285MissingOneDiffAttr : TF285 { + // Requirement is missing an attribute. + @differentiable(wrt: x) + // expected-note @+2 {{candidate is missing attribute '@differentiable(wrt: x)}} + // expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: (x, y))}} + func foo(x: Float, y: Float) -> Float { + return x + } +} diff --git a/test/AutoDiff/existential.swift b/test/AutoDiff/existential.swift index 590aa103f8e60..f9bc0d940b349 100644 --- a/test/AutoDiff/existential.swift +++ b/test/AutoDiff/existential.swift @@ -6,13 +6,13 @@ import StdlibUnittest var ExistentialTests = TestSuite("Existential") protocol A { - @differentiable - func a(_: Float) -> Float + @differentiable(wrt: x) + func a(_ x: Float) -> Float } func b(g: A) -> Float { return (3.0 as Float).gradient() { x in g.a(x) } } struct B : A { - @differentiable + @differentiable(wrt: x) func a(_ x: Float) -> Float { return x * 5.0 } } diff --git a/test/AutoDiff/sildeclref_parse.sil b/test/AutoDiff/sildeclref_parse.sil index 2c266d75fc5db..70d15a075410e 100644 --- a/test/AutoDiff/sildeclref_parse.sil +++ b/test/AutoDiff/sildeclref_parse.sil @@ -3,7 +3,7 @@ import Swift protocol Proto { - @differentiable() + @differentiable(wrt: (x, y)) func f(_ x: Float, _ y: Float) -> Float } diff --git a/test/AutoDiff/witness_table_irgen.sil b/test/AutoDiff/witness_table_irgen.sil index f1a91d327d458..371705f414ac0 100644 --- a/test/AutoDiff/witness_table_irgen.sil +++ b/test/AutoDiff/witness_table_irgen.sil @@ -7,12 +7,12 @@ import Swift import SwiftShims protocol DifferentiableRequirement { - @differentiable() + @differentiable(wrt: x) func f(_ x: Float) -> Float } struct DifferentiableConformance : DifferentiableRequirement { - @differentiable(jvp: df, vjp: pf) + @differentiable(wrt: x, jvp: df, vjp: pf) func f(_ x: Float) -> Float func df(_ x: Float) -> (Float, (Float) -> Float) func pf(_ x: Float) -> (Float, (Float) -> Float) diff --git a/test/AutoDiff/witness_table_silgen.swift b/test/AutoDiff/witness_table_silgen.swift index 5687fefb0bdd0..25697f6b4dabb 100644 --- a/test/AutoDiff/witness_table_silgen.swift +++ b/test/AutoDiff/witness_table_silgen.swift @@ -1,13 +1,13 @@ // RUN: %target-swift-frontend -emit-sil -verify %s | %FileCheck %s protocol Proto : Differentiable { - @differentiable() + @differentiable(wrt: (x, y)) func function1(_ x: Float, _ y: Float) -> Float @differentiable(wrt: (self, x, y)) func function2(_ x: Float, _ y: Float) -> Float - @differentiable(wrt: (y)) + @differentiable(wrt: y) func function3(_ x: Float, _ y: Float) -> Float } @@ -28,7 +28,7 @@ struct S : Proto, VectorNumeric { return (p, { dp in S(p: dp) }) } - @differentiable() + @differentiable(wrt: (x, y)) func function1(_ x: Float, _ y: Float) -> Float { return x + y + p } diff --git a/test/Serialization/differentiable_attr.swift b/test/Serialization/differentiable_attr.swift index 9a58a8557cca3..62a2b5aae9f1f 100644 --- a/test/Serialization/differentiable_attr.swift +++ b/test/Serialization/differentiable_attr.swift @@ -12,7 +12,7 @@ func jvpSimple(x: Float) -> Float { return x } -// CHECK-DAG: @differentiable(wrt: (x), jvp: jvpSimpleJVP) +// CHECK-DAG: @differentiable(jvp: jvpSimpleJVP) // CHECK-DAG: func jvpSimpleJVP(x: Float) -> (Float, (Float) -> Float) func jvpSimpleJVP(x: Float) -> (Float, (Float) -> Float) { return (x, { v in v }) @@ -23,13 +23,13 @@ func vjpSimple(x: Float) -> Float { return x } -// CHECK-DAG: @differentiable(wrt: (x), vjp: vjpSimpleVJP) +// CHECK-DAG: @differentiable(vjp: vjpSimpleVJP) // CHECK-DAG: func vjpSimpleVJP(x: Float) -> (Float, (Float) -> Float) func vjpSimpleVJP(x: Float) -> (Float, (Float) -> Float) { return (x, { v in v }) } -// CHECK-DAG: @differentiable(wrt: (x), vjp: vjpTestWhereClause where T : Differentiable, T : Numeric) +// CHECK-DAG: @differentiable(vjp: vjpTestWhereClause where T : Differentiable, T : Numeric) // CHECK-DAG: func testWhereClause(x: T) -> T where T : Numeric @differentiable(vjp: vjpTestWhereClause where T : Differentiable) func testWhereClause(x: T) -> T { @@ -43,9 +43,9 @@ func vjpTestWhereClause(x: T) -> (T, (T.CotangentVector) -> T.CotangentVector protocol P {} extension P { - // CHECK-DAG: @differentiable(wrt: (self), vjp: vjpTestWhereClause where Self : Differentiable, Self : P) + // CHECK-DAG: @differentiable(vjp: vjpTestWhereClause where Self : Differentiable, Self : P) // CHECK-DAG: func testWhereClause() -> Self - @differentiable(wrt: (self), vjp: vjpTestWhereClause where Self : Differentiable) + @differentiable(wrt: self, vjp: vjpTestWhereClause where Self : Differentiable) func testWhereClause() -> Self { return self } @@ -59,7 +59,7 @@ extension P where Self : Differentiable { // NOTE: The failing tests involve where clauses with member type constraints. // They pass type-checking but crash during serialization. -// CHECK-DAG: @differentiable(wrt: (x), vjp: vjpTestWhereClauseMemberTypeConstraint where T : Differentiable, T : Numeric, T == T.CotangentVector) +// CHECK-DAG: @differentiable(vjp: vjpTestWhereClauseMemberTypeConstraint where T : Differentiable, T : Numeric, T == T.CotangentVector) // CHECK-DAG: func testWhereClauseMemberTypeConstraint(x: T) -> T where T : Numeric @differentiable(vjp: vjpTestWhereClauseMemberTypeConstraint where T : Differentiable, T == T.CotangentVector) func testWhereClauseMemberTypeConstraint(x: T) -> T { @@ -72,9 +72,9 @@ func vjpTestWhereClauseMemberTypeConstraint(x: T) -> (T, (T) -> T) } extension P { - // CHECK-DAG: @differentiable(wrt: (self), vjp: vjpTestWhereClauseMemberTypeConstraint where Self : Differentiable, Self : P, Self == Self.CotangentVector) + // CHECK-DAG: @differentiable(vjp: vjpTestWhereClauseMemberTypeConstraint where Self : Differentiable, Self : P, Self == Self.CotangentVector) // CHECK-DAG: func testWhereClauseMemberTypeConstraint() -> Self - @differentiable(wrt: (self), vjp: vjpTestWhereClauseMemberTypeConstraint where Self.CotangentVector == Self, Self : Differentiable) + @differentiable(wrt: self, vjp: vjpTestWhereClauseMemberTypeConstraint where Self.CotangentVector == Self, Self : Differentiable) func testWhereClauseMemberTypeConstraint() -> Self { return self } diff --git a/test/TensorFlowRuntime/tensor_autodiff_indirect.swift b/test/TensorFlowRuntime/tensor_autodiff_indirect.swift index 0bcc07b00ed88..2234d4b8e484d 100644 --- a/test/TensorFlowRuntime/tensor_autodiff_indirect.swift +++ b/test/TensorFlowRuntime/tensor_autodiff_indirect.swift @@ -33,7 +33,7 @@ TensorADTests.testAllBackends("Concrete") { } extension Tensor where Scalar : Differentiable & FloatingPoint { - @differentiable(vjp: vjpFoo) + @differentiable(wrt: x, vjp: vjpFoo) func foo(_ x: Scalar) -> Scalar { return x }