Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,12 @@ class AutoDiffParameterIndicesBuilder {
/// `AutoDiffParameterIndices::parameters` for documentation about the order.
void setParameter(unsigned parameterIndex);

/// Sets the parameters at indices in the specified range.
void setParameters(unsigned lowerBound, unsigned upperBound);

/// Sets all parameters.
void setAllParameters();

/// Returns the number of parameters.
unsigned size() { return parameters.size(); }
};
Expand Down
4 changes: 2 additions & 2 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2446,8 +2446,6 @@ WARNING(differentiable_implicit_noderivative_fixit,none,
"stored property %0 has no derivative because it does not conform to "
"'Differentiable'; add '@noDerivative' to make it explicit",
(Identifier))
NOTE(protocol_witness_missing_differentiable_attr,none,
"candidate is missing attribute '%0'", (StringRef))

NOTE(codable_extraneous_codingkey_case_here,none,
"CodingKey case %0 does not match any stored properties", (Identifier))
Expand Down Expand Up @@ -2728,6 +2726,8 @@ ERROR(differentiable_attr_unsupported_req_kind,none,
"layout requirement are not supported by '@differentiable' attribute", ())
ERROR(differentiable_attr_class_unsupported,none,
"class members cannot be marked with '@differentiable'", ())
NOTE(protocol_witness_missing_specific_differentiable_attr,none,
"candidate is missing attribute '%0'", (StringRef))

// @differentiang
ERROR(differentiating_attr_expected_result_tuple,none,
Expand Down
53 changes: 22 additions & 31 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -557,15 +557,17 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
// SWIFT_ENABLE_TENSORFLOW
case DAK_Differentiable: {
Printer.printAttrName("@differentiable");
Printer << '(';
auto *attr = cast<DifferentiableAttr>(this);
auto parsedParams = attr->getParsedParameters();

// If no attribute parameter is specified, do not print parentheses at all.
if (parsedParams.empty() && !attr->getJVP() && !attr->getVJP() &&
!attr->getWhereClause())
break;
Printer << '(';
// Get original function.
auto *original = dyn_cast_or_null<AbstractFunctionDecl>(D);
if (auto *varDecl = dyn_cast_or_null<VarDecl>(D))
original = varDecl->getGetter();
bool isMethod = original && original->hasImplicitSelfDecl();

// Print comma if not leading clause.
bool isLeadingClause = true;
Expand All @@ -579,35 +581,24 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,

// Print differentiation parameters, if any.
if (auto indices = attr->getParameterIndices()) {
printCommaIfNecessary();
Printer << "wrt: (";
SmallBitVector parameters(indices->parameters);
// Check if differentiating wrt `self`. If so, manually print it first.
if (isMethod && parameters.test(parameters.size() - 1)) {
parameters.reset(parameters.size() - 1);
Printer << "self";
if (parameters.any())
Printer << ", ";
if (!parsedParams.empty()) {
printCommaIfNecessary();
Printer << "wrt: ";
if (parsedParams.size() > 1)
Printer << '(';
interleave(parsedParams, [&](const ParsedAutoDiffParameter &param) {
switch (param.getKind()) {
case ParsedAutoDiffParameter::Kind::Named:
Printer << param.getName();
break;
case ParsedAutoDiffParameter::Kind::Self:
Printer << "self";
break;
}
}, [&]{ Printer << ", "; });
if (parsedParams.size() > 1)
Printer << ')';
}
// Print remaining differentiation parameters.
interleave(parameters.set_bits(), [&](unsigned index) {
Printer << original->getParameters()->get(index)->getName().str();
}, [&] { Printer << ", "; });
Printer << ")";
} else if (!parsedParams.empty()) {
printCommaIfNecessary();
Printer << "wrt: (";
interleave(parsedParams, [&](const ParsedAutoDiffParameter &param) {
switch (param.getKind()) {
case ParsedAutoDiffParameter::Kind::Named:
Printer << param.getName();
break;
case ParsedAutoDiffParameter::Kind::Self:
Printer << "self";
break;
}
}, [&] { Printer << ", "; });
Printer << ")";
}
// Print jvp function name.
if (auto jvp = attr->getJVP()) {
Expand Down
13 changes: 11 additions & 2 deletions lib/AST/AutoDiff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ static unsigned getNumAutoDiffParameterIndices(AnyFunctionType *fnTy) {
}

AutoDiffParameterIndicesBuilder::AutoDiffParameterIndicesBuilder(
AnyFunctionType *functionType, bool setAllParams) :
parameters(getNumAutoDiffParameterIndices(functionType), setAllParams) {
AnyFunctionType *functionType, bool setAllParams)
: parameters(getNumAutoDiffParameterIndices(functionType), setAllParams) {
}

AutoDiffParameterIndices *
Expand All @@ -276,6 +276,15 @@ void AutoDiffParameterIndicesBuilder::setParameter(unsigned paramIndex) {
parameters.set(paramIndex);
}

void AutoDiffParameterIndicesBuilder::setParameters(unsigned lowerBound,
unsigned upperBound) {
parameters.set(lowerBound, upperBound);
}

void AutoDiffParameterIndicesBuilder::setAllParameters() {
parameters.set();
}

Type VectorSpace::getType() const {
switch (kind) {
case Kind::Vector:
Expand Down
20 changes: 7 additions & 13 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2462,15 +2462,12 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
LookUpConformanceInModule(D->getDeclContext()->getParentModule());

AbstractFunctionDecl *original = nullptr;
bool isProperty = false;
if (auto *vd = dyn_cast<VarDecl>(D)) {
// When used on a storage decl, @differentiable refers to its getter.
original = vd->getGetter();
isProperty = true;
} else if (auto *afd = dyn_cast<AbstractFunctionDecl>(D)) {
original = afd;
if (auto *accessor = dyn_cast<AccessorDecl>(afd)) {
isProperty = true;
// We do not support setters yet because inout is not supported yet.
if (accessor->isSetter())
original = nullptr;
Expand Down Expand Up @@ -2609,16 +2606,13 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
AutoDiffParameterIndicesBuilder autoDiffParameterIndicesBuilder(
originalFnTy);
if (parsedWrtParams.empty()) {
if (isProperty)
autoDiffParameterIndicesBuilder.setParameter(0);
else {
// If 'wrt:' is not specified, the wrt parameters are all the parameters
// in the main parameter group. Self is intentionally excluded except
// when it's a property.
unsigned numNonSelfParameters = autoDiffParameterIndicesBuilder.size() -
(isMethod ? 1 : 0);
for (unsigned i : range(numNonSelfParameters))
autoDiffParameterIndicesBuilder.setParameter(i);
if (original->isStatic() || isa<ConstructorDecl>(original)) {
auto *methodTy =
original->getMethodInterfaceType()->castTo<AnyFunctionType>();
autoDiffParameterIndicesBuilder
.setParameters(0, methodTy->getNumParams());
} else {
autoDiffParameterIndicesBuilder.setAllParameters();
}
} else {
// 'wrt:' is specified. Validate and collect the selected parameters.
Expand Down
54 changes: 26 additions & 28 deletions lib/Sema/TypeCheckProtocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,24 +661,22 @@ swift::matchWitness(
}

// SWIFT_ENABLE_TENSORFLOW
// Differentiation attributes must match completely or the generated
// functions will have the wrong signature.
// TODO(TF-285): Handle multiple `@differentiable` attributes on protocol
// requirements. Only missing attributes should be diagnosed.
auto *reqDiffAttr =
reqAttrs.getAttribute<DifferentiableAttr>(/*AllowInvalid*/ true);
auto *witnessDiffAttr =
witnessAttrs.getAttribute<DifferentiableAttr>(/*AllowInvalid*/ true);
if (reqDiffAttr && (!reqDiffAttr->getParameterIndices() ||
!witnessDiffAttr ||
!witnessDiffAttr->getParameterIndices() ||
!witnessDiffAttr->parametersMatch(*reqDiffAttr))) {
if (auto *vdWitness = dyn_cast<VarDecl>(witness))
return RequirementMatch(
getStandinForAccessor(vdWitness, AccessorKind::Get),
MatchKind::DifferentiableConflict);
else
return RequirementMatch(witness, MatchKind::DifferentiableConflict);
// '@differentiable' attributes must match completely.
for (auto *reqDiffAttr : reqAttrs.getAttributes<DifferentiableAttr>()) {
auto witnessDiffAttrs =
witnessAttrs.getAttributes<DifferentiableAttr, /*AllowInvalid*/ true>();
bool reqDiffAttrMatch = llvm::any_of(witnessDiffAttrs,
[&](const DifferentiableAttr *witnessDiffAttr) {
return witnessDiffAttr->parametersMatch(*reqDiffAttr);
});
if (!reqDiffAttrMatch) {
if (auto *vdWitness = dyn_cast<VarDecl>(witness))
return RequirementMatch(
getStandinForAccessor(vdWitness, AccessorKind::Get),
MatchKind::DifferentiableConflict);
else
return RequirementMatch(witness, MatchKind::DifferentiableConflict);
}
}

// Now finalize the match.
Expand Down Expand Up @@ -2244,20 +2242,20 @@ diagnoseMatch(ModuleDecl *module, NormalProtocolConformance *conformance,
diags.diagnose(match.Witness, diag::protocol_witness_not_objc);
break;
// SWIFT_ENABLE_TENSORFLOW
case MatchKind::DifferentiableConflict:
std::string diffAttrReq;
{
case MatchKind::DifferentiableConflict: {
for (auto *da : req->getAttrs()
.getAttributes<DifferentiableAttr, /*allowInvalid*/ true>()) {
assert(da);
std::string diffAttrReq;
llvm::raw_string_ostream stream(diffAttrReq);
// TODO(TF-285): Handle multiple `@differentiable` attributes on protocol
// requirements. Only missing attributes should be diagnosed.
req->getAttrs().getAttribute<DifferentiableAttr>()->print(stream, req);
diffAttrReq = StringRef(stream.str()).trim();
da->print(stream, req);
diags.diagnose(match.Witness,
diag::protocol_witness_missing_specific_differentiable_attr,
StringRef(stream.str()).trim());
}
diags.diagnose(match.Witness,
diag::protocol_witness_missing_differentiable_attr,
diffAttrReq);
break;
}
}
}

ConformanceChecker::ConformanceChecker(
Expand Down
6 changes: 3 additions & 3 deletions test/AutoDiff/autodiff_diagnostics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ _ = gradient(at: 0) { x in if_else(0, true) }

var a: Float = 3.0
protocol P {
@differentiable
@differentiable(wrt: x)
func foo(x: Float) -> Float
}

enum T : P {
// expected-note @+2 {{when differentiating this function definition}}
// expected-error @+1 {{function is not differentiable}}
@differentiable func foo(x: Float) -> Float {
@differentiable(wrt: x) func foo(x: Float) -> Float {
// expected-note @+1 {{cannot differentiate writes to global variables}}
a = a + x
return a
Expand All @@ -127,7 +127,7 @@ enum T : P {

// expected-note @+2 {{when differentiating this function definition}}
// expected-error @+1 {{function is not differentiable}}
@differentiable func foo(x: Float) -> Float {
@differentiable(wrt: x) func foo(x: Float) -> Float {
// expected-note @+1 {{cannot differentiate writes to global variables}}
a = a + x
return a
Expand Down
2 changes: 1 addition & 1 deletion test/AutoDiff/derived_differentiable_properties.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ public struct Foo : Differentiable {
}

// CHECK-AST-LABEL: @_fieldwiseDifferentiable public struct Foo : Differentiable {
// CHECK-AST: @differentiable(wrt: (self))
// CHECK-AST: @differentiable
// CHECK-AST: public var a: Float
// CHECK-AST: internal init(a: Float)
// CHECK-AST: @_fieldwiseDifferentiable public struct AllDifferentiableVariables
Expand Down
46 changes: 39 additions & 7 deletions test/AutoDiff/differentiable_attr_type_checking.swift
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ extension JVPStruct : Differentiable {
}

extension JVPStruct {
@differentiable(jvp: wrtAllNonSelfJVP)
@differentiable(wrt: x, jvp: wrtAllNonSelfJVP)
func wrtAllNonSelf(x: Float) -> Float {
return x + p
}
Expand Down Expand Up @@ -318,7 +318,7 @@ extension VJPStruct : Differentiable {
}

extension VJPStruct {
@differentiable(vjp: wrtAllNonSelfVJP)
@differentiable(wrt: x, vjp: wrtAllNonSelfVJP)
func wrtAllNonSelf(x: Float) -> Float {
return x + p
}
Expand Down Expand Up @@ -422,7 +422,7 @@ func vjpWhere2<Scalar : Numeric & Differentiable>(x: Tensor<Scalar>) -> (Tensor<

struct A<T> {
struct B<U, V> {
@differentiable(where T : Differentiable, V : Differentiable, V.TangentVector == V)
@differentiable(wrt: x where T : Differentiable, V : Differentiable, V.TangentVector == V)
func whereInGenericContext<T>(x: T) -> T {
return x
}
Expand Down Expand Up @@ -510,18 +510,50 @@ struct DifferentiableInitStruct : DifferentiableInit {
var y: Float

// FIXME(TF-284): Fix unexpected diagnostic.
// expected-note @+2 {{candidate is missing attribute '@differentiable(wrt: (x, y))'}}
// expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: (x))'}}
// expected-note @+2 {{candidate is missing attribute '@differentiable'}}
// expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: x)'}}
init(x: Float, y: Float) {
self.x = x
self.y = y
}

// FIXME(TF-284): Fix unexpected diagnostic.
// expected-note @+2 {{candidate is missing attribute '@differentiable(wrt: (x))'}}
// expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: (x, y))'}}
// expected-note @+2 {{candidate is missing attribute '@differentiable(wrt: x)'}}
// expected-note @+1 {{candidate is missing attribute '@differentiable'}}
init(x: Float, y: Int) {
self.x = x
self.y = Float(y)
}
}


protocol NotRefiningDiffable {
@differentiable(wrt: x)
// expected-note @+1 {{protocol requires function 'a' with type '(Float) -> Float'; do you want to add a stub?}}
func a(_ x: Float) -> Float
}

// expected-error @+1 {{type 'CertainlyNotDiffableWrtSelf' does not conform to protocol 'NotRefiningDiffable'}}
struct CertainlyNotDiffableWrtSelf : NotRefiningDiffable {
// expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: x)'}}
func a(_ x: Float) -> Float { return x * 5.0 }
}


protocol TF285 : Differentiable {
@differentiable(wrt: (x, y))
@differentiable(wrt: x)
// expected-note @+1 {{protocol requires function 'foo(x:y:)' with type '(Float, Float) -> Float'; do you want to add a stub?}}
func foo(x: Float, y: Float) -> Float
}

// expected-error @+1 {{type 'TF285MissingOneDiffAttr' does not conform to protocol 'TF285'}}
struct TF285MissingOneDiffAttr : TF285 {
// Requirement is missing an attribute.
@differentiable(wrt: x)
// expected-note @+2 {{candidate is missing attribute '@differentiable(wrt: x)}}
// expected-note @+1 {{candidate is missing attribute '@differentiable(wrt: (x, y))}}
func foo(x: Float, y: Float) -> Float {
return x
}
}
6 changes: 3 additions & 3 deletions test/AutoDiff/existential.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ import StdlibUnittest
var ExistentialTests = TestSuite("Existential")

protocol A {
@differentiable
func a(_: Float) -> Float
@differentiable(wrt: x)
func a(_ x: Float) -> Float
}
func b(g: A) -> Float { return (3.0 as Float).gradient() { x in g.a(x) } }

struct B : A {
@differentiable
@differentiable(wrt: x)
func a(_ x: Float) -> Float { return x * 5.0 }
}

Expand Down
2 changes: 1 addition & 1 deletion test/AutoDiff/sildeclref_parse.sil
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import Swift

protocol Proto {
@differentiable()
@differentiable(wrt: (x, y))
func f(_ x: Float, _ y: Float) -> Float
}

Expand Down
Loading