Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
104 additions
and
16 deletions.
There are no files selected for viewing
62 changes: 62 additions & 0 deletions
62
scalariform/src/main/scala/scalariform/forexpander/ForExpander.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
package scalariform.forexpander | ||
|
||
import scalariform.parser._ | ||
import scalariform.lexer.Tokens._ | ||
import scalariform.lexer._ | ||
|
||
sealed trait NscEnumerator | ||
case class ValFrom(pat: Expr, rhs: Expr) extends NscEnumerator | ||
case class ValEq(pat: Expr, rhs: Expr) extends NscEnumerator | ||
case class Filter(test: Expr) extends NscEnumerator | ||
|
||
object ForExpander { | ||
|
||
def expandFor(forExpr: ForExpr): CallExpr = { | ||
val (mapName, flatMapName) = if (forExpr.yieldOption.isDefined) ("map", "flatMap") else ("foreach", "foreach") | ||
makeFor(mapName, flatMapName, makeEnumerators(forExpr.enumerators), forExpr.body) | ||
} | ||
|
||
def makeEnumerators(enumerators: Enumerators): List[NscEnumerator] = { | ||
|
||
def makeEnumerators(enumerator: Enumerator): List[NscEnumerator] = enumerator match { | ||
case OldForGuard(expr) ⇒ List(Filter(expr)) | ||
case guard: Guard ⇒ List(makeFilter(guard)) | ||
case generator: Generator ⇒ enumeratorsFromGenerator(generator) | ||
} | ||
|
||
def enumeratorsFromGenerator(generator: Generator): List[NscEnumerator] = { | ||
val Generator(_, pattern, equalsOrArrowToken, expr, guards) = generator | ||
(generator.equalsOrArrowToken.tokenType match { | ||
case EQUALS ⇒ ValEq(pattern, expr) | ||
case LARROW ⇒ ValFrom(pattern, expr) | ||
}) :: guards.map(makeFilter) | ||
} | ||
|
||
makeEnumerators(enumerators.initialGenerator) ::: (enumerators.rest flatMap { case (_, enumerator) ⇒ makeEnumerators(enumerator) }) | ||
} | ||
|
||
private def makeFilter(guard: Guard) = Filter(guard.expr) | ||
|
||
private def token(tokenType: TokenType, text: String) = | ||
Token(tokenType, text, 0, text.length - 1) | ||
|
||
private def makeFor(mapName: String, flatMapName: String, enums: List[NscEnumerator], body: Expr): CallExpr = { | ||
|
||
def makeCombination(meth: String, qual: Expr, pat: Expr, body: Expr) = { | ||
val argument = Argument(Expr(List(AnonymousFunction(List(pat), token(ARROW, "=>"), List(body))))) | ||
val arguments = ParenArgumentExprs(token(LPAREN, "("), List(argument), token(RPAREN, ")")) | ||
CallExpr(Some(List(ParenExpr(token(LPAREN, "("), List(qual), token(RPAREN, ")"))), token(DOT, ".")), token(VARID, meth), newLineOptsAndArgumentExprss = List((None, arguments))) | ||
} | ||
|
||
enums match { | ||
case ValFrom(pat, rhs) :: Nil ⇒ | ||
makeCombination(mapName, rhs, pat, body) | ||
case ValFrom(pat, rhs) :: (rest@(ValFrom(_, _) :: _)) ⇒ | ||
makeCombination(flatMapName, rhs, pat, Expr(List(makeFor(mapName, flatMapName, rest, body)))) | ||
case ValFrom(pat, rhs) :: Filter(test) :: rest ⇒ | ||
makeFor(mapName, flatMapName, ValFrom(pat, Expr(List(makeCombination("withFilter", rhs, pat, test)))) :: rest, body) | ||
} | ||
|
||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
32 changes: 32 additions & 0 deletions
32
scalariform/src/test/scala/scalariform/forexpander/ForExpanderTest.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
package scalariform.forexpander | ||
|
||
import scalariform.lexer._ | ||
import scalariform.parser._ | ||
import scalariform.formatter._ | ||
|
||
import org.scalatest._ | ||
import org.scalatest.matchers._ | ||
|
||
import scalariform.utils.Range | ||
|
||
// format: OAF | ||
class ForExpanderTest extends FlatSpec with ShouldMatchers { | ||
|
||
private val source = "for (x <- 1 to 10; y <- 1 to 10 if y > x) yield x * y" | ||
private val (hiddenTokenInfo, tokens) = ScalaLexer.tokeniseFull(source) | ||
private val parser = new ScalaParser(tokens.toArray) | ||
private val Expr(List(forExpr: ForExpr)) = parser.safeParse(parser.expr).get | ||
|
||
val expandedFor = ForExpander.expandFor(forExpr) | ||
|
||
println(expandedFor) | ||
println() | ||
|
||
expandedFor.tokens foreach println | ||
|
||
val rawText = expandedFor.tokens.map(_.text).mkString(" ") | ||
val result = ScalaFormatter.format(rawText) | ||
println(source) | ||
println(result) | ||
|
||
} |