Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions include/swift/AST/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -6498,16 +6498,20 @@ class KeyPathDotExpr : public Expr {
}
};

struct ForCollectionInit {
VarDecl *ForAccumulatorDecl;
PatternBindingDecl *ForAccumulatorBinding;
};

/// An expression that may wrap a statement which produces a single value.
class SingleValueStmtExpr : public Expr {
public:
enum class Kind {
If, Switch, Do, DoCatch
};
enum class Kind { If, Switch, Do, DoCatch, For };

private:
Stmt *S;
DeclContext *DC;
std::optional<ForCollectionInit> ForExpressionPreamble;

SingleValueStmtExpr(Stmt *S, DeclContext *DC)
: Expr(ExprKind::SingleValueStmt, /*isImplicit*/ true), S(S), DC(DC) {}
Expand Down Expand Up @@ -6572,6 +6576,14 @@ class SingleValueStmtExpr : public Expr {

SourceRange getSourceRange() const;

std::optional<ForCollectionInit> getForExpressionPreamble() const {
return this->ForExpressionPreamble;
}

void setForExpressionPreamble(ForCollectionInit newPreamble) {
this->ForExpressionPreamble = newPreamble;
}

static bool classof(const Expr *E) {
return E->getKind() == ExprKind::SingleValueStmt;
}
Expand Down
1 change: 1 addition & 0 deletions include/swift/AST/KnownIdentifiers.def
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ IDENTIFIER(alloc)
IDENTIFIER(allocWithZone)
IDENTIFIER(allZeros)
IDENTIFIER(accumulated)
IDENTIFIER(append)
IDENTIFIER(ActorType)
IDENTIFIER(Any)
IDENTIFIER(ArrayLiteralElement)
Expand Down
3 changes: 3 additions & 0 deletions include/swift/Basic/Features.def
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,9 @@ EXPERIMENTAL_FEATURE(ThenStatements, false)
/// Enable 'do' expressions.
EXPERIMENTAL_FEATURE(DoExpressions, false)

/// Enable 'for' expressions.
EXPERIMENTAL_FEATURE(ForExpressions, false)

/// Enable implicitly treating the last expression in a function, closure,
/// and 'if'/'switch' expression as the result.
EXPERIMENTAL_FEATURE(ImplicitLastExprResults, false)
Expand Down
7 changes: 4 additions & 3 deletions include/swift/Sema/SyntacticElementTarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
//
//===----------------------------------------------------------------------===//
//
// This file defines the SyntacticElementTarget class.
// This file defines the SyntacticElementTarget class (a unit of
// type-checking).
//
//===----------------------------------------------------------------------===//

Expand Down Expand Up @@ -59,8 +60,8 @@ struct PackIterationInfo {
/// within the constraint system.
using ForEachStmtInfo = TaggedUnion<SequenceIterationInfo, PackIterationInfo>;

/// Describes the target to which a constraint system's solution can be
/// applied.
/// Describes the target (a unit of type-checking) to which a constraint
/// system's solution can be applied.
class SyntacticElementTarget {
public:
enum class Kind {
Expand Down
6 changes: 6 additions & 0 deletions lib/AST/ASTDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4496,6 +4496,12 @@ class PrintExpr : public ExprVisitor<PrintExpr, void, Label>,
void visitSingleValueStmtExpr(SingleValueStmtExpr *E, Label label) {
printCommon(E, "single_value_stmt_expr", label);
printDeclContext(E);
if (auto preamble = E->getForExpressionPreamble()) {
printRec(preamble->ForAccumulatorDecl,
Label::optional("for_preamble_accumulator_decl"));
printRec(preamble->ForAccumulatorBinding,
Label::optional("for_preamble_accumulator_binding"));
}
printRec(E->getStmt(), &E->getDeclContext()->getASTContext(),
Label::optional("stmt"));
printFoot();
Expand Down
10 changes: 10 additions & 0 deletions lib/AST/ASTWalker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,16 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,
Expr *visitKeyPathDotExpr(KeyPathDotExpr *E) { return E; }

Expr *visitSingleValueStmtExpr(SingleValueStmtExpr *E) {
if (auto preamble = E->getForExpressionPreamble()) {
if (doIt(preamble->ForAccumulatorDecl)) {
return nullptr;
}

if (doIt(preamble->ForAccumulatorBinding)) {
return nullptr;
}
}

if (auto *S = doIt(E->getStmt())) {
E->setStmt(S);
} else {
Expand Down
5 changes: 5 additions & 0 deletions lib/AST/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2763,6 +2763,8 @@ SingleValueStmtExpr::Kind SingleValueStmtExpr::getStmtKind() const {
return Kind::Do;
case StmtKind::DoCatch:
return Kind::DoCatch;
case StmtKind::ForEach:
return Kind::For;
default:
llvm_unreachable("Unhandled kind!");
}
Expand All @@ -2781,6 +2783,9 @@ SingleValueStmtExpr::getBranches(SmallVectorImpl<Stmt *> &scratch) const {
return scratch;
case Kind::DoCatch:
return cast<DoCatchStmt>(getStmt())->getBranches(scratch);
case Kind::For:
scratch.push_back(cast<ForEachStmt>(getStmt())->getBody());
return scratch;
}
llvm_unreachable("Unhandled case in switch!");
}
Expand Down
1 change: 1 addition & 0 deletions lib/AST/FeatureSet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ UNINTERESTING_FEATURE(RegionBasedIsolation)
UNINTERESTING_FEATURE(PlaygroundExtendedCallbacks)
UNINTERESTING_FEATURE(ThenStatements)
UNINTERESTING_FEATURE(DoExpressions)
UNINTERESTING_FEATURE(ForExpressions)
UNINTERESTING_FEATURE(ImplicitLastExprResults)
UNINTERESTING_FEATURE(RawLayout)
UNINTERESTING_FEATURE(Embedded)
Expand Down
11 changes: 11 additions & 0 deletions lib/SILGen/SILGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2515,6 +2515,17 @@ RValue RValueEmitter::visitEnumIsCaseExpr(EnumIsCaseExpr *E,

RValue RValueEmitter::visitSingleValueStmtExpr(SingleValueStmtExpr *E,
SGFContext C) {
if (E->getStmtKind() == SingleValueStmtExpr::Kind::For) {
auto *decl = E->getForExpressionPreamble()->ForAccumulatorDecl;
auto *binding = E->getForExpressionPreamble()->ForAccumulatorBinding;
SGF.visit(decl);
SGF.visit(binding);
SGF.emitStmt(E->getStmt());

return SGF.emitRValueForDecl(E, ConcreteDeclRef(decl), E->getType(),
AccessSemantics::Ordinary);
}

auto emitStmt = [&]() {
SGF.emitStmt(E->getStmt());

Expand Down
86 changes: 68 additions & 18 deletions lib/Sema/CSSyntacticElement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1292,8 +1292,12 @@ class SyntacticElementConstraintGenerator

// First check to make sure the ThenStmt is in a valid position.
SmallVector<ThenStmt *, 4> validThenStmts;
if (auto SVE = context.getAsSingleValueStmtExpr())
if (auto SVE = context.getAsSingleValueStmtExpr()) {
(void)SVE.get()->getThenStmts(validThenStmts);
if (SVE.get()->getStmtKind() == SingleValueStmtExpr::Kind::For) {
contextInfo = std::nullopt;
}
}

if (!llvm::is_contained(validThenStmts, thenStmt)) {
auto *thenLoc = cs.getConstraintLocator(thenStmt);
Expand Down Expand Up @@ -1488,8 +1492,37 @@ bool ConstraintSystem::generateConstraints(SingleValueStmtExpr *E) {
auto &ctx = getASTContext();

auto *loc = getConstraintLocator(E);
Type resultTy = createTypeVariable(loc, /*options*/ 0);
setType(E, resultTy);
Type resultType = createTypeVariable(loc, /*options*/ 0);
setType(E, resultType);

if (E->getStmtKind() == SingleValueStmtExpr::Kind::For) {
auto *rrcProtocol =
ctx.getProtocol(KnownProtocolKind::RangeReplaceableCollection);
auto *sequenceProtocol = ctx.getProtocol(KnownProtocolKind::Sequence);

addConstraint(ConstraintKind::ConformsTo, resultType,
rrcProtocol->getDeclaredInterfaceType(), loc);
Type elementTypeVar = createTypeVariable(loc, /*options*/ 0);
Type elementType = DependentMemberType::get(
resultType, sequenceProtocol->getAssociatedType(ctx.Id_Element));

addConstraint(ConstraintKind::Bind, elementTypeVar, elementType, loc);
addConstraint(ConstraintKind::Defaultable, resultType,
ArraySliceType::get(elementTypeVar), loc);

auto *binding = E->getForExpressionPreamble()->ForAccumulatorBinding;

auto *initializer = binding->getInit(0);
auto target = SyntacticElementTarget::forInitialization(initializer, Type(),
binding, 0, false);
setTargetFor({binding, 0}, target);

if (generateConstraints(target)) {
return true;
}

addConstraint(ConstraintKind::Bind, getType(initializer), resultType, loc);
}

// Propagate the implied result kind from the if/switch expression itself
// into the branches.
Expand All @@ -1513,21 +1546,24 @@ bool ConstraintSystem::generateConstraints(SingleValueStmtExpr *E) {
auto *loc = getConstraintLocator(
E, {LocatorPathElt::SingleValueStmtResult(idx), ctpElt});

ContextualTypeInfo info(resultTy, CTP_SingleValueStmtBranch, loc);
ContextualTypeInfo info(resultType, CTP_SingleValueStmtBranch, loc);
setContextualInfo(result, info);
}

TypeJoinExpr *join = nullptr;
if (branches.empty()) {
// If we only have statement branches, the expression is typed as Void. This
// should only be the case for 'if' and 'switch' statements that must be
// expressions that have branches that all end in a throw, and we'll warn
// that we've inferred Void.
addConstraint(ConstraintKind::Bind, resultTy, ctx.getVoidType(), loc);
} else {
// Otherwise, we join the result types for each of the branches.
join = TypeJoinExpr::forBranchesOfSingleValueStmtExpr(
ctx, resultTy, E, AllocationArena::ConstraintSolver);

if (E->getStmtKind() != SingleValueStmtExpr::Kind::For) {
if (branches.empty()) {
// If we only have statement branches, the expression is typed as Void.
// This should only be the case for 'if' and 'switch' statements that must
// be expressions that have branches that all end in a throw, and we'll
// warn that we've inferred Void.
addConstraint(ConstraintKind::Bind, resultType, ctx.getVoidType(), loc);
} else {
// Otherwise, we join the result types for each of the branches.
join = TypeJoinExpr::forBranchesOfSingleValueStmtExpr(
ctx, resultType, E, AllocationArena::ConstraintSolver);
}
}

// If this is an implied return in a closure, we need to account for the fact
Expand Down Expand Up @@ -1568,11 +1604,11 @@ bool ConstraintSystem::generateConstraints(SingleValueStmtExpr *E) {
if (auto *closureTy = getClosureTypeIfAvailable(CE)) {
auto closureResultTy = closureTy->getResult();
auto *bindToClosure = Constraint::create(
*this, ConstraintKind::Bind, resultTy, closureResultTy, loc);
*this, ConstraintKind::Bind, resultType, closureResultTy, loc);
bindToClosure->setFavored();

auto *bindToVoid = Constraint::create(*this, ConstraintKind::Bind,
resultTy, ctx.getVoidType(), loc);
auto *bindToVoid = Constraint::create(
*this, ConstraintKind::Bind, resultType, ctx.getVoidType(), loc);

addDisjunctionConstraint({bindToClosure, bindToVoid}, loc);
}
Expand Down Expand Up @@ -2221,7 +2257,9 @@ class SyntacticElementSolutionApplication
// not the branch result type. This is necessary as there may be
// an additional conversion required for the branch.
auto target = solution.getTargetFor(thenStmt->getResult());
target->setExprConversionType(ty);
if (SVE.get()->getStmtKind() != SingleValueStmtExpr::Kind::For) {
target->setExprConversionType(ty);
}

auto *resultExpr = thenStmt->getResult();
if (auto newResultTarget = rewriter.rewriteTarget(*target))
Expand Down Expand Up @@ -2663,6 +2701,18 @@ bool ConstraintSystem::applySolutionToSingleValueStmt(
if (!stmt || application.hadError)
return true;

if (SVE->getStmtKind() == SingleValueStmtExpr::Kind::For) {
auto *binding = SVE->getForExpressionPreamble()->ForAccumulatorBinding;
auto target = getTargetFor({binding, 0}).value();

auto newTarget = rewriter.rewriteTarget(target);
if (!newTarget) {
return true;
}

binding->setInit(0, newTarget->getAsExpr());
}

SVE->setStmt(stmt);
return false;
}
Expand Down
67 changes: 64 additions & 3 deletions lib/Sema/PreCheckTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2014 - 2018 Apple Inc. and the Swift project authors
// Copyright (c) 2014 - 2025 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
Expand All @@ -11,7 +11,8 @@
//===----------------------------------------------------------------------===//
//
// Pre-checking resolves unqualified name references, type expressions and
// operators.
// operators. Target in this context refers to `SyntacticElementTarget`, which
// is a unit of type-checking.
//
//===----------------------------------------------------------------------===//

Expand Down Expand Up @@ -1190,6 +1191,11 @@ class PreCheckTarget final : public ASTWalker {
/// For the given statement, mark any valid SingleValueStmtExpr children.
void markAnyValidSingleValueStmts(Stmt *S);

/// For the given single value expr that's a `for`-expression, run the
/// desugaring transformation. For all other kinds of single value statement
/// expressions do nothing.
void transformForExpression(SingleValueStmtExpr *E);

PreCheckTarget(DeclContext *dc) : Ctx(dc->getASTContext()), DC(dc) {}

public:
Expand Down Expand Up @@ -1386,8 +1392,10 @@ class PreCheckTarget final : public ASTWalker {
if (auto *assignment = dyn_cast<AssignExpr>(expr))
markAcceptableDiscardExprs(assignment->getDest());

if (auto *SVE = dyn_cast<SingleValueStmtExpr>(expr))
if (auto *SVE = dyn_cast<SingleValueStmtExpr>(expr)) {
checkSingleValueStmtExpr(SVE);
transformForExpression(SVE);
}

return finish(true, expr);
}
Expand Down Expand Up @@ -1700,6 +1708,59 @@ void PreCheckTarget::checkSingleValueStmtExpr(SingleValueStmtExpr *SVE) {
}
}

void PreCheckTarget::transformForExpression(SingleValueStmtExpr *SVE) {
if (SVE->getStmtKind() != SingleValueStmtExpr::Kind::For) {
return;
}

auto *declCtx = SVE->getDeclContext();
auto &astCtx = declCtx->getASTContext();

auto sveLoc = SVE->getLoc();

auto *varDecl = new (astCtx)
VarDecl(false, VarDecl::Introducer::Var, sveLoc,
astCtx.getIdentifier("$forExpressionResult"), declCtx);

auto namedPattern = NamedPattern::createImplicit(astCtx, varDecl);

auto *initFunc = new (astCtx) UnresolvedMemberExpr(
sveLoc, DeclNameLoc(), DeclNameRef(DeclBaseName::createConstructor()),
true);
auto *callExpr = CallExpr::createImplicitEmpty(astCtx, initFunc);
auto *initExpr =
new (astCtx) UnresolvedMemberChainResultExpr(callExpr, initFunc);

auto *bindingDecl = PatternBindingDecl::createImplicit(
astCtx, StaticSpellingKind::None, namedPattern, initExpr, declCtx);

SVE->setForExpressionPreamble({varDecl, bindingDecl});

// For-expressions always have a single branch.
SmallVector<Stmt *, 1> scratch;
for (auto *branch : SVE->getBranches(scratch)) {
auto *BS = dyn_cast<BraceStmt>(branch);
if (!BS)
continue;

auto &result = BS->getElements().back();
if (auto stmt = result.dyn_cast<Stmt *>()) {
if (auto *then = dyn_cast<ThenStmt>(stmt)) {
auto *declRefExpr =
new (astCtx) DeclRefExpr(varDecl, DeclNameLoc(), true);
auto *dotExpr = new (astCtx) UnresolvedDotExpr(
declRefExpr, SourceLoc(),
DeclNameRef(DeclBaseName(astCtx.Id_append)), DeclNameLoc(), true);
auto *argumentList = ArgumentList::createImplicit(
astCtx, {Argument::unlabeled(then->getResult())});
auto *callExpr =
CallExpr::createImplicit(astCtx, dotExpr, argumentList);
then->setResult(callExpr);
}
}
}
}

void PreCheckTarget::markAnyValidSingleValueStmts(Expr *E) {
auto findAssignment = [&]() -> AssignExpr * {
// Don't consider assignments if we have a parent expression (as otherwise
Expand Down
Loading