Skip to content

Commit

Permalink
airframe-surface: Use GenericSurface for Scala 3 + Scala.js (#2345)
Browse files Browse the repository at this point in the history
* airframe-surface: Use GenericSurface for Scala 3 + Scala.js
* airframe-surface: Build ObjectFactory for Scala.js + Scala 3
* Fix InnerClassTest
* Exclude type param list from params
* Make xxxUnit constructor accessible
  • Loading branch information
xerial committed Jul 26, 2022
1 parent 75193e4 commit 2fb5f8d
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 59 deletions.
Expand Up @@ -75,7 +75,7 @@ object Count {
val units = List(ONE, THOUSAND, MILLION, BILLION, TRILLION, QUADRILLION)
private val unitTable = units.map(x => x.unitString -> x).toMap[String, CountUnit]

sealed class CountUnit private[metrics] (val factor: Long, val unitString: String) {
sealed abstract class CountUnit(val factor: Long, val unitString: String) {
override def toString = unitString
}
case object ONE extends CountUnit(1L, "")
Expand Down
Expand Up @@ -97,7 +97,7 @@ object DataSize {

private val dataSizePattern = """^\s*(?<num>\d+(?:\.\d+)?)\s*(?<unit>[a-zA-Z]+)\s*$""".r

sealed class DataSizeUnit private[metrics] (val factor: Long, val unitString: String)
sealed abstract class DataSizeUnit(val factor: Long, val unitString: String)
case object BYTE extends DataSizeUnit(1L, "B")
case object KILOBYTE extends DataSizeUnit(1L << 10, "kB")
case object MEGABYTE extends DataSizeUnit(1L << 20, "MB")
Expand Down
Expand Up @@ -13,16 +13,13 @@
*/
package wvlet.airframe.surface

import wvlet.airframe.surface.reflect.RuntimeGenericSurface

/**
*/
class InnerClassTest extends munit.FunSuite {
case class A(id: Int, name: String)

test("pass inner class context to Surface") {
val s = Surface.of[A]
println(s.asInstanceOf[RuntimeGenericSurface].outer.get.getClass())
val a = s.objectFactory.map { x => x.newInstance(Seq(1, "leo")) }
assertEquals(a, Some(A(1, "leo")))
}
Expand Down
Expand Up @@ -287,16 +287,118 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
.exists(p => p.exists && !p.flags.is(Flags.Private) && p.paramSymss.nonEmpty) =>
val typeArgs = typeArgsOf(t.simplified).map(surfaceOf(_))
val methodParams = constructorParametersOf(t)
val isStatic = !t.typeSymbol.flags.is(Flags.Local)
// TODO: This code doesn't work for Scala.js + Scala 3.0.0
// val isStatic = !t.typeSymbol.flags.is(Flags.Local)
val factory = createObjectFactoryOf(t) match {
case Some(x) => '{ Some(${ x }) }
case None => '{ None }
}

'{
new wvlet.airframe.surface.reflect.RuntimeGenericSurface(
new wvlet.airframe.surface.GenericSurface(
${ clsOf(t) },
${ Expr.ofSeq(typeArgs) }.toIndexedSeq,
params = ${ methodParams },
isStatic = ${ Expr(isStatic) }
objectFactory = ${ factory }
)
}
}

private def typeMappingTable(t: TypeRepr, method: Symbol): Map[String, TypeRepr] = {
val classTypeParams: List[TypeRepr] = t match {
case a: AppliedType => a.args
case _ => List.empty[TypeRepr]
}

// Build a table for resolving type parameters, e.g., class MyClass[A, B] -> Map("A" -> TypeRepr, "B" -> TypeRepr)
method.paramSymss match {
// tpeArgs for case fields, methodArgs for method arguments
case tpeArgs :: tail if t.typeSymbol.typeMembers.nonEmpty =>
val typeArgTable = tpeArgs
.map(_.tree).zipWithIndex.collect {
case (td: TypeDef, i: Int) if i < classTypeParams.size =>
td.name -> classTypeParams(i)
}.toMap[String, TypeRepr]
// pri ntln(s"type args: ${typeArgTable}")
typeArgTable
case _ =>
Map.empty
}
}

// Get a constructor with its generic types are resolved
private def getResolvedConstructorOf(t: TypeRepr): Option[Term] = {
val ts = t.typeSymbol
ts.primaryConstructor match {
case pc if pc == Symbol.noSymbol =>
None
case pc =>
// val cstr = Select.apply(New(TypeIdent(ts)), "<init>")
val cstr = New(Inferred(t)).select(pc)
if (ts.typeMembers.isEmpty) {
Some(cstr)
} else {
val lookupTable = typeMappingTable(t, pc)
// println(s"--- ${lookupTable}")
val typeArgs = pc.paramSymss.headOption.getOrElse(List.empty).map(_.tree).collect { case t: TypeDef =>
lookupTable.getOrElse(t.name, TypeRepr.of[AnyRef])
}
Some(cstr.appliedToTypes(typeArgs))
}
}
}

private def createObjectFactoryOf(targetType: TypeRepr): Option[Expr[ObjectFactory]] = {
val ts = targetType.typeSymbol
val flags = ts.flags
if (
flags.is(Flags.Abstract) || flags.is(Flags.Module) || hasAbstractMethods(targetType) || isPathDependentType(
targetType
)
) {
None
} else {
getResolvedConstructorOf(targetType).map { cstr =>
val argListList = methodArgsOf(targetType, ts.primaryConstructor)
val newClassFn = Lambda(
owner = Symbol.spliceOwner,
tpe = MethodType(List("args"))(_ => List(TypeRepr.of[Seq[Any]]), _ => TypeRepr.of[Any]),
rhsFn = (sym: Symbol, paramRefs: List[Tree]) => {
val args = paramRefs.head.asExprOf[Seq[Any]].asTerm
var index = 0
val fn = argListList.foldLeft[Term](cstr) { (prev, argList) =>
val argExtractors = argList.map { a =>
// args(i+1)
val extracted = Select.unique(args, "apply").appliedTo(Literal(IntConstant(index)))
index += 1
// args(i+1).asInstanceOf[A]
// TODO: Cast primitive values to target types
Select.unique(extracted, "asInstanceOf").appliedToType(a.tpe)
}
Apply(prev, argExtractors.toList)
}
// println(s"== ${fn.show}")
fn.changeOwner(sym)
}
)
val expr = '{
new wvlet.airframe.surface.ObjectFactory {
override def newInstance(args: Seq[Any]): Any = { ${ newClassFn.asExprOf[Seq[Any] => Any] }(args) }
}
}
expr
}
}
}

private def hasAbstractMethods(t: TypeRepr): Boolean = {
t.typeSymbol.methodMembers.exists(_.flags.is(Flags.Abstract))
}

private def isPathDependentType(t: TypeRepr): Boolean = {
!t.typeSymbol.flags.is(Flags.Static) && (t match {
case t: TypeBounds => true
case _ => false
})
}

private def typeParameterFactory: Factory = {
Expand Down Expand Up @@ -359,7 +461,6 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
}
}

// TODO add defaultValue
private case class MethodArg(
name: String,
tpe: TypeRepr,
Expand All @@ -369,63 +470,54 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {
isSecret: Boolean
)

private def methodArgsOf(t: TypeRepr, method: Symbol): List[MethodArg] = {
val classTypeParams: List[TypeRepr] = t match {
case a: AppliedType =>
a.args
case _ =>
List.empty[TypeRepr]
}
private def methodArgsOf(t: TypeRepr, method: Symbol): List[List[MethodArg]] = {
// println(s"==== method args of ${fullTypeNameOf(t)}")

val defaultValueMethods = t.typeSymbol.companionClass.declaredMethods.filter { m =>
m.name.startsWith("apply$default$") || m.name.startsWith("$lessinit$greater$default$")
}

// println(s"==== method args of ${fullTypeNameOf(t)}")

// Build a table for resolving type parameters, e.g., class MyClass[A, B] -> Map("A" -> TypeRepr, "B" -> TypeRepr)
val typeArgTable: Map[String, TypeRepr] = method.paramSymss match {
// tpeArgs for case fields, methodArgs for method arguments
case List(tpeArgs, methodArgs) =>
val typeArgTable = tpeArgs
.map(_.tree).zipWithIndex.collect {
case (td: TypeDef, i: Int) if i < classTypeParams.size =>
td.name -> classTypeParams(i)
}.toMap[String, TypeRepr]
// println(s"type args: ${typeArgTable}")
typeArgTable
case _ =>
Map.empty
}

method.paramSymss.flatten.zipWithIndex
.map((x, i) => (x, i + 1, x.tree))
.collect { case (s: Symbol, i: Int, v: ValDef) =>
// E.g. case class Foo(a: String)(implicit b: Int)
// Substitue type param to actual types
val resolved: TypeRepr = v.tpt.tpe match {
case a: AppliedType =>
val resolvedTypeArgs = a.args.map {
case p if p.typeSymbol.isTypeParam && typeArgTable.contains(p.typeSymbol.name) =>
typeArgTable(p.typeSymbol.name)
case other => other
val typeArgTable: Map[String, TypeRepr] = typeMappingTable(t, method)

val origParamSymss = method.paramSymss
val paramss =
if (origParamSymss.nonEmpty && t.typeSymbol.typeMembers.nonEmpty) origParamSymss.tail else origParamSymss

paramss.map { params =>
params.zipWithIndex
.map((x, i) => (x, i + 1, x.tree))
.collect { case (s: Symbol, i: Int, v: ValDef) =>
// E.g. case class Foo(a: String)(implicit b: Int)
// Substitue type param to actual types
val resolved: TypeRepr = v.tpt.tpe match {
case a: AppliedType =>
val resolvedTypeArgs = a.args.map {
case p if p.typeSymbol.isTypeParam && typeArgTable.contains(p.typeSymbol.name) =>
typeArgTable(p.typeSymbol.name)
case other =>
other
}
a.appliedTo(resolvedTypeArgs)
case TypeRef(_, name) if typeArgTable.contains(name) =>
typeArgTable(name)
case other =>
other
}
val isSecret = hasSecretAnnotation(s)
val isRequired = hasRequiredAnnotation(s)
val defaultValueGetter = defaultValueMethods.find(m => m.name.endsWith(s"$$${i}"))

val defaultMethodArgGetter = {
val targetMethodName = method.name + "$default$" + i
t.typeSymbol.declaredMethods.find { m =>
// println(s"=== target: ${m.name}, ${m.owner.name}")
m.name == targetMethodName
}
a.appliedTo(resolvedTypeArgs)
case other => other
}
val isSecret = hasSecretAnnotation(s)
val isRequired = hasRequiredAnnotation(s)
val defaultValueGetter = defaultValueMethods.find(m => m.name.endsWith(s"$$${i}"))

val defaultMethodArgGetter = {
val targetMethodName = method.name + "$default$" + i
t.typeSymbol.declaredMethods.find { m =>
// println(s"=== target: ${m.name}, ${m.owner.name}")
m.name == targetMethodName
}
MethodArg(v.name, resolved, defaultValueGetter, defaultMethodArgGetter, isRequired, isSecret)
}
MethodArg(v.name, resolved, defaultValueGetter, defaultMethodArgGetter, isRequired, isSecret)
}
}
}

private def hasSecretAnnotation(s: Symbol): Boolean = {
Expand All @@ -443,7 +535,7 @@ private[surface] class CompileTimeSurfaceFactory[Q <: Quotes](using quotes: Q) {

private def methodParametersOf(t: TypeRepr, method: Symbol): Expr[Seq[MethodParameter]] = {
val methodName = method.name
val methodArgs = methodArgsOf(t, method)
val methodArgs = methodArgsOf(t, method).flatten
val argClasses = methodArgs.map { arg =>
clsOf(arg.tpe.dealias)
}
Expand Down

0 comments on commit 2fb5f8d

Please sign in to comment.