Skip to content

Commit

Permalink
improvement: Use java.lang.StringBuilderfor optimized concatation o…
Browse files Browse the repository at this point in the history
…f Strings in NIR codegen (#3640)
  • Loading branch information
WojciechMazur committed Jan 2, 2024
1 parent 51dc81e commit f2929b1
Show file tree
Hide file tree
Showing 7 changed files with 358 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,21 @@ trait NirDefinitions {
lazy val JavaProperties = getRequiredClass("java.util.Properties")

lazy val StringConcatMethod = getMember(StringClass, TermName("concat"))
lazy val String_valueOf_Object =
getMember(StringModule, nme.valueOf).filter(sym =>
sym.info.paramTypes match {
case List(pt) => pt.typeSymbol == ObjectClass
case _ => false
}
)
lazy val jlStringBuilderRef = getRequiredClass("java.lang.StringBuilder")
lazy val jlStringBuilderType = jlStringBuilderRef.toType
lazy val jlStringBuilderAppendAlts =
getMemberMethod(jlStringBuilderRef, TermName("append")).alternatives
lazy val jlStringBufferRef = getRequiredClass("java.lang.StringBuffer")
lazy val jlStringBufferType = jlStringBufferRef.toType
lazy val jlCharSequenceRef = getRequiredClass("java.lang.CharSequence")
lazy val jlCharSequenceType = jlCharSequenceRef.toType

lazy val BoxMethod = Map[Char, Symbol](
'B' -> getDecl(BoxesRunTimeModule, TermName("boxToBoolean")),
Expand Down
160 changes: 127 additions & 33 deletions nscplugin/src/main/scala-2/scala/scalanative/nscplugin/NirGenExpr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1207,7 +1207,7 @@ trait NirGenExpr[G <: nsc.Global with Singleton] { self: NirGenPhase[G] =>
val sym = app.symbol
val code = scalaPrimitives.getPrimitive(sym, receiver.tpe)
(code: @switch) match {
case CONCAT => genStringConcat(receiver, args.head)
case CONCAT => genStringConcat(app)
case HASH => genHashCode(args.head)
case CFUNCPTR_APPLY => genCFuncPtrApply(app, code)
case CFUNCPTR_FROM_FUNCTION => genCFuncFromScalaFunction(app)
Expand Down Expand Up @@ -1810,43 +1810,137 @@ trait NirGenExpr[G <: nsc.Global with Singleton] { self: NirGenPhase[G] =>
abort(s"can't perform binary operation between $lty and $rty")
}

def genStringConcat(leftp: Tree, rightp: Tree): nir.Val = {
def stringify(sym: Symbol, value: nir.Val)(implicit
pos: nir.Position
): nir.Val = {
val cond = ContTree { () =>
buf.comp(nir.Comp.Ieq, nir.Rt.Object, value, nir.Val.Null, unwind)
}
val thenp = ContTree { () => nir.Val.String("null") }
val elsep = ContTree { () =>
if (sym == StringClass) {
value
} else {
val meth = Object_toString
genApplyMethod(meth, statically = false, value, Seq.empty)
}
/*
* Returns a list of trees that each should be concatenated, from left to right.
* It turns a chained call like "a".+("b").+("c") into a list of arguments.
*/
def liftStringConcat(tree: Tree): List[Tree] = {
val result = collection.mutable.ListBuffer[Tree]()
def loop(tree: Tree): Unit = {
tree match {
case Apply(fun @ Select(larg, method), rarg :: Nil)
if (scalaPrimitives.isPrimitive(fun.symbol) &&
scalaPrimitives.getPrimitive(fun.symbol) ==
scalaPrimitives.CONCAT) =>
loop(larg)
loop(rarg)
case _ =>
result += tree
}
genIf(nir.Rt.String, cond, thenp, elsep)
}
loop(tree)
result.toList
}

/* Issue a call to `StringBuilder#append` for the right element type */
private final def genStringBuilderAppend(
stringBuilder: nir.Val.Local,
tree: Tree
): Unit = {
implicit val nirPos: nir.Position = tree.pos

val tpe = tree.tpe
val argType =
if (tpe <:< defn.StringTpe) nir.Rt.String
else if (tpe <:< nirDefinitions.jlStringBufferType)
genType(nirDefinitions.jlStringBufferRef)
else if (tpe <:< nirDefinitions.jlCharSequenceType)
genType(nirDefinitions.jlCharSequenceRef)
// Don't match for `Array(Char)`, even though StringBuilder has such an overload:
// `"a" + Array('b')` should NOT be "ab", but "a[C@...".
else if (tpe <:< defn.ObjectTpe) nir.Rt.Object
else genType(tpe)

val value = genExpr(tree)
val (adaptedValue, targetType) = argType match {
// jlStringBuilder does not have overloads for byte and short, but we can just use the int version
case nir.Type.Byte | nir.Type.Short =>
genCoercion(value, value.ty, nir.Type.Int) -> nir.Type.Int
case nirType => value -> nirType
}

val (appendFunction, appendSig) =
jlStringBuilderAppendForSymbol(targetType)
buf.call(
appendSig,
appendFunction,
Seq(stringBuilder, adaptedValue),
unwind
)
}

val left = {
implicit val pos: nir.Position = leftp.pos

val typesym = leftp.tpe.typeSymbol
val unboxed = genExpr(leftp)
val boxed = boxValue(typesym, unboxed)
stringify(typesym, boxed)
}
private lazy val jlStringBuilderRef =
nir.Type.Ref(genTypeName(nirDefinitions.jlStringBuilderRef))
private lazy val jlStringBuilderCtor =
jlStringBuilderRef.name.member(nir.Sig.Ctor(Seq(nir.Type.Int)))
private lazy val jlStringBuilderCtorSig = nir.Type.Function(
Seq(jlStringBuilderRef, nir.Type.Int),
nir.Type.Unit
)
private lazy val jlStringBuilderToString =
jlStringBuilderRef.name.member(
nir.Sig.Method("toString", Seq(nir.Rt.String))
)
private lazy val jlStringBuilderToStringSig = nir.Type.Function(
Seq(jlStringBuilderRef),
nir.Rt.String
)

private def genStringConcat(tree: Apply): nir.Val = {
implicit val nirPos: nir.Position = tree.pos
liftStringConcat(tree) match {
// Optimization for expressions of the form "" + x
case List(Literal(Constant("")), arg) =>
genApplyStaticMethod(
nirDefinitions.String_valueOf_Object,
defn.StringClass,
Seq(arg)
)

val right = {
val typesym = rightp.tpe.typeSymbol
val boxed = genExpr(rightp)
stringify(typesym, boxed)(rightp.pos)
case concatenations =>
val concatArguments = concatenations.view
.filter {
// empty strings are no-ops in concatenation
case Literal(Constant("")) => false
case _ => true
}
.map {
// Eliminate boxing of primitive values. Boxing is introduced by erasure because
// there's only a single synthetic `+` method "added" to the string class.
case Apply(boxOp, value :: Nil)
// TODO: SN specific boxing
if currentRun.runDefinitions.isBox(boxOp.symbol) =>
value
case other => other
}
.toList
// Estimate capacity needed for the string builder
val approxBuilderSize = concatArguments.view.map {
case Literal(Constant(s: String)) => s.length
case Literal(c @ Constant(_)) if c.isNonUnitAnyVal =>
String.valueOf(c).length
case _ => 0
}.sum

// new StringBuidler(approxBuilderSize)
val stringBuilder =
buf.classalloc(jlStringBuilderRef.name, unwind, None)
buf.call(
jlStringBuilderCtorSig,
nir.Val.Global(jlStringBuilderCtor, nir.Type.Ptr),
Seq(stringBuilder, nir.Val.Int(approxBuilderSize)),
unwind
)
// concat substrings
concatArguments.foreach(genStringBuilderAppend(stringBuilder, _))
// stringBuilder.toString
buf.call(
jlStringBuilderToStringSig,
nir.Val.Global(jlStringBuilderToString, nir.Type.Ptr),
Seq(stringBuilder),
unwind
)
}

genApplyMethod(String_+, statically = true, left, Seq(ValTree(right)))(
leftp.pos
)
}

def genHashCode(argp: Tree)(implicit pos: nir.Position): nir.Val = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,4 +260,19 @@ trait NirGenType[G <: Global with Singleton] { self: NirGenPhase[G] =>
}
}
}

lazy val jlStringBuilderAppendForSymbol =
nirDefinitions.jlStringBuilderAppendAlts.flatMap { sym =>
val sig = genMethodSig(sym)
def name = genMethodName(sym)
sig match {
case nir.Type.Function(Seq(_, arg), _)
if sym.owner == nirDefinitions.jlStringBuilderRef =>
Some(
nir.Type.normalize(arg) -> (nir.Val.Global(name, nir.Type.Ptr), sig)
)
case _ => None
}
}.toMap

}
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,18 @@ final class NirDefinitions()(using ctx: Context) {
@tu lazy val JavaUtilServiceLoaderLoad = JavaUtilServiceLoader.alternatives("load")
@tu lazy val JavaUtilServiceLoaderLoadInstalled = JavaUtilServiceLoader.requiredMethod("loadInstalled")
@tu lazy val LinktimeIntrinsics = JavaUtilServiceLoaderLoad ++ Seq(JavaUtilServiceLoaderLoadInstalled)

@tu lazy val jlStringBuilderRef = requiredClass("java.lang.StringBuilder")
@tu lazy val jlStringBuilderType = jlStringBuilderRef.typeRef
@tu lazy val jlStringBuilderAppendAlts = jlStringBuilderRef.info
.decl(termName("append"))
.alternatives
.map(_.symbol)
@tu lazy val jlStringBufferRef = requiredClass("java.lang.StringBuffer")
@tu lazy val jlStringBufferType = jlStringBufferRef.typeRef
@tu lazy val jlCharSequenceRef = requiredClass("java.lang.CharSequence")
@tu lazy val jlCharSequenceType = jlCharSequenceRef.typeRef

// Scala library & runtime
@tu lazy val InlineClass = requiredClass("scala.inline")
@tu lazy val NoInlineClass = requiredClass("scala.noinline")
Expand Down

0 comments on commit f2929b1

Please sign in to comment.