Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/swift/IDE/RefactoringKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ RANGE_REFACTORING(ConvertGuardExprToIfLetExpr, "Convert To IfLet Expression", co

RANGE_REFACTORING(ConvertToComputedProperty, "Convert To Computed Property", convert.to.computed.property)

RANGE_REFACTORING(ConvertToSwitchStmt, "Convert To Switch Statement", convert.switch.stmt)

// These internal refactorings are designed to be helpful for working on
// the compiler/standard library, etc., but are likely to be just confusing and
// noise for general development.
Expand Down
266 changes: 266 additions & 0 deletions lib/IDE/Refactoring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2243,6 +2243,272 @@ bool RefactoringActionConvertGuardExprToIfLetExpr::performChange() {
return false;
}

bool RefactoringActionConvertToSwitchStmt::
isApplicable(ResolvedRangeInfo Info, DiagnosticEngine &Diag) {

class ConditionalChecker : public ASTWalker {
public:
bool ParamsUseSameVars = true;
bool ConditionUseOnlyAllowedFunctions = false;
StringRef ExpectName;

Expr *walkToExprPost(Expr *E) {
if (E->getKind() != ExprKind::DeclRef)
return E;
auto D = dyn_cast<DeclRefExpr>(E)->getDecl();
if (D->getKind() == DeclKind::Var || D->getKind() == DeclKind::Param)
ParamsUseSameVars = checkName(dyn_cast<VarDecl>(D));
if (D->getKind() == DeclKind::Func)
ConditionUseOnlyAllowedFunctions = checkName(dyn_cast<FuncDecl>(D));
if (allCheckPassed())
return E;
return nullptr;
}

bool allCheckPassed() {
return ParamsUseSameVars && ConditionUseOnlyAllowedFunctions;
}

private:
bool checkName(VarDecl *VD) {
auto Name = VD->getName().str();
if (ExpectName.empty())
ExpectName = Name;
return Name == ExpectName;
}

bool checkName(FuncDecl *FD) {
auto Name = FD->getName().str();
return Name == "~="
|| Name == "=="
|| Name == "__derived_enum_equals"
|| Name == "__derived_struct_equals"
|| Name == "||"
|| Name == "...";
}
};

class SwitchConvertable {
public:
SwitchConvertable(ResolvedRangeInfo Info) {
this->Info = Info;
}

bool isApplicable() {
if (Info.Kind != RangeKind::SingleStatement)
return false;
if (!findIfStmt())
return false;
return checkEachCondition();
}

private:
ResolvedRangeInfo Info;
IfStmt *If = nullptr;
ConditionalChecker checker;

bool findIfStmt() {
if (Info.ContainedNodes.size() != 1)
return false;
if (auto S = Info.ContainedNodes.front().dyn_cast<Stmt*>())
If = dyn_cast<IfStmt>(S);
return If != nullptr;
}

bool checkEachCondition() {
checker = ConditionalChecker();
do {
if (!checkEachElement())
return false;
} while ((If = dyn_cast_or_null<IfStmt>(If->getElseStmt())));
return true;
}

bool checkEachElement() {
bool result = true;
auto ConditionalList = If->getCond();
for (auto Element : ConditionalList) {
result &= check(Element);
}
return result;
}

bool check(StmtConditionElement ConditionElement) {
if (ConditionElement.getKind() == StmtConditionElement::CK_Availability)
return false;
if (ConditionElement.getKind() == StmtConditionElement::CK_PatternBinding)
checker.ConditionUseOnlyAllowedFunctions = true;
ConditionElement.walk(checker);
return checker.allCheckPassed();
}
};
return SwitchConvertable(Info).isApplicable();
}

bool RefactoringActionConvertToSwitchStmt::performChange() {

class VarNameFinder : public ASTWalker {
public:
std::string VarName;

Expr *walkToExprPost(Expr *E) {
if (E->getKind() != ExprKind::DeclRef)
return E;
auto D = dyn_cast<DeclRefExpr>(E)->getDecl();
if (D->getKind() != DeclKind::Var && D->getKind() != DeclKind::Param)
return E;
VarName = dyn_cast<VarDecl>(D)->getName().str().str();
return nullptr;
}
};

class ConditionalPatternFinder : public ASTWalker {
public:
ConditionalPatternFinder(SourceManager &SM) : SM(SM) {}

SmallString<64> ConditionalPattern = SmallString<64>();

Expr *walkToExprPost(Expr *E) {
if (E->getKind() != ExprKind::Binary)
return E;
auto BE = dyn_cast<BinaryExpr>(E);
if (isFunctionNameAllowed(BE))
appendPattern(dyn_cast<BinaryExpr>(E)->getArg());
return E;
}

std::pair<bool, Pattern*> walkToPatternPre(Pattern *P) {
ConditionalPattern.append(Lexer::getCharSourceRangeFromSourceRange(SM, P->getSourceRange()).str());
if (P->getKind() == PatternKind::OptionalSome)
ConditionalPattern.append("?");
return { true, nullptr };
}

private:

SourceManager &SM;

bool isFunctionNameAllowed(BinaryExpr *E) {
auto FunctionBody = dyn_cast<DotSyntaxCallExpr>(E->getFn())->getFn();
auto FunctionDeclaration = dyn_cast<DeclRefExpr>(FunctionBody)->getDecl();
auto FunctionName = dyn_cast<FuncDecl>(FunctionDeclaration)->getName().str();
return FunctionName == "~="
|| FunctionName == "=="
|| FunctionName == "__derived_enum_equals"
|| FunctionName == "__derived_struct_equals";
}

void appendPattern(TupleExpr *Tuple) {
auto PatternArgument = Tuple->getElements().back();
if (PatternArgument->getKind() == ExprKind::DeclRef)
PatternArgument = Tuple->getElements().front();
if (ConditionalPattern.size() > 0)
ConditionalPattern.append(", ");
ConditionalPattern.append(Lexer::getCharSourceRangeFromSourceRange(SM, PatternArgument->getSourceRange()).str());
}
};

class ConverterToSwitch {
public:
ConverterToSwitch(ResolvedRangeInfo Info, SourceManager &SM) : SM(SM) {
this->Info = Info;
}

void performConvert(SmallString<64> &Out) {
If = findIf();
OptionalLabel = If->getLabelInfo().Name.str().str();
ControlExpression = findControlExpression();
findPatternsAndBodies(PatternsAndBodies);
DefaultStatements = findDefaultStatements();
makeSwitchStatement(Out);
}

private:
ResolvedRangeInfo Info;
SourceManager &SM;

IfStmt *If;
IfStmt *PreviousIf;

std::string OptionalLabel;
std::string ControlExpression;
SmallVector<std::pair<std::string, std::string>, 16> PatternsAndBodies;
std::string DefaultStatements;

IfStmt *findIf() {
auto S = Info.ContainedNodes[0].dyn_cast<Stmt*>();
return dyn_cast<IfStmt>(S);
}

std::string findControlExpression() {
auto ConditionElement = If->getCond().front();
auto Finder = VarNameFinder();
ConditionElement.walk(Finder);
return Finder.VarName;
}

void findPatternsAndBodies(SmallVectorImpl<std::pair<std::string, std::string>> &Out) {
do {
auto pattern = findPattern();
auto body = findBodyStatements();
Out.push_back(std::make_pair(pattern, body));
PreviousIf = If;
} while ((If = dyn_cast_or_null<IfStmt>(If->getElseStmt())));
}

std::string findPattern() {
auto ConditionElement = If->getCond().front();
auto Finder = ConditionalPatternFinder(SM);
ConditionElement.walk(Finder);
return Finder.ConditionalPattern.str().str();
}

std::string findBodyStatements() {
return findBodyWithoutBraces(If->getThenStmt());
}

std::string findDefaultStatements() {
auto ElseBody = dyn_cast_or_null<BraceStmt>(PreviousIf->getElseStmt());
if (!ElseBody)
return getTokenText(tok::kw_break);
return findBodyWithoutBraces(ElseBody);
}

std::string findBodyWithoutBraces(Stmt *body) {
auto BS = dyn_cast<BraceStmt>(body);
if (!BS)
return Lexer::getCharSourceRangeFromSourceRange(SM, body->getSourceRange()).str().str();
if (BS->getElements().empty())
return getTokenText(tok::kw_break);
SourceRange BodyRange = BS->getElements().front().getSourceRange();
BodyRange.widen(BS->getElements().back().getSourceRange());
return Lexer::getCharSourceRangeFromSourceRange(SM, BodyRange).str().str();
}

void makeSwitchStatement(SmallString<64> &Out) {
StringRef Space = " ";
StringRef NewLine = "\n";
llvm::raw_svector_ostream OS(Out);
if (OptionalLabel.size() > 0)
OS << OptionalLabel << ":" << Space;
OS << tok::kw_switch << Space << ControlExpression << Space << tok::l_brace << NewLine;
for (auto &pair : PatternsAndBodies) {
OS << tok::kw_case << Space << pair.first << tok::colon << NewLine;
OS << pair.second << NewLine;
}
OS << tok::kw_default << tok::colon << NewLine;
OS << DefaultStatements << NewLine;
OS << tok::r_brace;
}

};

SmallString<64> result;
ConverterToSwitch(RangeInfo, SM).performConvert(result);
EditConsumer.accept(SM, RangeInfo.ContentRange, result.str());
return false;
}

/// Struct containing info about an IfStmt that can be converted into an IfExpr.
struct ConvertToTernaryExprInfo {
ConvertToTernaryExprInfo() {}
Expand Down
Loading