Skip to content

Commit

Permalink
forexpander
Browse files Browse the repository at this point in the history
  • Loading branch information
mdr committed Mar 8, 2011
1 parent 9e14ea0 commit a37d3b7
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 16 deletions.
@@ -0,0 +1,62 @@
package scalariform.forexpander

import scalariform.parser._
import scalariform.lexer.Tokens._
import scalariform.lexer._

sealed trait NscEnumerator
case class ValFrom(pat: Expr, rhs: Expr) extends NscEnumerator
case class ValEq(pat: Expr, rhs: Expr) extends NscEnumerator
case class Filter(test: Expr) extends NscEnumerator

object ForExpander {

def expandFor(forExpr: ForExpr): CallExpr = {
val (mapName, flatMapName) = if (forExpr.yieldOption.isDefined) ("map", "flatMap") else ("foreach", "foreach")
makeFor(mapName, flatMapName, makeEnumerators(forExpr.enumerators), forExpr.body)
}

def makeEnumerators(enumerators: Enumerators): List[NscEnumerator] = {

def makeEnumerators(enumerator: Enumerator): List[NscEnumerator] = enumerator match {
case OldForGuard(expr) List(Filter(expr))
case guard: Guard List(makeFilter(guard))
case generator: Generator enumeratorsFromGenerator(generator)
}

def enumeratorsFromGenerator(generator: Generator): List[NscEnumerator] = {
val Generator(_, pattern, equalsOrArrowToken, expr, guards) = generator
(generator.equalsOrArrowToken.tokenType match {
case EQUALS ValEq(pattern, expr)
case LARROW ValFrom(pattern, expr)
}) :: guards.map(makeFilter)
}

makeEnumerators(enumerators.initialGenerator) ::: (enumerators.rest flatMap { case (_, enumerator) makeEnumerators(enumerator) })
}

private def makeFilter(guard: Guard) = Filter(guard.expr)

private def token(tokenType: TokenType, text: String) =
Token(tokenType, text, 0, text.length - 1)

private def makeFor(mapName: String, flatMapName: String, enums: List[NscEnumerator], body: Expr): CallExpr = {

def makeCombination(meth: String, qual: Expr, pat: Expr, body: Expr) = {
val argument = Argument(Expr(List(AnonymousFunction(List(pat), token(ARROW, "=>"), List(body)))))
val arguments = ParenArgumentExprs(token(LPAREN, "("), List(argument), token(RPAREN, ")"))
CallExpr(Some(List(ParenExpr(token(LPAREN, "("), List(qual), token(RPAREN, ")"))), token(DOT, ".")), token(VARID, meth), newLineOptsAndArgumentExprss = List((None, arguments)))
}

enums match {
case ValFrom(pat, rhs) :: Nil
makeCombination(mapName, rhs, pat, body)
case ValFrom(pat, rhs) :: (rest@(ValFrom(_, _) :: _))
makeCombination(flatMapName, rhs, pat, Expr(List(makeFor(mapName, flatMapName, rest, body))))
case ValFrom(pat, rhs) :: Filter(test) :: rest
makeFor(mapName, flatMapName, ValFrom(pat, Expr(List(makeCombination("withFilter", rhs, pat, test)))) :: rest, body)
}

}

}
Expand Up @@ -471,13 +471,12 @@ trait ExprFormatter { self: HasFormattingPreferences with AnnotationFormatter wi
formatResult
}

private def format(enumerator: Enumerator)(implicit formatterState: FormatterState): FormatResult = {
private def format(enumerator: Enumerator)(implicit formatterState: FormatterState): FormatResult =
enumerator match {
case expr@Expr(_) format(expr)
case expr@OldForGuard(subexpr) format(subexpr)
case generator@Generator(_, _, _, _, _) format(generator)
case guard@Guard(_, _) format(guard: Guard)
}
}

private def format(generator: Generator)(implicit formatterState: FormatterState): FormatResult = {
val Generator(valOption: Option[Token], pattern: Expr, equalsOrArrowToken: Token, expr: Expr, guards: List[Guard]) = generator
Expand Down
Expand Up @@ -15,9 +15,9 @@ class ScalaLexerReader(val tokens: List[Token]) extends Reader[Token] {

private class ScalaLexerPosition(token: Token) extends Position {

def line: Int = token.getLine
def line: Int = -1

def column: Int = token.getCharPositionInLine
def column: Int = -1

protected def lineContents: String = token.getText

Expand Down
2 changes: 0 additions & 2 deletions scalariform/src/main/scala/scalariform/lexer/Token.scala
Expand Up @@ -7,8 +7,6 @@ case class Token(tokenType: TokenType, text: String, startIndex: Int, stopIndex:
require(tokenType == Tokens.EOF || stopIndex - startIndex + 1 == text.length)
lazy val getText = text // Delete me?
lazy val getType = tokenType // Delete me?
lazy val getLine = -1 // TODO
lazy val getCharPositionInLine = -1 // TODO
lazy val getStartIndex = startIndex // Delete me?
lazy val getStopIndex = stopIndex // Delete me?
def length = stopIndex - startIndex + 1
Expand Down
13 changes: 5 additions & 8 deletions scalariform/src/main/scala/scalariform/parser/AstNodes.scala
Expand Up @@ -109,7 +109,7 @@ case class CallByNameTypeElement(arrow: Token) extends AstNode with TypeElement

sealed trait ExprElement extends AstNode

case class Expr(contents: List[ExprElement]) extends AstNode with ExprElement with Stat with Enumerator with XmlContents with ImportExpr {
case class Expr(contents: List[ExprElement]) extends AstNode with ExprElement with Stat with XmlContents with ImportExpr {
lazy val tokens = flatten(contents)
}

Expand Down Expand Up @@ -205,15 +205,12 @@ case class Enumerators(initialGenerator: Generator, rest: List[(Token, Enumerato
lazy val tokens = flatten(initialGenerator, rest)
}

case class Generator(
valOption: Option[Token],
pattern: Expr,
equalsOrArrowToken: Token,
expr: Expr,
guards: List[Guard]) extends AstNode with Enumerator {

case class Generator(valOption: Option[Token], pattern: Expr, equalsOrArrowToken: Token, expr: Expr, guards: List[Guard]) extends AstNode with Enumerator {
lazy val tokens = flatten(valOption, pattern, equalsOrArrowToken, expr, guards)
}

case class OldForGuard(expr: Expr) extends AstNode with Enumerator {
lazy val tokens = flatten(expr)
}

case class Guard(ifToken: Token, expr: Expr) extends AstNode with Enumerator {
Expand Down
Expand Up @@ -874,7 +874,7 @@ class ScalaParser(tokens: Array[Token]) {
else generator(eqOK = true)
} else {
if (VAL) generator(eqOK = true)
else expr()
else OldForGuard(expr())
}
otherEnumerators += ((statSep, enumerator))
}
Expand Down
@@ -0,0 +1,32 @@
package scalariform.forexpander

import scalariform.lexer._
import scalariform.parser._
import scalariform.formatter._

import org.scalatest._
import org.scalatest.matchers._

import scalariform.utils.Range

// format: OAF
class ForExpanderTest extends FlatSpec with ShouldMatchers {

private val source = "for (x <- 1 to 10; y <- 1 to 10 if y > x) yield x * y"
private val (hiddenTokenInfo, tokens) = ScalaLexer.tokeniseFull(source)
private val parser = new ScalaParser(tokens.toArray)
private val Expr(List(forExpr: ForExpr)) = parser.safeParse(parser.expr).get

val expandedFor = ForExpander.expandFor(forExpr)

println(expandedFor)
println()

expandedFor.tokens foreach println

val rawText = expandedFor.tokens.map(_.text).mkString(" ")
val result = ScalaFormatter.format(rawText)
println(source)
println(result)

}

0 comments on commit a37d3b7

Please sign in to comment.