From 74de6d57c9930b3639ed0c65b51bc01b3807cad3 Mon Sep 17 00:00:00 2001 From: Som Snytt Date: Sun, 4 Jul 2021 02:39:02 -0700 Subject: [PATCH] Accept supplementary characters --- build.sbt | 1 + .../scala/tools/nsc/ast/parser/Scanners.scala | 120 +++++++++++++++--- .../symtab/classfile/AbstractFileReader.scala | 10 +- src/partest/scala/tools/partest/package.scala | 13 -- .../scala/tools/testkit/AssertUtil.scala | 19 +++ test/files/neg/surrogates.check | 4 + test/files/neg/surrogates.scala | 4 + test/files/pos/surrogates.scala | 12 ++ test/files/run/t12276.scala | 3 +- test/files/run/t9915/Test_2.scala | 14 +- .../scala/tools/testkit/AssertUtilTest.scala | 6 + 11 files changed, 162 insertions(+), 44 deletions(-) create mode 100644 test/files/neg/surrogates.check create mode 100644 test/files/neg/surrogates.scala create mode 100644 test/files/pos/surrogates.scala diff --git a/build.sbt b/build.sbt index 82208895c7ad..cf8270840c5c 100644 --- a/build.sbt +++ b/build.sbt @@ -717,6 +717,7 @@ lazy val junit = project.in(file("test") / "junit") "-feature", "-Xlint:-valpattern,_", "-Wconf:msg=match may not be exhaustive:s", // if we missed a case, all that happens is the test fails + "-Wconf:cat=lint-nullary-unit:s", // normal unit test style "-Ypatmat-exhaust-depth", "40", // despite not caring about patmat exhaustiveness, we still get warnings for this ), Compile / javacOptions ++= Seq("-Xlint"), diff --git a/src/compiler/scala/tools/nsc/ast/parser/Scanners.scala b/src/compiler/scala/tools/nsc/ast/parser/Scanners.scala index 17b46da9191c..6c8a0cf3abf0 100644 --- a/src/compiler/scala/tools/nsc/ast/parser/Scanners.scala +++ b/src/compiler/scala/tools/nsc/ast/parser/Scanners.scala @@ -700,7 +700,10 @@ trait Scanners extends ScannersCommon { else if (!isAtEnd && (ch != SU && ch != CR && ch != LF)) { val isEmptyCharLit = (ch == '\'') getLitChar() - if (ch == '\'') { + if (Character.isHighSurrogate(cbuf.charAt(0))) { + syntaxError("illegal codepoint in Char constant") + if (ch == '\'') nextChar() + } else if (ch == '\'') { if (isEmptyCharLit) syntaxError("empty character literal (use '\\'' for single quote)") else { @@ -749,22 +752,40 @@ trait Scanners extends ScannersCommon { } case _ => def fetchOther() = { + import Character.{isHighSurrogate, isLowSurrogate, isUnicodeIdentifierStart, isValidCodePoint, toCodePoint} if (ch == '\u21D2') { deprecationWarning("The unicode arrow `⇒` is deprecated, use `=>` instead. If you still wish to display it as one character, consider using a font with programming ligatures such as Fira Code.", "2.13.0") nextChar(); token = ARROW } else if (ch == '\u2190') { deprecationWarning("The unicode arrow `←` is deprecated, use `<-` instead. If you still wish to display it as one character, consider using a font with programming ligatures such as Fira Code.", "2.13.0") nextChar(); token = LARROW - } else if (Character.isUnicodeIdentifierStart(ch)) { + } else if (isUnicodeIdentifierStart(ch)) { putChar(ch) nextChar() getIdentRest() + } else if (isHighSurrogate(ch)) { + val high = ch + nextChar() + if (isLowSurrogate(ch)) { + val low = ch + nextChar() + val codepoint = toCodePoint(high, low) + if (isValidCodePoint(codepoint) && isUnicodeIdentifierStart(codepoint)) { + putChar(high) + putChar(low) + getIdentRest() + } else { + syntaxError(f"illegal character '\\u${high.toInt}%04x\\u${low.toInt}%04x'") + } + } else { + syntaxError(f"illegal character '\\u${high.toInt}%04x' missing low surrogate") + } } else if (isSpecial(ch)) { putChar(ch) nextChar() getOperatorRest() } else { - syntaxError("illegal character '" + ("" + '\\' + 'u' + "%04x".format(ch.toInt)) + "'") + syntaxError(f"illegal character '\\u${ch.toInt}%04x'") nextChar() } } @@ -831,10 +852,28 @@ trait Scanners extends ScannersCommon { case SU => // strangely enough, Character.isUnicodeIdentifierPart(SU) returns true! finishNamed() case _ => - if (Character.isUnicodeIdentifierPart(ch)) { + import Character.{isHighSurrogate, isLowSurrogate, isUnicodeIdentifierPart, isValidCodePoint, toCodePoint} + if (isUnicodeIdentifierPart(ch)) { putChar(ch) nextChar() getIdentRest() + } else if (isHighSurrogate(ch)) { + val high = ch + nextChar() + if (isLowSurrogate(ch)) { + val low = ch + nextChar() + val codepoint = toCodePoint(high, low) + if (isValidCodePoint(codepoint) && isUnicodeIdentifierPart(codepoint)) { + putChar(high) + putChar(low) + getIdentRest() + } else { + syntaxError(f"illegal character '\\u${high.toInt}%04x\\u${low.toInt}%04x'") + } + } else { + syntaxError(f"illegal character '\\u${high.toInt}%04x' missing low surrogate") + } } else { finishNamed() } @@ -955,6 +994,38 @@ trait Scanners extends ScannersCommon { } getStringPart(multiLine, seenEscapedQuote || q) } else if (ch == '$') { + import Character.{isHighSurrogate, isLowSurrogate, isUnicodeIdentifierPart, isUnicodeIdentifierStart, isValidCodePoint, toCodePoint} + def isUnicodeSurrogate(ch: Char, f: Int => Boolean): Boolean = + isHighSurrogate(ch) && { + val hi = ch + val r = lookaheadReader + r.nextRawChar() + val lo = r.ch + isLowSurrogate(lo) && { + val codepoint = toCodePoint(hi, lo) + isValidCodePoint(codepoint) && f(codepoint) + } + } + @tailrec def getInterpolatedIdentRest(): Unit = + if (ch != SU && isUnicodeIdentifierPart(ch)) { + putChar(ch) + nextRawChar() + getInterpolatedIdentRest() + } else if (isUnicodeSurrogate(ch, isUnicodeIdentifierPart)) { + putChar(ch) + nextRawChar() + putChar(ch) + nextRawChar() + getInterpolatedIdentRest() + } else { + next.token = IDENTIFIER + next.name = newTermName(cbuf.toString) + cbuf.clear() + val idx = next.name.start - kwOffset + if (idx >= 0 && idx < kwArray.length) { + next.token = kwArray(idx) + } + } nextRawChar() if (ch == '$' || ch == '"') { putChar(ch) @@ -968,19 +1039,18 @@ trait Scanners extends ScannersCommon { finishStringPart() nextRawChar() next.token = USCORE - } else if (Character.isUnicodeIdentifierStart(ch)) { + } else if (isUnicodeIdentifierStart(ch)) { finishStringPart() - do { - putChar(ch) - nextRawChar() - } while (ch != SU && Character.isUnicodeIdentifierPart(ch)) - next.token = IDENTIFIER - next.name = newTermName(cbuf.toString) - cbuf.clear() - val idx = next.name.start - kwOffset - if (idx >= 0 && idx < kwArray.length) { - next.token = kwArray(idx) - } + putChar(ch) + nextRawChar() + getInterpolatedIdentRest() + } else if (isUnicodeSurrogate(ch, isUnicodeIdentifierStart)) { + finishStringPart() + putChar(ch) + nextRawChar() + putChar(ch) + nextRawChar() + getInterpolatedIdentRest() } else { val expectations = "$$, $\", $identifier or ${expression}" syntaxError(s"invalid string interpolation $$$ch, expected: $expectations") @@ -1068,7 +1138,23 @@ trait Scanners extends ScannersCommon { nextChar() } } - } else { + } else if (Character.isHighSurrogate(ch)) { + val high = ch + nextChar() + if (Character.isLowSurrogate(ch)) { + val low = ch + nextChar() + val codepoint = Character.toCodePoint(high, low) + if (Character.isValidCodePoint(codepoint)) { + putChar(high) + putChar(low) + } else { + syntaxError(f"illegal character '\\u${high.toInt}%04x\\u${low.toInt}%04x'") + } + } else { + syntaxError(f"illegal character '\\u${high.toInt}%04x' missing low surrogate") + } + } else { putChar(ch) nextChar() } diff --git a/src/compiler/scala/tools/nsc/symtab/classfile/AbstractFileReader.scala b/src/compiler/scala/tools/nsc/symtab/classfile/AbstractFileReader.scala index ca1378e6c87e..faf69d5769e3 100644 --- a/src/compiler/scala/tools/nsc/symtab/classfile/AbstractFileReader.scala +++ b/src/compiler/scala/tools/nsc/symtab/classfile/AbstractFileReader.scala @@ -27,9 +27,7 @@ import scala.tools.nsc.io.AbstractFile */ final class AbstractFileReader(val buf: Array[Byte]) extends DataReader { @deprecated("Use other constructor", "2.13.0") - def this(file: AbstractFile) = { - this(file.toByteArray) - } + def this(file: AbstractFile) = this(file.toByteArray) /** the current input pointer */ @@ -67,9 +65,8 @@ final class AbstractFileReader(val buf: Array[Byte]) extends DataReader { def getByte(mybp: Int): Byte = buf(mybp) - def getBytes(mybp: Int, bytes: Array[Byte]): Unit = { + def getBytes(mybp: Int, bytes: Array[Byte]): Unit = System.arraycopy(buf, mybp, bytes, 0, bytes.length) - } /** extract a character at position bp from buf */ @@ -95,9 +92,8 @@ final class AbstractFileReader(val buf: Array[Byte]) extends DataReader { */ def getDouble(mybp: Int): Double = longBitsToDouble(getLong(mybp)) - def getUTF(mybp: Int, len: Int): String = { + def getUTF(mybp: Int, len: Int): String = new DataInputStream(new ByteArrayInputStream(buf, mybp, len)).readUTF - } /** skip next 'n' bytes */ diff --git a/src/partest/scala/tools/partest/package.scala b/src/partest/scala/tools/partest/package.scala index d3e5f070eed9..0599766c141a 100644 --- a/src/partest/scala/tools/partest/package.scala +++ b/src/partest/scala/tools/partest/package.scala @@ -180,17 +180,4 @@ package object partest { def isDebug = sys.props.contains("partest.debug") || sys.env.contains("PARTEST_DEBUG") def debugSettings = sys.props.getOrElse("partest.debug.settings", "") def log(msg: => Any): Unit = if (isDebug) Console.err.println(msg) - - private val printable = raw"\p{Print}".r - - def hexdump(s: String): Iterator[String] = { - var offset = 0 - def hex(bytes: Array[Byte]) = bytes.map(b => f"$b%02x").mkString(" ") - def charFor(byte: Byte): Char = byte.toChar match { case c @ printable() => c ; case _ => '.' } - def ascii(bytes: Array[Byte]) = bytes.map(charFor).mkString - def format(bytes: Array[Byte]): String = - f"$offset%08x ${hex(bytes.slice(0, 8))}%-24s ${hex(bytes.slice(8, 16))}%-24s |${ascii(bytes)}|" - .tap(_ => offset += bytes.length) - s.getBytes(codec.charSet).grouped(16).map(format) - } } diff --git a/src/testkit/scala/tools/testkit/AssertUtil.scala b/src/testkit/scala/tools/testkit/AssertUtil.scala index 4b7083d83e2c..f6087c22d258 100644 --- a/src/testkit/scala/tools/testkit/AssertUtil.scala +++ b/src/testkit/scala/tools/testkit/AssertUtil.scala @@ -51,6 +51,25 @@ object AssertUtil { // junit fail is Unit def fail(message: String): Nothing = throw new AssertionError(message) + private val printable = raw"\p{Print}".r + + def hexdump(s: String): Iterator[String] = { + import scala.io.Codec + val codec: Codec = Codec.UTF8 + var offset = 0 + def hex(bytes: Array[Byte]) = bytes.map(b => f"$b%02x").mkString(" ") + def charFor(byte: Byte): Char = byte.toChar match { case c @ printable() => c ; case _ => '.' } + def ascii(bytes: Array[Byte]) = bytes.map(charFor).mkString + def format(bytes: Array[Byte]): String = + f"$offset%08x ${hex(bytes.slice(0, 8))}%-24s ${hex(bytes.slice(8, 16))}%-24s |${ascii(bytes)}|" + .tap(_ => offset += bytes.length) + s.getBytes(codec.charSet).grouped(16).map(format) + } + + private def dump(s: String) = hexdump(s).mkString("\n") + def assertEqualStrings(expected: String)(actual: String) = + assert(expected == actual, s"Expected:\n${dump(expected)}\nActual:\n${dump(actual)}") + private final val timeout = 60 * 1000L // wait a minute private implicit class `ref helper`[A](val r: Reference[A]) extends AnyVal { diff --git a/test/files/neg/surrogates.check b/test/files/neg/surrogates.check new file mode 100644 index 000000000000..ed95ed3eca5c --- /dev/null +++ b/test/files/neg/surrogates.check @@ -0,0 +1,4 @@ +surrogates.scala:3: error: illegal codepoint in Char constant + def c = '𐐀' + ^ +1 error diff --git a/test/files/neg/surrogates.scala b/test/files/neg/surrogates.scala new file mode 100644 index 000000000000..f9c91438627f --- /dev/null +++ b/test/files/neg/surrogates.scala @@ -0,0 +1,4 @@ + +class C { + def c = '𐐀' +} diff --git a/test/files/pos/surrogates.scala b/test/files/pos/surrogates.scala new file mode 100644 index 000000000000..24db6cf86215 --- /dev/null +++ b/test/files/pos/surrogates.scala @@ -0,0 +1,12 @@ + +class 𐐀 { + def 𐐀 = 42 + def x = "𐐀" + def y = s"$𐐀" +} + +case class 𐐀𐐀(n: Int) { + def 𐐀𐐀 = n +} + +// was: error: illegal character '\ud801', '\udc00' diff --git a/test/files/run/t12276.scala b/test/files/run/t12276.scala index 50ef6b0edc5e..36fbbbc6c558 100644 --- a/test/files/run/t12276.scala +++ b/test/files/run/t12276.scala @@ -1,6 +1,7 @@ import scala.tools.nsc.Settings import scala.tools.nsc.interpreter.shell.{ILoop, ShellConfig} -import scala.tools.partest.{hexdump, ReplTest} +import scala.tools.partest.ReplTest +import scala.tools.testkit.AssertUtil.hexdump object Test extends ReplTest { def code = s""" diff --git a/test/files/run/t9915/Test_2.scala b/test/files/run/t9915/Test_2.scala index afed667cc6e5..f26f1c1a3d91 100644 --- a/test/files/run/t9915/Test_2.scala +++ b/test/files/run/t9915/Test_2.scala @@ -1,12 +1,14 @@ +import scala.tools.testkit.AssertUtil.assertEqualStrings + object Test extends App { val c = new C_1 - assert(c.nulled == "X\u0000ABC") // "X\000ABC" - assert(c.supped == "𐒈𐒝𐒑𐒛𐒐𐒘𐒕𐒖") + assert(C_1.NULLED.length == "XYABC".length) + assert(C_1.SUPPED.codePointCount(0, C_1.SUPPED.length) == 8) - assert(C_1.NULLED == "X\u0000ABC") // "X\000ABC" - assert(C_1.SUPPED == "𐒈𐒝𐒑𐒛𐒐𐒘𐒕𐒖") + assertEqualStrings(c.nulled)("X\u0000ABC") // "X\000ABC" in java source + assertEqualStrings(c.supped)("𐒈𐒝𐒑𐒛𐒐𐒘𐒕𐒖") - assert(C_1.NULLED.size == "XYABC".size) - assert(C_1.SUPPED.codePointCount(0, C_1.SUPPED.length) == 8) + assertEqualStrings(C_1.NULLED)("X\u0000ABC") // "X\000ABC" in java source + assertEqualStrings(C_1.SUPPED)("𐒈𐒝𐒑𐒛𐒐𐒘𐒕𐒖") } diff --git a/test/junit/scala/tools/testkit/AssertUtilTest.scala b/test/junit/scala/tools/testkit/AssertUtilTest.scala index 98e2c0308553..90e98e1598e3 100644 --- a/test/junit/scala/tools/testkit/AssertUtilTest.scala +++ b/test/junit/scala/tools/testkit/AssertUtilTest.scala @@ -110,4 +110,10 @@ class AssertUtilTest { assertEquals(1, sut.errors.size) assertEquals(0, sut.errors.head._2.getSuppressed.length) } + + /** TODO + @Test def `hexdump is supplementary-aware`: Unit = { + assertEquals("00000000 f0 90 90 80 |𐐀.|", hexdump("\ud801\udc00").next()) + } + */ }