Skip to content

Commit

Permalink
Merge pull request #68653 from hamishknight/nestless-5.9
Browse files Browse the repository at this point in the history
[5.9] [Sema] Catch invalid if/switch exprs in more places
  • Loading branch information
bnbarham committed Sep 21, 2023
2 parents abe28dc + 5e7b70b commit 49fe7f7
Show file tree
Hide file tree
Showing 14 changed files with 428 additions and 92 deletions.
57 changes: 37 additions & 20 deletions lib/AST/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2519,30 +2519,47 @@ SingleValueStmtExpr *SingleValueStmtExpr::createWithWrappedBranches(

SingleValueStmtExpr *
SingleValueStmtExpr::tryDigOutSingleValueStmtExpr(Expr *E) {
while (true) {
// Look through implicit conversions.
if (auto *ICE = dyn_cast<ImplicitConversionExpr>(E)) {
E = ICE->getSubExpr();
continue;
class SVEFinder final : public ASTWalker {
public:
SingleValueStmtExpr *FoundSVE = nullptr;

PreWalkResult<Expr *> walkToExprPre(Expr *E) override {
if (auto *SVE = dyn_cast<SingleValueStmtExpr>(E)) {
FoundSVE = SVE;
return Action::Stop();
}

// Look through implicit exprs.
if (E->isImplicit())
return Action::Continue(E);

// Look through coercions.
if (isa<CoerceExpr>(E))
return Action::Continue(E);

// Look through try/await (this is invalid, but we'll error on it in
// effect checking).
if (isa<AnyTryExpr>(E) || isa<AwaitExpr>(E))
return Action::Continue(E);

return Action::Stop();
}
// Look through coercions.
if (auto *CE = dyn_cast<CoerceExpr>(E)) {
E = CE->getSubExpr();
continue;
PreWalkResult<Stmt *> walkToStmtPre(Stmt *S) override {
return Action::Stop();
}
// Look through try/await (this is invalid, but we'll error on it in
// effect checking).
if (auto *TE = dyn_cast<AnyTryExpr>(E)) {
E = TE->getSubExpr();
continue;
PreWalkAction walkToDeclPre(Decl *D) override {
return Action::Stop();
}
if (auto *AE = dyn_cast<AwaitExpr>(E)) {
E = AE->getSubExpr();
continue;
PreWalkResult<Pattern *> walkToPatternPre(Pattern *P) override {
return Action::Stop();
}
break;
}
return dyn_cast<SingleValueStmtExpr>(E);
PreWalkAction walkToTypeReprPre(TypeRepr *T) override {
return Action::Stop();
}
};
SVEFinder finder;
E->walk(finder);
return finder.FoundSVE;
}

SourceRange SingleValueStmtExpr::getSourceRange() const {
Expand Down
71 changes: 51 additions & 20 deletions lib/Sema/MiscDiagnostics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3850,7 +3850,20 @@ class SingleValueStmtUsageChecker final : public ASTWalker {
llvm::DenseSet<SingleValueStmtExpr *> ValidSingleValueStmtExprs;

public:
SingleValueStmtUsageChecker(ASTContext &ctx) : Ctx(ctx), Diags(ctx.Diags) {}
SingleValueStmtUsageChecker(
ASTContext &ctx, ASTNode root,
llvm::Optional<ContextualTypePurpose> contextualPurpose)
: Ctx(ctx), Diags(ctx.Diags) {
assert(!root.is<Expr *>() || contextualPurpose &&
"Must provide contextual purpose for expr");

// If we have a contextual purpose, this is for an expression. Check if it's
// an expression in a valid position.
if (contextualPurpose) {
markAnyValidTopLevelSingleValueStmt(root.get<Expr *>(),
*contextualPurpose);
}
}

private:
/// Mark a given expression as a valid position for a SingleValueStmtExpr.
Expand All @@ -3862,8 +3875,23 @@ class SingleValueStmtUsageChecker final : public ASTWalker {
ValidSingleValueStmtExprs.insert(SVE);
}

/// Mark a valid top-level expression with a given contextual purpose.
void markAnyValidTopLevelSingleValueStmt(Expr *E, ContextualTypePurpose ctp) {
// Allowed in returns, throws, and bindings.
switch (ctp) {
case CTP_ReturnStmt:
case CTP_ReturnSingleExpr:
case CTP_ThrowStmt:
case CTP_Initialization:
markValidSingleValueStmt(E);
break;
default:
break;
}
}

MacroWalking getMacroWalkingBehavior() const override {
return MacroWalking::Expansion;
return MacroWalking::ArgumentsAndExpansion;
}

AssignExpr *findAssignment(Expr *E) const {
Expand Down Expand Up @@ -3989,28 +4017,33 @@ class SingleValueStmtUsageChecker final : public ASTWalker {
if (auto *PBD = dyn_cast<PatternBindingDecl>(D)) {
for (auto idx : range(PBD->getNumPatternEntries()))
markValidSingleValueStmt(PBD->getInit(idx));

return Action::Continue();
}
// Valid as a single expression body of a function. This is needed in
// addition to ReturnStmt checking, as we will remove the return if the
// expression is inferred to be Never.
if (auto *AFD = dyn_cast<AbstractFunctionDecl>(D)) {
if (AFD->hasSingleExpressionBody())
markValidSingleValueStmt(AFD->getSingleExpressionBody());
}
return Action::Continue();
// We don't want to walk into any other decl, we will visit them as part of
// typeCheckDecl.
return Action::SkipChildren();
}
};
} // end anonymous namespace

void swift::diagnoseOutOfPlaceExprs(
ASTContext &ctx, ASTNode root,
llvm::Optional<ContextualTypePurpose> contextualPurpose) {
// TODO: We ought to consider moving this into pre-checking such that we can
// still diagnose on invalid code, and don't have to traverse over implicit
// exprs. We need to first separate out SequenceExpr folding though.
SingleValueStmtUsageChecker sveChecker(ctx, root, contextualPurpose);
root.walk(sveChecker);
}

/// Apply the warnings managed by VarDeclUsageChecker to the top level
/// code declarations that haven't been checked yet.
void swift::
performTopLevelDeclDiagnostics(TopLevelCodeDecl *TLCD) {
auto &ctx = TLCD->getDeclContext()->getASTContext();
VarDeclUsageChecker checker(TLCD, ctx.Diags);
TLCD->walk(checker);
SingleValueStmtUsageChecker sveChecker(ctx);
TLCD->walk(sveChecker);
}

/// Perform diagnostics for func/init/deinit declarations.
Expand All @@ -4026,10 +4059,6 @@ void swift::performAbstractFuncDeclDiagnostics(AbstractFunctionDecl *AFD) {
auto &ctx = AFD->getDeclContext()->getASTContext();
VarDeclUsageChecker checker(AFD, ctx.Diags);
AFD->walk(checker);

// Do a similar walk to check for out of place SingleValueStmtExprs.
SingleValueStmtUsageChecker sveChecker(ctx);
AFD->walk(sveChecker);
}

auto *body = AFD->getBody();
Expand Down Expand Up @@ -5864,10 +5893,10 @@ diagnoseDictionaryLiteralDuplicateKeyEntries(const Expr *E,
//===----------------------------------------------------------------------===//

/// Emit diagnostics for syntactic restrictions on a given expression.
void swift::performSyntacticExprDiagnostics(const Expr *E,
const DeclContext *DC,
bool isExprStmt,
bool disableExprAvailabilityChecking) {
void swift::performSyntacticExprDiagnostics(
const Expr *E, const DeclContext *DC,
llvm::Optional<ContextualTypePurpose> contextualPurpose, bool isExprStmt,
bool disableExprAvailabilityChecking, bool disableOutOfPlaceExprChecking) {
auto &ctx = DC->getASTContext();
TypeChecker::diagnoseSelfAssignment(E);
diagSyntacticUseRestrictions(E, DC, isExprStmt);
Expand All @@ -5886,6 +5915,8 @@ void swift::performSyntacticExprDiagnostics(const Expr *E,
diagnoseConstantArgumentRequirement(E, DC);
diagUnqualifiedAccessToMethodNamedSelf(E, DC);
diagnoseDictionaryLiteralDuplicateKeyEntries(E, DC);
if (!disableOutOfPlaceExprChecking)
diagnoseOutOfPlaceExprs(ctx, const_cast<Expr *>(E), contextualPurpose);
}

void swift::performStmtDiagnostics(const Stmt *S, DeclContext *DC) {
Expand Down
15 changes: 14 additions & 1 deletion lib/Sema/MiscDiagnostics.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ namespace swift {
class ApplyExpr;
class CallExpr;
class ClosureExpr;
enum ContextualTypePurpose : uint8_t;
class DeclContext;
class Decl;
class Expr;
Expand All @@ -37,10 +38,22 @@ namespace swift {
class ValueDecl;
class ForEachStmt;

/// Diagnose any expressions that appear in an unsupported position. If visiting
/// an expression directly, its \p contextualPurpose should be provided to
/// evaluate its position.
void diagnoseOutOfPlaceExprs(
ASTContext &ctx, ASTNode root,
llvm::Optional<ContextualTypePurpose> contextualPurpose);

/// Emit diagnostics for syntactic restrictions on a given expression.
///
/// Note: \p contextualPurpose must be non-nil, unless
/// \p disableOutOfPlaceExprChecking is set to \c true.
void performSyntacticExprDiagnostics(
const Expr *E, const DeclContext *DC,
bool isExprStmt, bool disableExprAvailabilityChecking = false);
llvm::Optional<ContextualTypePurpose> contextualPurpose,
bool isExprStmt, bool disableExprAvailabilityChecking = false,
bool disableOutOfPlaceExprChecking = false);

/// Emit diagnostics for a given statement.
void performStmtDiagnostics(const Stmt *S, DeclContext *DC);
Expand Down
25 changes: 19 additions & 6 deletions lib/Sema/TypeCheckConstraints.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,11 @@ class FunctionSyntacticDiagnosticWalker : public ASTWalker {
}

PreWalkResult<Expr *> walkToExprPre(Expr *expr) override {
performSyntacticExprDiagnostics(expr, dcStack.back(), /*isExprStmt=*/false);
// We skip out-of-place expr checking here since we've already performed it.
performSyntacticExprDiagnostics(expr, dcStack.back(), /*ctp*/ llvm::None,
/*isExprStmt=*/false,
/*disableAvailabilityChecking*/ false,
/*disableOutOfPlaceExprChecking*/ true);

if (auto closure = dyn_cast<ClosureExpr>(expr)) {
if (closure->isSeparatelyTypeChecked()) {
Expand Down Expand Up @@ -346,8 +350,9 @@ void constraints::performSyntacticDiagnosticsForTarget(
switch (target.kind) {
case SyntacticElementTarget::Kind::expression: {
// First emit diagnostics for the main expression.
performSyntacticExprDiagnostics(target.getAsExpr(), dc,
isExprStmt, disableExprAvailabilityChecking);
performSyntacticExprDiagnostics(
target.getAsExpr(), dc, target.getExprContextualTypePurpose(),
isExprStmt, disableExprAvailabilityChecking);
return;
}

Expand All @@ -356,17 +361,25 @@ void constraints::performSyntacticDiagnosticsForTarget(

// First emit diagnostics for the main expression.
performSyntacticExprDiagnostics(stmt->getTypeCheckedSequence(), dc,
isExprStmt,
CTP_ForEachSequence, isExprStmt,
disableExprAvailabilityChecking);

if (auto *whereExpr = stmt->getWhere())
performSyntacticExprDiagnostics(whereExpr, dc, /*isExprStmt*/ false);
performSyntacticExprDiagnostics(whereExpr, dc, CTP_Condition,
/*isExprStmt*/ false);
return;
}

case SyntacticElementTarget::Kind::function: {
// Check for out of place expressions. This needs to be done on the entire
// function body rather than on individual expressions since we need the
// context of the parent nodes.
auto *body = target.getFunctionBody();
diagnoseOutOfPlaceExprs(dc->getASTContext(), body,
/*contextualPurpose*/ llvm::None);

FunctionSyntacticDiagnosticWalker walker(dc);
target.getFunctionBody()->walk(walker);
body->walk(walker);
return;
}
case SyntacticElementTarget::Kind::closure:
Expand Down
8 changes: 3 additions & 5 deletions test/Constraints/closures.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1149,11 +1149,9 @@ struct R_76250381<Result, Failure: Error> {
// rdar://77022842 - crash due to a missing argument to a ternary operator
func rdar77022842(argA: Bool? = nil, argB: Bool? = nil) {
if let a = argA ?? false, if let b = argB ?? {
// expected-error@-1 {{'if' may only be used as expression in return, throw, or as the source of an assignment}}
// expected-error@-2 {{initializer for conditional binding must have Optional type, not 'Bool'}}
// expected-error@-3 {{cannot convert value of type '() -> ()' to expected argument type 'Bool?'}}
// expected-error@-4 {{cannot convert value of type 'Void' to expected condition type 'Bool'}}
// expected-error@-5 {{'if' must have an unconditional 'else' to be used as expression}}
// expected-error@-1 {{initializer for conditional binding must have Optional type, not 'Bool'}}
// expected-error@-2 {{cannot convert value of type '() -> ()' to expected argument type 'Bool?'}}
// expected-error@-3 {{cannot convert value of type 'Void' to expected condition type 'Bool'}}
} // expected-error {{expected '{' after 'if' condition}}
}

Expand Down
42 changes: 39 additions & 3 deletions test/Constraints/if_expr.swift
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,6 @@ func testReturnMismatch() {
let _ = if .random() {
return 1 // expected-error {{unexpected non-void return value in void function}}
// expected-note@-1 {{did you mean to add a return type?}}
// expected-error@-2 {{cannot 'return' in 'if' when used as expression}}
} else {
0
}
Expand Down Expand Up @@ -651,9 +650,46 @@ func builderWithBinding() -> Either<String, Int> {
}
}

@Builder
func builderWithInvalidBinding() -> Either<String, Int> {
let str = (if .random() { "a" } else { "b" })
// expected-error@-1 {{'if' may only be used as expression in return, throw, or as the source of an assignment}}
if .random() {
str
} else {
1
}
}

func takesBuilder(@Builder _ fn: () -> Either<String, Int>) {}

func builderClosureWithBinding() {
takesBuilder {
// Make sure the binding gets type-checked as an if expression, but the
// other if block gets type-checked as a stmt.
let str = if .random() { "a" } else { "b" }
if .random() {
str
} else {
1
}
}
}

func builderClosureWithInvalidBinding() {
takesBuilder {
let str = (if .random() { "a" } else { "b" })
// expected-error@-1 {{'if' may only be used as expression in return, throw, or as the source of an assignment}}
if .random() {
str
} else {
1
}
}
}

func builderInClosure() {
func build(@Builder _ fn: () -> Either<String, Int>) {}
build {
takesBuilder {
if .random() {
""
} else {
Expand Down
Loading

0 comments on commit 49fe7f7

Please sign in to comment.