Skip to content

Commit

Permalink
Fix #8715: enforce syntax for _* (alternate version)
Browse files Browse the repository at this point in the history
Previously we didn't check that _* is indeed the last argument,
checking for ")" is not enough, as ")" may be the closing parenthesis
of a nested pattern.
  • Loading branch information
odersky committed Apr 20, 2020
1 parent b3f2aee commit 4639991
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 73 deletions.
2 changes: 1 addition & 1 deletion community-build/community-projects/stdLib213
83 changes: 46 additions & 37 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,13 @@ object Parsers {
def nonePositive: Boolean = parCounts forall (_ <= 0)
}

@sharable object Location extends Enumeration {
val InParens, InBlock, InPattern, ElseWhere: Value = Value
}
enum Location(val inParens: Boolean, val inPattern: Boolean, val inArgs: Boolean):
case InParens extends Location(true, false, false)
case InArgs extends Location(true, false, true)
case InPattern extends Location(false, true, false)
case InPatternArgs extends Location(false, true, true) // InParens not true, since it might be an alternative
case InBlock extends Location(false, false, false)
case ElseWhere extends Location(false, false, false)

@sharable object ParamOwner extends Enumeration {
val Class, Type, TypeParam, Def: Value = Value
Expand Down Expand Up @@ -1754,9 +1758,9 @@ object Parsers {
else TypeTree().withSpan(Span(in.lastOffset))
}

def typeDependingOn(location: Location.Value): Tree =
if (location == Location.InParens) typ()
else if (location == Location.InPattern) refinedType()
def typeDependingOn(location: Location): Tree =
if location.inParens then typ()
else if location.inPattern then refinedType()
else infixType()

/* ----------- EXPRESSIONS ------------------------------------------------ */
Expand Down Expand Up @@ -1843,7 +1847,7 @@ object Parsers {

def subExpr() = subPart(expr)

def expr(location: Location.Value): Tree = {
def expr(location: Location): Tree = {
val start = in.offset
def isSpecialClosureStart =
val lookahead = in.LookaheadScanner()
Expand Down Expand Up @@ -1876,7 +1880,7 @@ object Parsers {
}
}

def expr1(location: Location.Value = Location.ElseWhere): Tree = in.token match
def expr1(location: Location = Location.ElseWhere): Tree = in.token match
case IF =>
in.endMarkerScope(IF) { ifExpr(in.offset, If) }
case WHILE =>
Expand Down Expand Up @@ -1989,11 +1993,13 @@ object Parsers {
else expr1Rest(postfixExpr(), location)
end expr1

def expr1Rest(t: Tree, location: Location.Value): Tree = in.token match
def expr1Rest(t: Tree, location: Location): Tree = in.token match
case EQUALS =>
t match
case Ident(_) | Select(_, _) | Apply(_, _) =>
atSpan(startOffset(t), in.skipToken()) { Assign(t, subExpr()) }
atSpan(startOffset(t), in.skipToken()) {
Assign(t, subPart(() => expr(location)))
}
case _ =>
t
case COLON =>
Expand All @@ -2003,24 +2009,29 @@ object Parsers {
t
end expr1Rest

def ascription(t: Tree, location: Location.Value): Tree = atSpan(startOffset(t)) {
def ascription(t: Tree, location: Location): Tree = atSpan(startOffset(t)) {
in.token match {
case USCORE =>
val uscoreStart = in.skipToken()
if (isIdent(nme.raw.STAR)) {
if isIdent(nme.raw.STAR) then
in.nextToken()
if (in.token != RPAREN) syntaxError(SeqWildcardPatternPos(), uscoreStart)
if !(location.inArgs && in.token == RPAREN) then
if opStack.nonEmpty
ctx.errorOrMigrationWarning(
em"""`_*` can be used only for last argument of method application.
|It is no longer allowed in operands of infix operations.""",
in.sourcePos(uscoreStart))
else
syntaxError(SeqWildcardPatternPos(), uscoreStart)
Typed(t, atSpan(uscoreStart) { Ident(tpnme.WILDCARD_STAR) })
}
else {
else
syntaxErrorOrIncomplete(IncorrectRepeatedParameterSyntax())
t
}
case AT if location != Location.InPattern =>
case AT if !location.inPattern =>
annotations().foldLeft(t)(Annotated)
case _ =>
val tpt = typeDependingOn(location)
if (isWildcard(t) && location != Location.InPattern) {
if (isWildcard(t) && !location.inPattern) {
val vd :: rest = placeholderParams
placeholderParams =
cpy.ValDef(vd)(tpt = tpt).withSpan(vd.span.union(tpt.span)) :: rest
Expand Down Expand Up @@ -2063,7 +2074,7 @@ object Parsers {
* | `_'
* Bindings ::= `(' [[‘using’] [‘erased’] Binding {`,' Binding}] `)'
*/
def funParams(mods: Modifiers, location: Location.Value): List[Tree] =
def funParams(mods: Modifiers, location: Location): List[Tree] =
if in.token == LPAREN then
in.nextToken()
if in.token == RPAREN then
Expand Down Expand Up @@ -2117,10 +2128,10 @@ object Parsers {
/** Expr ::= [‘implicit’] FunParams `=>' Expr
* BlockResult ::= implicit id [`:' InfixType] `=>' Block // Scala2 only
*/
def closure(start: Int, location: Location.Value, implicitMods: Modifiers): Tree =
def closure(start: Int, location: Location, implicitMods: Modifiers): Tree =
closureRest(start, location, funParams(implicitMods, location))

def closureRest(start: Int, location: Location.Value, params: List[Tree]): Tree =
def closureRest(start: Int, location: Location, params: List[Tree]): Tree =
atSpan(start, in.offset) {
if in.token == CTXARROW then in.nextToken() else accept(ARROW)
Function(params, if (location == Location.InBlock) block() else expr())
Expand Down Expand Up @@ -2295,10 +2306,9 @@ object Parsers {
if args._2 then res.setUsingApply()
res

val argumentExpr: () => Tree = () => exprInParens() match {
val argumentExpr: () => Tree = () => expr(Location.InArgs) match
case arg @ Assign(Ident(id), rhs) => cpy.NamedArg(arg)(id, rhs)
case arg => arg
}

/** ArgumentExprss ::= {ArgumentExprs}
*/
Expand Down Expand Up @@ -2535,21 +2545,20 @@ object Parsers {

/** Pattern ::= Pattern1 { `|' Pattern1 }
*/
val pattern: () => Tree = () => {
val pat = pattern1()
def pattern(location: Location = Location.InPattern): Tree =
val pat = pattern1(location)
if (isIdent(nme.raw.BAR))
atSpan(startOffset(pat)) { Alternative(pat :: patternAlts()) }
atSpan(startOffset(pat)) { Alternative(pat :: patternAlts(location)) }
else pat
}

def patternAlts(): List[Tree] =
if (isIdent(nme.raw.BAR)) { in.nextToken(); pattern1() :: patternAlts() }
def patternAlts(location: Location): List[Tree] =
if (isIdent(nme.raw.BAR)) { in.nextToken(); pattern1(location) :: patternAlts(location) }
else Nil

/** Pattern1 ::= Pattern2 [Ascription]
* | ‘given’ PatVar ‘:’ RefinedType
*/
def pattern1(): Tree =
def pattern1(location: Location = Location.InPattern): Tree =
if (in.token == GIVEN) {
val givenMod = atSpan(in.skipToken())(Mod.Given())
atSpan(in.offset) {
Expand All @@ -2558,7 +2567,7 @@ object Parsers {
val name = in.name
in.nextToken()
accept(COLON)
val typed = ascription(Ident(nme.WILDCARD), Location.InPattern)
val typed = ascription(Ident(nme.WILDCARD), location)
Bind(name, typed).withMods(addMod(Modifiers(), givenMod))
case _ =>
syntaxErrorOrIncomplete("pattern variable expected")
Expand All @@ -2570,7 +2579,7 @@ object Parsers {
val p = pattern2()
if (in.token == COLON) {
in.nextToken()
ascription(p, Location.InPattern)
ascription(p, location)
}
else p
}
Expand Down Expand Up @@ -2661,17 +2670,17 @@ object Parsers {

/** Patterns ::= Pattern [`,' Pattern]
*/
def patterns(): List[Tree] = commaSeparated(pattern)

def patternsOpt(): List[Tree] =
if (in.token == RPAREN) Nil else patterns()
def patterns(location: Location = Location.InPattern): List[Tree] =
commaSeparated(() => pattern(location))

def patternsOpt(location: Location = Location.InPattern): List[Tree] =
if (in.token == RPAREN) Nil else patterns(location)

/** ArgumentPatterns ::= ‘(’ [Patterns] ‘)’
* | ‘(’ [Patterns ‘,’] Pattern2 ‘:’ ‘_’ ‘*’ ‘)’
*/
def argumentPatterns(): List[Tree] =
inParens(patternsOpt())
inParens(patternsOpt(Location.InPatternArgs))

/* -------- MODIFIERS and ANNOTATIONS ------------------------------------------- */

Expand Down
2 changes: 1 addition & 1 deletion compiler/test/dotty/tools/vulpix/ParallelTesting.scala
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ trait ParallelTesting extends RunnerOrchestration { self =>
lazy val actualErrors = reporters.foldLeft(0)(_ + _.errorCount)
def hasMissingAnnotations = getMissingExpectedErrors(errorMap, reporters.iterator.flatMap(_.errors))
def showErrors = "-> following the errors:\n" +
reporters.flatMap(_.allErrors.map(e => e.pos.toString + ": " + e.message)).mkString(start = "at ", sep = "\n at ", end = "")
reporters.flatMap(_.allErrors.map(e => e.pos.line.toString + ": " + e.message)).mkString(start = "at ", sep = "\n at ", end = "")

if (compilerCrashed) Some(s"Compiler crashed when compiling: ${testSource.title}")
else if (actualErrors == 0) Some(s"\nNo errors found when compiling neg test $testSource")
Expand Down
10 changes: 5 additions & 5 deletions tests/neg/i7972.scala
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
object O {
def m1(a: Int*) = (a: _*) // error: Cannot return repeated parameter type Int*
def m1(a: Int*) = (a: _*) // error // error: Cannot return repeated parameter type Int*
def m2(a: Int*) = { // error: Cannot return repeated parameter type Int*
val b = (a: _*) // error: Cannot return repeated parameter type Int*
val b = (a: _*) // error // error: Cannot return repeated parameter type Int*
b
}
def m3(a: Int*): Any = {
val b = (a: _*) // error: Cannot return repeated parameter type Int*
val b = (a: _*) // error // error: Cannot return repeated parameter type Int*
b
}
def m4(a: 2*) = (a: _*) // error: Cannot return repeated parameter type Int*
def m4(a: 2*) = (a: _*) // error // error: Cannot return repeated parameter type Int*

}

class O(a: Int*) {
val m = (a: _*) // error: Cannot return repeated parameter type Int*
val m = (a: _*) // error // error: Cannot return repeated parameter type Int*
}

2 changes: 2 additions & 0 deletions tests/neg/i8715.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
@main
def Test = List(42) match { case List(xs @ (ys: _*)) => xs } // error
29 changes: 29 additions & 0 deletions tests/neg/t5702-neg-bad-and-wild.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@

object Test {
case class K(i: Int)

def main(args: Array[String]) = {
val k = new K(9)
val is = List(1,2,3)

is match {
case List(1, _*,) => // error // error // error: bad use of _* (a sequence pattern must be the last pattern)
// illegal start of simple pattern
case List(1, _*3,) => // error // error: illegal start of simple pattern
//case List(1, _*3:) => // poor recovery by parens
case List(1, x*) => // error: use _* to match a sequence
case List(x*, 1) => // error: trailing * is not a valid pattern
case (1, x*) => // error: trailing * is not a valid pattern
case (1, x: _*) => // error: bad use of _* (sequence pattern not allowed)
}

// good syntax, bad semantics, detected by typer
//gowild.scala:14: error: star patterns must correspond with varargs parameters
val K(x @ _*) = k
val K(ns @ _*, x) = k // error: bad use of _* (a sequence pattern must be the last pattern)
val (b, _ : _* ) = (5,6) // error: bad use of _* (sequence pattern not allowed)
// no longer complains
//bad-and-wild.scala:15: error: ')' expected but '}' found.
}
}

4 changes: 4 additions & 0 deletions tests/pos-scala2/i8715b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
// from stdlib
class Test {
def printf(text: String, args: Any*): Unit = { System.out.print(text format (args : _*)) }
}
29 changes: 0 additions & 29 deletions tests/untried/neg/t5702-neg-bad-and-wild.scala

This file was deleted.

0 comments on commit 4639991

Please sign in to comment.