Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Applicative desugaring in for comprehensions #5819

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/compiler/scala/reflect/quasiquotes/Parsers.scala
Expand Up @@ -175,12 +175,15 @@ trait Parsers { self: Quasiquotes =>
stats
}

override def enumerator(isFirst: Boolean, allowNestedIf: Boolean = true) =
if (isHole && lookingAhead { in.token == EOF || in.token == RPAREN || isStatSep }) {
val res = ForEnumPlaceholder(in.name) :: Nil
override def enumerator(isFirst: Boolean, allowNested: Boolean = true) =
if (isHole && lookingAhead { in.token == EOF || in.token == RPAREN || in.token == WITH || isStatSep }) {
val tree = ForEnumPlaceholder(in.name)
in.nextToken()
res
} else super.enumerator(isFirst, allowNestedIf)
if(in.token == WITH) {
in.nextToken()
gen.With(tree) :: Nil
} else tree :: Nil
} else super.enumerator(isFirst, allowNested)
}
}

Expand Down Expand Up @@ -219,7 +222,7 @@ trait Parsers { self: Quasiquotes =>

object ForEnumeratorParser extends Parser {
def entryPoint = { parser =>
val enums = parser.enumerator(isFirst = false, allowNestedIf = false)
val enums = parser.enumerator(isFirst = false, allowNested = false)
assert(enums.length == 1)
implodePatDefs(enums.head)
}
Expand Down
2 changes: 2 additions & 0 deletions src/compiler/scala/reflect/quasiquotes/Reifiers.scala
Expand Up @@ -178,6 +178,8 @@ trait Reifiers { self: Quasiquotes =>
reifyBuildCall(nme.SyntacticValEq, pat, rhs)
case SyntacticFilter(cond) =>
reifyBuildCall(nme.SyntacticFilter, cond)
case SyntacticWith(enum) =>
reifyBuildCall(nme.SyntacticWith, enum)
case SyntacticFor(enums, body) =>
reifyBuildCall(nme.SyntacticFor, enums, body)
case SyntacticForYield(enums, body) =>
Expand Down
28 changes: 19 additions & 9 deletions src/compiler/scala/tools/nsc/ast/parser/Parsers.scala
Expand Up @@ -1831,7 +1831,7 @@ self =>
* Generator ::= Pattern1 (`<-' | `=') Expr [Guard]
* }}}
*/
def generator(eqOK: Boolean, allowNestedIf: Boolean = true): List[Tree] = {
def generator(eqOK: Boolean, allowNested: Boolean = true): List[Tree] = {
val start = in.offset
val hasVal = in.token == VAL
if (hasVal)
Expand All @@ -1850,18 +1850,28 @@ self =>
else accept(LARROW)
val rhs = expr()

def loop(): List[Tree] =
// why max? IDE stress tests have shown that lastOffset could be less than start,
// I guess this happens if instead if a for-expression we sit on a closing paren.
val genPos = r2p(start, point, in.lastOffset max start)
val genr = gen.mkGenerator(genPos, pat, hasEq, rhs)

val hasWith = !hasEq && in.token == WITH
val head = if(hasWith) {
val offsetOfWith = in.offset
in.nextToken()
gen.With(genr).setPos(r2p(start, offsetOfWith, in.lastOffset))
} else genr

def nestedFilters(): List[Tree] =
if (in.token != IF) Nil
else makeFilter(in.offset, guard()) :: loop()
else makeFilter(in.offset, guard()) :: nestedFilters()

val tail =
if (allowNestedIf) loop()
else Nil
if(!allowNested) Nil
else if(hasWith) generator(eqOK = false)
else nestedFilters()

// why max? IDE stress tests have shown that lastOffset could be less than start,
// I guess this happens if instead if a for-expression we sit on a closing paren.
val genPos = r2p(start, point, in.lastOffset max start)
gen.mkGenerator(genPos, pat, hasEq, rhs) :: tail
head :: tail
}

def makeFilter(start: Offset, tree: Tree) = gen.Filter(tree).setPos(r2p(start, tree.pos.point, tree.pos.end))
Expand Down
7 changes: 7 additions & 0 deletions src/reflect/scala/reflect/api/Internals.scala
Expand Up @@ -737,6 +737,13 @@ trait Internals { self: Universe =>
def unapply(tree: Tree): Option[(Tree)]
}

val SyntacticWith: SyntacticWithExtractor

trait SyntacticWithExtractor {
def apply(enum: Tree): Tree
def unapply(tree: Tree): Option[Tree]
}

val SyntacticEmptyTypeTree: SyntacticEmptyTypeTreeExtractor

trait SyntacticEmptyTypeTreeExtractor {
Expand Down
156 changes: 114 additions & 42 deletions src/reflect/scala/reflect/internal/ReificationSupport.scala
Expand Up @@ -624,6 +624,11 @@ trait ReificationSupport { self: SymbolTable =>
def unapply(tree: Tree): Option[Tree] = gen.Filter.unapply(tree)
}

object SyntacticWith extends SyntacticWithExtractor {
def apply(tree: Tree): Tree = gen.With(tree)
def unapply(tree: Tree): Option[Tree] = gen.With.unapply(tree)
}

// If a tree in type position isn't provided by the user (e.g. `tpt` fields of
// `ValDef` and `DefDef`, function params etc), then it's going to be parsed as
// TypeTree with empty original and empty tpe. This extractor matches such trees
Expand Down Expand Up @@ -744,18 +749,6 @@ trait ReificationSupport { self: SymbolTable =>
}
}

// transform a chain of withFilter calls into a sequence of for filters
protected object UnFilter {
def unapply(tree: Tree): Some[(Tree, List[Tree])] = tree match {
case UnCheckIfRefutable(_, _) =>
Some((tree, Nil))
case FilterCall(UnFilter(rhs, rest), UnClosure(_, test)) =>
Some((rhs, rest :+ SyntacticFilter(test)))
case _ =>
Some((tree, Nil))
}
}

// undo gen.mkCheckIfRefutable
protected object UnCheckIfRefutable {
def unapply(tree: Tree): Option[(Tree, Tree)] = tree match {
Expand All @@ -768,59 +761,138 @@ trait ReificationSupport { self: SymbolTable =>
}
}

// undo gen.mkFor:makeCombination accounting for possible extra implicit argument
// undo gen.mkFor:makeCombination/makeProduct accounting for possible extra implicit argument
protected class UnForCombination(name: TermName) {
def unapply(tree: Tree) = tree match {
case SyntacticApplied(SyntacticTypeApplied(sel @ Select(lhs, meth), _), (f :: Nil) :: Nil)
case SyntacticApplied(SyntacticTypeApplied(sel @ Select(lhs, meth), _), (rhs :: Nil) :: Nil)
if name == meth && sel.hasAttachment[ForAttachment.type] =>
Some((lhs, f))
case SyntacticApplied(SyntacticTypeApplied(sel @ Select(lhs, meth), _), (f :: Nil) :: _ :: Nil)
Some((lhs, rhs))
case SyntacticApplied(SyntacticTypeApplied(sel @ Select(lhs, meth), _), (rhs :: Nil) :: _ :: Nil)
if name == meth && sel.hasAttachment[ForAttachment.type] =>
Some((lhs, f))
Some((lhs, rhs))
case _ => None
}
}
protected object UnMap extends UnForCombination(nme.map)
protected object UnForeach extends UnForCombination(nme.foreach)
protected object UnFlatMap extends UnForCombination(nme.flatMap)
protected object UnProduct extends UnForCombination(nme.product)

// transform a chain of withFilter calls into a sequence of for filters
protected object UnFilter {
def unapply(tree: Tree): Option[(Tree, List[Tree])] = tree match {
case UnCheckIfRefutable(_, _) =>
None
case FilterCall(rhs, UnClosure(_, test)) => rhs match {
case UnFilter(finalRhs, filters) => Some((finalRhs, filters :+ test))
case _ => Some((rhs, test :: Nil))
}
case _ =>
None
}
}

protected object UnFilterWithPat {
def unapply(patRhs: (Tree, Tree)): Option[((Tree, Tree), List[Tree])] = patRhs match {
case (pat, UnFilter(nrhs, tests)) => Some((pat, nrhs), tests)
case _ => None
}
}

protected object MaybeBind {
def unapply(tree: Tree): Option[Tree] = tree match {
case Bind(_, body) => Some(body)
case _ => Some(tree)
}
}

protected object UnProductWithPat {
def unapply(patRhs: (Tree, Tree)): Option[((Tree, Tree), (Tree, Tree))] = patRhs match {
case (MaybeBind(SyntacticTuple(List(lpat, rpat))), UnProduct(lhs, rhs)) =>
Some((lpat, lhs), (rpat, rhs))
case _ => None
}
}

protected object UnValEqs {
def unapply(patRhs: (Tree, Tree)): Option[((Tree, Tree), List[(Tree, Tree)])] = patRhs match {
case (MaybeBind(SyntacticTuple(_)), UnMap(lhs, UnClosure(pat, UnPatSeqWithRes(pats, _)))) =>
Some((pat, lhs), pats)
case _ => None
}
}

protected abstract class AbstractUnProducts(inner: Boolean) {
private def maybeWith(tree: Tree) = if(inner) SyntacticWith(tree) else tree

def unapply(patRhs: (Tree, Tree)): Option[List[Tree]] = patRhs match {
case UnProductWithPat(UnProductsInner(rest), (pat, rhs)) =>
Some(rest :+ maybeWith(SyntacticValFrom(pat, rhs)))
case (pat, rhs) =>
Some(maybeWith(SyntacticValFrom(pat, rhs)) :: Nil)
}
}

protected object UnProducts extends AbstractUnProducts(inner = false)
protected object UnProductsInner extends AbstractUnProducts(inner = true)

// extractor of sequence of for comprehension enumerators which start with arrow generators
// joined using `with` keyword followed by any combination of filters and assignments,
// i.e. there are no `flatMap`s or `foreach`es in a sequence matched by this extractor
protected object UnForStep {
def unapply(patRhs: (Tree, Tree)): Option[List[Tree]] = patRhs match {
case UnValEqs(UnForStep(rest), pats) =>
val valeqs = pats.map { case (pat, rhs) => SyntacticValEq(pat, rhs) }
Some(rest ::: valeqs)
case UnFilterWithPat(UnForStep(rest), tests) =>
val filters = tests.map(SyntacticFilter(_))
Some(rest ::: filters)
case UnProducts(generators) =>
Some(generators)
}
}

// undo desugaring done in gen.mkFor
protected object UnFor {
def unapply(tree: Tree): Option[(List[Tree], Tree)] = {
val interm = tree match {
case UnFlatMap(UnFilter(rhs, filters), UnClosure(pat, UnFor(rest, body))) =>
Some(((pat, rhs), filters ::: rest, body))
case UnForeach(UnFilter(rhs, filters), UnClosure(pat, UnFor(rest, body))) =>
Some(((pat, rhs), filters ::: rest, body))
case UnMap(UnFilter(rhs, filters), UnClosure(pat, cbody)) =>
Some(((pat, rhs), filters, gen.Yield(cbody)))
case UnForeach(UnFilter(rhs, filters), UnClosure(pat, cbody)) =>
Some(((pat, rhs), filters, cbody))
case UnFlatMap(rhs, UnClosure(pat, UnFor(rest, body))) =>
Some(((pat, rhs), rest, body))
case UnForeach(rhs, UnClosure(pat, UnFor(rest, body))) =>
Some(((pat, rhs), rest, body))
case UnMap(rhs, UnClosure(pat, cbody)) =>
Some(((pat, rhs), Nil, gen.Yield(cbody)))
case UnForeach(rhs, UnClosure(pat, cbody)) =>
Some(((pat, rhs), Nil, cbody))
case _ => None
}
interm.flatMap {
case ((Bind(_, SyntacticTuple(_)) | SyntacticTuple(_),
UnFor(SyntacticValFrom(pat, rhs) :: innerRest, gen.Yield(UnPatSeqWithRes(pats, elems2)))),
outerRest, fbody) =>
val valeqs = pats.map { case (pat, rhs) => SyntacticValEq(pat, rhs) }
Some((SyntacticValFrom(pat, rhs) :: innerRest ::: valeqs ::: outerRest, fbody))
case ((pat, rhs), filters, body) =>
Some((SyntacticValFrom(pat, rhs) :: filters, body))
interm.map {
case (UnForStep(lrest), rest, body) =>
(lrest ::: rest, body)
}
}
}

// check that enumerators are valid
protected def mkEnumerators(enums: List[Tree]): List[Tree] = {
require(enums.nonEmpty, "enumerators can't be empty")
enums.head match {
case SyntacticValFrom(_, _) =>
case t => throw new IllegalArgumentException(s"$t is not a valid first enumerator of for loop")
}
enums.tail.foreach {
case SyntacticValEq(_, _) | SyntacticValFrom(_, _) | SyntacticFilter(_) =>
case t => throw new IllegalArgumentException(s"$t is not a valid representation of a for loop enumerator")
}
def validate(enums: List[Tree], first: Boolean = false, afterWith: Boolean = false): Unit = enums match {
case SyntacticValFrom(_, _) :: rest =>
validate(rest)
case SyntacticWith(SyntacticValFrom(_, _)) :: rest =>
validate(rest, afterWith = true)
case (SyntacticValEq(_, _) | SyntacticFilter(_)) :: rest if !first && !afterWith =>
validate(rest)
case t :: _ if first =>
throw new IllegalArgumentException(s"$t is not a valid first enumerator of for loop")
case t :: _ if afterWith =>
throw new IllegalArgumentException(s"$t is not a valid enumerator of for loop to follow `with` keyword")
case t :: _ =>
throw new IllegalArgumentException(s"$t is not a valid representation of a for loop enumerator")
case Nil if first =>
throw new IllegalArgumentException("enumerators can't be empty")
case Nil =>
}
validate(enums, first = true)
enums
}

Expand Down
2 changes: 2 additions & 0 deletions src/reflect/scala/reflect/internal/StdNames.scala
Expand Up @@ -748,6 +748,7 @@ trait StdNames {
val null_ : NameType = "null"
val pendingSuperCall: NameType = "pendingSuperCall"
val prefix : NameType = "prefix"
val product: NameType = "product"
val productArity: NameType = "productArity"
val productElement: NameType = "productElement"
val productIterator: NameType = "productIterator"
Expand Down Expand Up @@ -852,6 +853,7 @@ trait StdNames {
val SyntacticValEq: NameType = "SyntacticValEq"
val SyntacticValFrom: NameType = "SyntacticValFrom"
val SyntacticVarDef: NameType = "SyntacticVarDef"
val SyntacticWith: NameType = "SyntacticWith"

// unencoded operators
object raw {
Expand Down
48 changes: 44 additions & 4 deletions src/reflect/scala/reflect/internal/TreeGen.scala
Expand Up @@ -577,6 +577,25 @@ abstract class TreeGen {
}
}

object With {
def apply(tree: Tree) =
Select(tree, nme.WITHkw).updateAttachment(ForAttachment)
def unapply(tree: Tree): Option[Tree] = tree match {
case Select(enum, nme.WITHkw)
if tree.hasAttachment[ForAttachment.type] => Some(enum)
case _ => None
}
}

object MaybeWith {
def apply(tree: Tree, prod: Boolean) =
if (prod) With(tree) else tree
def unapply(tree: Tree): Option[(Tree, Boolean)] = tree match {
case With(t) => Some((t, true))
case t => Some((t, false))
}
}

/** Encode/decode body of for yield loop as q"`yield`($tree)" */
object Yield {
def apply(tree: Tree): Tree =
Expand Down Expand Up @@ -692,14 +711,35 @@ abstract class TreeGen {
rangePos(genpos.source, genpos.start, genpos.point, end)
}

def makeProduct(posOfWith: Position, lhs: Tree, rhs: Tree): Tree = {
val selectPos =
if (posOfWith == NoPosition) NoPosition
else if(lhs.pos == NoPosition) posOfWith
else rangePos(posOfWith.source, lhs.pos.start, posOfWith.point, posOfWith.end)
// there's no way to make these positions non-overlapping with patterns so they must be transparent
Apply(
Select(lhs, nme.product).setPos(selectPos.makeTransparent).updateAttachment(ForAttachment),
List(rhs)
).setPos(wrappingPos(posOfWith, List(lhs, rhs)).makeTransparent)
}

def maybeMakeWith(tree: Tree, posOfWith: Position, wrap: Boolean) =
if(wrap) With(tree).setPos(posOfWith union tree.pos) else tree

enums match {
case (t @ ValFrom(pat, rhs)) :: Nil =>
makeCombination(closurePos(t.pos), mapName, rhs, pat, body)
case (t @ ValFrom(pat, rhs)) :: (rest @ (ValFrom(_, _) :: _)) =>
makeCombination(closurePos(t.pos), flatMapName, rhs, pat,
mkFor(rest, sugarBody))
case (wt1 @ With(t1 @ ValFrom(pat1, rhs1))) :: (wt2 @ MaybeWith(t2 @ ValFrom(pat2, rhs2), hasWith)) :: rest =>
val pat = atPos((pat1.pos union pat2.pos).makeTransparent) { mkTuple(List(pat1, pat2)) }
val rhs = makeProduct(wt1.pos, rhs1, rhs2)
val combined = maybeMakeWith(ValFrom(pat, rhs).setPos(t1.pos union t2.pos), wt2.pos, hasWith)
mkFor(combined :: rest, sugarBody)
case (t @ ValFrom(pat, rhs)) :: (rest @ MaybeWith(ValFrom(_, _), _) :: _) =>
makeCombination(closurePos(t.pos), flatMapName, rhs, pat, mkFor(rest, sugarBody))
case (t @ ValFrom(pat, rhs)) :: Filter(test) :: rest =>
mkFor(ValFrom(pat, makeCombination(rhs.pos union test.pos, nme.withFilter, rhs, pat.duplicate, test)).setPos(t.pos) :: rest, sugarBody)
val filteredRhs = makeCombination(rhs.pos union test.pos, nme.withFilter, rhs, pat.duplicate, test)
val filtered = ValFrom(pat, filteredRhs).setPos(t.pos union test.pos)
mkFor(filtered :: rest, sugarBody)
case (t @ ValFrom(pat, rhs)) :: rest =>
val valeqs = rest.take(definitions.MaxTupleArity - 1).takeWhile { ValEq.unapply(_).nonEmpty }
assert(!valeqs.isEmpty)
Expand Down