From 151d371bf2192d7d3fb9aab4a5dec8806845e803 Mon Sep 17 00:00:00 2001 From: Matt Russell Date: Wed, 9 Mar 2011 08:35:59 +0000 Subject: [PATCH] wip --- .../scalariform/forexpander/ForExpander.scala | 59 ++++++++++++++++--- .../scala/scalariform/parser/AstNodes.scala | 4 ++ .../scalariform/parser/ScalaParser.scala | 19 +++--- .../forexpander/ForExpanderTest.scala | 4 +- 4 files changed, 67 insertions(+), 19 deletions(-) diff --git a/scalariform/src/main/scala/scalariform/forexpander/ForExpander.scala b/scalariform/src/main/scala/scalariform/forexpander/ForExpander.scala index 6c21f456..67888797 100644 --- a/scalariform/src/main/scala/scalariform/forexpander/ForExpander.scala +++ b/scalariform/src/main/scala/scalariform/forexpander/ForExpander.scala @@ -42,6 +42,44 @@ object ForExpander { private def token(tokenType: TokenType, text: String) = Token(tokenType, text, 0, text.length - 1) + private def makeBind(expr: Expr): Bind = expr match { + case Expr(List(bind@Bind(_, _, _))) ⇒ bind + case _ ⇒ Bind(token(VARID, freshName()), token(AT, "@"), expr.contents) + } + + private var n: Int = 1 + private def freshName(): String = { + val name = "freshName" + n + n += 1 + name + } + + private def intersperse[A](sep: ⇒ A, c: List[A]): List[A] = c match { + case Nil ⇒ Nil + case List(x) ⇒ List(x) + case x :: xs ⇒ x :: sep :: intersperse(sep, xs) + } + + private def comma = token(COMMA, ",") + + private def makeTuple(exprs: List[Expr]): Expr = { + val contents = intersperse(GeneralTokens(List(comma)), exprs) + Expr(List(ParenExpr(token(LPAREN, "("), contents, token(RPAREN, ")")))) + } + + private def makeTupleTerm(uscoreOrIds: List[Token]): Expr = Expr(List(uscoreOrIds match { + case List(uscoreOrId) ⇒ + GeneralTokens(List(uscoreOrId)) + case _ ⇒ + val contents = List(GeneralTokens(intersperse(comma, uscoreOrIds))) + ParenExpr(token(LPAREN, "("), contents, token(RPAREN, ")")) + })) + + private def makePatDef(pat: ExprElement, rhs: Expr): List[FullDefOrDcl] = + List(FullDefOrDcl(Nil, Nil, PatDefOrDcl(token(VAL, "val"), pattern = Expr(List(pat)), otherPatterns = Nil, typedOpt = None, equalsClauseOption = Some(token(EQUALS, "="), rhs)))) + + private def makeValue(bind: Bind): Token = bind.uscoreOrId + private def makeFor(mapName: String, flatMapName: String, enums: List[NscEnumerator], body: Expr): CallExpr = { def makeCombination(meth: String, qual: Expr, pat: Expr, body: Expr) = { @@ -57,23 +95,26 @@ object ForExpander { makeCombination(flatMapName, rhs, pat, Expr(List(makeFor(mapName, flatMapName, rest, body)))) case ValFrom(pat, rhs) :: Filter(test) :: rest ⇒ makeFor(mapName, flatMapName, ValFrom(pat, Expr(List(makeCombination("withFilter", rhs, pat, test)))) :: rest, body) - case ValFrom(pat , rhs) :: rest ⇒ + case ValFrom(pat, rhs) :: rest ⇒ val valeqs = rest.take(MaxTupleArity - 1).takeWhile(_.isInstanceOf[ValEq]) val rest1 = rest.drop(valeqs.length) val pats = valeqs map { case ValEq(pat, _) ⇒ pat } val rhss = valeqs map { case ValEq(_, rhs) ⇒ rhs } - + val defpat1 = makeBind(pat) val defpats = pats map makeBind - val pdefs = (defpats, rhss).zipped flatMap makePatDef + val pdefs: List[Stat] = (defpats, rhss).zipped flatMap makePatDef val ids = (defpat1 :: defpats) map makeValue - val rhs1 = makeForYield( - List(ValFrom(pos, defpat1, rhs)), - Block(pdefs, atPos(wrappingPos(ids)) { makeTupleTerm(ids, true) }) setPos wrappingPos(pdefs)) - val allpats = (pat :: pats) map (_.duplicate) - val vfrom1 = ValFrom(r2p(pos.startOrPoint, pos.point, rhs1.pos.endOrPoint), atPos(wrappingPos(allpats)) { makeTuple(allpats, false) }, rhs1) + + val (firstStat :: otherStats) = pdefs :+ makeTupleTerm(ids) + val statSeq = Right(StatSeq(None, Some(firstStat), otherStats map { stat ⇒ (token(SEMI, ";"), Some(stat)) })) + val rhs1 = makeFor("map", "flatMap", + List(ValFrom(Expr(List(defpat1)), rhs)), + Expr(List(BlockExpr(token(LBRACE, "{"), statSeq, token(RBRACE, "}"))))) + + val allpats = pat :: pats + val vfrom1 = ValFrom(makeTuple(allpats), Expr(List(rhs1))) makeFor(mapName, flatMapName, vfrom1 :: rest1, body) - //throw new UnsupportedOperationException } } diff --git a/scalariform/src/main/scala/scalariform/parser/AstNodes.scala b/scalariform/src/main/scala/scalariform/parser/AstNodes.scala index ce07c5d0..8b78421f 100644 --- a/scalariform/src/main/scala/scalariform/parser/AstNodes.scala +++ b/scalariform/src/main/scala/scalariform/parser/AstNodes.scala @@ -271,6 +271,10 @@ case class PatDefOrDcl(valOrVarToken: Token, } +case class Bind(uscoreOrId: Token, at: Token, rhs: List[ExprElement]) extends ExprElement { + lazy val tokens = flatten(uscoreOrId, at, rhs) +} + sealed trait FunBody extends AstNode case class ProcFunBody(newlineOpt: Option[Token], bodyBlock: BlockExpr) extends FunBody { lazy val tokens = flatten(newlineOpt, bodyBlock) diff --git a/scalariform/src/main/scala/scalariform/parser/ScalaParser.scala b/scalariform/src/main/scala/scalariform/parser/ScalaParser.scala index dee2e69c..e35a5250 100644 --- a/scalariform/src/main/scala/scalariform/parser/ScalaParser.scala +++ b/scalariform/src/main/scala/scalariform/parser/ScalaParser.scala @@ -937,16 +937,17 @@ class ScalaParser(tokens: Array[Token]) { def pattern2(): Expr = { val firstPattern = pattern3() - val atOtherOpt = if (AT) { - // TODO: Compare Parsers.scala - optional { - val atToken = nextToken() - val otherPattern = pattern3() - (atToken, otherPattern) + if (AT) // TODO: Compare Parsers.scala + firstPattern match { + case List(GeneralTokens(List(lhsToken))) if lhsToken.tokenType == USCORE || lhsToken.tokenType.isId ⇒ + val atToken = nextToken() + val otherPattern = pattern3() + makeExpr(Bind(lhsToken, atToken, otherPattern)) + case _ ⇒ + makeExpr(firstPattern) } - } else - None - makeExpr(firstPattern, atOtherOpt) + else + makeExpr(firstPattern) } def pattern3(): List[ExprElement] = { diff --git a/scalariform/src/test/scala/scalariform/forexpander/ForExpanderTest.scala b/scalariform/src/test/scala/scalariform/forexpander/ForExpanderTest.scala index f7936823..5b0ec1c1 100644 --- a/scalariform/src/test/scala/scalariform/forexpander/ForExpanderTest.scala +++ b/scalariform/src/test/scala/scalariform/forexpander/ForExpanderTest.scala @@ -12,7 +12,8 @@ import scalariform.utils.Range // format: OAF class ForExpanderTest extends FlatSpec with ShouldMatchers { - private val source = "for (x <- 1 to 10; y <- 1 to 10 if y > x) yield x * y" +// private val source = "for (x <- 1 to 10; y <- 1 to 10 if y > x) yield x * y" + private val source = "for (x <- 1 to 10; z = x * x) yield z" private val (hiddenTokenInfo, tokens) = ScalaLexer.tokeniseFull(source) private val parser = new ScalaParser(tokens.toArray) private val Expr(List(forExpr: ForExpr)) = parser.safeParse(parser.expr).get @@ -25,6 +26,7 @@ class ForExpanderTest extends FlatSpec with ShouldMatchers { expandedFor.tokens foreach println val rawText = expandedFor.tokens.map(_.text).mkString(" ") + println(rawText) val result = ScalaFormatter.format(rawText) println(source) println(result)