Skip to content

Commit

Permalink
surface (fix): Reduce the bytecode size of Surface.methodsOf for Scal…
Browse files Browse the repository at this point in the history
…a 3 (#3149)

This fixes byte code too large error when using RxRouter.of[X] or
Surface.methodsOf[X] when class X has many methods and repeated
occurrences of the same parameters.

- Bind surface to local lazy val __s000, __s001, ... and reference them
when building other Surfaces and MethodSurfaces
  - Add a special handling for managing lazy surface
- Fixes #3131 as well 
- Reuse object methods when generating method accessor and object
factory.

Other ideas
- [x] Use externally defined methods for field setter and object builder
- [x] Reuse Surface inside OptionSurface, GenericSurface by analyzing
dependencies of Surface fixes #3150
  • Loading branch information
xerial committed Aug 21, 2023
1 parent d9e3642 commit 0b94a88
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object HttpRequestMapperTest extends AirSpec {
def rpc5(p1: Option[String]): Unit = {}
def rpc6(p1: Option[NestedRequest]): Unit = {}
def rpc7(
request: HttpMessage.Request,
request: Request,
context: HttpContext[Request, Response, Future],
req: HttpRequest[Request]
): Unit = {}
Expand All @@ -64,8 +64,10 @@ object HttpRequestMapperTest extends AirSpec {
def endpoint4(p1: Option[Seq[String]]): Unit = {}
}

private val api = new MyApi {}
private val router = Router.add[MyApi].add[MyApi2]
private val api = new MyApi {}
private val router = Router
.add[MyApi]
.add[MyApi2]

private val mockContext = HttpContext.mockContext
private def mapArgs(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package wvlet.airframe.surface
import scala.quoted._
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.immutable.ListMap
import scala.quoted.*

private[surface] object CompileTimeSurfaceFactory {

Expand Down Expand Up @@ -76,27 +78,36 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
surfaceOf(TypeRepr.of(using tpe))
}

private val seen = scala.collection.mutable.Set[TypeRepr]()
private val memo = scala.collection.mutable.Map[TypeRepr, Expr[Surface]]()
private val lazySurface = scala.collection.mutable.Set[TypeRepr]()
private var observedSurfaceCount = new AtomicInteger(0)
private var seen = ListMap[TypeRepr, Int]()
private val memo = scala.collection.mutable.Map[TypeRepr, Expr[Surface]]()
private val lazySurface = scala.collection.mutable.Set[TypeRepr]()

private def surfaceOf(t: TypeRepr): Expr[Surface] = {
if (surfaceToVar.contains(t)) {
// println(s"==== ${t} is already cached")
Ref(surfaceToVar(t)).asExprOf[Surface]
private def surfaceOf(t: TypeRepr, useVarRef: Boolean = true): Expr[Surface] = {
def buildLazySurface: Expr[Surface] = {
'{ LazySurface(${ clsOf(t) }, ${ Expr(fullTypeNameOf(t)) }) }
}

if (useVarRef && surfaceToVar.contains(t)) {
if (lazySurface.contains(t)) {
buildLazySurface
} else {
Ref(surfaceToVar(t)).asExprOf[Surface]
}
} else if (seen.contains(t)) {
if (memo.contains(t)) {
memo(t)
} else {
lazySurface += t
'{ LazySurface(${ clsOf(t) }, ${ Expr(fullTypeNameOf(t)) }) }
buildLazySurface
}
} else {
seen += t
seen += t -> observedSurfaceCount.getAndIncrement()
// For debugging
// println(s"[${typeNameOf(t)}]\n ${t}\nfull type name: ${fullTypeNameOf(t)}\nclass: ${t.getClass}")
val generator = factory.andThen { expr =>
if (!lazySurface.contains(t)) {
// Generate the surface code without using the cache
expr
} else {
// Need to cache the recursive Surface to be referenced in a LazySurface
Expand All @@ -115,16 +126,12 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
} else {
fullTypeNameOf(t)
}
val key = Literal(StringConstant(cacheKey)).asExprOf[String]
'{
val key = ${
Expr(cacheKey)
if (!wvlet.airframe.surface.surfaceCache.contains(${ key })) {
wvlet.airframe.surface.surfaceCache += ${ key } -> ${ expr }
}
if (!wvlet.airframe.surface.surfaceCache.contains(key)) {
wvlet.airframe.surface.surfaceCache += key -> ${
expr
}
}
wvlet.airframe.surface.surfaceCache(key)
wvlet.airframe.surface.surfaceCache.apply(${ key })
}
}
}
Expand Down Expand Up @@ -386,9 +393,8 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
// args(i+1)
val extracted = Select.unique(args, "apply").appliedTo(Literal(IntConstant(index)))
index += 1
// args(i+1).asInstanceOf[A]
// TODO: Cast primitive values to target types
Select.unique(extracted, "asInstanceOf").appliedToType(a.tpe)
// classOf[A].cast(args(i+1))
clsCast(extracted, a.tpe)
}
Apply(prev, argExtractors.toList)
}
Expand All @@ -397,9 +403,7 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
}
)
val expr = '{
new wvlet.airframe.surface.ObjectFactory {
override def newInstance(args: Seq[Any]): Any = { ${ newClassFn.asExprOf[Seq[Any] => Any] }(args) }
}
ObjectFactory.newFactory(${ newClassFn.asExprOf[Seq[Any] => Any] })
}
expr
}
Expand Down Expand Up @@ -442,7 +446,9 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
),
rhsFn = (sym: Symbol, paramRefs: List[Tree]) => {
val strVarRef = paramRefs(1).asExprOf[String].asTerm
Select.unique(Apply(m, List(strVarRef)), "asInstanceOf").appliedToType(TypeRepr.of[Option[Any]])
val expr = Select.unique(Apply(m, List(strVarRef)), "asInstanceOf").appliedToType(TypeRepr.of[Option[Any]])
expr.changeOwner(sym)

}
)
'{
Expand Down Expand Up @@ -539,7 +545,8 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
case other =>
other
}
a.appliedTo(resolvedTypeArgs)
// Need to use the base type of the applied type to replace the type parameters
a.tycon.appliedTo(resolvedTypeArgs)
case TypeRef(_, name) if typeArgTable.contains(name) =>
typeArgTable(name)
case other =>
Expand Down Expand Up @@ -582,7 +589,7 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
clsOf(arg.tpe.dealias)
}
val isConstructor = t.typeSymbol.primaryConstructor == method
val constructorRef = '{
val constructorRef: Expr[MethodRef] = '{
MethodRef(
owner = ${ clsOf(t) },
name = ${ Expr(methodName) },
Expand Down Expand Up @@ -618,20 +625,40 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
// println(s"${paramName} ${paramIsAccessible}")

val accessor: Expr[Option[Any => Any]] = if (method.isClassConstructor && paramIsAccessible) {
// MethodParameter.accessor[(owner type), (parameter type]]
val accessorMethod: Symbol = TypeRepr.of[MethodParameter.type].typeSymbol.methodMember("accessor").head
val objRef = Ref(TypeRepr.of[MethodParameter].typeSymbol.companionModule)

def resolveType(tpe: TypeRepr): TypeRepr = tpe match {
case b: TypeBounds =>
TypeRepr.of[Any]
case _ =>
tpe
}

val t1 = resolveType(t)
val t2 = resolveType(paramType)

val typedAccessor = objRef.select(accessorMethod).appliedToTypes(List(t1, t2))
val methodCall = typedAccessor.appliedToArgs(List(Literal(ClassOfConstant(t1))))

val lambda = Lambda(
owner = Symbol.spliceOwner,
tpe = MethodType(List("x"))(_ => List(TypeRepr.of[Any]), _ => TypeRepr.of[Any]),
tpe = MethodType(List("x"))(_ => List(t1), _ => t2),
rhsFn = (sym, params) => {
val x = params.head.asInstanceOf[Term]
val expr = Select.unique(Select.unique(x, "asInstanceOf").appliedToType(t), paramName)
val expr = Select.unique(x, paramName)
expr.changeOwner(sym)
}
)
// println(t.typeSymbol)
// println(paramType.typeSymbol.flags.show)
// println(lambda.show)
// println(lambda.show(using Printer.TreeStructure))
'{ Some(${ lambda.asExprOf[Any => Any] }) }
val accMethod = methodCall.appliedToArgs(List(lambda))
// println(s"=== ${accMethod.show}")

// Generate code like :
// {{{
// MethodParameter.accessor[t1, t2](classOf[t1]){(x:t1) => x.(field name) }
// }}}
'{ Some(${ accMethod.asExprOf[Any => Any] }) }
} else {
'{ None }
}
Expand All @@ -645,7 +672,7 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
tpe = MethodType(List("x"))(_ => List(TypeRepr.of[Any]), _ => TypeRepr.of[Any]),
rhsFn = (sym, params) => {
val x = params.head.asInstanceOf[Term]
val expr = Select.unique(x, "asInstanceOf").appliedToType(t).select(m)
val expr = clsCast(x, t).select(m)
expr.changeOwner(sym)
}
)
Expand Down Expand Up @@ -683,35 +710,53 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
}

// To reduce the byte code size, we need to memoize the generated surface bound to a variable
private val surfaceToVar = scala.collection.mutable.Map[TypeRepr, Symbol]()
private var surfaceToVar = ListMap.empty[TypeRepr, Symbol]

private def methodsOf(t: TypeRepr): Expr[Seq[MethodSurface]] = {
// Run just for collecting known surfaces. seen variable will be updated
methodsOfInternal(t)

var count = 0
// Bind the observed surfaces to local variables __s0, __s1, ...
seen.foreach { s =>
// Update the cache so that the next call of surfaceOf method will use the local varaible reference
surfaceToVar += s -> Symbol.newVal(
Symbol.spliceOwner,
s"__s${count}",
TypeRepr.of[Surface],
Flags.EmptyFlags,
Symbol.noSymbol
)
count += 1
}
val surfaceDefs: List[ValDef] = surfaceToVar.map { x =>
val sym = x._2
ValDef(sym, Some(memo(x._1).asTerm))
}.toList
// Create a var def table for replacing surfaceOf[xxx] to __s0, __s1, ...
var surfaceVarCount = 0
seen
// Exclude primitive type surface
.toSeq
// Exclude primitive surfaces as it is already defined in Primitive object
.filterNot(x => primitiveTypeFactory.isDefinedAt(x._1))
.sortBy(_._2)
.reverse
.map { case (tpe, order) =>
// Update the cache so that the next call of surfaceOf method will use the local varaible reference
surfaceToVar += tpe -> Symbol.newVal(
Symbol.spliceOwner,
// Use alphabetically ordered variable names
f"__s${surfaceVarCount}%03X",
TypeRepr.of[Surface],
if (lazySurface.contains(tpe)) {
// If the surface itself is lazy, we need to eagerly initialize it to update the surface cache
Flags.EmptyFlags
} else {
// Use lazy val to avoid forward reference error
Flags.Lazy
},
Symbol.noSymbol
)
surfaceVarCount += 1
}

// Clear method observation cache
// Clear surface cache
memo.clear()
seen = ListMap.empty
seenMethodParent.clear()

val surfaceDefs: List[ValDef] = surfaceToVar.toSeq.map { case (tpe, sym) =>
ValDef(sym, Some(surfaceOf(tpe, useVarRef = false).asTerm))
}.toList

/**
* Generate a code like this: {{ val __s0 = Surface.of[A] val __s1 = Surface.of[B] ...
* Generate a code like this:
*
* {{ lazy val __s000 = Surface.of[A]; lazy val __s001 = Surface.of[B] ...
*
* ClassMethodSurface( .... ) }}
*/
Expand Down Expand Up @@ -758,6 +803,10 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
}
}

private def clsCast(term: Term, t: TypeRepr): Term = {
Select.unique(term, "asInstanceOf").appliedToType(t)
}

private def createMethodCaller(
objectType: TypeRepr,
m: Symbol,
Expand All @@ -779,13 +828,12 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
rhsFn = (sym, params) => {
val x = params(0).asInstanceOf[Term]
val args = params(1).asInstanceOf[Term]
val expr = Select.unique(x, "asInstanceOf").appliedToType(objectType).select(m)
val expr = clsCast(x, objectType).select(m)
val argList = methodArgs.zipWithIndex.collect {
// If the arg is implicit, no need to explicitly bind it
case (arg, i) if !arg.isImplicit =>
// args(i).asInstanceOf[ArgType]
val extracted = Select.unique(args, "apply").appliedTo(Literal(IntConstant(i)))
Select.unique(extracted, "asInstanceOf").appliedToType(arg.tpe)
clsCast(extracted, arg.tpe)
}
if (argList.isEmpty) {
val newExpr = m.tree match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@ trait ObjectFactory extends Serializable {
def newInstance(args: Seq[Any]): Any
}

object ObjectFactory {

/**
* Used internally for creating a new ObjectFactory instance from a given generic function
* @param f
* @return
*/
def newFactory(f: Seq[Any] => Any): ObjectFactory = new ObjectFactory {
override def newInstance(args: Seq[Any]): Any = f(args)
}
}

case class MethodRef(owner: Class[_], name: String, paramTypes: Seq[Class[_]], isConstructor: Boolean)

trait MethodParameter extends Parameter {
Expand All @@ -75,6 +87,12 @@ trait MethodParameter extends Parameter {
def getMethodArgDefaultValue(methodOwner: Any): Option[Any] = getDefaultValue
}

object MethodParameter {
def accessor[A, B](cl: Class[A])(body: A => B): Any => B = { (x: Any) =>
body(cl.cast(x))
}
}

trait MethodSurface extends ParameterBase {
def mod: Int
def owner: Surface
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package wvlet.airframe.surface

object RecursiveMethodParamTest {
case class Node(parent: Option[Node])

trait MyRecursiveApi {
def find(node: Node): Unit = {}
}
}

class RecursiveMethodParamTest extends munit.FunSuite {
import RecursiveMethodParamTest._

// ....
test("Compile method surfaces with recursive method param") {
Surface.methodsOf[MyRecursiveApi]
}
}
2 changes: 2 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ val buildSettings = Seq[Setting[_]](
scalacOptions ++= Seq(
"-feature",
"-deprecation"
// Use this for debugging Macros
// "-Xcheck-macros"
) ++ {
if (scalaVersion.value.startsWith("3.")) {
Seq.empty
Expand Down

0 comments on commit 0b94a88

Please sign in to comment.