Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improvement: Use java.lang.StringBuilder for optimized concatation of Strings in NIR CodeGen #3640

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,21 @@ trait NirDefinitions {
lazy val JavaProperties = getRequiredClass("java.util.Properties")

lazy val StringConcatMethod = getMember(StringClass, TermName("concat"))
lazy val String_valueOf_Object =
getMember(StringModule, nme.valueOf).filter(sym =>
sym.info.paramTypes match {
case List(pt) => pt.typeSymbol == ObjectClass
case _ => false
}
)
lazy val jlStringBuilderRef = getRequiredClass("java.lang.StringBuilder")
lazy val jlStringBuilderType = jlStringBuilderRef.toType
lazy val jlStringBuilderAppendAlts =
getMemberMethod(jlStringBuilderRef, TermName("append")).alternatives
lazy val jlStringBufferRef = getRequiredClass("java.lang.StringBuffer")
lazy val jlStringBufferType = jlStringBufferRef.toType
lazy val jlCharSequenceRef = getRequiredClass("java.lang.CharSequence")
lazy val jlCharSequenceType = jlCharSequenceRef.toType

lazy val BoxMethod = Map[Char, Symbol](
'B' -> getDecl(BoxesRunTimeModule, TermName("boxToBoolean")),
Expand Down
160 changes: 127 additions & 33 deletions nscplugin/src/main/scala-2/scala/scalanative/nscplugin/NirGenExpr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1207,7 +1207,7 @@ trait NirGenExpr[G <: nsc.Global with Singleton] { self: NirGenPhase[G] =>
val sym = app.symbol
val code = scalaPrimitives.getPrimitive(sym, receiver.tpe)
(code: @switch) match {
case CONCAT => genStringConcat(receiver, args.head)
case CONCAT => genStringConcat(app)
case HASH => genHashCode(args.head)
case CFUNCPTR_APPLY => genCFuncPtrApply(app, code)
case CFUNCPTR_FROM_FUNCTION => genCFuncFromScalaFunction(app)
Expand Down Expand Up @@ -1810,43 +1810,137 @@ trait NirGenExpr[G <: nsc.Global with Singleton] { self: NirGenPhase[G] =>
abort(s"can't perform binary operation between $lty and $rty")
}

def genStringConcat(leftp: Tree, rightp: Tree): nir.Val = {
def stringify(sym: Symbol, value: nir.Val)(implicit
pos: nir.Position
): nir.Val = {
val cond = ContTree { () =>
buf.comp(nir.Comp.Ieq, nir.Rt.Object, value, nir.Val.Null, unwind)
}
val thenp = ContTree { () => nir.Val.String("null") }
val elsep = ContTree { () =>
if (sym == StringClass) {
value
} else {
val meth = Object_toString
genApplyMethod(meth, statically = false, value, Seq.empty)
}
/*
* Returns a list of trees that each should be concatenated, from left to right.
* It turns a chained call like "a".+("b").+("c") into a list of arguments.
*/
def liftStringConcat(tree: Tree): List[Tree] = {
val result = collection.mutable.ListBuffer[Tree]()
def loop(tree: Tree): Unit = {
tree match {
case Apply(fun @ Select(larg, method), rarg :: Nil)
if (scalaPrimitives.isPrimitive(fun.symbol) &&
scalaPrimitives.getPrimitive(fun.symbol) ==
scalaPrimitives.CONCAT) =>
loop(larg)
loop(rarg)
case _ =>
result += tree
}
genIf(nir.Rt.String, cond, thenp, elsep)
}
loop(tree)
result.toList
}

/* Issue a call to `StringBuilder#append` for the right element type */
private final def genStringBuilderAppend(
stringBuilder: nir.Val.Local,
tree: Tree
): Unit = {
implicit val nirPos: nir.Position = tree.pos

val tpe = tree.tpe
val argType =
if (tpe <:< defn.StringTpe) nir.Rt.String
else if (tpe <:< nirDefinitions.jlStringBufferType)
genType(nirDefinitions.jlStringBufferRef)
else if (tpe <:< nirDefinitions.jlCharSequenceType)
genType(nirDefinitions.jlCharSequenceRef)
// Don't match for `Array(Char)`, even though StringBuilder has such an overload:
// `"a" + Array('b')` should NOT be "ab", but "a[C@...".
else if (tpe <:< defn.ObjectTpe) nir.Rt.Object
else genType(tpe)

val value = genExpr(tree)
val (adaptedValue, targetType) = argType match {
// jlStringBuilder does not have overloads for byte and short, but we can just use the int version
case nir.Type.Byte | nir.Type.Short =>
genCoercion(value, value.ty, nir.Type.Int) -> nir.Type.Int
case nirType => value -> nirType
}

val (appendFunction, appendSig) =
jlStringBuilderAppendForSymbol(targetType)
buf.call(
appendSig,
appendFunction,
Seq(stringBuilder, adaptedValue),
unwind
)
}

val left = {
implicit val pos: nir.Position = leftp.pos

val typesym = leftp.tpe.typeSymbol
val unboxed = genExpr(leftp)
val boxed = boxValue(typesym, unboxed)
stringify(typesym, boxed)
}
private lazy val jlStringBuilderRef =
nir.Type.Ref(genTypeName(nirDefinitions.jlStringBuilderRef))
private lazy val jlStringBuilderCtor =
jlStringBuilderRef.name.member(nir.Sig.Ctor(Seq(nir.Type.Int)))
private lazy val jlStringBuilderCtorSig = nir.Type.Function(
Seq(jlStringBuilderRef, nir.Type.Int),
nir.Type.Unit
)
private lazy val jlStringBuilderToString =
jlStringBuilderRef.name.member(
nir.Sig.Method("toString", Seq(nir.Rt.String))
)
private lazy val jlStringBuilderToStringSig = nir.Type.Function(
Seq(jlStringBuilderRef),
nir.Rt.String
)

private def genStringConcat(tree: Apply): nir.Val = {
implicit val nirPos: nir.Position = tree.pos
liftStringConcat(tree) match {
// Optimization for expressions of the form "" + x
case List(Literal(Constant("")), arg) =>
genApplyStaticMethod(
nirDefinitions.String_valueOf_Object,
defn.StringClass,
Seq(arg)
)

val right = {
val typesym = rightp.tpe.typeSymbol
val boxed = genExpr(rightp)
stringify(typesym, boxed)(rightp.pos)
case concatenations =>
val concatArguments = concatenations.view
.filter {
// empty strings are no-ops in concatenation
case Literal(Constant("")) => false
case _ => true
}
.map {
// Eliminate boxing of primitive values. Boxing is introduced by erasure because
// there's only a single synthetic `+` method "added" to the string class.
case Apply(boxOp, value :: Nil)
// TODO: SN specific boxing
if currentRun.runDefinitions.isBox(boxOp.symbol) =>
value
case other => other
}
.toList
// Estimate capacity needed for the string builder
val approxBuilderSize = concatArguments.view.map {
case Literal(Constant(s: String)) => s.length
case Literal(c @ Constant(_)) if c.isNonUnitAnyVal =>
String.valueOf(c).length
case _ => 0
}.sum

// new StringBuidler(approxBuilderSize)
val stringBuilder =
buf.classalloc(jlStringBuilderRef.name, unwind, None)
buf.call(
jlStringBuilderCtorSig,
nir.Val.Global(jlStringBuilderCtor, nir.Type.Ptr),
Seq(stringBuilder, nir.Val.Int(approxBuilderSize)),
unwind
)
// concat substrings
concatArguments.foreach(genStringBuilderAppend(stringBuilder, _))
// stringBuilder.toString
buf.call(
jlStringBuilderToStringSig,
nir.Val.Global(jlStringBuilderToString, nir.Type.Ptr),
Seq(stringBuilder),
unwind
)
}

genApplyMethod(String_+, statically = true, left, Seq(ValTree(right)))(
leftp.pos
)
}

def genHashCode(argp: Tree)(implicit pos: nir.Position): nir.Val = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,4 +260,19 @@ trait NirGenType[G <: Global with Singleton] { self: NirGenPhase[G] =>
}
}
}

lazy val jlStringBuilderAppendForSymbol =
nirDefinitions.jlStringBuilderAppendAlts.flatMap { sym =>
val sig = genMethodSig(sym)
def name = genMethodName(sym)
sig match {
case nir.Type.Function(Seq(_, arg), _)
if sym.owner == nirDefinitions.jlStringBuilderRef =>
Some(
nir.Type.normalize(arg) -> (nir.Val.Global(name, nir.Type.Ptr), sig)
)
case _ => None
}
}.toMap

}
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,18 @@ final class NirDefinitions()(using ctx: Context) {
@tu lazy val JavaUtilServiceLoaderLoad = JavaUtilServiceLoader.alternatives("load")
@tu lazy val JavaUtilServiceLoaderLoadInstalled = JavaUtilServiceLoader.requiredMethod("loadInstalled")
@tu lazy val LinktimeIntrinsics = JavaUtilServiceLoaderLoad ++ Seq(JavaUtilServiceLoaderLoadInstalled)

@tu lazy val jlStringBuilderRef = requiredClass("java.lang.StringBuilder")
@tu lazy val jlStringBuilderType = jlStringBuilderRef.typeRef
@tu lazy val jlStringBuilderAppendAlts = jlStringBuilderRef.info
.decl(termName("append"))
.alternatives
.map(_.symbol)
@tu lazy val jlStringBufferRef = requiredClass("java.lang.StringBuffer")
@tu lazy val jlStringBufferType = jlStringBufferRef.typeRef
@tu lazy val jlCharSequenceRef = requiredClass("java.lang.CharSequence")
@tu lazy val jlCharSequenceType = jlCharSequenceRef.typeRef

// Scala library & runtime
@tu lazy val InlineClass = requiredClass("scala.inline")
@tu lazy val NoInlineClass = requiredClass("scala.noinline")
Expand Down