Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions include/swift/IDE/SourceEntityWalker.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define SWIFT_IDE_SOURCE_ENTITY_WALKER_H

#include "swift/AST/ASTWalker.h"
#include "swift/Basic/Defer.h"
#include "swift/Basic/LLVM.h"
#include "swift/Basic/SourceLoc.h"
#include "llvm/ADT/PointerUnion.h"
Expand Down Expand Up @@ -176,6 +177,24 @@ class SourceEntityWalker {
virtual ~SourceEntityWalker() {}

virtual void anchor();

/// Retrieve the current ASTWalker being used to traverse the AST.
const ASTWalker &getWalker() const {
assert(Walker && "Not walking!");
return *Walker;
}

private:
ASTWalker *Walker = nullptr;

/// Utility that lets us keep track of an ASTWalker when walking.
bool performWalk(ASTWalker &W, llvm::function_ref<bool(void)> DoWalk) {
Walker = &W;
SWIFT_DEFER {
Walker = nullptr;
};
return DoWalk();
}
};

} // namespace swift
Expand Down
38 changes: 28 additions & 10 deletions lib/IDE/Refactoring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4623,16 +4623,24 @@ class NodesToPrint {
!Nodes.back().isImplicit();
}

/// If the last recorded node is an explicit return or break statement, drop
/// it from the list.
void dropTrailingReturnOrBreak() {
/// If the last recorded node is an explicit return or break statement that
/// can be safely dropped, drop it from the list.
void dropTrailingReturnOrBreakIfPossible() {
if (!hasTrailingReturnOrBreak())
return;

auto *Node = Nodes.back().get<Stmt *>();

// If this is a return statement with return expression, let's preserve it.
if (auto *RS = dyn_cast<ReturnStmt>(Node)) {
if (RS->hasResult())
return;
}

// Remove the node from the list, but make sure to add it as a possible
// comment loc to preserve any of its attached comments.
auto Node = Nodes.pop_back_val();
addPossibleCommentLoc(Node.getStartLoc());
Nodes.pop_back();
addPossibleCommentLoc(Node->getStartLoc());
}

/// Returns a list of nodes to print in a brace statement. This picks up the
Expand Down Expand Up @@ -4895,7 +4903,7 @@ struct CallbackClassifier {
NodesToPrint Nodes) {
if (Nodes.hasTrailingReturnOrBreak()) {
CurrentBlock = OtherBlock;
Nodes.dropTrailingReturnOrBreak();
Nodes.dropTrailingReturnOrBreakIfPossible();
Block->addAllNodes(std::move(Nodes));
} else {
Block->addAllNodes(std::move(Nodes));
Expand Down Expand Up @@ -5600,12 +5608,14 @@ class AsyncConverter : private SourceEntityWalker {
// it would have been lifted out of the switch statement.
if (auto *SS = dyn_cast<SwitchStmt>(BS->getTarget())) {
if (HandledSwitches.contains(SS))
replaceRangeWithPlaceholder(S->getSourceRange());
return replaceRangeWithPlaceholder(S->getSourceRange());
}
} else if (isa<ReturnStmt>(S) && NestedExprCount == 0) {
// For a return, if it's not nested inside another closure or function,
// turn it into a placeholder, as it will be lifted out of the callback.
replaceRangeWithPlaceholder(S->getSourceRange());
// Note that we only turn the 'return' token into a placeholder as we
// still want to be able to apply transforms to the argument.
replaceRangeWithPlaceholder(S->getStartLoc());
}
}
return true;
Expand Down Expand Up @@ -5734,15 +5744,23 @@ class AsyncConverter : private SourceEntityWalker {
void addHandlerCall(const CallExpr *CE) {
auto Exprs = TopHandler.extractResultArgs(CE);

bool AddedReturnOrThrow = true;
if (!Exprs.isError()) {
OS << tok::kw_return;
// It's possible the user has already written an explicit return statement
// for the completion handler call, e.g 'return completion(args...)'. In
// that case, be sure not to add another return.
auto *parent = getWalker().Parent.getAsStmt();
AddedReturnOrThrow = !(parent && isa<ReturnStmt>(parent));
if (AddedReturnOrThrow)
OS << tok::kw_return;
} else {
OS << tok::kw_throw;
}

ArrayRef<Expr *> Args = Exprs.args();
if (!Args.empty()) {
OS << " ";
if (AddedReturnOrThrow)
OS << " ";
if (Args.size() > 1)
OS << tok::l_paren;
for (size_t I = 0, E = Args.size(); I < E; ++I) {
Expand Down
12 changes: 6 additions & 6 deletions lib/IDE/SourceEntityWalker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -806,32 +806,32 @@ bool SemaAnnotator::shouldIgnore(Decl *D) {

bool SourceEntityWalker::walk(SourceFile &SrcFile) {
SemaAnnotator Annotator(*this);
return SrcFile.walk(Annotator);
return performWalk(Annotator, [&]() { return SrcFile.walk(Annotator); });
}

bool SourceEntityWalker::walk(ModuleDecl &Mod) {
SemaAnnotator Annotator(*this);
return Mod.walk(Annotator);
return performWalk(Annotator, [&]() { return Mod.walk(Annotator); });
}

bool SourceEntityWalker::walk(Stmt *S) {
SemaAnnotator Annotator(*this);
return S->walk(Annotator);
return performWalk(Annotator, [&]() { return S->walk(Annotator); });
}

bool SourceEntityWalker::walk(Expr *E) {
SemaAnnotator Annotator(*this);
return E->walk(Annotator);
return performWalk(Annotator, [&]() { return E->walk(Annotator); });
}

bool SourceEntityWalker::walk(Decl *D) {
SemaAnnotator Annotator(*this);
return D->walk(Annotator);
return performWalk(Annotator, [&]() { return D->walk(Annotator); });
}

bool SourceEntityWalker::walk(DeclContext *DC) {
SemaAnnotator Annotator(*this);
return DC->walkContext(Annotator);
return performWalk(Annotator, [&]() { return DC->walkContext(Annotator); });
}

bool SourceEntityWalker::walk(ASTNode N) {
Expand Down
42 changes: 42 additions & 0 deletions test/refactoring/ConvertAsync/convert_function.swift
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,45 @@ func voidResultCompletion(completion: (Result<Void, Error>) -> Void) {
// RUN: %refactor -add-async-alternative -dump-text -source-filename %s -pos=%(line+1):1 | %FileCheck -check-prefix=NON-COMPLETION-HANDLER %s
func functionWithSomeHandler(handler: (String) -> Void) {}
// NON-COMPLETION-HANDLER: func functionWithSomeHandler() async -> String {}

// rdar://77789360 Make sure we don't print a double return statement.
// RUN: %refactor -add-async-alternative -dump-text -source-filename %s -pos=%(line+1):1 | %FileCheck -check-prefix=RETURN-HANDLING %s
func testReturnHandling(_ completion: (String?, Error?) -> Void) {
return completion("", nil)
}
// RETURN-HANDLING: func testReturnHandling() async throws -> String {
// RETURN-HANDLING-NEXT: {{^}} return ""{{$}}
// RETURN-HANDLING-NEXT: }

// rdar://77789360 Make sure we don't print a double return statement and don't
// completely drop completion(a).
// RUN: %refactor -add-async-alternative -dump-text -source-filename %s -pos=%(line+1):1 | %FileCheck -check-prefix=RETURN-HANDLING2 %s
func testReturnHandling2(completion: @escaping (String) -> ()) {
testReturnHandling { x, err in
guard let x = x else {
let a = ""
return completion(a)
}
let b = ""
return completion(b)
}
}
// RETURN-HANDLING2: func testReturnHandling2() async -> String {
// RETURN-HANDLING2-NEXT: do {
// RETURN-HANDLING2-NEXT: let x = try await testReturnHandling()
// RETURN-HANDLING2-NEXT: let b = ""
// RETURN-HANDLING2-NEXT: {{^}}<#return#> b{{$}}
// RETURN-HANDLING2-NEXT: } catch let err {
// RETURN-HANDLING2-NEXT: let a = ""
// RETURN-HANDLING2-NEXT: {{^}}<#return#> a{{$}}
// RETURN-HANDLING2-NEXT: }
// RETURN-HANDLING2-NEXT: }

// FIXME: We should arguably be able to handle transforming this completion handler call (rdar://78011350).
// RUN: %refactor -add-async-alternative -dump-text -source-filename %s -pos=%(line+1):1 | %FileCheck -check-prefix=RETURN-HANDLING3 %s
func testReturnHandling3(_ completion: (String?, Error?) -> Void) {
return (completion("", nil))
}
// RETURN-HANDLING3: func testReturnHandling3() async throws -> String {
// RETURN-HANDLING3-NEXT: {{^}} return (<#completion#>("", nil)){{$}}
// RETURN-HANDLING3-NEXT: }