Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 19 additions & 42 deletions include/swift/AST/Stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ class alignas(8) Stmt : public ASTAllocated<Stmt> {
NumElements : 32
);

SWIFT_INLINE_BITFIELD_FULL(CaseStmt, Stmt, 32,
SWIFT_INLINE_BITFIELD_FULL(CaseStmt, Stmt, 16+32,
: NumPadBits,
NumCaseBodyVars : 16,
NumPatterns : 32
);

Expand Down Expand Up @@ -1210,8 +1211,8 @@ enum CaseParentKind { Switch, DoCatch };
///
class CaseStmt final
: public Stmt,
private llvm::TrailingObjects<CaseStmt, FallthroughStmt *,
CaseLabelItem> {
private llvm::TrailingObjects<CaseStmt, FallthroughStmt *, CaseLabelItem,
VarDecl *> {
friend TrailingObjects;

Stmt *ParentStmt = nullptr;
Expand All @@ -1222,15 +1223,17 @@ class CaseStmt final

llvm::PointerIntPair<BraceStmt *, 1, bool> BodyAndHasFallthrough;

std::optional<MutableArrayRef<VarDecl *>> CaseBodyVariables;

CaseStmt(CaseParentKind ParentKind, SourceLoc ItemIntroducerLoc,
ArrayRef<CaseLabelItem> CaseLabelItems, SourceLoc UnknownAttrLoc,
SourceLoc ItemTerminatorLoc, BraceStmt *Body,
std::optional<MutableArrayRef<VarDecl *>> CaseBodyVariables,
std::optional<bool> Implicit,
ArrayRef<VarDecl *> CaseBodyVariables, std::optional<bool> Implicit,
NullablePtr<FallthroughStmt> fallthroughStmt);

MutableArrayRef<VarDecl *> getCaseBodyVariablesBuffer() {
return {getTrailingObjects<VarDecl *>(),
static_cast<size_t>(Bits.CaseStmt.NumCaseBodyVars)};
}

public:
/// Create a parsed 'case'/'default' for 'switch' statement.
static CaseStmt *
Expand All @@ -1244,13 +1247,17 @@ class CaseStmt final
ArrayRef<CaseLabelItem> CaseLabelItems,
BraceStmt *Body);

static CaseStmt *
createImplicit(ASTContext &ctx, CaseParentKind parentKind,
ArrayRef<CaseLabelItem> caseLabelItems, BraceStmt *body,
NullablePtr<FallthroughStmt> fallthroughStmt = nullptr);

static CaseStmt *
create(ASTContext &C, CaseParentKind ParentKind, SourceLoc ItemIntroducerLoc,
ArrayRef<CaseLabelItem> CaseLabelItems, SourceLoc UnknownAttrLoc,
SourceLoc ItemTerminatorLoc, BraceStmt *Body,
std::optional<MutableArrayRef<VarDecl *>> CaseBodyVariables,
std::optional<bool> Implicit = std::nullopt,
NullablePtr<FallthroughStmt> fallthroughStmt = nullptr);
ArrayRef<VarDecl *> CaseBodyVariables, std::optional<bool> Implicit,
NullablePtr<FallthroughStmt> fallthroughStmt);

CaseParentKind getParentKind() const { return ParentKind; }

Expand Down Expand Up @@ -1293,7 +1300,7 @@ class CaseStmt final
void setBody(BraceStmt *body) { BodyAndHasFallthrough.setPointer(body); }

/// True if the case block declares any patterns with local variable bindings.
bool hasBoundDecls() const { return CaseBodyVariables.has_value(); }
bool hasCaseBodyVariables() const { return !getCaseBodyVariables().empty(); }

/// Get the source location of the 'case', 'default', or 'catch' of the first
/// label.
Expand Down Expand Up @@ -1345,38 +1352,8 @@ class CaseStmt final
}

/// Return an ArrayRef containing the case body variables of this CaseStmt.
///
/// Asserts if case body variables was not explicitly initialized. In contexts
/// where one wants a non-asserting version, \see
/// getCaseBodyVariablesOrEmptyArray.
ArrayRef<VarDecl *> getCaseBodyVariables() const {
ArrayRef<VarDecl *> a = *CaseBodyVariables;
return a;
}

bool hasCaseBodyVariables() const { return CaseBodyVariables.has_value(); }

/// Return an MutableArrayRef containing the case body variables of this
/// CaseStmt.
///
/// Asserts if case body variables was not explicitly initialized. In contexts
/// where one wants a non-asserting version, \see
/// getCaseBodyVariablesOrEmptyArray.
MutableArrayRef<VarDecl *> getCaseBodyVariables() {
return *CaseBodyVariables;
}

ArrayRef<VarDecl *> getCaseBodyVariablesOrEmptyArray() const {
if (!CaseBodyVariables)
return ArrayRef<VarDecl *>();
ArrayRef<VarDecl *> a = *CaseBodyVariables;
return a;
}

MutableArrayRef<VarDecl *> getCaseBodyVariablesOrEmptyArray() {
if (!CaseBodyVariables)
return MutableArrayRef<VarDecl *>();
return *CaseBodyVariables;
return const_cast<CaseStmt *>(this)->getCaseBodyVariablesBuffer();
}

/// Find the next case statement within the same 'switch' or 'do-catch',
Expand Down
6 changes: 3 additions & 3 deletions lib/AST/ASTScopeLookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,10 @@ bool CaseLabelItemScope::lookupLocalsOrMembers(DeclConsumer consumer) const {
}

bool CaseStmtBodyScope::lookupLocalsOrMembers(DeclConsumer consumer) const {
for (auto *var : stmt->getCaseBodyVariablesOrEmptyArray())
for (auto *var : stmt->getCaseBodyVariables()) {
if (consumer.consume({var}))
return true;

return true;
}
return false;
}

Expand Down
2 changes: 1 addition & 1 deletion lib/AST/ASTVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2795,7 +2795,7 @@ class Verifier : public ASTWalker {
// guarantee that all case label items bind corresponding patterns and
// the case body var decls of a case stmt are created from the var decls
// of the first case label items.
if (!caseStmt->hasBoundDecls()) {
if (!caseStmt->hasCaseBodyVariables()) {
Out << "parent CaseStmt of VarDecl does not have any case body "
"decls?!\n";
abort();
Expand Down
94 changes: 25 additions & 69 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8159,50 +8159,6 @@ SourceRange AbstractStorageDecl::getTypeSourceRangeForDiagnostics() const {
return SourceRange();
}

static std::optional<std::pair<CaseStmt *, Pattern *>>
findParentPatternCaseStmtAndPattern(const VarDecl *inputVD) {
auto getMatchingPattern = [&](CaseStmt *cs) -> Pattern * {
// Check if inputVD is in our case body var decls if we have any. If we do,
// treat its pattern as our first case label item pattern.
for (auto *vd : cs->getCaseBodyVariablesOrEmptyArray()) {
if (vd == inputVD) {
return cs->getMutableCaseLabelItems().front().getPattern();
}
}

// Then check the rest of our case label items.
for (auto &item : cs->getMutableCaseLabelItems()) {
if (item.getPattern()->containsVarDecl(inputVD)) {
return item.getPattern();
}
}

// Otherwise return false if we do not find anything.
return nullptr;
};

// First find our canonical var decl. This is the VarDecl corresponding to the
// first case label item of the first case block in the fallthrough chain that
// our case block is within. Grab the case stmt associated with that var decl
// and start traveling down the fallthrough chain looking for the case
// statement that the input VD belongs to by using getMatchingPattern().
auto *canonicalVD = inputVD->getCanonicalVarDecl();
auto *caseStmt =
dyn_cast_or_null<CaseStmt>(canonicalVD->getParentPatternStmt());
if (!caseStmt)
return std::nullopt;

if (auto *p = getMatchingPattern(caseStmt))
return std::make_pair(caseStmt, p);

while ((caseStmt = caseStmt->getFallthroughDest().getPtrOrNull())) {
if (auto *p = getMatchingPattern(caseStmt))
return std::make_pair(caseStmt, p);
}

return std::nullopt;
}

VarDecl *VarDecl::getCanonicalVarDecl() const {
// Any var decl without a parent var decl is canonical. This means that before
// type checking, all var decls are canonical.
Expand All @@ -8227,16 +8183,7 @@ VarDecl *VarDecl::getCanonicalVarDecl() const {
}

Stmt *VarDecl::getRecursiveParentPatternStmt() const {
// If our parent is already a pattern stmt, just return that.
if (auto *stmt = getParentPatternStmt())
return stmt;

// Otherwise, see if we have a parent var decl. If we do not, then return
// nullptr. Otherwise, return the case stmt that we found.
auto result = findParentPatternCaseStmtAndPattern(this);
if (!result.has_value())
return nullptr;
return result->first;
return getCanonicalVarDecl()->getParentPatternStmt();
}

/// Return the Pattern involved in initializing this VarDecl. Recall that the
Expand All @@ -8256,17 +8203,34 @@ Pattern *VarDecl::getParentPattern() const {
}

// If this is a statement parent, dig the pattern out of it.
if (auto *stmt = getParentPatternStmt()) {
const auto *canonicalVD = getCanonicalVarDecl();
if (auto *stmt = canonicalVD->getParentPatternStmt()) {
if (auto *FES = dyn_cast<ForEachStmt>(stmt))
return FES->getPattern();

if (auto *cs = dyn_cast<CaseStmt>(stmt)) {
// In a case statement, search for the pattern that contains it. This is
// a bit silly, because you can't have something like "case x, y:" anyway.
for (auto items : cs->getCaseLabelItems()) {
if (items.getPattern()->containsVarDecl(this))
return items.getPattern();
// In a case statement, search for the pattern that contains it.
auto findPattern = [](CaseStmt *cs, const VarDecl *VD) -> Pattern * {
for (auto items : cs->getCaseLabelItems()) {
if (items.getPattern()->containsVarDecl(VD))
return items.getPattern();
}
return nullptr;
};
if (auto *P = findPattern(cs, this))
return P;

// If it's not in the CaseStmt, check its fallthrough destination.
if (auto fallthrough = cs->getFallthroughDest()) {
if (auto *P = findPattern(fallthrough.get(), this))
return P;
}

// Finally, check the canonical variable, this is necessary to correctly
// handle case body vars, we just want to take the first pattern that
// declares it in that case.
if (auto *P = findPattern(cs, canonicalVD))
return P;
}

if (auto *LCS = dyn_cast<LabeledConditionalStmt>(stmt)) {
Expand All @@ -8277,14 +8241,6 @@ Pattern *VarDecl::getParentPattern() const {
}
}

// Otherwise, check if we have to walk our case stmt's var decl list to find
// the pattern.
if (auto caseStmtPatternPair = findParentPatternCaseStmtAndPattern(this)) {
return caseStmtPatternPair->second;
}

// Otherwise, this is a case we do not know or understand. Return nullptr to
// signal we do not have any information.
return nullptr;
}

Expand Down Expand Up @@ -8345,7 +8301,7 @@ bool VarDecl::isCaseBodyVariable() const {
auto *caseStmt = dyn_cast_or_null<CaseStmt>(getRecursiveParentPatternStmt());
if (!caseStmt)
return false;
return llvm::any_of(caseStmt->getCaseBodyVariablesOrEmptyArray(),
return llvm::any_of(caseStmt->getCaseBodyVariables(),
[&](VarDecl *vd) { return vd == this; });
}

Expand Down
10 changes: 7 additions & 3 deletions lib/AST/Pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,14 @@ namespace {
return Action::Continue(P);
}

// Only walk into an expression insofar as it doesn't open a new scope -
// that is, don't walk into a closure body.
PreWalkResult<Expr *> walkToExprPre(Expr *E) override {
if (isa<ClosureExpr>(E)) {
// Only walk into an expression insofar as it doesn't open a new scope -
// that is, don't walk into a closure body, TapExpr, or
// SingleValueStmtExpr. Also don't walk into key paths since any nested
// VarDecls are invalid there, and after being diagnosed by key path
// resolution the ASTWalker won't visit them.
if (isa<ClosureExpr>(E) || isa<TapExpr>(E) ||
isa<SingleValueStmtExpr>(E) || isa<KeyPathExpr>(E)) {
return Action::SkipNode(E);
}
return Action::Continue(E);
Expand Down
Loading