diff --git a/include/swift/AST/Expr.h b/include/swift/AST/Expr.h index 08980a8f85989..bb370c0b332c8 100644 --- a/include/swift/AST/Expr.h +++ b/include/swift/AST/Expr.h @@ -6498,16 +6498,20 @@ class KeyPathDotExpr : public Expr { } }; +struct ForCollectionInit { + VarDecl *ForAccumulatorDecl; + PatternBindingDecl *ForAccumulatorBinding; +}; + /// An expression that may wrap a statement which produces a single value. class SingleValueStmtExpr : public Expr { public: - enum class Kind { - If, Switch, Do, DoCatch - }; + enum class Kind { If, Switch, Do, DoCatch, For }; private: Stmt *S; DeclContext *DC; + std::optional ForExpressionPreamble; SingleValueStmtExpr(Stmt *S, DeclContext *DC) : Expr(ExprKind::SingleValueStmt, /*isImplicit*/ true), S(S), DC(DC) {} @@ -6572,6 +6576,14 @@ class SingleValueStmtExpr : public Expr { SourceRange getSourceRange() const; + std::optional getForExpressionPreamble() const { + return this->ForExpressionPreamble; + } + + void setForExpressionPreamble(ForCollectionInit newPreamble) { + this->ForExpressionPreamble = newPreamble; + } + static bool classof(const Expr *E) { return E->getKind() == ExprKind::SingleValueStmt; } diff --git a/include/swift/AST/KnownIdentifiers.def b/include/swift/AST/KnownIdentifiers.def index a05145cbc9fa1..02ef359f568e3 100644 --- a/include/swift/AST/KnownIdentifiers.def +++ b/include/swift/AST/KnownIdentifiers.def @@ -29,6 +29,7 @@ IDENTIFIER(alloc) IDENTIFIER(allocWithZone) IDENTIFIER(allZeros) IDENTIFIER(accumulated) +IDENTIFIER(append) IDENTIFIER(ActorType) IDENTIFIER(Any) IDENTIFIER(ArrayLiteralElement) diff --git a/include/swift/Basic/Features.def b/include/swift/Basic/Features.def index cc2665001e14c..cc66d0f279c4a 100644 --- a/include/swift/Basic/Features.def +++ b/include/swift/Basic/Features.def @@ -424,6 +424,9 @@ EXPERIMENTAL_FEATURE(ThenStatements, false) /// Enable 'do' expressions. EXPERIMENTAL_FEATURE(DoExpressions, false) +/// Enable 'for' expressions. +EXPERIMENTAL_FEATURE(ForExpressions, false) + /// Enable implicitly treating the last expression in a function, closure, /// and 'if'/'switch' expression as the result. EXPERIMENTAL_FEATURE(ImplicitLastExprResults, false) diff --git a/include/swift/Sema/SyntacticElementTarget.h b/include/swift/Sema/SyntacticElementTarget.h index 674a429ccf7b8..49b14df4d94af 100644 --- a/include/swift/Sema/SyntacticElementTarget.h +++ b/include/swift/Sema/SyntacticElementTarget.h @@ -10,7 +10,8 @@ // //===----------------------------------------------------------------------===// // -// This file defines the SyntacticElementTarget class. +// This file defines the SyntacticElementTarget class (a unit of +// type-checking). // //===----------------------------------------------------------------------===// @@ -59,8 +60,8 @@ struct PackIterationInfo { /// within the constraint system. using ForEachStmtInfo = TaggedUnion; -/// Describes the target to which a constraint system's solution can be -/// applied. +/// Describes the target (a unit of type-checking) to which a constraint +/// system's solution can be applied. class SyntacticElementTarget { public: enum class Kind { diff --git a/lib/AST/ASTDumper.cpp b/lib/AST/ASTDumper.cpp index b79b88e91d879..c8bbe0dd5afd0 100644 --- a/lib/AST/ASTDumper.cpp +++ b/lib/AST/ASTDumper.cpp @@ -4496,6 +4496,12 @@ class PrintExpr : public ExprVisitor, void visitSingleValueStmtExpr(SingleValueStmtExpr *E, Label label) { printCommon(E, "single_value_stmt_expr", label); printDeclContext(E); + if (auto preamble = E->getForExpressionPreamble()) { + printRec(preamble->ForAccumulatorDecl, + Label::optional("for_preamble_accumulator_decl")); + printRec(preamble->ForAccumulatorBinding, + Label::optional("for_preamble_accumulator_binding")); + } printRec(E->getStmt(), &E->getDeclContext()->getASTContext(), Label::optional("stmt")); printFoot(); diff --git a/lib/AST/ASTWalker.cpp b/lib/AST/ASTWalker.cpp index 5e7580cdfabe2..c72bdd2ee06ad 100644 --- a/lib/AST/ASTWalker.cpp +++ b/lib/AST/ASTWalker.cpp @@ -1385,6 +1385,16 @@ class Traversal : public ASTVisitorgetForExpressionPreamble()) { + if (doIt(preamble->ForAccumulatorDecl)) { + return nullptr; + } + + if (doIt(preamble->ForAccumulatorBinding)) { + return nullptr; + } + } + if (auto *S = doIt(E->getStmt())) { E->setStmt(S); } else { diff --git a/lib/AST/Expr.cpp b/lib/AST/Expr.cpp index 881e269d348c0..8aff969050f5b 100644 --- a/lib/AST/Expr.cpp +++ b/lib/AST/Expr.cpp @@ -2763,6 +2763,8 @@ SingleValueStmtExpr::Kind SingleValueStmtExpr::getStmtKind() const { return Kind::Do; case StmtKind::DoCatch: return Kind::DoCatch; + case StmtKind::ForEach: + return Kind::For; default: llvm_unreachable("Unhandled kind!"); } @@ -2781,6 +2783,9 @@ SingleValueStmtExpr::getBranches(SmallVectorImpl &scratch) const { return scratch; case Kind::DoCatch: return cast(getStmt())->getBranches(scratch); + case Kind::For: + scratch.push_back(cast(getStmt())->getBody()); + return scratch; } llvm_unreachable("Unhandled case in switch!"); } diff --git a/lib/AST/FeatureSet.cpp b/lib/AST/FeatureSet.cpp index 0184fe7e940b2..8ff646b1922a5 100644 --- a/lib/AST/FeatureSet.cpp +++ b/lib/AST/FeatureSet.cpp @@ -116,6 +116,7 @@ UNINTERESTING_FEATURE(RegionBasedIsolation) UNINTERESTING_FEATURE(PlaygroundExtendedCallbacks) UNINTERESTING_FEATURE(ThenStatements) UNINTERESTING_FEATURE(DoExpressions) +UNINTERESTING_FEATURE(ForExpressions) UNINTERESTING_FEATURE(ImplicitLastExprResults) UNINTERESTING_FEATURE(RawLayout) UNINTERESTING_FEATURE(Embedded) diff --git a/lib/SILGen/SILGenExpr.cpp b/lib/SILGen/SILGenExpr.cpp index 9f20526949ef0..699c883c3f46d 100644 --- a/lib/SILGen/SILGenExpr.cpp +++ b/lib/SILGen/SILGenExpr.cpp @@ -2515,6 +2515,17 @@ RValue RValueEmitter::visitEnumIsCaseExpr(EnumIsCaseExpr *E, RValue RValueEmitter::visitSingleValueStmtExpr(SingleValueStmtExpr *E, SGFContext C) { + if (E->getStmtKind() == SingleValueStmtExpr::Kind::For) { + auto *decl = E->getForExpressionPreamble()->ForAccumulatorDecl; + auto *binding = E->getForExpressionPreamble()->ForAccumulatorBinding; + SGF.visit(decl); + SGF.visit(binding); + SGF.emitStmt(E->getStmt()); + + return SGF.emitRValueForDecl(E, ConcreteDeclRef(decl), E->getType(), + AccessSemantics::Ordinary); + } + auto emitStmt = [&]() { SGF.emitStmt(E->getStmt()); diff --git a/lib/Sema/CSSyntacticElement.cpp b/lib/Sema/CSSyntacticElement.cpp index 0180f5a19507d..9ee3f0eb62702 100644 --- a/lib/Sema/CSSyntacticElement.cpp +++ b/lib/Sema/CSSyntacticElement.cpp @@ -1292,8 +1292,12 @@ class SyntacticElementConstraintGenerator // First check to make sure the ThenStmt is in a valid position. SmallVector validThenStmts; - if (auto SVE = context.getAsSingleValueStmtExpr()) + if (auto SVE = context.getAsSingleValueStmtExpr()) { (void)SVE.get()->getThenStmts(validThenStmts); + if (SVE.get()->getStmtKind() == SingleValueStmtExpr::Kind::For) { + contextInfo = std::nullopt; + } + } if (!llvm::is_contained(validThenStmts, thenStmt)) { auto *thenLoc = cs.getConstraintLocator(thenStmt); @@ -1488,8 +1492,37 @@ bool ConstraintSystem::generateConstraints(SingleValueStmtExpr *E) { auto &ctx = getASTContext(); auto *loc = getConstraintLocator(E); - Type resultTy = createTypeVariable(loc, /*options*/ 0); - setType(E, resultTy); + Type resultType = createTypeVariable(loc, /*options*/ 0); + setType(E, resultType); + + if (E->getStmtKind() == SingleValueStmtExpr::Kind::For) { + auto *rrcProtocol = + ctx.getProtocol(KnownProtocolKind::RangeReplaceableCollection); + auto *sequenceProtocol = ctx.getProtocol(KnownProtocolKind::Sequence); + + addConstraint(ConstraintKind::ConformsTo, resultType, + rrcProtocol->getDeclaredInterfaceType(), loc); + Type elementTypeVar = createTypeVariable(loc, /*options*/ 0); + Type elementType = DependentMemberType::get( + resultType, sequenceProtocol->getAssociatedType(ctx.Id_Element)); + + addConstraint(ConstraintKind::Bind, elementTypeVar, elementType, loc); + addConstraint(ConstraintKind::Defaultable, resultType, + ArraySliceType::get(elementTypeVar), loc); + + auto *binding = E->getForExpressionPreamble()->ForAccumulatorBinding; + + auto *initializer = binding->getInit(0); + auto target = SyntacticElementTarget::forInitialization(initializer, Type(), + binding, 0, false); + setTargetFor({binding, 0}, target); + + if (generateConstraints(target)) { + return true; + } + + addConstraint(ConstraintKind::Bind, getType(initializer), resultType, loc); + } // Propagate the implied result kind from the if/switch expression itself // into the branches. @@ -1513,21 +1546,24 @@ bool ConstraintSystem::generateConstraints(SingleValueStmtExpr *E) { auto *loc = getConstraintLocator( E, {LocatorPathElt::SingleValueStmtResult(idx), ctpElt}); - ContextualTypeInfo info(resultTy, CTP_SingleValueStmtBranch, loc); + ContextualTypeInfo info(resultType, CTP_SingleValueStmtBranch, loc); setContextualInfo(result, info); } TypeJoinExpr *join = nullptr; - if (branches.empty()) { - // If we only have statement branches, the expression is typed as Void. This - // should only be the case for 'if' and 'switch' statements that must be - // expressions that have branches that all end in a throw, and we'll warn - // that we've inferred Void. - addConstraint(ConstraintKind::Bind, resultTy, ctx.getVoidType(), loc); - } else { - // Otherwise, we join the result types for each of the branches. - join = TypeJoinExpr::forBranchesOfSingleValueStmtExpr( - ctx, resultTy, E, AllocationArena::ConstraintSolver); + + if (E->getStmtKind() != SingleValueStmtExpr::Kind::For) { + if (branches.empty()) { + // If we only have statement branches, the expression is typed as Void. + // This should only be the case for 'if' and 'switch' statements that must + // be expressions that have branches that all end in a throw, and we'll + // warn that we've inferred Void. + addConstraint(ConstraintKind::Bind, resultType, ctx.getVoidType(), loc); + } else { + // Otherwise, we join the result types for each of the branches. + join = TypeJoinExpr::forBranchesOfSingleValueStmtExpr( + ctx, resultType, E, AllocationArena::ConstraintSolver); + } } // If this is an implied return in a closure, we need to account for the fact @@ -1568,11 +1604,11 @@ bool ConstraintSystem::generateConstraints(SingleValueStmtExpr *E) { if (auto *closureTy = getClosureTypeIfAvailable(CE)) { auto closureResultTy = closureTy->getResult(); auto *bindToClosure = Constraint::create( - *this, ConstraintKind::Bind, resultTy, closureResultTy, loc); + *this, ConstraintKind::Bind, resultType, closureResultTy, loc); bindToClosure->setFavored(); - auto *bindToVoid = Constraint::create(*this, ConstraintKind::Bind, - resultTy, ctx.getVoidType(), loc); + auto *bindToVoid = Constraint::create( + *this, ConstraintKind::Bind, resultType, ctx.getVoidType(), loc); addDisjunctionConstraint({bindToClosure, bindToVoid}, loc); } @@ -2221,7 +2257,9 @@ class SyntacticElementSolutionApplication // not the branch result type. This is necessary as there may be // an additional conversion required for the branch. auto target = solution.getTargetFor(thenStmt->getResult()); - target->setExprConversionType(ty); + if (SVE.get()->getStmtKind() != SingleValueStmtExpr::Kind::For) { + target->setExprConversionType(ty); + } auto *resultExpr = thenStmt->getResult(); if (auto newResultTarget = rewriter.rewriteTarget(*target)) @@ -2663,6 +2701,18 @@ bool ConstraintSystem::applySolutionToSingleValueStmt( if (!stmt || application.hadError) return true; + if (SVE->getStmtKind() == SingleValueStmtExpr::Kind::For) { + auto *binding = SVE->getForExpressionPreamble()->ForAccumulatorBinding; + auto target = getTargetFor({binding, 0}).value(); + + auto newTarget = rewriter.rewriteTarget(target); + if (!newTarget) { + return true; + } + + binding->setInit(0, newTarget->getAsExpr()); + } + SVE->setStmt(stmt); return false; } diff --git a/lib/Sema/PreCheckTarget.cpp b/lib/Sema/PreCheckTarget.cpp index 6f297a7a2f1c5..0c598db8155b2 100644 --- a/lib/Sema/PreCheckTarget.cpp +++ b/lib/Sema/PreCheckTarget.cpp @@ -2,7 +2,7 @@ // // This source file is part of the Swift.org open source project // -// Copyright (c) 2014 - 2018 Apple Inc. and the Swift project authors +// Copyright (c) 2014 - 2025 Apple Inc. and the Swift project authors // Licensed under Apache License v2.0 with Runtime Library Exception // // See https://swift.org/LICENSE.txt for license information @@ -11,7 +11,8 @@ //===----------------------------------------------------------------------===// // // Pre-checking resolves unqualified name references, type expressions and -// operators. +// operators. Target in this context refers to `SyntacticElementTarget`, which +// is a unit of type-checking. // //===----------------------------------------------------------------------===// @@ -1190,6 +1191,11 @@ class PreCheckTarget final : public ASTWalker { /// For the given statement, mark any valid SingleValueStmtExpr children. void markAnyValidSingleValueStmts(Stmt *S); + /// For the given single value expr that's a `for`-expression, run the + /// desugaring transformation. For all other kinds of single value statement + /// expressions do nothing. + void transformForExpression(SingleValueStmtExpr *E); + PreCheckTarget(DeclContext *dc) : Ctx(dc->getASTContext()), DC(dc) {} public: @@ -1386,8 +1392,10 @@ class PreCheckTarget final : public ASTWalker { if (auto *assignment = dyn_cast(expr)) markAcceptableDiscardExprs(assignment->getDest()); - if (auto *SVE = dyn_cast(expr)) + if (auto *SVE = dyn_cast(expr)) { checkSingleValueStmtExpr(SVE); + transformForExpression(SVE); + } return finish(true, expr); } @@ -1700,6 +1708,59 @@ void PreCheckTarget::checkSingleValueStmtExpr(SingleValueStmtExpr *SVE) { } } +void PreCheckTarget::transformForExpression(SingleValueStmtExpr *SVE) { + if (SVE->getStmtKind() != SingleValueStmtExpr::Kind::For) { + return; + } + + auto *declCtx = SVE->getDeclContext(); + auto &astCtx = declCtx->getASTContext(); + + auto sveLoc = SVE->getLoc(); + + auto *varDecl = new (astCtx) + VarDecl(false, VarDecl::Introducer::Var, sveLoc, + astCtx.getIdentifier("$forExpressionResult"), declCtx); + + auto namedPattern = NamedPattern::createImplicit(astCtx, varDecl); + + auto *initFunc = new (astCtx) UnresolvedMemberExpr( + sveLoc, DeclNameLoc(), DeclNameRef(DeclBaseName::createConstructor()), + true); + auto *callExpr = CallExpr::createImplicitEmpty(astCtx, initFunc); + auto *initExpr = + new (astCtx) UnresolvedMemberChainResultExpr(callExpr, initFunc); + + auto *bindingDecl = PatternBindingDecl::createImplicit( + astCtx, StaticSpellingKind::None, namedPattern, initExpr, declCtx); + + SVE->setForExpressionPreamble({varDecl, bindingDecl}); + + // For-expressions always have a single branch. + SmallVector scratch; + for (auto *branch : SVE->getBranches(scratch)) { + auto *BS = dyn_cast(branch); + if (!BS) + continue; + + auto &result = BS->getElements().back(); + if (auto stmt = result.dyn_cast()) { + if (auto *then = dyn_cast(stmt)) { + auto *declRefExpr = + new (astCtx) DeclRefExpr(varDecl, DeclNameLoc(), true); + auto *dotExpr = new (astCtx) UnresolvedDotExpr( + declRefExpr, SourceLoc(), + DeclNameRef(DeclBaseName(astCtx.Id_append)), DeclNameLoc(), true); + auto *argumentList = ArgumentList::createImplicit( + astCtx, {Argument::unlabeled(then->getResult())}); + auto *callExpr = + CallExpr::createImplicit(astCtx, dotExpr, argumentList); + then->setResult(callExpr); + } + } + } +} + void PreCheckTarget::markAnyValidSingleValueStmts(Expr *E) { auto findAssignment = [&]() -> AssignExpr * { // Don't consider assignments if we have a parent expression (as otherwise diff --git a/lib/Sema/TypeCheckEffects.cpp b/lib/Sema/TypeCheckEffects.cpp index c23bd67e24db1..1cf491b90686f 100644 --- a/lib/Sema/TypeCheckEffects.cpp +++ b/lib/Sema/TypeCheckEffects.cpp @@ -4036,6 +4036,12 @@ class CheckEffectsCoverage : public EffectsHandlingWalker ContextScope scope(*this, /*newContext*/ std::nullopt); scope.setCoverageForSingleValueStmtExpr(); SVE->getStmt()->walk(*this); + + if (auto preamble = SVE->getForExpressionPreamble()) { + preamble->ForAccumulatorDecl->walk(*this); + preamble->ForAccumulatorBinding->walk(*this); + } + scope.preserveCoverageFromSingleValueStmtExpr(); return ShouldNotRecurse; } diff --git a/lib/Sema/TypeCheckStmt.cpp b/lib/Sema/TypeCheckStmt.cpp index b1fe8cf1e426c..430d699fe3e94 100644 --- a/lib/Sema/TypeCheckStmt.cpp +++ b/lib/Sema/TypeCheckStmt.cpp @@ -3224,6 +3224,12 @@ IsSingleValueStmtRequest::evaluate(Evaluator &eval, const Stmt *S, return areBranchesValidForSingleValueStmt(ctx, DCS, DCS->getBranches(scratch)); } + if (auto *FS = dyn_cast(S)) { + if (!ctx.LangOpts.hasFeature(Feature::ForExpressions)) + return IsSingleValueStmtResult::unhandledStmt(); + + return areBranchesValidForSingleValueStmt(ctx, FS, FS->getBody()); + } return IsSingleValueStmtResult::unhandledStmt(); } diff --git a/test/stmt/for-expr.swift b/test/stmt/for-expr.swift new file mode 100644 index 0000000000000..b87b46e809b19 --- /dev/null +++ b/test/stmt/for-expr.swift @@ -0,0 +1,15 @@ +// RUN: %target-run-simple-swift(-enable-experimental-feature ForExpressions) | %FileCheck %s + +// REQUIRES: swift_feature_ForExpressions + +func f() -> String { + for (i, x) in "hello".enumerated() { + if i % 2 == 0 { + x.uppercased() + } else { + "*skip*" + + } + } +} +print(f()) // CHECK: H*skip*L*skip*O