Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1439,13 +1439,7 @@ ERROR(attr_implements_expected_member_name,PointsToFirstBadToken,
// SWIFT_ENABLE_TENSORFLOW
// differentiable
ERROR(attr_differentiable_expected_function_name,PointsToFirstBadToken,
"expected a qualified function to differentiate", ())

ERROR(attr_differentiable_missing_lsquare,PointsToFirstBadToken,
"missing '[' for parameter index list", ())

ERROR(attr_differentiable_missing_rsquare,PointsToFirstBadToken,
"missing ']' for parameter index list", ())
"expected a qualified %0 function", (StringRef))

ERROR(attr_differentiable_expected_parameter_list,PointsToFirstBadToken,
"expected a list of parameters to differentiate with respect to, e.g. (.0, w, b)", ())
Expand Down
8 changes: 7 additions & 1 deletion include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2410,6 +2410,8 @@ ERROR(implements_attr_protocol_not_conformed_to,none,
// SWIFT_ENABLE_TENSORFLOW
ERROR(differentiable_attr_no_parameters,none,
"%0 has no parameters to differentiate with respect to", (DeclName))
ERROR(differentiable_attr_void_result,none,
"cannot differentiate void function %0", (DeclName))
ERROR(differentiable_attr_primal_overload_not_found,none,
"%0 does not have expected parameters' type %1", (DeclName, Type))
ERROR(differentiable_attr_adjoint_overload_not_found,none,
Expand All @@ -2428,13 +2430,17 @@ ERROR(differentiable_attr_cannot_diff_wrt_objects_or_existentials,none,
ERROR(differentiable_attr_function_not_same_type_context,none,
"%0 is not defined in the current type context", (DeclName))
ERROR(differentiable_attr_specified_not_function,none,
"%0 is not a function to be used as %select{primal|adjoint}1",
"%0 is not a function to be used as %select{adjoint|primal}1",
(DeclName, bool))
ERROR(differentiable_attr_ambiguous_function_identifier,none,
"ambiguous or overloaded identifier %0 cannot be used in @differentiable "
"attribute", (DeclName))
ERROR(differentiable_attr_forward_mode_unsupported,none,
"forward-mode automatic differentiation is not supported yet", ())
ERROR(differentiable_attr_invalid_access,none,
"%select{adjoint|primal}2 %0 is required to either be public or "
"@usableFromInline because the original function %1 is public or "
"@usableFromInline", (DeclName, DeclName, bool))

ERROR(compiler_evaluable_bad_context,none,
"@compilerEvaluable functions not allowed here", ())
Expand Down
9 changes: 5 additions & 4 deletions lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,12 +612,13 @@ Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) {
parseToken(tok::colon,
diag::attr_differentiable_expected_colon_after_label, label))
return true;
// Parse the name of the adjoint function.
// Parse the name of the function.
Diagnostic funcDiag(diag::attr_differentiable_expected_function_name.ID,
{ label });
result.Name =
parseUnqualifiedDeclName(/*afterDot=*/false, result.Loc,
diag::attr_implements_expected_member_name,
/*allowOperators=*/true,
/*allowZeroArgCompoundNames=*/true);
funcDiag, /*allowOperators=*/true,
/*allowZeroArgCompoundNames=*/true);
return !result.Name;
};

Expand Down
219 changes: 93 additions & 126 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2215,97 +2215,6 @@ void AttributeChecker::visitImplementsAttr(ImplementsAttr *attr) {
}
}

// SWIFT_ENABLE_TENSORFLOW
// Returns the function declaration corresponding to the given function name and
// lookup context. If the function declaration cannot be resolved, emits a
// diagnostic and returns nullptr.
static FuncDecl *getResolvedFuncDecl(
DeclName funcName, SourceLoc funcNameLoc, TypeChecker &TC,
DeclContext *lookupContext,
const std::function<bool(FuncDecl *)> &isValidFuncDecl,
const std::function<bool(FuncDecl *)> &hasValidTypeContext,
const std::function<void()> &overloadDiagnostic,
const std::function<void()> &ambiguousDiagnostic,
const std::function<void()> &notFunctionDiagnostic) {

FuncDecl *resolvedFuncDecl = nullptr;

// Initialize error flags.
bool notAFuncDecl = false;
bool wrongTypeContext = false;
bool overloadNotFound = false;

// Perform lookup, ignoring access control.
auto options = defaultUnqualifiedLookupOptions |
NameLookupFlags::IgnoreAccessControl;
auto results =
TC.lookupUnqualified(lookupContext, funcName, funcNameLoc, options);

// Note: static methods are omitted from `TypeChecker.lookupUnqualified` in
// Swift 3. The code below is a workaround for resolving them.
//
// This is necessary because the stdlib is compiled with `-swift-version 3`
// for Swift 3 compatibility, and floating point types use the
// `@differentiable` attribute with static adjoint methods (such as
// `_adjointAdd`).
if (lookupContext->getASTContext().isSwiftVersion3() && results.empty() &&
lookupContext->isTypeContext()) {
auto tmp = TC.lookupMember(lookupContext,
lookupContext->getSelfTypeInContext(), funcName);
for (auto choice : tmp) {
auto decl = choice.getValueDecl();
if (!decl) continue;
auto funcDecl = dyn_cast<FuncDecl>(decl);
if (!funcDecl) continue;
results.add(LookupResultEntry(funcDecl));
}
}

for (auto choice : results) {
auto decl = choice.getValueDecl();
if (!decl) continue;

auto funcDecl = dyn_cast<FuncDecl>(decl);
if (!funcDecl) {
notAFuncDecl = true;
continue;
}
if (!hasValidTypeContext(funcDecl)) {
wrongTypeContext = true;
continue;
}
if (!isValidFuncDecl(funcDecl)) {
overloadNotFound = true;
continue;
}
if (resolvedFuncDecl) {
ambiguousDiagnostic();
resolvedFuncDecl = nullptr;
break;
}
resolvedFuncDecl = funcDecl;
}
// If function declaration could not be resolved, emit the appropriate
// diagnostic.
if (!resolvedFuncDecl) {
if (results.empty()) {
TC.diagnose(funcNameLoc, diag::use_unresolved_identifier, funcName,
funcName.isOperator());
} else if (wrongTypeContext) {
TC.diagnose(funcNameLoc,
diag::differentiable_attr_function_not_same_type_context,
funcName);
} else if (overloadNotFound) {
overloadDiagnostic();
} else {
assert(notAFuncDecl && "Expected 'not a function' error");
notFunctionDiagnostic();
}
}

return resolvedFuncDecl;
}

void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
// Forward mode is unsupported.
if (attr->getMode() == AutoDiffMode::Forward) {
Expand All @@ -2318,40 +2227,74 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
auto *original = cast<FuncDecl>(D);
auto isInstanceMethod = original->isInstanceMember();
auto selfDecl = original->getImplicitSelfDecl();
auto &ctx = original->getASTContext();

// If the original function has no parameters, there's nothing to
// differentiate with respect to.
// If the original function has no parameters or returns the empty tuple
// type, there's nothing to differentiate from or with-respect-to.
auto &originalParams = *original->getParameterList(selfDecl ? 1 : 0);
if (!isInstanceMethod && originalParams.size() == 0) {
TC.diagnose(attr->getLocation(), diag::differentiable_attr_no_parameters,
original->getName())
.highlight(original->getSourceRange());
return;
}
auto originalResultTy = original->getResultInterfaceType();
if (originalResultTy->isEqual(ctx.TheEmptyTupleType)) {
TC.diagnose(attr->getLocation(), diag::differentiable_attr_void_result,
original->getName())
.highlight(original->getSourceRange());
return;
}

auto originalParamsTy =
originalParams.getInterfaceType(original->getASTContext());
auto originalParamsTy = originalParams.getInterfaceType(ctx);

// If the original function and the primal/adjoint have different parents, or
// if they both have no type context and are in different modules, then
// it's an error.
auto hasValidTypeContext = [&](FuncDecl *decl) {
// if they both have no type context and are in different modules, then it's
// an error.
// Returns true on error.
std::function<bool(FuncDecl *)> hasValidTypeContext
= [&](FuncDecl *func) {
if (!original->getInnermostTypeContext() &&
!decl->getInnermostTypeContext() &&
original->getParentModule() == decl->getParentModule())
!func->getInnermostTypeContext() &&
original->getParentModule() == func->getParentModule())
return true;
if (auto typeCtx1 = original->getInnermostTypeContext()) {
if (auto typeCtx2 = decl->getInnermostTypeContext()) {
if (auto typeCtx2 = func->getInnermostTypeContext()) {
auto type1 = typeCtx1->getDeclaredInterfaceType();
auto type2 = typeCtx2->getDeclaredInterfaceType();
return type1->isEqual(type2);
}
}
return original->getParent() == decl->getParent();
return original->getParent() == func->getParent();
};

// If the original function is exported (i.e. it is public or
// @usableFromInline), then the primal/adjoint must also be exported.
// Returns true on error.
using FuncSpecifier = DifferentiableAttr::FunctionSpecifier;
auto checkAccessControl = [&](FuncDecl *func, FuncSpecifier funcSpec,
bool isPrimal) {
auto originalAccess =
original->getFormalAccess(/*useDC*/ nullptr,
/*treatUsableFromInlineAsPublic*/ true);
if (originalAccess < AccessLevel::Public) return false;
auto funcAccess =
func->getFormalAccess(/*useDC*/ nullptr,
/*treatUsableFromInlineAsPublic*/ true);
if (funcAccess >= AccessLevel::Public) return false;
TC.diagnose(funcSpec.Loc.getBaseNameLoc(),
diag::differentiable_attr_invalid_access,
funcSpec.Name, original->getFullName(), isPrimal);
attr->setInvalid();
return true;
};

// Set lookup options.
auto lookupOptions = defaultMemberLookupOptions
| NameLookupFlags::IgnoreAccessControl;

// Resolve the primal declaration, if it exists.
FuncDecl *resolvedPrimal = nullptr;
FuncDecl *primal = nullptr;
if (attr->getPrimal()) {
auto primalSpecifier = attr->getPrimal().getValue();
auto primalNameLoc = primalSpecifier.Loc.getBaseNameLoc();
Expand All @@ -2374,6 +2317,11 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
diag::differentiable_attr_specified_not_function,
primalSpecifier.Name, /*isPrimal*/ true);
};
std::function<void()> primalInvalidTypeContextDiagnostic = [&]() {
TC.diagnose(primalNameLoc,
diag::differentiable_attr_function_not_same_type_context,
primalSpecifier.Name);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need a setInvalid here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think setInvalid is handled by the block below:

primal = TC.lookupFuncDecl(...)
if (!primal) {
  attr->setInvalid();
  return;
}

If any diagnostic is emitted, primal will be nullptr and attr->setInvalid() will execute.

};

auto isValidPrimal = [&](FuncDecl *primalCandidate) {
// Returns true if the primal candidate
Expand All @@ -2384,8 +2332,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
auto primalSelfDecl = primalCandidate->getImplicitSelfDecl();
auto primalParams =
primalCandidate->getParameterList(primalSelfDecl ? 1 : 0);
auto primalParamsTy =
primalParams->getInterfaceType(original->getASTContext());
auto primalParamsTy = primalParams->getInterfaceType(ctx);
if (!primalParamsTy->isEqual(originalParamsTy))
return false;
auto originalCanGenSig = original->getGenericSignature()
Expand All @@ -2406,15 +2353,20 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
return true;
};

resolvedPrimal =
getResolvedFuncDecl(primalSpecifier.Name, primalNameLoc,
TC, primalTypeCtx, isValidPrimal, hasValidTypeContext,
primalOverloadDiagnostic, primalAmbiguousDiagnostic,
primalNotFunctionDiagnostic);
primal = TC.lookupFuncDecl(
primalSpecifier.Name, primalNameLoc, primalTypeCtx, isValidPrimal,
primalOverloadDiagnostic, primalAmbiguousDiagnostic,
primalNotFunctionDiagnostic, lookupOptions, hasValidTypeContext,
primalInvalidTypeContextDiagnostic);

if (!resolvedPrimal) return;
if (!primal) {
attr->setInvalid();
return;
}
// Check primal access control.
if (checkAccessControl(primal, primalSpecifier, /*isPrimal*/ true)) return;
// Memorize the primal reference in the attribute.
attr->setPrimalFunction(resolvedPrimal);
attr->setPrimalFunction(primal);
}

// Compute the return type of the adjoint function.
Expand All @@ -2428,7 +2380,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
SmallVector<TupleTypeElt, 8> retElts;
for (auto *param : originalParams)
retElts.push_back(param->getInterfaceType());
retTy = TupleType::get(retElts, original->getASTContext());
retTy = TupleType::get(retElts, ctx);
} else {
retTy = originalParams[0]->getInterfaceType();
}
Expand Down Expand Up @@ -2503,7 +2455,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
// type.
assert(retElts.size() > 0 && "There should be at least one return type");
retTy = retElts.size() > 1
? TupleType::get(retElts, original->getASTContext())
? TupleType::get(retElts, ctx)
: retElts[0].getType();
}

Expand All @@ -2517,9 +2469,8 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
// original result, and the seed.
//
// If the primal exists, the checkpoints type is the primal result type.
if (attr->getPrimal()) {
auto *primResultTy =
resolvedPrimal->getResultInterfaceType()->getAs<TupleType>();
if (primal) {
auto *primResultTy = primal->getResultInterfaceType()->getAs<TupleType>();
auto checkpointsTy = primResultTy->getElement(0).getType();
paramTypes.push_back(FunctionType::Param(checkpointsTy));
}
Expand Down Expand Up @@ -2552,7 +2503,7 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
}

// Resolve the adjoint declaration.
FuncDecl *resolvedAdjoint = nullptr;
FuncDecl *adjoint = nullptr;
auto adjointSpecifier = attr->getAdjoint();
auto adjointNameLoc = adjointSpecifier.Loc.getBaseNameLoc();

Expand All @@ -2563,35 +2514,51 @@ void AttributeChecker::visitDifferentiableAttr(DifferentiableAttr *attr) {
TC.diagnose(adjointNameLoc,
diag::differentiable_attr_adjoint_overload_not_found,
adjointSpecifier.Name, expectedAdjointFnTy);
attr->setInvalid();
};
auto adjointAmbiguousDiagnostic = [&]() {
TC.diagnose(adjointNameLoc,
diag::differentiable_attr_ambiguous_function_identifier,
adjointSpecifier.Name);
attr->setInvalid();
};
auto adjointNotFunctionDiagnostic = [&]() {
TC.diagnose(adjointNameLoc,
diag::differentiable_attr_specified_not_function,
adjointSpecifier.Name, /*isPrimal*/ false);
attr->setInvalid();
};
std::function<void()> adjointInvalidTypeContextDiagnostic = [&]() {
TC.diagnose(adjointNameLoc,
diag::differentiable_attr_function_not_same_type_context,
adjointSpecifier.Name);
attr->setInvalid();
};

auto isValidAdjoint = [&](FuncDecl *adjointCandidate) {
// Returns true if adjoint candidate has the expected type.
auto adjointType = adjointCandidate->getInterfaceType()
->getUnlabeledType(original->getASTContext());
->getUnlabeledType(ctx);
return adjointType->isEqual(expectedAdjointFnTy);
};

resolvedAdjoint =
getResolvedFuncDecl(adjointSpecifier.Name, adjointNameLoc,
TC, adjointTypeCtx, isValidAdjoint, hasValidTypeContext,
adjointOverloadDiagnostic, adjointAmbiguousDiagnostic,
adjointNotFunctionDiagnostic);
adjoint =
TC.lookupFuncDecl(adjointSpecifier.Name, adjointNameLoc, adjointTypeCtx,
isValidAdjoint, adjointOverloadDiagnostic,
adjointAmbiguousDiagnostic, adjointNotFunctionDiagnostic,
lookupOptions, hasValidTypeContext,
adjointInvalidTypeContextDiagnostic);

if (!resolvedAdjoint) return;
if (!adjoint) {
attr->setInvalid();
return;
}
// Check adjoint access control.
if (checkAccessControl(adjoint, adjointSpecifier, /*isPrimal*/ false))
return;
// Done checking @differentiable attribute.
// Memorize the adjoint reference in the attribute.
attr->setAdjointFunction(resolvedAdjoint);
attr->setAdjointFunction(adjoint);
}

void AttributeChecker::visitCompilerEvaluableAttr(CompilerEvaluableAttr *attr) {
Expand Down
Loading