Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

surface (fix): Reduce the bytecode size of Surface.methodsOf for Scala 3 #3149

Merged
merged 20 commits into from
Aug 21, 2023
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object HttpRequestMapperTest extends AirSpec {
def rpc5(p1: Option[String]): Unit = {}
def rpc6(p1: Option[NestedRequest]): Unit = {}
def rpc7(
request: HttpMessage.Request,
request: Request,
context: HttpContext[Request, Response, Future],
req: HttpRequest[Request]
): Unit = {}
Expand All @@ -64,8 +64,10 @@ object HttpRequestMapperTest extends AirSpec {
def endpoint4(p1: Option[Seq[String]]): Unit = {}
}

private val api = new MyApi {}
private val router = Router.add[MyApi].add[MyApi2]
private val api = new MyApi {}
private val router = Router
.add[MyApi]
.add[MyApi2]

private val mockContext = HttpContext.mockContext
private def mapArgs(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package wvlet.airframe.surface
import scala.quoted._
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.immutable.ListMap
import scala.quoted.*

private[surface] object CompileTimeSurfaceFactory {

Expand Down Expand Up @@ -76,27 +78,36 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
surfaceOf(TypeRepr.of(using tpe))
}

private val seen = scala.collection.mutable.Set[TypeRepr]()
private val memo = scala.collection.mutable.Map[TypeRepr, Expr[Surface]]()
private val lazySurface = scala.collection.mutable.Set[TypeRepr]()
private var observedSurfaceCount = new AtomicInteger(0)
private var seen = ListMap[TypeRepr, Int]()
private val memo = scala.collection.mutable.Map[TypeRepr, Expr[Surface]]()
private val lazySurface = scala.collection.mutable.Set[TypeRepr]()

private def surfaceOf(t: TypeRepr): Expr[Surface] = {
if (surfaceToVar.contains(t)) {
// println(s"==== ${t} is already cached")
Ref(surfaceToVar(t)).asExprOf[Surface]
private def surfaceOf(t: TypeRepr, useVarRef: Boolean = true): Expr[Surface] = {
def buildLazySurface: Expr[Surface] = {
'{ LazySurface(${ clsOf(t) }, ${ Expr(fullTypeNameOf(t)) }) }
}

if (useVarRef && surfaceToVar.contains(t)) {
if (lazySurface.contains(t)) {
buildLazySurface
} else {
Ref(surfaceToVar(t)).asExprOf[Surface]
}
} else if (seen.contains(t)) {
if (memo.contains(t)) {
memo(t)
} else {
lazySurface += t
'{ LazySurface(${ clsOf(t) }, ${ Expr(fullTypeNameOf(t)) }) }
buildLazySurface
}
} else {
seen += t
seen += t -> observedSurfaceCount.getAndIncrement()
// For debugging
// println(s"[${typeNameOf(t)}]\n ${t}\nfull type name: ${fullTypeNameOf(t)}\nclass: ${t.getClass}")
val generator = factory.andThen { expr =>
if (!lazySurface.contains(t)) {
// Generate the surface code without using the cache
expr
} else {
// Need to cache the recursive Surface to be referenced in a LazySurface
Expand All @@ -115,16 +126,12 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
} else {
fullTypeNameOf(t)
}
val key = Literal(StringConstant(cacheKey)).asExprOf[String]
'{
val key = ${
Expr(cacheKey)
if (!wvlet.airframe.surface.surfaceCache.contains(${ key })) {
wvlet.airframe.surface.surfaceCache += ${ key } -> ${ expr }
}
if (!wvlet.airframe.surface.surfaceCache.contains(key)) {
wvlet.airframe.surface.surfaceCache += key -> ${
expr
}
}
wvlet.airframe.surface.surfaceCache(key)
wvlet.airframe.surface.surfaceCache.apply(${ key })
}
}
}
Expand Down Expand Up @@ -386,9 +393,8 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
// args(i+1)
val extracted = Select.unique(args, "apply").appliedTo(Literal(IntConstant(index)))
index += 1
// args(i+1).asInstanceOf[A]
// TODO: Cast primitive values to target types
Select.unique(extracted, "asInstanceOf").appliedToType(a.tpe)
// classOf[A].cast(args(i+1))
clsCast(extracted, a.tpe)
}
Apply(prev, argExtractors.toList)
}
Expand All @@ -397,9 +403,7 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
}
)
val expr = '{
new wvlet.airframe.surface.ObjectFactory {
override def newInstance(args: Seq[Any]): Any = { ${ newClassFn.asExprOf[Seq[Any] => Any] }(args) }
}
ObjectFactory.newFactory(${ newClassFn.asExprOf[Seq[Any] => Any] })
}
expr
}
Expand Down Expand Up @@ -442,7 +446,9 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
),
rhsFn = (sym: Symbol, paramRefs: List[Tree]) => {
val strVarRef = paramRefs(1).asExprOf[String].asTerm
Select.unique(Apply(m, List(strVarRef)), "asInstanceOf").appliedToType(TypeRepr.of[Option[Any]])
val expr = Select.unique(Apply(m, List(strVarRef)), "asInstanceOf").appliedToType(TypeRepr.of[Option[Any]])
expr.changeOwner(sym)

}
)
'{
Expand Down Expand Up @@ -539,7 +545,8 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
case other =>
other
}
a.appliedTo(resolvedTypeArgs)
// Need to use the base type of the applied type to replace the type parameters
a.tycon.appliedTo(resolvedTypeArgs)
case TypeRef(_, name) if typeArgTable.contains(name) =>
typeArgTable(name)
case other =>
Expand Down Expand Up @@ -582,7 +589,7 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
clsOf(arg.tpe.dealias)
}
val isConstructor = t.typeSymbol.primaryConstructor == method
val constructorRef = '{
val constructorRef: Expr[MethodRef] = '{
MethodRef(
owner = ${ clsOf(t) },
name = ${ Expr(methodName) },
Expand Down Expand Up @@ -618,20 +625,40 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
// println(s"${paramName} ${paramIsAccessible}")

val accessor: Expr[Option[Any => Any]] = if (method.isClassConstructor && paramIsAccessible) {
// MethodParameter.accessor[(owner type), (parameter type]]
val accessorMethod: Symbol = TypeRepr.of[MethodParameter.type].typeSymbol.methodMember("accessor").head
val objRef = Ref(TypeRepr.of[MethodParameter].typeSymbol.companionModule)

def resolveType(tpe: TypeRepr): TypeRepr = tpe match {
case b: TypeBounds =>
TypeRepr.of[Any]
case _ =>
tpe
}

val t1 = resolveType(t)
val t2 = resolveType(paramType)

val typedAccessor = objRef.select(accessorMethod).appliedToTypes(List(t1, t2))
val methodCall = typedAccessor.appliedToArgs(List(Literal(ClassOfConstant(t1))))

val lambda = Lambda(
owner = Symbol.spliceOwner,
tpe = MethodType(List("x"))(_ => List(TypeRepr.of[Any]), _ => TypeRepr.of[Any]),
tpe = MethodType(List("x"))(_ => List(t1), _ => t2),
rhsFn = (sym, params) => {
val x = params.head.asInstanceOf[Term]
val expr = Select.unique(Select.unique(x, "asInstanceOf").appliedToType(t), paramName)
val expr = Select.unique(x, paramName)
expr.changeOwner(sym)
}
)
// println(t.typeSymbol)
// println(paramType.typeSymbol.flags.show)
// println(lambda.show)
// println(lambda.show(using Printer.TreeStructure))
'{ Some(${ lambda.asExprOf[Any => Any] }) }
val accMethod = methodCall.appliedToArgs(List(lambda))
// println(s"=== ${accMethod.show}")

// Generate code like :
// {{{
// MethodParameter.accessor[t1, t2](classOf[t1]){(x:t1) => x.(field name) }
// }}}
'{ Some(${ accMethod.asExprOf[Any => Any] }) }
} else {
'{ None }
}
Expand All @@ -645,7 +672,7 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
tpe = MethodType(List("x"))(_ => List(TypeRepr.of[Any]), _ => TypeRepr.of[Any]),
rhsFn = (sym, params) => {
val x = params.head.asInstanceOf[Term]
val expr = Select.unique(x, "asInstanceOf").appliedToType(t).select(m)
val expr = clsCast(x, t).select(m)
expr.changeOwner(sym)
}
)
Expand Down Expand Up @@ -683,35 +710,53 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
}

// To reduce the byte code size, we need to memoize the generated surface bound to a variable
private val surfaceToVar = scala.collection.mutable.Map[TypeRepr, Symbol]()
private var surfaceToVar = ListMap.empty[TypeRepr, Symbol]

private def methodsOf(t: TypeRepr): Expr[Seq[MethodSurface]] = {
// Run just for collecting known surfaces. seen variable will be updated
methodsOfInternal(t)

var count = 0
// Bind the observed surfaces to local variables __s0, __s1, ...
seen.foreach { s =>
// Update the cache so that the next call of surfaceOf method will use the local varaible reference
surfaceToVar += s -> Symbol.newVal(
Symbol.spliceOwner,
s"__s${count}",
TypeRepr.of[Surface],
Flags.EmptyFlags,
Symbol.noSymbol
)
count += 1
}
val surfaceDefs: List[ValDef] = surfaceToVar.map { x =>
val sym = x._2
ValDef(sym, Some(memo(x._1).asTerm))
}.toList
// Create a var def table for replacing surfaceOf[xxx] to __s0, __s1, ...
var surfaceVarCount = 0
seen
// Exclude primitive type surface
.toSeq
// Exclude primitive surfaces as it is already defined in Primitive object
.filterNot(x => primitiveTypeFactory.isDefinedAt(x._1))
.sortBy(_._2)
.reverse
.map { case (tpe, order) =>
// Update the cache so that the next call of surfaceOf method will use the local varaible reference
surfaceToVar += tpe -> Symbol.newVal(
Symbol.spliceOwner,
// Use alphabetically ordered variable names
f"__s${surfaceVarCount}%03X",
TypeRepr.of[Surface],
if (lazySurface.contains(tpe)) {
// If the surface itself is lazy, we need to eagerly initialize it to update the surface cache
Flags.EmptyFlags
} else {
// Use lazy val to avoid forward reference error
Flags.Lazy
},
Symbol.noSymbol
)
surfaceVarCount += 1
}

// Clear method observation cache
// Clear surface cache
memo.clear()
seen = ListMap.empty
seenMethodParent.clear()

val surfaceDefs: List[ValDef] = surfaceToVar.toSeq.map { case (tpe, sym) =>
ValDef(sym, Some(surfaceOf(tpe, useVarRef = false).asTerm))
}.toList

/**
* Generate a code like this: {{ val __s0 = Surface.of[A] val __s1 = Surface.of[B] ...
* Generate a code like this:
*
* {{ lazy val __s000 = Surface.of[A]; lazy val __s001 = Surface.of[B] ...
*
* ClassMethodSurface( .... ) }}
*/
Expand Down Expand Up @@ -758,6 +803,10 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
}
}

private def clsCast(term: Term, t: TypeRepr): Term = {
Select.unique(term, "asInstanceOf").appliedToType(t)
}

private def createMethodCaller(
objectType: TypeRepr,
m: Symbol,
Expand All @@ -779,13 +828,12 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
rhsFn = (sym, params) => {
val x = params(0).asInstanceOf[Term]
val args = params(1).asInstanceOf[Term]
val expr = Select.unique(x, "asInstanceOf").appliedToType(objectType).select(m)
val expr = clsCast(x, objectType).select(m)
val argList = methodArgs.zipWithIndex.collect {
// If the arg is implicit, no need to explicitly bind it
case (arg, i) if !arg.isImplicit =>
// args(i).asInstanceOf[ArgType]
val extracted = Select.unique(args, "apply").appliedTo(Literal(IntConstant(i)))
Select.unique(extracted, "asInstanceOf").appliedToType(arg.tpe)
clsCast(extracted, arg.tpe)
}
if (argList.isEmpty) {
val newExpr = m.tree match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,18 @@ trait ObjectFactory extends Serializable {
def newInstance(args: Seq[Any]): Any
}

object ObjectFactory {

/**
* Used internally for creating a new ObjectFactory instance from a given generic function
* @param f
* @return
*/
def newFactory(f: Seq[Any] => Any): ObjectFactory = new ObjectFactory {
override def newInstance(args: Seq[Any]): Any = f(args)
}
}

case class MethodRef(owner: Class[_], name: String, paramTypes: Seq[Class[_]], isConstructor: Boolean)

trait MethodParameter extends Parameter {
Expand All @@ -75,6 +87,12 @@ trait MethodParameter extends Parameter {
def getMethodArgDefaultValue(methodOwner: Any): Option[Any] = getDefaultValue
}

object MethodParameter {
def accessor[A, B](cl: Class[A])(body: A => B): Any => B = { (x: Any) =>
body(cl.cast(x))
}
}

trait MethodSurface extends ParameterBase {
def mod: Int
def owner: Surface
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package wvlet.airframe.surface

object RecursiveMethodParamTest {
case class Node(parent: Option[Node])

trait MyRecursiveApi {
def find(node: Node): Unit = {}
}
}

class RecursiveMethodParamTest extends munit.FunSuite {
import RecursiveMethodParamTest._

// ....
test("Compile method surfaces with recursive method param") {
Surface.methodsOf[MyRecursiveApi]
}
}
2 changes: 2 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ val buildSettings = Seq[Setting[_]](
scalacOptions ++= Seq(
"-feature",
"-deprecation"
// Use this for debugging Macros
// "-Xcheck-macros"
) ++ {
if (scalaVersion.value.startsWith("3.")) {
Seq.empty
Expand Down
Loading