Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/DifferentiableProgramming.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Differentiable Programming Manifesto

* Authors: [Richard Wei], [Dan Zheng], [Marc Rasi], [Bart Chrzaszcz]
* Status: Partially implemented on master, feature gated under `import _Differentiation`
* Status:
* Partially implemented on main, feature gated under `import _Differentiation`
* Initial proposal [pitched](https://forums.swift.org/t/differentiable-programming-for-gradient-based-machine-learning/42147) with a significantly scoped-down subset of features. Please refer to the linked pitch thread for the latest design discussions and changes.

## Table of contents

Expand Down
8 changes: 7 additions & 1 deletion include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -4473,13 +4473,19 @@ ERROR(differentiable_function_type_invalid_parameter,none,
"%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing "
"function type is '@differentiable%select{|(linear)}1'"
"%select{|; did you want to add '@noDerivative' to this parameter?}2",
(StringRef, /*tangentVectorEqualsSelf*/ bool,
(StringRef, /*isLinear*/ bool,
/*hasValidDifferentiabilityParameter*/ bool))
ERROR(differentiable_function_type_invalid_result,none,
"result type '%0' does not conform to 'Differentiable'"
"%select{| and satisfy '%0 == %0.TangentVector'}1, but the enclosing "
"function type is '@differentiable%select{|(linear)}1'",
(StringRef, bool))
ERROR(differentiable_function_type_no_differentiability_parameters,
none,
"'@differentiable' function type requires at least one differentiability "
"parameter, i.e. a non-'@noDerivative' parameter whose type conforms to "
"'Differentiable'%select{| with its 'TangentVector' equal to itself}0",
(/*isLinear*/ bool))

// SIL
ERROR(opened_non_protocol,none,
Expand Down
13 changes: 12 additions & 1 deletion lib/Sema/TypeCheckType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2610,22 +2610,33 @@ TypeResolver::resolveASTFunctionTypeParams(TupleTypeRepr *inputRepr,
return isDifferentiable(param.getPlainType(),
/*tangentVectorEqualsSelf*/ isLinear);
}) != elements.end();
bool alreadyDiagnosedOneParam = false;
for (unsigned i = 0, end = inputRepr->getNumElements(); i != end; ++i) {
auto *eltTypeRepr = inputRepr->getElementType(i);
auto param = elements[i];
if (param.isNoDerivative())
continue;
auto paramType = param.getPlainType();
if (isDifferentiable(paramType, /*tangentVectorEqualsSelf*/ isLinear))
if (isDifferentiable(paramType, isLinear))
continue;
auto paramTypeString = paramType->getString();
auto diagnostic =
diagnose(eltTypeRepr->getLoc(),
diag::differentiable_function_type_invalid_parameter,
paramTypeString, isLinear, hasValidDifferentiabilityParam);
alreadyDiagnosedOneParam = true;
if (hasValidDifferentiabilityParam)
diagnostic.fixItInsert(eltTypeRepr->getLoc(), "@noDerivative ");
}
// Reject the case where all parameters have '@noDerivative'.
if (!alreadyDiagnosedOneParam && !hasValidDifferentiabilityParam) {
diagnose(
inputRepr->getLoc(),
diag::
differentiable_function_type_no_differentiability_parameters,
isLinear)
.highlight(inputRepr->getSourceRange());
}
}

return elements;
Expand Down
12 changes: 12 additions & 0 deletions test/AutoDiff/Sema/differentiable_func_type.swift
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,18 @@ let _: (@noDerivative Float, Float) -> Float

let _: @differentiable (Float, @noDerivative Float) -> Float // okay

// expected-error @+1 {{'@differentiable' function type requires at least one differentiability parameter, i.e. a non-'@noDerivative' parameter whose type conforms to 'Differentiable'}}
let _: @differentiable (@noDerivative Float) -> Float

// expected-error @+1 {{'@differentiable' function type requires at least one differentiability parameter, i.e. a non-'@noDerivative' parameter whose type conforms to 'Differentiable'}}
let _: @differentiable (@noDerivative Float, @noDerivative Int) -> Float

// expected-error @+1 {{'@differentiable' function type requires at least one differentiability parameter, i.e. a non-'@noDerivative' parameter whose type conforms to 'Differentiable'}}
let _: @differentiable (@noDerivative Float, @noDerivative Float) -> Float

// expected-error @+1 {{parameter type 'Int' does not conform to 'Differentiable' and satisfy 'Int == Int.TangentVector', but the enclosing function type is '@differentiable(linear)'}}
let _: @differentiable(linear) (@noDerivative Float, Int) -> Float

// expected-error @+1 {{'@noDerivative' may only be used on parameters of '@differentiable' function types}}
let _: (Float) -> @noDerivative Float

Expand Down