Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,15 @@ swift::matchWitness(
bool foundSupersetAttr = false;
for (auto witnessConfig :
witnessAFD->getDerivativeFunctionConfigurations()) {
// We can't use witnesses that have generic signatures not satisfied by
// the requirement's generic signature.
if (witnessConfig.derivativeGenericSignature &&
!witnessConfig.derivativeGenericSignature
->requirementsNotSatisfiedBy(
reqDiffAttr->getDerivativeGenericSignature())
.empty())
continue;

if (witnessConfig.parameterIndices ==
reqDiffAttr->getParameterIndices())
foundExactAttr = true;
Expand Down
16 changes: 16 additions & 0 deletions test/AutoDiff/protocol_requirement_autodiff.swift
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ func blah2<T: DoubleDifferentiableDistribution>(_ x: T, _ value: T.Value) -> Tra
x.logProbability(of: value)
}

// Satisfying the requirement with more wrt indices than are necessary.

protocol DifferentiableFoo {
associatedtype T: Differentiable
@differentiable(wrt: x)
Expand All @@ -181,4 +183,18 @@ struct MoreDifferentiableFooStruct: MoreDifferentiableFoo {
}
}

// Satisfiying the requirement with a less-constrained derivative than is necessary.

protocol ExtraDerivativeConstraint {}

protocol HasExtraConstrainedDerivative {
@differentiable
func requirement<T: Differentiable & ExtraDerivativeConstraint>(_ x: T) -> T
}

struct SatisfiesDerivativeWithLessConstraint: HasExtraConstrainedDerivative {
@differentiable
func requirement<T: Differentiable>(_ x: T) -> T { x }
}

runAllTests()
18 changes: 18 additions & 0 deletions test/AutoDiff/protocol_requirement_autodiff_diags.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: %target-swift-frontend -typecheck -verify %s

protocol P {}

protocol HasRequirement {
@differentiable
// expected-note @+1 {{protocol requires function 'requirement' with type '<T> (T, T) -> T'; do you want to add a stub?}}
func requirement<T: Differentiable>(_ x: T, _ y: T) -> T
}

// expected-error @+1 {{type 'AttemptsToSatisfyRequirement' does not conform to protocol 'HasRequirement'}}
struct AttemptsToSatisfyRequirement: HasRequirement {
// This does not satisfy the requirement because the differentiable attribute is more
// constrained than the requirement's differentiable attribute.
@differentiable(where T: P)
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
func requirement<T: Differentiable>(_ x: T, _ y: T) -> T { x }
}