Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concatenate strings if not formatting [ci: last-only] #10364

Merged
merged 9 commits into from Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
81 changes: 49 additions & 32 deletions src/compiler/scala/tools/reflect/FastStringInterpolator.scala
Expand Up @@ -70,40 +70,20 @@ trait FastStringInterpolator extends FormatInterpolator {
case iue: StringContext.InvalidUnicodeEscapeException => c.abort(parts.head.pos.withShift(iue.index), iue.getMessage)
}

val argsIndexed = args.toVector
val concatArgs = collection.mutable.ListBuffer[Tree]()
val numLits = parts.length
foreachWithIndex(treated.tail) { (lit, i) =>
val treatedContents = lit.asInstanceOf[Literal].value.stringValue
val emptyLit = treatedContents.isEmpty
if (i < numLits - 1) {
concatArgs += argsIndexed(i)
if (!emptyLit) concatArgs += lit
} else if (!emptyLit) {
concatArgs += lit
if (args.forall(treeInfo.isLiteralString)) {
val it1 = treated.iterator
val it2 = args.iterator
val res = new StringBuilder
def add(t: Tree) = res.append(t.asInstanceOf[Literal].value.value)
add(it1.next())
while (it2.hasNext) {
add(it2.next())
add(it1.next())
}
val k = Constant(res.toString)
Literal(k).setType(ConstantType(k))
}
def mkConcat(pos: Position, lhs: Tree, rhs: Tree): Tree =
atPos(pos)(gen.mkMethodCall(gen.mkAttributedSelect(lhs, definitions.String_+), rhs :: Nil)).setType(definitions.StringTpe)

var result: Tree = treated.head
val chunkSize = 32
if (concatArgs.lengthCompare(chunkSize) <= 0) {
concatArgs.foreach { t =>
result = mkConcat(t.pos, result, t)
}
} else {
concatArgs.toList.grouped(chunkSize).foreach {
case group =>
var chunkResult: Tree = Literal(Constant("")).setType(definitions.StringTpe)
group.foreach { t =>
chunkResult = mkConcat(t.pos, chunkResult, t)
}
result = mkConcat(chunkResult.pos, result, chunkResult)
}
}

result
else concatenate(treated, args)

// Fallback -- inline the original implementation of the `s` or `raw` interpolator.
case t@Apply(Select(someStringContext, _interpol), args) =>
Expand All @@ -116,4 +96,41 @@ trait FastStringInterpolator extends FormatInterpolator {
}"""
case x => throw new MatchError(x)
}

def concatenate(parts: List[Tree], args: List[Tree]): Tree = {
val argsIndexed = args.toVector
val concatArgs = collection.mutable.ListBuffer[Tree]()
val numLits = parts.length
foreachWithIndex(parts.tail) { (lit, i) =>
val treatedContents = lit.asInstanceOf[Literal].value.stringValue
val emptyLit = treatedContents.isEmpty
if (i < numLits - 1) {
concatArgs += argsIndexed(i)
if (!emptyLit) concatArgs += lit
} else if (!emptyLit) {
concatArgs += lit
}
}
def mkConcat(pos: Position, lhs: Tree, rhs: Tree): Tree =
atPos(pos)(gen.mkMethodCall(gen.mkAttributedSelect(lhs, definitions.String_+), rhs :: Nil)).setType(definitions.StringTpe)

var result: Tree = parts.head
val chunkSize = 32
if (concatArgs.lengthCompare(chunkSize) <= 0) {
concatArgs.foreach { t =>
result = mkConcat(t.pos, result, t)
}
} else {
concatArgs.toList.grouped(chunkSize).foreach {
case group =>
var chunkResult: Tree = Literal(Constant("")).setType(definitions.StringTpe)
group.foreach { t =>
chunkResult = mkConcat(t.pos, chunkResult, t)
}
result = mkConcat(chunkResult.pos, result, chunkResult)
}
}

result
}
}
50 changes: 35 additions & 15 deletions src/compiler/scala/tools/reflect/FormatInterpolator.scala
Expand Up @@ -21,7 +21,6 @@ import scala.util.matching.Regex.Match
import java.util.Formattable

abstract class FormatInterpolator {

import FormatInterpolator._
import SpecifierGroups.{Value => SpecGroup, _}

Expand All @@ -34,6 +33,8 @@ abstract class FormatInterpolator {

private def bail(msg: String) = global.abort(msg)

def concatenate(parts: List[Tree], args: List[Tree]): Tree

def interpolateF: Tree = c.macroApplication match {
//case q"$_(..$parts).f(..$args)" =>
case Applied(Select(Apply(_, parts), _), _, argss) =>
Expand Down Expand Up @@ -81,6 +82,9 @@ abstract class FormatInterpolator {
val actuals = ListBuffer.empty[Tree]
val convert = ListBuffer.empty[Conversion]

// whether this format does more than concatenate strings
var formatting = false

def argType(argi: Int, types: Type*): Type = {
val tpe = argTypes(argi)
types.find(t => argConformsTo(argi, tpe, t))
Expand All @@ -94,6 +98,7 @@ abstract class FormatInterpolator {
else all.head + all.tail.map { case req(what) => what case _ => "?" }.mkString(", ", ", ", "")
}
c.error(args(argi).pos, msg)
reported = true
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could reported be a local variable next to formatting? It's not obvious when instances of FormatInterpolator are created.

There are also some invocations of c.error in escapeHatch - does that matter?

Copy link
Contributor Author

@som-snytt som-snytt Jul 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cached FastTrack mechanism results in a ctx => (new { val c = ctx } with FastStringInterpolator).interpolateF, so a new interpolator per invocation. Not obvious.

reported is set also by Conversion#errorAt, so that is why it's not local. The flag is needed so that it doesn't attempt a bad compiletime format, so the bad escapes in escapeHatch should matter. I'll make a test.

Edit: there is already a test; took me a while to realize that format() doesn't complain about bad escapes in a format string, Scala does.

actuals += args(argi)
types.head
}
Expand All @@ -112,8 +117,8 @@ abstract class FormatInterpolator {
}

// Append the nth part to the string builder, possibly prepending an omitted %s first.
// Sanity-check the % fields in this part.
def loop(remaining: List[Tree], n: Int): Unit = {
// Check the % fields in this part.
def loop(remaining: List[Tree], n: Int): Unit =
remaining match {
case part0 :: more =>
val part1 = part0 match {
Expand All @@ -139,6 +144,8 @@ abstract class FormatInterpolator {
else if (!matches.hasNext) insertStringConversion()
else {
val cv = Conversion(matches.next(), part0.pos, argc)
if (cv.kind != Kind.StringXn || cv.cc.isUpper || cv.width.nonEmpty || cv.flags.nonEmpty)
formatting = true
if (cv.isLiteral) insertStringConversion()
else if (cv.isIndexed) {
if (cv.index.getOrElse(-1) == n) accept(cv)
Expand All @@ -155,16 +162,23 @@ abstract class FormatInterpolator {
val cv = Conversion(matches.next(), part0.pos, argc)
if (n == 0 && cv.hasFlag('<')) cv.badFlag('<', "No last arg")
else if (!cv.isLiteral && !cv.isIndexed) errorLeading(cv)
formatting = true
}
loop(more, n = n + 1)
case Nil =>
}
}
loop(parts, n = 0)

def constantly(s: String) = {
val k = Constant(s)
Literal(k).setType(ConstantType(k))
}

//q"{..$evals; new StringOps(${fstring.toString}).format(..$ids)}"
val format = amended.mkString
if (actuals.isEmpty && !format.contains("%")) Literal(Constant(format))
if (actuals.isEmpty && !formatting) constantly(format)
else if (!reported && actuals.forall(treeInfo.isLiteralString)) constantly(format.format(actuals.map(_.asInstanceOf[Literal].value.value).toIndexedSeq: _*))
else if (!formatting) concatenate(amended.map(p => constantly(p.stripPrefix("%s"))).toList, actuals.toList)
else {
val scalaPackage = Select(Ident(nme.ROOTPKG), TermName("scala"))
val newStringOps = Select(
Expand Down Expand Up @@ -225,12 +239,13 @@ abstract class FormatInterpolator {
val badFlags = flags.filterNot { case '-' | '<' => true case _ => false }
badFlags.isEmpty or badFlag(badFlags(0), s"Only '-' allowed for $msg")
}
def goodFlags = {
def goodFlags = flags.isEmpty || {
for (dupe <- flags.diff(flags.distinct).distinct) errorAt(Flags, flags.lastIndexOf(dupe))(s"Duplicate flag '$dupe'")
val badFlags = flags.filterNot(okFlags.contains(_))
for (f <- badFlags) badFlag(f, s"Illegal flag '$f'")
badFlags.isEmpty
}
def goodIndex = {
def goodIndex = !isIndexed || {
if (index.nonEmpty && hasFlag('<')) warningAt(Index)("Argument index ignored if '<' flag is present")
val okRange = index.map(i => i > 0 && i <= argc).getOrElse(true)
okRange || hasFlag('<') or errorAt(Index)("Argument index out of range")
Expand Down Expand Up @@ -389,19 +404,24 @@ object FormatInterpolator {
}
val suggest = {
val r = "([0-7]{1,3}).*".r
(s0 drop e.index + 1) match {
case r(n) => altOf { n.foldLeft(0){ case (a, o) => (8 * a) + (o - '0') } }
s0.drop(e.index + 1) match {
case r(n) => altOf(n.foldLeft(0) { case (a, o) => (8 * a) + (o - '0') })
case _ => ""
}
}
val txt =
if ("" == suggest) ""
else s"use $suggest instead"
txt
if (suggest.isEmpty) ""
else s"use $suggest instead"
}
def control(ctl: Char, i: Int, name: String) =
c.error(errPoint, s"\\$ctl is not supported, but for $name use \\u${f"$i%04x"};\n${e.getMessage}")
if (e.index == s0.length - 1) c.error(errPoint, """Trailing '\' escapes nothing.""")
else if (octalOf(s0(e.index + 1)) >= 0) c.error(errPoint, s"octal escape literals are unsupported: $alt")
else c.error(errPoint, e.getMessage)
else s0(e.index + 1) match {
case 'a' => control('a', 0x7, "alert or BEL")
case 'v' => control('v', 0xB, "vertical tab")
case 'e' => control('e', 0x1B, "escape")
case i if octalOf(i) >= 0 => c.error(errPoint, s"octal escape literals are unsupported: $alt")
case _ => c.error(errPoint, e.getMessage)
}
s0
}
}
4 changes: 2 additions & 2 deletions test/files/jvm/t7181/Foo_1.scala
Expand Up @@ -11,9 +11,9 @@ class Foo_1 {
} finally {
// this should be the only copy of the magic constant 3
// making it easy to detect copies of this finally block
println(s"finally ${3}")
println("finally " + 3)
}
println(s"normal flow")
println("normal flow")
}
}

Expand Down
43 changes: 23 additions & 20 deletions test/files/neg/stringinterpolation_macro-neg.check
Expand Up @@ -109,64 +109,67 @@ stringinterpolation_macro-neg.scala:43: error: '(' not allowed for a, A
stringinterpolation_macro-neg.scala:44: error: Only '-' allowed for date/time conversions
f"$t%#+ 0,(tT"
^
stringinterpolation_macro-neg.scala:47: error: precision not allowed
stringinterpolation_macro-neg.scala:45: error: Duplicate flag ','
f"$d%,,d"
^
stringinterpolation_macro-neg.scala:48: error: precision not allowed
f"$c%.2c"
^
stringinterpolation_macro-neg.scala:48: error: precision not allowed
stringinterpolation_macro-neg.scala:49: error: precision not allowed
f"$d%.2d"
^
stringinterpolation_macro-neg.scala:49: error: precision not allowed
stringinterpolation_macro-neg.scala:50: error: precision not allowed
f"%.2%"
^
stringinterpolation_macro-neg.scala:50: error: precision not allowed
stringinterpolation_macro-neg.scala:51: error: precision not allowed
f"%.2n"
^
stringinterpolation_macro-neg.scala:51: error: precision not allowed
stringinterpolation_macro-neg.scala:52: error: precision not allowed
f"$f%.2a"
^
stringinterpolation_macro-neg.scala:52: error: precision not allowed
stringinterpolation_macro-neg.scala:53: error: precision not allowed
f"$t%.2tT"
^
stringinterpolation_macro-neg.scala:55: error: No last arg
stringinterpolation_macro-neg.scala:56: error: No last arg
f"%<s"
^
stringinterpolation_macro-neg.scala:56: error: No last arg
stringinterpolation_macro-neg.scala:57: error: No last arg
f"%<c"
^
stringinterpolation_macro-neg.scala:57: error: No last arg
stringinterpolation_macro-neg.scala:58: error: No last arg
f"%<tT"
^
stringinterpolation_macro-neg.scala:58: error: Argument index out of range
stringinterpolation_macro-neg.scala:59: error: Argument index out of range
f"${8}%d ${9}%d%3$$d"
^
stringinterpolation_macro-neg.scala:59: error: Argument index out of range
stringinterpolation_macro-neg.scala:60: error: Argument index out of range
f"${8}%d ${9}%d%0$$d"
^
stringinterpolation_macro-neg.scala:67: error: type mismatch;
stringinterpolation_macro-neg.scala:68: error: type mismatch;
found : String
required: java.util.Formattable
f"$s%#s"
^
stringinterpolation_macro-neg.scala:70: error: 'G' doesn't seem to be a date or time conversion
stringinterpolation_macro-neg.scala:75: error: 'G' doesn't seem to be a date or time conversion
f"$t%tG"
^
stringinterpolation_macro-neg.scala:71: error: Date/time conversion must have two characters
stringinterpolation_macro-neg.scala:76: error: Date/time conversion must have two characters
f"$t%t"
^
stringinterpolation_macro-neg.scala:72: error: Missing conversion operator in '%10.5'; use %% for literal %, %n for newline
stringinterpolation_macro-neg.scala:77: error: Missing conversion operator in '%10.5'; use %% for literal %, %n for newline
f"$s%10.5"
^
stringinterpolation_macro-neg.scala:75: error: conversions must follow a splice; use %% for literal %, %n for newline
stringinterpolation_macro-neg.scala:80: error: conversions must follow a splice; use %% for literal %, %n for newline
f"${d}random-leading-junk%d"
^
stringinterpolation_macro-neg.scala:62: warning: Index is not this arg
stringinterpolation_macro-neg.scala:63: warning: Index is not this arg
f"${8}%d ${9}%1$$d"
^
stringinterpolation_macro-neg.scala:63: warning: Argument index ignored if '<' flag is present
stringinterpolation_macro-neg.scala:64: warning: Argument index ignored if '<' flag is present
f"$s%s $s%s %1$$<s"
^
stringinterpolation_macro-neg.scala:64: warning: Index is not this arg
stringinterpolation_macro-neg.scala:65: warning: Index is not this arg
f"$s%s $s%1$$s"
^
3 warnings
45 errors
46 errors
5 changes: 5 additions & 0 deletions test/files/neg/stringinterpolation_macro-neg.scala
Expand Up @@ -42,6 +42,7 @@ object Test extends App {
f"$d%+ (x"
f"$f%,(a"
f"$t%#+ 0,(tT"
f"$d%,,d"

// 4) bad precisions
f"$c%.2c"
Expand All @@ -65,6 +66,10 @@ object Test extends App {

// 6) bad arg types
f"$s%#s"
f"$f%f %<d"
f"%1$$d $f%f"
f"${null}%s %<#s"
// add tests from https://github.com/scala/scala/pull/4316/files

// 7) misunderstood conversions
f"$t%tG"
Expand Down
11 changes: 8 additions & 3 deletions test/files/neg/t8266-invalid-interp.check
Expand Up @@ -4,7 +4,12 @@ t8266-invalid-interp.scala:4: error: Trailing '\' escapes nothing.
t8266-invalid-interp.scala:5: error: invalid escape '\x' not one of [\b, \t, \n, \f, \r, \\, \", \', \uxxxx] at index 1 in "a\xc". Use \\ for literal \.
f"a\xc",
^
t8266-invalid-interp.scala:7: error: invalid escape '\v' not one of [\b, \t, \n, \f, \r, \\, \", \', \uxxxx] at index 1 in "a\vc". Use \\ for literal \.
f"a\vc"
t8266-invalid-interp.scala:7: error: \v is not supported, but for vertical tab use \u000b;
invalid escape '\v' not one of [\b, \t, \n, \f, \r, \\, \", \', \uxxxx] at index 1 in "a\vc". Use \\ for literal \.
f"a\vc",
^
3 errors
t8266-invalid-interp.scala:8: error: \v is not supported, but for vertical tab use \u000b;
invalid escape '\v' not one of [\b, \t, \n, \f, \r, \\, \", \', \uxxxx] at index 0 in "\v". Use \\ for literal \.
f"\v$x%.4s, Fred",
^
4 errors
5 changes: 3 additions & 2 deletions test/files/neg/t8266-invalid-interp.scala
@@ -1,9 +1,10 @@

trait X {
final val x = "hello, world"
def f = Seq(
f"""a\""",
f"a\xc",
// following could suggest \u000b for vertical tab, similar for \a alert
f"a\vc"
f"a\vc",
f"\v$x%.4s, Fred",
)
}
9 changes: 9 additions & 0 deletions test/files/neg/t8650.check
@@ -0,0 +1,9 @@
t8650.scala:19: warning: method s in class C is deprecated (since MyLib 17): hello world
def t = s
^
t8650.scala:21: warning: method f in class C is deprecated (since MyLib 17): hello world
def g = f
^
error: No warnings can be incurred under -Werror.
2 warnings
1 error