Skip to content

Commit

Permalink
Merge pull request #14702 from griggt/fix-14473
Browse files Browse the repository at this point in the history
  • Loading branch information
dwijnand committed Mar 18, 2022
2 parents 17e46ad + 0b0f626 commit 5587767
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 21 deletions.
20 changes: 13 additions & 7 deletions compiler/src/dotty/tools/repl/Rendering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -129,25 +129,31 @@ private[repl] class Rendering(parentClassLoader: Option[ClassLoader] = None) {
infoDiagnostic(d.symbol.showUser, d)

/** Render value definition result */
def renderVal(d: Denotation)(using Context): Option[Diagnostic] =
def renderVal(d: Denotation)(using Context): Either[InvocationTargetException, Option[Diagnostic]] =
val dcl = d.symbol.showUser
def msg(s: String) = infoDiagnostic(s, d)
try
if (d.symbol.is(Flags.Lazy)) Some(msg(dcl))
else valueOf(d.symbol).map(value => msg(s"$dcl = $value"))
catch case e: InvocationTargetException => Some(msg(renderError(e, d)))
Right(
if d.symbol.is(Flags.Lazy) then Some(msg(dcl))
else valueOf(d.symbol).map(value => msg(s"$dcl = $value"))
)
catch case e: InvocationTargetException => Left(e)
end renderVal

/** Force module initialization in the absence of members. */
def forceModule(sym: Symbol)(using Context): Seq[Diagnostic] =
import scala.util.control.NonFatal
def load() =
val objectName = sym.fullName.encode.toString
Class.forName(objectName, true, classLoader())
Nil
try load() catch case e: ExceptionInInitializerError => List(infoDiagnostic(renderError(e, sym.denot), sym.denot))
try load()
catch
case e: ExceptionInInitializerError => List(renderError(e, sym.denot))
case NonFatal(e) => List(renderError(InvocationTargetException(e), sym.denot))

/** Render the stack trace of the underlying exception. */
private def renderError(ite: InvocationTargetException | ExceptionInInitializerError, d: Denotation)(using Context): String =
def renderError(ite: InvocationTargetException | ExceptionInInitializerError, d: Denotation)(using Context): Diagnostic =
import dotty.tools.dotc.util.StackTraceOps._
val cause = ite.getCause match
case e: ExceptionInInitializerError => e.getCause
Expand All @@ -159,7 +165,7 @@ private[repl] class Rendering(parentClassLoader: Option[ClassLoader] = None) {
ste.getClassName.startsWith(REPL_WRAPPER_NAME_PREFIX) // d.symbol.owner.name.show is simple name
&& (ste.getMethodName == nme.STATIC_CONSTRUCTOR.show || ste.getMethodName == nme.CONSTRUCTOR.show)

cause.formatStackTracePrefix(!isWrapperInitialization(_))
infoDiagnostic(cause.formatStackTracePrefix(!isWrapperInitialization(_)), d)
end renderError

private def infoDiagnostic(msg: String, d: Denotation)(using Context): Diagnostic =
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/repl/ReplCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class ReplCompiler extends Compiler {
val rootCtx = super.rootContext.fresh
.setOwner(defn.EmptyPackageClass)
.withRootImports
(1 to state.objectIndex).foldLeft(rootCtx)((ctx, id) =>
(state.validObjectIndexes).foldLeft(rootCtx)((ctx, id) =>
importPreviousRun(id)(using ctx))
}
}
Expand Down
51 changes: 38 additions & 13 deletions compiler/src/dotty/tools/repl/ReplDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import dotty.tools.runner.ScalaClassLoader.*
import org.jline.reader._

import scala.annotation.tailrec
import scala.collection.mutable
import scala.collection.JavaConverters._
import scala.util.Using

Expand All @@ -55,12 +56,15 @@ import scala.util.Using
* @param objectIndex the index of the next wrapper
* @param valIndex the index of next value binding for free expressions
* @param imports a map from object index to the list of user defined imports
* @param invalidObjectIndexes the set of object indexes that failed to initialize
* @param context the latest compiler context
*/
case class State(objectIndex: Int,
valIndex: Int,
imports: Map[Int, List[tpd.Import]],
context: Context)
invalidObjectIndexes: Set[Int],
context: Context):
def validObjectIndexes = (1 to objectIndex).filterNot(invalidObjectIndexes.contains(_))

/** Main REPL instance, orchestrating input, compilation and presentation */
class ReplDriver(settings: Array[String],
Expand Down Expand Up @@ -94,7 +98,7 @@ class ReplDriver(settings: Array[String],
}

/** the initial, empty state of the REPL session */
final def initialState: State = State(0, 0, Map.empty, rootCtx)
final def initialState: State = State(0, 0, Map.empty, Set.empty, rootCtx)

/** Reset state of repl to the initial state
*
Expand Down Expand Up @@ -237,7 +241,7 @@ class ReplDriver(settings: Array[String],
completions.map(_.label).distinct.map(makeCandidate)
}
.getOrElse(Nil)
end completions
end completions

private def interpret(res: ParseResult)(implicit state: State): State = {
res match {
Expand Down Expand Up @@ -353,14 +357,33 @@ class ReplDriver(settings: Array[String],
val typeAliases =
info.bounds.hi.typeMembers.filter(_.symbol.info.isTypeAlias)

val formattedMembers =
typeAliases.map(rendering.renderTypeAlias) ++
defs.map(rendering.renderMethod) ++
vals.flatMap(rendering.renderVal)

val diagnostics = if formattedMembers.isEmpty then rendering.forceModule(symbol) else formattedMembers

(state.copy(valIndex = state.valIndex - vals.count(resAndUnit)), diagnostics)
// The wrapper object may fail to initialize if the rhs of a ValDef throws.
// In that case, don't attempt to render any subsequent vals, and mark this
// wrapper object index as invalid.
var failedInit = false
val renderedVals =
val buf = mutable.ListBuffer[Diagnostic]()
for d <- vals do if !failedInit then rendering.renderVal(d) match
case Right(Some(v)) =>
buf += v
case Left(e) =>
buf += rendering.renderError(e, d)
failedInit = true
case _ =>
buf.toList

if failedInit then
// We limit the returned diagnostics here to `renderedVals`, which will contain the rendered error
// for the val which failed to initialize. Since any other defs, aliases, imports, etc. from this
// input line will be inaccessible, we avoid rendering those so as not to confuse the user.
(state.copy(invalidObjectIndexes = state.invalidObjectIndexes + state.objectIndex), renderedVals)
else
val formattedMembers =
typeAliases.map(rendering.renderTypeAlias)
++ defs.map(rendering.renderMethod)
++ renderedVals
val diagnostics = if formattedMembers.isEmpty then rendering.forceModule(symbol) else formattedMembers
(state.copy(valIndex = state.valIndex - vals.count(resAndUnit)), diagnostics)
}
else (state, Seq.empty)

Expand All @@ -378,8 +401,10 @@ class ReplDriver(settings: Array[String],
tree.symbol.info.memberClasses
.find(_.symbol.name == newestWrapper.moduleClassName)
.map { wrapperModule =>
val formattedTypeDefs = typeDefs(wrapperModule.symbol)
val (newState, formattedMembers) = extractAndFormatMembers(wrapperModule.symbol)
val formattedTypeDefs = // don't render type defs if wrapper initialization failed
if newState.invalidObjectIndexes.contains(state.objectIndex) then Seq.empty
else typeDefs(wrapperModule.symbol)
val highlighted = (formattedTypeDefs ++ formattedMembers)
.map(d => new Diagnostic(d.msg.mapMsg(SyntaxHighlighting.highlight), d.pos, d.level))
(newState, highlighted)
Expand Down Expand Up @@ -420,7 +445,7 @@ class ReplDriver(settings: Array[String],

case Imports =>
for {
objectIndex <- 1 to state.objectIndex
objectIndex <- state.validObjectIndexes
imp <- state.imports.getOrElse(objectIndex, Nil)
} out.println(imp.show(using state.context))
state
Expand Down
104 changes: 104 additions & 0 deletions compiler/test/dotty/tools/repl/ReplCompilerTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,110 @@ class ReplCompilerTests extends ReplTest:
assertEquals(List("// defined class C"), lines())
}

def assertNotFoundError(id: String): Unit =
val lines = storedOutput().linesIterator
assert(lines.next().startsWith("-- [E006] Not Found Error:"))
assert(lines.drop(2).next().trim().endsWith(s"Not found: $id"))

@Test def i4416 = initially {
val state = run("val x = 1 / 0")
val all = lines()
assertEquals(2, all.length)
assert(all.head.startsWith("java.lang.ArithmeticException:"))
state
} andThen {
val state = run("def foo = x")
assertNotFoundError("x")
state
} andThen {
run("x")
assertNotFoundError("x")
}

@Test def i4416b = initially {
val state = run("val a = 1234")
val _ = storedOutput() // discard output
state
} andThen {
val state = run("val a = 1; val x = ???; val y = x")
val all = lines()
assertEquals(3, all.length)
assertEquals("scala.NotImplementedError: an implementation is missing", all.head)
state
} andThen {
val state = run("x")
assertNotFoundError("x")
state
} andThen {
val state = run("y")
assertNotFoundError("y")
state
} andThen {
run("a") // `a` should retain its original binding
assertEquals("val res0: Int = 1234", storedOutput().trim)
}

@Test def i4416_imports = initially {
run("import scala.collection.mutable")
} andThen {
val state = run("import scala.util.Try; val x = ???")
val _ = storedOutput() // discard output
state
} andThen {
run(":imports") // scala.util.Try should not be imported
assertEquals("import scala.collection.mutable", storedOutput().trim)
}

@Test def i4416_types_defs_aliases = initially {
val state =
run("""|type Foo = String
|trait Bar
|def bar: Bar = ???
|val x = ???
|""".stripMargin)
val all = lines()
assertEquals(3, all.length)
assertEquals("scala.NotImplementedError: an implementation is missing", all.head)
assert("type alias in failed wrapper should not be rendered",
!all.exists(_.startsWith("// defined alias type Foo = String")))
assert("type definitions in failed wrapper should not be rendered",
!all.exists(_.startsWith("// defined trait Bar")))
assert("defs in failed wrapper should not be rendered",
!all.exists(_.startsWith("def bar: Bar")))
state
} andThen {
val state = run("def foo: Foo = ???")
assertNotFoundError("type Foo")
state
} andThen {
val state = run("type B = Bar")
assertNotFoundError("type Bar")
state
} andThen {
run("bar")
assertNotFoundError("bar")
}

@Test def i14473 = initially {
run("""val (x,y) = if true then "hi" else (42,17)""")
val all = lines()
assertEquals(2, all.length)
assertEquals("scala.MatchError: hi (of class java.lang.String)", all.head)
}

@Test def i14701 = initially {
val state = run("val _ = ???")
val all = lines()
assertEquals(3, all.length)
assertEquals("scala.NotImplementedError: an implementation is missing", all.head)
state
} andThen {
run("val _ = assert(false)")
val all = lines()
assertEquals(3, all.length)
assertEquals("java.lang.AssertionError: assertion failed", all.head)
}

@Test def i14491 =
initially {
run("import language.experimental.fewerBraces")
Expand Down

0 comments on commit 5587767

Please sign in to comment.