Skip to content

Commit

Permalink
Merge pull request #705 from adpi2/fix-casting
Browse files Browse the repository at this point in the history
[Expr Compiler] Fix casting qualifier of accessible member
  • Loading branch information
adpi2 committed May 6, 2024
2 parents ecdad2c + 59f52a9 commit e335c75
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 226 deletions.
Expand Up @@ -27,7 +27,7 @@ final class ExpressionCompilerBridge {
"-classpath",
classPath
// Debugging: Print the tree after phases of the debugger
// "-Xprint:insert-expression,extract-expression,resolve-reflect-eval",
// "-Xprint:extract-expression,resolve-reflect-eval",
// "-Vdebug"
) ++ options :+ sourceFile.toString

Expand Down
Expand Up @@ -42,12 +42,22 @@ class ExtractExpression(override val global: ExpressionGlobal)
override def transform(tree: Tree): Tree = tree match {
case tree: Import => tree

// store all local symbol in the expression val to later update their owner
case tree @ (_: DefTree | _: Function) if tree.symbol.owner == expressionVal =>
localDefs += tree.symbol
if (tree.symbol.isModule) localDefs += tree.symbol.moduleClass
super.transform(tree)

case tree if isLocalToExpression(tree.symbol) => super.transform(tree)

case _: Ident | _: Select | _: This if isStaticObject(tree.symbol) =>
getStaticObject(tree)(tree.symbol)

case _: This | _: Apply if isInaccessibleNonStaticObject(tree.symbol) =>
val qualifier = getTransformedQualifier(tree)
callMethod(tree)(qualifier, tree.symbol.asTerm, List.empty)
case _: This if !tree.symbol.hasPackageFlag =>
thisOrOuterValue(tree)(tree.symbol.enclClass.asClass)

case _: Apply if isNonStaticObject(tree.symbol) =>
callMethod(tree)(getTransformedQualifier(tree), tree.symbol.asTerm, List.empty)

case tree: Ident if isLocalVariable(tree.symbol) =>
getCapturer(tree.symbol.asTerm) match {
Expand All @@ -67,84 +77,43 @@ class ExtractExpression(override val global: ExpressionGlobal)
case None => setLocalValue(tree)(variable, rhs)
}

case tree: Select if isInaccessibleField(tree) =>
case _: Select if tree.symbol.isField && !isAccessibleMember(tree) =>
if (isJavaStatic(tree.symbol)) getField(tree)(mkNullLiteral, tree.symbol.asTerm)
else {
val qualifier = getTransformedQualifier(tree)
getField(tree)(qualifier, tree.symbol.asTerm)
}
else getField(tree)(getTransformedQualifier(tree), tree.symbol.asTerm)

case tree @ Assign(lhs, rhs) if isInaccessibleField(lhs) =>
case Assign(lhs, rhs) if lhs.symbol.isField && !isAccessibleMember(lhs) =>
if (isJavaStatic(lhs.symbol)) setField(tree)(mkNullLiteral, lhs.symbol.asTerm, transform(rhs))
else {
val qualifier = getTransformedQualifier(lhs)
setField(tree)(qualifier, lhs.symbol.asTerm, transform(rhs))
}
else setField(tree)(getTransformedQualifier(lhs), lhs.symbol.asTerm, transform(rhs))

case This(name) if !tree.symbol.hasPackageFlag && !isOwnedByExpression(tree.symbol) =>
thisOrOuterValue(tree)(tree.symbol.enclClass.asClass)
case _: Select | _: Apply | _: TypeApply
if tree.symbol.isConstructor && (!tree.symbol.owner.isStatic || !isAccessibleMember(tree)) =>
callConstructor(tree)(getTransformedQualifierOfNew(tree), tree.symbol.asTerm, getTransformedArgs(tree))

case _: Select | _: Apply | _: TypeApply if isInaccessibleConstructor(tree) =>
val args = getTransformedArgs(tree)
val qualifier = getTransformedQualifierOfNew(tree)
callConstructor(tree)(qualifier, tree.symbol.asTerm, args)

case _: Ident | _: Select | _: Apply | _: TypeApply if isInaccessibleMethod(tree) =>
case _: Ident | _: Select | _: Apply | _: TypeApply if isRealMethod(tree.symbol) && !isAccessibleMember(tree) =>
val args = getTransformedArgs(tree)
if (isJavaStatic(tree.symbol)) callMethod(tree)(mkNullLiteral, tree.symbol.asTerm, args)
else {
val qualifier = getTransformedQualifier(tree)
callMethod(tree)(qualifier, tree.symbol.asTerm, args)
else callMethod(tree)(getTransformedQualifier(tree), tree.symbol.asTerm, args)

// accessible members
case tree @ (_: Ident | _: Select) if !tree.symbol.isStatic =>
val qualifier = getTransformedQualifier(tree)
val qualifierType = widenDealiasQualifierType(tree)
val castQualifier =
if (qualifier.tpe <:< qualifierType) qualifier
else gen.mkAttributedCast(qualifier, qualifierType)
val name = tree match {
case Ident(name) => name
case Select(_, name) => name
}
Select(castQualifier, name).copyAttrs(tree)

case Typed(tree, tpt) if tpt.symbol.isType && !isTypeAccessible(tpt.symbol.asType) =>
case Typed(tree, tpt) if tpt.symbol.isType && !isTypeAccessible(tpt.tpe) =>
transform(tree)

case tree @ (_: DefTree | _: Function) if tree.symbol.owner == expressionVal =>
localDefs += tree.symbol
if (tree.symbol.isModule) localDefs += tree.symbol.moduleClass
super.transform(tree)

case tree =>
super.transform(tree)
}

/**
* The symbol is a field and the expression class cannot access it
* either because it is private or it belongs to an inaccessible type
*/
private def isInaccessibleField(tree: Tree): Boolean = {
val symbol = tree.symbol
symbol.isField &&
symbol.owner.isType &&
!isTermAccessible(symbol.asTerm, getQualifierTypeSymbol(tree))
}

/**
* The symbol is a real method and the expression class cannot access it
* either because it is private or it belongs to an inaccessible type
*/
private def isInaccessibleMethod(tree: Tree): Boolean = {
val symbol = tree.symbol
!isOwnedByExpression(symbol) &&
isRealMethod(symbol) &&
(!symbol.owner.isType || !isTermAccessible(symbol.asTerm, getQualifierTypeSymbol(tree)))
}

private def isRealMethod(symbol: Symbol): Boolean =
symbol.isMethod && !symbol.isAnonymousFunction

/**
* The symbol is a constructor and the expression class cannot access it
* either because it is an inaccessible method or it belong to a nested type (not static)
*/
private def isInaccessibleConstructor(tree: Tree): Boolean = {
val symbol = tree.symbol
!isOwnedByExpression(symbol) &&
symbol.isConstructor &&
(isInaccessibleMethod(tree) || !symbol.owner.isStatic)
}

private def getCapturer(variable: TermSymbol): Option[Symbol] = {
// a local variable can be captured by a class or method
val candidates = expressionVal.ownersIterator
Expand All @@ -162,13 +131,6 @@ class ExtractExpression(override val global: ExpressionGlobal)
case TypeApply(fun, _) => getTransformedArgs(fun)
}

private def getQualifierTypeSymbol(tree: Tree): TypeSymbol = tree match {
case Ident(_) => tree.symbol.enclClass.asClass
case Select(qualifier, _) => qualifier.tpe.dealiasWiden.typeSymbol.asType
case Apply(fun, _) => getQualifierTypeSymbol(fun)
case TypeApply(fun, _) => getQualifierTypeSymbol(fun)
}

private def getTransformedQualifier(tree: Tree): Tree = tree match {
case Ident(_) =>
// it is a local method, it captures its outer value
Expand Down Expand Up @@ -294,18 +256,14 @@ class ExtractExpression(override val global: ExpressionGlobal)
tree: Tree
)(qualifier: Tree, strategy: EvaluationStrategy, args: List[Tree], tpe: Type): Tree = {
val attachments = tree.attachments.addElement(strategy)
val methodCall = gen
gen
.mkMethodCall(
reflectEval,
List(qualifier, mkStringLiteral(strategy.toString), mkObjectArray(args))
)
.setType(definitions.AnyTpe)
.setPos(tree.pos)
.setAttachments(attachments)
val widenDealiasTpe = tpe.dealiasWiden
if (isTypeAccessible(widenDealiasTpe.typeSymbol.asType)) {
gen.mkCast(methodCall, widenDealiasTpe).setType(widenDealiasTpe)
} else methodCall
}
}

Expand All @@ -316,24 +274,43 @@ class ExtractExpression(override val global: ExpressionGlobal)
private def isStaticObject(symbol: Symbol): Boolean =
symbol.isModuleOrModuleClass && symbol.isStatic && !symbol.isJava && !symbol.isRoot

private def isInaccessibleNonStaticObject(symbol: Symbol): Boolean =
symbol.isModuleOrModuleClass && !symbol.isStatic && !symbol.isRoot && !isOwnedByExpression(symbol)
private def isNonStaticObject(symbol: Symbol): Boolean =
symbol.isModuleOrModuleClass && !symbol.isStatic && !symbol.isRoot

private def isLocalVariable(symbol: Symbol): Boolean =
!symbol.isMethod && symbol.isLocalToBlock && !isOwnedByExpression(symbol)
private def isRealMethod(symbol: Symbol): Boolean = symbol.isMethod && !symbol.isAnonymousFunction

private def isLocalVariable(symbol: Symbol): Boolean = !symbol.isMethod && symbol.isLocalToBlock

// Check if a term is accessible from the expression class
private def isTermAccessible(symbol: TermSymbol, owner: TypeSymbol): Boolean =
isOwnedByExpression(symbol) ||
(!symbol.isPrivate && !symbol.isProtected && isTypeAccessible(owner))
private def isAccessibleMember(tree: Tree): Boolean = {
val symbol = tree.symbol
symbol.owner.isType && !symbol.isPrivate && !symbol.isProtected && isTypeAccessible(widenDealiasQualifierType(tree))
}

private def widenDealiasQualifierType(tree: Tree): Type = tree match {
case Ident(_) => tree.symbol.enclClass.thisType.dealiasWiden
case Select(qualifier, _) => qualifier.tpe.dealiasWiden
case Apply(fun, _) => widenDealiasQualifierType(fun)
case TypeApply(fun, _) => widenDealiasQualifierType(fun)
}

// Check if a type is accessible from the expression class
private def isTypeAccessible(symbol: TypeSymbol): Boolean =
isOwnedByExpression(symbol) || (
!symbol.isLocalToBlock &&
symbol.ownersIterator.forall(s => s.isPublic || s.privateWithin.isPackageClass)
)
private def isTypeAccessible(tpe: Type): Boolean = {
def isPublic(sym: Symbol): Boolean = !sym.isLocalToBlock && (sym.isPublic || sym.privateWithin.isPackageClass)
val parts = Buffer.empty[Symbol]
tpe.foreach { part =>
parts += part.typeSymbol
parts += part.termSymbol
}
parts.forall {
case NoSymbol => true
case sym => isLocalToExpression(sym) || isPublic(sym)
}
}

private def isOwnedByExpression(symbol: Symbol): Boolean =
symbol.ownersIterator.exists(_ == expressionVal)
private def isLocalToExpression(symbol: Symbol): Boolean =
symbol != null && (
symbol.owner == NoSymbol || // param of a LabelDef
symbol.ownersIterator.exists(_ == expressionVal)
)
}
Expand Up @@ -48,7 +48,8 @@ class InsertExpression(override val global: ExpressionGlobal) extends Transform
| }
| .getOrElse(throw new NoSuchMethodException(methodName))
| method.setAccessible(true)
| unwrapException(method.invoke(obj, args: _*))
| val res = unwrapException(method.invoke(obj, args: _*))
| if (returnTypeName == "void") () else res
| }
|
| def callConstructor(className: String, paramTypesNames: Array[String], args: Array[Object]): Any = {
Expand All @@ -67,7 +68,7 @@ class InsertExpression(override val global: ExpressionGlobal) extends Transform
| field.get(obj)
| }
|
| def setField(obj: Any, className: String, fieldName: String, value: Any): Unit = {
| def setField(obj: Any, className: String, fieldName: String, value: Any): Any = {
| val clazz = classLoader.loadClass(className)
| val field = clazz.getDeclaredField(fieldName)
| field.setAccessible(true)
Expand Down

0 comments on commit e335c75

Please sign in to comment.