Skip to content

Commit

Permalink
Merge pull request #646 from adpi2/null-arg
Browse files Browse the repository at this point in the history
[Runtime Evaluation] add support for null arg
  • Loading branch information
adpi2 committed Feb 7, 2024
2 parents e927223 + 4c7baa3 commit f1d864f
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ private[internal] class JdiClassLoader(
Safe(mirrorOfAnyVal(value.asInstanceOf[AnyVal]))
case value: String => mirrorOf(value)
case () => Safe(mirrorOfVoid())
case null => Safe(JdiValue(null, thread))
case _ => Safe.failed(new IllegalArgumentException(s"Unsupported literal $value"))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ object RuntimePrimitiveOps {
op match {
case "==" => Valid(Eq)
case "!=" => Valid(Neq)
case _ if lhs == null || rhs == null => Recoverable(s"The $op operator is not defined on null value")
case _ if !isPrimitive(lhs) || !isPrimitive(rhs) => notDefined
case "&&" if isBoolean(lhs) && isBoolean(rhs) => Valid(And)
case "||" if isBoolean(lhs) && isBoolean(rhs) => Valid(Or)
Expand Down Expand Up @@ -102,7 +103,7 @@ object RuntimePrimitiveOps {
}

private def referenceTypeCheck(lhs: Type, rhs: Type): Option[Type] = {
(lhs.name(), rhs.name()) match {
(lhs.name, rhs.name) match {
case ("java.lang.Double", _) => Some(lhs)
case (_, "java.lang.Double") => Some(rhs)
case ("java.lang.Float", _) => Some(lhs)
Expand Down Expand Up @@ -284,6 +285,7 @@ object RuntimePrimitiveOps {
object UnaryOp {
def apply(rhs: Type, op: String): Validation[UnaryOp] =
op match {
case _ if rhs == null => Recoverable(s"$op is not defined on null")
case "unary_+" if isNumeric(rhs) => Valid(UnaryPlus)
case "unary_-" if isNumeric(rhs) => Valid(UnaryMinus)
case "unary_~" if isIntegral(rhs) => Valid(UnaryBitwiseNot)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,25 +110,6 @@ object RuntimeEvaluationTree {
}
}

object ArrayElem {
def apply(tree: RuntimeEvaluationTree, index: Seq[RuntimeEvaluationTree]): Validation[ArrayElem] = {
val integerTypes = Seq("java.lang.Integer", "java.lang.Short", "java.lang.Byte", "java.lang.Character")
if (index.size < 1 || index.size > 1) Recoverable("Array accessor must have one argument")
else
(tree, tree.`type`) match {
case (tree: RuntimeEvaluationTree, arr: jdi.ArrayType) =>
index.head.`type` match {
case idx @ (_: jdi.IntegerType | _: jdi.ShortType | _: jdi.ByteType | _: jdi.CharType) =>
Valid(new ArrayElem(tree, index.head, arr.componentType()))
case ref: jdi.ReferenceType if integerTypes.contains(ref.name) =>
Valid(new ArrayElem(tree, index.head, arr.componentType()))
case _ => Recoverable("Array index must be an integer")
}
case _ => Recoverable("Not an array accessor")
}
}
}

case class CallUnaryOp(
rhs: RuntimeEvaluationTree,
op: UnaryOp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
private def validateLiteral(lit: Lit): Validation[RuntimeEvaluationTree] =
classLoader.map { loader =>
val value = loader.mirrorOfLiteral(lit.value)
val tpe = loader.mirrorOfLiteral(lit.value).map(_.value.`type`).extract.get
val tpe = if (lit.value == null) null else value.map(_.value.`type`).extract.get
Value(value, tpe)
}

Expand Down Expand Up @@ -217,7 +217,7 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
candidates
.validateSingle(s"Cannot find top-level class $name")
.flatMap(loadClass)
.map(StaticOrTopLevelClass)
.map(StaticOrTopLevelClass.apply)
}

private def getAllFullyQualifiedClassNames(name: String): Seq[String] = {
Expand Down Expand Up @@ -297,7 +297,7 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
qualifier: RuntimeEvaluationTree,
args: Seq[RuntimeEvaluationTree]
): Validation[RuntimeEvaluationTree] =
findMethodBySignedName(qualifier, "apply", args).orElse(ArrayElem(qualifier, args))
findMethodBySignedName(qualifier, "apply", args).orElse(asArrayElem(qualifier, args))

private def findMethodInThisOrOuter(
thisOrOuter: RuntimeEvaluationTree,
Expand Down Expand Up @@ -370,7 +370,7 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
}

private def isBoolean(tpe: jdi.Type): Boolean =
tpe.isInstanceOf[jdi.BooleanType] || tpe.name == "java.lang.Boolean"
tpe != null && (tpe.isInstanceOf[jdi.BooleanType] || tpe.name == "java.lang.Boolean")

private def validateAssign(tree: Term.Assign): Validation[RuntimeEvaluationTree] = {
val lhs =
Expand All @@ -397,7 +397,7 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
rhs <- validateAsValue(tree.rhs)
.filter(
rhs => isAssignableFrom(rhs.`type`, lhs.`type`),
rhs => s"Cannot assign ${rhs.`type`.name} to ${lhs.`type`.name}"
rhs => s"Cannot assign ${nameOrNull(rhs.`type`)} to ${nameOrNull(lhs.`type`)}"
)
unit <- unitTree
} yield Assign(lhs, rhs, unit.`type`)
Expand All @@ -420,18 +420,15 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
private def moreSpecificThan(m1: jdi.Method, m2: jdi.Method): Boolean = {
m1.argumentTypes()
.asScala
.zip(m2.argumentTypes().asScala)
.zip(m2.argumentTypes.asScala)
.forall {
case (t1, t2) if t1.name == t2.name => true
case (t1, t2) if nameOrNull(t1) == nameOrNull(t2) => true
case (_: jdi.PrimitiveType, _) => true
case (_, _: jdi.PrimitiveType) => true
case (r1: jdi.ReferenceType, r2: jdi.ReferenceType) => isAssignableFrom(r1, r2)
}
}

private def argsMatch(method: jdi.Method, args: Seq[jdi.Type], boxing: Boolean): Boolean =
method.argumentTypeNames().size() == args.size && areAssignableFrom(method, args, boxing)

/**
* @see <a href="https://docs.oracle.com/javase/specs/jls/se20/html/jls-15.html#jls-15.12.2.5">JLS#15.12.2.5. Choosing the most specific method</a>
*
Expand Down Expand Up @@ -468,9 +465,9 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
args: Seq[jdi.Type]
): Validation[jdi.Method] = {
val candidates = findMethodsByName(ref, name)
val unboxedCandidates = candidates.filter(argsMatch(_, args, boxing = false))
val unboxedCandidates = candidates.filter(matchArguments(_, args, boxing = false))
val boxedCandidates = unboxedCandidates.size match {
case 0 => candidates.filter(argsMatch(_, args, boxing = true))
case 0 => candidates.filter(matchArguments(_, args, boxing = true))
case _ => unboxedCandidates
}
val withoutBridges = boxedCandidates.size match {
Expand All @@ -481,10 +478,9 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
case 0 | 1 => withoutBridges
case _ => filterMostPreciseMethod(withoutBridges)
}
def formatArgs = args.map(nameOrNull).mkString("[", ", ", "]")
finalCandidates
.validateSingle(
s"Cannot find method $name with arguments of types ${args.map(_.name).mkString("[", ", ", "]")} in ${ref.name}"
)
.validateSingle(s"Cannot find method $name with arguments of types $formatArgs in ${ref.name}")
.map(loadClassOnNeed)
}

Expand Down Expand Up @@ -579,22 +575,20 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source

(got, expected) match {
case (g: jdi.ArrayType, at: jdi.ArrayType) =>
checkClassStatus(at.componentType())
g.componentType().equals(at.componentType())
checkClassStatus(at.componentType)
g.componentType.equals(at.componentType)
case (g: jdi.PrimitiveType, pt: jdi.PrimitiveType) => got.equals(pt)
case (g: jdi.ReferenceType, ref: jdi.ReferenceType) => referenceTypesMatch(g, ref)
case (_: jdi.VoidType, _: jdi.VoidType) => true
case (g: jdi.ReferenceType, pt: jdi.PrimitiveType) =>
isAssignableFrom(g, frame.getPrimitiveBoxedClass(pt))
case (g: jdi.PrimitiveType, ct: jdi.ReferenceType) =>
isAssignableFrom(frame.getPrimitiveBoxedClass(g), ct)

case (g: jdi.ReferenceType, pt: jdi.PrimitiveType) => isAssignableFrom(g, frame.getPrimitiveBoxedClass(pt))
case (g: jdi.PrimitiveType, ct: jdi.ReferenceType) => isAssignableFrom(frame.getPrimitiveBoxedClass(g), ct)
case (null, _: jdi.ReferenceType) => true
case _ => false
}
}

private def areAssignableFrom(method: jdi.Method, args: Seq[jdi.Type], boxing: Boolean): Boolean =
method.argumentTypes.size == args.size &&
private def matchArguments(method: jdi.Method, args: Seq[jdi.Type], boxing: Boolean): Boolean =
method.argumentTypeNames.size == args.size &&
method.argumentTypes.asScala.zip(args).forall {
case (_: jdi.PrimitiveType, _: jdi.ReferenceType) if !boxing => false
case (_: jdi.ReferenceType, _: jdi.PrimitiveType) if !boxing => false
Expand Down Expand Up @@ -652,18 +646,18 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
}

private def getCommonSuperClass(tpe1: jdi.Type, tpe2: jdi.Type): Validation[jdi.Type] = {
def getSuperClasses(of: jdi.Type): Array[jdi.ClassType] =
of match {
def getSuperClasses(tpe: jdi.Type): Array[jdi.ClassType] =
tpe match {
case cls: jdi.ClassType =>
Iterator.iterate(cls)(cls => cls.superclass()).takeWhile(_ != null).toArray
Iterator.iterate(cls)(cls => cls.superclass).takeWhile(_ != null).toArray
case _ => Array()
}

val superClasses1 = getSuperClasses(tpe1)
val superClasses2 = getSuperClasses(tpe2)
Validation.fromOption(
superClasses1.find(superClasses2.contains),
s"${tpe1.name} and ${tpe2.name} do not have any common super class"
s"${nameOrNull(tpe1)} and ${nameOrNull(tpe2)} do not have any common super class"
)
}

Expand Down Expand Up @@ -701,7 +695,7 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source

private def asStaticModule(tpe: jdi.Type): Validation[RuntimeEvaluationTree] =
if (isStaticModule(tpe)) Valid(preEvaluate(StaticModule(tpe.asInstanceOf[jdi.ClassType])))
else Recoverable(s"${tpe.name} is not a static module")
else Recoverable(s"${nameOrNull(tpe)} is not a static module")

private def asModule(tpe: jdi.ReferenceType, qualifier: RuntimeEvaluationTree): Validation[RuntimeEvaluationTree] =
if (isStaticModule(tpe)) asStaticModule(tpe)
Expand Down Expand Up @@ -746,6 +740,31 @@ private[evaluator] class RuntimeValidation(frame: JdiFrame, sourceLookUp: Source
if (method.isStatic) Valid(CallStaticMethod(method, args, qualifier))
else Recoverable(s"Cannot access instance method ${method.name} from static context")

private def asArrayElem(
array: RuntimeEvaluationTree,
args: Seq[RuntimeEvaluationTree]
): Validation[RuntimeEvaluationTree] = {
val integerTypes = Seq("java.lang.Integer", "java.lang.Short", "java.lang.Byte", "java.lang.Character")
if (args.size != 1) Recoverable("Array accessor must have one argument")
else {
val index = args.head
array.`type` match {
case arrayTpe: jdi.ArrayType =>
index.`type` match {
case (_: jdi.IntegerType | _: jdi.ShortType | _: jdi.ByteType | _: jdi.CharType) =>
Valid(preEvaluate(new ArrayElem(array, index, arrayTpe.componentType)))
case ref: jdi.ReferenceType if integerTypes.contains(ref.name) =>
Valid(preEvaluate(new ArrayElem(array, index, arrayTpe.componentType)))
case tpe => Recoverable(s"Array index must be an integer, found ${tpe.name}")
}
case tpe => Recoverable(s"${tpe.name} is not an array")
}
}
}

private def nameOrNull(tpe: jdi.Type): String =
if (tpe == null) "null" else tpe.name

private implicit class IterableExtensions[A](iter: Iterable[A]) {
def validateSingle(message: String): Validation[A] =
iter.size match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ object Safe {

def apply[A](a: Try[A]): Safe[A] = {
a match {
case null => new Safe(Success(null).asInstanceOf, () => ())
case Success(value) =>
value match {
case obj: ObjectReference =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,16 +371,6 @@ object RuntimeEvaluatorEnvironments {
| x()
|}
|""".stripMargin
val localVarTestSource =
"""|package example
|
|object Main {
| def main(args: Array[String]): Unit = {
| val name = "world"
| println(name)
| }
|}
|""".stripMargin

val fieldSource =
"""|package example
Expand Down Expand Up @@ -598,25 +588,6 @@ object RuntimeEvaluatorEnvironments {
|}
|""".stripMargin

val arraysSource =
"""|package example
|
|object Main {
| def main(args: Array[String]): Unit = {
| val arr = Array(1, 2, 3)
| val sh: Short = 2
| val ch: Char = 2
| val by: Byte = 2
| println("ok")
| }
|
| def test(arr: Array[Int]): String = arr.mkString(",")
|
| def test(arr: Array[Test]): String = arr.map(_.i).mkString(",")
|
| case class Test(i: Int)
|}
|""".stripMargin
val collectionSource =
"""|package example
|
Expand Down Expand Up @@ -733,8 +704,6 @@ object RuntimeEvaluatorEnvironments {
}

abstract class RuntimeEvaluatorTests(val scalaVersion: ScalaVersion) extends DebugTestSuite {
lazy val localVar =
TestingDebuggee.mainClass(RuntimeEvaluatorEnvironments.localVarTestSource, "example.Main", scalaVersion)
lazy val field = TestingDebuggee.mainClass(RuntimeEvaluatorEnvironments.fieldSource, "example.Main", scalaVersion)
lazy val method = TestingDebuggee.mainClass(RuntimeEvaluatorEnvironments.methodSource, "example.Main", scalaVersion)
lazy val overloads =
Expand All @@ -743,8 +712,6 @@ abstract class RuntimeEvaluatorTests(val scalaVersion: ScalaVersion) extends Deb
lazy val cls = TestingDebuggee.mainClass(RuntimeEvaluatorEnvironments.cls, "example.Main", scalaVersion)
lazy val boxingOverloads =
TestingDebuggee.mainClass(RuntimeEvaluatorEnvironments.boxingOverloads, "example.Main", scalaVersion)
lazy val arrays =
TestingDebuggee.mainClass(RuntimeEvaluatorEnvironments.arraysSource, "example.Main", scalaVersion)
lazy val collections =
TestingDebuggee.mainClass(RuntimeEvaluatorEnvironments.collectionSource, "example.Main", scalaVersion)
lazy val inners =
Expand All @@ -771,13 +738,22 @@ abstract class RuntimeEvaluatorTests(val scalaVersion: ScalaVersion) extends Deb
}

test("local variable") {
implicit val debuggee = localVar
val source =
"""|package example
|
|object Main {
| def main(args: Array[String]): Unit = {
| val name = "world"
| println(name)
| }
|}
|""".stripMargin
implicit val debuggee: TestingDebuggee = TestingDebuggee.mainClass(source, "example.Main", scalaVersion)
check(
Breakpoint(6),
Evaluation.success("name", "world"),
Evaluation.failed("unknown", "unknown is not a local variable")
)

}

test("instance fields") {
Expand Down Expand Up @@ -916,8 +892,24 @@ abstract class RuntimeEvaluatorTests(val scalaVersion: ScalaVersion) extends Deb
)
}

test("Should work on arrays") {
implicit val debuggee: TestingDebuggee = arrays
test("arrays") {
val source =
"""|package example
|
|object Main {
| def main(args: Array[String]): Unit = {
| val arr = Array(1, 2, 3)
| val sh: Short = 2
| val ch: Char = 2
| val by: Byte = 2
| println("ok")
| }
| def test(arr: Array[Int]): String = arr.mkString(",")
| def test(arr: Array[Test]): String = arr.map(_.i).mkString(",")
| case class Test(i: Int)
|}
|""".stripMargin
implicit val debuggee: TestingDebuggee = TestingDebuggee.mainClass(source, "example.Main", scalaVersion)
check(
Breakpoint(9),
DebugStepAssert.inParallel(
Expand Down Expand Up @@ -1394,4 +1386,20 @@ abstract class RuntimeEvaluatorTests(val scalaVersion: ScalaVersion) extends Deb
Evaluation.success("Main.test.t.a", 42)
)
}

test("accept null as an argument") {
val source =
"""|package example
|
|object Main {
| def main(args: Array[String]): Unit = {
| println("Hello, World!")
| }
|
| def m(x: String): String = x
|}
|""".stripMargin
implicit val debuggee: TestingDebuggee = TestingDebuggee.mainClass(source, "example.Main", scalaVersion)
check(Breakpoint(5), Evaluation.success("m(null)", null))
}
}

0 comments on commit f1d864f

Please sign in to comment.