Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1508,6 +1508,19 @@ class DifferentiableAttr final
static bool classof(const DeclAttribute *DA) {
return DA->getKind() == DAK_Differentiable;
}

bool parametersMatch(const DifferentiableAttr &other) const {
auto a = getParsedParameters();
auto b = other.getParsedParameters();
if (a.size() != b.size())
return false;

for (unsigned i = 0, n = b.size(); i < n; ++i) {
if (!a[i].isEqual(b[i]))
return false;
}
return true;
}
};

/// \brief Attributes that may be applied to declarations.
Expand Down
2 changes: 2 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2449,6 +2449,8 @@ ERROR(broken_differentiable_requirement,none,
WARNING(differentiable_implicit_noderivative_fixit,none,
"stored property has no derivative because it does not conform to "
"'Differentiable'; add '@noDerivative' to make it explicit", ())
NOTE(protocol_witness_missing_differentiable_attr,none,
"candidate is missing attribute '%0'", (StringRef))

NOTE(codable_extraneous_codingkey_case_here,none,
"CodingKey case %0 does not match any stored properties", (Identifier))
Expand Down
27 changes: 27 additions & 0 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,22 @@ swift::matchWitness(
cast<AbstractFunctionDecl>(witness)->hasThrows())
return RequirementMatch(witness, MatchKind::RethrowsConflict);

// SWIFT_ENABLE_TENSORFLOW
// Differentiation attributes must match completely or the generated
// functions will have the wrong signature.
{
auto *reqDifferentiationAttr =
reqAttrs.getAttribute<DifferentiableAttr>(/*AllowInvalid*/ true);
auto *witnessDifferentiationAttr =
witnessAttrs.getAttribute<DifferentiableAttr>(
/*AllowInvalid*/ true);
if (reqDifferentiationAttr &&
(!witnessDifferentiationAttr ||
!witnessDifferentiationAttr->parametersMatch(
*reqDifferentiationAttr)))
return RequirementMatch(witness, MatchKind::DifferentiableConflict);
}

// We want to decompose the parameters to handle them separately.
decomposeFunctionType = true;
} else if (auto *witnessASD = dyn_cast<AbstractStorageDecl>(witness)) {
Expand Down Expand Up @@ -2212,6 +2228,17 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
case MatchKind::NonObjC:
diags.diagnose(match.Witness, diag::protocol_witness_not_objc);
break;
// SWIFT_ENABLE_TENSORFLOW
case MatchKind::DifferentiableConflict:
std::string diffAttrReq;
{
llvm::raw_string_ostream stream(diffAttrReq);
req->getAttrs().getAttribute<DifferentiableAttr>()->print(stream, req);
}
diags.diagnose(match.Witness,
diag::protocol_witness_missing_differentiable_attr,
diffAttrReq);
break;
}
}

Expand Down
8 changes: 8 additions & 0 deletions lib/Sema/TypeCheckProtocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ enum class MatchKind : uint8_t {

/// The witness is explicitly @nonobjc but the requirement is @objc.
NonObjC,

// SWIFT_ENABLE_TENSORFLOW
/// The @differentiable attribute does not match.
DifferentiableConflict,
};

/// Describes the kind of optional adjustment performed when
Expand Down Expand Up @@ -418,6 +422,8 @@ struct RequirementMatch {
case MatchKind::RethrowsConflict:
case MatchKind::ThrowsConflict:
case MatchKind::NonObjC:
// SWIFT_ENABLE_TENSORFLOW
case MatchKind::DifferentiableConflict:
return false;
}

Expand Down Expand Up @@ -446,6 +452,8 @@ struct RequirementMatch {
case MatchKind::RethrowsConflict:
case MatchKind::ThrowsConflict:
case MatchKind::NonObjC:
// SWIFT_ENABLE_TENSORFLOW
case MatchKind::DifferentiableConflict:
return false;
}

Expand Down
25 changes: 25 additions & 0 deletions test/AutoDiff/differentiable_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -404,3 +404,28 @@ func invalidRequirementConformance<Scalar>(x: Scalar) -> Scalar {
func invalidRequirementLayout<Scalar>(x: Scalar) -> Scalar {
return x
}


protocol DiffReq : Differentiable {
// expected-note @+2 {{protocol requires function 'f1'}}
@differentiable(wrt: (self, x))
func f1(_ x: Float) -> Float

// expected-note @+2 {{protocol requires function 'f2'}}
@differentiable(wrt: (self, x, y))
func f2(_ x: Float, _ y: Float) -> Float
}

// expected-error @+1 {{does not conform to protocol}}
struct ConformingWithErrors : DiffReq {
// expected-note @+1 {{@differentiable(wrt: (x, self))}}
func f1(_ x: Float) -> Float {
return x
}

// expected-note @+2 {{@differentiable(wrt: (x, y, self))}}
@differentiable(wrt: (self, x))
func f2(_ x: Float, _ y: Float) -> Float {
return x + y
}
}
1 change: 1 addition & 0 deletions test/AutoDiff/protocol_requirement_autodiff.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ struct Quadratic : DiffReq, Equatable {
self.c = c
}

@differentiable(wrt: (self, x))
func f(_ x: Float) -> Float {
return a * x * x + b * x + c
}
Expand Down
3 changes: 3 additions & 0 deletions test/AutoDiff/witness_table_silgen.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct S : Proto, VectorNumeric {
return (p, { dp in S(p: dp) })
}

@differentiable()
func function1(_ x: Float, _ y: Float) -> Float {
return x + y + p
}
Expand All @@ -48,6 +49,7 @@ struct S : Proto, VectorNumeric {
// CHECK: apply [[VJP1]]
// CHECK: } // end sil function 'AD__{{.*}}function1{{.*}}_vjp_SSU'

@differentiable(wrt: (self, x, y))
func function2(_ x: Float, _ y: Float) -> Float {
return x + y + p
}
Expand All @@ -68,6 +70,7 @@ struct S : Proto, VectorNumeric {
// CHECK: apply [[VJP2]]
// CHECK: } // end sil function 'AD__{{.*}}function2{{.*}}_vjp_SSS'

@differentiable(wrt: (y))
func function3(_ x: Float, _ y: Float) -> Float {
return x + y + p
}
Expand Down