Skip to content

Commit

Permalink
Merge pull request #532 from adpi2/fix-eval-static-field
Browse files Browse the repository at this point in the history
Evaluation of static members
  • Loading branch information
adpi2 committed Jul 28, 2023
2 parents 2941a74 + 42335ac commit ec02508
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 48 deletions.
26 changes: 16 additions & 10 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ lazy val core = projectMatrix
buildInfoKeys := Seq[BuildInfoKey](
BuildInfoKey.action("organization")(organization.value),
BuildInfoKey.action("version")(version.value),
BuildInfoKey.action("expressionCompilerName")((expressionCompiler212 / name).value),
BuildInfoKey.action("expressionCompilerName")((LocalProject("expressionCompiler2_12") / name).value),
BuildInfoKey.action("unpicklerName")((LocalProject("unpickler3") / name).value),
BuildInfoKey.action("scala212")(Dependencies.scala212),
BuildInfoKey.action("scala213")(Dependencies.scala213),
Expand All @@ -102,6 +102,7 @@ lazy val core = projectMatrix
)

lazy val tests212 = tests.jvm(Dependencies.scala212)
lazy val tests213 = tests.jvm(Dependencies.scala213)
lazy val tests3 = tests.jvm(Dependencies.scala31Plus)
lazy val tests = projectMatrix
.in(file("modules/tests"))
Expand All @@ -123,11 +124,11 @@ lazy val tests = projectMatrix
Test / testOptions += Tests.Argument(TestFrameworks.MUnit, "+l"),
Test / testOptions := (Test / testOptions)
.dependsOn(
expressionCompiler212 / publishLocal,
expressionCompiler213 / publishLocal,
expressionCompiler30 / publishLocal,
expressionCompiler3 / publishLocal,
// break cyclic reference
LocalProject("expressionCompiler2_12") / publishLocal,
LocalProject("expressionCompiler2_13") / publishLocal,
LocalProject("expressionCompiler3_0") / publishLocal,
LocalProject("expressionCompiler3") / publishLocal,
LocalProject("unpickler3") / publishLocal
)
.value
Expand Down Expand Up @@ -158,10 +159,10 @@ lazy val expressionCompiler30 = expressionCompiler.finder(scala30Axis)(true)
lazy val expressionCompiler3 = expressionCompiler.finder(scala3Axis)(true)
lazy val expressionCompiler = projectMatrix
.in(file("modules/expression-compiler"))
.customRow(true, Seq(scala212Axis, VirtualAxis.jvm), identity[Project] _)
.customRow(true, Seq(scala213Axis, VirtualAxis.jvm), identity[Project] _)
.customRow(true, Seq(scala30Axis, VirtualAxis.jvm), identity[Project] _)
.customRow(true, Seq(scala3Axis, VirtualAxis.jvm), identity[Project] _)
.customRow(true, Seq(scala212Axis, VirtualAxis.jvm), p => p.dependsOn(tests212 % Test))
.customRow(true, Seq(scala213Axis, VirtualAxis.jvm), p => p.dependsOn(tests213 % Test))
.customRow(true, Seq(scala30Axis, VirtualAxis.jvm), p => p)
.customRow(true, Seq(scala3Axis, VirtualAxis.jvm), p => p.dependsOn(tests3 % Test))
.settings(
name := "scala-expression-compiler",
crossScalaVersions ++= CrossVersion
Expand All @@ -175,7 +176,7 @@ lazy val expressionCompiler = projectMatrix
.toSeq
.flatten,
crossScalaVersions := crossScalaVersions.value.distinct,
libraryDependencies ++= onScalaVersion(
libraryDependencies ++= Seq(Dependencies.munit % Test) ++ onScalaVersion(
scala212 = Some(Dependencies.scalaCollectionCompat),
scala213 = None,
scala3 = None
Expand All @@ -189,6 +190,11 @@ lazy val expressionCompiler = projectMatrix
case (3, minor) => sourceDir / s"scala-3.1+"
}
},
Test / unmanagedSourceDirectories ++= {
val sourceDir = (Test / sourceDirectory).value
if (scalaVersion.value == Dependencies.scala31Plus) Seq(sourceDir / "scala-3.latest")
else Seq.empty
},
Compile / doc / sources := Seq.empty,
libraryDependencies += Dependencies.scalaCompiler(scalaVersion.value),
scalacOptionsSetting
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ch.epfl.scala.debugadapter

import java.nio.file.Path
import java.io.File

trait Debuggee {
def name: String
Expand All @@ -15,4 +16,5 @@ trait Debuggee {
def classPathEntries: Seq[ClassPathEntry] = managedEntries ++ unmanagedEntries
def classPath: Seq[Path] = classPathEntries.map(_.absolutePath)
def classEntries: Seq[ClassEntry] = classPathEntries ++ javaRuntime
def classPathString: String = classPath.mkString(File.pathSeparator)
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,8 @@ class ExtractExpression(using exprCtx: ExpressionContext) extends MacroTransform
case tree: ImportOrExport => tree

case tree if tree.symbol.is(Inline) =>
tree.symbol.info match
case tpe: ConstantType => cpy.Literal(tree)(tpe.value)
case _ =>
report.error(s"Cannot evaluate inlined expression with non constant type", tree.srcPos)
tree
val tpe = tree.symbol.info.asInstanceOf[ConstantType]
cpy.Literal(tree)(tpe.value)

// static object
case tree: (Ident | Select) if isStaticObject(tree.symbol) =>
Expand Down Expand Up @@ -108,16 +105,20 @@ class ExtractExpression(using exprCtx: ExpressionContext) extends MacroTransform

// inaccessible fields
case tree: Select if isInaccessibleField(tree) =>
val qualifier = getTransformedQualifier(tree)
getField(tree)(qualifier, tree.symbol.asTerm)
if tree.symbol.is(JavaStatic) then getField(tree)(nullLiteral, tree.symbol.asTerm)
else
val qualifier = getTransformedQualifier(tree)
getField(tree)(qualifier, tree.symbol.asTerm)

// assignment to inaccessible fields
case tree @ Assign(lhs, rhs) if isInaccessibleField(lhs) =>
val qualifier = getTransformedQualifier(lhs)
setField(tree)(qualifier, lhs.symbol.asTerm, transform(rhs))
if lhs.symbol.is(JavaStatic) then setField(tree)(nullLiteral, lhs.symbol.asTerm, transform(rhs))
else
val qualifier = getTransformedQualifier(lhs)
setField(tree)(qualifier, lhs.symbol.asTerm, transform(rhs))

// this or outer this
case tree @ This(Ident(name)) if !isOwnedByExpression(tree.symbol) =>
case tree @ This(Ident(name)) if !tree.symbol.is(Package) && !isOwnedByExpression(tree.symbol) =>
thisOrOuterValue(tree)(tree.symbol.enclosingClass.asClass)

// inaccessible constructors
Expand All @@ -129,8 +130,10 @@ class ExtractExpression(using exprCtx: ExpressionContext) extends MacroTransform
// inaccessible methods
case tree: (Ident | Select | Apply | TypeApply) if isInaccessibleMethod(tree) =>
val args = getTransformedArgs(tree)
val qualifier = getTransformedQualifier(tree)
callMethod(tree)(qualifier, tree.symbol.asTerm, args)
if tree.symbol.is(JavaStatic) then callMethod(tree)(nullLiteral, tree.symbol.asTerm, args)
else
val qualifier = getTransformedQualifier(tree)
callMethod(tree)(qualifier, tree.symbol.asTerm, args)

case Typed(tree, tpt) if tpt.symbol.isType && !isTypeAccessible(tpt.symbol.asType) =>
transform(tree)
Expand All @@ -143,19 +146,19 @@ class ExtractExpression(using exprCtx: ExpressionContext) extends MacroTransform
*/
private def isInaccessibleField(tree: Tree)(using Context): Boolean =
val symbol = tree.symbol
symbol.isField &&
symbol.owner.isType &&
!isTermAccessible(symbol.asTerm, getQualifierTypeSymbol(tree))
symbol.isField
&& symbol.owner.isType
&& !isTermAccessible(symbol.asTerm, getQualifierTypeSymbol(tree))

/**
* The symbol is a real method and the expression class cannot access it
* either because it is private or it belongs to an inaccessible type
*/
private def isInaccessibleMethod(tree: Tree)(using Context): Boolean =
val symbol = tree.symbol
!isOwnedByExpression(symbol) &&
symbol.isRealMethod &&
(!symbol.owner.isType || !isTermAccessible(symbol.asTerm, getQualifierTypeSymbol(tree)))
!isOwnedByExpression(symbol)
&& symbol.isRealMethod
&& (!symbol.owner.isType || !isTermAccessible(symbol.asTerm, getQualifierTypeSymbol(tree)))

/**
* The symbol is a constructor and the expression class cannot access it
Expand Down Expand Up @@ -329,11 +332,7 @@ class ExtractExpression(using exprCtx: ExpressionContext) extends MacroTransform
val strategy = EvaluationStrategy.StaticObject(obj.asClass)
reflectEval(tree)(nullLiteral, strategy, List.empty, obj.typeRef)

private def getField(
tree: Tree
)(qualifier: Tree, field: TermSymbol)(using
Context
): Tree =
private def getField(tree: Tree)(qualifier: Tree, field: TermSymbol)(using Context): Tree =
reportErrorIfLocalInsideValueClass(field, tree.srcPos)
val byName = isByNameParam(field.info)
val strategy = EvaluationStrategy.Field(field, byName)
Expand Down Expand Up @@ -433,9 +432,8 @@ class ExtractExpression(using exprCtx: ExpressionContext) extends MacroTransform
private def isTermAccessible(symbol: TermSymbol, owner: TypeSymbol)(using
Context
): Boolean =
isOwnedByExpression(symbol) || (
!symbol.isPrivate && isTypeAccessible(owner)
)
isOwnedByExpression(symbol)
|| (!symbol.isPrivate && !symbol.is(Protected) && isTypeAccessible(owner))

// Check if a type is accessible from the expression class
private def isTypeAccessible(symbol: TypeSymbol)(using Context): Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,7 @@ class ResolveReflectEval(using exprCtx: ExpressionContext) extends MiniPhase:
// if the field is lazy, if it is private in a value class or a trait
// then we must call the getter method
val fieldValue =
if field.is(Lazy) ||
field.owner.isValueClass ||
field.owner.is(Trait)
if field.is(Lazy) || field.owner.isValueClass || field.owner.is(Trait)
then gen.callMethod(qualifier, field.getter.asTerm, Nil)
else
val rawValue = gen.getField(qualifier, field)
Expand Down Expand Up @@ -236,24 +234,22 @@ class ResolveReflectEval(using exprCtx: ExpressionContext) extends MiniPhase:
)

def getField(qualifier: Tree, field: TermSymbol): Tree =
val fieldName = JavaEncoding.encode(field.name)
Apply(
Select(expressionThis, termName("getField")),
List(
qualifier,
Literal(Constant(JavaEncoding.encode(field.owner.asType))),
Literal(Constant(fieldName))
Literal(Constant(JavaEncoding.encode(field.name)))
)
)

def setField(qualifier: Tree, field: TermSymbol, value: Tree): Tree =
val fieldName = JavaEncoding.encode(field.name)
Apply(
Select(expressionThis, termName("setField")),
List(
qualifier,
Literal(Constant(JavaEncoding.encode(field.owner.asType))),
Literal(Constant(fieldName)),
Literal(Constant(JavaEncoding.encode(field.name))),
value
)
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package ch.epfl.scala.debugadapter

import ch.epfl.scala.debugadapter.testfmk.TestingDebuggee
import ch.epfl.scala.debugadapter.ScalaVersion
import dotty.tools.dotc.ExpressionCompilerBridge
import java.nio.file.Files
import scala.collection.mutable.Buffer
import java.{util => ju}
import scala.jdk.CollectionConverters.*
import scala.concurrent.duration.*

/**
* This class is used to enter the expression compiler with a debugger
* It is not meant to be run in the CI
*/
class ExpressionCompilerDebug extends munit.FunSuite:
val scalaVersion = ScalaVersion.`3.1+`
val compiler = new ExpressionCompilerBridge

override def munitTimeout: Duration = 1.hour

test("debug expression compiler".ignore) {
val javaSource =
"""|package example;
|
|class A {
| protected static String x = "x";
| protected static String m() {
| return "m";
| }
|}
|""".stripMargin
val javaModule = TestingDebuggee.fromJavaSource(javaSource, "example.A", scalaVersion)
val scalaSource =
"""|package example
|
|object Main extends A {
| def main(args: Array[String]): Unit = {
| println("Hello, World!")
| }
|}
|""".stripMargin
implicit val debuggee: TestingDebuggee =
TestingDebuggee.mainClass(scalaSource, "example.Main", scalaVersion, Seq.empty, Seq(javaModule.mainModule))
evaluate(5, "A.x")
evaluate(5, "A.m")
}

def evaluate(line: Int, expression: String, localVariables: Set[String] = Set.empty)(using
debuggee: TestingDebuggee
): Unit =
val out = debuggee.tempDir.resolve("expr-classes")
if Files.notExists(out) then Files.createDirectory(out)
val errors = Buffer.empty[String]
compiler.run(
out,
"Expression",
debuggee.classPathString,
debuggee.mainModule.scalacOptions.toArray,
debuggee.mainSource,
line,
expression,
localVariables.asJava,
"example",
error => {
println(Console.RED + error + Console.RESET)
errors += error
},
testMode = true
)
if errors.nonEmpty then throw new Exception("Evaluation failed")
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import scala.util.Properties
import scala.util.control.NonFatal

case class TestingDebuggee(
tempDir: Path,
scalaVersion: ScalaVersion,
sourceFiles: Seq[Path],
mainModule: Module,
Expand All @@ -45,7 +46,7 @@ case class TestingDebuggee(
override def libraries: Seq[Library] = dependencies.collect { case m: Library => m }
override def unmanagedEntries: Seq[UnmanagedEntry] = Seq.empty
override def run(listener: DebuggeeListener): CancelableFuture[Unit] = {
val cmd = Seq("java", DebugInterface, "-cp", classPath.mkString(File.pathSeparator), mainClass)
val cmd = Seq("java", DebugInterface, "-cp", classPathString, mainClass)
val builder = new ProcessBuilder(cmd: _*)
val process = builder.start()
new MainProcess(process, listener)
Expand Down Expand Up @@ -147,7 +148,7 @@ object TestingDebuggee {
): TestingDebuggee = {
val className = mainClassName.split('.').last
val sourceName = s"$className.scala"
mainClass(Seq(sourceName -> source), mainClassName, scalaVersion, Seq.empty, dependencies)
mainClass(Seq(sourceName -> source), mainClassName, scalaVersion, scalacOptions, dependencies)
}

def mainClass(
Expand Down Expand Up @@ -197,7 +198,7 @@ object TestingDebuggee {
}

val mainModule = Module(mainClassName, Some(scalaVersion), scalacOptions, classDir, sourceEntries)
TestingDebuggee(scalaVersion, sourceFiles, mainModule, allDependencies, mainClassName, javaRuntime)
TestingDebuggee(tempDir, scalaVersion, sourceFiles, mainModule, allDependencies, mainClassName, javaRuntime)
}

def munitTestSuite(
Expand Down Expand Up @@ -246,7 +247,7 @@ object TestingDebuggee {

val sourceEntry = SourceDirectory(srcDir)
val mainModule = Module(testSuite, Some(scalaVersion), Seq.empty, classDir, Seq(sourceEntry))
TestingDebuggee(scalaVersion, Seq(sourceFile), mainModule, classPath, "TestRunner", getRuntime())
TestingDebuggee(tempDir, scalaVersion, Seq(sourceFile), mainModule, classPath, "TestRunner", getRuntime())
}

private def getResource(name: String): Path =
Expand Down Expand Up @@ -288,7 +289,7 @@ object TestingDebuggee {

val sourceEntry = SourceDirectory(srcDir)
val mainModule = Module(mainClassName, None, Seq.empty, classDir, Seq(sourceEntry))
TestingDebuggee(scalaVersion, Seq(srcFile), mainModule, Seq.empty, mainClassName, getRuntime())
TestingDebuggee(tempDir, scalaVersion, Seq(srcFile), mainModule, Seq.empty, mainClassName, getRuntime())
}

private def startCrawling(input: InputStream)(f: String => Unit): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2117,6 +2117,38 @@ abstract class ScalaEvaluationTests(scalaVersion: ScalaVersion) extends DebugTes
)
}
}

test("java static members") {
val javaSource =
"""|package example;
|
|class A {
| protected static String x = "x";
| protected static String m() {
| return "m";
| }
|}
|""".stripMargin
val javaModule = TestingDebuggee.fromJavaSource(javaSource, "example.A", scalaVersion)
val scalaSource =
"""|package example
|
|object Main extends A {
| def main(args: Array[String]): Unit = {
| println("Hello, World!")
| }
|}
|""".stripMargin
implicit val debuggee: TestingDebuggee =
TestingDebuggee.mainClass(scalaSource, "example.Main", scalaVersion, Seq.empty, Seq(javaModule.mainModule))
check(
Breakpoint(5),
Evaluation.successOrIgnore("A.x", "x", isScala2),
Evaluation.successOrIgnore("A.x = \"y\"", (), isScala2),
Evaluation.successOrIgnore("A.x", "y", isScala2),
Evaluation.successOrIgnore("A.m()", "m", isScala2)
)
}
}

abstract class Scala2EvaluationTests(val scalaVersion: ScalaVersion) extends ScalaEvaluationTests(scalaVersion) {
Expand Down

0 comments on commit ec02508

Please sign in to comment.