diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index 1ddcbfcf4a17e..ca9fb1d591a60 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -560,6 +560,15 @@ swift::matchWitness( bool foundSupersetAttr = false; for (auto witnessConfig : witnessAFD->getDerivativeFunctionConfigurations()) { + // We can't use witnesses that have generic signatures not satisfied by + // the requirement's generic signature. + if (witnessConfig.derivativeGenericSignature && + !witnessConfig.derivativeGenericSignature + ->requirementsNotSatisfiedBy( + reqDiffAttr->getDerivativeGenericSignature()) + .empty()) + continue; + if (witnessConfig.parameterIndices == reqDiffAttr->getParameterIndices()) foundExactAttr = true; diff --git a/test/AutoDiff/protocol_requirement_autodiff.swift b/test/AutoDiff/protocol_requirement_autodiff.swift index d3d8b73f86661..e5c49e1c6ce8d 100644 --- a/test/AutoDiff/protocol_requirement_autodiff.swift +++ b/test/AutoDiff/protocol_requirement_autodiff.swift @@ -163,6 +163,8 @@ func blah2(_ x: T, _ value: T.Value) -> Tra x.logProbability(of: value) } +// Satisfying the requirement with more wrt indices than are necessary. + protocol DifferentiableFoo { associatedtype T: Differentiable @differentiable(wrt: x) @@ -181,4 +183,18 @@ struct MoreDifferentiableFooStruct: MoreDifferentiableFoo { } } +// Satisfiying the requirement with a less-constrained derivative than is necessary. + +protocol ExtraDerivativeConstraint {} + +protocol HasExtraConstrainedDerivative { + @differentiable + func requirement(_ x: T) -> T +} + +struct SatisfiesDerivativeWithLessConstraint: HasExtraConstrainedDerivative { + @differentiable + func requirement(_ x: T) -> T { x } +} + runAllTests() diff --git a/test/AutoDiff/protocol_requirement_autodiff_diags.swift b/test/AutoDiff/protocol_requirement_autodiff_diags.swift new file mode 100644 index 0000000000000..ec1fc3d508add --- /dev/null +++ b/test/AutoDiff/protocol_requirement_autodiff_diags.swift @@ -0,0 +1,18 @@ +// RUN: %target-swift-frontend -typecheck -verify %s + +protocol P {} + +protocol HasRequirement { + @differentiable + // expected-note @+1 {{protocol requires function 'requirement' with type ' (T, T) -> T'; do you want to add a stub?}} + func requirement(_ x: T, _ y: T) -> T +} + +// expected-error @+1 {{type 'AttemptsToSatisfyRequirement' does not conform to protocol 'HasRequirement'}} +struct AttemptsToSatisfyRequirement: HasRequirement { + // This does not satisfy the requirement because the differentiable attribute is more + // constrained than the requirement's differentiable attribute. + @differentiable(where T: P) + // expected-note @+1 {{candidate is missing attribute '@differentiable'}} + func requirement(_ x: T, _ y: T) -> T { x } +}