Skip to content

Commit

Permalink
Fix #102: Better main class detection
Browse files Browse the repository at this point in the history
Previously, the main class detection was handled by
https://github.com/sbt/zinc/blob/1.0/internal/zinc-apiinfo/src/main/scala/xsbt/api/Discovery.scala
which looks for a main method with the correct signature in the
extracted API. This is imperfect because it relies on ExtractAPI
dealiasing types (because Discovery will look for a main method with a
parameter type of `java.lang.String` and won't recognize
`scala.Predef.String`), dealiasing means that the extracted API looses
information and thus can lead to undercompilation.

This commit partially fixes this by adding a new callback to AnalysisCallback:
    void mainClass(File sourceFile, String className)
that is used to explicitly register main entry points. This way, tools
do not need to interpret the extracted API, this is much better since it
makes it easier for zinc to evolve the API representation.

This commit does not actually changes ExtractAPI to not dealias, this
can be done in a later PR.

Note that there is another usecase for xsbt.api.Discovery that this PR
does not replace: discovering tests. This is more complicated because
different test frameworks have different ways to discover tests. For
more information, grep for "Fingerprint" in https://github.com/sbt/sbt
and https://github.com/sbt/junit-interface
  • Loading branch information
smarter committed May 26, 2017
1 parent 931b57c commit f10c53c
Show file tree
Hide file tree
Showing 17 changed files with 140 additions and 18 deletions.
5 changes: 5 additions & 0 deletions internal/compiler-bridge/src/main/scala/xsbt/API.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ final class API(val global: CallbackGlobal) extends Compat with GlobalHelpers {
extractUsedNames.extractAndReport(unit)

val classApis = traverser.allNonLocalClasses
val mainClasses = traverser.mainClasses

classApis.foreach(callback.api(sourceFile, _))
mainClasses.foreach(callback.mainClass(sourceFile, _))
}
}

Expand All @@ -56,6 +58,9 @@ final class API(val global: CallbackGlobal) extends Compat with GlobalHelpers {
def allNonLocalClasses: Set[ClassLike] = {
extractApi.allExtractedNonLocalClasses
}

def mainClasses: Set[String] = extractApi.mainClasses

def `class`(c: Symbol): Unit = {
extractApi.extractAllClassesOf(c.owner, c)
}
Expand Down
12 changes: 11 additions & 1 deletion internal/compiler-bridge/src/main/scala/xsbt/ExtractAPI.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ package xsbt
import java.io.File
import java.util.{ Arrays, Comparator }
import scala.tools.nsc.symtab.Flags
import scala.collection.mutable.{ HashMap, HashSet }
import scala.collection.mutable.{ HashMap, HashSet, ListBuffer }
import xsbti.api._

import scala.tools.nsc.Global
Expand Down Expand Up @@ -71,6 +71,7 @@ class ExtractAPI[GlobalType <: Global](
private[this] val emptyStringArray = new Array[String](0)

private[this] val allNonLocalClassesInSrc = new HashSet[xsbti.api.ClassLike]
private[this] val _mainClasses = new HashSet[String]

/**
* Implements a work-around for https://github.com/sbt/sbt/issues/823
Expand Down Expand Up @@ -600,6 +601,11 @@ class ExtractAPI[GlobalType <: Global](
allNonLocalClassesInSrc.toSet
}

def mainClasses: Set[String] = {
forceStructures()
_mainClasses.toSet
}

private def classLike(in: Symbol, c: Symbol): ClassLikeDef =
classLikeCache.getOrElseUpdate((in, c), mkClassLike(in, c))
private def mkClassLike(in: Symbol, c: Symbol): ClassLikeDef = {
Expand Down Expand Up @@ -641,6 +647,10 @@ class ExtractAPI[GlobalType <: Global](

allNonLocalClassesInSrc += classWithMembers

if (sym.isStatic && defType == DefinitionType.Module && definitions.hasJavaMainMethod(sym)) {
_mainClasses += name
}

val classDef = new xsbti.api.ClassLikeDef(
name,
acc,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,20 @@ void generatedNonLocalClass(File source,
*/
void api(File sourceFile, xsbti.api.ClassLike classApi);

/**
* Register a class containing an entry point coming from a given source file.
*
* A class is an entry point if its bytecode contains a method with the
* following signature:
* <pre>
* public static void main(String[] args);
* </pre>
*
* @param sourceFile Source file where <code>className</code> is defined.
* @param className A class containing an entry point.
*/
void mainClass(File sourceFile, String className);

/**
* Register the use of a <code>name</code> from a given source class name.
*
Expand Down Expand Up @@ -158,4 +172,4 @@ void problem(String what,
* phase defined by <code>xsbt-analyzer</code> should be added.
*/
boolean enabled();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,12 @@ public interface SourceInfo {
* @return The compiler reported problems.
*/
public Problem[] getUnreportedProblems();

/**
* Returns the main classes found in this compilation unit.
*
* @return The full name of the main classes, like "foo.bar.Main"
*/
public String[] getMainClasses();
}

Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class TestCallback extends AnalysisCallback {
()
}

def mainClass(source: File, className: String): Unit = ()

override def enabled(): Boolean = true

def problem(category: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@ object ClassToAPI {
def apply(c: Seq[Class[_]]): Seq[api.ClassLike] = process(c)._1

// (api, public inherited classes)
def process(classes: Seq[Class[_]]): (Seq[api.ClassLike], Set[(Class[_], Class[_])]) = {
def process(
classes: Seq[Class[_]]): (Seq[api.ClassLike], Seq[String], Set[(Class[_], Class[_])]) = {
val cmap = emptyClassMap
classes.foreach(toDefinitions(cmap)) // force recording of class definitions
cmap.lz.foreach(_.get()) // force thunks to ensure all inherited dependencies are recorded
val classApis = cmap.allNonLocalClasses.toSeq
val mainClasses = cmap.mainClasses.toSeq
val inDeps = cmap.inherited.toSet
cmap.clear()
(classApis, inDeps)
(classApis, mainClasses, inDeps)
}

// Avoiding implicit allocation.
Expand All @@ -55,7 +57,8 @@ object ClassToAPI {
private[sbt] val memo: mutable.Map[String, Seq[api.ClassLikeDef]],
private[sbt] val inherited: mutable.Set[(Class[_], Class[_])],
private[sbt] val lz: mutable.Buffer[xsbti.api.Lazy[_]],
private[sbt] val allNonLocalClasses: mutable.Set[api.ClassLike]
private[sbt] val allNonLocalClasses: mutable.Set[api.ClassLike],
private[sbt] val mainClasses: mutable.Set[String]
) {
def clear(): Unit = {
memo.clear()
Expand All @@ -67,6 +70,7 @@ object ClassToAPI {
new ClassMap(new mutable.HashMap,
new mutable.HashSet,
new mutable.ListBuffer,
new mutable.HashSet,
new mutable.HashSet)

def classCanonicalName(c: Class[_]): String =
Expand Down Expand Up @@ -115,6 +119,17 @@ object ClassToAPI {
val defsEmptyMembers = clsDef :: statDef :: Nil
cmap.memo(name) = defsEmptyMembers
cmap.allNonLocalClasses ++= defs

if (c.getMethods.exists(
meth =>
meth.getName == "main" &&
Modifier.isStatic(meth.getModifiers) &&
meth.getParameterTypes.length == 1 &&
meth.getParameterTypes.head == classOf[Array[String]] &&
meth.getReturnType == java.lang.Void.TYPE)) {
cmap.mainClasses += name
}

defsEmptyMembers
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ class ClassToAPISpecification extends UnitSpec {
def readAPI(callback: AnalysisCallback,
source: File,
classes: Seq[Class[_]]): Set[(String, String)] = {
val (apis, inherits) = ClassToAPI.process(classes)
val (apis, mainClasses, inherits) = ClassToAPI.process(classes)
apis.foreach(callback.api(source, _))
mainClasses.foreach(callback.mainClass(source, _))
inherits.map {
case (from, to) => (from.getName, to.getName)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ private final class AnalysisCallback(
private[this] val usedNames = new HashMap[String, Set[UsedName]]
private[this] val unreporteds = new HashMap[File, ListBuffer[Problem]]
private[this] val reporteds = new HashMap[File, ListBuffer[Problem]]
private[this] val mainClasses = new HashMap[File, ListBuffer[String]]
private[this] val binaryDeps = new HashMap[File, Set[File]]
// source file to set of generated (class file, binary class name); only non local classes are stored here
private[this] val nonLocalClasses = new HashMap[File, Set[(File, String)]]
Expand Down Expand Up @@ -285,6 +286,11 @@ private final class AnalysisCallback(
}
}

def mainClass(sourceFile: File, className: String): Unit = {
mainClasses.getOrElseUpdate(sourceFile, ListBuffer.empty) += className
()
}

def usedName(className: String, name: String, useScopes: util.EnumSet[UseScope]) =
add(usedNames, className, UsedName(name, useScopes))

Expand Down Expand Up @@ -346,7 +352,9 @@ private final class AnalysisCallback(
val stamp = stampReader.source(src)
val classesInSrc = classNames.getOrElse(src, Set.empty).map(_._1)
val analyzedApis = classesInSrc.map(analyzeClass)
val info = SourceInfos.makeInfo(getOrNil(reporteds, src), getOrNil(unreporteds, src))
val info = SourceInfos.makeInfo(getOrNil(reporteds, src),
getOrNil(unreporteds, src),
getOrNil(mainClasses, src))
val binaries = binaryDeps.getOrElse(src, Nil: Iterable[File])
val localProds = localClasses.getOrElse(src, Nil: Iterable[File]) map { classFile =>
val classFileStamp = stampReader.product(classFile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ object SourceInfos {
def empty: SourceInfos = make(Map.empty)
def make(m: Map[File, SourceInfo]): SourceInfos = new MSourceInfos(m)

val emptyInfo: SourceInfo = makeInfo(Nil, Nil)
def makeInfo(reported: Seq[Problem], unreported: Seq[Problem]): SourceInfo =
new UnderlyingSourceInfo(reported, unreported)
val emptyInfo: SourceInfo = makeInfo(Nil, Nil, Nil)
def makeInfo(reported: Seq[Problem],
unreported: Seq[Problem],
mainClasses: Seq[String]): SourceInfo =
new UnderlyingSourceInfo(reported, unreported, mainClasses)
def merge(infos: Traversable[SourceInfos]): SourceInfos = (SourceInfos.empty /: infos)(_ ++ _)
}

Expand All @@ -48,8 +50,10 @@ private final class MSourceInfos(val allInfos: Map[File, SourceInfo]) extends So
}

private final class UnderlyingSourceInfo(val reportedProblems: Seq[Problem],
val unreportedProblems: Seq[Problem])
val unreportedProblems: Seq[Problem],
val mainClasses: Seq[String])
extends SourceInfo {
override def getReportedProblems: Array[Problem] = reportedProblems.toArray
override def getUnreportedProblems: Array[Problem] = unreportedProblems.toArray
override def getMainClasses: Array[String] = mainClasses.toArray
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ class TextAnalysisFormat(override val mappers: AnalysisMappers)
private implicit val analyzedClassFormat: Format[AnalyzedClass] =
AnalyzedClassFormats.analyzedClassFormat
private implicit def infoFormat: Format[SourceInfo] =
wrap[SourceInfo, (Seq[Problem], Seq[Problem])](
si => (si.getReportedProblems, si.getUnreportedProblems), {
case (a, b) => SourceInfos.makeInfo(a, b)
wrap[SourceInfo, (Seq[Problem], Seq[Problem], Seq[String])](
si => (si.getReportedProblems, si.getUnreportedProblems, si.getMainClasses), {
case (a, b, c) => SourceInfos.makeInfo(a, b, c)
})
private implicit def fileHashFormat: Format[FileHash] =
asProduct2((file: File, hash: Int) => new FileHash(file, hash))(h => (h.file, h.hash))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ trait BaseTextAnalysisFormatTest { self: Properties =>
val aClass = genClass("A").sample.get
val cClass = genClass("C").sample.get
val absent = EmptyStamp
val sourceInfos = SourceInfos.makeInfo(Nil, Nil)
val sourceInfos = SourceInfos.makeInfo(Nil, Nil, Nil)

var analysis = Analysis.empty
val products = NonLocalProduct("A", "A", f("A.class"), absent) ::
Expand Down Expand Up @@ -106,10 +106,14 @@ trait BaseTextAnalysisFormatTest { self: Properties =>
("Whole Analysis" |: left =? right)
}

private def mapInfos(a: SourceInfos): Map[File, (Seq[Problem], Seq[Problem])] =
private def mapInfos(a: SourceInfos): Map[File, (Seq[Problem], Seq[Problem], Seq[String])] =
a.allInfos.map {
case (f, infos) =>
f -> (infos.getReportedProblems.toList -> infos.getUnreportedProblems.toList)
f -> ((
infos.getReportedProblems.toList,
infos.getUnreportedProblems.toList,
infos.getMainClasses.toList
))
}

private def compareOutputs(left: Output, right: Output): Prop = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,12 @@ final class IncHandler(directory: File, scriptedLog: ManagedLogger)
p.checkClasses(i, srcFile, products)
case (p, other, _) => p.unrecognizedArguments("checkClasses", other)
},
"checkMainClasses" -> {
case (p, src :: products, i) =>
val srcFile = if (src endsWith ":") src dropRight 1 else src
p.checkMainClasses(i, srcFile, products)
case (p, other, _) => p.unrecognizedArguments("checkMainClasses", other)
},
"checkProducts" -> {
case (p, src :: products, i) =>
val srcFile = if (src endsWith ":") src dropRight 1 else src
Expand Down Expand Up @@ -336,6 +342,17 @@ case class ProjectStructure(
()
}

def checkMainClasses(i: IncInstance, src: String, expected: List[String]): Unit = {
val analysis = compile(i)
def mainClasses(src: String): Set[String] =
analysis.infos.get(baseDirectory / src).getMainClasses.toSet
def assertClasses(expected: Set[String], actual: Set[String]) =
assert(expected == actual, s"Expected $expected classes, got $actual")

assertClasses(expected.toSet, mainClasses(src))
()
}

def checkProducts(i: IncInstance, src: String, expected: List[String]): Unit = {
val analysis = compile(i)
def relativeClassDir(f: File): File = f.relativeTo(classesDir) getOrElse f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,9 @@ final class AnalyzingJavaCompiler private[sbt] (

/** Read the API information from [[Class]] to analyze dependencies. */
def readAPI(source: File, classes: Seq[Class[_]]): Set[(String, String)] = {
val (apis, inherits) = ClassToAPI.process(classes)
val (apis, mainClasses, inherits) = ClassToAPI.process(classes)
apis.foreach(callback.api(source, _))
mainClasses.foreach(callback.mainClass(source, _))
inherits.map {
case (from, to) => (from.getName, to.getName)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package runjava;

class MainJava {
public static void main(String args[]) {
}

static public class StaticInner {
public static void main(String args[]) {
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package runjava;

class NoMainJava {
public void main(String args[]) {
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package runscala

object MainScala {
def main(args: Array[String]) {}

object StaticInner {
def main(args: Array[String]) {}
}
}

class NoMainScala {
def main(args: Array[String]) {}
}
4 changes: 4 additions & 0 deletions zinc/src/sbt-test/apiinfo/main-discovery/test
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
> compile
> checkMainClasses src/main/java/runjava/MainJava.java: runjava.MainJava runjava.MainJava.StaticInner
> checkMainClasses src/main/java/runjava/oMainJava.java:
> checkMainClasses src/main/scala/Hello.scala: runscala.MainScala runscala.MainScala.StaticInner

0 comments on commit f10c53c

Please sign in to comment.